Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
ok, I stop trying to please sonar.
[simgrid.git] / src / smpi / smpi_coll.cpp
1 /* smpi_coll.c -- various optimized routing for collectives                   */
2
3 /* Copyright (c) 2009-2015. The SimGrid Team.
4  * All rights reserved.                                                     */
5
6 /* This program is free software; you can redistribute it and/or modify it
7  * under the terms of the license (GNU LGPL) which comes with this package. */
8
9 #include <stdio.h>
10 #include <string.h>
11 #include <assert.h>
12
13 #include "private.h"
14 #include "simgrid/sg_config.h"
15
16 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI (coll)");
17
18 #define COLL_SETTER(cat, ret, args, args2)\
19 int (*Colls::cat ) args;\
20 void Colls::set_##cat (const char * name){\
21     int id = find_coll_description(mpi_coll_## cat ##_description,\
22                                              name,#cat);\
23     cat = reinterpret_cast<ret (*) args>\
24         (mpi_coll_## cat ##_description[id].coll);\
25     if (cat == nullptr)\
26       xbt_die("Collective "#cat" set to nullptr!");\
27 }
28
29 #define SET_COLL(coll)\
30     name = xbt_cfg_get_string("smpi/"#coll);\
31     if (name==nullptr || name[0] == '\0')\
32         name = selector_name;\
33     set_##coll(name);
34
35 namespace simgrid{
36 namespace smpi{
37
38 void (*Colls::smpi_coll_cleanup_callback)();
39
40 /* these arrays must be nullptr terminated */
41 s_mpi_coll_description_t Colls::mpi_coll_gather_description[] = {
42     COLL_GATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
43 s_mpi_coll_description_t Colls::mpi_coll_allgather_description[] = {
44     COLL_ALLGATHERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
45 s_mpi_coll_description_t Colls::mpi_coll_allgatherv_description[] = {
46     COLL_ALLGATHERVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
47 s_mpi_coll_description_t Colls::mpi_coll_allreduce_description[] ={
48     COLL_ALLREDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
49 s_mpi_coll_description_t Colls::mpi_coll_reduce_scatter_description[] = {
50     COLL_REDUCE_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
51 s_mpi_coll_description_t Colls::mpi_coll_scatter_description[] ={
52     COLL_SCATTERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
53 s_mpi_coll_description_t Colls::mpi_coll_barrier_description[] ={
54     COLL_BARRIERS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
55 s_mpi_coll_description_t Colls::mpi_coll_alltoall_description[] = {
56     COLL_ALLTOALLS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
57 s_mpi_coll_description_t Colls::mpi_coll_alltoallv_description[] = {
58     COLL_ALLTOALLVS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
59 s_mpi_coll_description_t Colls::mpi_coll_bcast_description[] = {
60     COLL_BCASTS(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
61 s_mpi_coll_description_t Colls::mpi_coll_reduce_description[] = {
62     COLL_REDUCES(COLL_DESCRIPTION, COLL_COMMA), {nullptr, nullptr, nullptr} };
63
64 /** Displays the long description of all registered models, and quit */
65 void Colls::coll_help(const char *category, s_mpi_coll_description_t * table)
66 {
67   XBT_WARN("Long description of the %s models accepted by this simulator:\n", category);
68   for (int i = 0; table[i].name; i++)
69     XBT_WARN("  %s: %s\n", table[i].name, table[i].description);
70 }
71
72 int Colls::find_coll_description(s_mpi_coll_description_t * table, const char *name, const char *desc)
73 {
74   char *name_list = nullptr;
75   for (int i = 0; table[i].name; i++)
76     if (!strcmp(name, table[i].name)) {
77       if (strcmp(table[i].name,"default"))
78         XBT_INFO("Switch to algorithm %s for collective %s",table[i].name,desc);
79       return i;
80     }
81
82   if (!table[0].name)
83     xbt_die("No collective is valid for '%s'! This is a bug.",name);
84   name_list = xbt_strdup(table[0].name);
85   for (int i = 1; table[i].name; i++) {
86     name_list = static_cast<char*>(xbt_realloc(name_list, strlen(name_list) + strlen(table[i].name) + 3));
87     strncat(name_list, ", ",2);
88     strncat(name_list, table[i].name, strlen(table[i].name));
89   }
90   xbt_die("Collective '%s' is invalid! Valid collectives are: %s.", name, name_list);
91   return -1;
92 }
93
94
95
96 COLL_APPLY(COLL_SETTER,COLL_GATHER_SIG,"");
97 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"");
98 COLL_APPLY(COLL_SETTER,COLL_ALLGATHERV_SIG,"");
99 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SIG,"");
100 COLL_APPLY(COLL_SETTER,COLL_ALLREDUCE_SIG,"");
101 COLL_APPLY(COLL_SETTER,COLL_REDUCE_SCATTER_SIG,"");
102 COLL_APPLY(COLL_SETTER,COLL_SCATTER_SIG,"");
103 COLL_APPLY(COLL_SETTER,COLL_BARRIER_SIG,"");
104 COLL_APPLY(COLL_SETTER,COLL_BCAST_SIG,"");
105 COLL_APPLY(COLL_SETTER,COLL_ALLTOALL_SIG,"");
106 COLL_APPLY(COLL_SETTER,COLL_ALLTOALLV_SIG,"");
107
108
109 void Colls::set_collectives(){
110     const char* selector_name = static_cast<char*>(xbt_cfg_get_string("smpi/coll-selector"));
111     if (selector_name==nullptr || selector_name[0] == '\0')
112         selector_name = "default";
113
114     const char* name;
115
116     SET_COLL(gather);
117     SET_COLL(allgather);
118     SET_COLL(allgatherv);
119     SET_COLL(allreduce);
120     SET_COLL(alltoall);
121     SET_COLL(alltoallv);
122     SET_COLL(reduce);
123     SET_COLL(reduce_scatter);
124     SET_COLL(scatter);
125     SET_COLL(bcast);
126     SET_COLL(barrier);
127 }
128
129
130 //Implementations of the single algorith collectives
131
132 int Colls::gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,
133                       MPI_Datatype recvtype, int root, MPI_Comm comm)
134 {
135   int system_tag = COLL_TAG_GATHERV;
136   MPI_Aint lb = 0;
137   MPI_Aint recvext = 0;
138
139   int rank = comm->rank();
140   int size = comm->size();
141   if (rank != root) {
142     // Send buffer to root
143     Request::send(sendbuf, sendcount, sendtype, root, system_tag, comm);
144   } else {
145     recvtype->extent(&lb, &recvext);
146     // Local copy from root
147     Datatype::copy(sendbuf, sendcount, sendtype, static_cast<char*>(recvbuf) + displs[root] * recvext,
148                        recvcounts[root], recvtype);
149     // Receive buffers from senders
150     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
151     int index = 0;
152     for (int src = 0; src < size; src++) {
153       if(src != root) {
154         requests[index] = Request::irecv_init(static_cast<char*>(recvbuf) + displs[src] * recvext,
155                           recvcounts[src], recvtype, src, system_tag, comm);
156         index++;
157       }
158     }
159     // Wait for completion of irecv's.
160     Request::startall(size - 1, requests);
161     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
162     for (int src = 0; src < size-1; src++) {
163       Request::unref(&requests[src]);
164     }
165     xbt_free(requests);
166   }
167   return MPI_SUCCESS;
168 }
169
170
171 int Colls::scatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,
172                        MPI_Datatype recvtype, int root, MPI_Comm comm)
173 {
174   int system_tag = COLL_TAG_SCATTERV;
175   MPI_Aint lb = 0;
176   MPI_Aint sendext = 0;
177
178   int rank = comm->rank();
179   int size = comm->size();
180   if(rank != root) {
181     // Recv buffer from root
182     Request::recv(recvbuf, recvcount, recvtype, root, system_tag, comm, MPI_STATUS_IGNORE);
183   } else {
184     sendtype->extent(&lb, &sendext);
185     // Local copy from root
186     if(recvbuf!=MPI_IN_PLACE){
187       Datatype::copy(static_cast<char *>(sendbuf) + displs[root] * sendext, sendcounts[root],
188                        sendtype, recvbuf, recvcount, recvtype);
189     }
190     // Send buffers to receivers
191     MPI_Request *requests = xbt_new(MPI_Request, size - 1);
192     int index = 0;
193     for (int dst = 0; dst < size; dst++) {
194       if (dst != root) {
195         requests[index] = Request::isend_init(static_cast<char *>(sendbuf) + displs[dst] * sendext, sendcounts[dst],
196                             sendtype, dst, system_tag, comm);
197         index++;
198       }
199     }
200     // Wait for completion of isend's.
201     Request::startall(size - 1, requests);
202     Request::waitall(size - 1, requests, MPI_STATUS_IGNORE);
203     for (int dst = 0; dst < size-1; dst++) {
204       Request::unref(&requests[dst]);
205     }
206     xbt_free(requests);
207   }
208   return MPI_SUCCESS;
209 }
210
211
212 int Colls::scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
213 {
214   int system_tag = -888;
215   MPI_Aint lb      = 0;
216   MPI_Aint dataext = 0;
217
218   int rank = comm->rank();
219   int size = comm->size();
220
221   datatype->extent(&lb, &dataext);
222
223   // Local copy from self
224   Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype);
225
226   // Send/Recv buffers to/from others
227   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
228   void **tmpbufs = xbt_new(void *, rank);
229   int index = 0;
230   for (int other = 0; other < rank; other++) {
231     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
232     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
233     index++;
234   }
235   for (int other = rank + 1; other < size; other++) {
236     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
237     index++;
238   }
239   // Wait for completion of all comms.
240   Request::startall(size - 1, requests);
241
242   if(op != MPI_OP_NULL && op->is_commutative()){
243     for (int other = 0; other < size - 1; other++) {
244       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
245       if(index == MPI_UNDEFINED) {
246         break;
247       }
248       if(index < rank) {
249         // #Request is below rank: it's a irecv
250         op->apply( tmpbufs[index], recvbuf, &count, datatype);
251       }
252     }
253   }else{
254     //non commutative case, wait in order
255     for (int other = 0; other < size - 1; other++) {
256       Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
257       if(index < rank && op!=MPI_OP_NULL) {
258         op->apply( tmpbufs[other], recvbuf, &count, datatype);
259       }
260     }
261   }
262   for(index = 0; index < rank; index++) {
263     smpi_free_tmp_buffer(tmpbufs[index]);
264   }
265   for(index = 0; index < size-1; index++) {
266     Request::unref(&requests[index]);
267   }
268   xbt_free(tmpbufs);
269   xbt_free(requests);
270   return MPI_SUCCESS;
271 }
272
273 int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
274 {
275   int system_tag = -888;
276   MPI_Aint lb         = 0;
277   MPI_Aint dataext    = 0;
278   int recvbuf_is_empty=1;
279   int rank = comm->rank();
280   int size = comm->size();
281
282   datatype->extent(&lb, &dataext);
283
284   // Send/Recv buffers to/from others
285   MPI_Request *requests = xbt_new(MPI_Request, size - 1);
286   void **tmpbufs = xbt_new(void *, rank);
287   int index = 0;
288   for (int other = 0; other < rank; other++) {
289     tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
290     requests[index] = Request::irecv_init(tmpbufs[index], count, datatype, other, system_tag, comm);
291     index++;
292   }
293   for (int other = rank + 1; other < size; other++) {
294     requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag, comm);
295     index++;
296   }
297   // Wait for completion of all comms.
298   Request::startall(size - 1, requests);
299
300   if(op != MPI_OP_NULL && op->is_commutative()){
301     for (int other = 0; other < size - 1; other++) {
302       index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
303       if(index == MPI_UNDEFINED) {
304         break;
305       }
306       if(index < rank) {
307         if(recvbuf_is_empty){
308           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
309           recvbuf_is_empty=0;
310         } else
311           // #Request is below rank: it's a irecv
312           op->apply( tmpbufs[index], recvbuf, &count, datatype);
313       }
314     }
315   }else{
316     //non commutative case, wait in order
317     for (int other = 0; other < size - 1; other++) {
318      Request::wait(&(requests[other]), MPI_STATUS_IGNORE);
319       if(index < rank) {
320         if (recvbuf_is_empty) {
321           Datatype::copy(tmpbufs[other], count, datatype, recvbuf, count, datatype);
322           recvbuf_is_empty = 0;
323         } else
324           if(op!=MPI_OP_NULL) 
325             op->apply( tmpbufs[other], recvbuf, &count, datatype);
326       }
327     }
328   }
329   for(index = 0; index < rank; index++) {
330     smpi_free_tmp_buffer(tmpbufs[index]);
331   }
332   for(index = 0; index < size-1; index++) {
333     Request::unref(&requests[index]);
334   }
335   xbt_free(tmpbufs);
336   xbt_free(requests);
337   return MPI_SUCCESS;
338 }
339
340 }
341 }
342
343
344
345
346