Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MPI_Comm -> C++
[simgrid.git] / src / smpi / colls / reduce-scatter-gather.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8
9 /*
10   reduce
11   Author: MPICH
12  */
13
14 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
15                                           int count, MPI_Datatype datatype,
16                                           MPI_Op op, int root, MPI_Comm comm)
17 {
18   MPI_Status status;
19   int comm_size, rank, pof2, rem, newrank;
20   int mask, *cnts, *disps, i, j, send_idx = 0;
21   int recv_idx, last_idx = 0, newdst;
22   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23   int newroot_tree_root, new_count;
24   int tag = COLL_TAG_REDUCE,temporary_buffer=0;
25   void *send_ptr, *recv_ptr, *tmp_buf;
26
27   cnts = NULL;
28   disps = NULL;
29
30   MPI_Aint extent;
31
32   if (count == 0)
33     return 0;
34   rank = comm->rank();
35   comm_size = comm->size();
36   
37
38
39   extent = smpi_datatype_get_extent(datatype);
40   /* If I'm not the root, then my recvbuf may not be valid, therefore
41   I have to allocate a temporary one */
42   if (rank != root && !recvbuf) {
43     temporary_buffer=1;
44     recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
45   }
46   /* find nearest power-of-two less than or equal to comm_size */
47   pof2 = 1;
48   while (pof2 <= comm_size)
49     pof2 <<= 1;
50   pof2 >>= 1;
51
52   if (count < comm_size) {
53     new_count = comm_size;
54     send_ptr = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
55     recv_ptr = (void *) smpi_get_tmp_recvbuffer(new_count * extent);
56     tmp_buf = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
57     memcpy(send_ptr, sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, extent * count);
58
59     //if ((rank != root))
60     smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
61                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
62
63     rem = comm_size - pof2;
64     if (rank < 2 * rem) {
65       if (rank % 2 != 0) {
66         /* odd */
67         smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
68         newrank = -1;
69       } else {
70         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
71         smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
72         newrank = rank / 2;
73       }
74     } else                      /* rank >= 2*rem */
75       newrank = rank - rem;
76
77     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
78     disps = (int *) xbt_malloc(pof2 * sizeof(int));
79
80     if (newrank != -1) {
81       for (i = 0; i < (pof2 - 1); i++)
82         cnts[i] = new_count / pof2;
83       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
84
85       disps[0] = 0;
86       for (i = 1; i < pof2; i++)
87         disps[i] = disps[i - 1] + cnts[i - 1];
88
89       mask = 0x1;
90       send_idx = recv_idx = 0;
91       last_idx = pof2;
92       while (mask < pof2) {
93         newdst = newrank ^ mask;
94         /* find real rank of dest */
95         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
96
97         send_cnt = recv_cnt = 0;
98         if (newrank < newdst) {
99           send_idx = recv_idx + pof2 / (mask * 2);
100           for (i = send_idx; i < last_idx; i++)
101             send_cnt += cnts[i];
102           for (i = recv_idx; i < send_idx; i++)
103             recv_cnt += cnts[i];
104         } else {
105           recv_idx = send_idx + pof2 / (mask * 2);
106           for (i = send_idx; i < recv_idx; i++)
107             send_cnt += cnts[i];
108           for (i = recv_idx; i < last_idx; i++)
109             recv_cnt += cnts[i];
110         }
111
112         /* Send data from recvbuf. Recv into tmp_buf */
113         smpi_mpi_sendrecv((char *) recv_ptr +
114                      disps[send_idx] * extent,
115                      send_cnt, datatype,
116                      dst, tag,
117                      (char *) tmp_buf +
118                      disps[recv_idx] * extent,
119                      recv_cnt, datatype, dst, tag, comm, &status);
120
121         /* tmp_buf contains data received in this step.
122            recvbuf contains data accumulated so far */
123
124         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
125                        (char *) recv_ptr + disps[recv_idx] * extent,
126                        &recv_cnt, &datatype);
127
128         /* update send_idx for next iteration */
129         send_idx = recv_idx;
130         mask <<= 1;
131
132         if (mask < pof2)
133           last_idx = recv_idx + pof2 / mask;
134       }
135     }
136
137     /* now do the gather to root */
138
139     if (root < 2 * rem) {
140       if (root % 2 != 0) {
141         if (rank == root) {
142           /* recv */
143           for (i = 0; i < (pof2 - 1); i++)
144             cnts[i] = new_count / pof2;
145           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
146
147           disps[0] = 0;
148           for (i = 1; i < pof2; i++)
149             disps[i] = disps[i - 1] + cnts[i - 1];
150
151           smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
152
153           newrank = 0;
154           send_idx = 0;
155           last_idx = 2;
156         } else if (newrank == 0) {
157           smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
158           newrank = -1;
159         }
160         newroot = 0;
161       } else
162         newroot = root / 2;
163     } else
164       newroot = root - rem;
165
166     if (newrank != -1) {
167       j = 0;
168       mask = 0x1;
169       while (mask < pof2) {
170         mask <<= 1;
171         j++;
172       }
173       mask >>= 1;
174       j--;
175       while (mask > 0) {
176         newdst = newrank ^ mask;
177
178         /* find real rank of dest */
179         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
180
181         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
182           dst = root;
183         newdst_tree_root = newdst >> j;
184         newdst_tree_root <<= j;
185
186         newroot_tree_root = newroot >> j;
187         newroot_tree_root <<= j;
188
189         send_cnt = recv_cnt = 0;
190         if (newrank < newdst) {
191           /* update last_idx except on first iteration */
192           if (mask != pof2 / 2)
193             last_idx = last_idx + pof2 / (mask * 2);
194
195           recv_idx = send_idx + pof2 / (mask * 2);
196           for (i = send_idx; i < recv_idx; i++)
197             send_cnt += cnts[i];
198           for (i = recv_idx; i < last_idx; i++)
199             recv_cnt += cnts[i];
200         } else {
201           recv_idx = send_idx - pof2 / (mask * 2);
202           for (i = send_idx; i < last_idx; i++)
203             send_cnt += cnts[i];
204           for (i = recv_idx; i < send_idx; i++)
205             recv_cnt += cnts[i];
206         }
207
208         if (newdst_tree_root == newroot_tree_root) {
209           smpi_mpi_send((char *) recv_ptr +
210                    disps[send_idx] * extent,
211                    send_cnt, datatype, dst, tag, comm);
212           break;
213         } else {
214           smpi_mpi_recv((char *) recv_ptr +
215                    disps[recv_idx] * extent,
216                    recv_cnt, datatype, dst, tag, comm, &status);
217         }
218
219         if (newrank > newdst)
220           send_idx = recv_idx;
221
222         mask >>= 1;
223         j--;
224       }
225     }
226     memcpy(recvbuf, recv_ptr, extent * count);
227     smpi_free_tmp_buffer(send_ptr);
228     smpi_free_tmp_buffer(recv_ptr);
229   }
230
231
232   else /* (count >= comm_size) */ {
233     tmp_buf = (void *) smpi_get_tmp_sendbuffer(count * extent);
234
235     //if ((rank != root))
236     smpi_mpi_sendrecv(sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, count, datatype, rank, tag,
237                  recvbuf, count, datatype, rank, tag, comm, &status);
238
239     rem = comm_size - pof2;
240     if (rank < 2 * rem) {
241       if (rank % 2 != 0) {      /* odd */
242         smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
243         newrank = -1;
244       }
245
246       else {
247         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
248         smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
249         newrank = rank / 2;
250       }
251     } else                      /* rank >= 2*rem */
252       newrank = rank - rem;
253
254     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
255     disps = (int *) xbt_malloc(pof2 * sizeof(int));
256
257     if (newrank != -1) {
258       for (i = 0; i < (pof2 - 1); i++)
259         cnts[i] = count / pof2;
260       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
261
262       disps[0] = 0;
263       for (i = 1; i < pof2; i++)
264         disps[i] = disps[i - 1] + cnts[i - 1];
265
266       mask = 0x1;
267       send_idx = recv_idx = 0;
268       last_idx = pof2;
269       while (mask < pof2) {
270         newdst = newrank ^ mask;
271         /* find real rank of dest */
272         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
273
274         send_cnt = recv_cnt = 0;
275         if (newrank < newdst) {
276           send_idx = recv_idx + pof2 / (mask * 2);
277           for (i = send_idx; i < last_idx; i++)
278             send_cnt += cnts[i];
279           for (i = recv_idx; i < send_idx; i++)
280             recv_cnt += cnts[i];
281         } else {
282           recv_idx = send_idx + pof2 / (mask * 2);
283           for (i = send_idx; i < recv_idx; i++)
284             send_cnt += cnts[i];
285           for (i = recv_idx; i < last_idx; i++)
286             recv_cnt += cnts[i];
287         }
288
289         /* Send data from recvbuf. Recv into tmp_buf */
290         smpi_mpi_sendrecv((char *) recvbuf +
291                      disps[send_idx] * extent,
292                      send_cnt, datatype,
293                      dst, tag,
294                      (char *) tmp_buf +
295                      disps[recv_idx] * extent,
296                      recv_cnt, datatype, dst, tag, comm, &status);
297
298         /* tmp_buf contains data received in this step.
299            recvbuf contains data accumulated so far */
300
301         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
302                        (char *) recvbuf + disps[recv_idx] * extent,
303                        &recv_cnt, &datatype);
304
305         /* update send_idx for next iteration */
306         send_idx = recv_idx;
307         mask <<= 1;
308
309         if (mask < pof2)
310           last_idx = recv_idx + pof2 / mask;
311       }
312     }
313
314     /* now do the gather to root */
315
316     if (root < 2 * rem) {
317       if (root % 2 != 0) {
318         if (rank == root) {     /* recv */
319           for (i = 0; i < (pof2 - 1); i++)
320             cnts[i] = count / pof2;
321           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
322
323           disps[0] = 0;
324           for (i = 1; i < pof2; i++)
325             disps[i] = disps[i - 1] + cnts[i - 1];
326
327           smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
328
329           newrank = 0;
330           send_idx = 0;
331           last_idx = 2;
332         } else if (newrank == 0) {
333           smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
334           newrank = -1;
335         }
336         newroot = 0;
337       } else
338         newroot = root / 2;
339     } else
340       newroot = root - rem;
341
342     if (newrank != -1) {
343       j = 0;
344       mask = 0x1;
345       while (mask < pof2) {
346         mask <<= 1;
347         j++;
348       }
349       mask >>= 1;
350       j--;
351       while (mask > 0) {
352         newdst = newrank ^ mask;
353
354         /* find real rank of dest */
355         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
356
357         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
358           dst = root;
359         newdst_tree_root = newdst >> j;
360         newdst_tree_root <<= j;
361
362         newroot_tree_root = newroot >> j;
363         newroot_tree_root <<= j;
364
365         send_cnt = recv_cnt = 0;
366         if (newrank < newdst) {
367           /* update last_idx except on first iteration */
368           if (mask != pof2 / 2)
369             last_idx = last_idx + pof2 / (mask * 2);
370
371           recv_idx = send_idx + pof2 / (mask * 2);
372           for (i = send_idx; i < recv_idx; i++)
373             send_cnt += cnts[i];
374           for (i = recv_idx; i < last_idx; i++)
375             recv_cnt += cnts[i];
376         } else {
377           recv_idx = send_idx - pof2 / (mask * 2);
378           for (i = send_idx; i < last_idx; i++)
379             send_cnt += cnts[i];
380           for (i = recv_idx; i < send_idx; i++)
381             recv_cnt += cnts[i];
382         }
383
384         if (newdst_tree_root == newroot_tree_root) {
385           smpi_mpi_send((char *) recvbuf +
386                    disps[send_idx] * extent,
387                    send_cnt, datatype, dst, tag, comm);
388           break;
389         } else {
390           smpi_mpi_recv((char *) recvbuf +
391                    disps[recv_idx] * extent,
392                    recv_cnt, datatype, dst, tag, comm, &status);
393         }
394
395         if (newrank > newdst)
396           send_idx = recv_idx;
397
398         mask >>= 1;
399         j--;
400       }
401     }
402   }
403   if (tmp_buf)
404     smpi_free_tmp_buffer(tmp_buf);
405   if(temporary_buffer==1) smpi_free_tmp_buffer(recvbuf);
406   if (cnts)
407     free(cnts);
408   if (disps)
409     free(disps);
410
411   return 0;
412 }