Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Change smpi::Colls static class into a namespace of functions
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "smpi_coll.hpp"
9 #include "private.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
15
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI collectives.");
17
18 #define COLL_SETTER(cat, ret, args, args2)                                                                             \
19   void colls::_XBT_CONCAT(set_, cat)(const std::string& name)                                                          \
20   {                                                                                                                    \
21     int id = find_coll_description(_XBT_CONCAT3(mpi_coll_, cat, _description), name, _XBT_STRINGIFY(cat));             \
22     cat    = reinterpret_cast<ret(*) args>(_XBT_CONCAT3(mpi_coll_, cat, _description)[id].coll);                       \
23     if (cat == nullptr)                                                                                                \
24       xbt_die("Collective " _XBT_STRINGIFY(cat) " set to nullptr!");                                                   \
25   }
26
27 namespace simgrid{
28 namespace smpi{
29
30 /* these arrays must be nullptr terminated */
31 s_mpi_coll_description_t mpi_coll_gather_description[]         = {COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA),
32                                                           {"", "", nullptr}};
33 s_mpi_coll_description_t mpi_coll_allgather_description[]      = {COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA),
34                                                              {"", "", nullptr}};
35 s_mpi_coll_description_t mpi_coll_allgatherv_description[]     = {COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA),
36                                                               {"", "", nullptr}};
37 s_mpi_coll_description_t mpi_coll_allreduce_description[]      = {COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA),
38                                                              {"", "", nullptr}};
39 s_mpi_coll_description_t mpi_coll_reduce_scatter_description[] = {COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA),
40                                                                   {"", "", nullptr}};
41 s_mpi_coll_description_t mpi_coll_scatter_description[]        = {COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA),
42                                                            {"", "", nullptr}};
43 s_mpi_coll_description_t mpi_coll_barrier_description[]        = {COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA),
44                                                            {"", "", nullptr}};
45 s_mpi_coll_description_t mpi_coll_alltoall_description[]       = {COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA),
46                                                             {"", "", nullptr}};
47 s_mpi_coll_description_t mpi_coll_alltoallv_description[]      = {COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA),
48                                                              {"", "", nullptr}};
49 s_mpi_coll_description_t mpi_coll_bcast_description[]  = {COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr}};
50 s_mpi_coll_description_t mpi_coll_reduce_description[] = {COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA),
51                                                           {"", "", nullptr}};
52
53 // Needed by the automatic selector weird implementation
54 s_mpi_coll_description_t* colls::get_smpi_coll_description(const char* name, int rank)
55 {
56   if (strcmp(name, "gather") == 0)
57     return &mpi_coll_gather_description[rank];
58   if (strcmp(name, "allgather") == 0)
59     return &mpi_coll_allgather_description[rank];
60   if (strcmp(name, "allgatherv") == 0)
61     return &mpi_coll_allgatherv_description[rank];
62   if (strcmp(name, "allreduce") == 0)
63     return &mpi_coll_allreduce_description[rank];
64   if (strcmp(name, "reduce_scatter") == 0)
65     return &mpi_coll_reduce_scatter_description[rank];
66   if (strcmp(name, "scatter") == 0)
67     return &mpi_coll_scatter_description[rank];
68   if (strcmp(name, "barrier") == 0)
69     return &mpi_coll_barrier_description[rank];
70   if (strcmp(name, "alltoall") == 0)
71     return &mpi_coll_alltoall_description[rank];
72   if (strcmp(name, "alltoallv") == 0)
73     return &mpi_coll_alltoallv_description[rank];
74   if (strcmp(name, "bcast") == 0)
75     return &mpi_coll_bcast_description[rank];
76   if (strcmp(name, "reduce") == 0)
77     return &mpi_coll_reduce_description[rank];
78   XBT_INFO("You requested an unknown collective: %s", name);
79   return nullptr;
80 }
81
82 /** Displays the long description of all registered models, and quit */
83 void colls::coll_help(const char* category, s_mpi_coll_description_t* table)
84 {
85   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
86   for (int i = 0; not table[i].name.empty(); i++)
87     XBT_WARN("  %s: %s\n", table[i].name.c_str(), table[i].description.c_str());
88 }
89
90 int colls::find_coll_description(s_mpi_coll_description_t* table, const std::string& name, const char* desc)
91 {
92   for (int i = 0; not table[i].name.empty(); i++)
93     if (name == table[i].name) {
94       if (table[i].name != "default")
95         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name.c_str(),desc);
96       return i;
97     }
98
99   if (table[0].name.empty())
100     xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
101   std::string name_list = table[0].name;
102   for (int i = 1; not table[i].name.empty(); i++)
103     name_list = name_list + ", " + table[i].name;
104
105   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
106   return -1;
107 }
108
109 int (*colls::gather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
110                      MPI_Datatype recv_type, int root, MPI_Comm comm);
111 int (*colls::allgather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
112                         MPI_Datatype recv_type, MPI_Comm comm);
113 int (*colls::allgatherv)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff,
114                          const int* recv_count, const int* recv_disps, MPI_Datatype recv_type, MPI_Comm comm);
115 int (*colls::alltoall)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
116                        MPI_Datatype recv_type, MPI_Comm comm);
117 int (*colls::alltoallv)(const void* send_buff, const int* send_counts, const int* send_disps, MPI_Datatype send_type,
118                         void* recv_buff, const int* recv_counts, const int* recv_disps, MPI_Datatype recv_type,
119                         MPI_Comm comm);
120 int (*colls::bcast)(void* buf, int count, MPI_Datatype datatype, int root, MPI_Comm comm);
121 int (*colls::reduce)(const void* buf, void* rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
122 int (*colls::allreduce)(const void* sbuf, void* rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
123 int (*colls::reduce_scatter)(const void* sbuf, void* rbuf, const int* rcounts, MPI_Datatype dtype, MPI_Op op,
124                              MPI_Comm comm);
125 int (*colls::scatter)(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount,
126                       MPI_Datatype recvtype, int root, MPI_Comm comm);
127 int (*colls::barrier)(MPI_Comm comm);
128
129 void (*colls::smpi_coll_cleanup_callback)();
130
131 void colls::set_gather(const std::string& name)
132 {
133   int id = find_coll_description(mpi_coll_gather_description, name, "gather");
134   gather = reinterpret_cast<int(*)(const void *send_buff, int send_count, MPI_Datatype send_type,
135                                     void *recv_buff, int recv_count, MPI_Datatype recv_type,
136                                         int root, MPI_Comm comm)>(mpi_coll_gather_description[id].coll);
137   if (gather == nullptr)
138     xbt_die("Collective gather set to nullptr!");
139 }
140
141 //COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
142 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
143 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
144 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
145 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
146 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
147 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
148 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
149 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
150 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
151 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
152
153 void colls::set_collectives()
154 {
155   std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
156   if (selector_name.empty())
157     selector_name = "default";
158
159   std::pair<std::string, std::function<void(std::string)>> setter_callbacks[] = {
160       {"gather", &colls::set_gather},         {"allgather", &colls::set_allgather},
161       {"allgatherv", &colls::set_allgatherv}, {"allreduce", &colls::set_allreduce},
162       {"alltoall", &colls::set_alltoall},     {"alltoallv", &colls::set_alltoallv},
163       {"reduce", &colls::set_reduce},         {"reduce_scatter", &colls::set_reduce_scatter},
164       {"scatter", &colls::set_scatter},       {"bcast", &colls::set_bcast},
165       {"barrier", &colls::set_barrier}};
166
167   for (auto& elem : setter_callbacks) {
168     std::string name = simgrid::config::get_value<std::string>(("smpi/" + elem.first).c_str());
169     if (name.empty())
170       name = selector_name;
171
172     (elem.second)(name);
173   }
174 }
175
176 //Implementations of the single algorithm collectives
177
178 int colls::gatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, const int* recvcounts,
179                    const int* displs, MPI_Datatype recvtype, int root, MPI_Comm comm)
180 {
181   MPI_Request request;
182   colls::igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, &request, 0);
183   return Request::wait(&request, MPI_STATUS_IGNORE);
184 }
185
186 int colls::scatterv(const void* sendbuf, const int* sendcounts, const int* displs, MPI_Datatype sendtype, void* recvbuf,
187                     int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
188 {
189   MPI_Request request;
190   colls::iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
191   return Request::wait(&request, MPI_STATUS_IGNORE);
192 }
193
194 int colls::scan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
195 {
196   int system_tag = -888;
197   MPI_Aint lb      = 0;
198   MPI_Aint dataext = 0;
199
200   int rank = comm->rank();
201   int size = comm->size();
202
203   datatype->extent(&lb, &dataext);
204
205   // Local copy from self
206   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
207
208   // Send/Recv buffers to/from others
209   MPI_Request* requests = new MPI_Request[size - 1];
210   unsigned char** tmpbufs = new unsigned char*[rank];
211   int index = 0;
212   for (int other = 0; other < rank; other++) {
213     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
214     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
215     index++;
216   }
217   for (int other = rank + 1; other < size; other++) {
218     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
219     index++;
220   }
221   // Wait for completion of all comms.
222   Request::startall(size - 1, requests);
223
224   if(op != MPI_OP_NULL && op->is_commutative()){
225     for (int other = 0; other < size - 1; other++) {
226       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
227       if(index == MPI_UNDEFINED) {
228         break;
229       }
230       if(index < rank) {
231         // #Request is below rank: it's a irecv
232         op->apply( tmpbufs[index], recvbuf, &count, datatype);
233       }
234     }
235   }else{
236     //non commutative case, wait in order
237     for (int other = 0; other < size - 1; other++) {
238       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
239       if(index < rank && op!=MPI_OP_NULL) {
240         op->apply( tmpbufs[other], recvbuf, &count, datatype);
241       }
242     }
243   }
244   for(index = 0; index < rank; index++) {
245     smpi_free_tmp_buffer(tmpbufs[index]);
246   }
247   for(index = 0; index < size-1; index++) {
248     Request::unref(&requests[index]);
249   }
250   delete[] tmpbufs;
251   delete[] requests;
252   return MPI_SUCCESS;
253 }
254
255 int colls::exscan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
256 {
257   int system_tag = -888;
258   MPI_Aint lb         = 0;
259   MPI_Aint dataext    = 0;
260   int recvbuf_is_empty=1;
261   int rank = comm->rank();
262   int size = comm->size();
263
264   datatype->extent(&lb, &dataext);
265
266   // Send/Recv buffers to/from others
267   MPI_Request* requests = new MPI_Request[size - 1];
268   unsigned char** tmpbufs = new unsigned char*[rank];
269   int index = 0;
270   for (int other = 0; other < rank; other++) {
271     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
272     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
273     index++;
274   }
275   for (int other = rank + 1; other < size; other++) {
276     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
277     index++;
278   }
279   // Wait for completion of all comms.
280   Request::startall(size - 1, requests);
281
282   if(op != MPI_OP_NULL && op->is_commutative()){
283     for (int other = 0; other < size - 1; other++) {
284       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
285       if(index == MPI_UNDEFINED) {
286         break;
287       }
288       if(index < rank) {
289         if(recvbuf_is_empty){
290           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
291           recvbuf_is_empty=0;
292         } else
293           // #Request is below rank: it's a irecv
294           op->apply( tmpbufs[index], recvbuf, &count, datatype);
295       }
296     }
297   }else{
298     //non commutative case, wait in order
299     for (int other = 0; other < size - 1; other++) {
300      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
301       if(index < rank) {
302         if (recvbuf_is_empty) {
303           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
304           recvbuf_is_empty = 0;
305         } else
306           if(op!=MPI_OP_NULL)
307             op->apply( tmpbufs[other], recvbuf, &count, datatype);
308       }
309     }
310   }
311   for(index = 0; index < rank; index++) {
312     smpi_free_tmp_buffer(tmpbufs[index]);
313   }
314   for(index = 0; index < size-1; index++) {
315     Request::unref(&requests[index]);
316   }
317   delete[] tmpbufs;
318   delete[] requests;
319   return MPI_SUCCESS;
320 }
321
322 int colls::alltoallw(const void* sendbuf, const int* sendcounts, const int* senddisps, const MPI_Datatype* sendtypes,
323                      void* recvbuf, const int* recvcounts, const int* recvdisps, const MPI_Datatype* recvtypes,
324                      MPI_Comm comm)
325 {
326   MPI_Request request;
327   colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm,
328                     &request, 0);
329   return Request::wait(&request, MPI_STATUS_IGNORE);
330 }
331
332 }
333 }