1 /* Copyright (c) 2013-2023. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
8 * (C) 2001 by Argonne National Laboratory.
9 * See COPYRIGHT in top-level directory.
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
15 * This file is part of the MVAPICH2 software package developed by the
16 * team members of The Ohio State University's Network-Based Computing
17 * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
19 * For detailed copyright and licensing information, please refer to the
20 * copyright file COPYRIGHT in the top level MVAPICH2 directory.
24 #include "../colls_private.hpp"
27 namespace simgrid::smpi {
28 int allreduce__mvapich2_rs(const void *sendbuf,
31 MPI_Datatype datatype,
32 MPI_Op op, MPI_Comm comm)
34 int mpi_errno = MPI_SUCCESS;
36 int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
37 int dst, rem, newdst, recv_cnt;
38 MPI_Aint true_lb, true_extent, extent;
46 int comm_size = comm->size();
47 int rank = comm->rank();
49 bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
51 /* need to allocate temporary buffer to store incoming data */
52 datatype->extent(&true_lb, &true_extent);
53 extent = datatype->get_extent();
55 unsigned char* tmp_buf_free = smpi_get_tmp_recvbuffer(count * std::max(extent, true_extent));
57 /* adjust for potential negative lower bound in datatype */
58 unsigned char* tmp_buf = tmp_buf_free - true_lb;
60 /* copy local data into recvbuf */
61 if (sendbuf != MPI_IN_PLACE) {
63 Datatype::copy(sendbuf, count, datatype, recvbuf, count,
67 /* find nearest power-of-two less than or equal to comm_size */
68 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
71 rem = comm_size - pof2;
73 /* In the non-power-of-two case, all even-numbered
74 processes of rank < 2*rem send their data to
75 (rank+1). These even-numbered processes no longer
76 participate in the algorithm until the very end. The
77 remaining processes form a nice power-of-two. */
82 Request::send(recvbuf, count, datatype, rank + 1,
83 COLL_TAG_ALLREDUCE, comm);
85 /* temporarily set the rank to -1 so that this
86 process does not participate in recursive
91 Request::recv(tmp_buf, count, datatype, rank - 1,
92 COLL_TAG_ALLREDUCE, comm,
94 /* do the reduction on received data. since the
95 ordering is right, it doesn't matter whether
96 the operation is commutative or not. */
97 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
101 } else { /* rank >= 2*rem */
102 newrank = rank - rem;
105 /* If op is user-defined or count is less than pof2, use
106 recursive doubling algorithm. Otherwise do a reduce-scatter
107 followed by allgather. (If op is user-defined,
108 derived datatypes are allowed and the user could pass basic
109 datatypes on one process and derived on another as long as
110 the type maps are the same. Breaking up derived
111 datatypes to do the reduce-scatter is tricky, therefore
112 using recursive doubling in that case.) */
115 if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) { /* use recursive doubling */
117 while (mask < pof2) {
118 newdst = newrank ^ mask;
119 /* find real rank of dest */
120 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
122 /* Send the most current data, which is in recvbuf. Recv
124 Request::sendrecv(recvbuf, count, datatype,
125 dst, COLL_TAG_ALLREDUCE,
126 tmp_buf, count, datatype, dst,
127 COLL_TAG_ALLREDUCE, comm,
130 /* tmp_buf contains data received in this step.
131 recvbuf contains data accumulated so far */
133 if (is_commutative || (dst < rank)) {
134 /* op is commutative OR the order is already right */
135 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
137 /* op is noncommutative and the order is not right */
138 if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
139 /* copy result back into recvbuf */
140 mpi_errno = Datatype::copy(tmp_buf, count, datatype,
141 recvbuf, count, datatype);
147 /* do a reduce-scatter followed by allgather */
149 /* for the reduce-scatter, calculate the count that
150 each process receives and the displacement within
152 int* cnts = new int[pof2];
153 int* disps = new int[pof2];
155 for (i = 0; i < (pof2 - 1); i++) {
156 cnts[i] = count / pof2;
158 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
161 for (i = 1; i < pof2; i++) {
162 disps[i] = disps[i - 1] + cnts[i - 1];
166 send_idx = recv_idx = 0;
168 while (mask < pof2) {
169 newdst = newrank ^ mask;
170 /* find real rank of dest */
171 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
173 send_cnt = recv_cnt = 0;
174 if (newrank < newdst) {
175 send_idx = recv_idx + pof2 / (mask * 2);
176 for (i = send_idx; i < last_idx; i++)
178 for (i = recv_idx; i < send_idx; i++)
181 recv_idx = send_idx + pof2 / (mask * 2);
182 for (i = send_idx; i < recv_idx; i++)
184 for (i = recv_idx; i < last_idx; i++)
188 /* Send data from recvbuf. Recv into tmp_buf */
189 Request::sendrecv(static_cast<char*>(recvbuf) + disps[send_idx] * extent, send_cnt, datatype, dst,
190 COLL_TAG_ALLREDUCE, tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst,
191 COLL_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE);
193 /* tmp_buf contains data received in this step.
194 recvbuf contains data accumulated so far */
196 /* This algorithm is used only for predefined ops
197 and predefined ops are always commutative. */
199 if (op != MPI_OP_NULL)
200 op->apply(tmp_buf + disps[recv_idx] * extent, static_cast<char*>(recvbuf) + disps[recv_idx] * extent,
201 &recv_cnt, datatype);
203 /* update send_idx for next iteration */
207 /* update last_idx, but not in last iteration
208 because the value is needed in the allgather
211 last_idx = recv_idx + pof2 / mask;
214 /* now do the allgather */
218 newdst = newrank ^ mask;
219 /* find real rank of dest */
220 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
222 send_cnt = recv_cnt = 0;
223 if (newrank < newdst) {
224 /* update last_idx except on first iteration */
225 if (mask != pof2 / 2) {
226 last_idx = last_idx + pof2 / (mask * 2);
229 recv_idx = send_idx + pof2 / (mask * 2);
230 for (i = send_idx; i < recv_idx; i++) {
233 for (i = recv_idx; i < last_idx; i++) {
237 recv_idx = send_idx - pof2 / (mask * 2);
238 for (i = send_idx; i < last_idx; i++) {
241 for (i = recv_idx; i < send_idx; i++) {
246 Request::sendrecv((char *) recvbuf +
247 disps[send_idx] * extent,
249 dst, COLL_TAG_ALLREDUCE,
251 disps[recv_idx] * extent,
252 recv_cnt, datatype, dst,
253 COLL_TAG_ALLREDUCE, comm,
255 if (newrank > newdst) {
266 /* In the non-power-of-two case, all odd-numbered
267 processes of rank < 2*rem send the result to
268 (rank-1), the ranks who didn't participate above. */
269 if (rank < 2 * rem) {
270 if (rank % 2) { /* odd */
271 Request::send(recvbuf, count,
273 COLL_TAG_ALLREDUCE, comm);
275 Request::recv(recvbuf, count,
277 COLL_TAG_ALLREDUCE, comm,
281 smpi_free_tmp_buffer(tmp_buf_free);
286 } // namespace simgrid::smpi