Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' into energy-pstate
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2017. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "smpi_coll.hpp"
9 #include "private.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14
15 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
16
17 #define COLL_SETTER(cat, ret, args, args2)                                                                             \
18   int(*Colls::cat) args;                                                                                               \
19   void Colls::set_##cat(std::string name)                                                                              \
20   {                                                                                                                    \
21     int id = find_coll_description(mpi_coll_##cat##_description, name, #cat);                                          \
22     cat    = reinterpret_cast<ret(*) args>(mpi_coll_##cat##_description[id].coll);                                     \
23     if (cat == nullptr)                                                                                                \
24       xbt_die("Collective " #cat " set to nullptr!");                                                                  \
25   }
26
27 #define SET_COLL(coll)                                                                                                 \
28   name = xbt_cfg_get_string("smpi/" #coll);                                                                            \
29   if (name.empty())                                                                                                    \
30     name = selector_name;                                                                                              \
31   set_##coll(name);
32
33 namespace simgrid{
34 namespace smpi{
35
36 void (*Colls::smpi_coll_cleanup_callback)();
37
38 /* these arrays must be nullptr terminated */
39 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
40     COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
41 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
42     COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
43 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
44     COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
45 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
46     COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
47 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
48     COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
49 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
50     COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
51 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
52     COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
53 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
54     COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
55 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
56     COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
57 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
58     COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
59 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
60     COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
61
62 /** Displays the long description of all registered models, and quit */
63 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
64 {
65   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
66   for (int i = 0; table[i].name; i++)
67     XBT_WARN("  %s: %s\n", table[i].name, table[i].description);
68 }
69
70 int Colls::find_coll_description(s_mpi_coll_description_t* table, std::string name, const char* desc)
71 {
72   for (int i = 0; table[i].name; i++)
73     if (name == table[i].name) {
74       if (strcmp(table[i].name,"default"))
75         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
76       return i;
77     }
78
79   if (not table[0].name)
80     xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
81   std::string name_list = std::string(table[0].name);
82   for (int i = 1; table[i].name; i++)
83     name_list = name_list + ", " + table[i].name;
84
85   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
86   return -1;
87 }
88
89 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
90 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
95 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
96 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
100
101
102 void Colls::set_collectives(){
103   std::string selector_name = xbt_cfg_get_string("smpi/coll-selector");
104   if (selector_name.empty())
105     selector_name = "default";
106
107   std::string name;
108
109   SET_COLL(gather);
110   SET_COLL(allgather);
111   SET_COLL(allgatherv);
112   SET_COLL(allreduce);
113   SET_COLL(alltoall);
114   SET_COLL(alltoallv);
115   SET_COLL(reduce);
116   SET_COLL(reduce_scatter);
117   SET_COLL(scatter);
118   SET_COLL(bcast);
119   SET_COLL(barrier);
120 }
121
122
123 //Implementations of the single algorith collectives
124
125 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
126                       MPI_Datatype recvtype, int root, MPI_Comm comm)
127 {
128   int system_tag = COLL_TAG_GATHERV;
129   MPI_Aint lb = 0;
130   MPI_Aint recvext = 0;
131
132   int rank = comm->rank();
133   int size = comm->size();
134   if (rank != root) {
135     // Send buffer to root
136     Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
137   } else {
138     recvtype->extent(&lb, &recvext);
139     // Local copy from root
140     Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
141                        recvcounts[root], recvtype);
142     // Receive buffers from senders
143     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
144     int index = 0;
145     for (int src = 0; src < size; src++) {
146       if(src != root) {
147         requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
148                           recvcounts[src], recvtype, src, system_tag, comm);
149         index++;
150       }
151     }
152     // Wait for completion of irecv's.
153     Request::startall(size - 1, requests);
154     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
155     for (int src = 0; src < size-1; src++) {
156       Request::unref(&requests[src]);
157     }
158     xbt_free(requests);
159   }
160   return MPI_SUCCESS;
161 }
162
163
164 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
165                        MPI_Datatype recvtype, int root, MPI_Comm comm)
166 {
167   int system_tag = COLL_TAG_SCATTERV;
168   MPI_Aint lb = 0;
169   MPI_Aint sendext = 0;
170
171   int rank = comm->rank();
172   int size = comm->size();
173   if(rank != root) {
174     // Recv buffer from root
175     Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
176   } else {
177     sendtype->extent(&lb, &sendext);
178     // Local copy from root
179     if(recvbuf!=MPI_IN_PLACE){
180       Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
181                        sendtype, recvbuf, recvcount, recvtype);
182     }
183     // Send buffers to receivers
184     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
185     int index = 0;
186     for (int dst = 0; dst < size; dst++) {
187       if (dst != root) {
188         requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
189                             sendtype, dst, system_tag, comm);
190         index++;
191       }
192     }
193     // Wait for completion of isend's.
194     Request::startall(size - 1, requests);
195     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
196     for (int dst = 0; dst < size-1; dst++) {
197       Request::unref(&requests[dst]);
198     }
199     xbt_free(requests);
200   }
201   return MPI_SUCCESS;
202 }
203
204
205 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
206 {
207   int system_tag = -888;
208   MPI_Aint lb      = 0;
209   MPI_Aint dataext = 0;
210
211   int rank = comm->rank();
212   int size = comm->size();
213
214   datatype->extent(&lb, &dataext);
215
216   // Local copy from self
217   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
218
219   // Send/Recv buffers to/from others
220   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
221   void **tmpbufs = xbt_new(void *, rank);
222   int index = 0;
223   for (int other = 0; other < rank; other++) {
224     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
225     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
226     index++;
227   }
228   for (int other = rank + 1; other < size; other++) {
229     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
230     index++;
231   }
232   // Wait for completion of all comms.
233   Request::startall(size - 1, requests);
234
235   if(op != MPI_OP_NULL && op->is_commutative()){
236     for (int other = 0; other < size - 1; other++) {
237       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
238       if(index == MPI_UNDEFINED) {
239         break;
240       }
241       if(index < rank) {
242         // #Request is below rank: it's a irecv
243         op->apply( tmpbufs[index], recvbuf, &count, datatype);
244       }
245     }
246   }else{
247     //non commutative case, wait in order
248     for (int other = 0; other < size - 1; other++) {
249       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
250       if(index < rank && op!=MPI_OP_NULL) {
251         op->apply( tmpbufs[other], recvbuf, &count, datatype);
252       }
253     }
254   }
255   for(index = 0; index < rank; index++) {
256     smpi_free_tmp_buffer(tmpbufs[index]);
257   }
258   for(index = 0; index < size-1; index++) {
259     Request::unref(&requests[index]);
260   }
261   xbt_free(tmpbufs);
262   xbt_free(requests);
263   return MPI_SUCCESS;
264 }
265
266 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
267 {
268   int system_tag = -888;
269   MPI_Aint lb         = 0;
270   MPI_Aint dataext    = 0;
271   int recvbuf_is_empty=1;
272   int rank = comm->rank();
273   int size = comm->size();
274
275   datatype->extent(&lb, &dataext);
276
277   // Send/Recv buffers to/from others
278   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
279   void **tmpbufs = xbt_new(void *, rank);
280   int index = 0;
281   for (int other = 0; other < rank; other++) {
282     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
283     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
284     index++;
285   }
286   for (int other = rank + 1; other < size; other++) {
287     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
288     index++;
289   }
290   // Wait for completion of all comms.
291   Request::startall(size - 1, requests);
292
293   if(op != MPI_OP_NULL && op->is_commutative()){
294     for (int other = 0; other < size - 1; other++) {
295       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
296       if(index == MPI_UNDEFINED) {
297         break;
298       }
299       if(index < rank) {
300         if(recvbuf_is_empty){
301           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
302           recvbuf_is_empty=0;
303         } else
304           // #Request is below rank: it's a irecv
305           op->apply( tmpbufs[index], recvbuf, &count, datatype);
306       }
307     }
308   }else{
309     //non commutative case, wait in order
310     for (int other = 0; other < size - 1; other++) {
311      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
312       if(index < rank) {
313         if (recvbuf_is_empty) {
314           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
315           recvbuf_is_empty = 0;
316         } else
317           if(op!=MPI_OP_NULL)
318             op->apply( tmpbufs[other], recvbuf, &count, datatype);
319       }
320     }
321   }
322   for(index = 0; index < rank; index++) {
323     smpi_free_tmp_buffer(tmpbufs[index]);
324   }
325   for(index = 0; index < size-1; index++) {
326     Request::unref(&requests[index]);
327   }
328   xbt_free(tmpbufs);
329   xbt_free(requests);
330   return MPI_SUCCESS;
331 }
332
333 }
334 }