Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MC: complete workaround in the error msg seen on modern systems
[simgrid.git] / src / smpi / colls / reduce / reduce-scatter-gather.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8
9 /*
10   reduce
11   Author: MPICH
12  */
13 namespace simgrid{
14 namespace smpi{
15 int Coll_reduce_scatter_gather::reduce(const void *sendbuf, void *recvbuf,
16                                           int count, MPI_Datatype datatype,
17                                           MPI_Op op, int root, MPI_Comm comm)
18 {
19   MPI_Status status;
20   int comm_size, rank, pof2, rem, newrank;
21   int mask, *cnts, *disps, i, j, send_idx = 0;
22   int recv_idx, last_idx = 0, newdst;
23   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
24   int newroot_tree_root, new_count;
25   int tag = COLL_TAG_REDUCE,temporary_buffer=0;
26   unsigned char *send_ptr, *recv_ptr, *tmp_buf;
27
28   cnts = NULL;
29   disps = NULL;
30
31   MPI_Aint extent;
32
33   if (count == 0)
34     return 0;
35   rank = comm->rank();
36   comm_size = comm->size();
37
38
39
40   extent = datatype->get_extent();
41   /* If I'm not the root, then my recvbuf may not be valid, therefore
42   I have to allocate a temporary one */
43   if (rank != root && not recvbuf) {
44     temporary_buffer=1;
45     recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
46   }
47   /* find nearest power-of-two less than or equal to comm_size */
48   pof2 = 1;
49   while (pof2 <= comm_size)
50     pof2 <<= 1;
51   pof2 >>= 1;
52
53   if (count < comm_size) {
54     new_count = comm_size;
55     send_ptr  = smpi_get_tmp_sendbuffer(new_count * extent);
56     recv_ptr  = smpi_get_tmp_recvbuffer(new_count * extent);
57     tmp_buf   = smpi_get_tmp_sendbuffer(new_count * extent);
58     memcpy(send_ptr, sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, extent * count);
59
60     //if ((rank != root))
61     Request::sendrecv(send_ptr, new_count, datatype, rank, tag,
62                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
63
64     rem = comm_size - pof2;
65     if (rank < 2 * rem) {
66       if (rank % 2 != 0) {
67         /* odd */
68         Request::send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
69         newrank = -1;
70       } else {
71         Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
72         if(op!=MPI_OP_NULL) op->apply( tmp_buf, recv_ptr, &new_count, datatype);
73         newrank = rank / 2;
74       }
75     } else                      /* rank >= 2*rem */
76       newrank = rank - rem;
77
78     cnts  = new int[pof2];
79     disps = new int[pof2];
80
81     if (newrank != -1) {
82       for (i = 0; i < (pof2 - 1); i++)
83         cnts[i] = new_count / pof2;
84       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
85
86       disps[0] = 0;
87       for (i = 1; i < pof2; i++)
88         disps[i] = disps[i - 1] + cnts[i - 1];
89
90       mask = 0x1;
91       send_idx = recv_idx = 0;
92       last_idx = pof2;
93       while (mask < pof2) {
94         newdst = newrank ^ mask;
95         /* find real rank of dest */
96         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
97
98         send_cnt = recv_cnt = 0;
99         if (newrank < newdst) {
100           send_idx = recv_idx + pof2 / (mask * 2);
101           for (i = send_idx; i < last_idx; i++)
102             send_cnt += cnts[i];
103           for (i = recv_idx; i < send_idx; i++)
104             recv_cnt += cnts[i];
105         } else {
106           recv_idx = send_idx + pof2 / (mask * 2);
107           for (i = send_idx; i < recv_idx; i++)
108             send_cnt += cnts[i];
109           for (i = recv_idx; i < last_idx; i++)
110             recv_cnt += cnts[i];
111         }
112
113         /* Send data from recvbuf. Recv into tmp_buf */
114         Request::sendrecv(recv_ptr + disps[send_idx] * extent, send_cnt, datatype, dst, tag,
115                           tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
116
117         /* tmp_buf contains data received in this step.
118            recvbuf contains data accumulated so far */
119
120         if (op != MPI_OP_NULL)
121           op->apply(tmp_buf + disps[recv_idx] * extent, recv_ptr + disps[recv_idx] * extent, &recv_cnt, datatype);
122
123         /* update send_idx for next iteration */
124         send_idx = recv_idx;
125         mask <<= 1;
126
127         if (mask < pof2)
128           last_idx = recv_idx + pof2 / mask;
129       }
130     }
131
132     /* now do the gather to root */
133
134     if (root < 2 * rem) {
135       if (root % 2 != 0) {
136         if (rank == root) {
137           /* recv */
138           for (i = 0; i < (pof2 - 1); i++)
139             cnts[i] = new_count / pof2;
140           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
141
142           disps[0] = 0;
143           for (i = 1; i < pof2; i++)
144             disps[i] = disps[i - 1] + cnts[i - 1];
145
146           Request::recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
147
148           newrank = 0;
149           send_idx = 0;
150           last_idx = 2;
151         } else if (newrank == 0) {
152           Request::send(recv_ptr, cnts[0], datatype, root, tag, comm);
153           newrank = -1;
154         }
155         newroot = 0;
156       } else
157         newroot = root / 2;
158     } else
159       newroot = root - rem;
160
161     if (newrank != -1) {
162       j = 0;
163       mask = 0x1;
164       while (mask < pof2) {
165         mask <<= 1;
166         j++;
167       }
168       mask >>= 1;
169       j--;
170       while (mask > 0) {
171         newdst = newrank ^ mask;
172
173         /* find real rank of dest */
174         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
175
176         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
177           dst = root;
178         newdst_tree_root = newdst >> j;
179         newdst_tree_root <<= j;
180
181         newroot_tree_root = newroot >> j;
182         newroot_tree_root <<= j;
183
184         send_cnt = recv_cnt = 0;
185         if (newrank < newdst) {
186           /* update last_idx except on first iteration */
187           if (mask != pof2 / 2)
188             last_idx = last_idx + pof2 / (mask * 2);
189
190           recv_idx = send_idx + pof2 / (mask * 2);
191           for (i = send_idx; i < recv_idx; i++)
192             send_cnt += cnts[i];
193           for (i = recv_idx; i < last_idx; i++)
194             recv_cnt += cnts[i];
195         } else {
196           recv_idx = send_idx - pof2 / (mask * 2);
197           for (i = send_idx; i < last_idx; i++)
198             send_cnt += cnts[i];
199           for (i = recv_idx; i < send_idx; i++)
200             recv_cnt += cnts[i];
201         }
202
203         if (newdst_tree_root == newroot_tree_root) {
204           Request::send(recv_ptr + disps[send_idx] * extent, send_cnt, datatype, dst, tag, comm);
205           break;
206         } else {
207           Request::recv(recv_ptr + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
208         }
209
210         if (newrank > newdst)
211           send_idx = recv_idx;
212
213         mask >>= 1;
214         j--;
215       }
216     }
217     memcpy(recvbuf, recv_ptr, extent * count);
218     smpi_free_tmp_buffer(send_ptr);
219     smpi_free_tmp_buffer(recv_ptr);
220   }
221
222
223   else /* (count >= comm_size) */ {
224     tmp_buf = smpi_get_tmp_sendbuffer(count * extent);
225
226     //if ((rank != root))
227     Request::sendrecv(sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, count, datatype, rank, tag,
228                  recvbuf, count, datatype, rank, tag, comm, &status);
229
230     rem = comm_size - pof2;
231     if (rank < 2 * rem) {
232       if (rank % 2 != 0) {      /* odd */
233         Request::send(recvbuf, count, datatype, rank - 1, tag, comm);
234         newrank = -1;
235       }
236
237       else {
238         Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
239         if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
240         newrank = rank / 2;
241       }
242     } else                      /* rank >= 2*rem */
243       newrank = rank - rem;
244
245     cnts  = new int[pof2];
246     disps = new int[pof2];
247
248     if (newrank != -1) {
249       for (i = 0; i < (pof2 - 1); i++)
250         cnts[i] = count / pof2;
251       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
252
253       disps[0] = 0;
254       for (i = 1; i < pof2; i++)
255         disps[i] = disps[i - 1] + cnts[i - 1];
256
257       mask = 0x1;
258       send_idx = recv_idx = 0;
259       last_idx = pof2;
260       while (mask < pof2) {
261         newdst = newrank ^ mask;
262         /* find real rank of dest */
263         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
264
265         send_cnt = recv_cnt = 0;
266         if (newrank < newdst) {
267           send_idx = recv_idx + pof2 / (mask * 2);
268           for (i = send_idx; i < last_idx; i++)
269             send_cnt += cnts[i];
270           for (i = recv_idx; i < send_idx; i++)
271             recv_cnt += cnts[i];
272         } else {
273           recv_idx = send_idx + pof2 / (mask * 2);
274           for (i = send_idx; i < recv_idx; i++)
275             send_cnt += cnts[i];
276           for (i = recv_idx; i < last_idx; i++)
277             recv_cnt += cnts[i];
278         }
279
280         /* Send data from recvbuf. Recv into tmp_buf */
281         Request::sendrecv(static_cast<char*>(recvbuf) + disps[send_idx] * extent, send_cnt, datatype, dst, tag,
282                           tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
283
284         /* tmp_buf contains data received in this step.
285            recvbuf contains data accumulated so far */
286
287         if (op != MPI_OP_NULL)
288           op->apply(tmp_buf + disps[recv_idx] * extent, static_cast<char*>(recvbuf) + disps[recv_idx] * extent,
289                     &recv_cnt, datatype);
290
291         /* update send_idx for next iteration */
292         send_idx = recv_idx;
293         mask <<= 1;
294
295         if (mask < pof2)
296           last_idx = recv_idx + pof2 / mask;
297       }
298     }
299
300     /* now do the gather to root */
301
302     if (root < 2 * rem) {
303       if (root % 2 != 0) {
304         if (rank == root) {     /* recv */
305           for (i = 0; i < (pof2 - 1); i++)
306             cnts[i] = count / pof2;
307           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
308
309           disps[0] = 0;
310           for (i = 1; i < pof2; i++)
311             disps[i] = disps[i - 1] + cnts[i - 1];
312
313           Request::recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
314
315           newrank = 0;
316           send_idx = 0;
317           last_idx = 2;
318         } else if (newrank == 0) {
319           Request::send(recvbuf, cnts[0], datatype, root, tag, comm);
320           newrank = -1;
321         }
322         newroot = 0;
323       } else
324         newroot = root / 2;
325     } else
326       newroot = root - rem;
327
328     if (newrank != -1) {
329       j = 0;
330       mask = 0x1;
331       while (mask < pof2) {
332         mask <<= 1;
333         j++;
334       }
335       mask >>= 1;
336       j--;
337       while (mask > 0) {
338         newdst = newrank ^ mask;
339
340         /* find real rank of dest */
341         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
342
343         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
344           dst = root;
345         newdst_tree_root = newdst >> j;
346         newdst_tree_root <<= j;
347
348         newroot_tree_root = newroot >> j;
349         newroot_tree_root <<= j;
350
351         send_cnt = recv_cnt = 0;
352         if (newrank < newdst) {
353           /* update last_idx except on first iteration */
354           if (mask != pof2 / 2)
355             last_idx = last_idx + pof2 / (mask * 2);
356
357           recv_idx = send_idx + pof2 / (mask * 2);
358           for (i = send_idx; i < recv_idx; i++)
359             send_cnt += cnts[i];
360           for (i = recv_idx; i < last_idx; i++)
361             recv_cnt += cnts[i];
362         } else {
363           recv_idx = send_idx - pof2 / (mask * 2);
364           for (i = send_idx; i < last_idx; i++)
365             send_cnt += cnts[i];
366           for (i = recv_idx; i < send_idx; i++)
367             recv_cnt += cnts[i];
368         }
369
370         if (newdst_tree_root == newroot_tree_root) {
371           Request::send((char *) recvbuf +
372                    disps[send_idx] * extent,
373                    send_cnt, datatype, dst, tag, comm);
374           break;
375         } else {
376           Request::recv((char *) recvbuf +
377                    disps[recv_idx] * extent,
378                    recv_cnt, datatype, dst, tag, comm, &status);
379         }
380
381         if (newrank > newdst)
382           send_idx = recv_idx;
383
384         mask >>= 1;
385         j--;
386       }
387     }
388   }
389   if (tmp_buf)
390     smpi_free_tmp_buffer(tmp_buf);
391   if (temporary_buffer == 1)
392     smpi_free_tmp_buffer(static_cast<unsigned char*>(recvbuf));
393   delete[] cnts;
394   delete[] disps;
395
396   return 0;
397 }
398 }
399 }