1 /* smpi_coll.c -- various optimized routing for collectives */
3 /* Copyright (c) 2009-2017. The SimGrid Team. All rights reserved. */
5 /* This program is free software; you can redistribute it and/or modify it
6 * under the terms of the license (GNU LGPL) which comes with this package. */
8 #include "smpi_coll.hpp"
10 #include "smpi_comm.hpp"
11 #include "smpi_datatype.hpp"
12 #include "smpi_op.hpp"
13 #include "smpi_request.hpp"
15 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
17 #define COLL_SETTER(cat, ret, args, args2) \
18 int(*Colls::cat) args; \
19 void Colls::set_##cat(std::string name) \
21 int id = find_coll_description(mpi_coll_##cat##_description, name, #cat); \
22 cat = reinterpret_cast<ret(*) args>(mpi_coll_##cat##_description[id].coll); \
24 xbt_die("Collective " #cat " set to nullptr!"); \
27 #define SET_COLL(coll) \
28 name = xbt_cfg_get_string("smpi/" #coll); \
30 name = selector_name; \
36 void (*Colls::smpi_coll_cleanup_callback)();
38 /* these arrays must be nullptr terminated */
39 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
40 COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
41 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
42 COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
43 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
44 COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
45 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
46 COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
47 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
48 COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
49 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
50 COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
51 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
52 COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
53 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
54 COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
55 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
56 COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
57 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
58 COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
59 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
60 COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
62 /** Displays the long description of all registered models, and quit */
63 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
65 XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
66 for (int i = 0; table[i].name; i++)
67 XBT_WARN(" %s: %s\n", table[i].name, table[i].description);
70 int Colls::find_coll_description(s_mpi_coll_description_t* table, std::string name, const char* desc)
72 for (int i = 0; table[i].name; i++)
73 if (name == table[i].name) {
74 if (strcmp(table[i].name,"default"))
75 XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
79 if (not table[0].name)
80 xbt_die("No collective is valid for '%s'! This is a bug.", name.c_str());
81 std::string name_list = std::string(table[0].name);
82 for (int i = 1; table[i].name; i++)
83 name_list = name_list + ", " + table[i].name;
85 xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name.c_str(), name_list.c_str());
89 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
90 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
91 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
92 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
93 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
94 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
95 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
96 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
102 void Colls::set_collectives(){
103 std::string selector_name = xbt_cfg_get_string("smpi/coll-selector");
104 if (selector_name.empty())
105 selector_name = "default";
111 SET_COLL(allgatherv);
116 SET_COLL(reduce_scatter);
123 //Implementations of the single algorith collectives
125 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
126 MPI_Datatype recvtype, int root, MPI_Comm comm)
128 int system_tag = COLL_TAG_GATHERV;
130 MPI_Aint recvext = 0;
132 int rank = comm->rank();
133 int size = comm->size();
135 // Send buffer to root
136 Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
138 recvtype->extent(&lb, &recvext);
139 // Local copy from root
140 Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
141 recvcounts[root], recvtype);
142 // Receive buffers from senders
143 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
145 for (int src = 0; src < size; src++) {
147 requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
148 recvcounts[src], recvtype, src, system_tag, comm);
152 // Wait for completion of irecv's.
153 Request::startall(size - 1, requests);
154 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
155 for (int src = 0; src < size-1; src++) {
156 Request::unref(&requests[src]);
164 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
165 MPI_Datatype recvtype, int root, MPI_Comm comm)
167 int system_tag = COLL_TAG_SCATTERV;
169 MPI_Aint sendext = 0;
171 int rank = comm->rank();
172 int size = comm->size();
174 // Recv buffer from root
175 Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
177 sendtype->extent(&lb, &sendext);
178 // Local copy from root
179 if(recvbuf!=MPI_IN_PLACE){
180 Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
181 sendtype, recvbuf, recvcount, recvtype);
183 // Send buffers to receivers
184 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
186 for (int dst = 0; dst < size; dst++) {
188 requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
189 sendtype, dst, system_tag, comm);
193 // Wait for completion of isend's.
194 Request::startall(size - 1, requests);
195 Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
196 for (int dst = 0; dst < size-1; dst++) {
197 Request::unref(&requests[dst]);
205 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
207 int system_tag = -888;
209 MPI_Aint dataext = 0;
211 int rank = comm->rank();
212 int size = comm->size();
214 datatype->extent(&lb, &dataext);
216 // Local copy from self
217 Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
219 // Send/Recv buffers to/from others
220 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
221 void **tmpbufs = xbt_new(void *, rank);
223 for (int other = 0; other < rank; other++) {
224 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
225 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
228 for (int other = rank + 1; other < size; other++) {
229 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
232 // Wait for completion of all comms.
233 Request::startall(size - 1, requests);
235 if(op != MPI_OP_NULL && op->is_commutative()){
236 for (int other = 0; other < size - 1; other++) {
237 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
238 if(index == MPI_UNDEFINED) {
242 // #Request is below rank: it's a irecv
243 op->apply( tmpbufs[index], recvbuf, &count, datatype);
247 //non commutative case, wait in order
248 for (int other = 0; other < size - 1; other++) {
249 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
250 if(index < rank && op!=MPI_OP_NULL) {
251 op->apply( tmpbufs[other], recvbuf, &count, datatype);
255 for(index = 0; index < rank; index++) {
256 smpi_free_tmp_buffer(tmpbufs[index]);
258 for(index = 0; index < size-1; index++) {
259 Request::unref(&requests[index]);
266 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
268 int system_tag = -888;
270 MPI_Aint dataext = 0;
271 int recvbuf_is_empty=1;
272 int rank = comm->rank();
273 int size = comm->size();
275 datatype->extent(&lb, &dataext);
277 // Send/Recv buffers to/from others
278 MPI_Request *requests = xbt_new(MPI_Request, size - 1);
279 void **tmpbufs = xbt_new(void *, rank);
281 for (int other = 0; other < rank; other++) {
282 tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
283 requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
286 for (int other = rank + 1; other < size; other++) {
287 requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
290 // Wait for completion of all comms.
291 Request::startall(size - 1, requests);
293 if(op != MPI_OP_NULL && op->is_commutative()){
294 for (int other = 0; other < size - 1; other++) {
295 index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
296 if(index == MPI_UNDEFINED) {
300 if(recvbuf_is_empty){
301 Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
304 // #Request is below rank: it's a irecv
305 op->apply( tmpbufs[index], recvbuf, &count, datatype);
309 //non commutative case, wait in order
310 for (int other = 0; other < size - 1; other++) {
311 Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
313 if (recvbuf_is_empty) {
314 Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
315 recvbuf_is_empty = 0;
318 op->apply( tmpbufs[other], recvbuf, &count, datatype);
322 for(index = 0; index < rank; index++) {
323 smpi_free_tmp_buffer(tmpbufs[index]);
325 for(index = 0; index < size-1; index++) {
326 Request::unref(&requests[index]);