Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
37a3612e437efe4a77ea23eac7511c157689b3d9
[simgrid.git] / src / smpi / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                   */
2
3 /* Copyright (c) 2009-2015. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include <stdio.h>
10 #include <string.h>
11 #include <assert.h>
12
13 #include "private.h"
14 #include "colls/colls.h"
15 #include "simgrid/sg_config.h"
16
17 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
18
19 s_mpi_coll_description_t mpi_coll_gather_description[] = {
20    COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}  /* this array must be nullptr terminated */
21 };
22
23 s_mpi_coll_description_t mpi_coll_allgather_description[] = {
24    COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}
25 };
26
27 s_mpi_coll_description_t mpi_coll_allgatherv_description[] = { COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA),
28    {nullptr, nullptr, nullptr}      /* this array must be nullptr terminated */
29 };
30
31 s_mpi_coll_description_t mpi_coll_allreduce_description[] ={ COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA),
32   {nullptr, nullptr, nullptr}      /* this array must be nullptr terminated */
33 };
34
35 s_mpi_coll_description_t mpi_coll_reduce_scatter_description[] = {COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA),
36     {nullptr, nullptr, nullptr}      /* this array must be nullptr terminated */
37 };
38
39 s_mpi_coll_description_t mpi_coll_scatter_description[] ={COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
40
41 s_mpi_coll_description_t mpi_coll_barrier_description[] ={COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
42
43 s_mpi_coll_description_t mpi_coll_alltoall_description[] = {COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
44
45 s_mpi_coll_description_t mpi_coll_alltoallv_description[] = {COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA),
46    {nullptr, nullptr, nullptr}      /* this array must be nullptr terminated */
47 };
48
49 s_mpi_coll_description_t mpi_coll_bcast_description[] = {COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
50
51 s_mpi_coll_description_t mpi_coll_reduce_description[] = {COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
52
53
54
55 /** Displays the long description of all registered models, and quit */
56 void coll_help(const char *category, s_mpi_coll_description_t * table)
57 {
58   printf("Long description of the %s models accepted by this simulator:\n", category);
59   for (int i = 0; table[i].name; i++)
60     printf("  %s: %s\n", table[i].name, table[i].description);
61 }
62
63 int find_coll_description(s_mpi_coll_description_t * table, const char *name, const char *desc)
64 {
65   char *name_list = nullptr;
66   int selector_on=0;
67   if (name==nullptr || name[0] == '\0') {
68     //no argument provided, use active selector's algorithm
69     name=static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
70     selector_on=1;
71   }
72   for (int i = 0; table[i].name; i++)
73     if (!strcmp(name, table[i].name)) {
74       if (strcmp(table[i].name,"default"))
75         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
76       return i;
77     }
78
79   if(selector_on){
80     // collective seems not handled by the active selector, try with default one
81     for (int i = 0; table[i].name; i++)
82       if (!strcmp("default", table[i].name)) {
83         return i;
84     }
85   }
86   if (!table[0].name)
87     xbt_die("No collective is valid for '%s'! This is a bug.",name);
88   name_list = xbt_strdup(table[0].name);
89   for (int i = 1; table[i].name; i++) {
90     name_list = static_cast<char*>(xbt_realloc(name_list, strlen(name_list) + strlen(table[i].name) + 3));
91     strncat(name_list, ", ",2);
92     strncat(name_list, table[i].name, strlen(table[i].name));
93   }
94   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name, name_list);
95   return -1;
96 }
97
98 void (*smpi_coll_cleanup_callback)();
99
100 namespace simgrid{
101 namespace smpi{
102
103 int (*Colls::gather)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, int root, MPI_Comm);
104 int (*Colls::allgather)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm);
105 int (*Colls::allgatherv)(void *, int, MPI_Datatype, void*, int*, int*, MPI_Datatype, MPI_Comm);
106 int (*Colls::allreduce)(void *sbuf, void *rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
107 int (*Colls::alltoall)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm);
108 int (*Colls::alltoallv)(void *, int*, int*, MPI_Datatype, void*, int*, int*, MPI_Datatype, MPI_Comm);
109 int (*Colls::bcast)(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm com);
110 int (*Colls::reduce)(void *buf, void *rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
111 int (*Colls::reduce_scatter)(void *sbuf, void *rbuf, int *rcounts,MPI_Datatype dtype,MPI_Op  op,MPI_Comm  comm);
112 int (*Colls::scatter)(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbuf, int recvcount, MPI_Datatype recvtype,int root, MPI_Comm comm);
113 int (*Colls::barrier)(MPI_Comm comm);
114
115
116 #define COLL_SETTER(cat, ret, args, args2)\
117 void Colls::set_##cat (const char * name){\
118     int id = find_coll_description(mpi_coll_## cat ##_description,\
119                                              name,#cat);\
120     cat = reinterpret_cast<ret (*) args>\
121         (mpi_coll_## cat ##_description[id].coll);\
122 }
123
124 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
125 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
126 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
127 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
128 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
129 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
130 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
131 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
132 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
133 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
134 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
135
136
137 void Colls::set_collectives(){
138     const char* selector_name = static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
139     if (selector_name==nullptr || selector_name[0] == '\0')
140         selector_name = "default";
141
142     const char* name = xbt_cfg_get_string("smpi/gather");
143     if (name==nullptr || name[0] == '\0')
144         name = selector_name;
145       
146     set_gather(name);
147
148     name = xbt_cfg_get_string("smpi/allgather");
149     if (name==nullptr || name[0] == '\0')
150         name = selector_name;
151
152     set_allgather(name);
153
154     name = xbt_cfg_get_string("smpi/allgatherv");
155     if (name==nullptr || name[0] == '\0')
156         name = selector_name;
157
158     set_allgatherv(name);
159
160     name = xbt_cfg_get_string("smpi/allreduce");
161     if (name==nullptr || name[0] == '\0')
162         name = selector_name;
163
164     set_allreduce(name);
165
166     name = xbt_cfg_get_string("smpi/alltoall");
167     if (name==nullptr || name[0] == '\0')
168         name = selector_name;
169
170     set_alltoall(name);
171
172     name = xbt_cfg_get_string("smpi/alltoallv");
173     if (name==nullptr || name[0] == '\0')
174         name = selector_name;
175
176     set_alltoallv(name);
177
178     name = xbt_cfg_get_string("smpi/reduce");
179     if (name==nullptr || name[0] == '\0')
180         name = selector_name;
181
182     set_reduce(name);
183
184     name = xbt_cfg_get_string("smpi/reduce-scatter");
185     if (name==nullptr || name[0] == '\0')
186         name = selector_name;
187
188     set_reduce_scatter(name);
189
190     name = xbt_cfg_get_string("smpi/scatter");
191     if (name==nullptr || name[0] == '\0')
192         name = selector_name;
193
194     set_scatter(name);
195
196     name = xbt_cfg_get_string("smpi/bcast");
197     if (name==nullptr || name[0] == '\0')
198         name = selector_name;
199
200     set_bcast(name);
201
202     name = xbt_cfg_get_string("smpi/barrier");
203     if (name==nullptr || name[0] == '\0')
204         name = selector_name;
205
206     set_barrier(name);
207 }
208
209
210 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
211                       MPI_Datatype recvtype, int root, MPI_Comm comm)
212 {
213   int system_tag = COLL_TAG_GATHERV;
214   MPI_Aint lb = 0;
215   MPI_Aint recvext = 0;
216
217   int rank = comm->rank();
218   int size = comm->size();
219   if (rank != root) {
220     // Send buffer to root
221     Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
222   } else {
223     recvtype->extent(&lb, &recvext);
224     // Local copy from root
225     Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
226                        recvcounts[root], recvtype);
227     // Receive buffers from senders
228     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
229     int index = 0;
230     for (int src = 0; src < size; src++) {
231       if(src != root) {
232         requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
233                           recvcounts[src], recvtype, src, system_tag, comm);
234         index++;
235       }
236     }
237     // Wait for completion of irecv's.
238     Request::startall(size - 1, requests);
239     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
240     for (int src = 0; src < size-1; src++) {
241       Request::unref(&requests[src]);
242     }
243     xbt_free(requests);
244   }
245   return MPI_SUCCESS;
246 }
247
248
249 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
250                        MPI_Datatype recvtype, int root, MPI_Comm comm)
251 {
252   int system_tag = COLL_TAG_SCATTERV;
253   MPI_Aint lb = 0;
254   MPI_Aint sendext = 0;
255
256   int rank = comm->rank();
257   int size = comm->size();
258   if(rank != root) {
259     // Recv buffer from root
260     Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
261   } else {
262     sendtype->extent(&lb, &sendext);
263     // Local copy from root
264     if(recvbuf!=MPI_IN_PLACE){
265       Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
266                        sendtype, recvbuf, recvcount, recvtype);
267     }
268     // Send buffers to receivers
269     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
270     int index = 0;
271     for (int dst = 0; dst < size; dst++) {
272       if (dst != root) {
273         requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
274                             sendtype, dst, system_tag, comm);
275         index++;
276       }
277     }
278     // Wait for completion of isend's.
279     Request::startall(size - 1, requests);
280     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
281     for (int dst = 0; dst < size-1; dst++) {
282       Request::unref(&requests[dst]);
283     }
284     xbt_free(requests);
285   }
286   return MPI_SUCCESS;
287 }
288
289
290 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
291 {
292   int system_tag = -888;
293   MPI_Aint lb      = 0;
294   MPI_Aint dataext = 0;
295
296   int rank = comm->rank();
297   int size = comm->size();
298
299   datatype->extent(&lb, &dataext);
300
301   // Local copy from self
302   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
303
304   // Send/Recv buffers to/from others;
305   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
306   void **tmpbufs = xbt_new(void *, rank);
307   int index = 0;
308   for (int other = 0; other < rank; other++) {
309     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
310     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
311     index++;
312   }
313   for (int other = rank + 1; other < size; other++) {
314     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
315     index++;
316   }
317   // Wait for completion of all comms.
318   Request::startall(size - 1, requests);
319
320   if(op != MPI_OP_NULL && op->is_commutative()){
321     for (int other = 0; other < size - 1; other++) {
322       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
323       if(index == MPI_UNDEFINED) {
324         break;
325       }
326       if(index < rank) {
327         // #Request is below rank: it's a irecv
328         if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
329       }
330     }
331   }else{
332     //non commutative case, wait in order
333     for (int other = 0; other < size - 1; other++) {
334       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
335       if(index < rank) {
336         if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
337       }
338     }
339   }
340   for(index = 0; index < rank; index++) {
341     smpi_free_tmp_buffer(tmpbufs[index]);
342   }
343   for(index = 0; index < size-1; index++) {
344     Request::unref(&requests[index]);
345   }
346   xbt_free(tmpbufs);
347   xbt_free(requests);
348   return MPI_SUCCESS;
349 }
350
351 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
352 {
353   int system_tag = -888;
354   MPI_Aint lb         = 0;
355   MPI_Aint dataext    = 0;
356   int recvbuf_is_empty=1;
357   int rank = comm->rank();
358   int size = comm->size();
359
360   datatype->extent(&lb, &dataext);
361
362   // Send/Recv buffers to/from others;
363   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
364   void **tmpbufs = xbt_new(void *, rank);
365   int index = 0;
366   for (int other = 0; other < rank; other++) {
367     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
368     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
369     index++;
370   }
371   for (int other = rank + 1; other < size; other++) {
372     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
373     index++;
374   }
375   // Wait for completion of all comms.
376   Request::startall(size - 1, requests);
377
378   if(op != MPI_OP_NULL && op->is_commutative()){
379     for (int other = 0; other < size - 1; other++) {
380       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
381       if(index == MPI_UNDEFINED) {
382         break;
383       }
384       if(index < rank) {
385         if(recvbuf_is_empty){
386           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
387           recvbuf_is_empty=0;
388         } else
389           // #Request is below rank: it's a irecv
390           if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
391       }
392     }
393   }else{
394     //non commutative case, wait in order
395     for (int other = 0; other < size - 1; other++) {
396      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
397       if(index < rank) {
398         if (recvbuf_is_empty) {
399           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
400           recvbuf_is_empty = 0;
401         } else
402           if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
403       }
404     }
405   }
406   for(index = 0; index < rank; index++) {
407     smpi_free_tmp_buffer(tmpbufs[index]);
408   }
409   for(index = 0; index < size-1; index++) {
410     Request::unref(&requests[index]);
411   }
412   xbt_free(tmpbufs);
413   xbt_free(requests);
414   return MPI_SUCCESS;
415 }
416
417 }
418 }
419
420
421
422
423