Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
add gather collectives from ompi
[simgrid.git] / src / smpi / smpi_coll.c
1 /* smpi_coll.c -- various optimized routing for collectives                   */
2
3 /* Copyright (c) 2009, 2010. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include <stdio.h>
10 #include <string.h>
11 #include <assert.h>
12
13 #include "private.h"
14 #include "colls/colls.h"
15 #include "simgrid/sg_config.h"
16
17 s_mpi_coll_description_t mpi_coll_gather_description[] = {
18   {"default",
19    "gather default collective",
20    smpi_mpi_gather},
21 COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA),
22   {NULL, NULL, NULL}      /* this array must be NULL terminated */
23 };
24
25
26 s_mpi_coll_description_t mpi_coll_allgather_description[] = {
27   {"default",
28    "allgather default collective",
29    smpi_mpi_allgather},
30 COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA),
31   {NULL, NULL, NULL}      /* this array must be NULL terminated */
32 };
33
34 s_mpi_coll_description_t mpi_coll_allgatherv_description[] = {
35   {"default",
36    "allgatherv default collective",
37    smpi_mpi_allgatherv},
38 COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA),
39   {NULL, NULL, NULL}      /* this array must be NULL terminated */
40 };
41
42 s_mpi_coll_description_t mpi_coll_allreduce_description[] = {
43   {"default",
44    "allreduce default collective",
45    smpi_mpi_allreduce},
46 COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA),
47   {NULL, NULL, NULL}      /* this array must be NULL terminated */
48 };
49
50 s_mpi_coll_description_t mpi_coll_alltoall_description[] = {
51   {"default",
52    "Ompi alltoall default collective",
53    smpi_coll_tuned_alltoall_ompi2},
54 COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA),
55   {"bruck",
56    "Alltoall Bruck (SG) collective",
57    smpi_coll_tuned_alltoall_bruck},
58   {"basic_linear",
59    "Alltoall basic linear (SG) collective",
60    smpi_coll_tuned_alltoall_basic_linear},
61   {NULL, NULL, NULL}      /* this array must be NULL terminated */
62 };
63
64 s_mpi_coll_description_t mpi_coll_alltoallv_description[] = {
65   {"default",
66    "Ompi alltoallv default collective",
67    smpi_coll_basic_alltoallv},
68 COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA),
69   {NULL, NULL, NULL}      /* this array must be NULL terminated */
70 };
71
72 s_mpi_coll_description_t mpi_coll_bcast_description[] = {
73   {"default",
74    "bcast default collective",
75    smpi_mpi_bcast},
76 COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA),
77   {NULL, NULL, NULL}      /* this array must be NULL terminated */
78 };
79
80 s_mpi_coll_description_t mpi_coll_reduce_description[] = {
81   {"default",
82    "reduce default collective",
83    smpi_mpi_reduce},
84 COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA),
85   {NULL, NULL, NULL}      /* this array must be NULL terminated */
86 };
87
88
89
90 /** Displays the long description of all registered models, and quit */
91 void coll_help(const char *category, s_mpi_coll_description_t * table)
92 {
93   int i;
94   printf("Long description of the %s models accepted by this simulator:\n",
95          category);
96   for (i = 0; table[i].name; i++)
97     printf("  %s: %s\n", table[i].name, table[i].description);
98 }
99
100 int find_coll_description(s_mpi_coll_description_t * table,
101                            char *name)
102 {
103   int i;
104   char *name_list = NULL;
105   int selector_on=0;
106   if(name==NULL){//no argument provided, use active selector's algorithm
107     name=(char*)sg_cfg_get_string("smpi/coll_selector");
108     selector_on=1;
109   }
110   for (i = 0; table[i].name; i++)
111     if (!strcmp(name, table[i].name)) {
112       return i;
113     }
114
115   if(selector_on){
116     // collective seems not handled by the active selector, try with default one
117     name=(char*)"default";
118     for (i = 0; table[i].name; i++)
119       if (!strcmp(name, table[i].name)) {
120         return i;
121     }
122   }
123   name_list = strdup(table[0].name);
124   for (i = 1; table[i].name; i++) {
125     name_list =
126         xbt_realloc(name_list,
127                     strlen(name_list) + strlen(table[i].name) + 3);
128     strcat(name_list, ", ");
129     strcat(name_list, table[i].name);
130   }
131   xbt_die("Model '%s' is invalid! Valid models are: %s.", name, name_list);
132   return -1;
133 }
134
135 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi,
136                                 "Logging specific to SMPI (coll)");
137
138 int (*mpi_coll_gather_fun)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, int root, MPI_Comm);
139 int (*mpi_coll_allgather_fun)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm);
140 int (*mpi_coll_allgatherv_fun)(void *, int, MPI_Datatype, void*, int*, int*, MPI_Datatype, MPI_Comm);
141 int (*mpi_coll_allreduce_fun)(void *sbuf, void *rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
142 int (*mpi_coll_alltoall_fun)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm);
143 int (*mpi_coll_alltoallv_fun)(void *, int*, int*, MPI_Datatype, void*, int*, int*, MPI_Datatype, MPI_Comm);
144 int (*mpi_coll_bcast_fun)(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm com);
145 int (*mpi_coll_reduce_fun)(void *buf, void *rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
146
147 struct s_proc_tree {
148   int PROCTREE_A;
149   int numChildren;
150   int *child;
151   int parent;
152   int me;
153   int root;
154   int isRoot;
155 };
156 typedef struct s_proc_tree *proc_tree_t;
157
158 /**
159  * alloc and init
160  **/
161 static proc_tree_t alloc_tree(int arity)
162 {
163   proc_tree_t tree;
164   int i;
165
166   tree = xbt_new(struct s_proc_tree, 1);
167   tree->PROCTREE_A = arity;
168   tree->isRoot = 0;
169   tree->numChildren = 0;
170   tree->child = xbt_new(int, arity);
171   for (i = 0; i < arity; i++) {
172     tree->child[i] = -1;
173   }
174   tree->root = -1;
175   tree->parent = -1;
176   return tree;
177 }
178
179 /**
180  * free
181  **/
182 static void free_tree(proc_tree_t tree)
183 {
184   xbt_free(tree->child);
185   xbt_free(tree);
186 }
187
188 /**
189  * Build the tree depending on a process rank (index) and the group size (extent)
190  * @param root the rank of the tree root
191  * @param rank the rank of the calling process
192  * @param size the total number of processes
193  **/
194 static void build_tree(int root, int rank, int size, proc_tree_t * tree)
195 {
196   int index = (rank - root + size) % size;
197   int firstChildIdx = index * (*tree)->PROCTREE_A + 1;
198   int i;
199
200   (*tree)->me = rank;
201   (*tree)->root = root;
202
203   for (i = 0; i < (*tree)->PROCTREE_A && firstChildIdx + i < size; i++) {
204     (*tree)->child[i] = (firstChildIdx + i + root) % size;
205     (*tree)->numChildren++;
206   }
207   if (rank == root) {
208     (*tree)->isRoot = 1;
209   } else {
210     (*tree)->isRoot = 0;
211     (*tree)->parent = (((index - 1) / (*tree)->PROCTREE_A) + root) % size;
212   }
213 }
214
215 /**
216  * bcast
217  **/
218 static void tree_bcast(void *buf, int count, MPI_Datatype datatype,
219                        MPI_Comm comm, proc_tree_t tree)
220 {
221   int system_tag = 999;         // used negative int but smpi_create_request() declares this illegal (to be checked)
222   int rank, i;
223   MPI_Request *requests;
224
225   rank = smpi_comm_rank(comm);
226   /* wait for data from my parent in the tree */
227   if (!tree->isRoot) {
228     XBT_DEBUG("<%d> tree_bcast(): i am not root: recv from %d, tag=%d)",
229            rank, tree->parent, system_tag + rank);
230     smpi_mpi_recv(buf, count, datatype, tree->parent, system_tag + rank,
231                   comm, MPI_STATUS_IGNORE);
232   }
233   requests = xbt_new(MPI_Request, tree->numChildren);
234   XBT_DEBUG("<%d> creates %d requests (1 per child)", rank,
235          tree->numChildren);
236   /* iniates sends to ranks lower in the tree */
237   for (i = 0; i < tree->numChildren; i++) {
238     if (tree->child[i] == -1) {
239       requests[i] = MPI_REQUEST_NULL;
240     } else {
241       XBT_DEBUG("<%d> send to <%d>, tag=%d", rank, tree->child[i],
242              system_tag + tree->child[i]);
243       requests[i] =
244           smpi_isend_init(buf, count, datatype, tree->child[i],
245                           system_tag + tree->child[i], comm);
246     }
247   }
248   smpi_mpi_startall(tree->numChildren, requests);
249   smpi_mpi_waitall(tree->numChildren, requests, MPI_STATUS_IGNORE);
250   xbt_free(requests);
251 }
252
253 /**
254  * anti-bcast
255  **/
256 static void tree_antibcast(void *buf, int count, MPI_Datatype datatype,
257                            MPI_Comm comm, proc_tree_t tree)
258 {
259   int system_tag = 999;         // used negative int but smpi_create_request() declares this illegal (to be checked)
260   int rank, i;
261   MPI_Request *requests;
262
263   rank = smpi_comm_rank(comm);
264   // everyone sends to its parent, except root.
265   if (!tree->isRoot) {
266     XBT_DEBUG("<%d> tree_antibcast(): i am not root: send to %d, tag=%d)",
267            rank, tree->parent, system_tag + rank);
268     smpi_mpi_send(buf, count, datatype, tree->parent, system_tag + rank,
269                   comm);
270   }
271   //every one receives as many messages as it has children
272   requests = xbt_new(MPI_Request, tree->numChildren);
273   XBT_DEBUG("<%d> creates %d requests (1 per child)", rank,
274          tree->numChildren);
275   for (i = 0; i < tree->numChildren; i++) {
276     if (tree->child[i] == -1) {
277       requests[i] = MPI_REQUEST_NULL;
278     } else {
279       XBT_DEBUG("<%d> recv from <%d>, tag=%d", rank, tree->child[i],
280              system_tag + tree->child[i]);
281       requests[i] =
282           smpi_irecv_init(buf, count, datatype, tree->child[i],
283                           system_tag + tree->child[i], comm);
284     }
285   }
286   smpi_mpi_startall(tree->numChildren, requests);
287   smpi_mpi_waitall(tree->numChildren, requests, MPI_STATUS_IGNORE);
288   xbt_free(requests);
289 }
290
291 /**
292  * bcast with a binary, ternary, or whatever tree ..
293  **/
294 void nary_tree_bcast(void *buf, int count, MPI_Datatype datatype, int root,
295                      MPI_Comm comm, int arity)
296 {
297   proc_tree_t tree = alloc_tree(arity);
298   int rank, size;
299
300   rank = smpi_comm_rank(comm);
301   size = smpi_comm_size(comm);
302   build_tree(root, rank, size, &tree);
303   tree_bcast(buf, count, datatype, comm, tree);
304   free_tree(tree);
305 }
306
307 /**
308  * barrier with a binary, ternary, or whatever tree ..
309  **/
310 void nary_tree_barrier(MPI_Comm comm, int arity)
311 {
312   proc_tree_t tree = alloc_tree(arity);
313   int rank, size;
314   char dummy = '$';
315
316   rank = smpi_comm_rank(comm);
317   size = smpi_comm_size(comm);
318   build_tree(0, rank, size, &tree);
319   tree_antibcast(&dummy, 1, MPI_CHAR, comm, tree);
320   tree_bcast(&dummy, 1, MPI_CHAR, comm, tree);
321   free_tree(tree);
322 }
323
324 int smpi_coll_tuned_alltoall_ompi2(void *sendbuf, int sendcount,
325                                    MPI_Datatype sendtype, void *recvbuf,
326                                    int recvcount, MPI_Datatype recvtype,
327                                    MPI_Comm comm)
328 {
329   int size, sendsize;   
330   size = smpi_comm_size(comm);  
331   sendsize = smpi_datatype_size(sendtype) * sendcount;  
332   if (sendsize < 200 && size > 12) {
333     return
334         smpi_coll_tuned_alltoall_bruck(sendbuf, sendcount, sendtype,
335                                        recvbuf, recvcount, recvtype,
336                                        comm);
337   } else if (sendsize < 3000) {
338     return
339         smpi_coll_tuned_alltoall_basic_linear(sendbuf, sendcount,
340                                               sendtype, recvbuf,
341                                               recvcount, recvtype, comm);
342   } else {
343     return
344         smpi_coll_tuned_alltoall_ring(sendbuf, sendcount, sendtype,
345                                       recvbuf, recvcount, recvtype,
346                                       comm);
347   }
348 }
349
350 /**
351  * Alltoall Bruck
352  *
353  * Openmpi calls this routine when the message size sent to each rank < 2000 bytes and size < 12
354  * FIXME: uh, check smpi_pmpi again, but this routine is called for > 12, not
355  * less...
356  **/
357 int smpi_coll_tuned_alltoall_bruck(void *sendbuf, int sendcount,
358                                    MPI_Datatype sendtype, void *recvbuf,
359                                    int recvcount, MPI_Datatype recvtype,
360                                    MPI_Comm comm)
361 {
362   int system_tag = 777;
363   int i, rank, size, err, count;
364   MPI_Aint lb;
365   MPI_Aint sendext = 0;
366   MPI_Aint recvext = 0;
367   MPI_Request *requests;
368
369   // FIXME: check implementation
370   rank = smpi_comm_rank(comm);
371   size = smpi_comm_size(comm);
372   XBT_DEBUG("<%d> algorithm alltoall_bruck() called.", rank);
373   err = smpi_datatype_extent(sendtype, &lb, &sendext);
374   err = smpi_datatype_extent(recvtype, &lb, &recvext);
375   /* Local copy from self */
376   err =
377       smpi_datatype_copy((char *)sendbuf + rank * sendcount * sendext, 
378                          sendcount, sendtype, 
379                          (char *)recvbuf + rank * recvcount * recvext,
380                          recvcount, recvtype);
381   if (err == MPI_SUCCESS && size > 1) {
382     /* Initiate all send/recv to/from others. */
383     requests = xbt_new(MPI_Request, 2 * (size - 1));
384     count = 0;
385     /* Create all receives that will be posted first */
386     for (i = 0; i < size; ++i) {
387       if (i == rank) {
388         XBT_DEBUG("<%d> skip request creation [src = %d, recvcount = %d]",
389                rank, i, recvcount);
390         continue;
391       }
392       requests[count] =
393           smpi_irecv_init((char *)recvbuf + i * recvcount * recvext, recvcount,
394                           recvtype, i, system_tag, comm);
395       count++;
396     }
397     /* Now create all sends  */
398     for (i = 0; i < size; ++i) {
399       if (i == rank) {
400         XBT_DEBUG("<%d> skip request creation [dst = %d, sendcount = %d]",
401                rank, i, sendcount);
402         continue;
403       }
404       requests[count] =
405           smpi_isend_init((char *)sendbuf + i * sendcount * sendext, sendcount,
406                           sendtype, i, system_tag, comm);
407       count++;
408     }
409     /* Wait for them all. */
410     smpi_mpi_startall(count, requests);
411     XBT_DEBUG("<%d> wait for %d requests", rank, count);
412     smpi_mpi_waitall(count, requests, MPI_STATUS_IGNORE);
413     xbt_free(requests);
414   }
415   return MPI_SUCCESS;
416 }
417
418 /**
419  * Alltoall basic_linear (STARMPI:alltoall-simple)
420  **/
421 int smpi_coll_tuned_alltoall_basic_linear(void *sendbuf, int sendcount,
422                                           MPI_Datatype sendtype,
423                                           void *recvbuf, int recvcount,
424                                           MPI_Datatype recvtype,
425                                           MPI_Comm comm)
426 {
427   int system_tag = 888;
428   int i, rank, size, err, count;
429   MPI_Aint lb = 0, sendext = 0, recvext = 0;
430   MPI_Request *requests;
431
432   /* Initialize. */
433   rank = smpi_comm_rank(comm);
434   size = smpi_comm_size(comm);
435   XBT_DEBUG("<%d> algorithm alltoall_basic_linear() called.", rank);
436   err = smpi_datatype_extent(sendtype, &lb, &sendext);
437   err = smpi_datatype_extent(recvtype, &lb, &recvext);
438   /* simple optimization */
439   err = smpi_datatype_copy((char *)sendbuf + rank * sendcount * sendext, 
440                            sendcount, sendtype, 
441                            (char *)recvbuf + rank * recvcount * recvext, 
442                            recvcount, recvtype);
443   if (err == MPI_SUCCESS && size > 1) {
444     /* Initiate all send/recv to/from others. */
445     requests = xbt_new(MPI_Request, 2 * (size - 1));
446     /* Post all receives first -- a simple optimization */
447     count = 0;
448     for (i = (rank + 1) % size; i != rank; i = (i + 1) % size) {
449       requests[count] =
450           smpi_irecv_init((char *)recvbuf + i * recvcount * recvext, recvcount, 
451                           recvtype, i, system_tag, comm);
452       count++;
453     }
454     /* Now post all sends in reverse order
455      *   - We would like to minimize the search time through message queue
456      *     when messages actually arrive in the order in which they were posted.
457      * TODO: check the previous assertion
458      */
459     for (i = (rank + size - 1) % size; i != rank; i = (i + size - 1) % size) {
460       requests[count] =
461           smpi_isend_init((char *)sendbuf + i * sendcount * sendext, sendcount,
462                           sendtype, i, system_tag, comm);
463       count++;
464     }
465     /* Wait for them all. */
466     smpi_mpi_startall(count, requests);
467     XBT_DEBUG("<%d> wait for %d requests", rank, count);
468     smpi_mpi_waitall(count, requests, MPI_STATUS_IGNORE);
469     xbt_free(requests);
470   }
471   return err;
472 }
473
474 int smpi_coll_basic_alltoallv(void *sendbuf, int *sendcounts,
475                               int *senddisps, MPI_Datatype sendtype,
476                               void *recvbuf, int *recvcounts,
477                               int *recvdisps, MPI_Datatype recvtype,
478                               MPI_Comm comm)
479 {
480   int system_tag = 889;
481   int i, rank, size, err, count;
482   MPI_Aint lb = 0, sendext = 0, recvext = 0;
483   MPI_Request *requests;
484
485   /* Initialize. */
486   rank = smpi_comm_rank(comm);
487   size = smpi_comm_size(comm);
488   XBT_DEBUG("<%d> algorithm basic_alltoallv() called.", rank);
489   err = smpi_datatype_extent(sendtype, &lb, &sendext);
490   err = smpi_datatype_extent(recvtype, &lb, &recvext);
491   /* Local copy from self */
492   err =
493       smpi_datatype_copy((char *)sendbuf + senddisps[rank] * sendext, 
494                          sendcounts[rank], sendtype,
495                          (char *)recvbuf + recvdisps[rank] * recvext, 
496                          recvcounts[rank], recvtype);
497   if (err == MPI_SUCCESS && size > 1) {
498     /* Initiate all send/recv to/from others. */
499     requests = xbt_new(MPI_Request, 2 * (size - 1));
500     count = 0;
501     /* Create all receives that will be posted first */
502     for (i = 0; i < size; ++i) {
503       if (i == rank || recvcounts[i] == 0) {
504         XBT_DEBUG
505             ("<%d> skip request creation [src = %d, recvcounts[src] = %d]",
506              rank, i, recvcounts[i]);
507         continue;
508       }
509       requests[count] =
510           smpi_irecv_init((char *)recvbuf + recvdisps[i] * recvext, 
511                           recvcounts[i], recvtype, i, system_tag, comm);
512       count++;
513     }
514     /* Now create all sends  */
515     for (i = 0; i < size; ++i) {
516       if (i == rank || sendcounts[i] == 0) {
517         XBT_DEBUG
518             ("<%d> skip request creation [dst = %d, sendcounts[dst] = %d]",
519              rank, i, sendcounts[i]);
520         continue;
521       }
522       requests[count] =
523           smpi_isend_init((char *)sendbuf + senddisps[i] * sendext, 
524                           sendcounts[i], sendtype, i, system_tag, comm);
525       count++;
526     }
527     /* Wait for them all. */
528     smpi_mpi_startall(count, requests);
529     XBT_DEBUG("<%d> wait for %d requests", rank, count);
530     smpi_mpi_waitall(count, requests, MPI_STATUS_IGNORE);
531     xbt_free(requests);
532   }
533   return err;
534 }