1 /* smpi_coll.c -- various optimized routing for collectives */
3 /* Copyright (c) 2009-2017. The SimGrid Team. All rights reserved. */
5 /* This program is free software; you can redistribute it and/or modify it
6 * under the terms of the license (GNU LGPL) which comes with this package. */
9 #include "smpi_coll.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
15 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
17 #define COLL_SETTER(cat, ret, args, args2)\
18 int (*Colls::cat ) args;\
19 void Colls::set_##cat (const char * name){\
20 int id = find_coll_description(mpi_coll_## cat ##_description,\
22 cat = reinterpret_cast<ret (*) args>\
23 (mpi_coll_## cat ##_description[id].coll);\
25 xbt_die("Collective "#cat" set to nullptr!");\
28 #define SET_COLL(coll)\
29 name = xbt_cfg_get_string("smpi/"#coll);\
30 if (name==nullptr || name[0] == '\0')\
31 name = selector_name;\
37 void (*Colls::smpi_coll_cleanup_callback)();
39 /* these arrays must be nullptr terminated */
40 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
41 COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
42 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
43 COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
44 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
45 COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
46 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
47 COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
48 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
49 COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
50 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
51 COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
52 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
53 COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
54 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
55 COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
56 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
57 COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
58 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
59 COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
60 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
61 COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
63 /** Displays the long description of all registered models, and quit */
64 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
66 XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
67 for (int i = 0; table[i].name; i++)
68 XBT_WARN(" %s: %s\n", table[i].name, table[i].description);
71 int Colls::find_coll_description(s_mpi_coll_description_t * table, const char *name, const char *desc)
73 char *name_list = nullptr;
74 for (int i = 0; table[i].name; i++)
75 if (not strcmp(name, table[i].name)) {
76 if (strcmp(table[i].name,"default"))
77 XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
81 if (not table[0].name)
82 xbt_die("No collective is valid for '%s'! This is a bug.",name);
83 name_list = xbt_strdup(table[0].name);
84 for (int i = 1; table[i].name; i++) {
85 name_list = static_cast<char*>(xbt_realloc(name_list, strlen(name_list) + strlen(table[i].name) + 3));
86 strncat(name_list, ", ",2);
87 strncat(name_list, table[i].name, strlen(table[i].name));
89 xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name, name_list);
95 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
96 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
100 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
101 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
102 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
103 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
104 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
105 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
108 void Colls::set_collectives(){
109 const char* selector_name = static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
110 if (selector_name==nullptr || selector_name[0] == '\0')
111 selector_name = "default";
117 SET_COLL(allgatherv);
122 SET_COLL(reduce_scatter);
129 //Implementations of the single algorith collectives
131 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
132 MPI_Datatype recvtype, int root, MPI_Comm comm)
134 int system_tag = COLL_TAG_GATHERV;
136 MPI_Aint recvext = 0;
138 int rank = comm->rank();
139 int size = comm->size();
141 // Send buffer to root
142 Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
144 recvtype->extent(&lb, &recvext);
145 // Local copy from root
146 Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
147 recvcounts[root], recvtype);
148 // Receive buffers from senders
149 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
151 for (int src = 0; src < size; src++) {
153 requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
154 recvcounts[src], recvtype, src, system_tag, comm);
158 // Wait for completion of irecv's.
159 Request::startall(size - 1, requests);
160 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
161 for (int src = 0; src < size-1; src++) {
162 Request::unref(&requests[src]);
170 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
171 MPI_Datatype recvtype, int root, MPI_Comm comm)
173 int system_tag = COLL_TAG_SCATTERV;
175 MPI_Aint sendext = 0;
177 int rank = comm->rank();
178 int size = comm->size();
180 // Recv buffer from root
181 Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
183 sendtype->extent(&lb, &sendext);
184 // Local copy from root
185 if(recvbuf!=MPI_IN_PLACE){
186 Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
187 sendtype, recvbuf, recvcount, recvtype);
189 // Send buffers to receivers
190 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
192 for (int dst = 0; dst < size; dst++) {
194 requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
195 sendtype, dst, system_tag, comm);
199 // Wait for completion of isend's.
200 Request::startall(size - 1, requests);
201 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
202 for (int dst = 0; dst < size-1; dst++) {
203 Request::unref(&requests[dst]);
211 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
213 int system_tag = -888;
215 MPI_Aint dataext = 0;
217 int rank = comm->rank();
218 int size = comm->size();
220 datatype->extent(&lb, &dataext);
222 // Local copy from self
223 Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
225 // Send/Recv buffers to/from others
226 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
227 void **tmpbufs = xbt_new(void *, rank);
229 for (int other = 0; other < rank; other++) {
230 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
231 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
234 for (int other = rank + 1; other < size; other++) {
235 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
238 // Wait for completion of all comms.
239 Request::startall(size - 1, requests);
241 if(op != MPI_OP_NULL && op->is_commutative()){
242 for (int other = 0; other < size - 1; other++) {
243 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
244 if(index == MPI_UNDEFINED) {
248 // #Request is below rank: it's a irecv
249 op->apply( tmpbufs[index], recvbuf, &count, datatype);
253 //non commutative case, wait in order
254 for (int other = 0; other < size - 1; other++) {
255 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
256 if(index < rank && op!=MPI_OP_NULL) {
257 op->apply( tmpbufs[other], recvbuf, &count, datatype);
261 for(index = 0; index < rank; index++) {
262 smpi_free_tmp_buffer(tmpbufs[index]);
264 for(index = 0; index < size-1; index++) {
265 Request::unref(&requests[index]);
272 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
274 int system_tag = -888;
276 MPI_Aint dataext = 0;
277 int recvbuf_is_empty=1;
278 int rank = comm->rank();
279 int size = comm->size();
281 datatype->extent(&lb, &dataext);
283 // Send/Recv buffers to/from others
284 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
285 void **tmpbufs = xbt_new(void *, rank);
287 for (int other = 0; other < rank; other++) {
288 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
289 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
292 for (int other = rank + 1; other < size; other++) {
293 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
296 // Wait for completion of all comms.
297 Request::startall(size - 1, requests);
299 if(op != MPI_OP_NULL && op->is_commutative()){
300 for (int other = 0; other < size - 1; other++) {
301 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
302 if(index == MPI_UNDEFINED) {
306 if(recvbuf_is_empty){
307 Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
310 // #Request is below rank: it's a irecv
311 op->apply( tmpbufs[index], recvbuf, &count, datatype);
315 //non commutative case, wait in order
316 for (int other = 0; other < size - 1; other++) {
317 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
319 if (recvbuf_is_empty) {
320 Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
321 recvbuf_is_empty = 0;
324 op->apply( tmpbufs[other], recvbuf, &count, datatype);
328 for(index = 0; index < rank; index++) {
329 smpi_free_tmp_buffer(tmpbufs[index]);
331 for(index = 0; index < size-1; index++) {
332 Request::unref(&requests[index]);