Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add/update copyright notices.
[simgrid.git] / src / smpi / colls / reduce-scatter-gather.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8
9 /*
10   reduce
11   Author: MPICH
12  */
13
14 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
15                                           int count, MPI_Datatype datatype,
16                                           MPI_Op op, int root, MPI_Comm comm)
17 {
18   MPI_Status status;
19   int comm_size, rank, pof2, rem, newrank;
20   int mask, *cnts, *disps, i, j, send_idx = 0;
21   int recv_idx, last_idx = 0, newdst;
22   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23   int newroot_tree_root, new_count;
24   int tag = COLL_TAG_REDUCE;
25   void *send_ptr, *recv_ptr, *tmp_buf;
26
27   cnts = NULL;
28   disps = NULL;
29
30   MPI_Aint extent;
31
32   if (count == 0)
33     return 0;
34   rank = smpi_comm_rank(comm);
35   comm_size = smpi_comm_size(comm);
36
37   extent = smpi_datatype_get_extent(datatype);
38
39   /* find nearest power-of-two less than or equal to comm_size */
40   pof2 = 1;
41   while (pof2 <= comm_size)
42     pof2 <<= 1;
43   pof2 >>= 1;
44
45   if (count < comm_size) {
46     new_count = comm_size;
47     send_ptr = (void *) xbt_malloc(new_count * extent);
48     recv_ptr = (void *) xbt_malloc(new_count * extent);
49     tmp_buf = (void *) xbt_malloc(new_count * extent);
50     memcpy(send_ptr, sendbuf, extent * count);
51
52     //if ((rank != root))
53     smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
54                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
55
56     rem = comm_size - pof2;
57     if (rank < 2 * rem) {
58       if (rank % 2 != 0) {
59         /* odd */
60         smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
61         newrank = -1;
62       } else {
63         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
64         smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
65         newrank = rank / 2;
66       }
67     } else                      /* rank >= 2*rem */
68       newrank = rank - rem;
69
70     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
71     disps = (int *) xbt_malloc(pof2 * sizeof(int));
72
73     if (newrank != -1) {
74       for (i = 0; i < (pof2 - 1); i++)
75         cnts[i] = new_count / pof2;
76       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
77
78       disps[0] = 0;
79       for (i = 1; i < pof2; i++)
80         disps[i] = disps[i - 1] + cnts[i - 1];
81
82       mask = 0x1;
83       send_idx = recv_idx = 0;
84       last_idx = pof2;
85       while (mask < pof2) {
86         newdst = newrank ^ mask;
87         /* find real rank of dest */
88         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
89
90         send_cnt = recv_cnt = 0;
91         if (newrank < newdst) {
92           send_idx = recv_idx + pof2 / (mask * 2);
93           for (i = send_idx; i < last_idx; i++)
94             send_cnt += cnts[i];
95           for (i = recv_idx; i < send_idx; i++)
96             recv_cnt += cnts[i];
97         } else {
98           recv_idx = send_idx + pof2 / (mask * 2);
99           for (i = send_idx; i < recv_idx; i++)
100             send_cnt += cnts[i];
101           for (i = recv_idx; i < last_idx; i++)
102             recv_cnt += cnts[i];
103         }
104
105         /* Send data from recvbuf. Recv into tmp_buf */
106         smpi_mpi_sendrecv((char *) recv_ptr +
107                      disps[send_idx] * extent,
108                      send_cnt, datatype,
109                      dst, tag,
110                      (char *) tmp_buf +
111                      disps[recv_idx] * extent,
112                      recv_cnt, datatype, dst, tag, comm, &status);
113
114         /* tmp_buf contains data received in this step.
115            recvbuf contains data accumulated so far */
116
117         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
118                        (char *) recv_ptr + disps[recv_idx] * extent,
119                        &recv_cnt, &datatype);
120
121         /* update send_idx for next iteration */
122         send_idx = recv_idx;
123         mask <<= 1;
124
125         if (mask < pof2)
126           last_idx = recv_idx + pof2 / mask;
127       }
128     }
129
130     /* now do the gather to root */
131
132     if (root < 2 * rem) {
133       if (root % 2 != 0) {
134         if (rank == root) {
135           /* recv */
136           for (i = 0; i < (pof2 - 1); i++)
137             cnts[i] = new_count / pof2;
138           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
139
140           disps[0] = 0;
141           for (i = 1; i < pof2; i++)
142             disps[i] = disps[i - 1] + cnts[i - 1];
143
144           smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
145
146           newrank = 0;
147           send_idx = 0;
148           last_idx = 2;
149         } else if (newrank == 0) {
150           smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
151           newrank = -1;
152         }
153         newroot = 0;
154       } else
155         newroot = root / 2;
156     } else
157       newroot = root - rem;
158
159     if (newrank != -1) {
160       j = 0;
161       mask = 0x1;
162       while (mask < pof2) {
163         mask <<= 1;
164         j++;
165       }
166       mask >>= 1;
167       j--;
168       while (mask > 0) {
169         newdst = newrank ^ mask;
170
171         /* find real rank of dest */
172         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
173
174         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
175           dst = root;
176         newdst_tree_root = newdst >> j;
177         newdst_tree_root <<= j;
178
179         newroot_tree_root = newroot >> j;
180         newroot_tree_root <<= j;
181
182         send_cnt = recv_cnt = 0;
183         if (newrank < newdst) {
184           /* update last_idx except on first iteration */
185           if (mask != pof2 / 2)
186             last_idx = last_idx + pof2 / (mask * 2);
187
188           recv_idx = send_idx + pof2 / (mask * 2);
189           for (i = send_idx; i < recv_idx; i++)
190             send_cnt += cnts[i];
191           for (i = recv_idx; i < last_idx; i++)
192             recv_cnt += cnts[i];
193         } else {
194           recv_idx = send_idx - pof2 / (mask * 2);
195           for (i = send_idx; i < last_idx; i++)
196             send_cnt += cnts[i];
197           for (i = recv_idx; i < send_idx; i++)
198             recv_cnt += cnts[i];
199         }
200
201         if (newdst_tree_root == newroot_tree_root) {
202           smpi_mpi_send((char *) recv_ptr +
203                    disps[send_idx] * extent,
204                    send_cnt, datatype, dst, tag, comm);
205           break;
206         } else {
207           smpi_mpi_recv((char *) recv_ptr +
208                    disps[recv_idx] * extent,
209                    recv_cnt, datatype, dst, tag, comm, &status);
210         }
211
212         if (newrank > newdst)
213           send_idx = recv_idx;
214
215         mask >>= 1;
216         j--;
217       }
218     }
219     memcpy(recvbuf, recv_ptr, extent * count);
220     free(send_ptr);
221     free(recv_ptr);
222   }
223
224
225   else /* (count >= comm_size) */ {
226     tmp_buf = (void *) xbt_malloc(count * extent);
227
228     //if ((rank != root))
229     smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
230                  recvbuf, count, datatype, rank, tag, comm, &status);
231
232     rem = comm_size - pof2;
233     if (rank < 2 * rem) {
234       if (rank % 2 != 0) {      /* odd */
235         smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
236         newrank = -1;
237       }
238
239       else {
240         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
241         smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
242         newrank = rank / 2;
243       }
244     } else                      /* rank >= 2*rem */
245       newrank = rank - rem;
246
247     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
248     disps = (int *) xbt_malloc(pof2 * sizeof(int));
249
250     if (newrank != -1) {
251       for (i = 0; i < (pof2 - 1); i++)
252         cnts[i] = count / pof2;
253       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
254
255       disps[0] = 0;
256       for (i = 1; i < pof2; i++)
257         disps[i] = disps[i - 1] + cnts[i - 1];
258
259       mask = 0x1;
260       send_idx = recv_idx = 0;
261       last_idx = pof2;
262       while (mask < pof2) {
263         newdst = newrank ^ mask;
264         /* find real rank of dest */
265         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
266
267         send_cnt = recv_cnt = 0;
268         if (newrank < newdst) {
269           send_idx = recv_idx + pof2 / (mask * 2);
270           for (i = send_idx; i < last_idx; i++)
271             send_cnt += cnts[i];
272           for (i = recv_idx; i < send_idx; i++)
273             recv_cnt += cnts[i];
274         } else {
275           recv_idx = send_idx + pof2 / (mask * 2);
276           for (i = send_idx; i < recv_idx; i++)
277             send_cnt += cnts[i];
278           for (i = recv_idx; i < last_idx; i++)
279             recv_cnt += cnts[i];
280         }
281
282         /* Send data from recvbuf. Recv into tmp_buf */
283         smpi_mpi_sendrecv((char *) recvbuf +
284                      disps[send_idx] * extent,
285                      send_cnt, datatype,
286                      dst, tag,
287                      (char *) tmp_buf +
288                      disps[recv_idx] * extent,
289                      recv_cnt, datatype, dst, tag, comm, &status);
290
291         /* tmp_buf contains data received in this step.
292            recvbuf contains data accumulated so far */
293
294         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
295                        (char *) recvbuf + disps[recv_idx] * extent,
296                        &recv_cnt, &datatype);
297
298         /* update send_idx for next iteration */
299         send_idx = recv_idx;
300         mask <<= 1;
301
302         if (mask < pof2)
303           last_idx = recv_idx + pof2 / mask;
304       }
305     }
306
307     /* now do the gather to root */
308
309     if (root < 2 * rem) {
310       if (root % 2 != 0) {
311         if (rank == root) {     /* recv */
312           for (i = 0; i < (pof2 - 1); i++)
313             cnts[i] = count / pof2;
314           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
315
316           disps[0] = 0;
317           for (i = 1; i < pof2; i++)
318             disps[i] = disps[i - 1] + cnts[i - 1];
319
320           smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
321
322           newrank = 0;
323           send_idx = 0;
324           last_idx = 2;
325         } else if (newrank == 0) {
326           smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
327           newrank = -1;
328         }
329         newroot = 0;
330       } else
331         newroot = root / 2;
332     } else
333       newroot = root - rem;
334
335     if (newrank != -1) {
336       j = 0;
337       mask = 0x1;
338       while (mask < pof2) {
339         mask <<= 1;
340         j++;
341       }
342       mask >>= 1;
343       j--;
344       while (mask > 0) {
345         newdst = newrank ^ mask;
346
347         /* find real rank of dest */
348         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
349
350         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
351           dst = root;
352         newdst_tree_root = newdst >> j;
353         newdst_tree_root <<= j;
354
355         newroot_tree_root = newroot >> j;
356         newroot_tree_root <<= j;
357
358         send_cnt = recv_cnt = 0;
359         if (newrank < newdst) {
360           /* update last_idx except on first iteration */
361           if (mask != pof2 / 2)
362             last_idx = last_idx + pof2 / (mask * 2);
363
364           recv_idx = send_idx + pof2 / (mask * 2);
365           for (i = send_idx; i < recv_idx; i++)
366             send_cnt += cnts[i];
367           for (i = recv_idx; i < last_idx; i++)
368             recv_cnt += cnts[i];
369         } else {
370           recv_idx = send_idx - pof2 / (mask * 2);
371           for (i = send_idx; i < last_idx; i++)
372             send_cnt += cnts[i];
373           for (i = recv_idx; i < send_idx; i++)
374             recv_cnt += cnts[i];
375         }
376
377         if (newdst_tree_root == newroot_tree_root) {
378           smpi_mpi_send((char *) recvbuf +
379                    disps[send_idx] * extent,
380                    send_cnt, datatype, dst, tag, comm);
381           break;
382         } else {
383           smpi_mpi_recv((char *) recvbuf +
384                    disps[recv_idx] * extent,
385                    recv_cnt, datatype, dst, tag, comm, &status);
386         }
387
388         if (newrank > newdst)
389           send_idx = recv_idx;
390
391         mask >>= 1;
392         j--;
393       }
394     }
395   }
396   if (tmp_buf)
397     free(tmp_buf);
398   if (cnts)
399     free(cnts);
400   if (disps)
401     free(disps);
402
403   return 0;
404 }