Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add new entry in Release_Notes.
[simgrid.git] / src / smpi / colls / reduce / reduce-scatter-gather.cpp
1 /* Copyright (c) 2013-2023. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8
9 /*
10   reduce
11   Author: MPICH
12  */
13 namespace simgrid::smpi {
14 int reduce__scatter_gather(const void *sendbuf, void *recvbuf,
15                            int count, MPI_Datatype datatype,
16                            MPI_Op op, int root, MPI_Comm comm)
17 {
18   MPI_Status status;
19   int comm_size, rank, pof2, rem, newrank;
20   int mask, *cnts, *disps, i, j, send_idx = 0;
21   int recv_idx, last_idx = 0, newdst;
22   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23   int newroot_tree_root, new_count;
24   int tag = COLL_TAG_REDUCE,temporary_buffer=0;
25   unsigned char *send_ptr, *recv_ptr, *tmp_buf;
26
27   cnts  = nullptr;
28   disps = nullptr;
29
30   MPI_Aint extent;
31
32   if (count == 0)
33     return 0;
34   rank = comm->rank();
35   comm_size = comm->size();
36
37
38
39   extent = datatype->get_extent();
40   /* If I'm not the root, then my recvbuf may not be valid, therefore
41   I have to allocate a temporary one */
42   if (rank != root && not recvbuf) {
43     temporary_buffer=1;
44     recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
45   }
46   /* find nearest power-of-two less than or equal to comm_size */
47   pof2 = 1;
48   while (pof2 <= comm_size)
49     pof2 <<= 1;
50   pof2 >>= 1;
51
52   if (count < comm_size) {
53     new_count = comm_size;
54     send_ptr  = smpi_get_tmp_sendbuffer(new_count * extent);
55     recv_ptr  = smpi_get_tmp_recvbuffer(new_count * extent);
56     tmp_buf   = smpi_get_tmp_sendbuffer(new_count * extent);
57     memcpy(send_ptr, sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, extent * count);
58
59     //if ((rank != root))
60     Request::sendrecv(send_ptr, new_count, datatype, rank, tag,
61                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
62
63     rem = comm_size - pof2;
64     if (rank < 2 * rem) {
65       if (rank % 2 != 0) {
66         /* odd */
67         Request::send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
68         newrank = -1;
69       } else {
70         Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
71         if(op!=MPI_OP_NULL) op->apply( tmp_buf, recv_ptr, &new_count, datatype);
72         newrank = rank / 2;
73       }
74     } else                      /* rank >= 2*rem */
75       newrank = rank - rem;
76
77     cnts  = new int[pof2];
78     disps = new int[pof2];
79
80     if (newrank != -1) {
81       for (i = 0; i < (pof2 - 1); i++)
82         cnts[i] = new_count / pof2;
83       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
84
85       disps[0] = 0;
86       for (i = 1; i < pof2; i++)
87         disps[i] = disps[i - 1] + cnts[i - 1];
88
89       mask = 0x1;
90       send_idx = recv_idx = 0;
91       last_idx = pof2;
92       while (mask < pof2) {
93         newdst = newrank ^ mask;
94         /* find real rank of dest */
95         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
96
97         send_cnt = recv_cnt = 0;
98         if (newrank < newdst) {
99           send_idx = recv_idx + pof2 / (mask * 2);
100           for (i = send_idx; i < last_idx; i++)
101             send_cnt += cnts[i];
102           for (i = recv_idx; i < send_idx; i++)
103             recv_cnt += cnts[i];
104         } else {
105           recv_idx = send_idx + pof2 / (mask * 2);
106           for (i = send_idx; i < recv_idx; i++)
107             send_cnt += cnts[i];
108           for (i = recv_idx; i < last_idx; i++)
109             recv_cnt += cnts[i];
110         }
111
112         /* Send data from recvbuf. Recv into tmp_buf */
113         Request::sendrecv(recv_ptr + disps[send_idx] * extent, send_cnt, datatype, dst, tag,
114                           tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
115
116         /* tmp_buf contains data received in this step.
117            recvbuf contains data accumulated so far */
118
119         if (op != MPI_OP_NULL)
120           op->apply(tmp_buf + disps[recv_idx] * extent, recv_ptr + disps[recv_idx] * extent, &recv_cnt, datatype);
121
122         /* update send_idx for next iteration */
123         send_idx = recv_idx;
124         mask <<= 1;
125
126         if (mask < pof2)
127           last_idx = recv_idx + pof2 / mask;
128       }
129     }
130
131     /* now do the gather to root */
132
133     if (root < 2 * rem) {
134       if (root % 2 != 0) {
135         if (rank == root) {
136           /* recv */
137           for (i = 0; i < (pof2 - 1); i++)
138             cnts[i] = new_count / pof2;
139           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
140
141           disps[0] = 0;
142           for (i = 1; i < pof2; i++)
143             disps[i] = disps[i - 1] + cnts[i - 1];
144
145           Request::recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
146
147           newrank = 0;
148           send_idx = 0;
149           last_idx = 2;
150         } else if (newrank == 0) {
151           Request::send(recv_ptr, cnts[0], datatype, root, tag, comm);
152           newrank = -1;
153         }
154         newroot = 0;
155       } else
156         newroot = root / 2;
157     } else
158       newroot = root - rem;
159
160     if (newrank != -1) {
161       j = 0;
162       mask = 0x1;
163       while (mask < pof2) {
164         mask <<= 1;
165         j++;
166       }
167       mask >>= 1;
168       j--;
169       while (mask > 0) {
170         newdst = newrank ^ mask;
171
172         /* find real rank of dest */
173         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
174
175         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
176           dst = root;
177         newdst_tree_root = newdst >> j;
178         newdst_tree_root <<= j;
179
180         newroot_tree_root = newroot >> j;
181         newroot_tree_root <<= j;
182
183         send_cnt = recv_cnt = 0;
184         if (newrank < newdst) {
185           /* update last_idx except on first iteration */
186           if (mask != pof2 / 2)
187             last_idx = last_idx + pof2 / (mask * 2);
188
189           recv_idx = send_idx + pof2 / (mask * 2);
190           for (i = send_idx; i < recv_idx; i++)
191             send_cnt += cnts[i];
192           for (i = recv_idx; i < last_idx; i++)
193             recv_cnt += cnts[i];
194         } else {
195           recv_idx = send_idx - pof2 / (mask * 2);
196           for (i = send_idx; i < last_idx; i++)
197             send_cnt += cnts[i];
198           for (i = recv_idx; i < send_idx; i++)
199             recv_cnt += cnts[i];
200         }
201
202         if (newdst_tree_root == newroot_tree_root) {
203           Request::send(recv_ptr + disps[send_idx] * extent, send_cnt, datatype, dst, tag, comm);
204           break;
205         } else {
206           Request::recv(recv_ptr + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
207         }
208
209         if (newrank > newdst)
210           send_idx = recv_idx;
211
212         mask >>= 1;
213         j--;
214       }
215     }
216     memcpy(recvbuf, recv_ptr, extent * count);
217     smpi_free_tmp_buffer(send_ptr);
218     smpi_free_tmp_buffer(recv_ptr);
219   }
220
221
222   else /* (count >= comm_size) */ {
223     tmp_buf = smpi_get_tmp_sendbuffer(count * extent);
224
225     //if ((rank != root))
226     Request::sendrecv(sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, count, datatype, rank, tag,
227                  recvbuf, count, datatype, rank, tag, comm, &status);
228
229     rem = comm_size - pof2;
230     if (rank < 2 * rem) {
231       if (rank % 2 != 0) {      /* odd */
232         Request::send(recvbuf, count, datatype, rank - 1, tag, comm);
233         newrank = -1;
234       }
235
236       else {
237         Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
238         if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
239         newrank = rank / 2;
240       }
241     } else                      /* rank >= 2*rem */
242       newrank = rank - rem;
243
244     cnts  = new int[pof2];
245     disps = new int[pof2];
246
247     if (newrank != -1) {
248       for (i = 0; i < (pof2 - 1); i++)
249         cnts[i] = count / pof2;
250       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
251
252       disps[0] = 0;
253       for (i = 1; i < pof2; i++)
254         disps[i] = disps[i - 1] + cnts[i - 1];
255
256       mask = 0x1;
257       send_idx = recv_idx = 0;
258       last_idx = pof2;
259       while (mask < pof2) {
260         newdst = newrank ^ mask;
261         /* find real rank of dest */
262         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
263
264         send_cnt = recv_cnt = 0;
265         if (newrank < newdst) {
266           send_idx = recv_idx + pof2 / (mask * 2);
267           for (i = send_idx; i < last_idx; i++)
268             send_cnt += cnts[i];
269           for (i = recv_idx; i < send_idx; i++)
270             recv_cnt += cnts[i];
271         } else {
272           recv_idx = send_idx + pof2 / (mask * 2);
273           for (i = send_idx; i < recv_idx; i++)
274             send_cnt += cnts[i];
275           for (i = recv_idx; i < last_idx; i++)
276             recv_cnt += cnts[i];
277         }
278
279         /* Send data from recvbuf. Recv into tmp_buf */
280         Request::sendrecv(static_cast<char*>(recvbuf) + disps[send_idx] * extent, send_cnt, datatype, dst, tag,
281                           tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
282
283         /* tmp_buf contains data received in this step.
284            recvbuf contains data accumulated so far */
285
286         if (op != MPI_OP_NULL)
287           op->apply(tmp_buf + disps[recv_idx] * extent, static_cast<char*>(recvbuf) + disps[recv_idx] * extent,
288                     &recv_cnt, datatype);
289
290         /* update send_idx for next iteration */
291         send_idx = recv_idx;
292         mask <<= 1;
293
294         if (mask < pof2)
295           last_idx = recv_idx + pof2 / mask;
296       }
297     }
298
299     /* now do the gather to root */
300
301     if (root < 2 * rem) {
302       if (root % 2 != 0) {
303         if (rank == root) {     /* recv */
304           for (i = 0; i < (pof2 - 1); i++)
305             cnts[i] = count / pof2;
306           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
307
308           disps[0] = 0;
309           for (i = 1; i < pof2; i++)
310             disps[i] = disps[i - 1] + cnts[i - 1];
311
312           Request::recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
313
314           newrank = 0;
315           send_idx = 0;
316           last_idx = 2;
317         } else if (newrank == 0) {
318           Request::send(recvbuf, cnts[0], datatype, root, tag, comm);
319           newrank = -1;
320         }
321         newroot = 0;
322       } else
323         newroot = root / 2;
324     } else
325       newroot = root - rem;
326
327     if (newrank != -1) {
328       j = 0;
329       mask = 0x1;
330       while (mask < pof2) {
331         mask <<= 1;
332         j++;
333       }
334       mask >>= 1;
335       j--;
336       while (mask > 0) {
337         newdst = newrank ^ mask;
338
339         /* find real rank of dest */
340         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
341
342         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
343           dst = root;
344         newdst_tree_root = newdst >> j;
345         newdst_tree_root <<= j;
346
347         newroot_tree_root = newroot >> j;
348         newroot_tree_root <<= j;
349
350         send_cnt = recv_cnt = 0;
351         if (newrank < newdst) {
352           /* update last_idx except on first iteration */
353           if (mask != pof2 / 2)
354             last_idx = last_idx + pof2 / (mask * 2);
355
356           recv_idx = send_idx + pof2 / (mask * 2);
357           for (i = send_idx; i < recv_idx; i++)
358             send_cnt += cnts[i];
359           for (i = recv_idx; i < last_idx; i++)
360             recv_cnt += cnts[i];
361         } else {
362           recv_idx = send_idx - pof2 / (mask * 2);
363           for (i = send_idx; i < last_idx; i++)
364             send_cnt += cnts[i];
365           for (i = recv_idx; i < send_idx; i++)
366             recv_cnt += cnts[i];
367         }
368
369         if (newdst_tree_root == newroot_tree_root) {
370           Request::send((char *) recvbuf +
371                    disps[send_idx] * extent,
372                    send_cnt, datatype, dst, tag, comm);
373           break;
374         } else {
375           Request::recv((char *) recvbuf +
376                    disps[recv_idx] * extent,
377                    recv_cnt, datatype, dst, tag, comm, &status);
378         }
379
380         if (newrank > newdst)
381           send_idx = recv_idx;
382
383         mask >>= 1;
384         j--;
385       }
386     }
387   }
388   if (tmp_buf)
389     smpi_free_tmp_buffer(tmp_buf);
390   if (temporary_buffer == 1)
391     smpi_free_tmp_buffer(static_cast<unsigned char*>(recvbuf));
392   delete[] cnts;
393   delete[] disps;
394
395   return 0;
396 }
397 } // namespace simgrid::smpi