1 /* smpi_coll.c -- various optimized routing for collectives */
3 /* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved. */
5 /* This program is free software; you can redistribute it and/or modify it
6 * under the terms of the license (GNU LGPL) which comes with this package. */
8 #include "smpi_coll.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI collectives.");
21 /* these arrays must be nullptr terminated */
22 s_mpi_coll_description_t mpi_coll_gather_description[] = {
23 {"default","gather default collective", (void*)gather__default},
24 {"ompi","gather ompi collective", (void*)gather__ompi},
25 {"ompi_basic_linear","gather ompi_basic_linear collective", (void*)gather__ompi_basic_linear},
26 {"ompi_binomial", "gather ompi_binomial collective", (void*)gather__ompi_binomial},
27 {"ompi_linear_sync", "gather ompi_linear_sync collective", (void*)gather__ompi_linear_sync},
28 {"mpich", "gather mpich collective", (void*)gather__mpich},
29 {"mvapich2","gather mvapich2 collective", (void*)gather__mvapich2},
30 {"mvapich2_two_level", "gather mvapich2_two_level collective", (void*)gather__mvapich2_two_level},
31 {"impi","gather impi collective", (void*)gather__impi},
32 {"automatic","gather automatic collective",(void*)gather__automatic},
34 s_mpi_coll_description_t mpi_coll_allgather_description[] = {
35 {"default", "allgather default collective", (void*)allgather__default},
36 {"2dmesh", "allgather 2dmesh collective", (void*)allgather__2dmesh},
37 {"3dmesh", "allgather 3dmesh collective", (void*)allgather__3dmesh},
38 {"bruck", "allgather bruck collective", (void*)allgather__bruck},
39 {"GB", "allgather GB collective", (void*)allgather__GB},
40 {"loosely_lr", "allgather loosely_lr collective", (void*)allgather__loosely_lr},
41 {"NTSLR", "allgather NTSLR collective", (void*)allgather__NTSLR},
42 {"NTSLR_NB", "allgather NTSLR_NB collective", (void*)allgather__NTSLR_NB},
43 {"pair", "allgather pair collective", (void*)allgather__pair},
44 {"rdb", "allgather rdb collective", (void*)allgather__rdb},
45 {"rhv", "allgather rhv collective", (void*)allgather__rhv},
46 {"ring", "allgather ring collective", (void*)allgather__ring },
47 {"SMP_NTS", "allgather SMP_NTS collective", (void*)allgather__SMP_NTS},
48 {"smp_simple", "allgather smp_simple collective", (void*)allgather__smp_simple},
49 {"spreading_simple", "allgather spreading_simple collective", (void*)allgather__spreading_simple},
50 {"ompi", "allgather ompi collective", (void*)allgather__ompi},
51 {"ompi_neighborexchange", "allgather ompi_neighborexchange collective", (void*)allgather__ompi_neighborexchange},
52 {"mvapich2", "allgather mvapich2 collective", (void*)allgather__mvapich2},
53 {"mvapich2_smp", "allgather mvapich2_smp collective", (void*)allgather__mvapich2_smp},
54 {"mpich", "allgather mpich collective", (void*)allgather__mpich},
55 {"impi", "allgather impi collective", (void*)allgather__impi},
56 {"automatic", "allgather automatic collective", (void*)allgather__automatic},
58 s_mpi_coll_description_t mpi_coll_allgatherv_description[] = {
59 {"default", "allgatherv default collective", (void*)allgatherv__default},
60 {"GB", "allgatherv GB collective", (void*)allgatherv__GB},
61 {"pair", "allgatherv pair collective", (void*)allgatherv__pair},
62 {"ring", "allgatherv ring collective", (void*)allgatherv__ring},
63 {"ompi", "allgatherv ompi collective", (void*)allgatherv__ompi},
64 {"ompi_neighborexchange", "allgatherv ompi_neighborexchange collective", (void*)allgatherv__ompi_neighborexchange},
65 {"ompi_bruck", "allgatherv ompi_bruck collective", (void*)allgatherv__ompi_bruck},
66 {"mpich", "allgatherv mpich collective", (void*)allgatherv__mpich},
67 {"mpich_rdb", "allgatherv mpich_rdb collective", (void*)allgatherv__mpich_rdb},
68 {"mpich_ring", "allgatherv mpich_ring collective", (void*)allgatherv__mpich_ring},
69 {"mvapich2", "allgatherv mvapich2 collective", (void*)allgatherv__mvapich2},
70 {"impi", "allgatherv impi collective", (void*)allgatherv__impi},
71 {"automatic", "allgatherv automatic collective", (void*)allgatherv__automatic},
73 s_mpi_coll_description_t mpi_coll_allreduce_description[] = {
74 {"default", "allreduce default collective", (void*)allreduce__default},
75 {"lr", "allreduce lr collective", (void*)allreduce__lr},
76 {"rab1", "allreduce rab1 collective", (void*)allreduce__rab1},
77 {"rab2", "allreduce rab2 collective", (void*)allreduce__rab2},
78 {"rab_rdb", "allreduce rab_rdb collective", (void*)allreduce__rab_rdb},
79 {"rdb", "allreduce rdb collective", (void*)allreduce__rdb},
80 {"smp_binomial", "allreduce smp_binomial collective", (void*)allreduce__smp_binomial},
81 {"smp_binomial_pipeline", "allreduce smp_binomial_pipeline collective", (void*)allreduce__smp_binomial_pipeline},
82 {"smp_rdb", "allreduce smp_rdb collective", (void*)allreduce__smp_rdb},
83 {"smp_rsag", "allreduce smp_rsag collective", (void*)allreduce__smp_rsag},
84 {"smp_rsag_lr", "allreduce smp_rsag_lr collective", (void*)allreduce__smp_rsag_lr},
85 {"smp_rsag_rab", "allreduce smp_rsag_rab collective", (void*)allreduce__smp_rsag_rab},
86 {"redbcast", "allreduce redbcast collective", (void*)allreduce__redbcast},
87 {"ompi", "allreduce ompi collective", (void*)allreduce__ompi},
88 {"ompi_ring_segmented", "allreduce ompi_ring_segmented collective", (void*)allreduce__ompi_ring_segmented},
89 {"mpich", "allreduce mpich collective", (void*)allreduce__mpich},
90 {"mvapich2", "allreduce mvapich2 collective", (void*)allreduce__mvapich2},
91 {"mvapich2_rs", "allreduce mvapich2_rs collective", (void*)allreduce__mvapich2_rs},
92 {"mvapich2_two_level", "allreduce mvapich2_two_level collective", (void*)allreduce__mvapich2_two_level},
93 {"impi", "allreduce impi collective", (void*)allreduce__impi},
94 {"rab", "allreduce rab collective", (void*)allreduce__rab},
95 {"automatic", "allreduce automatic collective", (void*)allreduce__automatic},
97 s_mpi_coll_description_t mpi_coll_reduce_scatter_description[] = {
98 {"default", "reduce_scatter default collective", (void*)reduce_scatter__default},
99 {"ompi", "reduce_scatter ompi collective", (void*)reduce_scatter__ompi},
100 {"ompi_basic_recursivehalving", "reduce_scatter ompi_basic_recursivehalving collective", (void*)reduce_scatter__ompi_basic_recursivehalving},
101 {"ompi_ring", "reduce_scatter ompi_ring collective", (void*)reduce_scatter__ompi_ring},
102 {"mpich", "reduce_scatter mpich collective", (void*)reduce_scatter__mpich},
103 {"mpich_pair", "reduce_scatter mpich_pair collective", (void*)reduce_scatter__mpich_pair},
104 {"mpich_rdb", "reduce_scatter mpich_rdb collective", (void*)reduce_scatter__mpich_rdb},
105 {"mpich_noncomm", "reduce_scatter mpich_noncomm collective", (void*)reduce_scatter__mpich_noncomm},
106 {"mvapich2", "reduce_scatter mvapich2 collective", (void*)reduce_scatter__mvapich2},
107 {"impi", "reduce_scatter impi collective", (void*)reduce_scatter__impi},
108 {"automatic", "reduce_scatter automatic collective", (void*)reduce_scatter__automatic},
110 s_mpi_coll_description_t mpi_coll_scatter_description[] = {
111 {"default", "scatter default collective", (void*)scatter__default},
112 {"ompi", "scatter ompi collective", (void*)scatter__ompi},
113 {"ompi_basic_linear", "scatter ompi_basic_linear collective", (void*)scatter__ompi_basic_linear},
114 {"ompi_binomial", "scatter ompi_binomial collective", (void*)scatter__ompi_binomial},
115 {"mpich", "scatter mpich collective", (void*)scatter__mpich},
116 {"mvapich2", "scatter mvapich2 collective", (void*)scatter__mvapich2},
117 {"mvapich2_two_level_binomial", "scatter mvapich2_two_level_binomial collective", (void*)scatter__mvapich2_two_level_binomial},
118 {"mvapich2_two_level_direct", "scatter mvapich2_two_level_direct collective", (void*)scatter__mvapich2_two_level_direct},
119 {"impi", "scatter impi collective", (void*)scatter__impi},
120 {"automatic", "scatter automatic collective", (void*)scatter__automatic},
122 s_mpi_coll_description_t mpi_coll_barrier_description[] = {
123 {"default", "barrier default collective", (void*)barrier__default},
124 {"ompi", "barrier ompi collective", (void*)barrier__ompi},
125 {"ompi_basic_linear", "barrier ompi_basic_linear collective", (void*)barrier__ompi_basic_linear},
126 {"ompi_two_procs", "barrier ompi_two_procs collective", (void*)barrier__ompi_two_procs},
127 {"ompi_tree", "barrier ompi_tree collective", (void*)barrier__ompi_tree},
128 {"ompi_bruck", "barrier ompi_bruck collective", (void*)barrier__ompi_bruck},
129 {"ompi_recursivedoubling", "barrier ompi_recursivedoubling collective", (void*)barrier__ompi_recursivedoubling},
130 {"ompi_doublering", "barrier ompi_doublering collective", (void*)barrier__ompi_doublering},
131 {"mpich_smp", "barrier mpich_smp collective", (void*)barrier__mpich_smp},
132 {"mpich", "barrier mpich collective", (void*)barrier__mpich},
133 {"mvapich2_pair", "barrier mvapich2_pair collective", (void*)barrier__mvapich2_pair},
134 {"mvapich2", "barrier mvapich2 collective", (void*)barrier__mvapich2},
135 {"impi", "barrier impi collective", (void*)barrier__impi},
136 {"automatic", "barrier automatic collective", (void*)barrier__automatic},
138 s_mpi_coll_description_t mpi_coll_alltoall_description[] = {
139 {"default", "alltoall default collective", (void*)alltoall__default},
140 {"2dmesh", "alltoall 2dmesh collective", (void*)alltoall__2dmesh},
141 {"3dmesh", "alltoall 3dmesh collective", (void*)alltoall__3dmesh},
142 {"basic_linear", "alltoall basic_linear collective", (void*)alltoall__basic_linear},
143 {"bruck", "alltoall bruck collective", (void*)alltoall__bruck},
144 {"pair", "alltoall pair collective", (void*)alltoall__pair},
145 {"pair_rma", "alltoall pair_rma collective", (void*)alltoall__pair_rma},
146 {"pair_light_barrier", "alltoall pair_light_barrier collective", (void*)alltoall__pair_light_barrier},
147 {"pair_mpi_barrier", "alltoall pair_mpi_barrier collective", (void*)alltoall__pair_mpi_barrier},
148 {"pair_one_barrier", "alltoall pair_one_barrier collective", (void*)alltoall__pair_one_barrier},
149 {"rdb", "alltoall rdb collective", (void*)alltoall__rdb},
150 {"ring", "alltoall ring collective", (void*)alltoall__ring},
151 {"ring_light_barrier", "alltoall ring_light_barrier collective", (void*)alltoall__ring_light_barrier},
152 {"ring_mpi_barrier", "alltoall ring_mpi_barrier collective", (void*)alltoall__ring_mpi_barrier},
153 {"ring_one_barrier", "alltoall ring_one_barrier collective", (void*)alltoall__ring_one_barrier},
154 {"mvapich2", "alltoall mvapich2 collective", (void*)alltoall__mvapich2},
155 {"mvapich2_scatter_dest", "alltoall mvapich2_scatter_dest collective", (void*)alltoall__mvapich2_scatter_dest},
156 {"ompi", "alltoall ompi collective", (void*)alltoall__ompi},
157 {"mpich", "alltoall mpich collective", (void*)alltoall__mpich},
158 {"impi", "alltoall impi collective", (void*)alltoall__impi},
159 {"automatic", "alltoall automatic collective", (void*)alltoall__automatic},
161 s_mpi_coll_description_t mpi_coll_alltoallv_description[] = {
162 {"default", "alltoallv default collective", (void*)alltoallv__default},
163 {"bruck", "alltoallv bruck collective", (void*)alltoallv__bruck},
164 {"pair", "alltoallv pair collective", (void*)alltoallv__pair},
165 {"pair_light_barrier", "alltoallv pair_light_barrier collective", (void*)alltoallv__pair_light_barrier},
166 {"pair_mpi_barrier", "alltoallv pair_mpi_barrier collective", (void*)alltoallv__pair_mpi_barrier},
167 {"pair_one_barrier", "alltoallv pair_one_barrier collective", (void*)alltoallv__pair_one_barrier},
168 {"ring", "alltoallv ring collective", (void*)alltoallv__ring},
169 {"ring_light_barrier", "alltoallv ring_light_barrier collective", (void*)alltoallv__ring_light_barrier},
170 {"ring_mpi_barrier", "alltoallv ring_mpi_barrier collective", (void*)alltoallv__ring_mpi_barrier},
171 {"ring_one_barrier", "alltoallv ring_one_barrier collective", (void*)alltoallv__ring_one_barrier},
172 {"ompi", "alltoallv ompi collective", (void*)alltoallv__ompi},
173 {"mpich", "alltoallv mpich collective", (void*)alltoallv__mpich},
174 {"ompi_basic_linear", "alltoallv ompi_basic_linear collective", (void*)alltoallv__ompi_basic_linear},
175 {"mvapich2", "alltoallv mvapich2 collective", (void*)alltoallv__mvapich2},
176 {"impi", "alltoallv impi collective", (void*)alltoallv__impi},
177 {"automatic", "alltoallv automatic collective", (void*)alltoallv__automatic},
179 s_mpi_coll_description_t mpi_coll_bcast_description[] = {
180 {"default", "bcast default collective", (void*)bcast__default},
181 {"arrival_pattern_aware", "bcast arrival_pattern_aware collective", (void*)bcast__arrival_pattern_aware},
182 {"arrival_pattern_aware_wait", "bcast arrival_pattern_aware_wait collective", (void*)bcast__arrival_pattern_aware_wait},
183 {"arrival_scatter", "bcast arrival_scatter collective", (void*)bcast__arrival_scatter},
184 {"binomial_tree", "bcast binomial_tree collective", (void*)bcast__binomial_tree},
185 {"flattree", "bcast flattree collective", (void*)bcast__flattree},
186 {"flattree_pipeline", "bcast flattree_pipeline collective", (void*)bcast__flattree_pipeline},
187 {"NTSB", "bcast NTSB collective", (void*)bcast__NTSB},
188 {"NTSL", "bcast NTSL collective", (void*)bcast__NTSL},
189 {"NTSL_Isend", "bcast NTSL_Isend collective", (void*)bcast__NTSL_Isend},
190 {"scatter_LR_allgather", "bcast scatter_LR_allgather collective", (void*)bcast__scatter_LR_allgather},
191 {"scatter_rdb_allgather", "bcast scatter_rdb_allgather collective", (void*)bcast__scatter_rdb_allgather},
192 {"SMP_binary", "bcast SMP_binary collective", (void*)bcast__SMP_binary},
193 {"SMP_binomial", "bcast SMP_binomial collective", (void*)bcast__SMP_binomial},
194 {"SMP_linear", "bcast SMP_linear collective", (void*)bcast__SMP_linear},
195 {"ompi", "bcast ompi collective", (void*)bcast__ompi},
196 {"ompi_split_bintree", "bcast ompi_split_bintree collective", (void*)bcast__ompi_split_bintree},
197 {"ompi_pipeline", "bcast ompi_pipeline collective", (void*)bcast__ompi_pipeline},
198 {"mpich", "bcast mpich collective", (void*)bcast__mpich},
199 {"mvapich2", "bcast mvapich2 collective", (void*)bcast__mvapich2},
200 {"mvapich2_inter_node", "bcast mvapich2_inter_node collective", (void*)bcast__mvapich2_inter_node},
201 {"mvapich2_intra_node", "bcast mvapich2_intra_node collective", (void*)bcast__mvapich2_intra_node},
202 {"mvapich2_knomial_intra_node", "bcast mvapich2_knomial_intra_node collective", (void*)bcast__mvapich2_knomial_intra_node},
203 {"impi", "bcast impi collective", (void*)bcast__impi},
204 {"automatic", "bcast automatic collective", (void*)bcast__automatic},
206 s_mpi_coll_description_t mpi_coll_reduce_description[] = {
207 {"default", "reduce default collective", (void*)reduce__default},
208 {"arrival_pattern_aware", "reduce arrival_pattern_aware collective", (void*)reduce__arrival_pattern_aware},
209 {"binomial", "reduce binomial collective", (void*)reduce__binomial},
210 {"flat_tree", "reduce flat_tree collective", (void*)reduce__flat_tree},
211 {"NTSL", "reduce NTSL collective", (void*)reduce__NTSL},
212 {"scatter_gather", "reduce scatter_gather collective", (void*)reduce__scatter_gather},
213 {"ompi", "reduce ompi collective", (void*)reduce__ompi},
214 {"ompi_chain", "reduce ompi_chain collective", (void*)reduce__ompi_chain},
215 {"ompi_pipeline", "reduce ompi_pipeline collective", (void*)reduce__ompi_pipeline},
216 {"ompi_basic_linear", "reduce ompi_basic_linear collective", (void*)reduce__ompi_basic_linear},
217 {"ompi_in_order_binary", "reduce ompi_in_order_binary collective", (void*)reduce__ompi_in_order_binary},
218 {"ompi_binary", "reduce ompi_binary collective", (void*)reduce__ompi_binary},
219 {"ompi_binomial", "reduce ompi_binomial collective", (void*)reduce__ompi_binomial},
220 {"mpich", "reduce mpich collective", (void*)reduce__mpich},
221 {"mvapich2", "reduce mvapich2 collective", (void*)reduce__mvapich2},
222 {"mvapich2_knomial", "reduce mvapich2_knomial collective", (void*)reduce__mvapich2_knomial},
223 {"mvapich2_two_level", "reduce mvapich2_two_level collective", (void*)reduce__mvapich2_two_level},
224 {"impi", "reduce impi collective", (void*)reduce__impi},
225 {"rab", "reduce rab collective", (void*)reduce__rab},
226 {"automatic", "reduce automatic collective", (void*)reduce__automatic},
229 // Needed by the automatic selector weird implementation
230 s_mpi_coll_description_t* colls::get_smpi_coll_description(const char* name, int rank)
232 if (strcmp(name, "gather") == 0)
233 return &mpi_coll_gather_description[rank];
234 if (strcmp(name, "allgather") == 0)
235 return &mpi_coll_allgather_description[rank];
236 if (strcmp(name, "allgatherv") == 0)
237 return &mpi_coll_allgatherv_description[rank];
238 if (strcmp(name, "allreduce") == 0)
239 return &mpi_coll_allreduce_description[rank];
240 if (strcmp(name, "reduce_scatter") == 0)
241 return &mpi_coll_reduce_scatter_description[rank];
242 if (strcmp(name, "scatter") == 0)
243 return &mpi_coll_scatter_description[rank];
244 if (strcmp(name, "barrier") == 0)
245 return &mpi_coll_barrier_description[rank];
246 if (strcmp(name, "alltoall") == 0)
247 return &mpi_coll_alltoall_description[rank];
248 if (strcmp(name, "alltoallv") == 0)
249 return &mpi_coll_alltoallv_description[rank];
250 if (strcmp(name, "bcast") == 0)
251 return &mpi_coll_bcast_description[rank];
252 if (strcmp(name, "reduce") == 0)
253 return &mpi_coll_reduce_description[rank];
254 XBT_INFO("You requested an unknown collective: %s", name);
258 /** Displays the long description of all registered models, and quit */
259 void colls::coll_help(const char* category, s_mpi_coll_description_t* table)
261 XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
262 for (int i = 0; not table[i].name.empty(); i++)
263 XBT_WARN(" %s: %s\n", table[i].name.c_str(), table[i].description.c_str());
266 int colls::find_coll_description(s_mpi_coll_description_t* table, const std::string& name, const char* desc)
268 for (int i = 0; not table[i].name.empty(); i++)
269 if (name == table[i].name) {
270 if (table[i].name != "default")
271 XBT_INFO("Switch to algorithm %s for collective %s",table[i].name.c_str(),desc);
275 if (table[0].name.empty())
276 xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
277 std::string name_list = table[0].name;
278 for (int i = 1; not table[i].name.empty(); i++)
279 name_list = name_list + ", " + table[i].name;
281 xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
285 int (*colls::gather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
286 MPI_Datatype recv_type, int root, MPI_Comm comm);
287 int (*colls::allgather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
288 MPI_Datatype recv_type, MPI_Comm comm);
289 int (*colls::allgatherv)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff,
290 const int* recv_count, const int* recv_disps, MPI_Datatype recv_type, MPI_Comm comm);
291 int (*colls::alltoall)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
292 MPI_Datatype recv_type, MPI_Comm comm);
293 int (*colls::alltoallv)(const void* send_buff, const int* send_counts, const int* send_disps, MPI_Datatype send_type,
294 void* recv_buff, const int* recv_counts, const int* recv_disps, MPI_Datatype recv_type,
296 int (*colls::bcast)(void* buf, int count, MPI_Datatype datatype, int root, MPI_Comm comm);
297 int (*colls::reduce)(const void* buf, void* rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
298 int (*colls::allreduce)(const void* sbuf, void* rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
299 int (*colls::reduce_scatter)(const void* sbuf, void* rbuf, const int* rcounts, MPI_Datatype dtype, MPI_Op op,
301 int (*colls::scatter)(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount,
302 MPI_Datatype recvtype, int root, MPI_Comm comm);
303 int (*colls::barrier)(MPI_Comm comm);
305 void (*colls::smpi_coll_cleanup_callback)();
307 #define COLL_SETTER(cat, ret, args, args2) \
308 void colls::_XBT_CONCAT(set_, cat)(const std::string& name) \
310 int id = find_coll_description(_XBT_CONCAT3(mpi_coll_, cat, _description), name, _XBT_STRINGIFY(cat)); \
311 cat = reinterpret_cast<ret(*) args>(_XBT_CONCAT3(mpi_coll_, cat, _description)[id].coll); \
312 if (cat == nullptr) \
313 xbt_die("Collective " _XBT_STRINGIFY(cat) " set to nullptr!"); \
315 COLL_APPLY(COLL_SETTER, COLL_GATHER_SIG, "");
316 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
317 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
318 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
319 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
320 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
321 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
322 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
323 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
324 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
325 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
327 void colls::set_collectives()
329 std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
330 if (selector_name.empty())
331 selector_name = "default";
333 std::pair<std::string, std::function<void(std::string)>> setter_callbacks[] = {
334 {"gather", &colls::set_gather}, {"allgather", &colls::set_allgather},
335 {"allgatherv", &colls::set_allgatherv}, {"allreduce", &colls::set_allreduce},
336 {"alltoall", &colls::set_alltoall}, {"alltoallv", &colls::set_alltoallv},
337 {"reduce", &colls::set_reduce}, {"reduce_scatter", &colls::set_reduce_scatter},
338 {"scatter", &colls::set_scatter}, {"bcast", &colls::set_bcast},
339 {"barrier", &colls::set_barrier}};
341 for (auto& elem : setter_callbacks) {
342 std::string name = simgrid::config::get_value<std::string>(("smpi/" + elem.first).c_str());
344 name = selector_name;
350 //Implementations of the single algorithm collectives
352 int colls::gatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, const int* recvcounts,
353 const int* displs, MPI_Datatype recvtype, int root, MPI_Comm comm)
356 colls::igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, &request, 0);
357 return Request::wait(&request, MPI_STATUS_IGNORE);
360 int colls::scatterv(const void* sendbuf, const int* sendcounts, const int* displs, MPI_Datatype sendtype, void* recvbuf,
361 int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
364 colls::iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
365 return Request::wait(&request, MPI_STATUS_IGNORE);
368 int colls::scan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
370 int system_tag = -888;
372 MPI_Aint dataext = 0;
374 int rank = comm->rank();
375 int size = comm->size();
377 datatype->extent(&lb, &dataext);
379 // Local copy from self
380 Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
382 // Send/Recv buffers to/from others
383 MPI_Request* requests = new MPI_Request[size - 1];
384 unsigned char** tmpbufs = new unsigned char*[rank];
386 for (int other = 0; other < rank; other++) {
387 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
388 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
391 for (int other = rank + 1; other < size; other++) {
392 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
395 // Wait for completion of all comms.
396 Request::startall(size - 1, requests);
398 if(op != MPI_OP_NULL && op->is_commutative()){
399 for (int other = 0; other < size - 1; other++) {
400 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
401 if(index == MPI_UNDEFINED) {
405 // #Request is below rank: it's a irecv
406 op->apply( tmpbufs[index], recvbuf, &count, datatype);
410 //non commutative case, wait in order
411 for (int other = 0; other < size - 1; other++) {
412 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
413 if(index < rank && op!=MPI_OP_NULL) {
414 op->apply( tmpbufs[other], recvbuf, &count, datatype);
418 for(index = 0; index < rank; index++) {
419 smpi_free_tmp_buffer(tmpbufs[index]);
421 for(index = 0; index < size-1; index++) {
422 Request::unref(&requests[index]);
429 int colls::exscan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
431 int system_tag = -888;
433 MPI_Aint dataext = 0;
434 int recvbuf_is_empty=1;
435 int rank = comm->rank();
436 int size = comm->size();
438 datatype->extent(&lb, &dataext);
440 // Send/Recv buffers to/from others
441 MPI_Request* requests = new MPI_Request[size - 1];
442 unsigned char** tmpbufs = new unsigned char*[rank];
444 for (int other = 0; other < rank; other++) {
445 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
446 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
449 for (int other = rank + 1; other < size; other++) {
450 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
453 // Wait for completion of all comms.
454 Request::startall(size - 1, requests);
456 if(op != MPI_OP_NULL && op->is_commutative()){
457 for (int other = 0; other < size - 1; other++) {
458 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
459 if(index == MPI_UNDEFINED) {
463 if(recvbuf_is_empty){
464 Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
467 // #Request is below rank: it's a irecv
468 op->apply( tmpbufs[index], recvbuf, &count, datatype);
472 //non commutative case, wait in order
473 for (int other = 0; other < size - 1; other++) {
474 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
476 if (recvbuf_is_empty) {
477 Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
478 recvbuf_is_empty = 0;
481 op->apply( tmpbufs[other], recvbuf, &count, datatype);
485 for(index = 0; index < rank; index++) {
486 smpi_free_tmp_buffer(tmpbufs[index]);
488 for(index = 0; index < size-1; index++) {
489 Request::unref(&requests[index]);
496 int colls::alltoallw(const void* sendbuf, const int* sendcounts, const int* senddisps, const MPI_Datatype* sendtypes,
497 void* recvbuf, const int* recvcounts, const int* recvdisps, const MPI_Datatype* recvtypes,
501 colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm,
503 return Request::wait(&request, MPI_STATUS_IGNORE);