1 /* Copyright (c) 2013-2019. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "../colls_private.hpp"
15 int Coll_reduce_scatter_gather::reduce(void *sendbuf, void *recvbuf,
16 int count, MPI_Datatype datatype,
17 MPI_Op op, int root, MPI_Comm comm)
20 int comm_size, rank, pof2, rem, newrank;
21 int mask, *cnts, *disps, i, j, send_idx = 0;
22 int recv_idx, last_idx = 0, newdst;
23 int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
24 int newroot_tree_root, new_count;
25 int tag = COLL_TAG_REDUCE,temporary_buffer=0;
26 void *send_ptr, *recv_ptr, *tmp_buf;
36 comm_size = comm->size();
40 extent = datatype->get_extent();
41 /* If I'm not the root, then my recvbuf may not be valid, therefore
42 I have to allocate a temporary one */
43 if (rank != root && not recvbuf) {
45 recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
47 /* find nearest power-of-two less than or equal to comm_size */
49 while (pof2 <= comm_size)
53 if (count < comm_size) {
54 new_count = comm_size;
55 send_ptr = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
56 recv_ptr = (void *) smpi_get_tmp_recvbuffer(new_count * extent);
57 tmp_buf = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
58 memcpy(send_ptr, sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, extent * count);
61 Request::sendrecv(send_ptr, new_count, datatype, rank, tag,
62 recv_ptr, new_count, datatype, rank, tag, comm, &status);
64 rem = comm_size - pof2;
68 Request::send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
71 Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
72 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recv_ptr, &new_count, datatype);
75 } else /* rank >= 2*rem */
78 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
79 disps = (int *) xbt_malloc(pof2 * sizeof(int));
82 for (i = 0; i < (pof2 - 1); i++)
83 cnts[i] = new_count / pof2;
84 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
87 for (i = 1; i < pof2; i++)
88 disps[i] = disps[i - 1] + cnts[i - 1];
91 send_idx = recv_idx = 0;
94 newdst = newrank ^ mask;
95 /* find real rank of dest */
96 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
98 send_cnt = recv_cnt = 0;
99 if (newrank < newdst) {
100 send_idx = recv_idx + pof2 / (mask * 2);
101 for (i = send_idx; i < last_idx; i++)
103 for (i = recv_idx; i < send_idx; i++)
106 recv_idx = send_idx + pof2 / (mask * 2);
107 for (i = send_idx; i < recv_idx; i++)
109 for (i = recv_idx; i < last_idx; i++)
113 /* Send data from recvbuf. Recv into tmp_buf */
114 Request::sendrecv((char *) recv_ptr +
115 disps[send_idx] * extent,
119 disps[recv_idx] * extent,
120 recv_cnt, datatype, dst, tag, comm, &status);
122 /* tmp_buf contains data received in this step.
123 recvbuf contains data accumulated so far */
125 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
126 (char *) recv_ptr + disps[recv_idx] * extent,
127 &recv_cnt, datatype);
129 /* update send_idx for next iteration */
134 last_idx = recv_idx + pof2 / mask;
138 /* now do the gather to root */
140 if (root < 2 * rem) {
144 for (i = 0; i < (pof2 - 1); i++)
145 cnts[i] = new_count / pof2;
146 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
149 for (i = 1; i < pof2; i++)
150 disps[i] = disps[i - 1] + cnts[i - 1];
152 Request::recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
157 } else if (newrank == 0) {
158 Request::send(recv_ptr, cnts[0], datatype, root, tag, comm);
165 newroot = root - rem;
170 while (mask < pof2) {
177 newdst = newrank ^ mask;
179 /* find real rank of dest */
180 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
182 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
184 newdst_tree_root = newdst >> j;
185 newdst_tree_root <<= j;
187 newroot_tree_root = newroot >> j;
188 newroot_tree_root <<= j;
190 send_cnt = recv_cnt = 0;
191 if (newrank < newdst) {
192 /* update last_idx except on first iteration */
193 if (mask != pof2 / 2)
194 last_idx = last_idx + pof2 / (mask * 2);
196 recv_idx = send_idx + pof2 / (mask * 2);
197 for (i = send_idx; i < recv_idx; i++)
199 for (i = recv_idx; i < last_idx; i++)
202 recv_idx = send_idx - pof2 / (mask * 2);
203 for (i = send_idx; i < last_idx; i++)
205 for (i = recv_idx; i < send_idx; i++)
209 if (newdst_tree_root == newroot_tree_root) {
210 Request::send((char *) recv_ptr +
211 disps[send_idx] * extent,
212 send_cnt, datatype, dst, tag, comm);
215 Request::recv((char *) recv_ptr +
216 disps[recv_idx] * extent,
217 recv_cnt, datatype, dst, tag, comm, &status);
220 if (newrank > newdst)
227 memcpy(recvbuf, recv_ptr, extent * count);
228 smpi_free_tmp_buffer(send_ptr);
229 smpi_free_tmp_buffer(recv_ptr);
233 else /* (count >= comm_size) */ {
234 tmp_buf = (void *) smpi_get_tmp_sendbuffer(count * extent);
236 //if ((rank != root))
237 Request::sendrecv(sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, count, datatype, rank, tag,
238 recvbuf, count, datatype, rank, tag, comm, &status);
240 rem = comm_size - pof2;
241 if (rank < 2 * rem) {
242 if (rank % 2 != 0) { /* odd */
243 Request::send(recvbuf, count, datatype, rank - 1, tag, comm);
248 Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
249 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
252 } else /* rank >= 2*rem */
253 newrank = rank - rem;
255 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
256 disps = (int *) xbt_malloc(pof2 * sizeof(int));
259 for (i = 0; i < (pof2 - 1); i++)
260 cnts[i] = count / pof2;
261 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
264 for (i = 1; i < pof2; i++)
265 disps[i] = disps[i - 1] + cnts[i - 1];
268 send_idx = recv_idx = 0;
270 while (mask < pof2) {
271 newdst = newrank ^ mask;
272 /* find real rank of dest */
273 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
275 send_cnt = recv_cnt = 0;
276 if (newrank < newdst) {
277 send_idx = recv_idx + pof2 / (mask * 2);
278 for (i = send_idx; i < last_idx; i++)
280 for (i = recv_idx; i < send_idx; i++)
283 recv_idx = send_idx + pof2 / (mask * 2);
284 for (i = send_idx; i < recv_idx; i++)
286 for (i = recv_idx; i < last_idx; i++)
290 /* Send data from recvbuf. Recv into tmp_buf */
291 Request::sendrecv((char *) recvbuf +
292 disps[send_idx] * extent,
296 disps[recv_idx] * extent,
297 recv_cnt, datatype, dst, tag, comm, &status);
299 /* tmp_buf contains data received in this step.
300 recvbuf contains data accumulated so far */
302 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
303 (char *) recvbuf + disps[recv_idx] * extent,
304 &recv_cnt, datatype);
306 /* update send_idx for next iteration */
311 last_idx = recv_idx + pof2 / mask;
315 /* now do the gather to root */
317 if (root < 2 * rem) {
319 if (rank == root) { /* recv */
320 for (i = 0; i < (pof2 - 1); i++)
321 cnts[i] = count / pof2;
322 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
325 for (i = 1; i < pof2; i++)
326 disps[i] = disps[i - 1] + cnts[i - 1];
328 Request::recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
333 } else if (newrank == 0) {
334 Request::send(recvbuf, cnts[0], datatype, root, tag, comm);
341 newroot = root - rem;
346 while (mask < pof2) {
353 newdst = newrank ^ mask;
355 /* find real rank of dest */
356 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
358 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
360 newdst_tree_root = newdst >> j;
361 newdst_tree_root <<= j;
363 newroot_tree_root = newroot >> j;
364 newroot_tree_root <<= j;
366 send_cnt = recv_cnt = 0;
367 if (newrank < newdst) {
368 /* update last_idx except on first iteration */
369 if (mask != pof2 / 2)
370 last_idx = last_idx + pof2 / (mask * 2);
372 recv_idx = send_idx + pof2 / (mask * 2);
373 for (i = send_idx; i < recv_idx; i++)
375 for (i = recv_idx; i < last_idx; i++)
378 recv_idx = send_idx - pof2 / (mask * 2);
379 for (i = send_idx; i < last_idx; i++)
381 for (i = recv_idx; i < send_idx; i++)
385 if (newdst_tree_root == newroot_tree_root) {
386 Request::send((char *) recvbuf +
387 disps[send_idx] * extent,
388 send_cnt, datatype, dst, tag, comm);
391 Request::recv((char *) recvbuf +
392 disps[recv_idx] * extent,
393 recv_cnt, datatype, dst, tag, comm, &status);
396 if (newrank > newdst)
405 smpi_free_tmp_buffer(tmp_buf);
406 if(temporary_buffer==1) smpi_free_tmp_buffer(recvbuf);