Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
4dbcbf1327f788438a16490ff8ff54bfd5737092
[simgrid.git] / src / smpi / colls / reduce-scatter-gather.c
1 #include "colls_private.h"
2
3 /*
4   reduce
5   Author: MPICH
6  */
7
8 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
9                                           int count, MPI_Datatype datatype,
10                                           MPI_Op op, int root, MPI_Comm comm)
11 {
12   MPI_Status status;
13   int comm_size, rank, type_size, pof2, rem, newrank;
14   int mask, *cnts, *disps, i, j, send_idx = 0;
15   int recv_idx, last_idx = 0, newdst;
16   int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
17   int newroot_tree_root, new_count;
18   int tag = 4321;
19   void *send_ptr, *recv_ptr, *tmp_buf;
20
21   cnts = NULL;
22   disps = NULL;
23
24   MPI_Aint extent;
25
26   if (count == 0)
27     return 0;
28   rank = smpi_comm_rank(comm);
29   comm_size = smpi_comm_size(comm);
30
31   extent = smpi_datatype_get_extent(datatype);
32   type_size = smpi_datatype_size(datatype);
33
34   /* find nearest power-of-two less than or equal to comm_size */
35   pof2 = 1;
36   while (pof2 <= comm_size)
37     pof2 <<= 1;
38   pof2 >>= 1;
39
40   if (count < comm_size) {
41     new_count = comm_size;
42     send_ptr = (void *) xbt_malloc(new_count * extent);
43     recv_ptr = (void *) xbt_malloc(new_count * extent);
44     tmp_buf = (void *) xbt_malloc(new_count * extent);
45     memcpy(send_ptr, sendbuf, extent * new_count);
46
47     //if ((rank != root))
48     smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
49                  recv_ptr, new_count, datatype, rank, tag, comm, &status);
50
51     rem = comm_size - pof2;
52     if (rank < 2 * rem) {
53       if (rank % 2 != 0) {
54         /* odd */
55         smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
56         newrank = -1;
57       } else {
58         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
59         star_reduction(op, tmp_buf, recv_ptr, &new_count, &datatype);
60         newrank = rank / 2;
61       }
62     } else                      /* rank >= 2*rem */
63       newrank = rank - rem;
64
65     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
66     disps = (int *) xbt_malloc(pof2 * sizeof(int));
67
68     if (newrank != -1) {
69       for (i = 0; i < (pof2 - 1); i++)
70         cnts[i] = new_count / pof2;
71       cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
72
73       disps[0] = 0;
74       for (i = 1; i < pof2; i++)
75         disps[i] = disps[i - 1] + cnts[i - 1];
76
77       mask = 0x1;
78       send_idx = recv_idx = 0;
79       last_idx = pof2;
80       while (mask < pof2) {
81         newdst = newrank ^ mask;
82         /* find real rank of dest */
83         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
84
85         send_cnt = recv_cnt = 0;
86         if (newrank < newdst) {
87           send_idx = recv_idx + pof2 / (mask * 2);
88           for (i = send_idx; i < last_idx; i++)
89             send_cnt += cnts[i];
90           for (i = recv_idx; i < send_idx; i++)
91             recv_cnt += cnts[i];
92         } else {
93           recv_idx = send_idx + pof2 / (mask * 2);
94           for (i = send_idx; i < recv_idx; i++)
95             send_cnt += cnts[i];
96           for (i = recv_idx; i < last_idx; i++)
97             recv_cnt += cnts[i];
98         }
99
100         /* Send data from recvbuf. Recv into tmp_buf */
101         smpi_mpi_sendrecv((char *) recv_ptr +
102                      disps[send_idx] * extent,
103                      send_cnt, datatype,
104                      dst, tag,
105                      (char *) tmp_buf +
106                      disps[recv_idx] * extent,
107                      recv_cnt, datatype, dst, tag, comm, &status);
108
109         /* tmp_buf contains data received in this step.
110            recvbuf contains data accumulated so far */
111
112         star_reduction(op, (char *) tmp_buf + disps[recv_idx] * extent,
113                        (char *) recv_ptr + disps[recv_idx] * extent,
114                        &recv_cnt, &datatype);
115
116         /* update send_idx for next iteration */
117         send_idx = recv_idx;
118         mask <<= 1;
119
120         if (mask < pof2)
121           last_idx = recv_idx + pof2 / mask;
122       }
123     }
124
125     /* now do the gather to root */
126
127     if (root < 2 * rem) {
128       if (root % 2 != 0) {
129         if (rank == root) {
130           /* recv */
131           for (i = 0; i < (pof2 - 1); i++)
132             cnts[i] = new_count / pof2;
133           cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
134
135           disps[0] = 0;
136           for (i = 1; i < pof2; i++)
137             disps[i] = disps[i - 1] + cnts[i - 1];
138
139           smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
140
141           newrank = 0;
142           send_idx = 0;
143           last_idx = 2;
144         } else if (newrank == 0) {
145           smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
146           newrank = -1;
147         }
148         newroot = 0;
149       } else
150         newroot = root / 2;
151     } else
152       newroot = root - rem;
153
154     if (newrank != -1) {
155       j = 0;
156       mask = 0x1;
157       while (mask < pof2) {
158         mask <<= 1;
159         j++;
160       }
161       mask >>= 1;
162       j--;
163       while (mask > 0) {
164         newdst = newrank ^ mask;
165
166         /* find real rank of dest */
167         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
168
169         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
170           dst = root;
171         newdst_tree_root = newdst >> j;
172         newdst_tree_root <<= j;
173
174         newroot_tree_root = newroot >> j;
175         newroot_tree_root <<= j;
176
177         send_cnt = recv_cnt = 0;
178         if (newrank < newdst) {
179           /* update last_idx except on first iteration */
180           if (mask != pof2 / 2)
181             last_idx = last_idx + pof2 / (mask * 2);
182
183           recv_idx = send_idx + pof2 / (mask * 2);
184           for (i = send_idx; i < recv_idx; i++)
185             send_cnt += cnts[i];
186           for (i = recv_idx; i < last_idx; i++)
187             recv_cnt += cnts[i];
188         } else {
189           recv_idx = send_idx - pof2 / (mask * 2);
190           for (i = send_idx; i < last_idx; i++)
191             send_cnt += cnts[i];
192           for (i = recv_idx; i < send_idx; i++)
193             recv_cnt += cnts[i];
194         }
195
196         if (newdst_tree_root == newroot_tree_root) {
197           smpi_mpi_send((char *) recv_ptr +
198                    disps[send_idx] * extent,
199                    send_cnt, datatype, dst, tag, comm);
200           break;
201         } else {
202           smpi_mpi_recv((char *) recv_ptr +
203                    disps[recv_idx] * extent,
204                    recv_cnt, datatype, dst, tag, comm, &status);
205         }
206
207         if (newrank > newdst)
208           send_idx = recv_idx;
209
210         mask >>= 1;
211         j--;
212       }
213     }
214     memcpy(recvbuf, recv_ptr, extent * count);
215     free(send_ptr);
216     free(recv_ptr);
217   }
218
219
220   else if (count >= comm_size) {
221     tmp_buf = (void *) xbt_malloc(count * extent);
222
223     //if ((rank != root))
224     smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
225                  recvbuf, count, datatype, rank, tag, comm, &status);
226
227     rem = comm_size - pof2;
228     if (rank < 2 * rem) {
229       if (rank % 2 != 0) {      /* odd */
230         smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
231         newrank = -1;
232       }
233
234       else {
235         smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
236         star_reduction(op, tmp_buf, recvbuf, &count, &datatype);
237         newrank = rank / 2;
238       }
239     } else                      /* rank >= 2*rem */
240       newrank = rank - rem;
241
242     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
243     disps = (int *) xbt_malloc(pof2 * sizeof(int));
244
245     if (newrank != -1) {
246       for (i = 0; i < (pof2 - 1); i++)
247         cnts[i] = count / pof2;
248       cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
249
250       disps[0] = 0;
251       for (i = 1; i < pof2; i++)
252         disps[i] = disps[i - 1] + cnts[i - 1];
253
254       mask = 0x1;
255       send_idx = recv_idx = 0;
256       last_idx = pof2;
257       while (mask < pof2) {
258         newdst = newrank ^ mask;
259         /* find real rank of dest */
260         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
261
262         send_cnt = recv_cnt = 0;
263         if (newrank < newdst) {
264           send_idx = recv_idx + pof2 / (mask * 2);
265           for (i = send_idx; i < last_idx; i++)
266             send_cnt += cnts[i];
267           for (i = recv_idx; i < send_idx; i++)
268             recv_cnt += cnts[i];
269         } else {
270           recv_idx = send_idx + pof2 / (mask * 2);
271           for (i = send_idx; i < recv_idx; i++)
272             send_cnt += cnts[i];
273           for (i = recv_idx; i < last_idx; i++)
274             recv_cnt += cnts[i];
275         }
276
277         /* Send data from recvbuf. Recv into tmp_buf */
278         smpi_mpi_sendrecv((char *) recvbuf +
279                      disps[send_idx] * extent,
280                      send_cnt, datatype,
281                      dst, tag,
282                      (char *) tmp_buf +
283                      disps[recv_idx] * extent,
284                      recv_cnt, datatype, dst, tag, comm, &status);
285
286         /* tmp_buf contains data received in this step.
287            recvbuf contains data accumulated so far */
288
289         star_reduction(op, (char *) tmp_buf + disps[recv_idx] * extent,
290                        (char *) recvbuf + disps[recv_idx] * extent,
291                        &recv_cnt, &datatype);
292
293         /* update send_idx for next iteration */
294         send_idx = recv_idx;
295         mask <<= 1;
296
297         if (mask < pof2)
298           last_idx = recv_idx + pof2 / mask;
299       }
300     }
301
302     /* now do the gather to root */
303
304     if (root < 2 * rem) {
305       if (root % 2 != 0) {
306         if (rank == root) {     /* recv */
307           for (i = 0; i < (pof2 - 1); i++)
308             cnts[i] = count / pof2;
309           cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
310
311           disps[0] = 0;
312           for (i = 1; i < pof2; i++)
313             disps[i] = disps[i - 1] + cnts[i - 1];
314
315           smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
316
317           newrank = 0;
318           send_idx = 0;
319           last_idx = 2;
320         } else if (newrank == 0) {
321           smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
322           newrank = -1;
323         }
324         newroot = 0;
325       } else
326         newroot = root / 2;
327     } else
328       newroot = root - rem;
329
330     if (newrank != -1) {
331       j = 0;
332       mask = 0x1;
333       while (mask < pof2) {
334         mask <<= 1;
335         j++;
336       }
337       mask >>= 1;
338       j--;
339       while (mask > 0) {
340         newdst = newrank ^ mask;
341
342         /* find real rank of dest */
343         dst = (newdst < rem) ? newdst * 2 : newdst + rem;
344
345         if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
346           dst = root;
347         newdst_tree_root = newdst >> j;
348         newdst_tree_root <<= j;
349
350         newroot_tree_root = newroot >> j;
351         newroot_tree_root <<= j;
352
353         send_cnt = recv_cnt = 0;
354         if (newrank < newdst) {
355           /* update last_idx except on first iteration */
356           if (mask != pof2 / 2)
357             last_idx = last_idx + pof2 / (mask * 2);
358
359           recv_idx = send_idx + pof2 / (mask * 2);
360           for (i = send_idx; i < recv_idx; i++)
361             send_cnt += cnts[i];
362           for (i = recv_idx; i < last_idx; i++)
363             recv_cnt += cnts[i];
364         } else {
365           recv_idx = send_idx - pof2 / (mask * 2);
366           for (i = send_idx; i < last_idx; i++)
367             send_cnt += cnts[i];
368           for (i = recv_idx; i < send_idx; i++)
369             recv_cnt += cnts[i];
370         }
371
372         if (newdst_tree_root == newroot_tree_root) {
373           smpi_mpi_send((char *) recvbuf +
374                    disps[send_idx] * extent,
375                    send_cnt, datatype, dst, tag, comm);
376           break;
377         } else {
378           smpi_mpi_recv((char *) recvbuf +
379                    disps[recv_idx] * extent,
380                    recv_cnt, datatype, dst, tag, comm, &status);
381         }
382
383         if (newrank > newdst)
384           send_idx = recv_idx;
385
386         mask >>= 1;
387         j--;
388       }
389     }
390   }
391   if (cnts)
392     free(cnts);
393   if (disps)
394     free(disps);
395
396   return 0;
397 }