Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
1e89798c6ed355d4bcdb3fe67e0b1f35085964f8
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2017. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "private.hpp"
9 #include "smpi_coll.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14
15 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
16
17 #define COLL_SETTER(cat, ret, args, args2)\
18 int (*Colls::cat ) args;\
19 void Colls::set_##cat (const char * name){\
20     int id = find_coll_description(mpi_coll_## cat ##_description,\
21                                              name,#cat);\
22     cat = reinterpret_cast<ret (*) args>\
23         (mpi_coll_## cat ##_description[id].coll);\
24     if (cat == nullptr)\
25       xbt_die("Collective "#cat" set to nullptr!");\
26 }
27
28 #define SET_COLL(coll)\
29     name = xbt_cfg_get_string("smpi/"#coll);\
30     if (name==nullptr || name[0] == '\0')\
31         name = selector_name;\
32     set_##coll(name);
33
34 namespace simgrid{
35 namespace smpi{
36
37 void (*Colls::smpi_coll_cleanup_callback)();
38
39 /* these arrays must be nullptr terminated */
40 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
41     COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
42 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
43     COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
44 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
45     COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
46 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
47     COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
48 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
49     COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
50 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
51     COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
52 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
53     COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
54 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
55     COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
56 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
57     COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
58 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
59     COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
60 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
61     COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
62
63 /** Displays the long description of all registered models, and quit */
64 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
65 {
66   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
67   for (int i = 0; table[i].name; i++)
68     XBT_WARN("  %s: %s\n", table[i].name, table[i].description);
69 }
70
71 int Colls::find_coll_description(s_mpi_coll_description_t * table, const char *name, const char *desc)
72 {
73   char *name_list = nullptr;
74   for (int i = 0; table[i].name; i++)
75     if (not strcmp(name, table[i].name)) {
76       if (strcmp(table[i].name,"default"))
77         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
78       return i;
79     }
80
81   if (not table[0].name)
82     xbt_die("No collective is valid for '%s'! This is a bug.",name);
83   name_list = xbt_strdup(table[0].name);
84   for (int i = 1; table[i].name; i++) {
85     name_list = static_cast<char*>(xbt_realloc(name_list, strlen(name_list) + strlen(table[i].name) + 3));
86     strncat(name_list, ", ",2);
87     strncat(name_list, table[i].name, strlen(table[i].name));
88   }
89   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name, name_list);
90   return -1;
91 }
92
93
94
95 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
96 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
100 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
101 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
102 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
103 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
104 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
105 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
106
107
108 void Colls::set_collectives(){
109     const char* selector_name = static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
110     if (selector_name==nullptr || selector_name[0] == '\0')
111         selector_name = "default";
112
113     const char* name;
114
115     SET_COLL(gather);
116     SET_COLL(allgather);
117     SET_COLL(allgatherv);
118     SET_COLL(allreduce);
119     SET_COLL(alltoall);
120     SET_COLL(alltoallv);
121     SET_COLL(reduce);
122     SET_COLL(reduce_scatter);
123     SET_COLL(scatter);
124     SET_COLL(bcast);
125     SET_COLL(barrier);
126 }
127
128
129 //Implementations of the single algorith collectives
130
131 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
132                       MPI_Datatype recvtype, int root, MPI_Comm comm)
133 {
134   int system_tag = COLL_TAG_GATHERV;
135   MPI_Aint lb = 0;
136   MPI_Aint recvext = 0;
137
138   int rank = comm->rank();
139   int size = comm->size();
140   if (rank != root) {
141     // Send buffer to root
142     Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
143   } else {
144     recvtype->extent(&lb, &recvext);
145     // Local copy from root
146     Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
147                        recvcounts[root], recvtype);
148     // Receive buffers from senders
149     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
150     int index = 0;
151     for (int src = 0; src < size; src++) {
152       if(src != root) {
153         requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
154                           recvcounts[src], recvtype, src, system_tag, comm);
155         index++;
156       }
157     }
158     // Wait for completion of irecv's.
159     Request::startall(size - 1, requests);
160     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
161     for (int src = 0; src < size-1; src++) {
162       Request::unref(&requests[src]);
163     }
164     xbt_free(requests);
165   }
166   return MPI_SUCCESS;
167 }
168
169
170 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
171                        MPI_Datatype recvtype, int root, MPI_Comm comm)
172 {
173   int system_tag = COLL_TAG_SCATTERV;
174   MPI_Aint lb = 0;
175   MPI_Aint sendext = 0;
176
177   int rank = comm->rank();
178   int size = comm->size();
179   if(rank != root) {
180     // Recv buffer from root
181     Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
182   } else {
183     sendtype->extent(&lb, &sendext);
184     // Local copy from root
185     if(recvbuf!=MPI_IN_PLACE){
186       Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
187                        sendtype, recvbuf, recvcount, recvtype);
188     }
189     // Send buffers to receivers
190     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
191     int index = 0;
192     for (int dst = 0; dst < size; dst++) {
193       if (dst != root) {
194         requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
195                             sendtype, dst, system_tag, comm);
196         index++;
197       }
198     }
199     // Wait for completion of isend's.
200     Request::startall(size - 1, requests);
201     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
202     for (int dst = 0; dst < size-1; dst++) {
203       Request::unref(&requests[dst]);
204     }
205     xbt_free(requests);
206   }
207   return MPI_SUCCESS;
208 }
209
210
211 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
212 {
213   int system_tag = -888;
214   MPI_Aint lb      = 0;
215   MPI_Aint dataext = 0;
216
217   int rank = comm->rank();
218   int size = comm->size();
219
220   datatype->extent(&lb, &dataext);
221
222   // Local copy from self
223   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
224
225   // Send/Recv buffers to/from others
226   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
227   void **tmpbufs = xbt_new(void *, rank);
228   int index = 0;
229   for (int other = 0; other < rank; other++) {
230     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
231     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
232     index++;
233   }
234   for (int other = rank + 1; other < size; other++) {
235     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
236     index++;
237   }
238   // Wait for completion of all comms.
239   Request::startall(size - 1, requests);
240
241   if(op != MPI_OP_NULL && op->is_commutative()){
242     for (int other = 0; other < size - 1; other++) {
243       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
244       if(index == MPI_UNDEFINED) {
245         break;
246       }
247       if(index < rank) {
248         // #Request is below rank: it's a irecv
249         op->apply( tmpbufs[index], recvbuf, &count, datatype);
250       }
251     }
252   }else{
253     //non commutative case, wait in order
254     for (int other = 0; other < size - 1; other++) {
255       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
256       if(index < rank && op!=MPI_OP_NULL) {
257         op->apply( tmpbufs[other], recvbuf, &count, datatype);
258       }
259     }
260   }
261   for(index = 0; index < rank; index++) {
262     smpi_free_tmp_buffer(tmpbufs[index]);
263   }
264   for(index = 0; index < size-1; index++) {
265     Request::unref(&requests[index]);
266   }
267   xbt_free(tmpbufs);
268   xbt_free(requests);
269   return MPI_SUCCESS;
270 }
271
272 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
273 {
274   int system_tag = -888;
275   MPI_Aint lb         = 0;
276   MPI_Aint dataext    = 0;
277   int recvbuf_is_empty=1;
278   int rank = comm->rank();
279   int size = comm->size();
280
281   datatype->extent(&lb, &dataext);
282
283   // Send/Recv buffers to/from others
284   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
285   void **tmpbufs = xbt_new(void *, rank);
286   int index = 0;
287   for (int other = 0; other < rank; other++) {
288     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
289     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
290     index++;
291   }
292   for (int other = rank + 1; other < size; other++) {
293     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
294     index++;
295   }
296   // Wait for completion of all comms.
297   Request::startall(size - 1, requests);
298
299   if(op != MPI_OP_NULL && op->is_commutative()){
300     for (int other = 0; other < size - 1; other++) {
301       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
302       if(index == MPI_UNDEFINED) {
303         break;
304       }
305       if(index < rank) {
306         if(recvbuf_is_empty){
307           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
308           recvbuf_is_empty=0;
309         } else
310           // #Request is below rank: it's a irecv
311           op->apply( tmpbufs[index], recvbuf, &count, datatype);
312       }
313     }
314   }else{
315     //non commutative case, wait in order
316     for (int other = 0; other < size - 1; other++) {
317      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
318       if(index < rank) {
319         if (recvbuf_is_empty) {
320           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
321           recvbuf_is_empty = 0;
322         } else
323           if(op!=MPI_OP_NULL)
324             op->apply( tmpbufs[other], recvbuf, &count, datatype);
325       }
326     }
327   }
328   for(index = 0; index < rank; index++) {
329     smpi_free_tmp_buffer(tmpbufs[index]);
330   }
331   for(index = 0; index < size-1; index++) {
332     Request::unref(&requests[index]);
333   }
334   xbt_free(tmpbufs);
335   xbt_free(requests);
336   return MPI_SUCCESS;
337 }
338
339 }
340 }