1 /* smpi_coll.c -- various optimized routing for collectives */
3 /* Copyright (c) 2009-2015. The SimGrid Team.
4 * All rights reserved. */
6 /* This program is free software; you can redistribute it and/or modify it
7 * under the terms of the license (GNU LGPL) which comes with this package. */
14 #include "simgrid/sg_config.h"
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
22 void (*Colls::smpi_coll_cleanup_callback)();
24 /* these arrays must be nullptr terminated */
25 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
26 COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
27 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
28 COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
29 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
30 COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
31 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
32 COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
33 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
34 COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
35 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
36 COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
37 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
38 COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
39 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
40 COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
41 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
42 COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
43 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
44 COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
45 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
46 COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
48 /** Displays the long description of all registered models, and quit */
49 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
51 printf("Long description of the %s models accepted by this simulator:\n", category);
52 for (int i = 0; table[i].name; i++)
53 printf(" %s: %s\n", table[i].name, table[i].description);
56 int Colls::find_coll_description(s_mpi_coll_description_t * table, const char *name, const char *desc)
58 char *name_list = nullptr;
59 for (int i = 0; table[i].name; i++)
60 if (!strcmp(name, table[i].name)) {
61 if (strcmp(table[i].name,"default"))
62 XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
67 xbt_die("No collective is valid for '%s'! This is a bug.",name);
68 name_list = xbt_strdup(table[0].name);
69 for (int i = 1; table[i].name; i++) {
70 name_list = static_cast<char*>(xbt_realloc(name_list, strlen(name_list) + strlen(table[i].name) + 3));
71 strncat(name_list, ", ",2);
72 strncat(name_list, table[i].name, strlen(table[i].name));
74 xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name, name_list);
80 #define COLL_SETTER(cat, ret, args, args2)\
81 int (*Colls::cat ) args;\
82 void Colls::set_##cat (const char * name){\
83 int id = find_coll_description(mpi_coll_## cat ##_description,\
85 cat = reinterpret_cast<ret (*) args>\
86 (mpi_coll_## cat ##_description[id].coll);\
89 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
90 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
95 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
96 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
102 void Colls::set_collectives(){
103 const char* selector_name = static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
104 if (selector_name==nullptr || selector_name[0] == '\0')
105 selector_name = "default";
107 const char* name = xbt_cfg_get_string("smpi/gather");
108 if (name==nullptr || name[0] == '\0')
109 name = selector_name;
113 name = xbt_cfg_get_string("smpi/allgather");
114 if (name==nullptr || name[0] == '\0')
115 name = selector_name;
119 name = xbt_cfg_get_string("smpi/allgatherv");
120 if (name==nullptr || name[0] == '\0')
121 name = selector_name;
123 set_allgatherv(name);
125 name = xbt_cfg_get_string("smpi/allreduce");
126 if (name==nullptr || name[0] == '\0')
127 name = selector_name;
131 name = xbt_cfg_get_string("smpi/alltoall");
132 if (name==nullptr || name[0] == '\0')
133 name = selector_name;
137 name = xbt_cfg_get_string("smpi/alltoallv");
138 if (name==nullptr || name[0] == '\0')
139 name = selector_name;
143 name = xbt_cfg_get_string("smpi/reduce");
144 if (name==nullptr || name[0] == '\0')
145 name = selector_name;
149 name = xbt_cfg_get_string("smpi/reduce-scatter");
150 if (name==nullptr || name[0] == '\0')
151 name = selector_name;
153 set_reduce_scatter(name);
155 name = xbt_cfg_get_string("smpi/scatter");
156 if (name==nullptr || name[0] == '\0')
157 name = selector_name;
161 name = xbt_cfg_get_string("smpi/bcast");
162 if (name==nullptr || name[0] == '\0')
163 name = selector_name;
167 name = xbt_cfg_get_string("smpi/barrier");
168 if (name==nullptr || name[0] == '\0')
169 name = selector_name;
175 //Implementations of the single algorith collectives
177 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
178 MPI_Datatype recvtype, int root, MPI_Comm comm)
180 int system_tag = COLL_TAG_GATHERV;
182 MPI_Aint recvext = 0;
184 int rank = comm->rank();
185 int size = comm->size();
187 // Send buffer to root
188 Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
190 recvtype->extent(&lb, &recvext);
191 // Local copy from root
192 Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
193 recvcounts[root], recvtype);
194 // Receive buffers from senders
195 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
197 for (int src = 0; src < size; src++) {
199 requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
200 recvcounts[src], recvtype, src, system_tag, comm);
204 // Wait for completion of irecv's.
205 Request::startall(size - 1, requests);
206 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
207 for (int src = 0; src < size-1; src++) {
208 Request::unref(&requests[src]);
216 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
217 MPI_Datatype recvtype, int root, MPI_Comm comm)
219 int system_tag = COLL_TAG_SCATTERV;
221 MPI_Aint sendext = 0;
223 int rank = comm->rank();
224 int size = comm->size();
226 // Recv buffer from root
227 Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
229 sendtype->extent(&lb, &sendext);
230 // Local copy from root
231 if(recvbuf!=MPI_IN_PLACE){
232 Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
233 sendtype, recvbuf, recvcount, recvtype);
235 // Send buffers to receivers
236 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
238 for (int dst = 0; dst < size; dst++) {
240 requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
241 sendtype, dst, system_tag, comm);
245 // Wait for completion of isend's.
246 Request::startall(size - 1, requests);
247 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
248 for (int dst = 0; dst < size-1; dst++) {
249 Request::unref(&requests[dst]);
257 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
259 int system_tag = -888;
261 MPI_Aint dataext = 0;
263 int rank = comm->rank();
264 int size = comm->size();
266 datatype->extent(&lb, &dataext);
268 // Local copy from self
269 Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
271 // Send/Recv buffers to/from others;
272 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
273 void **tmpbufs = xbt_new(void *, rank);
275 for (int other = 0; other < rank; other++) {
276 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
277 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
280 for (int other = rank + 1; other < size; other++) {
281 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
284 // Wait for completion of all comms.
285 Request::startall(size - 1, requests);
287 if(op != MPI_OP_NULL && op->is_commutative()){
288 for (int other = 0; other < size - 1; other++) {
289 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
290 if(index == MPI_UNDEFINED) {
294 // #Request is below rank: it's a irecv
295 if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
299 //non commutative case, wait in order
300 for (int other = 0; other < size - 1; other++) {
301 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
303 if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
307 for(index = 0; index < rank; index++) {
308 smpi_free_tmp_buffer(tmpbufs[index]);
310 for(index = 0; index < size-1; index++) {
311 Request::unref(&requests[index]);
318 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
320 int system_tag = -888;
322 MPI_Aint dataext = 0;
323 int recvbuf_is_empty=1;
324 int rank = comm->rank();
325 int size = comm->size();
327 datatype->extent(&lb, &dataext);
329 // Send/Recv buffers to/from others;
330 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
331 void **tmpbufs = xbt_new(void *, rank);
333 for (int other = 0; other < rank; other++) {
334 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
335 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
338 for (int other = rank + 1; other < size; other++) {
339 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
342 // Wait for completion of all comms.
343 Request::startall(size - 1, requests);
345 if(op != MPI_OP_NULL && op->is_commutative()){
346 for (int other = 0; other < size - 1; other++) {
347 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
348 if(index == MPI_UNDEFINED) {
352 if(recvbuf_is_empty){
353 Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
356 // #Request is below rank: it's a irecv
357 if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
361 //non commutative case, wait in order
362 for (int other = 0; other < size - 1; other++) {
363 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
365 if (recvbuf_is_empty) {
366 Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
367 recvbuf_is_empty = 0;
369 if(op!=MPI_OP_NULL) op->apply( tmpbufs[other], recvbuf, &count, datatype);
373 for(index = 0; index < rank; index++) {
374 smpi_free_tmp_buffer(tmpbufs[index]);
376 for(index = 0; index < size-1; index++) {
377 Request::unref(&requests[index]);