1 /* Copyright (c) 2013-2019. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
8 * (C) 2001 by Argonne National Laboratory.
9 * See COPYRIGHT in top-level directory.
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
15 * This file is part of the MVAPICH2 software package developed by the
16 * team members of The Ohio State University's Network-Based Computing
17 * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
19 * For detailed copyright and licensing information, please refer to the
20 * copyright file COPYRIGHT in the top level MVAPICH2 directory.
24 #include "../colls_private.hpp"
29 int Coll_allreduce_mvapich2_rs::allreduce(const void *sendbuf,
32 MPI_Datatype datatype,
33 MPI_Op op, MPI_Comm comm)
35 int mpi_errno = MPI_SUCCESS;
37 int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
38 int dst, is_commutative, rem, newdst, recv_cnt;
39 MPI_Aint true_lb, true_extent, extent;
40 void *tmp_buf, *tmp_buf_free;
48 int comm_size = comm->size();
49 int rank = comm->rank();
51 is_commutative = (op==MPI_OP_NULL || op->is_commutative());
53 /* need to allocate temporary buffer to store incoming data */
54 datatype->extent(&true_lb, &true_extent);
55 extent = datatype->get_extent();
57 tmp_buf_free = smpi_get_tmp_recvbuffer(count * std::max(extent, true_extent));
59 /* adjust for potential negative lower bound in datatype */
60 tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
62 /* copy local data into recvbuf */
63 if (sendbuf != MPI_IN_PLACE) {
65 Datatype::copy(sendbuf, count, datatype, recvbuf, count,
69 /* find nearest power-of-two less than or equal to comm_size */
70 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
73 rem = comm_size - pof2;
75 /* In the non-power-of-two case, all even-numbered
76 processes of rank < 2*rem send their data to
77 (rank+1). These even-numbered processes no longer
78 participate in the algorithm until the very end. The
79 remaining processes form a nice power-of-two. */
84 Request::send(recvbuf, count, datatype, rank + 1,
85 COLL_TAG_ALLREDUCE, comm);
87 /* temporarily set the rank to -1 so that this
88 process does not pariticipate in recursive
93 Request::recv(tmp_buf, count, datatype, rank - 1,
94 COLL_TAG_ALLREDUCE, comm,
96 /* do the reduction on received data. since the
97 ordering is right, it doesn't matter whether
98 the operation is commutative or not. */
99 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
100 /* change the rank */
103 } else { /* rank >= 2*rem */
104 newrank = rank - rem;
107 /* If op is user-defined or count is less than pof2, use
108 recursive doubling algorithm. Otherwise do a reduce-scatter
109 followed by allgather. (If op is user-defined,
110 derived datatypes are allowed and the user could pass basic
111 datatypes on one process and derived on another as long as
112 the type maps are the same. Breaking up derived
113 datatypes to do the reduce-scatter is tricky, therefore
114 using recursive doubling in that case.) */
117 if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) { /* use recursive doubling */
119 while (mask < pof2) {
120 newdst = newrank ^ mask;
121 /* find real rank of dest */
122 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
124 /* Send the most current data, which is in recvbuf. Recv
126 Request::sendrecv(recvbuf, count, datatype,
127 dst, COLL_TAG_ALLREDUCE,
128 tmp_buf, count, datatype, dst,
129 COLL_TAG_ALLREDUCE, comm,
132 /* tmp_buf contains data received in this step.
133 recvbuf contains data accumulated so far */
135 if (is_commutative || (dst < rank)) {
136 /* op is commutative OR the order is already right */
137 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
139 /* op is noncommutative and the order is not right */
140 if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
141 /* copy result back into recvbuf */
142 mpi_errno = Datatype::copy(tmp_buf, count, datatype,
143 recvbuf, count, datatype);
149 /* do a reduce-scatter followed by allgather */
151 /* for the reduce-scatter, calculate the count that
152 each process receives and the displacement within
154 int* cnts = new int[pof2];
155 int* disps = new int[pof2];
157 for (i = 0; i < (pof2 - 1); i++) {
158 cnts[i] = count / pof2;
160 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
163 for (i = 1; i < pof2; i++) {
164 disps[i] = disps[i - 1] + cnts[i - 1];
168 send_idx = recv_idx = 0;
170 while (mask < pof2) {
171 newdst = newrank ^ mask;
172 /* find real rank of dest */
173 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
175 send_cnt = recv_cnt = 0;
176 if (newrank < newdst) {
177 send_idx = recv_idx + pof2 / (mask * 2);
178 for (i = send_idx; i < last_idx; i++)
180 for (i = recv_idx; i < send_idx; i++)
183 recv_idx = send_idx + pof2 / (mask * 2);
184 for (i = send_idx; i < recv_idx; i++)
186 for (i = recv_idx; i < last_idx; i++)
190 /* Send data from recvbuf. Recv into tmp_buf */
191 Request::sendrecv((char *) recvbuf +
192 disps[send_idx] * extent,
194 dst, COLL_TAG_ALLREDUCE,
196 disps[recv_idx] * extent,
197 recv_cnt, datatype, dst,
198 COLL_TAG_ALLREDUCE, comm,
201 /* tmp_buf contains data received in this step.
202 recvbuf contains data accumulated so far */
204 /* This algorithm is used only for predefined ops
205 and predefined ops are always commutative. */
207 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
208 (char *) recvbuf + disps[recv_idx] * extent,
209 &recv_cnt, datatype);
211 /* update send_idx for next iteration */
215 /* update last_idx, but not in last iteration
216 because the value is needed in the allgather
219 last_idx = recv_idx + pof2 / mask;
222 /* now do the allgather */
226 newdst = newrank ^ mask;
227 /* find real rank of dest */
228 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
230 send_cnt = recv_cnt = 0;
231 if (newrank < newdst) {
232 /* update last_idx except on first iteration */
233 if (mask != pof2 / 2) {
234 last_idx = last_idx + pof2 / (mask * 2);
237 recv_idx = send_idx + pof2 / (mask * 2);
238 for (i = send_idx; i < recv_idx; i++) {
241 for (i = recv_idx; i < last_idx; i++) {
245 recv_idx = send_idx - pof2 / (mask * 2);
246 for (i = send_idx; i < last_idx; i++) {
249 for (i = recv_idx; i < send_idx; i++) {
254 Request::sendrecv((char *) recvbuf +
255 disps[send_idx] * extent,
257 dst, COLL_TAG_ALLREDUCE,
259 disps[recv_idx] * extent,
260 recv_cnt, datatype, dst,
261 COLL_TAG_ALLREDUCE, comm,
263 if (newrank > newdst) {
274 /* In the non-power-of-two case, all odd-numbered
275 processes of rank < 2*rem send the result to
276 (rank-1), the ranks who didn't participate above. */
277 if (rank < 2 * rem) {
278 if (rank % 2) { /* odd */
279 Request::send(recvbuf, count,
281 COLL_TAG_ALLREDUCE, comm);
283 Request::recv(recvbuf, count,
285 COLL_TAG_ALLREDUCE, comm,
289 smpi_free_tmp_buffer(tmp_buf_free);