Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
f8c1d751ef287e6f935062c8b25d2b5a4296d54f
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "smpi_coll.hpp"
9 #include "private.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
15
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
17
18 #define COLL_SETTER(cat, ret, args, args2)                                                                             \
19   int(*Colls::cat) args;                                                                                               \
20   void Colls::_XBT_CONCAT(set_, cat)(const std::string& name)                                                          \
21   {                                                                                                                    \
22     int id = find_coll_description(_XBT_CONCAT3(mpi_coll_, cat, _description), name, _XBT_STRINGIFY(cat));             \
23     cat    = reinterpret_cast<ret(*) args>(_XBT_CONCAT3(mpi_coll_, cat, _description)[id].coll);                       \
24     if (cat == nullptr)                                                                                                \
25       xbt_die("Collective " _XBT_STRINGIFY(cat) " set to nullptr!");                                                   \
26   }
27
28 namespace simgrid{
29 namespace smpi{
30
31 void (*Colls::smpi_coll_cleanup_callback)();
32
33 /* these arrays must be nullptr terminated */
34 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
35     COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
36 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
37     COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
38 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
39     COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
40 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
41     COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
42 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
43     COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
44 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
45     COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
46 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
47     COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
48 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
49     COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
50 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
51     COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
52 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
53     COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
54 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
55     COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
56
57 /** Displays the long description of all registered models, and quit */
58 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
59 {
60   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
61   for (int i = 0; not table[i].name.empty(); i++)
62     XBT_WARN("  %s: %s\n", table[i].name.c_str(), table[i].description.c_str());
63 }
64
65 int Colls::find_coll_description(s_mpi_coll_description_t* table, const std::string& name, const char* desc)
66 {
67   for (int i = 0; not table[i].name.empty(); i++)
68     if (name == table[i].name) {
69       if (table[i].name != "default")
70         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name.c_str(),desc);
71       return i;
72     }
73
74   if (table[0].name.empty())
75     xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
76   std::string name_list = table[0].name;
77   for (int i = 1; not table[i].name.empty(); i++)
78     name_list = name_list + ", " + table[i].name;
79
80   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
81   return -1;
82 }
83
84 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
85 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
86 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
87 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
88 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
89 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
90 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
95
96 void Colls::set_collectives(){
97   std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
98   if (selector_name.empty())
99     selector_name = "default";
100
101   std::pair<std::string, std::function<void(std::string)>> setter_callbacks[] = {
102       {"gather", &Colls::set_gather},         {"allgather", &Colls::set_allgather},
103       {"allgatherv", &Colls::set_allgatherv}, {"allreduce", &Colls::set_allreduce},
104       {"alltoall", &Colls::set_alltoall},     {"alltoallv", &Colls::set_alltoallv},
105       {"reduce", &Colls::set_reduce},         {"reduce_scatter", &Colls::set_reduce_scatter},
106       {"scatter", &Colls::set_scatter},       {"bcast", &Colls::set_bcast},
107       {"barrier", &Colls::set_barrier}};
108
109   for (auto& elem : setter_callbacks) {
110     std::string name = simgrid::config::get_value<std::string>(("smpi/" + elem.first).c_str());
111     if (name.empty())
112       name = selector_name;
113
114     (elem.second)(name);
115   }
116 }
117
118 //Implementations of the single algorithm collectives
119
120 int Colls::gatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, const int *recvcounts, const int *displs,
121                       MPI_Datatype recvtype, int root, MPI_Comm comm)
122 {
123   MPI_Request request;
124   Colls::igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, &request);
125   return Request::wait(&request, MPI_STATUS_IGNORE);
126 }
127
128
129 int Colls::scatterv(const void *sendbuf, const int *sendcounts, const int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
130                        MPI_Datatype recvtype, int root, MPI_Comm comm)
131 {
132   MPI_Request request;
133   Colls::iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, &request);
134   return Request::wait(&request, MPI_STATUS_IGNORE);
135 }
136
137
138 int Colls::scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
139 {
140   int system_tag = -888;
141   MPI_Aint lb      = 0;
142   MPI_Aint dataext = 0;
143
144   int rank = comm->rank();
145   int size = comm->size();
146
147   datatype->extent(&lb, &dataext);
148
149   // Local copy from self
150   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
151
152   // Send/Recv buffers to/from others
153   MPI_Request* requests = new MPI_Request[size - 1];
154   unsigned char** tmpbufs = new unsigned char*[rank];
155   int index = 0;
156   for (int other = 0; other < rank; other++) {
157     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
158     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
159     index++;
160   }
161   for (int other = rank + 1; other < size; other++) {
162     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
163     index++;
164   }
165   // Wait for completion of all comms.
166   Request::startall(size - 1, requests);
167
168   if(op != MPI_OP_NULL && op->is_commutative()){
169     for (int other = 0; other < size - 1; other++) {
170       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
171       if(index == MPI_UNDEFINED) {
172         break;
173       }
174       if(index < rank) {
175         // #Request is below rank: it's a irecv
176         op->apply( tmpbufs[index], recvbuf, &count, datatype);
177       }
178     }
179   }else{
180     //non commutative case, wait in order
181     for (int other = 0; other < size - 1; other++) {
182       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
183       if(index < rank && op!=MPI_OP_NULL) {
184         op->apply( tmpbufs[other], recvbuf, &count, datatype);
185       }
186     }
187   }
188   for(index = 0; index < rank; index++) {
189     smpi_free_tmp_buffer(tmpbufs[index]);
190   }
191   for(index = 0; index < size-1; index++) {
192     Request::unref(&requests[index]);
193   }
194   delete[] tmpbufs;
195   delete[] requests;
196   return MPI_SUCCESS;
197 }
198
199 int Colls::exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
200 {
201   int system_tag = -888;
202   MPI_Aint lb         = 0;
203   MPI_Aint dataext    = 0;
204   int recvbuf_is_empty=1;
205   int rank = comm->rank();
206   int size = comm->size();
207
208   datatype->extent(&lb, &dataext);
209
210   // Send/Recv buffers to/from others
211   MPI_Request* requests = new MPI_Request[size - 1];
212   unsigned char** tmpbufs = new unsigned char*[rank];
213   int index = 0;
214   for (int other = 0; other < rank; other++) {
215     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
216     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
217     index++;
218   }
219   for (int other = rank + 1; other < size; other++) {
220     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
221     index++;
222   }
223   // Wait for completion of all comms.
224   Request::startall(size - 1, requests);
225
226   if(op != MPI_OP_NULL && op->is_commutative()){
227     for (int other = 0; other < size - 1; other++) {
228       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
229       if(index == MPI_UNDEFINED) {
230         break;
231       }
232       if(index < rank) {
233         if(recvbuf_is_empty){
234           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
235           recvbuf_is_empty=0;
236         } else
237           // #Request is below rank: it's a irecv
238           op->apply( tmpbufs[index], recvbuf, &count, datatype);
239       }
240     }
241   }else{
242     //non commutative case, wait in order
243     for (int other = 0; other < size - 1; other++) {
244      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
245       if(index < rank) {
246         if (recvbuf_is_empty) {
247           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
248           recvbuf_is_empty = 0;
249         } else
250           if(op!=MPI_OP_NULL)
251             op->apply( tmpbufs[other], recvbuf, &count, datatype);
252       }
253     }
254   }
255   for(index = 0; index < rank; index++) {
256     smpi_free_tmp_buffer(tmpbufs[index]);
257   }
258   for(index = 0; index < size-1; index++) {
259     Request::unref(&requests[index]);
260   }
261   delete[] tmpbufs;
262   delete[] requests;
263   return MPI_SUCCESS;
264 }
265
266 int Colls::alltoallw(const void *sendbuf, const int *sendcounts, const int *senddisps, const MPI_Datatype* sendtypes,
267                               void *recvbuf, const int *recvcounts, const int *recvdisps, const MPI_Datatype* recvtypes, MPI_Comm comm)
268 {
269   MPI_Request request;
270   Colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, &request);  
271   return Request::wait(&request, MPI_STATUS_IGNORE);
272 }
273
274 }
275 }