Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
1d0f166183188e46d543ec57e56becbf64027a3c
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "smpi_coll.hpp"
9 #include "private.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
15
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI collectives.");
17
18 namespace simgrid {
19 namespace smpi {
20
21 /* these arrays must be nullptr terminated */
22 s_mpi_coll_description_t mpi_coll_gather_description[] = {
23     {"default","gather default collective", (void*)gather__default},
24     {"ompi","gather ompi collective", (void*)gather__ompi},
25     {"ompi_basic_linear","gather ompi_basic_linear collective", (void*)gather__ompi_basic_linear},
26     {"ompi_binomial", "gather ompi_binomial collective", (void*)gather__ompi_binomial},
27     {"ompi_linear_sync", "gather ompi_linear_sync collective", (void*)gather__ompi_linear_sync},
28     {"mpich", "gather mpich collective", (void*)gather__mpich},
29     {"mvapich2","gather mvapich2 collective", (void*)gather__mvapich2},
30     {"mvapich2_two_level", "gather mvapich2_two_level collective", (void*)gather__mvapich2_two_level},
31     {"impi","gather impi collective", (void*)gather__impi},
32     {"automatic","gather automatic collective",(void*)gather__automatic},
33     {"", "", nullptr}};
34 s_mpi_coll_description_t mpi_coll_allgather_description[] = {
35     {"default", "allgather default collective", (void*)allgather__default},
36     {"2dmesh", "allgather 2dmesh collective", (void*)allgather__2dmesh},
37     {"3dmesh", "allgather 3dmesh collective", (void*)allgather__3dmesh},
38     {"bruck", "allgather bruck collective", (void*)allgather__bruck},
39     {"GB", "allgather GB collective", (void*)allgather__GB},
40     {"loosely_lr", "allgather loosely_lr collective", (void*)allgather__loosely_lr},
41     {"NTSLR", "allgather NTSLR collective", (void*)allgather__NTSLR},
42     {"NTSLR_NB", "allgather NTSLR_NB collective", (void*)allgather__NTSLR_NB},
43     {"pair", "allgather pair collective", (void*)allgather__pair},
44     {"rdb", "allgather rdb collective", (void*)allgather__rdb},
45     {"rhv", "allgather rhv collective", (void*)allgather__rhv},
46     {"ring", "allgather ring collective", (void*)allgather__ring },
47     {"SMP_NTS", "allgather SMP_NTS collective", (void*)allgather__SMP_NTS},
48     {"smp_simple", "allgather smp_simple collective", (void*)allgather__smp_simple},
49     {"spreading_simple", "allgather spreading_simple collective", (void*)allgather__spreading_simple},
50     {"ompi", "allgather ompi collective", (void*)allgather__ompi},
51     {"ompi_neighborexchange", "allgather ompi_neighborexchange collective", (void*)allgather__ompi_neighborexchange},
52     {"mvapich2", "allgather mvapich2 collective", (void*)allgather__mvapich2},
53     {"mvapich2_smp", "allgather mvapich2_smp collective", (void*)allgather__mvapich2_smp},
54     {"mpich", "allgather mpich collective", (void*)allgather__mpich},
55     {"impi", "allgather impi collective", (void*)allgather__impi},
56     {"automatic", "allgather automatic collective", (void*)allgather__automatic},
57     {"", "", nullptr}};
58 s_mpi_coll_description_t mpi_coll_allgatherv_description[] = {
59     {"default", "allgatherv default collective", (void*)allgatherv__default},
60     {"GB", "allgatherv GB collective", (void*)allgatherv__GB},
61     {"pair", "allgatherv pair collective", (void*)allgatherv__pair},
62     {"ring", "allgatherv ring collective", (void*)allgatherv__ring},
63     {"ompi", "allgatherv ompi collective", (void*)allgatherv__ompi},
64     {"ompi_neighborexchange", "allgatherv ompi_neighborexchange collective", (void*)allgatherv__ompi_neighborexchange},
65     {"ompi_bruck", "allgatherv ompi_bruck collective", (void*)allgatherv__ompi_bruck},
66     {"mpich", "allgatherv mpich collective", (void*)allgatherv__mpich},
67     {"mpich_rdb", "allgatherv mpich_rdb collective", (void*)allgatherv__mpich_rdb},
68     {"mpich_ring", "allgatherv mpich_ring collective", (void*)allgatherv__mpich_ring},
69     {"mvapich2", "allgatherv mvapich2 collective", (void*)allgatherv__mvapich2},
70     {"impi", "allgatherv impi collective", (void*)allgatherv__impi},
71     {"automatic", "allgatherv automatic collective", (void*)allgatherv__automatic},
72     {"", "", nullptr}};
73 s_mpi_coll_description_t mpi_coll_allreduce_description[] = {
74     {"default", "allreduce default collective", (void*)allreduce__default},
75     {"lr", "allreduce lr collective", (void*)allreduce__lr},
76     {"rab1", "allreduce rab1 collective", (void*)allreduce__rab1},
77     {"rab2", "allreduce rab2 collective", (void*)allreduce__rab2},
78     {"rab_rdb", "allreduce rab_rdb collective", (void*)allreduce__rab_rdb},
79     {"rdb", "allreduce rdb collective", (void*)allreduce__rdb},
80     {"smp_binomial", "allreduce smp_binomial collective", (void*)allreduce__smp_binomial},
81     {"smp_binomial_pipeline", "allreduce smp_binomial_pipeline collective", (void*)allreduce__smp_binomial_pipeline},
82     {"smp_rdb", "allreduce smp_rdb collective", (void*)allreduce__smp_rdb},
83     {"smp_rsag", "allreduce smp_rsag collective", (void*)allreduce__smp_rsag},
84     {"smp_rsag_lr", "allreduce smp_rsag_lr collective", (void*)allreduce__smp_rsag_lr},
85     {"smp_rsag_rab", "allreduce smp_rsag_rab collective", (void*)allreduce__smp_rsag_rab},
86     {"redbcast", "allreduce redbcast collective", (void*)allreduce__redbcast},
87     {"ompi", "allreduce ompi collective", (void*)allreduce__ompi},
88     {"ompi_ring_segmented", "allreduce ompi_ring_segmented collective", (void*)allreduce__ompi_ring_segmented},
89     {"mpich", "allreduce mpich collective", (void*)allreduce__mpich},
90     {"mvapich2", "allreduce mvapich2 collective", (void*)allreduce__mvapich2},
91     {"mvapich2_rs", "allreduce mvapich2_rs collective", (void*)allreduce__mvapich2_rs},
92     {"mvapich2_two_level", "allreduce mvapich2_two_level collective", (void*)allreduce__mvapich2_two_level},
93     {"impi", "allreduce impi collective", (void*)allreduce__impi},
94     {"rab", "allreduce rab collective", (void*)allreduce__rab},
95     {"automatic", "allreduce automatic collective", (void*)allreduce__automatic},
96     {"", "", nullptr}};
97 s_mpi_coll_description_t mpi_coll_reduce_scatter_description[] = {
98     {"default", "reduce_scatter default collective", (void*)reduce_scatter__default},
99     {"ompi", "reduce_scatter ompi collective", (void*)reduce_scatter__ompi},
100     {"ompi_basic_recursivehalving", "reduce_scatter ompi_basic_recursivehalving collective", (void*)reduce_scatter__ompi_basic_recursivehalving},
101     {"ompi_ring", "reduce_scatter ompi_ring collective", (void*)reduce_scatter__ompi_ring},
102     {"mpich", "reduce_scatter mpich collective", (void*)reduce_scatter__mpich},
103     {"mpich_pair", "reduce_scatter mpich_pair collective", (void*)reduce_scatter__mpich_pair},
104     {"mpich_rdb", "reduce_scatter mpich_rdb collective", (void*)reduce_scatter__mpich_rdb},
105     {"mpich_noncomm", "reduce_scatter mpich_noncomm collective", (void*)reduce_scatter__mpich_noncomm},
106     {"mvapich2", "reduce_scatter mvapich2 collective", (void*)reduce_scatter__mvapich2},
107     {"impi", "reduce_scatter impi collective", (void*)reduce_scatter__impi},
108     {"automatic", "reduce_scatter automatic collective", (void*)reduce_scatter__automatic},
109     {"", "", nullptr}};
110 s_mpi_coll_description_t mpi_coll_scatter_description[] = {
111     {"default", "scatter default collective", (void*)scatter__default},
112     {"ompi", "scatter ompi collective", (void*)scatter__ompi},
113     {"ompi_basic_linear", "scatter ompi_basic_linear collective", (void*)scatter__ompi_basic_linear},
114     {"ompi_binomial", "scatter ompi_binomial collective", (void*)scatter__ompi_binomial},
115     {"mpich", "scatter mpich collective", (void*)scatter__mpich},
116     {"mvapich2", "scatter mvapich2 collective", (void*)scatter__mvapich2},
117     {"mvapich2_two_level_binomial", "scatter mvapich2_two_level_binomial collective", (void*)scatter__mvapich2_two_level_binomial},
118     {"mvapich2_two_level_direct", "scatter mvapich2_two_level_direct collective", (void*)scatter__mvapich2_two_level_direct},
119     {"impi", "scatter impi collective", (void*)scatter__impi},
120     {"automatic", "scatter automatic collective", (void*)scatter__automatic},
121     {"", "", nullptr}};
122 s_mpi_coll_description_t mpi_coll_barrier_description[] = {
123     {"default", "barrier default collective", (void*)barrier__default},
124     {"ompi", "barrier ompi collective", (void*)barrier__ompi},
125     {"ompi_basic_linear", "barrier ompi_basic_linear collective", (void*)barrier__ompi_basic_linear},
126     {"ompi_two_procs", "barrier ompi_two_procs collective", (void*)barrier__ompi_two_procs},
127     {"ompi_tree", "barrier ompi_tree collective", (void*)barrier__ompi_tree},
128     {"ompi_bruck", "barrier ompi_bruck collective", (void*)barrier__ompi_bruck},
129     {"ompi_recursivedoubling", "barrier ompi_recursivedoubling collective", (void*)barrier__ompi_recursivedoubling},
130     {"ompi_doublering", "barrier ompi_doublering collective", (void*)barrier__ompi_doublering},
131     {"mpich_smp", "barrier mpich_smp collective", (void*)barrier__mpich_smp},
132     {"mpich", "barrier mpich collective", (void*)barrier__mpich},
133     {"mvapich2_pair", "barrier mvapich2_pair collective", (void*)barrier__mvapich2_pair},
134     {"mvapich2", "barrier mvapich2 collective", (void*)barrier__mvapich2},
135     {"impi", "barrier impi collective", (void*)barrier__impi},
136     {"automatic", "barrier automatic collective", (void*)barrier__automatic},
137     {"", "", nullptr}};
138 s_mpi_coll_description_t mpi_coll_alltoall_description[] = {
139     {"default", "alltoall default collective", (void*)alltoall__default},
140     {"2dmesh", "alltoall 2dmesh collective", (void*)alltoall__2dmesh},
141     {"3dmesh", "alltoall 3dmesh collective", (void*)alltoall__3dmesh},
142     {"basic_linear", "alltoall basic_linear collective", (void*)alltoall__basic_linear},
143     {"bruck", "alltoall bruck collective", (void*)alltoall__bruck},
144     {"pair", "alltoall pair collective", (void*)alltoall__pair},
145     {"pair_rma", "alltoall pair_rma collective", (void*)alltoall__pair_rma},
146     {"pair_light_barrier", "alltoall pair_light_barrier collective", (void*)alltoall__pair_light_barrier},
147     {"pair_mpi_barrier", "alltoall pair_mpi_barrier collective", (void*)alltoall__pair_mpi_barrier},
148     {"pair_one_barrier", "alltoall pair_one_barrier collective", (void*)alltoall__pair_one_barrier},
149     {"rdb", "alltoall rdb collective", (void*)alltoall__rdb},
150     {"ring", "alltoall ring collective", (void*)alltoall__ring},
151     {"ring_light_barrier", "alltoall ring_light_barrier collective", (void*)alltoall__ring_light_barrier},
152     {"ring_mpi_barrier", "alltoall ring_mpi_barrier collective", (void*)alltoall__ring_mpi_barrier},
153     {"ring_one_barrier", "alltoall ring_one_barrier collective", (void*)alltoall__ring_one_barrier},
154     {"mvapich2", "alltoall mvapich2 collective", (void*)alltoall__mvapich2},
155     {"mvapich2_scatter_dest", "alltoall mvapich2_scatter_dest collective", (void*)alltoall__mvapich2_scatter_dest},
156     {"ompi", "alltoall ompi collective", (void*)alltoall__ompi},
157     {"mpich", "alltoall mpich collective", (void*)alltoall__mpich},
158     {"impi", "alltoall impi collective", (void*)alltoall__impi},
159     {"automatic", "alltoall automatic collective", (void*)alltoall__automatic},
160     {"", "", nullptr}};
161 s_mpi_coll_description_t mpi_coll_alltoallv_description[] = {
162     {"default", "alltoallv default collective", (void*)alltoallv__default},
163     {"bruck", "alltoallv bruck collective", (void*)alltoallv__bruck},
164     {"pair", "alltoallv pair collective", (void*)alltoallv__pair},
165     {"pair_light_barrier", "alltoallv pair_light_barrier collective", (void*)alltoallv__pair_light_barrier},
166     {"pair_mpi_barrier", "alltoallv pair_mpi_barrier collective", (void*)alltoallv__pair_mpi_barrier},
167     {"pair_one_barrier", "alltoallv pair_one_barrier collective", (void*)alltoallv__pair_one_barrier},
168     {"ring", "alltoallv ring collective", (void*)alltoallv__ring},
169     {"ring_light_barrier", "alltoallv ring_light_barrier collective", (void*)alltoallv__ring_light_barrier},
170     {"ring_mpi_barrier", "alltoallv ring_mpi_barrier collective", (void*)alltoallv__ring_mpi_barrier},
171     {"ring_one_barrier", "alltoallv ring_one_barrier collective", (void*)alltoallv__ring_one_barrier},
172     {"ompi", "alltoallv ompi collective", (void*)alltoallv__ompi},
173     {"mpich", "alltoallv mpich collective", (void*)alltoallv__mpich},
174     {"ompi_basic_linear", "alltoallv ompi_basic_linear collective", (void*)alltoallv__ompi_basic_linear},
175     {"mvapich2", "alltoallv mvapich2 collective", (void*)alltoallv__mvapich2},
176     {"impi", "alltoallv impi collective", (void*)alltoallv__impi},
177     {"automatic", "alltoallv automatic collective", (void*)alltoallv__automatic},
178     {"", "", nullptr}};
179 s_mpi_coll_description_t mpi_coll_bcast_description[] = {
180     {"default", "bcast default collective", (void*)bcast__default},
181     {"arrival_pattern_aware", "bcast arrival_pattern_aware collective", (void*)bcast__arrival_pattern_aware},
182     {"arrival_pattern_aware_wait", "bcast arrival_pattern_aware_wait collective", (void*)bcast__arrival_pattern_aware_wait},
183     {"arrival_scatter", "bcast arrival_scatter collective", (void*)bcast__arrival_scatter},
184     {"binomial_tree", "bcast binomial_tree collective", (void*)bcast__binomial_tree},
185     {"flattree", "bcast flattree collective", (void*)bcast__flattree},
186     {"flattree_pipeline", "bcast flattree_pipeline collective", (void*)bcast__flattree_pipeline},
187     {"NTSB", "bcast NTSB collective", (void*)bcast__NTSB},
188     {"NTSL", "bcast NTSL collective", (void*)bcast__NTSL},
189     {"NTSL_Isend", "bcast NTSL_Isend collective", (void*)bcast__NTSL_Isend},
190     {"scatter_LR_allgather", "bcast scatter_LR_allgather collective", (void*)bcast__scatter_LR_allgather},
191     {"scatter_rdb_allgather", "bcast scatter_rdb_allgather collective", (void*)bcast__scatter_rdb_allgather},
192     {"SMP_binary", "bcast SMP_binary collective", (void*)bcast__SMP_binary},
193     {"SMP_binomial", "bcast SMP_binomial collective", (void*)bcast__SMP_binomial},
194     {"SMP_linear", "bcast SMP_linear collective", (void*)bcast__SMP_linear},
195     {"ompi", "bcast ompi collective", (void*)bcast__ompi},
196     {"ompi_split_bintree", "bcast ompi_split_bintree collective", (void*)bcast__ompi_split_bintree},
197     {"ompi_pipeline", "bcast ompi_pipeline collective", (void*)bcast__ompi_pipeline},
198     {"mpich", "bcast mpich collective", (void*)bcast__mpich},
199     {"mvapich2", "bcast mvapich2 collective", (void*)bcast__mvapich2},
200     {"mvapich2_inter_node", "bcast mvapich2_inter_node collective", (void*)bcast__mvapich2_inter_node},
201     {"mvapich2_intra_node", "bcast mvapich2_intra_node collective", (void*)bcast__mvapich2_intra_node},
202     {"mvapich2_knomial_intra_node", "bcast mvapich2_knomial_intra_node collective", (void*)bcast__mvapich2_knomial_intra_node},
203     {"impi", "bcast impi collective", (void*)bcast__impi},
204     {"automatic", "bcast automatic collective", (void*)bcast__automatic},
205     {"", "", nullptr}};
206 s_mpi_coll_description_t mpi_coll_reduce_description[] = {
207     {"default", "reduce default collective", (void*)reduce__default},
208     {"arrival_pattern_aware", "reduce arrival_pattern_aware collective", (void*)reduce__arrival_pattern_aware},
209     {"binomial", "reduce binomial collective", (void*)reduce__binomial},
210     {"flat_tree", "reduce flat_tree collective", (void*)reduce__flat_tree},
211     {"NTSL", "reduce NTSL collective", (void*)reduce__NTSL},
212     {"scatter_gather", "reduce scatter_gather collective", (void*)reduce__scatter_gather},
213     {"ompi", "reduce ompi collective", (void*)reduce__ompi},
214     {"ompi_chain", "reduce ompi_chain collective", (void*)reduce__ompi_chain},
215     {"ompi_pipeline", "reduce ompi_pipeline collective", (void*)reduce__ompi_pipeline},
216     {"ompi_basic_linear", "reduce ompi_basic_linear collective", (void*)reduce__ompi_basic_linear},
217     {"ompi_in_order_binary", "reduce ompi_in_order_binary collective", (void*)reduce__ompi_in_order_binary},
218     {"ompi_binary", "reduce ompi_binary collective", (void*)reduce__ompi_binary},
219     {"ompi_binomial", "reduce ompi_binomial collective", (void*)reduce__ompi_binomial},
220     {"mpich", "reduce mpich collective", (void*)reduce__mpich},
221     {"mvapich2", "reduce mvapich2 collective", (void*)reduce__mvapich2},
222     {"mvapich2_knomial", "reduce mvapich2_knomial collective", (void*)reduce__mvapich2_knomial},
223     {"mvapich2_two_level", "reduce mvapich2_two_level collective", (void*)reduce__mvapich2_two_level},
224     {"impi", "reduce impi collective", (void*)reduce__impi},
225     {"rab", "reduce rab collective", (void*)reduce__rab},
226     {"automatic", "reduce automatic collective", (void*)reduce__automatic},
227     {"", "", nullptr}};
228
229 // Needed by the automatic selector weird implementation
230 s_mpi_coll_description_t* colls::get_smpi_coll_description(const char* name, int rank)
231 {
232   if (strcmp(name, "gather") == 0)
233     return &mpi_coll_gather_description[rank];
234   if (strcmp(name, "allgather") == 0)
235     return &mpi_coll_allgather_description[rank];
236   if (strcmp(name, "allgatherv") == 0)
237     return &mpi_coll_allgatherv_description[rank];
238   if (strcmp(name, "allreduce") == 0)
239     return &mpi_coll_allreduce_description[rank];
240   if (strcmp(name, "reduce_scatter") == 0)
241     return &mpi_coll_reduce_scatter_description[rank];
242   if (strcmp(name, "scatter") == 0)
243     return &mpi_coll_scatter_description[rank];
244   if (strcmp(name, "barrier") == 0)
245     return &mpi_coll_barrier_description[rank];
246   if (strcmp(name, "alltoall") == 0)
247     return &mpi_coll_alltoall_description[rank];
248   if (strcmp(name, "alltoallv") == 0)
249     return &mpi_coll_alltoallv_description[rank];
250   if (strcmp(name, "bcast") == 0)
251     return &mpi_coll_bcast_description[rank];
252   if (strcmp(name, "reduce") == 0)
253     return &mpi_coll_reduce_description[rank];
254   XBT_INFO("You requested an unknown collective: %s", name);
255   return nullptr;
256 }
257
258 /** Displays the long description of all registered models, and quit */
259 void colls::coll_help(const char* category, s_mpi_coll_description_t* table)
260 {
261   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
262   for (int i = 0; not table[i].name.empty(); i++)
263     XBT_WARN("  %s: %s\n", table[i].name.c_str(), table[i].description.c_str());
264 }
265
266 int colls::find_coll_description(s_mpi_coll_description_t* table, const std::string& name, const char* desc)
267 {
268   for (int i = 0; not table[i].name.empty(); i++)
269     if (name == table[i].name) {
270       if (table[i].name != "default")
271         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name.c_str(),desc);
272       return i;
273     }
274
275   if (table[0].name.empty())
276     xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
277   std::string name_list = table[0].name;
278   for (int i = 1; not table[i].name.empty(); i++)
279     name_list = name_list + ", " + table[i].name;
280
281   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
282   return -1;
283 }
284
285 int (*colls::gather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
286                      MPI_Datatype recv_type, int root, MPI_Comm comm);
287 int (*colls::allgather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
288                         MPI_Datatype recv_type, MPI_Comm comm);
289 int (*colls::allgatherv)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff,
290                          const int* recv_count, const int* recv_disps, MPI_Datatype recv_type, MPI_Comm comm);
291 int (*colls::alltoall)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
292                        MPI_Datatype recv_type, MPI_Comm comm);
293 int (*colls::alltoallv)(const void* send_buff, const int* send_counts, const int* send_disps, MPI_Datatype send_type,
294                         void* recv_buff, const int* recv_counts, const int* recv_disps, MPI_Datatype recv_type,
295                         MPI_Comm comm);
296 int (*colls::bcast)(void* buf, int count, MPI_Datatype datatype, int root, MPI_Comm comm);
297 int (*colls::reduce)(const void* buf, void* rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
298 int (*colls::allreduce)(const void* sbuf, void* rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
299 int (*colls::reduce_scatter)(const void* sbuf, void* rbuf, const int* rcounts, MPI_Datatype dtype, MPI_Op op,
300                              MPI_Comm comm);
301 int (*colls::scatter)(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount,
302                       MPI_Datatype recvtype, int root, MPI_Comm comm);
303 int (*colls::barrier)(MPI_Comm comm);
304
305 void (*colls::smpi_coll_cleanup_callback)();
306
307 #define COLL_SETTER(cat, ret, args, args2)                                                                             \
308   void colls::_XBT_CONCAT(set_, cat)(const std::string& name)                                                          \
309   {                                                                                                                    \
310     int id = find_coll_description(_XBT_CONCAT3(mpi_coll_, cat, _description), name, _XBT_STRINGIFY(cat));             \
311     cat    = reinterpret_cast<ret(*) args>(_XBT_CONCAT3(mpi_coll_, cat, _description)[id].coll);                       \
312     if (cat == nullptr)                                                                                                \
313       xbt_die("Collective " _XBT_STRINGIFY(cat) " set to nullptr!");                                                   \
314   }
315 COLL_APPLY(COLL_SETTER, COLL_GATHER_SIG, "");
316 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
317 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
318 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
319 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
320 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
321 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
322 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
323 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
324 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
325 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
326
327 void colls::set_collectives()
328 {
329   std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
330   if (selector_name.empty())
331     selector_name = "default";
332
333   std::pair<std::string, std::function<void(std::string)>> setter_callbacks[] = {
334       {"gather", &colls::set_gather},         {"allgather", &colls::set_allgather},
335       {"allgatherv", &colls::set_allgatherv}, {"allreduce", &colls::set_allreduce},
336       {"alltoall", &colls::set_alltoall},     {"alltoallv", &colls::set_alltoallv},
337       {"reduce", &colls::set_reduce},         {"reduce_scatter", &colls::set_reduce_scatter},
338       {"scatter", &colls::set_scatter},       {"bcast", &colls::set_bcast},
339       {"barrier", &colls::set_barrier}};
340
341   for (auto& elem : setter_callbacks) {
342     std::string name = simgrid::config::get_value<std::string>(("smpi/" + elem.first).c_str());
343     if (name.empty())
344       name = selector_name;
345
346     (elem.second)(name);
347   }
348 }
349
350 //Implementations of the single algorithm collectives
351
352 int colls::gatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, const int* recvcounts,
353                    const int* displs, MPI_Datatype recvtype, int root, MPI_Comm comm)
354 {
355   MPI_Request request;
356   colls::igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, &request, 0);
357   return Request::wait(&request, MPI_STATUS_IGNORE);
358 }
359
360 int colls::scatterv(const void* sendbuf, const int* sendcounts, const int* displs, MPI_Datatype sendtype, void* recvbuf,
361                     int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
362 {
363   MPI_Request request;
364   colls::iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
365   return Request::wait(&request, MPI_STATUS_IGNORE);
366 }
367
368 int colls::scan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
369 {
370   int system_tag = -888;
371   MPI_Aint lb      = 0;
372   MPI_Aint dataext = 0;
373
374   int rank = comm->rank();
375   int size = comm->size();
376
377   datatype->extent(&lb, &dataext);
378
379   // Local copy from self
380   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
381
382   // Send/Recv buffers to/from others
383   MPI_Request* requests = new MPI_Request[size - 1];
384   unsigned char** tmpbufs = new unsigned char*[rank];
385   int index = 0;
386   for (int other = 0; other < rank; other++) {
387     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
388     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
389     index++;
390   }
391   for (int other = rank + 1; other < size; other++) {
392     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
393     index++;
394   }
395   // Wait for completion of all comms.
396   Request::startall(size - 1, requests);
397
398   if(op != MPI_OP_NULL && op->is_commutative()){
399     for (int other = 0; other < size - 1; other++) {
400       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
401       if(index == MPI_UNDEFINED) {
402         break;
403       }
404       if(index < rank) {
405         // #Request is below rank: it's a irecv
406         op->apply( tmpbufs[index], recvbuf, &count, datatype);
407       }
408     }
409   }else{
410     //non commutative case, wait in order
411     for (int other = 0; other < size - 1; other++) {
412       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
413       if(index < rank && op!=MPI_OP_NULL) {
414         op->apply( tmpbufs[other], recvbuf, &count, datatype);
415       }
416     }
417   }
418   for(index = 0; index < rank; index++) {
419     smpi_free_tmp_buffer(tmpbufs[index]);
420   }
421   for(index = 0; index < size-1; index++) {
422     Request::unref(&requests[index]);
423   }
424   delete[] tmpbufs;
425   delete[] requests;
426   return MPI_SUCCESS;
427 }
428
429 int colls::exscan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
430 {
431   int system_tag = -888;
432   MPI_Aint lb         = 0;
433   MPI_Aint dataext    = 0;
434   int recvbuf_is_empty=1;
435   int rank = comm->rank();
436   int size = comm->size();
437
438   datatype->extent(&lb, &dataext);
439
440   // Send/Recv buffers to/from others
441   MPI_Request* requests = new MPI_Request[size - 1];
442   unsigned char** tmpbufs = new unsigned char*[rank];
443   int index = 0;
444   for (int other = 0; other < rank; other++) {
445     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
446     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
447     index++;
448   }
449   for (int other = rank + 1; other < size; other++) {
450     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
451     index++;
452   }
453   // Wait for completion of all comms.
454   Request::startall(size - 1, requests);
455
456   if(op != MPI_OP_NULL && op->is_commutative()){
457     for (int other = 0; other < size - 1; other++) {
458       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
459       if(index == MPI_UNDEFINED) {
460         break;
461       }
462       if(index < rank) {
463         if(recvbuf_is_empty){
464           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
465           recvbuf_is_empty=0;
466         } else
467           // #Request is below rank: it's a irecv
468           op->apply( tmpbufs[index], recvbuf, &count, datatype);
469       }
470     }
471   }else{
472     //non commutative case, wait in order
473     for (int other = 0; other < size - 1; other++) {
474      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
475       if(index < rank) {
476         if (recvbuf_is_empty) {
477           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
478           recvbuf_is_empty = 0;
479         } else
480           if(op!=MPI_OP_NULL)
481             op->apply( tmpbufs[other], recvbuf, &count, datatype);
482       }
483     }
484   }
485   for(index = 0; index < rank; index++) {
486     smpi_free_tmp_buffer(tmpbufs[index]);
487   }
488   for(index = 0; index < size-1; index++) {
489     Request::unref(&requests[index]);
490   }
491   delete[] tmpbufs;
492   delete[] requests;
493   return MPI_SUCCESS;
494 }
495
496 int colls::alltoallw(const void* sendbuf, const int* sendcounts, const int* senddisps, const MPI_Datatype* sendtypes,
497                      void* recvbuf, const int* recvcounts, const int* recvdisps, const MPI_Datatype* recvtypes,
498                      MPI_Comm comm)
499 {
500   MPI_Request request;
501   colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm,
502                     &request, 0);
503   return Request::wait(&request, MPI_STATUS_IGNORE);
504 }
505
506 }
507 }