1 /* smpi_coll.c -- various optimized routing for collectives */
3 /* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved. */
5 /* This program is free software; you can redistribute it and/or modify it
6 * under the terms of the license (GNU LGPL) which comes with this package. */
8 #include "smpi_coll.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
14 #include "xbt/config.hpp"
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI collectives.");
18 #define COLL_SETTER(cat, ret, args, args2) \
19 void colls::_XBT_CONCAT(set_, cat)(const std::string& name) \
21 int id = find_coll_description(_XBT_CONCAT3(mpi_coll_, cat, _description), name, _XBT_STRINGIFY(cat)); \
22 cat = reinterpret_cast<ret(*) args>(_XBT_CONCAT3(mpi_coll_, cat, _description)[id].coll); \
24 xbt_die("Collective " _XBT_STRINGIFY(cat) " set to nullptr!"); \
30 /* these arrays must be nullptr terminated */
31 s_mpi_coll_description_t mpi_coll_gather_description[] = {COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA),
33 s_mpi_coll_description_t mpi_coll_allgather_description[] = {COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA),
35 s_mpi_coll_description_t mpi_coll_allgatherv_description[] = {COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA),
37 s_mpi_coll_description_t mpi_coll_allreduce_description[] = {COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA),
39 s_mpi_coll_description_t mpi_coll_reduce_scatter_description[] = {COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA),
41 s_mpi_coll_description_t mpi_coll_scatter_description[] = {COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA),
43 s_mpi_coll_description_t mpi_coll_barrier_description[] = {COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA),
45 s_mpi_coll_description_t mpi_coll_alltoall_description[] = {COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA),
47 s_mpi_coll_description_t mpi_coll_alltoallv_description[] = {COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA),
49 s_mpi_coll_description_t mpi_coll_bcast_description[] = {COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {"", "", nullptr}};
50 s_mpi_coll_description_t mpi_coll_reduce_description[] = {COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA),
53 // Needed by the automatic selector weird implementation
54 s_mpi_coll_description_t* colls::get_smpi_coll_description(const char* name, int rank)
56 if (strcmp(name, "gather") == 0)
57 return &mpi_coll_gather_description[rank];
58 if (strcmp(name, "allgather") == 0)
59 return &mpi_coll_allgather_description[rank];
60 if (strcmp(name, "allgatherv") == 0)
61 return &mpi_coll_allgatherv_description[rank];
62 if (strcmp(name, "allreduce") == 0)
63 return &mpi_coll_allreduce_description[rank];
64 if (strcmp(name, "reduce_scatter") == 0)
65 return &mpi_coll_reduce_scatter_description[rank];
66 if (strcmp(name, "scatter") == 0)
67 return &mpi_coll_scatter_description[rank];
68 if (strcmp(name, "barrier") == 0)
69 return &mpi_coll_barrier_description[rank];
70 if (strcmp(name, "alltoall") == 0)
71 return &mpi_coll_alltoall_description[rank];
72 if (strcmp(name, "alltoallv") == 0)
73 return &mpi_coll_alltoallv_description[rank];
74 if (strcmp(name, "bcast") == 0)
75 return &mpi_coll_bcast_description[rank];
76 if (strcmp(name, "reduce") == 0)
77 return &mpi_coll_reduce_description[rank];
78 XBT_INFO("You requested an unknown collective: %s", name);
82 /** Displays the long description of all registered models, and quit */
83 void colls::coll_help(const char* category, s_mpi_coll_description_t* table)
85 XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
86 for (int i = 0; not table[i].name.empty(); i++)
87 XBT_WARN(" %s: %s\n", table[i].name.c_str(), table[i].description.c_str());
90 int colls::find_coll_description(s_mpi_coll_description_t* table, const std::string& name, const char* desc)
92 for (int i = 0; not table[i].name.empty(); i++)
93 if (name == table[i].name) {
94 if (table[i].name != "default")
95 XBT_INFO("Switch to algorithm %s for collective %s",table[i].name.c_str(),desc);
99 if (table[0].name.empty())
100 xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
101 std::string name_list = table[0].name;
102 for (int i = 1; not table[i].name.empty(); i++)
103 name_list = name_list + ", " + table[i].name;
105 xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
109 int (*colls::gather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
110 MPI_Datatype recv_type, int root, MPI_Comm comm);
111 int (*colls::allgather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
112 MPI_Datatype recv_type, MPI_Comm comm);
113 int (*colls::allgatherv)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff,
114 const int* recv_count, const int* recv_disps, MPI_Datatype recv_type, MPI_Comm comm);
115 int (*colls::alltoall)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
116 MPI_Datatype recv_type, MPI_Comm comm);
117 int (*colls::alltoallv)(const void* send_buff, const int* send_counts, const int* send_disps, MPI_Datatype send_type,
118 void* recv_buff, const int* recv_counts, const int* recv_disps, MPI_Datatype recv_type,
120 int (*colls::bcast)(void* buf, int count, MPI_Datatype datatype, int root, MPI_Comm comm);
121 int (*colls::reduce)(const void* buf, void* rbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm);
122 int (*colls::allreduce)(const void* sbuf, void* rbuf, int rcount, MPI_Datatype dtype, MPI_Op op, MPI_Comm comm);
123 int (*colls::reduce_scatter)(const void* sbuf, void* rbuf, const int* rcounts, MPI_Datatype dtype, MPI_Op op,
125 int (*colls::scatter)(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount,
126 MPI_Datatype recvtype, int root, MPI_Comm comm);
127 int (*colls::barrier)(MPI_Comm comm);
129 void (*colls::smpi_coll_cleanup_callback)();
131 void colls::set_gather(const std::string& name)
133 int id = find_coll_description(mpi_coll_gather_description, name, "gather");
134 gather = reinterpret_cast<int(*)(const void *send_buff, int send_count, MPI_Datatype send_type,
135 void *recv_buff, int recv_count, MPI_Datatype recv_type,
136 int root, MPI_Comm comm)>(mpi_coll_gather_description[id].coll);
137 if (gather == nullptr)
138 xbt_die("Collective gather set to nullptr!");
141 //COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
142 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
143 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
144 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
145 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
146 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
147 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
148 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
149 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
150 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
151 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
153 void colls::set_collectives()
155 std::string selector_name = simgrid::config::get_value<std::string>("smpi/coll-selector");
156 if (selector_name.empty())
157 selector_name = "default";
159 std::pair<std::string, std::function<void(std::string)>> setter_callbacks[] = {
160 {"gather", &colls::set_gather}, {"allgather", &colls::set_allgather},
161 {"allgatherv", &colls::set_allgatherv}, {"allreduce", &colls::set_allreduce},
162 {"alltoall", &colls::set_alltoall}, {"alltoallv", &colls::set_alltoallv},
163 {"reduce", &colls::set_reduce}, {"reduce_scatter", &colls::set_reduce_scatter},
164 {"scatter", &colls::set_scatter}, {"bcast", &colls::set_bcast},
165 {"barrier", &colls::set_barrier}};
167 for (auto& elem : setter_callbacks) {
168 std::string name = simgrid::config::get_value<std::string>(("smpi/" + elem.first).c_str());
170 name = selector_name;
176 //Implementations of the single algorithm collectives
178 int colls::gatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, const int* recvcounts,
179 const int* displs, MPI_Datatype recvtype, int root, MPI_Comm comm)
182 colls::igatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, &request, 0);
183 return Request::wait(&request, MPI_STATUS_IGNORE);
186 int colls::scatterv(const void* sendbuf, const int* sendcounts, const int* displs, MPI_Datatype sendtype, void* recvbuf,
187 int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
190 colls::iscatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
191 return Request::wait(&request, MPI_STATUS_IGNORE);
194 int colls::scan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
196 int system_tag = -888;
198 MPI_Aint dataext = 0;
200 int rank = comm->rank();
201 int size = comm->size();
203 datatype->extent(&lb, &dataext);
205 // Local copy from self
206 Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
208 // Send/Recv buffers to/from others
209 MPI_Request* requests = new MPI_Request[size - 1];
210 unsigned char** tmpbufs = new unsigned char*[rank];
212 for (int other = 0; other < rank; other++) {
213 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
214 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
217 for (int other = rank + 1; other < size; other++) {
218 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
221 // Wait for completion of all comms.
222 Request::startall(size - 1, requests);
224 if(op != MPI_OP_NULL && op->is_commutative()){
225 for (int other = 0; other < size - 1; other++) {
226 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
227 if(index == MPI_UNDEFINED) {
231 // #Request is below rank: it's a irecv
232 op->apply( tmpbufs[index], recvbuf, &count, datatype);
236 //non commutative case, wait in order
237 for (int other = 0; other < size - 1; other++) {
238 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
239 if(index < rank && op!=MPI_OP_NULL) {
240 op->apply( tmpbufs[other], recvbuf, &count, datatype);
244 for(index = 0; index < rank; index++) {
245 smpi_free_tmp_buffer(tmpbufs[index]);
247 for(index = 0; index < size-1; index++) {
248 Request::unref(&requests[index]);
255 int colls::exscan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
257 int system_tag = -888;
259 MPI_Aint dataext = 0;
260 int recvbuf_is_empty=1;
261 int rank = comm->rank();
262 int size = comm->size();
264 datatype->extent(&lb, &dataext);
266 // Send/Recv buffers to/from others
267 MPI_Request* requests = new MPI_Request[size - 1];
268 unsigned char** tmpbufs = new unsigned char*[rank];
270 for (int other = 0; other < rank; other++) {
271 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
272 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
275 for (int other = rank + 1; other < size; other++) {
276 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
279 // Wait for completion of all comms.
280 Request::startall(size - 1, requests);
282 if(op != MPI_OP_NULL && op->is_commutative()){
283 for (int other = 0; other < size - 1; other++) {
284 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
285 if(index == MPI_UNDEFINED) {
289 if(recvbuf_is_empty){
290 Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
293 // #Request is below rank: it's a irecv
294 op->apply( tmpbufs[index], recvbuf, &count, datatype);
298 //non commutative case, wait in order
299 for (int other = 0; other < size - 1; other++) {
300 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
302 if (recvbuf_is_empty) {
303 Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
304 recvbuf_is_empty = 0;
307 op->apply( tmpbufs[other], recvbuf, &count, datatype);
311 for(index = 0; index < rank; index++) {
312 smpi_free_tmp_buffer(tmpbufs[index]);
314 for(index = 0; index < size-1; index++) {
315 Request::unref(&requests[index]);
322 int colls::alltoallw(const void* sendbuf, const int* sendcounts, const int* senddisps, const MPI_Datatype* sendtypes,
323 void* recvbuf, const int* recvcounts, const int* recvdisps, const MPI_Datatype* recvtypes,
327 colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm,
329 return Request::wait(&request, MPI_STATUS_IGNORE);