Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add Reduce SMP collective from MVAPICH2
[simgrid.git] / src / smpi / colls / reduce-scatter-gather.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8
9 /*
10   reduce
11   Author: MPICH
12  */
13
14 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
15                                           int count, MPI_Datatype datatype,
16                                           MPI_Op op, int root, MPI_Comm comm)
17 {
18   MPI_Status status;
19   int comm_size, rank, pof2, rem, newrank;
20   int mask, *cnts, *disps, i, j, send_idx = 0;
21   int recv_idx, last_idx = 0, newdst;
22   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23   int newroot_tree_root, new_count;
24   int tag = COLL_TAG_REDUCE;
25   void *send_ptr, *recv_ptr, *tmp_buf;
26
27   cnts = NULL;
28   disps = NULL;
29
30   MPI_Aint extent;
31
32   if (count == 0)
33     return 0;
34   rank = smpi_comm_rank(comm);
35   comm_size = smpi_comm_size(comm);
36   
37
38
39   extent = smpi_datatype_get_extent(datatype);
40   /* If I'm not the root, then my recvbuf may not be valid, therefore
41   I have to allocate a temporary one */
42   if (rank != root && !recvbuf) {
43     recvbuf = (void *)xbt_malloc(count * extent);
44   }
45   /* find nearest power-of-two less than or equal to comm_size */
46   pof2 = 1;
47   while (pof2 <= comm_size)
48     pof2 <<= 1;
49   pof2 >>= 1;
50
51   if (count < comm_size) {
52     new_count = comm_size;
53     send_ptr = (void *) xbt_malloc(new_count * extent);
54     recv_ptr = (void *) xbt_malloc(new_count * extent);
55     tmp_buf = (void *) xbt_malloc(new_count * extent);
56     memcpy(send_ptr, sendbuf, extent * count);
57
58     //if ((rank != root))
59     smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
60                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
61
62     rem = comm_size - pof2;
63     if (rank < 2 * rem) {
64       if (rank % 2 != 0) {
65         /* odd */
66         smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
67         newrank = -1;
68       } else {
69         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
70         smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
71         newrank = rank / 2;
72       }
73     } else                      /* rank >= 2*rem */
74       newrank = rank - rem;
75
76     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
77     disps = (int *) xbt_malloc(pof2 * sizeof(int));
78
79     if (newrank != -1) {
80       for (i = 0; i < (pof2 - 1); i++)
81         cnts[i] = new_count / pof2;
82       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
83
84       disps[0] = 0;
85       for (i = 1; i < pof2; i++)
86         disps[i] = disps[i - 1] + cnts[i - 1];
87
88       mask = 0x1;
89       send_idx = recv_idx = 0;
90       last_idx = pof2;
91       while (mask < pof2) {
92         newdst = newrank ^ mask;
93         /* find real rank of dest */
94         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
95
96         send_cnt = recv_cnt = 0;
97         if (newrank < newdst) {
98           send_idx = recv_idx + pof2 / (mask * 2);
99           for (i = send_idx; i < last_idx; i++)
100             send_cnt += cnts[i];
101           for (i = recv_idx; i < send_idx; i++)
102             recv_cnt += cnts[i];
103         } else {
104           recv_idx = send_idx + pof2 / (mask * 2);
105           for (i = send_idx; i < recv_idx; i++)
106             send_cnt += cnts[i];
107           for (i = recv_idx; i < last_idx; i++)
108             recv_cnt += cnts[i];
109         }
110
111         /* Send data from recvbuf. Recv into tmp_buf */
112         smpi_mpi_sendrecv((char *) recv_ptr +
113                      disps[send_idx] * extent,
114                      send_cnt, datatype,
115                      dst, tag,
116                      (char *) tmp_buf +
117                      disps[recv_idx] * extent,
118                      recv_cnt, datatype, dst, tag, comm, &status);
119
120         /* tmp_buf contains data received in this step.
121            recvbuf contains data accumulated so far */
122
123         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
124                        (char *) recv_ptr + disps[recv_idx] * extent,
125                        &recv_cnt, &datatype);
126
127         /* update send_idx for next iteration */
128         send_idx = recv_idx;
129         mask <<= 1;
130
131         if (mask < pof2)
132           last_idx = recv_idx + pof2 / mask;
133       }
134     }
135
136     /* now do the gather to root */
137
138     if (root < 2 * rem) {
139       if (root % 2 != 0) {
140         if (rank == root) {
141           /* recv */
142           for (i = 0; i < (pof2 - 1); i++)
143             cnts[i] = new_count / pof2;
144           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
145
146           disps[0] = 0;
147           for (i = 1; i < pof2; i++)
148             disps[i] = disps[i - 1] + cnts[i - 1];
149
150           smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
151
152           newrank = 0;
153           send_idx = 0;
154           last_idx = 2;
155         } else if (newrank == 0) {
156           smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
157           newrank = -1;
158         }
159         newroot = 0;
160       } else
161         newroot = root / 2;
162     } else
163       newroot = root - rem;
164
165     if (newrank != -1) {
166       j = 0;
167       mask = 0x1;
168       while (mask < pof2) {
169         mask <<= 1;
170         j++;
171       }
172       mask >>= 1;
173       j--;
174       while (mask > 0) {
175         newdst = newrank ^ mask;
176
177         /* find real rank of dest */
178         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
179
180         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
181           dst = root;
182         newdst_tree_root = newdst >> j;
183         newdst_tree_root <<= j;
184
185         newroot_tree_root = newroot >> j;
186         newroot_tree_root <<= j;
187
188         send_cnt = recv_cnt = 0;
189         if (newrank < newdst) {
190           /* update last_idx except on first iteration */
191           if (mask != pof2 / 2)
192             last_idx = last_idx + pof2 / (mask * 2);
193
194           recv_idx = send_idx + pof2 / (mask * 2);
195           for (i = send_idx; i < recv_idx; i++)
196             send_cnt += cnts[i];
197           for (i = recv_idx; i < last_idx; i++)
198             recv_cnt += cnts[i];
199         } else {
200           recv_idx = send_idx - pof2 / (mask * 2);
201           for (i = send_idx; i < last_idx; i++)
202             send_cnt += cnts[i];
203           for (i = recv_idx; i < send_idx; i++)
204             recv_cnt += cnts[i];
205         }
206
207         if (newdst_tree_root == newroot_tree_root) {
208           smpi_mpi_send((char *) recv_ptr +
209                    disps[send_idx] * extent,
210                    send_cnt, datatype, dst, tag, comm);
211           break;
212         } else {
213           smpi_mpi_recv((char *) recv_ptr +
214                    disps[recv_idx] * extent,
215                    recv_cnt, datatype, dst, tag, comm, &status);
216         }
217
218         if (newrank > newdst)
219           send_idx = recv_idx;
220
221         mask >>= 1;
222         j--;
223       }
224     }
225     memcpy(recvbuf, recv_ptr, extent * count);
226     free(send_ptr);
227     free(recv_ptr);
228   }
229
230
231   else /* (count >= comm_size) */ {
232     tmp_buf = (void *) xbt_malloc(count * extent);
233
234     //if ((rank != root))
235     smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
236                  recvbuf, count, datatype, rank, tag, comm, &status);
237
238     rem = comm_size - pof2;
239     if (rank < 2 * rem) {
240       if (rank % 2 != 0) {      /* odd */
241         smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
242         newrank = -1;
243       }
244
245       else {
246         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
247         smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
248         newrank = rank / 2;
249       }
250     } else                      /* rank >= 2*rem */
251       newrank = rank - rem;
252
253     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
254     disps = (int *) xbt_malloc(pof2 * sizeof(int));
255
256     if (newrank != -1) {
257       for (i = 0; i < (pof2 - 1); i++)
258         cnts[i] = count / pof2;
259       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
260
261       disps[0] = 0;
262       for (i = 1; i < pof2; i++)
263         disps[i] = disps[i - 1] + cnts[i - 1];
264
265       mask = 0x1;
266       send_idx = recv_idx = 0;
267       last_idx = pof2;
268       while (mask < pof2) {
269         newdst = newrank ^ mask;
270         /* find real rank of dest */
271         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
272
273         send_cnt = recv_cnt = 0;
274         if (newrank < newdst) {
275           send_idx = recv_idx + pof2 / (mask * 2);
276           for (i = send_idx; i < last_idx; i++)
277             send_cnt += cnts[i];
278           for (i = recv_idx; i < send_idx; i++)
279             recv_cnt += cnts[i];
280         } else {
281           recv_idx = send_idx + pof2 / (mask * 2);
282           for (i = send_idx; i < recv_idx; i++)
283             send_cnt += cnts[i];
284           for (i = recv_idx; i < last_idx; i++)
285             recv_cnt += cnts[i];
286         }
287
288         /* Send data from recvbuf. Recv into tmp_buf */
289         smpi_mpi_sendrecv((char *) recvbuf +
290                      disps[send_idx] * extent,
291                      send_cnt, datatype,
292                      dst, tag,
293                      (char *) tmp_buf +
294                      disps[recv_idx] * extent,
295                      recv_cnt, datatype, dst, tag, comm, &status);
296
297         /* tmp_buf contains data received in this step.
298            recvbuf contains data accumulated so far */
299
300         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
301                        (char *) recvbuf + disps[recv_idx] * extent,
302                        &recv_cnt, &datatype);
303
304         /* update send_idx for next iteration */
305         send_idx = recv_idx;
306         mask <<= 1;
307
308         if (mask < pof2)
309           last_idx = recv_idx + pof2 / mask;
310       }
311     }
312
313     /* now do the gather to root */
314
315     if (root < 2 * rem) {
316       if (root % 2 != 0) {
317         if (rank == root) {     /* recv */
318           for (i = 0; i < (pof2 - 1); i++)
319             cnts[i] = count / pof2;
320           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
321
322           disps[0] = 0;
323           for (i = 1; i < pof2; i++)
324             disps[i] = disps[i - 1] + cnts[i - 1];
325
326           smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
327
328           newrank = 0;
329           send_idx = 0;
330           last_idx = 2;
331         } else if (newrank == 0) {
332           smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
333           newrank = -1;
334         }
335         newroot = 0;
336       } else
337         newroot = root / 2;
338     } else
339       newroot = root - rem;
340
341     if (newrank != -1) {
342       j = 0;
343       mask = 0x1;
344       while (mask < pof2) {
345         mask <<= 1;
346         j++;
347       }
348       mask >>= 1;
349       j--;
350       while (mask > 0) {
351         newdst = newrank ^ mask;
352
353         /* find real rank of dest */
354         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
355
356         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
357           dst = root;
358         newdst_tree_root = newdst >> j;
359         newdst_tree_root <<= j;
360
361         newroot_tree_root = newroot >> j;
362         newroot_tree_root <<= j;
363
364         send_cnt = recv_cnt = 0;
365         if (newrank < newdst) {
366           /* update last_idx except on first iteration */
367           if (mask != pof2 / 2)
368             last_idx = last_idx + pof2 / (mask * 2);
369
370           recv_idx = send_idx + pof2 / (mask * 2);
371           for (i = send_idx; i < recv_idx; i++)
372             send_cnt += cnts[i];
373           for (i = recv_idx; i < last_idx; i++)
374             recv_cnt += cnts[i];
375         } else {
376           recv_idx = send_idx - pof2 / (mask * 2);
377           for (i = send_idx; i < last_idx; i++)
378             send_cnt += cnts[i];
379           for (i = recv_idx; i < send_idx; i++)
380             recv_cnt += cnts[i];
381         }
382
383         if (newdst_tree_root == newroot_tree_root) {
384           smpi_mpi_send((char *) recvbuf +
385                    disps[send_idx] * extent,
386                    send_cnt, datatype, dst, tag, comm);
387           break;
388         } else {
389           smpi_mpi_recv((char *) recvbuf +
390                    disps[recv_idx] * extent,
391                    recv_cnt, datatype, dst, tag, comm, &status);
392         }
393
394         if (newrank > newdst)
395           send_idx = recv_idx;
396
397         mask >>= 1;
398         j--;
399       }
400     }
401   }
402   if (tmp_buf)
403     free(tmp_buf);
404   if (cnts)
405     free(cnts);
406   if (disps)
407     free(disps);
408
409   return 0;
410 }