Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
(painfully) constify colls.
[simgrid.git] / src / smpi / colls / reduce / reduce-scatter-gather.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8
9 /*
10   reduce
11   Author: MPICH
12  */
13 namespace simgrid{
14 namespace smpi{
15 int Coll_reduce_scatter_gather::reduce(const void *sendbuf, void *recvbuf,
16                                           int count, MPI_Datatype datatype,
17                                           MPI_Op op, int root, MPI_Comm comm)
18 {
19   MPI_Status status;
20   int comm_size, rank, pof2, rem, newrank;
21   int mask, *cnts, *disps, i, j, send_idx = 0;
22   int recv_idx, last_idx = 0, newdst;
23   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
24   int newroot_tree_root, new_count;
25   int tag = COLL_TAG_REDUCE,temporary_buffer=0;
26   void *send_ptr, *recv_ptr, *tmp_buf;
27
28   cnts = NULL;
29   disps = NULL;
30
31   MPI_Aint extent;
32
33   if (count == 0)
34     return 0;
35   rank = comm->rank();
36   comm_size = comm->size();
37
38
39
40   extent = datatype->get_extent();
41   /* If I'm not the root, then my recvbuf may not be valid, therefore
42   I have to allocate a temporary one */
43   if (rank != root && not recvbuf) {
44     temporary_buffer=1;
45     recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
46   }
47   /* find nearest power-of-two less than or equal to comm_size */
48   pof2 = 1;
49   while (pof2 <= comm_size)
50     pof2 <<= 1;
51   pof2 >>= 1;
52
53   if (count < comm_size) {
54     new_count = comm_size;
55     send_ptr = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
56     recv_ptr = (void *) smpi_get_tmp_recvbuffer(new_count * extent);
57     tmp_buf = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
58     memcpy(send_ptr, sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, extent * count);
59
60     //if ((rank != root))
61     Request::sendrecv(send_ptr, new_count, datatype, rank, tag,
62                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
63
64     rem = comm_size - pof2;
65     if (rank < 2 * rem) {
66       if (rank % 2 != 0) {
67         /* odd */
68         Request::send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
69         newrank = -1;
70       } else {
71         Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
72         if(op!=MPI_OP_NULL) op->apply( tmp_buf, recv_ptr, &new_count, datatype);
73         newrank = rank / 2;
74       }
75     } else                      /* rank >= 2*rem */
76       newrank = rank - rem;
77
78     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
79     disps = (int *) xbt_malloc(pof2 * sizeof(int));
80
81     if (newrank != -1) {
82       for (i = 0; i < (pof2 - 1); i++)
83         cnts[i] = new_count / pof2;
84       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
85
86       disps[0] = 0;
87       for (i = 1; i < pof2; i++)
88         disps[i] = disps[i - 1] + cnts[i - 1];
89
90       mask = 0x1;
91       send_idx = recv_idx = 0;
92       last_idx = pof2;
93       while (mask < pof2) {
94         newdst = newrank ^ mask;
95         /* find real rank of dest */
96         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
97
98         send_cnt = recv_cnt = 0;
99         if (newrank < newdst) {
100           send_idx = recv_idx + pof2 / (mask * 2);
101           for (i = send_idx; i < last_idx; i++)
102             send_cnt += cnts[i];
103           for (i = recv_idx; i < send_idx; i++)
104             recv_cnt += cnts[i];
105         } else {
106           recv_idx = send_idx + pof2 / (mask * 2);
107           for (i = send_idx; i < recv_idx; i++)
108             send_cnt += cnts[i];
109           for (i = recv_idx; i < last_idx; i++)
110             recv_cnt += cnts[i];
111         }
112
113         /* Send data from recvbuf. Recv into tmp_buf */
114         Request::sendrecv((char *) recv_ptr +
115                      disps[send_idx] * extent,
116                      send_cnt, datatype,
117                      dst, tag,
118                      (char *) tmp_buf +
119                      disps[recv_idx] * extent,
120                      recv_cnt, datatype, dst, tag, comm, &status);
121
122         /* tmp_buf contains data received in this step.
123            recvbuf contains data accumulated so far */
124
125         if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
126                        (char *) recv_ptr + disps[recv_idx] * extent,
127                        &recv_cnt, datatype);
128
129         /* update send_idx for next iteration */
130         send_idx = recv_idx;
131         mask <<= 1;
132
133         if (mask < pof2)
134           last_idx = recv_idx + pof2 / mask;
135       }
136     }
137
138     /* now do the gather to root */
139
140     if (root < 2 * rem) {
141       if (root % 2 != 0) {
142         if (rank == root) {
143           /* recv */
144           for (i = 0; i < (pof2 - 1); i++)
145             cnts[i] = new_count / pof2;
146           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
147
148           disps[0] = 0;
149           for (i = 1; i < pof2; i++)
150             disps[i] = disps[i - 1] + cnts[i - 1];
151
152           Request::recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
153
154           newrank = 0;
155           send_idx = 0;
156           last_idx = 2;
157         } else if (newrank == 0) {
158           Request::send(recv_ptr, cnts[0], datatype, root, tag, comm);
159           newrank = -1;
160         }
161         newroot = 0;
162       } else
163         newroot = root / 2;
164     } else
165       newroot = root - rem;
166
167     if (newrank != -1) {
168       j = 0;
169       mask = 0x1;
170       while (mask < pof2) {
171         mask <<= 1;
172         j++;
173       }
174       mask >>= 1;
175       j--;
176       while (mask > 0) {
177         newdst = newrank ^ mask;
178
179         /* find real rank of dest */
180         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
181
182         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
183           dst = root;
184         newdst_tree_root = newdst >> j;
185         newdst_tree_root <<= j;
186
187         newroot_tree_root = newroot >> j;
188         newroot_tree_root <<= j;
189
190         send_cnt = recv_cnt = 0;
191         if (newrank < newdst) {
192           /* update last_idx except on first iteration */
193           if (mask != pof2 / 2)
194             last_idx = last_idx + pof2 / (mask * 2);
195
196           recv_idx = send_idx + pof2 / (mask * 2);
197           for (i = send_idx; i < recv_idx; i++)
198             send_cnt += cnts[i];
199           for (i = recv_idx; i < last_idx; i++)
200             recv_cnt += cnts[i];
201         } else {
202           recv_idx = send_idx - pof2 / (mask * 2);
203           for (i = send_idx; i < last_idx; i++)
204             send_cnt += cnts[i];
205           for (i = recv_idx; i < send_idx; i++)
206             recv_cnt += cnts[i];
207         }
208
209         if (newdst_tree_root == newroot_tree_root) {
210           Request::send((char *) recv_ptr +
211                    disps[send_idx] * extent,
212                    send_cnt, datatype, dst, tag, comm);
213           break;
214         } else {
215           Request::recv((char *) recv_ptr +
216                    disps[recv_idx] * extent,
217                    recv_cnt, datatype, dst, tag, comm, &status);
218         }
219
220         if (newrank > newdst)
221           send_idx = recv_idx;
222
223         mask >>= 1;
224         j--;
225       }
226     }
227     memcpy(recvbuf, recv_ptr, extent * count);
228     smpi_free_tmp_buffer(send_ptr);
229     smpi_free_tmp_buffer(recv_ptr);
230   }
231
232
233   else /* (count >= comm_size) */ {
234     tmp_buf = (void *) smpi_get_tmp_sendbuffer(count * extent);
235
236     //if ((rank != root))
237     Request::sendrecv(sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, count, datatype, rank, tag,
238                  recvbuf, count, datatype, rank, tag, comm, &status);
239
240     rem = comm_size - pof2;
241     if (rank < 2 * rem) {
242       if (rank % 2 != 0) {      /* odd */
243         Request::send(recvbuf, count, datatype, rank - 1, tag, comm);
244         newrank = -1;
245       }
246
247       else {
248         Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
249         if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
250         newrank = rank / 2;
251       }
252     } else                      /* rank >= 2*rem */
253       newrank = rank - rem;
254
255     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
256     disps = (int *) xbt_malloc(pof2 * sizeof(int));
257
258     if (newrank != -1) {
259       for (i = 0; i < (pof2 - 1); i++)
260         cnts[i] = count / pof2;
261       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
262
263       disps[0] = 0;
264       for (i = 1; i < pof2; i++)
265         disps[i] = disps[i - 1] + cnts[i - 1];
266
267       mask = 0x1;
268       send_idx = recv_idx = 0;
269       last_idx = pof2;
270       while (mask < pof2) {
271         newdst = newrank ^ mask;
272         /* find real rank of dest */
273         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
274
275         send_cnt = recv_cnt = 0;
276         if (newrank < newdst) {
277           send_idx = recv_idx + pof2 / (mask * 2);
278           for (i = send_idx; i < last_idx; i++)
279             send_cnt += cnts[i];
280           for (i = recv_idx; i < send_idx; i++)
281             recv_cnt += cnts[i];
282         } else {
283           recv_idx = send_idx + pof2 / (mask * 2);
284           for (i = send_idx; i < recv_idx; i++)
285             send_cnt += cnts[i];
286           for (i = recv_idx; i < last_idx; i++)
287             recv_cnt += cnts[i];
288         }
289
290         /* Send data from recvbuf. Recv into tmp_buf */
291         Request::sendrecv((char *) recvbuf +
292                      disps[send_idx] * extent,
293                      send_cnt, datatype,
294                      dst, tag,
295                      (char *) tmp_buf +
296                      disps[recv_idx] * extent,
297                      recv_cnt, datatype, dst, tag, comm, &status);
298
299         /* tmp_buf contains data received in this step.
300            recvbuf contains data accumulated so far */
301
302         if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
303                        (char *) recvbuf + disps[recv_idx] * extent,
304                        &recv_cnt, datatype);
305
306         /* update send_idx for next iteration */
307         send_idx = recv_idx;
308         mask <<= 1;
309
310         if (mask < pof2)
311           last_idx = recv_idx + pof2 / mask;
312       }
313     }
314
315     /* now do the gather to root */
316
317     if (root < 2 * rem) {
318       if (root % 2 != 0) {
319         if (rank == root) {     /* recv */
320           for (i = 0; i < (pof2 - 1); i++)
321             cnts[i] = count / pof2;
322           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
323
324           disps[0] = 0;
325           for (i = 1; i < pof2; i++)
326             disps[i] = disps[i - 1] + cnts[i - 1];
327
328           Request::recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
329
330           newrank = 0;
331           send_idx = 0;
332           last_idx = 2;
333         } else if (newrank == 0) {
334           Request::send(recvbuf, cnts[0], datatype, root, tag, comm);
335           newrank = -1;
336         }
337         newroot = 0;
338       } else
339         newroot = root / 2;
340     } else
341       newroot = root - rem;
342
343     if (newrank != -1) {
344       j = 0;
345       mask = 0x1;
346       while (mask < pof2) {
347         mask <<= 1;
348         j++;
349       }
350       mask >>= 1;
351       j--;
352       while (mask > 0) {
353         newdst = newrank ^ mask;
354
355         /* find real rank of dest */
356         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
357
358         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
359           dst = root;
360         newdst_tree_root = newdst >> j;
361         newdst_tree_root <<= j;
362
363         newroot_tree_root = newroot >> j;
364         newroot_tree_root <<= j;
365
366         send_cnt = recv_cnt = 0;
367         if (newrank < newdst) {
368           /* update last_idx except on first iteration */
369           if (mask != pof2 / 2)
370             last_idx = last_idx + pof2 / (mask * 2);
371
372           recv_idx = send_idx + pof2 / (mask * 2);
373           for (i = send_idx; i < recv_idx; i++)
374             send_cnt += cnts[i];
375           for (i = recv_idx; i < last_idx; i++)
376             recv_cnt += cnts[i];
377         } else {
378           recv_idx = send_idx - pof2 / (mask * 2);
379           for (i = send_idx; i < last_idx; i++)
380             send_cnt += cnts[i];
381           for (i = recv_idx; i < send_idx; i++)
382             recv_cnt += cnts[i];
383         }
384
385         if (newdst_tree_root == newroot_tree_root) {
386           Request::send((char *) recvbuf +
387                    disps[send_idx] * extent,
388                    send_cnt, datatype, dst, tag, comm);
389           break;
390         } else {
391           Request::recv((char *) recvbuf +
392                    disps[recv_idx] * extent,
393                    recv_cnt, datatype, dst, tag, comm, &status);
394         }
395
396         if (newrank > newdst)
397           send_idx = recv_idx;
398
399         mask >>= 1;
400         j--;
401       }
402     }
403   }
404   if (tmp_buf)
405     smpi_free_tmp_buffer(tmp_buf);
406   if(temporary_buffer==1) smpi_free_tmp_buffer(recvbuf);
407   if (cnts)
408     free(cnts);
409   if (disps)
410     free(disps);
411
412   return 0;
413 }
414 }
415 }