1 /* smpi_coll.c -- various optimized routing for collectives */
3 /* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved. */
5 /* This program is free software; you can redistribute it and/or modify it
6 * under the terms of the license (GNU LGPL) which comes with this package. */
8 #include "smpi_coll.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
18 #define COLL_SETTER(cat, ret, args, args2) \
19 int(*Colls::cat) args; \
20 void Colls::set_##cat(const std::string& name) \
22 int id = find_coll_description(mpi_coll_##cat##_description, name, #cat); \
23 cat = reinterpret_cast<ret(*) args>(mpi_coll_##cat##_description[id].coll); \
25 xbt_die("Collective " #cat " set to nullptr!"); \
31 void (*Colls::smpi_coll_cleanup_callback)();
33 /* these arrays must be nullptr terminated */
34 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
35 COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
36 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
37 COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
38 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
39 COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
40 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
41 COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
42 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
43 COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
44 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
45 COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
46 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
47 COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
48 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
49 COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
50 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
51 COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
52 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
53 COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
54 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
55 COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr} };
57 /** Displays the long description of all registered models, and quit */
58 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
60 XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
61 for (int i = 0; not table[i].name.empty(); i++)
62 XBT_WARN(" %s: %s\n", table[i].name.c_str(), table[i].description.c_str());
65 int Colls::find_coll_description(s_mpi_coll_description_t* table, const std::string& name, const char* desc)
67 for (int i = 0; not table[i].name.empty(); i++)
68 if (name == table[i].name) {
69 if (table[i].name != "default")
70 XBT_INFO("Switch to algorithm %s for collective %s",table[i].name.c_str(),desc);
74 if (table[0].name.empty())
75 xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
76 std::string name_list = table[0].name;
77 for (int i = 1; not table[i].name.empty(); i++)
78 name_list = name_list + ", " + table[i].name;
80 xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
84 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
85 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
86 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
87 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
88 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
89 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
90 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
96 void Colls::set_collectives(){
97 std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
98 if (selector_name.empty())
99 selector_name = "default";
101 std::pair<std::string, std::function<void(std::string)>> setter_callbacks[] = {
102 {"gather", &Colls::set_gather}, {"allgather", &Colls::set_allgather},
103 {"allgatherv", &Colls::set_allgatherv}, {"allreduce", &Colls::set_allreduce},
104 {"alltoall", &Colls::set_alltoall}, {"alltoallv", &Colls::set_alltoallv},
105 {"reduce", &Colls::set_reduce}, {"reduce_scatter", &Colls::set_reduce_scatter},
106 {"scatter", &Colls::set_scatter}, {"bcast", &Colls::set_bcast},
107 {"barrier", &Colls::set_barrier}};
109 for (auto& elem : setter_callbacks) {
110 std::string name = simgrid::config::get_value<std::string>(("smpi/" + elem.first).c_str());
112 name = selector_name;
119 //Implementations of the single algorith collectives
121 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
122 MPI_Datatype recvtype, int root, MPI_Comm comm)
125 Colls::igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, &request);
126 MPI_Request* requests = request->get_nbc_requests();
127 int count = request->get_nbc_requests_size();
128 Request::waitall(count, requests, MPI_STATUS_IGNORE);
129 for (int i = 0; i < count; i++) {
130 if(requests[i]!=MPI_REQUEST_NULL)
131 Request::unref(&requests[i]);
134 Request::unref(&request);
139 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
140 MPI_Datatype recvtype, int root, MPI_Comm comm)
143 Colls::iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, &request);
144 MPI_Request* requests = request->get_nbc_requests();
145 int count = request->get_nbc_requests_size();
146 Request::waitall(count, requests, MPI_STATUS_IGNORE);
147 for (int dst = 0; dst < count; dst++) {
148 if(requests[dst]!=MPI_REQUEST_NULL)
149 Request::unref(&requests[dst]);
152 Request::unref(&request);
157 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
159 int system_tag = -888;
161 MPI_Aint dataext = 0;
163 int rank = comm->rank();
164 int size = comm->size();
166 datatype->extent(&lb, &dataext);
168 // Local copy from self
169 Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
171 // Send/Recv buffers to/from others
172 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
173 void **tmpbufs = xbt_new(void *, rank);
175 for (int other = 0; other < rank; other++) {
176 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
177 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
180 for (int other = rank + 1; other < size; other++) {
181 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
184 // Wait for completion of all comms.
185 Request::startall(size - 1, requests);
187 if(op != MPI_OP_NULL && op->is_commutative()){
188 for (int other = 0; other < size - 1; other++) {
189 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
190 if(index == MPI_UNDEFINED) {
194 // #Request is below rank: it's a irecv
195 op->apply( tmpbufs[index], recvbuf, &count, datatype);
199 //non commutative case, wait in order
200 for (int other = 0; other < size - 1; other++) {
201 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
202 if(index < rank && op!=MPI_OP_NULL) {
203 op->apply( tmpbufs[other], recvbuf, &count, datatype);
207 for(index = 0; index < rank; index++) {
208 smpi_free_tmp_buffer(tmpbufs[index]);
210 for(index = 0; index < size-1; index++) {
211 Request::unref(&requests[index]);
218 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
220 int system_tag = -888;
222 MPI_Aint dataext = 0;
223 int recvbuf_is_empty=1;
224 int rank = comm->rank();
225 int size = comm->size();
227 datatype->extent(&lb, &dataext);
229 // Send/Recv buffers to/from others
230 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
231 void **tmpbufs = xbt_new(void *, rank);
233 for (int other = 0; other < rank; other++) {
234 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
235 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
238 for (int other = rank + 1; other < size; other++) {
239 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
242 // Wait for completion of all comms.
243 Request::startall(size - 1, requests);
245 if(op != MPI_OP_NULL && op->is_commutative()){
246 for (int other = 0; other < size - 1; other++) {
247 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
248 if(index == MPI_UNDEFINED) {
252 if(recvbuf_is_empty){
253 Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
256 // #Request is below rank: it's a irecv
257 op->apply( tmpbufs[index], recvbuf, &count, datatype);
261 //non commutative case, wait in order
262 for (int other = 0; other < size - 1; other++) {
263 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
265 if (recvbuf_is_empty) {
266 Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
267 recvbuf_is_empty = 0;
270 op->apply( tmpbufs[other], recvbuf, &count, datatype);
274 for(index = 0; index < rank; index++) {
275 smpi_free_tmp_buffer(tmpbufs[index]);
277 for(index = 0; index < size-1; index++) {
278 Request::unref(&requests[index]);
285 int Colls::alltoallw(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype* sendtypes,
286 void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm)
289 int err = Colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, &request);
290 MPI_Request* requests = request->get_nbc_requests();
291 int count = request->get_nbc_requests_size();
292 XBT_DEBUG("<%d> wait for %d requests", comm->rank(), count);
293 Request::waitall(count, requests, MPI_STATUS_IGNORE);
294 for (int i = 0; i < count; i++) {
295 if(requests[i]!=MPI_REQUEST_NULL)
296 Request::unref(&requests[i]);
299 Request::unref(&request);