Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Colls: Move description & name for colls to std::string
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "smpi_coll.hpp"
9 #include "private.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
15
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
17
18 #define COLL_SETTER(cat, ret, args, args2)                                                                             \
19   int(*Colls::cat) args;                                                                                               \
20   void Colls::set_##cat(std::string name)                                                                              \
21   {                                                                                                                    \
22     int id = find_coll_description(mpi_coll_##cat##_description, name, #cat);                                          \
23     cat    = reinterpret_cast<ret(*) args>(mpi_coll_##cat##_description[id].coll);                                     \
24     if (cat == nullptr)                                                                                                \
25       xbt_die("Collective " #cat " set to nullptr!");                                                                  \
26   }
27
28 namespace simgrid{
29 namespace smpi{
30
31 void (*Colls::smpi_coll_cleanup_callback)();
32
33 /* these arrays must be nullptr terminated */
34 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
35     COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
36 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
37     COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
38 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
39     COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
40 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
41     COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
42 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
43     COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
44 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
45     COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
46 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
47     COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
48 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
49     COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
50 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
51     COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
52 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
53     COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
54 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
55     COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
56
57 /** Displays the long description of all registered models, and quit */
58 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
59 {
60   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
61   for (int i = 0; not table[i].name.empty(); i++)
62     XBT_WARN("  %s: %s\n", table[i].name.c_str(), table[i].description.c_str());
63 }
64
65 int Colls::find_coll_description(s_mpi_coll_description_t* table, std::string name, const char* desc)
66 {
67   for (int i = 0; not table[i].name.empty(); i++)
68     if (name == table[i].name) {
69       if (table[i].name != "default")
70         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name.c_str(),desc);
71       return i;
72     }
73
74   if (table[0].name.empty())
75     xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
76   std::string name_list = table[0].name;
77   for (int i = 1; not table[i].name.empty(); i++)
78     name_list = name_list + ", " + table[i].name;
79
80   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
81   return -1;
82 }
83
84 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
85 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
86 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
87 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
88 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
89 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
90 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
95
96 void Colls::set_collectives(){
97   std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
98   if (selector_name.empty())
99     selector_name = "default";
100
101   std::map<std::string, std::function<void(std::string)>> setter_callbacks = {
102       {"gather", &Colls::set_gather},         {"allgather", &Colls::set_allgather},
103       {"allgatherv", &Colls::set_allgatherv}, {"allreduce", &Colls::set_allreduce},
104       {"alltoall", &Colls::set_alltoall},     {"alltoallv", &Colls::set_alltoallv},
105       {"reduce", &Colls::set_reduce},         {"reduce_scatter", &Colls::set_reduce_scatter},
106       {"scatter", &Colls::set_scatter},       {"bcast", &Colls::set_bcast},
107       {"barrier", &Colls::set_barrier}};
108
109   // This only prevents code duplication
110   std::function<void(std::string)> setup = [selector_name, &setter_callbacks](std::string coll) {
111     std::string name = simgrid::config::get_value<std::string>(("smpi/" + coll).c_str());
112     if (name.empty())
113       name = selector_name;
114
115     setter_callbacks[coll](name);
116   };
117
118   setup("gather");
119   setup("allgather");
120   setup("allgatherv");
121   setup("allreduce");
122   setup("alltoall");
123   setup("alltoallv");
124   setup("reduce");
125   setup("reduce_scatter");
126   setup("scatter");
127   setup("bcast");
128   setup("barrier");
129 }
130
131
132 //Implementations of the single algorith collectives
133
134 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
135                       MPI_Datatype recvtype, int root, MPI_Comm comm)
136 {
137   int system_tag = COLL_TAG_GATHERV;
138   MPI_Aint lb = 0;
139   MPI_Aint recvext = 0;
140
141   int rank = comm->rank();
142   int size = comm->size();
143   if (rank != root) {
144     // Send buffer to root
145     Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
146   } else {
147     recvtype->extent(&lb, &recvext);
148     // Local copy from root
149     Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
150                        recvcounts[root], recvtype);
151     // Receive buffers from senders
152     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
153     int index = 0;
154     for (int src = 0; src < size; src++) {
155       if(src != root) {
156         requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
157                           recvcounts[src], recvtype, src, system_tag, comm);
158         index++;
159       }
160     }
161     // Wait for completion of irecv's.
162     Request::startall(size - 1, requests);
163     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
164     for (int src = 0; src < size-1; src++) {
165       Request::unref(&requests[src]);
166     }
167     xbt_free(requests);
168   }
169   return MPI_SUCCESS;
170 }
171
172
173 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
174                        MPI_Datatype recvtype, int root, MPI_Comm comm)
175 {
176   int system_tag = COLL_TAG_SCATTERV;
177   MPI_Aint lb = 0;
178   MPI_Aint sendext = 0;
179
180   int rank = comm->rank();
181   int size = comm->size();
182   if(rank != root) {
183     // Recv buffer from root
184     Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
185   } else {
186     sendtype->extent(&lb, &sendext);
187     // Local copy from root
188     if(recvbuf!=MPI_IN_PLACE){
189       Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
190                        sendtype, recvbuf, recvcount, recvtype);
191     }
192     // Send buffers to receivers
193     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
194     int index = 0;
195     for (int dst = 0; dst < size; dst++) {
196       if (dst != root) {
197         requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
198                             sendtype, dst, system_tag, comm);
199         index++;
200       }
201     }
202     // Wait for completion of isend's.
203     Request::startall(size - 1, requests);
204     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
205     for (int dst = 0; dst < size-1; dst++) {
206       Request::unref(&requests[dst]);
207     }
208     xbt_free(requests);
209   }
210   return MPI_SUCCESS;
211 }
212
213
214 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
215 {
216   int system_tag = -888;
217   MPI_Aint lb      = 0;
218   MPI_Aint dataext = 0;
219
220   int rank = comm->rank();
221   int size = comm->size();
222
223   datatype->extent(&lb, &dataext);
224
225   // Local copy from self
226   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
227
228   // Send/Recv buffers to/from others
229   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
230   void **tmpbufs = xbt_new(void *, rank);
231   int index = 0;
232   for (int other = 0; other < rank; other++) {
233     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
234     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
235     index++;
236   }
237   for (int other = rank + 1; other < size; other++) {
238     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
239     index++;
240   }
241   // Wait for completion of all comms.
242   Request::startall(size - 1, requests);
243
244   if(op != MPI_OP_NULL && op->is_commutative()){
245     for (int other = 0; other < size - 1; other++) {
246       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
247       if(index == MPI_UNDEFINED) {
248         break;
249       }
250       if(index < rank) {
251         // #Request is below rank: it's a irecv
252         op->apply( tmpbufs[index], recvbuf, &count, datatype);
253       }
254     }
255   }else{
256     //non commutative case, wait in order
257     for (int other = 0; other < size - 1; other++) {
258       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
259       if(index < rank && op!=MPI_OP_NULL) {
260         op->apply( tmpbufs[other], recvbuf, &count, datatype);
261       }
262     }
263   }
264   for(index = 0; index < rank; index++) {
265     smpi_free_tmp_buffer(tmpbufs[index]);
266   }
267   for(index = 0; index < size-1; index++) {
268     Request::unref(&requests[index]);
269   }
270   xbt_free(tmpbufs);
271   xbt_free(requests);
272   return MPI_SUCCESS;
273 }
274
275 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
276 {
277   int system_tag = -888;
278   MPI_Aint lb         = 0;
279   MPI_Aint dataext    = 0;
280   int recvbuf_is_empty=1;
281   int rank = comm->rank();
282   int size = comm->size();
283
284   datatype->extent(&lb, &dataext);
285
286   // Send/Recv buffers to/from others
287   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
288   void **tmpbufs = xbt_new(void *, rank);
289   int index = 0;
290   for (int other = 0; other < rank; other++) {
291     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
292     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
293     index++;
294   }
295   for (int other = rank + 1; other < size; other++) {
296     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
297     index++;
298   }
299   // Wait for completion of all comms.
300   Request::startall(size - 1, requests);
301
302   if(op != MPI_OP_NULL && op->is_commutative()){
303     for (int other = 0; other < size - 1; other++) {
304       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
305       if(index == MPI_UNDEFINED) {
306         break;
307       }
308       if(index < rank) {
309         if(recvbuf_is_empty){
310           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
311           recvbuf_is_empty=0;
312         } else
313           // #Request is below rank: it's a irecv
314           op->apply( tmpbufs[index], recvbuf, &count, datatype);
315       }
316     }
317   }else{
318     //non commutative case, wait in order
319     for (int other = 0; other < size - 1; other++) {
320      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
321       if(index < rank) {
322         if (recvbuf_is_empty) {
323           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
324           recvbuf_is_empty = 0;
325         } else
326           if(op!=MPI_OP_NULL)
327             op->apply( tmpbufs[other], recvbuf, &count, datatype);
328       }
329     }
330   }
331   for(index = 0; index < rank; index++) {
332     smpi_free_tmp_buffer(tmpbufs[index]);
333   }
334   for(index = 0; index < size-1; index++) {
335     Request::unref(&requests[index]);
336   }
337   xbt_free(tmpbufs);
338   xbt_free(requests);
339   return MPI_SUCCESS;
340 }
341
342 }
343 }