Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
59425a37778878f16159e85d6159c27a69d65d46
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "smpi_coll.hpp"
9 #include "private.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
15
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
17
18 #define COLL_SETTER(cat, ret, args, args2)                                                                             \
19   int(*Colls::cat) args;                                                                                               \
20   void Colls::set_##cat(std::string name)                                                                              \
21   {                                                                                                                    \
22     int id = find_coll_description(mpi_coll_##cat##_description, name, #cat);                                          \
23     cat    = reinterpret_cast<ret(*) args>(mpi_coll_##cat##_description[id].coll);                                     \
24     if (cat == nullptr)                                                                                                \
25       xbt_die("Collective " #cat " set to nullptr!");                                                                  \
26   }
27
28 #define SET_COLL(coll)                                                                                                 \
29   name = xbt_cfg_get_string("smpi/" #coll);                                                                            \
30   if (name.empty())                                                                                                    \
31     name = selector_name;                                                                                              \
32   set_##coll(name);
33
34 namespace simgrid{
35 namespace smpi{
36
37 void (*Colls::smpi_coll_cleanup_callback)();
38
39 /* these arrays must be nullptr terminated */
40 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
41     COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
42 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
43     COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
44 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
45     COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
46 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
47     COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
48 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
49     COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
50 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
51     COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
52 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
53     COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
54 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
55     COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
56 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
57     COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
58 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
59     COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
60 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
61     COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
62
63 /** Displays the long description of all registered models, and quit */
64 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
65 {
66   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
67   for (int i = 0; table[i].name; i++)
68     XBT_WARN("  %s: %s\n", table[i].name, table[i].description);
69 }
70
71 int Colls::find_coll_description(s_mpi_coll_description_t* table, std::string name, const char* desc)
72 {
73   for (int i = 0; table[i].name; i++)
74     if (name == table[i].name) {
75       if (strcmp(table[i].name,"default"))
76         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
77       return i;
78     }
79
80   if (not table[0].name)
81     xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
82   std::string name_list = std::string(table[0].name);
83   for (int i = 1; table[i].name; i++)
84     name_list = name_list + ", " + table[i].name;
85
86   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
87   return -1;
88 }
89
90 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
95 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
96 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
100 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
101
102
103 void Colls::set_collectives(){
104   std::string selector_name = xbt_cfg_get_string("smpi/coll-selector");
105   if (selector_name.empty())
106     selector_name = "default";
107
108   std::string name;
109
110   SET_COLL(gather);
111   SET_COLL(allgather);
112   SET_COLL(allgatherv);
113   SET_COLL(allreduce);
114   SET_COLL(alltoall);
115   SET_COLL(alltoallv);
116   SET_COLL(reduce);
117   SET_COLL(reduce_scatter);
118   SET_COLL(scatter);
119   SET_COLL(bcast);
120   SET_COLL(barrier);
121 }
122
123
124 //Implementations of the single algorith collectives
125
126 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
127                       MPI_Datatype recvtype, int root, MPI_Comm comm)
128 {
129   int system_tag = COLL_TAG_GATHERV;
130   MPI_Aint lb = 0;
131   MPI_Aint recvext = 0;
132
133   int rank = comm->rank();
134   int size = comm->size();
135   if (rank != root) {
136     // Send buffer to root
137     Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
138   } else {
139     recvtype->extent(&lb, &recvext);
140     // Local copy from root
141     Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
142                        recvcounts[root], recvtype);
143     // Receive buffers from senders
144     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
145     int index = 0;
146     for (int src = 0; src < size; src++) {
147       if(src != root) {
148         requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
149                           recvcounts[src], recvtype, src, system_tag, comm);
150         index++;
151       }
152     }
153     // Wait for completion of irecv's.
154     Request::startall(size - 1, requests);
155     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
156     for (int src = 0; src < size-1; src++) {
157       Request::unref(&requests[src]);
158     }
159     xbt_free(requests);
160   }
161   return MPI_SUCCESS;
162 }
163
164
165 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
166                        MPI_Datatype recvtype, int root, MPI_Comm comm)
167 {
168   int system_tag = COLL_TAG_SCATTERV;
169   MPI_Aint lb = 0;
170   MPI_Aint sendext = 0;
171
172   int rank = comm->rank();
173   int size = comm->size();
174   if(rank != root) {
175     // Recv buffer from root
176     Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
177   } else {
178     sendtype->extent(&lb, &sendext);
179     // Local copy from root
180     if(recvbuf!=MPI_IN_PLACE){
181       Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
182                        sendtype, recvbuf, recvcount, recvtype);
183     }
184     // Send buffers to receivers
185     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
186     int index = 0;
187     for (int dst = 0; dst < size; dst++) {
188       if (dst != root) {
189         requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
190                             sendtype, dst, system_tag, comm);
191         index++;
192       }
193     }
194     // Wait for completion of isend's.
195     Request::startall(size - 1, requests);
196     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
197     for (int dst = 0; dst < size-1; dst++) {
198       Request::unref(&requests[dst]);
199     }
200     xbt_free(requests);
201   }
202   return MPI_SUCCESS;
203 }
204
205
206 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
207 {
208   int system_tag = -888;
209   MPI_Aint lb      = 0;
210   MPI_Aint dataext = 0;
211
212   int rank = comm->rank();
213   int size = comm->size();
214
215   datatype->extent(&lb, &dataext);
216
217   // Local copy from self
218   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
219
220   // Send/Recv buffers to/from others
221   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
222   void **tmpbufs = xbt_new(void *, rank);
223   int index = 0;
224   for (int other = 0; other < rank; other++) {
225     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
226     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
227     index++;
228   }
229   for (int other = rank + 1; other < size; other++) {
230     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
231     index++;
232   }
233   // Wait for completion of all comms.
234   Request::startall(size - 1, requests);
235
236   if(op != MPI_OP_NULL && op->is_commutative()){
237     for (int other = 0; other < size - 1; other++) {
238       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
239       if(index == MPI_UNDEFINED) {
240         break;
241       }
242       if(index < rank) {
243         // #Request is below rank: it's a irecv
244         op->apply( tmpbufs[index], recvbuf, &count, datatype);
245       }
246     }
247   }else{
248     //non commutative case, wait in order
249     for (int other = 0; other < size - 1; other++) {
250       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
251       if(index < rank && op!=MPI_OP_NULL) {
252         op->apply( tmpbufs[other], recvbuf, &count, datatype);
253       }
254     }
255   }
256   for(index = 0; index < rank; index++) {
257     smpi_free_tmp_buffer(tmpbufs[index]);
258   }
259   for(index = 0; index < size-1; index++) {
260     Request::unref(&requests[index]);
261   }
262   xbt_free(tmpbufs);
263   xbt_free(requests);
264   return MPI_SUCCESS;
265 }
266
267 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
268 {
269   int system_tag = -888;
270   MPI_Aint lb         = 0;
271   MPI_Aint dataext    = 0;
272   int recvbuf_is_empty=1;
273   int rank = comm->rank();
274   int size = comm->size();
275
276   datatype->extent(&lb, &dataext);
277
278   // Send/Recv buffers to/from others
279   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
280   void **tmpbufs = xbt_new(void *, rank);
281   int index = 0;
282   for (int other = 0; other < rank; other++) {
283     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
284     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
285     index++;
286   }
287   for (int other = rank + 1; other < size; other++) {
288     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
289     index++;
290   }
291   // Wait for completion of all comms.
292   Request::startall(size - 1, requests);
293
294   if(op != MPI_OP_NULL && op->is_commutative()){
295     for (int other = 0; other < size - 1; other++) {
296       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
297       if(index == MPI_UNDEFINED) {
298         break;
299       }
300       if(index < rank) {
301         if(recvbuf_is_empty){
302           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
303           recvbuf_is_empty=0;
304         } else
305           // #Request is below rank: it's a irecv
306           op->apply( tmpbufs[index], recvbuf, &count, datatype);
307       }
308     }
309   }else{
310     //non commutative case, wait in order
311     for (int other = 0; other < size - 1; other++) {
312      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
313       if(index < rank) {
314         if (recvbuf_is_empty) {
315           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
316           recvbuf_is_empty = 0;
317         } else
318           if(op!=MPI_OP_NULL)
319             op->apply( tmpbufs[other], recvbuf, &count, datatype);
320       }
321     }
322   }
323   for(index = 0; index < rank; index++) {
324     smpi_free_tmp_buffer(tmpbufs[index]);
325   }
326   for(index = 0; index < size-1; index++) {
327     Request::unref(&requests[index]);
328   }
329   xbt_free(tmpbufs);
330   xbt_free(requests);
331   return MPI_SUCCESS;
332 }
333
334 }
335 }