1 /* Copyright (c) 2013-2014. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "colls_private.h"
14 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
15 int count, MPI_Datatype datatype,
16 MPI_Op op, int root, MPI_Comm comm)
19 int comm_size, rank, pof2, rem, newrank;
20 int mask, *cnts, *disps, i, j, send_idx = 0;
21 int recv_idx, last_idx = 0, newdst;
22 int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23 int newroot_tree_root, new_count;
24 int tag = COLL_TAG_REDUCE;
25 void *send_ptr, *recv_ptr, *tmp_buf;
34 rank = smpi_comm_rank(comm);
35 comm_size = smpi_comm_size(comm);
39 extent = smpi_datatype_get_extent(datatype);
40 /* If I'm not the root, then my recvbuf may not be valid, therefore
41 I have to allocate a temporary one */
42 if (rank != root && !recvbuf) {
43 recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
45 /* find nearest power-of-two less than or equal to comm_size */
47 while (pof2 <= comm_size)
51 if (count < comm_size) {
52 new_count = comm_size;
53 send_ptr = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
54 recv_ptr = (void *) smpi_get_tmp_recvbuffer(new_count * extent);
55 tmp_buf = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
56 memcpy(send_ptr, sendbuf, extent * count);
59 smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
60 recv_ptr, new_count, datatype, rank, tag, comm, &status);
62 rem = comm_size - pof2;
66 smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
69 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
70 smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
73 } else /* rank >= 2*rem */
76 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
77 disps = (int *) xbt_malloc(pof2 * sizeof(int));
80 for (i = 0; i < (pof2 - 1); i++)
81 cnts[i] = new_count / pof2;
82 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
85 for (i = 1; i < pof2; i++)
86 disps[i] = disps[i - 1] + cnts[i - 1];
89 send_idx = recv_idx = 0;
92 newdst = newrank ^ mask;
93 /* find real rank of dest */
94 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
96 send_cnt = recv_cnt = 0;
97 if (newrank < newdst) {
98 send_idx = recv_idx + pof2 / (mask * 2);
99 for (i = send_idx; i < last_idx; i++)
101 for (i = recv_idx; i < send_idx; i++)
104 recv_idx = send_idx + pof2 / (mask * 2);
105 for (i = send_idx; i < recv_idx; i++)
107 for (i = recv_idx; i < last_idx; i++)
111 /* Send data from recvbuf. Recv into tmp_buf */
112 smpi_mpi_sendrecv((char *) recv_ptr +
113 disps[send_idx] * extent,
117 disps[recv_idx] * extent,
118 recv_cnt, datatype, dst, tag, comm, &status);
120 /* tmp_buf contains data received in this step.
121 recvbuf contains data accumulated so far */
123 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
124 (char *) recv_ptr + disps[recv_idx] * extent,
125 &recv_cnt, &datatype);
127 /* update send_idx for next iteration */
132 last_idx = recv_idx + pof2 / mask;
136 /* now do the gather to root */
138 if (root < 2 * rem) {
142 for (i = 0; i < (pof2 - 1); i++)
143 cnts[i] = new_count / pof2;
144 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
147 for (i = 1; i < pof2; i++)
148 disps[i] = disps[i - 1] + cnts[i - 1];
150 smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
155 } else if (newrank == 0) {
156 smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
163 newroot = root - rem;
168 while (mask < pof2) {
175 newdst = newrank ^ mask;
177 /* find real rank of dest */
178 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
180 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
182 newdst_tree_root = newdst >> j;
183 newdst_tree_root <<= j;
185 newroot_tree_root = newroot >> j;
186 newroot_tree_root <<= j;
188 send_cnt = recv_cnt = 0;
189 if (newrank < newdst) {
190 /* update last_idx except on first iteration */
191 if (mask != pof2 / 2)
192 last_idx = last_idx + pof2 / (mask * 2);
194 recv_idx = send_idx + pof2 / (mask * 2);
195 for (i = send_idx; i < recv_idx; i++)
197 for (i = recv_idx; i < last_idx; i++)
200 recv_idx = send_idx - pof2 / (mask * 2);
201 for (i = send_idx; i < last_idx; i++)
203 for (i = recv_idx; i < send_idx; i++)
207 if (newdst_tree_root == newroot_tree_root) {
208 smpi_mpi_send((char *) recv_ptr +
209 disps[send_idx] * extent,
210 send_cnt, datatype, dst, tag, comm);
213 smpi_mpi_recv((char *) recv_ptr +
214 disps[recv_idx] * extent,
215 recv_cnt, datatype, dst, tag, comm, &status);
218 if (newrank > newdst)
225 memcpy(recvbuf, recv_ptr, extent * count);
226 smpi_free_tmp_buffer(send_ptr);
227 smpi_free_tmp_buffer(recv_ptr);
231 else /* (count >= comm_size) */ {
232 tmp_buf = (void *) smpi_get_tmp_sendbuffer(count * extent);
234 //if ((rank != root))
235 smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
236 recvbuf, count, datatype, rank, tag, comm, &status);
238 rem = comm_size - pof2;
239 if (rank < 2 * rem) {
240 if (rank % 2 != 0) { /* odd */
241 smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
246 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
247 smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
250 } else /* rank >= 2*rem */
251 newrank = rank - rem;
253 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
254 disps = (int *) xbt_malloc(pof2 * sizeof(int));
257 for (i = 0; i < (pof2 - 1); i++)
258 cnts[i] = count / pof2;
259 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
262 for (i = 1; i < pof2; i++)
263 disps[i] = disps[i - 1] + cnts[i - 1];
266 send_idx = recv_idx = 0;
268 while (mask < pof2) {
269 newdst = newrank ^ mask;
270 /* find real rank of dest */
271 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
273 send_cnt = recv_cnt = 0;
274 if (newrank < newdst) {
275 send_idx = recv_idx + pof2 / (mask * 2);
276 for (i = send_idx; i < last_idx; i++)
278 for (i = recv_idx; i < send_idx; i++)
281 recv_idx = send_idx + pof2 / (mask * 2);
282 for (i = send_idx; i < recv_idx; i++)
284 for (i = recv_idx; i < last_idx; i++)
288 /* Send data from recvbuf. Recv into tmp_buf */
289 smpi_mpi_sendrecv((char *) recvbuf +
290 disps[send_idx] * extent,
294 disps[recv_idx] * extent,
295 recv_cnt, datatype, dst, tag, comm, &status);
297 /* tmp_buf contains data received in this step.
298 recvbuf contains data accumulated so far */
300 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
301 (char *) recvbuf + disps[recv_idx] * extent,
302 &recv_cnt, &datatype);
304 /* update send_idx for next iteration */
309 last_idx = recv_idx + pof2 / mask;
313 /* now do the gather to root */
315 if (root < 2 * rem) {
317 if (rank == root) { /* recv */
318 for (i = 0; i < (pof2 - 1); i++)
319 cnts[i] = count / pof2;
320 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
323 for (i = 1; i < pof2; i++)
324 disps[i] = disps[i - 1] + cnts[i - 1];
326 smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
331 } else if (newrank == 0) {
332 smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
339 newroot = root - rem;
344 while (mask < pof2) {
351 newdst = newrank ^ mask;
353 /* find real rank of dest */
354 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
356 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
358 newdst_tree_root = newdst >> j;
359 newdst_tree_root <<= j;
361 newroot_tree_root = newroot >> j;
362 newroot_tree_root <<= j;
364 send_cnt = recv_cnt = 0;
365 if (newrank < newdst) {
366 /* update last_idx except on first iteration */
367 if (mask != pof2 / 2)
368 last_idx = last_idx + pof2 / (mask * 2);
370 recv_idx = send_idx + pof2 / (mask * 2);
371 for (i = send_idx; i < recv_idx; i++)
373 for (i = recv_idx; i < last_idx; i++)
376 recv_idx = send_idx - pof2 / (mask * 2);
377 for (i = send_idx; i < last_idx; i++)
379 for (i = recv_idx; i < send_idx; i++)
383 if (newdst_tree_root == newroot_tree_root) {
384 smpi_mpi_send((char *) recvbuf +
385 disps[send_idx] * extent,
386 send_cnt, datatype, dst, tag, comm);
389 smpi_mpi_recv((char *) recvbuf +
390 disps[recv_idx] * extent,
391 recv_cnt, datatype, dst, tag, comm, &status);
394 if (newrank > newdst)
403 smpi_free_tmp_buffer(tmp_buf);