Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
hide this from users
[simgrid.git] / src / smpi / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                   */
2
3 /* Copyright (c) 2009-2015. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include <stdio.h>
10 #include <string.h>
11 #include <assert.h>
12
13 #include "private.h"
14 #include "simgrid/sg_config.h"
15
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
17
18
19 namespace simgrid{
20 namespace smpi{
21
22 void (*Colls::smpi_coll_cleanup_callback)();
23
24 /* these arrays must be nullptr terminated */
25 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
26     COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
27 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
28     COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
29 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
30     COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
31 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
32     COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
33 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
34     COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
35 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
36     COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
37 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
38     COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
39 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
40     COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
41 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
42     COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
43 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
44     COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
45 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
46     COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
47
48 /** Displays the long description of all registered models, and quit */
49 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
50 {
51   printf("Long description of the %s models accepted by this simulator:\n", category);
52   for (int i = 0; table[i].name; i++)
53     printf("  %s: %s\n", table[i].name, table[i].description);
54 }
55
56 int Colls::find_coll_description(s_mpi_coll_description_t * table, const char *name, const char *desc)
57 {
58   char *name_list = nullptr;
59   for (int i = 0; table[i].name; i++)
60     if (!strcmp(name, table[i].name)) {
61       if (strcmp(table[i].name,"default"))
62         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
63       return i;
64     }
65
66   if (!table[0].name)
67     xbt_die("No collective is valid for '%s'! This is a bug.",name);
68   name_list = xbt_strdup(table[0].name);
69   for (int i = 1; table[i].name; i++) {
70     name_list = static_cast<char*>(xbt_realloc(name_list, strlen(name_list) + strlen(table[i].name) + 3));
71     strncat(name_list, ", ",2);
72     strncat(name_list, table[i].name, strlen(table[i].name));
73   }
74   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name, name_list);
75   return -1;
76 }
77
78
79
80 #define COLL_SETTER(cat, ret, args, args2)\
81 int (*Colls::cat ) args;\
82 void Colls::set_##cat (const char * name){\
83     int id = find_coll_description(mpi_coll_## cat ##_description,\
84                                              name,#cat);\
85     cat = reinterpret_cast<ret (*) args>\
86         (mpi_coll_## cat ##_description[id].coll);\
87 }
88
89 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
90 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
95 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
96 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
100
101
102 void Colls::set_collectives(){
103     const char* selector_name = static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
104     if (selector_name==nullptr || selector_name[0] == '\0')
105         selector_name = "default";
106
107     const char* name = xbt_cfg_get_string("smpi/gather");
108     if (name==nullptr || name[0] == '\0')
109         name = selector_name;
110       
111     set_gather(name);
112
113     name = xbt_cfg_get_string("smpi/allgather");
114     if (name==nullptr || name[0] == '\0')
115         name = selector_name;
116
117     set_allgather(name);
118
119     name = xbt_cfg_get_string("smpi/allgatherv");
120     if (name==nullptr || name[0] == '\0')
121         name = selector_name;
122
123     set_allgatherv(name);
124
125     name = xbt_cfg_get_string("smpi/allreduce");
126     if (name==nullptr || name[0] == '\0')
127         name = selector_name;
128
129     set_allreduce(name);
130
131     name = xbt_cfg_get_string("smpi/alltoall");
132     if (name==nullptr || name[0] == '\0')
133         name = selector_name;
134
135     set_alltoall(name);
136
137     name = xbt_cfg_get_string("smpi/alltoallv");
138     if (name==nullptr || name[0] == '\0')
139         name = selector_name;
140
141     set_alltoallv(name);
142
143     name = xbt_cfg_get_string("smpi/reduce");
144     if (name==nullptr || name[0] == '\0')
145         name = selector_name;
146
147     set_reduce(name);
148
149     name = xbt_cfg_get_string("smpi/reduce-scatter");
150     if (name==nullptr || name[0] == '\0')
151         name = selector_name;
152
153     set_reduce_scatter(name);
154
155     name = xbt_cfg_get_string("smpi/scatter");
156     if (name==nullptr || name[0] == '\0')
157         name = selector_name;
158
159     set_scatter(name);
160
161     name = xbt_cfg_get_string("smpi/bcast");
162     if (name==nullptr || name[0] == '\0')
163         name = selector_name;
164
165     set_bcast(name);
166
167     name = xbt_cfg_get_string("smpi/barrier");
168     if (name==nullptr || name[0] == '\0')
169         name = selector_name;
170
171     set_barrier(name);
172 }
173
174
175 //Implementations of the single algorith collectives
176
177 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
178                       MPI_Datatype recvtype, int root, MPI_Comm comm)
179 {
180   int system_tag = COLL_TAG_GATHERV;
181   MPI_Aint lb = 0;
182   MPI_Aint recvext = 0;
183
184   int rank = comm->rank();
185   int size = comm->size();
186   if (rank != root) {
187     // Send buffer to root
188     Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
189   } else {
190     recvtype->extent(&lb, &recvext);
191     // Local copy from root
192     Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
193                        recvcounts[root], recvtype);
194     // Receive buffers from senders
195     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
196     int index = 0;
197     for (int src = 0; src < size; src++) {
198       if(src != root) {
199         requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
200                           recvcounts[src], recvtype, src, system_tag, comm);
201         index++;
202       }
203     }
204     // Wait for completion of irecv's.
205     Request::startall(size - 1, requests);
206     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
207     for (int src = 0; src < size-1; src++) {
208       Request::unref(&requests[src]);
209     }
210     xbt_free(requests);
211   }
212   return MPI_SUCCESS;
213 }
214
215
216 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
217                        MPI_Datatype recvtype, int root, MPI_Comm comm)
218 {
219   int system_tag = COLL_TAG_SCATTERV;
220   MPI_Aint lb = 0;
221   MPI_Aint sendext = 0;
222
223   int rank = comm->rank();
224   int size = comm->size();
225   if(rank != root) {
226     // Recv buffer from root
227     Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
228   } else {
229     sendtype->extent(&lb, &sendext);
230     // Local copy from root
231     if(recvbuf!=MPI_IN_PLACE){
232       Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
233                        sendtype, recvbuf, recvcount, recvtype);
234     }
235     // Send buffers to receivers
236     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
237     int index = 0;
238     for (int dst = 0; dst < size; dst++) {
239       if (dst != root) {
240         requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
241                             sendtype, dst, system_tag, comm);
242         index++;
243       }
244     }
245     // Wait for completion of isend's.
246     Request::startall(size - 1, requests);
247     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
248     for (int dst = 0; dst < size-1; dst++) {
249       Request::unref(&requests[dst]);
250     }
251     xbt_free(requests);
252   }
253   return MPI_SUCCESS;
254 }
255
256
257 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
258 {
259   int system_tag = -888;
260   MPI_Aint lb      = 0;
261   MPI_Aint dataext = 0;
262
263   int rank = comm->rank();
264   int size = comm->size();
265
266   datatype->extent(&lb, &dataext);
267
268   // Local copy from self
269   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
270
271   // Send/Recv buffers to/from others;
272   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
273   void **tmpbufs = xbt_new(void *, rank);
274   int index = 0;
275   for (int other = 0; other < rank; other++) {
276     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
277     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
278     index++;
279   }
280   for (int other = rank + 1; other < size; other++) {
281     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
282     index++;
283   }
284   // Wait for completion of all comms.
285   Request::startall(size - 1, requests);
286
287   if(op != MPI_OP_NULL && op->is_commutative()){
288     for (int other = 0; other < size - 1; other++) {
289       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
290       if(index == MPI_UNDEFINED) {
291         break;
292       }
293       if(index < rank) {
294         // #Request is below rank: it's a irecv
295         if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
296       }
297     }
298   }else{
299     //non commutative case, wait in order
300     for (int other = 0; other < size - 1; other++) {
301       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
302       if(index < rank) {
303         if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
304       }
305     }
306   }
307   for(index = 0; index < rank; index++) {
308     smpi_free_tmp_buffer(tmpbufs[index]);
309   }
310   for(index = 0; index < size-1; index++) {
311     Request::unref(&requests[index]);
312   }
313   xbt_free(tmpbufs);
314   xbt_free(requests);
315   return MPI_SUCCESS;
316 }
317
318 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
319 {
320   int system_tag = -888;
321   MPI_Aint lb         = 0;
322   MPI_Aint dataext    = 0;
323   int recvbuf_is_empty=1;
324   int rank = comm->rank();
325   int size = comm->size();
326
327   datatype->extent(&lb, &dataext);
328
329   // Send/Recv buffers to/from others;
330   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
331   void **tmpbufs = xbt_new(void *, rank);
332   int index = 0;
333   for (int other = 0; other < rank; other++) {
334     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
335     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
336     index++;
337   }
338   for (int other = rank + 1; other < size; other++) {
339     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
340     index++;
341   }
342   // Wait for completion of all comms.
343   Request::startall(size - 1, requests);
344
345   if(op != MPI_OP_NULL && op->is_commutative()){
346     for (int other = 0; other < size - 1; other++) {
347       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
348       if(index == MPI_UNDEFINED) {
349         break;
350       }
351       if(index < rank) {
352         if(recvbuf_is_empty){
353           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
354           recvbuf_is_empty=0;
355         } else
356           // #Request is below rank: it's a irecv
357           if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
358       }
359     }
360   }else{
361     //non commutative case, wait in order
362     for (int other = 0; other < size - 1; other++) {
363      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
364       if(index < rank) {
365         if (recvbuf_is_empty) {
366           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
367           recvbuf_is_empty = 0;
368         } else
369           if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
370       }
371     }
372   }
373   for(index = 0; index < rank; index++) {
374     smpi_free_tmp_buffer(tmpbufs[index]);
375   }
376   for(index = 0; index < size-1; index++) {
377     Request::unref(&requests[index]);
378   }
379   xbt_free(tmpbufs);
380   xbt_free(requests);
381   return MPI_SUCCESS;
382 }
383
384 }
385 }
386
387
388
389
390