1 /* Copyright (c) 2013-2023. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "../colls_private.hpp"
13 namespace simgrid::smpi {
14 int reduce__scatter_gather(const void *sendbuf, void *recvbuf,
15 int count, MPI_Datatype datatype,
16 MPI_Op op, int root, MPI_Comm comm)
19 int comm_size, rank, pof2, rem, newrank;
20 int mask, *cnts, *disps, i, j, send_idx = 0;
21 int recv_idx, last_idx = 0, newdst;
22 int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23 int newroot_tree_root, new_count;
24 int tag = COLL_TAG_REDUCE,temporary_buffer=0;
25 unsigned char *send_ptr, *recv_ptr, *tmp_buf;
35 comm_size = comm->size();
39 extent = datatype->get_extent();
40 /* If I'm not the root, then my recvbuf may not be valid, therefore
41 I have to allocate a temporary one */
42 if (rank != root && not recvbuf) {
44 recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
46 /* find nearest power-of-two less than or equal to comm_size */
48 while (pof2 <= comm_size)
52 if (count < comm_size) {
53 new_count = comm_size;
54 send_ptr = smpi_get_tmp_sendbuffer(new_count * extent);
55 recv_ptr = smpi_get_tmp_recvbuffer(new_count * extent);
56 tmp_buf = smpi_get_tmp_sendbuffer(new_count * extent);
57 memcpy(send_ptr, sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, extent * count);
60 Request::sendrecv(send_ptr, new_count, datatype, rank, tag,
61 recv_ptr, new_count, datatype, rank, tag, comm, &status);
63 rem = comm_size - pof2;
67 Request::send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
70 Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
71 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recv_ptr, &new_count, datatype);
74 } else /* rank >= 2*rem */
78 disps = new int[pof2];
81 for (i = 0; i < (pof2 - 1); i++)
82 cnts[i] = new_count / pof2;
83 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
86 for (i = 1; i < pof2; i++)
87 disps[i] = disps[i - 1] + cnts[i - 1];
90 send_idx = recv_idx = 0;
93 newdst = newrank ^ mask;
94 /* find real rank of dest */
95 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
97 send_cnt = recv_cnt = 0;
98 if (newrank < newdst) {
99 send_idx = recv_idx + pof2 / (mask * 2);
100 for (i = send_idx; i < last_idx; i++)
102 for (i = recv_idx; i < send_idx; i++)
105 recv_idx = send_idx + pof2 / (mask * 2);
106 for (i = send_idx; i < recv_idx; i++)
108 for (i = recv_idx; i < last_idx; i++)
112 /* Send data from recvbuf. Recv into tmp_buf */
113 Request::sendrecv(recv_ptr + disps[send_idx] * extent, send_cnt, datatype, dst, tag,
114 tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
116 /* tmp_buf contains data received in this step.
117 recvbuf contains data accumulated so far */
119 if (op != MPI_OP_NULL)
120 op->apply(tmp_buf + disps[recv_idx] * extent, recv_ptr + disps[recv_idx] * extent, &recv_cnt, datatype);
122 /* update send_idx for next iteration */
127 last_idx = recv_idx + pof2 / mask;
131 /* now do the gather to root */
133 if (root < 2 * rem) {
137 for (i = 0; i < (pof2 - 1); i++)
138 cnts[i] = new_count / pof2;
139 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
142 for (i = 1; i < pof2; i++)
143 disps[i] = disps[i - 1] + cnts[i - 1];
145 Request::recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
150 } else if (newrank == 0) {
151 Request::send(recv_ptr, cnts[0], datatype, root, tag, comm);
158 newroot = root - rem;
163 while (mask < pof2) {
170 newdst = newrank ^ mask;
172 /* find real rank of dest */
173 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
175 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
177 newdst_tree_root = newdst >> j;
178 newdst_tree_root <<= j;
180 newroot_tree_root = newroot >> j;
181 newroot_tree_root <<= j;
183 send_cnt = recv_cnt = 0;
184 if (newrank < newdst) {
185 /* update last_idx except on first iteration */
186 if (mask != pof2 / 2)
187 last_idx = last_idx + pof2 / (mask * 2);
189 recv_idx = send_idx + pof2 / (mask * 2);
190 for (i = send_idx; i < recv_idx; i++)
192 for (i = recv_idx; i < last_idx; i++)
195 recv_idx = send_idx - pof2 / (mask * 2);
196 for (i = send_idx; i < last_idx; i++)
198 for (i = recv_idx; i < send_idx; i++)
202 if (newdst_tree_root == newroot_tree_root) {
203 Request::send(recv_ptr + disps[send_idx] * extent, send_cnt, datatype, dst, tag, comm);
206 Request::recv(recv_ptr + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
209 if (newrank > newdst)
216 memcpy(recvbuf, recv_ptr, extent * count);
217 smpi_free_tmp_buffer(send_ptr);
218 smpi_free_tmp_buffer(recv_ptr);
222 else /* (count >= comm_size) */ {
223 tmp_buf = smpi_get_tmp_sendbuffer(count * extent);
225 //if ((rank != root))
226 Request::sendrecv(sendbuf != MPI_IN_PLACE ? sendbuf : recvbuf, count, datatype, rank, tag,
227 recvbuf, count, datatype, rank, tag, comm, &status);
229 rem = comm_size - pof2;
230 if (rank < 2 * rem) {
231 if (rank % 2 != 0) { /* odd */
232 Request::send(recvbuf, count, datatype, rank - 1, tag, comm);
237 Request::recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
238 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
241 } else /* rank >= 2*rem */
242 newrank = rank - rem;
244 cnts = new int[pof2];
245 disps = new int[pof2];
248 for (i = 0; i < (pof2 - 1); i++)
249 cnts[i] = count / pof2;
250 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
253 for (i = 1; i < pof2; i++)
254 disps[i] = disps[i - 1] + cnts[i - 1];
257 send_idx = recv_idx = 0;
259 while (mask < pof2) {
260 newdst = newrank ^ mask;
261 /* find real rank of dest */
262 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
264 send_cnt = recv_cnt = 0;
265 if (newrank < newdst) {
266 send_idx = recv_idx + pof2 / (mask * 2);
267 for (i = send_idx; i < last_idx; i++)
269 for (i = recv_idx; i < send_idx; i++)
272 recv_idx = send_idx + pof2 / (mask * 2);
273 for (i = send_idx; i < recv_idx; i++)
275 for (i = recv_idx; i < last_idx; i++)
279 /* Send data from recvbuf. Recv into tmp_buf */
280 Request::sendrecv(static_cast<char*>(recvbuf) + disps[send_idx] * extent, send_cnt, datatype, dst, tag,
281 tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst, tag, comm, &status);
283 /* tmp_buf contains data received in this step.
284 recvbuf contains data accumulated so far */
286 if (op != MPI_OP_NULL)
287 op->apply(tmp_buf + disps[recv_idx] * extent, static_cast<char*>(recvbuf) + disps[recv_idx] * extent,
288 &recv_cnt, datatype);
290 /* update send_idx for next iteration */
295 last_idx = recv_idx + pof2 / mask;
299 /* now do the gather to root */
301 if (root < 2 * rem) {
303 if (rank == root) { /* recv */
304 for (i = 0; i < (pof2 - 1); i++)
305 cnts[i] = count / pof2;
306 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
309 for (i = 1; i < pof2; i++)
310 disps[i] = disps[i - 1] + cnts[i - 1];
312 Request::recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
317 } else if (newrank == 0) {
318 Request::send(recvbuf, cnts[0], datatype, root, tag, comm);
325 newroot = root - rem;
330 while (mask < pof2) {
337 newdst = newrank ^ mask;
339 /* find real rank of dest */
340 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
342 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
344 newdst_tree_root = newdst >> j;
345 newdst_tree_root <<= j;
347 newroot_tree_root = newroot >> j;
348 newroot_tree_root <<= j;
350 send_cnt = recv_cnt = 0;
351 if (newrank < newdst) {
352 /* update last_idx except on first iteration */
353 if (mask != pof2 / 2)
354 last_idx = last_idx + pof2 / (mask * 2);
356 recv_idx = send_idx + pof2 / (mask * 2);
357 for (i = send_idx; i < recv_idx; i++)
359 for (i = recv_idx; i < last_idx; i++)
362 recv_idx = send_idx - pof2 / (mask * 2);
363 for (i = send_idx; i < last_idx; i++)
365 for (i = recv_idx; i < send_idx; i++)
369 if (newdst_tree_root == newroot_tree_root) {
370 Request::send((char *) recvbuf +
371 disps[send_idx] * extent,
372 send_cnt, datatype, dst, tag, comm);
375 Request::recv((char *) recvbuf +
376 disps[recv_idx] * extent,
377 recv_cnt, datatype, dst, tag, comm, &status);
380 if (newrank > newdst)
389 smpi_free_tmp_buffer(tmp_buf);
390 if (temporary_buffer == 1)
391 smpi_free_tmp_buffer(static_cast<unsigned char*>(recvbuf));
397 } // namespace simgrid::smpi