1 /* smpi_coll.c -- various optimized routing for collectives */
3 /* Copyright (c) 2009-2015. The SimGrid Team.
4 * All rights reserved. */
6 /* This program is free software; you can redistribute it and/or modify it
7 * under the terms of the license (GNU LGPL) which comes with this package. */
14 #include "colls/colls.h"
15 #include "simgrid/sg_config.h"
17 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
19 s_mpi_coll_description_t mpi_coll_gather_description[] = {
20 COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} /* this array must be nullptr terminated */
23 s_mpi_coll_description_t mpi_coll_allgather_description[] = {
24 COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}
27 s_mpi_coll_description_t mpi_coll_allgatherv_description[] = { COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA),
28 {nullptr, nullptr, nullptr} /* this array must be nullptr terminated */
31 s_mpi_coll_description_t mpi_coll_allreduce_description[] ={ COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA),
32 {nullptr, nullptr, nullptr} /* this array must be nullptr terminated */
35 s_mpi_coll_description_t mpi_coll_reduce_scatter_description[] = {COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA),
36 {nullptr, nullptr, nullptr} /* this array must be nullptr terminated */
39 s_mpi_coll_description_t mpi_coll_scatter_description[] ={COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
41 s_mpi_coll_description_t mpi_coll_barrier_description[] ={COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
43 s_mpi_coll_description_t mpi_coll_alltoall_description[] = {COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
45 s_mpi_coll_description_t mpi_coll_alltoallv_description[] = {COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA),
46 {nullptr, nullptr, nullptr} /* this array must be nullptr terminated */
49 s_mpi_coll_description_t mpi_coll_bcast_description[] = {COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr}};
51 s_mpi_coll_description_t mpi_coll_reduce_description[] = {COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
55 /** Displays the long description of all registered models, and quit */
56 void coll_help(const char *category, s_mpi_coll_description_t * table)
58 printf("Long description of the %s models accepted by this simulator:\n", category);
59 for (int i = 0; table[i].name; i++)
60 printf(" %s: %s\n", table[i].name, table[i].description);
63 int find_coll_description(s_mpi_coll_description_t * table, const char *name, const char *desc)
65 char *name_list = nullptr;
67 if (name==nullptr || name[0] == '\0') {
68 //no argument provided, use active selector's algorithm
69 name=static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
72 for (int i = 0; table[i].name; i++)
73 if (!strcmp(name, table[i].name)) {
74 if (strcmp(table[i].name,"default"))
75 XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
80 // collective seems not handled by the active selector, try with default one
81 for (int i = 0; table[i].name; i++)
82 if (!strcmp("default", table[i].name)) {
87 xbt_die("No collective is valid for '%s'! This is a bug.",name);
88 name_list = xbt_strdup(table[0].name);
89 for (int i = 1; table[i].name; i++) {
90 name_list = static_cast<char*>(xbt_realloc(name_list, strlen(name_list) + strlen(table[i].name) + 3));
91 strncat(name_list, ", ",2);
92 strncat(name_list, table[i].name, strlen(table[i].name));
94 xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name, name_list);
98 void (*smpi_coll_cleanup_callback)();
103 int (*Colls::gather)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, int root, MPI_Comm);
104 int (*Colls::allgather)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm);
105 int (*Colls::allgatherv)(void *, int, MPI_Datatype, void*, int*, int*, MPI_Datatype, MPI_Comm);
106 int (*Colls::allreduce)(void *sbuf, void *rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
107 int (*Colls::alltoall)(void *, int, MPI_Datatype, void*, int, MPI_Datatype, MPI_Comm);
108 int (*Colls::alltoallv)(void *, int*, int*, MPI_Datatype, void*, int*, int*, MPI_Datatype, MPI_Comm);
109 int (*Colls::bcast)(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm com);
110 int (*Colls::reduce)(void *buf, void *rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
111 int (*Colls::reduce_scatter)(void *sbuf, void *rbuf, int *rcounts,MPI_Datatype dtype,MPI_Op op,MPI_Comm comm);
112 int (*Colls::scatter)(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbuf, int recvcount, MPI_Datatype recvtype,int root, MPI_Comm comm);
113 int (*Colls::barrier)(MPI_Comm comm);
116 #define COLL_SETTER(cat, ret, args, args2)\
117 void Colls::set_##cat (const char * name){\
118 int id = find_coll_description(mpi_coll_## cat ##_description,\
120 cat = reinterpret_cast<ret (*) args>\
121 (mpi_coll_## cat ##_description[id].coll);\
124 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
125 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
126 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
127 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
128 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
129 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
130 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
131 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
132 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
133 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
134 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
137 void Colls::set_collectives(){
138 const char* selector_name = static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
139 if (selector_name==nullptr || selector_name[0] == '\0')
140 selector_name = "default";
142 const char* name = xbt_cfg_get_string("smpi/gather");
143 if (name==nullptr || name[0] == '\0')
144 name = selector_name;
148 name = xbt_cfg_get_string("smpi/allgather");
149 if (name==nullptr || name[0] == '\0')
150 name = selector_name;
154 name = xbt_cfg_get_string("smpi/allgatherv");
155 if (name==nullptr || name[0] == '\0')
156 name = selector_name;
158 set_allgatherv(name);
160 name = xbt_cfg_get_string("smpi/allreduce");
161 if (name==nullptr || name[0] == '\0')
162 name = selector_name;
166 name = xbt_cfg_get_string("smpi/alltoall");
167 if (name==nullptr || name[0] == '\0')
168 name = selector_name;
172 name = xbt_cfg_get_string("smpi/alltoallv");
173 if (name==nullptr || name[0] == '\0')
174 name = selector_name;
178 name = xbt_cfg_get_string("smpi/reduce");
179 if (name==nullptr || name[0] == '\0')
180 name = selector_name;
184 name = xbt_cfg_get_string("smpi/reduce-scatter");
185 if (name==nullptr || name[0] == '\0')
186 name = selector_name;
188 set_reduce_scatter(name);
190 name = xbt_cfg_get_string("smpi/scatter");
191 if (name==nullptr || name[0] == '\0')
192 name = selector_name;
196 name = xbt_cfg_get_string("smpi/bcast");
197 if (name==nullptr || name[0] == '\0')
198 name = selector_name;
202 name = xbt_cfg_get_string("smpi/barrier");
203 if (name==nullptr || name[0] == '\0')
204 name = selector_name;
210 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
211 MPI_Datatype recvtype, int root, MPI_Comm comm)
213 int system_tag = COLL_TAG_GATHERV;
215 MPI_Aint recvext = 0;
217 int rank = comm->rank();
218 int size = comm->size();
220 // Send buffer to root
221 Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
223 recvtype->extent(&lb, &recvext);
224 // Local copy from root
225 Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
226 recvcounts[root], recvtype);
227 // Receive buffers from senders
228 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
230 for (int src = 0; src < size; src++) {
232 requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
233 recvcounts[src], recvtype, src, system_tag, comm);
237 // Wait for completion of irecv's.
238 Request::startall(size - 1, requests);
239 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
240 for (int src = 0; src < size-1; src++) {
241 Request::unref(&requests[src]);
249 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
250 MPI_Datatype recvtype, int root, MPI_Comm comm)
252 int system_tag = COLL_TAG_SCATTERV;
254 MPI_Aint sendext = 0;
256 int rank = comm->rank();
257 int size = comm->size();
259 // Recv buffer from root
260 Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
262 sendtype->extent(&lb, &sendext);
263 // Local copy from root
264 if(recvbuf!=MPI_IN_PLACE){
265 Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
266 sendtype, recvbuf, recvcount, recvtype);
268 // Send buffers to receivers
269 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
271 for (int dst = 0; dst < size; dst++) {
273 requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
274 sendtype, dst, system_tag, comm);
278 // Wait for completion of isend's.
279 Request::startall(size - 1, requests);
280 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
281 for (int dst = 0; dst < size-1; dst++) {
282 Request::unref(&requests[dst]);
290 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
292 int system_tag = -888;
294 MPI_Aint dataext = 0;
296 int rank = comm->rank();
297 int size = comm->size();
299 datatype->extent(&lb, &dataext);
301 // Local copy from self
302 Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
304 // Send/Recv buffers to/from others;
305 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
306 void **tmpbufs = xbt_new(void *, rank);
308 for (int other = 0; other < rank; other++) {
309 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
310 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
313 for (int other = rank + 1; other < size; other++) {
314 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
317 // Wait for completion of all comms.
318 Request::startall(size - 1, requests);
320 if(op != MPI_OP_NULL && op->is_commutative()){
321 for (int other = 0; other < size - 1; other++) {
322 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
323 if(index == MPI_UNDEFINED) {
327 // #Request is below rank: it's a irecv
328 if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
332 //non commutative case, wait in order
333 for (int other = 0; other < size - 1; other++) {
334 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
336 if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
340 for(index = 0; index < rank; index++) {
341 smpi_free_tmp_buffer(tmpbufs[index]);
343 for(index = 0; index < size-1; index++) {
344 Request::unref(&requests[index]);
351 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
353 int system_tag = -888;
355 MPI_Aint dataext = 0;
356 int recvbuf_is_empty=1;
357 int rank = comm->rank();
358 int size = comm->size();
360 datatype->extent(&lb, &dataext);
362 // Send/Recv buffers to/from others;
363 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
364 void **tmpbufs = xbt_new(void *, rank);
366 for (int other = 0; other < rank; other++) {
367 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
368 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
371 for (int other = rank + 1; other < size; other++) {
372 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
375 // Wait for completion of all comms.
376 Request::startall(size - 1, requests);
378 if(op != MPI_OP_NULL && op->is_commutative()){
379 for (int other = 0; other < size - 1; other++) {
380 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
381 if(index == MPI_UNDEFINED) {
385 if(recvbuf_is_empty){
386 Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
389 // #Request is below rank: it's a irecv
390 if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
394 //non commutative case, wait in order
395 for (int other = 0; other < size - 1; other++) {
396 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
398 if (recvbuf_is_empty) {
399 Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
400 recvbuf_is_empty = 0;
402 if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
406 for(index = 0; index < rank; index++) {
407 smpi_free_tmp_buffer(tmpbufs[index]);
409 for(index = 0; index < size-1; index++) {
410 Request::unref(&requests[index]);