Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'surf++'
[simgrid.git] / src / smpi / colls / reduce-scatter-gather.c
1 #include "colls_private.h"
2
3 /*
4   reduce
5   Author: MPICH
6  */
7
8 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
9                                           int count, MPI_Datatype datatype,
10                                           MPI_Op op, int root, MPI_Comm comm)
11 {
12   MPI_Status status;
13   int comm_size, rank, pof2, rem, newrank;
14   int mask, *cnts, *disps, i, j, send_idx = 0;
15   int recv_idx, last_idx = 0, newdst;
16   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
17   int newroot_tree_root, new_count;
18   int tag = COLL_TAG_REDUCE;
19   void *send_ptr, *recv_ptr, *tmp_buf;
20
21   cnts = NULL;
22   disps = NULL;
23
24   MPI_Aint extent;
25
26   if (count == 0)
27     return 0;
28   rank = smpi_comm_rank(comm);
29   comm_size = smpi_comm_size(comm);
30
31   extent = smpi_datatype_get_extent(datatype);
32
33   /* find nearest power-of-two less than or equal to comm_size */
34   pof2 = 1;
35   while (pof2 <= comm_size)
36     pof2 <<= 1;
37   pof2 >>= 1;
38
39   if (count < comm_size) {
40     new_count = comm_size;
41     send_ptr = (void *) xbt_malloc(new_count * extent);
42     recv_ptr = (void *) xbt_malloc(new_count * extent);
43     tmp_buf = (void *) xbt_malloc(new_count * extent);
44     memcpy(send_ptr, sendbuf, extent * count);
45
46     //if ((rank != root))
47     smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
48                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
49
50     rem = comm_size - pof2;
51     if (rank < 2 * rem) {
52       if (rank % 2 != 0) {
53         /* odd */
54         smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
55         newrank = -1;
56       } else {
57         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
58         smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
59         newrank = rank / 2;
60       }
61     } else                      /* rank >= 2*rem */
62       newrank = rank - rem;
63
64     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
65     disps = (int *) xbt_malloc(pof2 * sizeof(int));
66
67     if (newrank != -1) {
68       for (i = 0; i < (pof2 - 1); i++)
69         cnts[i] = new_count / pof2;
70       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
71
72       disps[0] = 0;
73       for (i = 1; i < pof2; i++)
74         disps[i] = disps[i - 1] + cnts[i - 1];
75
76       mask = 0x1;
77       send_idx = recv_idx = 0;
78       last_idx = pof2;
79       while (mask < pof2) {
80         newdst = newrank ^ mask;
81         /* find real rank of dest */
82         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
83
84         send_cnt = recv_cnt = 0;
85         if (newrank < newdst) {
86           send_idx = recv_idx + pof2 / (mask * 2);
87           for (i = send_idx; i < last_idx; i++)
88             send_cnt += cnts[i];
89           for (i = recv_idx; i < send_idx; i++)
90             recv_cnt += cnts[i];
91         } else {
92           recv_idx = send_idx + pof2 / (mask * 2);
93           for (i = send_idx; i < recv_idx; i++)
94             send_cnt += cnts[i];
95           for (i = recv_idx; i < last_idx; i++)
96             recv_cnt += cnts[i];
97         }
98
99         /* Send data from recvbuf. Recv into tmp_buf */
100         smpi_mpi_sendrecv((char *) recv_ptr +
101                      disps[send_idx] * extent,
102                      send_cnt, datatype,
103                      dst, tag,
104                      (char *) tmp_buf +
105                      disps[recv_idx] * extent,
106                      recv_cnt, datatype, dst, tag, comm, &status);
107
108         /* tmp_buf contains data received in this step.
109            recvbuf contains data accumulated so far */
110
111         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
112                        (char *) recv_ptr + disps[recv_idx] * extent,
113                        &recv_cnt, &datatype);
114
115         /* update send_idx for next iteration */
116         send_idx = recv_idx;
117         mask <<= 1;
118
119         if (mask < pof2)
120           last_idx = recv_idx + pof2 / mask;
121       }
122     }
123
124     /* now do the gather to root */
125
126     if (root < 2 * rem) {
127       if (root % 2 != 0) {
128         if (rank == root) {
129           /* recv */
130           for (i = 0; i < (pof2 - 1); i++)
131             cnts[i] = new_count / pof2;
132           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
133
134           disps[0] = 0;
135           for (i = 1; i < pof2; i++)
136             disps[i] = disps[i - 1] + cnts[i - 1];
137
138           smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
139
140           newrank = 0;
141           send_idx = 0;
142           last_idx = 2;
143         } else if (newrank == 0) {
144           smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
145           newrank = -1;
146         }
147         newroot = 0;
148       } else
149         newroot = root / 2;
150     } else
151       newroot = root - rem;
152
153     if (newrank != -1) {
154       j = 0;
155       mask = 0x1;
156       while (mask < pof2) {
157         mask <<= 1;
158         j++;
159       }
160       mask >>= 1;
161       j--;
162       while (mask > 0) {
163         newdst = newrank ^ mask;
164
165         /* find real rank of dest */
166         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
167
168         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
169           dst = root;
170         newdst_tree_root = newdst >> j;
171         newdst_tree_root <<= j;
172
173         newroot_tree_root = newroot >> j;
174         newroot_tree_root <<= j;
175
176         send_cnt = recv_cnt = 0;
177         if (newrank < newdst) {
178           /* update last_idx except on first iteration */
179           if (mask != pof2 / 2)
180             last_idx = last_idx + pof2 / (mask * 2);
181
182           recv_idx = send_idx + pof2 / (mask * 2);
183           for (i = send_idx; i < recv_idx; i++)
184             send_cnt += cnts[i];
185           for (i = recv_idx; i < last_idx; i++)
186             recv_cnt += cnts[i];
187         } else {
188           recv_idx = send_idx - pof2 / (mask * 2);
189           for (i = send_idx; i < last_idx; i++)
190             send_cnt += cnts[i];
191           for (i = recv_idx; i < send_idx; i++)
192             recv_cnt += cnts[i];
193         }
194
195         if (newdst_tree_root == newroot_tree_root) {
196           smpi_mpi_send((char *) recv_ptr +
197                    disps[send_idx] * extent,
198                    send_cnt, datatype, dst, tag, comm);
199           break;
200         } else {
201           smpi_mpi_recv((char *) recv_ptr +
202                    disps[recv_idx] * extent,
203                    recv_cnt, datatype, dst, tag, comm, &status);
204         }
205
206         if (newrank > newdst)
207           send_idx = recv_idx;
208
209         mask >>= 1;
210         j--;
211       }
212     }
213     memcpy(recvbuf, recv_ptr, extent * count);
214     free(send_ptr);
215     free(recv_ptr);
216   }
217
218
219   else /* (count >= comm_size) */ {
220     tmp_buf = (void *) xbt_malloc(count * extent);
221
222     //if ((rank != root))
223     smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
224                  recvbuf, count, datatype, rank, tag, comm, &status);
225
226     rem = comm_size - pof2;
227     if (rank < 2 * rem) {
228       if (rank % 2 != 0) {      /* odd */
229         smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
230         newrank = -1;
231       }
232
233       else {
234         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
235         smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
236         newrank = rank / 2;
237       }
238     } else                      /* rank >= 2*rem */
239       newrank = rank - rem;
240
241     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
242     disps = (int *) xbt_malloc(pof2 * sizeof(int));
243
244     if (newrank != -1) {
245       for (i = 0; i < (pof2 - 1); i++)
246         cnts[i] = count / pof2;
247       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
248
249       disps[0] = 0;
250       for (i = 1; i < pof2; i++)
251         disps[i] = disps[i - 1] + cnts[i - 1];
252
253       mask = 0x1;
254       send_idx = recv_idx = 0;
255       last_idx = pof2;
256       while (mask < pof2) {
257         newdst = newrank ^ mask;
258         /* find real rank of dest */
259         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
260
261         send_cnt = recv_cnt = 0;
262         if (newrank < newdst) {
263           send_idx = recv_idx + pof2 / (mask * 2);
264           for (i = send_idx; i < last_idx; i++)
265             send_cnt += cnts[i];
266           for (i = recv_idx; i < send_idx; i++)
267             recv_cnt += cnts[i];
268         } else {
269           recv_idx = send_idx + pof2 / (mask * 2);
270           for (i = send_idx; i < recv_idx; i++)
271             send_cnt += cnts[i];
272           for (i = recv_idx; i < last_idx; i++)
273             recv_cnt += cnts[i];
274         }
275
276         /* Send data from recvbuf. Recv into tmp_buf */
277         smpi_mpi_sendrecv((char *) recvbuf +
278                      disps[send_idx] * extent,
279                      send_cnt, datatype,
280                      dst, tag,
281                      (char *) tmp_buf +
282                      disps[recv_idx] * extent,
283                      recv_cnt, datatype, dst, tag, comm, &status);
284
285         /* tmp_buf contains data received in this step.
286            recvbuf contains data accumulated so far */
287
288         smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
289                        (char *) recvbuf + disps[recv_idx] * extent,
290                        &recv_cnt, &datatype);
291
292         /* update send_idx for next iteration */
293         send_idx = recv_idx;
294         mask <<= 1;
295
296         if (mask < pof2)
297           last_idx = recv_idx + pof2 / mask;
298       }
299     }
300
301     /* now do the gather to root */
302
303     if (root < 2 * rem) {
304       if (root % 2 != 0) {
305         if (rank == root) {     /* recv */
306           for (i = 0; i < (pof2 - 1); i++)
307             cnts[i] = count / pof2;
308           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
309
310           disps[0] = 0;
311           for (i = 1; i < pof2; i++)
312             disps[i] = disps[i - 1] + cnts[i - 1];
313
314           smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
315
316           newrank = 0;
317           send_idx = 0;
318           last_idx = 2;
319         } else if (newrank == 0) {
320           smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
321           newrank = -1;
322         }
323         newroot = 0;
324       } else
325         newroot = root / 2;
326     } else
327       newroot = root - rem;
328
329     if (newrank != -1) {
330       j = 0;
331       mask = 0x1;
332       while (mask < pof2) {
333         mask <<= 1;
334         j++;
335       }
336       mask >>= 1;
337       j--;
338       while (mask > 0) {
339         newdst = newrank ^ mask;
340
341         /* find real rank of dest */
342         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
343
344         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
345           dst = root;
346         newdst_tree_root = newdst >> j;
347         newdst_tree_root <<= j;
348
349         newroot_tree_root = newroot >> j;
350         newroot_tree_root <<= j;
351
352         send_cnt = recv_cnt = 0;
353         if (newrank < newdst) {
354           /* update last_idx except on first iteration */
355           if (mask != pof2 / 2)
356             last_idx = last_idx + pof2 / (mask * 2);
357
358           recv_idx = send_idx + pof2 / (mask * 2);
359           for (i = send_idx; i < recv_idx; i++)
360             send_cnt += cnts[i];
361           for (i = recv_idx; i < last_idx; i++)
362             recv_cnt += cnts[i];
363         } else {
364           recv_idx = send_idx - pof2 / (mask * 2);
365           for (i = send_idx; i < last_idx; i++)
366             send_cnt += cnts[i];
367           for (i = recv_idx; i < send_idx; i++)
368             recv_cnt += cnts[i];
369         }
370
371         if (newdst_tree_root == newroot_tree_root) {
372           smpi_mpi_send((char *) recvbuf +
373                    disps[send_idx] * extent,
374                    send_cnt, datatype, dst, tag, comm);
375           break;
376         } else {
377           smpi_mpi_recv((char *) recvbuf +
378                    disps[recv_idx] * extent,
379                    recv_cnt, datatype, dst, tag, comm, &status);
380         }
381
382         if (newrank > newdst)
383           send_idx = recv_idx;
384
385         mask >>= 1;
386         j--;
387       }
388     }
389   }
390   if (tmp_buf)
391     free(tmp_buf);
392   if (cnts)
393     free(cnts);
394   if (disps)
395     free(disps);
396
397   return 0;
398 }