Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add new entry in Release_Notes.
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                 */
2
3 /* Copyright (c) 2009-2023. The SimGrid Team. All rights reserved.          */
4
5 /* This program is free software; you can redistribute it and/or modify it
6  * under the terms of the license (GNU LGPL) which comes with this package. */
7
8 #include "smpi_coll.hpp"
9 #include "private.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
15
16 #include <iomanip>
17 #include <map>
18 #include <numeric>
19 #include <sstream>
20
21 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI collectives.");
22
23 namespace simgrid::smpi {
24
25 std::map<std::string, std::vector<s_mpi_coll_description_t>, std::less<>> smpi_coll_descriptions(
26     {{"gather",
27       {{"default", "gather default collective", (void*)gather__default},
28        {"ompi", "gather ompi collective", (void*)gather__ompi},
29        {"ompi_basic_linear", "gather ompi_basic_linear collective", (void*)gather__ompi_basic_linear},
30        {"ompi_binomial", "gather ompi_binomial collective", (void*)gather__ompi_binomial},
31        {"ompi_linear_sync", "gather ompi_linear_sync collective", (void*)gather__ompi_linear_sync},
32        {"mpich", "gather mpich collective", (void*)gather__mpich},
33        {"mvapich2", "gather mvapich2 collective", (void*)gather__mvapich2},
34        {"mvapich2_two_level", "gather mvapich2_two_level collective", (void*)gather__mvapich2_two_level},
35        {"impi", "gather impi collective", (void*)gather__impi},
36        {"automatic", "gather automatic collective", (void*)gather__automatic}}},
37
38      {"allgather",
39       {{"default", "allgather default collective", (void*)allgather__default},
40        {"2dmesh", "allgather 2dmesh collective", (void*)allgather__2dmesh},
41        {"3dmesh", "allgather 3dmesh collective", (void*)allgather__3dmesh},
42        {"bruck", "allgather bruck collective", (void*)allgather__bruck},
43        {"GB", "allgather GB collective", (void*)allgather__GB},
44        {"loosely_lr", "allgather loosely_lr collective", (void*)allgather__loosely_lr},
45        {"NTSLR", "allgather NTSLR collective", (void*)allgather__NTSLR},
46        {"NTSLR_NB", "allgather NTSLR_NB collective", (void*)allgather__NTSLR_NB},
47        {"pair", "allgather pair collective", (void*)allgather__pair},
48        {"rdb", "allgather rdb collective", (void*)allgather__rdb},
49        {"rhv", "allgather rhv collective", (void*)allgather__rhv},
50        {"ring", "allgather ring collective", (void*)allgather__ring},
51        {"SMP_NTS", "allgather SMP_NTS collective", (void*)allgather__SMP_NTS},
52        {"smp_simple", "allgather smp_simple collective", (void*)allgather__smp_simple},
53        {"spreading_simple", "allgather spreading_simple collective", (void*)allgather__spreading_simple},
54        {"ompi", "allgather ompi collective", (void*)allgather__ompi},
55        {"ompi_neighborexchange", "allgather ompi_neighborexchange collective", (void*)allgather__ompi_neighborexchange},
56        {"mvapich2", "allgather mvapich2 collective", (void*)allgather__mvapich2},
57        {"mvapich2_smp", "allgather mvapich2_smp collective", (void*)allgather__mvapich2_smp},
58        {"mpich", "allgather mpich collective", (void*)allgather__mpich},
59        {"impi", "allgather impi collective", (void*)allgather__impi},
60        {"automatic", "allgather automatic collective", (void*)allgather__automatic}}},
61
62      {"allgatherv",
63       {{"default", "allgatherv default collective", (void*)allgatherv__default},
64        {"GB", "allgatherv GB collective", (void*)allgatherv__GB},
65        {"pair", "allgatherv pair collective", (void*)allgatherv__pair},
66        {"ring", "allgatherv ring collective", (void*)allgatherv__ring},
67        {"ompi", "allgatherv ompi collective", (void*)allgatherv__ompi},
68        {"ompi_neighborexchange", "allgatherv ompi_neighborexchange collective",
69         (void*)allgatherv__ompi_neighborexchange},
70        {"ompi_bruck", "allgatherv ompi_bruck collective", (void*)allgatherv__ompi_bruck},
71        {"mpich", "allgatherv mpich collective", (void*)allgatherv__mpich},
72        {"mpich_rdb", "allgatherv mpich_rdb collective", (void*)allgatherv__mpich_rdb},
73        {"mpich_ring", "allgatherv mpich_ring collective", (void*)allgatherv__mpich_ring},
74        {"mvapich2", "allgatherv mvapich2 collective", (void*)allgatherv__mvapich2},
75        {"impi", "allgatherv impi collective", (void*)allgatherv__impi},
76        {"automatic", "allgatherv automatic collective", (void*)allgatherv__automatic}}},
77
78      {"allreduce",
79       {{"default", "allreduce default collective", (void*)allreduce__default},
80        {"lr", "allreduce lr collective", (void*)allreduce__lr},
81        {"rab1", "allreduce rab1 collective", (void*)allreduce__rab1},
82        {"rab2", "allreduce rab2 collective", (void*)allreduce__rab2},
83        {"rab_rdb", "allreduce rab_rdb collective", (void*)allreduce__rab_rdb},
84        {"rdb", "allreduce rdb collective", (void*)allreduce__rdb},
85        {"smp_binomial", "allreduce smp_binomial collective", (void*)allreduce__smp_binomial},
86        {"smp_binomial_pipeline", "allreduce smp_binomial_pipeline collective", (void*)allreduce__smp_binomial_pipeline},
87        {"smp_rdb", "allreduce smp_rdb collective", (void*)allreduce__smp_rdb},
88        {"smp_rsag", "allreduce smp_rsag collective", (void*)allreduce__smp_rsag},
89        {"smp_rsag_lr", "allreduce smp_rsag_lr collective", (void*)allreduce__smp_rsag_lr},
90        {"smp_rsag_rab", "allreduce smp_rsag_rab collective", (void*)allreduce__smp_rsag_rab},
91        {"redbcast", "allreduce redbcast collective", (void*)allreduce__redbcast},
92        {"ompi", "allreduce ompi collective", (void*)allreduce__ompi},
93        {"ompi_ring_segmented", "allreduce ompi_ring_segmented collective", (void*)allreduce__ompi_ring_segmented},
94        {"mpich", "allreduce mpich collective", (void*)allreduce__mpich},
95        {"mvapich2", "allreduce mvapich2 collective", (void*)allreduce__mvapich2},
96        {"mvapich2_rs", "allreduce mvapich2_rs collective", (void*)allreduce__mvapich2_rs},
97        {"mvapich2_two_level", "allreduce mvapich2_two_level collective", (void*)allreduce__mvapich2_two_level},
98        {"impi", "allreduce impi collective", (void*)allreduce__impi},
99        {"rab", "allreduce rab collective", (void*)allreduce__rab},
100        {"automatic", "allreduce automatic collective", (void*)allreduce__automatic}}},
101
102      {"reduce_scatter",
103       {{"default", "reduce_scatter default collective", (void*)reduce_scatter__default},
104        {"ompi", "reduce_scatter ompi collective", (void*)reduce_scatter__ompi},
105        {"ompi_basic_recursivehalving", "reduce_scatter ompi_basic_recursivehalving collective",
106         (void*)reduce_scatter__ompi_basic_recursivehalving},
107        {"ompi_ring", "reduce_scatter ompi_ring collective", (void*)reduce_scatter__ompi_ring},
108        {"ompi_butterfly", "reduce_scatter ompi_butterfly collective", (void*)reduce_scatter__ompi_butterfly},
109        {"mpich", "reduce_scatter mpich collective", (void*)reduce_scatter__mpich},
110        {"mpich_pair", "reduce_scatter mpich_pair collective", (void*)reduce_scatter__mpich_pair},
111        {"mpich_rdb", "reduce_scatter mpich_rdb collective", (void*)reduce_scatter__mpich_rdb},
112        {"mpich_noncomm", "reduce_scatter mpich_noncomm collective", (void*)reduce_scatter__mpich_noncomm},
113        {"mvapich2", "reduce_scatter mvapich2 collective", (void*)reduce_scatter__mvapich2},
114        {"impi", "reduce_scatter impi collective", (void*)reduce_scatter__impi},
115        {"automatic", "reduce_scatter automatic collective", (void*)reduce_scatter__automatic}}},
116
117      {"scatter",
118       {{"default", "scatter default collective", (void*)scatter__default},
119        {"ompi", "scatter ompi collective", (void*)scatter__ompi},
120        {"ompi_basic_linear", "scatter ompi_basic_linear collective", (void*)scatter__ompi_basic_linear},
121        {"ompi_linear_nb", "scatter ompi_linear nonblocking collective", (void*)scatter__ompi_linear_nb},
122        {"ompi_binomial", "scatter ompi_binomial collective", (void*)scatter__ompi_binomial},
123        {"mpich", "scatter mpich collective", (void*)scatter__mpich},
124        {"mvapich2", "scatter mvapich2 collective", (void*)scatter__mvapich2},
125        {"mvapich2_two_level_binomial", "scatter mvapich2_two_level_binomial collective",
126         (void*)scatter__mvapich2_two_level_binomial},
127        {"mvapich2_two_level_direct", "scatter mvapich2_two_level_direct collective",
128         (void*)scatter__mvapich2_two_level_direct},
129        {"impi", "scatter impi collective", (void*)scatter__impi},
130        {"automatic", "scatter automatic collective", (void*)scatter__automatic}}},
131
132      {"barrier",
133       {{"default", "barrier default collective", (void*)barrier__default},
134        {"ompi", "barrier ompi collective", (void*)barrier__ompi},
135        {"ompi_basic_linear", "barrier ompi_basic_linear collective", (void*)barrier__ompi_basic_linear},
136        {"ompi_two_procs", "barrier ompi_two_procs collective", (void*)barrier__ompi_two_procs},
137        {"ompi_tree", "barrier ompi_tree collective", (void*)barrier__ompi_tree},
138        {"ompi_bruck", "barrier ompi_bruck collective", (void*)barrier__ompi_bruck},
139        {"ompi_recursivedoubling", "barrier ompi_recursivedoubling collective", (void*)barrier__ompi_recursivedoubling},
140        {"ompi_doublering", "barrier ompi_doublering collective", (void*)barrier__ompi_doublering},
141        {"mpich_smp", "barrier mpich_smp collective", (void*)barrier__mpich_smp},
142        {"mpich", "barrier mpich collective", (void*)barrier__mpich},
143        {"mvapich2_pair", "barrier mvapich2_pair collective", (void*)barrier__mvapich2_pair},
144        {"mvapich2", "barrier mvapich2 collective", (void*)barrier__mvapich2},
145        {"impi", "barrier impi collective", (void*)barrier__impi},
146        {"automatic", "barrier automatic collective", (void*)barrier__automatic}}},
147
148      {"alltoall",
149       {{"default", "alltoall default collective", (void*)alltoall__default},
150        {"2dmesh", "alltoall 2dmesh collective", (void*)alltoall__2dmesh},
151        {"3dmesh", "alltoall 3dmesh collective", (void*)alltoall__3dmesh},
152        {"basic_linear", "alltoall basic_linear collective", (void*)alltoall__basic_linear},
153        {"bruck", "alltoall bruck collective", (void*)alltoall__bruck},
154        {"pair", "alltoall pair collective", (void*)alltoall__pair},
155        {"pair_rma", "alltoall pair_rma collective", (void*)alltoall__pair_rma},
156        {"pair_light_barrier", "alltoall pair_light_barrier collective", (void*)alltoall__pair_light_barrier},
157        {"pair_mpi_barrier", "alltoall pair_mpi_barrier collective", (void*)alltoall__pair_mpi_barrier},
158        {"pair_one_barrier", "alltoall pair_one_barrier collective", (void*)alltoall__pair_one_barrier},
159        {"rdb", "alltoall rdb collective", (void*)alltoall__rdb},
160        {"ring", "alltoall ring collective", (void*)alltoall__ring},
161        {"ring_light_barrier", "alltoall ring_light_barrier collective", (void*)alltoall__ring_light_barrier},
162        {"ring_mpi_barrier", "alltoall ring_mpi_barrier collective", (void*)alltoall__ring_mpi_barrier},
163        {"ring_one_barrier", "alltoall ring_one_barrier collective", (void*)alltoall__ring_one_barrier},
164        {"mvapich2", "alltoall mvapich2 collective", (void*)alltoall__mvapich2},
165        {"mvapich2_scatter_dest", "alltoall mvapich2_scatter_dest collective", (void*)alltoall__mvapich2_scatter_dest},
166        {"ompi", "alltoall ompi collective", (void*)alltoall__ompi},
167        {"mpich", "alltoall mpich collective", (void*)alltoall__mpich},
168        {"impi", "alltoall impi collective", (void*)alltoall__impi},
169        {"automatic", "alltoall automatic collective", (void*)alltoall__automatic}}},
170
171      {"alltoallv",
172       {{"default", "alltoallv default collective", (void*)alltoallv__default},
173        {"bruck", "alltoallv bruck collective", (void*)alltoallv__bruck},
174        {"pair", "alltoallv pair collective", (void*)alltoallv__pair},
175        {"pair_light_barrier", "alltoallv pair_light_barrier collective", (void*)alltoallv__pair_light_barrier},
176        {"pair_mpi_barrier", "alltoallv pair_mpi_barrier collective", (void*)alltoallv__pair_mpi_barrier},
177        {"pair_one_barrier", "alltoallv pair_one_barrier collective", (void*)alltoallv__pair_one_barrier},
178        {"ring", "alltoallv ring collective", (void*)alltoallv__ring},
179        {"ring_light_barrier", "alltoallv ring_light_barrier collective", (void*)alltoallv__ring_light_barrier},
180        {"ring_mpi_barrier", "alltoallv ring_mpi_barrier collective", (void*)alltoallv__ring_mpi_barrier},
181        {"ring_one_barrier", "alltoallv ring_one_barrier collective", (void*)alltoallv__ring_one_barrier},
182        {"ompi", "alltoallv ompi collective", (void*)alltoallv__ompi},
183        {"mpich", "alltoallv mpich collective", (void*)alltoallv__mpich},
184        {"ompi_basic_linear", "alltoallv ompi_basic_linear collective", (void*)alltoallv__ompi_basic_linear},
185        {"mvapich2", "alltoallv mvapich2 collective", (void*)alltoallv__mvapich2},
186        {"impi", "alltoallv impi collective", (void*)alltoallv__impi},
187        {"automatic", "alltoallv automatic collective", (void*)alltoallv__automatic}}},
188
189      {"bcast",
190       {{"default", "bcast default collective", (void*)bcast__default},
191        {"arrival_pattern_aware", "bcast arrival_pattern_aware collective", (void*)bcast__arrival_pattern_aware},
192        {"arrival_pattern_aware_wait", "bcast arrival_pattern_aware_wait collective",
193         (void*)bcast__arrival_pattern_aware_wait},
194        {"arrival_scatter", "bcast arrival_scatter collective", (void*)bcast__arrival_scatter},
195        {"binomial_tree", "bcast binomial_tree collective", (void*)bcast__binomial_tree},
196        {"flattree", "bcast flattree collective", (void*)bcast__flattree},
197        {"flattree_pipeline", "bcast flattree_pipeline collective", (void*)bcast__flattree_pipeline},
198        {"NTSB", "bcast NTSB collective", (void*)bcast__NTSB},
199        {"NTSL", "bcast NTSL collective", (void*)bcast__NTSL},
200        {"NTSL_Isend", "bcast NTSL_Isend collective", (void*)bcast__NTSL_Isend},
201        {"scatter_LR_allgather", "bcast scatter_LR_allgather collective", (void*)bcast__scatter_LR_allgather},
202        {"scatter_rdb_allgather", "bcast scatter_rdb_allgather collective", (void*)bcast__scatter_rdb_allgather},
203        {"SMP_binary", "bcast SMP_binary collective", (void*)bcast__SMP_binary},
204        {"SMP_binomial", "bcast SMP_binomial collective", (void*)bcast__SMP_binomial},
205        {"SMP_linear", "bcast SMP_linear collective", (void*)bcast__SMP_linear},
206        {"ompi", "bcast ompi collective", (void*)bcast__ompi},
207        {"ompi_split_bintree", "bcast ompi_split_bintree collective", (void*)bcast__ompi_split_bintree},
208        {"ompi_pipeline", "bcast ompi_pipeline collective", (void*)bcast__ompi_pipeline},
209        {"mpich", "bcast mpich collective", (void*)bcast__mpich},
210        {"mvapich2", "bcast mvapich2 collective", (void*)bcast__mvapich2},
211        {"mvapich2_inter_node", "bcast mvapich2_inter_node collective", (void*)bcast__mvapich2_inter_node},
212        {"mvapich2_intra_node", "bcast mvapich2_intra_node collective", (void*)bcast__mvapich2_intra_node},
213        {"mvapich2_knomial_intra_node", "bcast mvapich2_knomial_intra_node collective",
214         (void*)bcast__mvapich2_knomial_intra_node},
215        {"impi", "bcast impi collective", (void*)bcast__impi},
216        {"automatic", "bcast automatic collective", (void*)bcast__automatic}}},
217
218      {"reduce",
219       {{"default", "reduce default collective", (void*)reduce__default},
220        {"arrival_pattern_aware", "reduce arrival_pattern_aware collective", (void*)reduce__arrival_pattern_aware},
221        {"binomial", "reduce binomial collective", (void*)reduce__binomial},
222        {"flat_tree", "reduce flat_tree collective", (void*)reduce__flat_tree},
223        {"NTSL", "reduce NTSL collective", (void*)reduce__NTSL},
224        {"scatter_gather", "reduce scatter_gather collective", (void*)reduce__scatter_gather},
225        {"ompi", "reduce ompi collective", (void*)reduce__ompi},
226        {"ompi_chain", "reduce ompi_chain collective", (void*)reduce__ompi_chain},
227        {"ompi_pipeline", "reduce ompi_pipeline collective", (void*)reduce__ompi_pipeline},
228        {"ompi_basic_linear", "reduce ompi_basic_linear collective", (void*)reduce__ompi_basic_linear},
229        {"ompi_in_order_binary", "reduce ompi_in_order_binary collective", (void*)reduce__ompi_in_order_binary},
230        {"ompi_binary", "reduce ompi_binary collective", (void*)reduce__ompi_binary},
231        {"ompi_binomial", "reduce ompi_binomial collective", (void*)reduce__ompi_binomial},
232        {"mpich", "reduce mpich collective", (void*)reduce__mpich},
233        {"mvapich2", "reduce mvapich2 collective", (void*)reduce__mvapich2},
234        {"mvapich2_knomial", "reduce mvapich2_knomial collective", (void*)reduce__mvapich2_knomial},
235        {"mvapich2_two_level", "reduce mvapich2_two_level collective", (void*)reduce__mvapich2_two_level},
236        {"impi", "reduce impi collective", (void*)reduce__impi},
237        {"rab", "reduce rab collective", (void*)reduce__rab},
238        {"automatic", "reduce automatic collective", (void*)reduce__automatic}}}});
239
240 // Needed by the automatic selector weird implementation
241 std::vector<s_mpi_coll_description_t>* colls::get_smpi_coll_descriptions(const std::string& name)
242 {
243   auto iter = smpi_coll_descriptions.find(name);
244   xbt_assert(iter != smpi_coll_descriptions.end(), "No collective named %s. This is a bug.", name.c_str());
245   return &iter->second;
246 }
247
248 std::string colls::get_smpi_coll_help()
249 {
250   size_t max_name_len =
251       std::accumulate(begin(smpi_coll_descriptions), end(smpi_coll_descriptions), 0, [](auto len, auto const& coll) {
252         return std::max(len, std::accumulate(begin(coll.second), end(coll.second), 0, [](auto len, auto const& descr) {
253                           return std::max<size_t>(len, descr.name.length());
254                         }));
255       });
256   std::ostringstream oss;
257   std::string title = "Available collective algorithms (select them with \"smpi/collective_name:algo_name\"):";
258   oss << title << '\n' << std::setfill('=') << std::setw(title.length() + 1) << '\n';
259   for (auto const& [coll, algos] : smpi_coll_descriptions) {
260     std::string line = "Collective: \"" + coll + "\"";
261     oss << line << '\n' << std::setfill('-') << std::setw(line.length() + 1) << '\n';
262     oss << std::setfill(' ') << std::left;
263     for (auto const& [name, descr, _] : algos)
264       oss << "  " << std::setw(max_name_len) << name << " " << descr << "\n";
265     oss << std::right << '\n';
266   }
267   oss << "Please see https://simgrid.org/doc/latest/app_smpi.html#available-algorithms for more information.\n";
268   return oss.str();
269 }
270
271 static s_mpi_coll_description_t* find_coll_description(const std::string& collective, const std::string& algo)
272 {
273   std::vector<s_mpi_coll_description_t>* table = colls::get_smpi_coll_descriptions(collective);
274   xbt_assert(not table->empty(), "No registered algorithm for collective '%s'! This is a bug.", collective.c_str());
275
276   for (auto& desc : *table) {
277     if (algo == desc.name) {
278       if (desc.name != "default")
279         XBT_INFO("Switch to algorithm %s for collective %s", desc.name.c_str(), collective.c_str());
280       return &desc;
281     }
282   }
283
284   std::string name_list = table->at(0).name;
285   for (unsigned long i = 1; i < table->size(); i++)
286     name_list = name_list + ", " + table->at(i).name;
287   xbt_die("Collective '%s' has no algorithm '%s'! Valid algorithms: %s. Please use --help-coll for details.",
288           collective.c_str(), algo.c_str(), name_list.c_str());
289 }
290
291 int (*colls::gather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
292                      MPI_Datatype recv_type, int root, MPI_Comm comm);
293 int (*colls::allgather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
294                         MPI_Datatype recv_type, MPI_Comm comm);
295 int (*colls::allgatherv)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff,
296                          const int* recv_count, const int* recv_disps, MPI_Datatype recv_type, MPI_Comm comm);
297 int (*colls::alltoall)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
298                        MPI_Datatype recv_type, MPI_Comm comm);
299 int (*colls::alltoallv)(const void* send_buff, const int* send_counts, const int* send_disps, MPI_Datatype send_type,
300                         void* recv_buff, const int* recv_counts, const int* recv_disps, MPI_Datatype recv_type,
301                         MPI_Comm comm);
302 int (*colls::bcast)(void* buf, int count, MPI_Datatype datatype, int root, MPI_Comm comm);
303 int (*colls::reduce)(const void* buf, void* rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
304 int (*colls::allreduce)(const void* sbuf, void* rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
305 int (*colls::reduce_scatter)(const void* sbuf, void* rbuf, const int* rcounts, MPI_Datatype dtype, MPI_Op op,
306                              MPI_Comm comm);
307 int (*colls::scatter)(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount,
308                       MPI_Datatype recvtype, int root, MPI_Comm comm);
309 int (*colls::barrier)(MPI_Comm comm);
310
311 void (*colls::smpi_coll_cleanup_callback)();
312
313 #define COLL_SETTER(cat, ret, args, args2)                                                                             \
314   void colls::_XBT_CONCAT(set_, cat)(const std::string& name)                                                          \
315   {                                                                                                                    \
316     auto desc = find_coll_description(_XBT_STRINGIFY(cat), name);                                                      \
317     cat       = reinterpret_cast<ret(*) args>(desc->coll);                                                             \
318     xbt_assert(cat != nullptr, "Collective " _XBT_STRINGIFY(cat) " set to nullptr!");                                  \
319   }
320 COLL_APPLY(COLL_SETTER, COLL_GATHER_SIG, "")
321 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"")
322 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"")
323 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"")
324 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"")
325 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"")
326 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"")
327 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"")
328 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"")
329 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"")
330 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"")
331
332 void colls::set_collectives()
333 {
334   std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
335   if (selector_name.empty())
336     selector_name = "default";
337
338   std::pair<std::string, std::function<void(std::string)>> setter_callbacks[] = {
339       {"gather", &colls::set_gather},         {"allgather", &colls::set_allgather},
340       {"allgatherv", &colls::set_allgatherv}, {"allreduce", &colls::set_allreduce},
341       {"alltoall", &colls::set_alltoall},     {"alltoallv", &colls::set_alltoallv},
342       {"reduce", &colls::set_reduce},         {"reduce_scatter", &colls::set_reduce_scatter},
343       {"scatter", &colls::set_scatter},       {"bcast", &colls::set_bcast},
344       {"barrier", &colls::set_barrier}};
345
346   for (auto& elem : setter_callbacks) {
347     std::string name = simgrid::config::get_value<std::string>(("smpi/" + elem.first).c_str());
348     if (name.empty())
349       name = selector_name;
350
351     (elem.second)(name);
352   }
353 }
354
355 //Implementations of the single algorithm collectives
356
357 int colls::gatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, const int* recvcounts,
358                    const int* displs, MPI_Datatype recvtype, int root, MPI_Comm comm)
359 {
360   MPI_Request request;
361   colls::igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, &request, 0);
362   return Request::wait(&request, MPI_STATUS_IGNORE);
363 }
364
365 int colls::scatterv(const void* sendbuf, const int* sendcounts, const int* displs, MPI_Datatype sendtype, void* recvbuf,
366                     int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
367 {
368   MPI_Request request;
369   colls::iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
370   return Request::wait(&request, MPI_STATUS_IGNORE);
371 }
372
373 int colls::scan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
374 {
375   int system_tag = -888;
376   MPI_Aint lb      = 0;
377   MPI_Aint dataext = 0;
378
379   int rank = comm->rank();
380   int size = comm->size();
381
382   datatype->extent(&lb, &dataext);
383
384   // Local copy from self
385   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
386
387   // Send/Recv buffers to/from others
388   auto* requests = new MPI_Request[size - 1];
389   auto** tmpbufs = new unsigned char*[rank];
390   int index = 0;
391   for (int other = 0; other < rank; other++) {
392     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
393     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
394     index++;
395   }
396   for (int other = rank + 1; other < size; other++) {
397     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
398     index++;
399   }
400   // Wait for completion of all comms.
401   Request::startall(size - 1, requests);
402
403   if(op != MPI_OP_NULL && op->is_commutative()){
404     for (int other = 0; other < size - 1; other++) {
405       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
406       if(index == MPI_UNDEFINED) {
407         break;
408       }
409       if(index < rank) {
410         // #Request is below rank: it's an irecv
411         op->apply( tmpbufs[index], recvbuf, &count, datatype);
412       }
413     }
414   }else{
415     //non commutative case, wait in order
416     for (int other = 0; other < size - 1; other++) {
417       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
418       if(index < rank && op!=MPI_OP_NULL) {
419         op->apply( tmpbufs[other], recvbuf, &count, datatype);
420       }
421     }
422   }
423   for(index = 0; index < rank; index++) {
424     smpi_free_tmp_buffer(tmpbufs[index]);
425   }
426   for(index = 0; index < size-1; index++) {
427     Request::unref(&requests[index]);
428   }
429   delete[] tmpbufs;
430   delete[] requests;
431   return MPI_SUCCESS;
432 }
433
434 int colls::exscan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
435 {
436   int system_tag = -888;
437   MPI_Aint lb         = 0;
438   MPI_Aint dataext    = 0;
439   int recvbuf_is_empty=1;
440   int rank = comm->rank();
441   int size = comm->size();
442
443   datatype->extent(&lb, &dataext);
444
445   // Send/Recv buffers to/from others
446   auto* requests = new MPI_Request[size - 1];
447   auto** tmpbufs = new unsigned char*[rank];
448   int index = 0;
449   for (int other = 0; other < rank; other++) {
450     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
451     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
452     index++;
453   }
454   for (int other = rank + 1; other < size; other++) {
455     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
456     index++;
457   }
458   // Wait for completion of all comms.
459   Request::startall(size - 1, requests);
460
461   if(op != MPI_OP_NULL && op->is_commutative()){
462     for (int other = 0; other < size - 1; other++) {
463       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
464       if(index == MPI_UNDEFINED) {
465         break;
466       }
467       if(index < rank) {
468         if(recvbuf_is_empty){
469           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
470           recvbuf_is_empty=0;
471         } else
472           // #Request is below rank: it's an irecv
473           op->apply( tmpbufs[index], recvbuf, &count, datatype);
474       }
475     }
476   }else{
477     //non commutative case, wait in order
478     for (int other = 0; other < size - 1; other++) {
479      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
480       if(index < rank) {
481         if (recvbuf_is_empty) {
482           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
483           recvbuf_is_empty = 0;
484         } else
485           if(op!=MPI_OP_NULL)
486             op->apply( tmpbufs[other], recvbuf, &count, datatype);
487       }
488     }
489   }
490   for(index = 0; index < rank; index++) {
491     smpi_free_tmp_buffer(tmpbufs[index]);
492   }
493   for(index = 0; index < size-1; index++) {
494     Request::unref(&requests[index]);
495   }
496   delete[] tmpbufs;
497   delete[] requests;
498   return MPI_SUCCESS;
499 }
500
501 int colls::alltoallw(const void* sendbuf, const int* sendcounts, const int* senddisps, const MPI_Datatype* sendtypes,
502                      void* recvbuf, const int* recvcounts, const int* recvdisps, const MPI_Datatype* recvtypes,
503                      MPI_Comm comm)
504 {
505   MPI_Request request;
506   colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm,
507                     &request, 0);
508   return Request::wait(&request, MPI_STATUS_IGNORE);
509 }
510
511 } // namespace simgrid::smpi