1 /* Copyright (c) 2013-2018. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
8 * (C) 2001 by Argonne National Laboratory.
9 * See COPYRIGHT in top-level directory.
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
15 * This file is part of the MVAPICH2 software package developed by the
16 * team members of The Ohio State University's Network-Based Computing
17 * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
19 * For detailed copyright and licensing information, please refer to the
20 * copyright file COPYRIGHT in the top level MVAPICH2 directory.
24 #include "../colls_private.hpp"
29 int Coll_allreduce_mvapich2_rs::allreduce(void *sendbuf,
32 MPI_Datatype datatype,
33 MPI_Op op, MPI_Comm comm)
35 int mpi_errno = MPI_SUCCESS;
37 int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
38 int dst, is_commutative, rem, newdst,
39 recv_cnt, *cnts, *disps;
40 MPI_Aint true_lb, true_extent, extent;
41 void *tmp_buf, *tmp_buf_free;
49 int comm_size = comm->size();
50 int rank = comm->rank();
52 is_commutative = (op==MPI_OP_NULL || op->is_commutative());
54 /* need to allocate temporary buffer to store incoming data */
55 datatype->extent(&true_lb, &true_extent);
56 extent = datatype->get_extent();
58 tmp_buf_free = smpi_get_tmp_recvbuffer(count * std::max(extent, true_extent));
60 /* adjust for potential negative lower bound in datatype */
61 tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
63 /* copy local data into recvbuf */
64 if (sendbuf != MPI_IN_PLACE) {
66 Datatype::copy(sendbuf, count, datatype, recvbuf, count,
70 /* find nearest power-of-two less than or equal to comm_size */
71 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
74 rem = comm_size - pof2;
76 /* In the non-power-of-two case, all even-numbered
77 processes of rank < 2*rem send their data to
78 (rank+1). These even-numbered processes no longer
79 participate in the algorithm until the very end. The
80 remaining processes form a nice power-of-two. */
85 Request::send(recvbuf, count, datatype, rank + 1,
86 COLL_TAG_ALLREDUCE, comm);
88 /* temporarily set the rank to -1 so that this
89 process does not pariticipate in recursive
94 Request::recv(tmp_buf, count, datatype, rank - 1,
95 COLL_TAG_ALLREDUCE, comm,
97 /* do the reduction on received data. since the
98 ordering is right, it doesn't matter whether
99 the operation is commutative or not. */
100 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
101 /* change the rank */
104 } else { /* rank >= 2*rem */
105 newrank = rank - rem;
108 /* If op is user-defined or count is less than pof2, use
109 recursive doubling algorithm. Otherwise do a reduce-scatter
110 followed by allgather. (If op is user-defined,
111 derived datatypes are allowed and the user could pass basic
112 datatypes on one process and derived on another as long as
113 the type maps are the same. Breaking up derived
114 datatypes to do the reduce-scatter is tricky, therefore
115 using recursive doubling in that case.) */
118 if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) { /* use recursive doubling */
120 while (mask < pof2) {
121 newdst = newrank ^ mask;
122 /* find real rank of dest */
123 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
125 /* Send the most current data, which is in recvbuf. Recv
127 Request::sendrecv(recvbuf, count, datatype,
128 dst, COLL_TAG_ALLREDUCE,
129 tmp_buf, count, datatype, dst,
130 COLL_TAG_ALLREDUCE, comm,
133 /* tmp_buf contains data received in this step.
134 recvbuf contains data accumulated so far */
136 if (is_commutative || (dst < rank)) {
137 /* op is commutative OR the order is already right */
138 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
140 /* op is noncommutative and the order is not right */
141 if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
142 /* copy result back into recvbuf */
143 mpi_errno = Datatype::copy(tmp_buf, count, datatype,
144 recvbuf, count, datatype);
150 /* do a reduce-scatter followed by allgather */
152 /* for the reduce-scatter, calculate the count that
153 each process receives and the displacement within
155 cnts = (int *)xbt_malloc(pof2 * sizeof (int));
156 disps = (int *)xbt_malloc(pof2 * sizeof (int));
158 for (i = 0; i < (pof2 - 1); i++) {
159 cnts[i] = count / pof2;
161 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
164 for (i = 1; i < pof2; i++) {
165 disps[i] = disps[i - 1] + cnts[i - 1];
169 send_idx = recv_idx = 0;
171 while (mask < pof2) {
172 newdst = newrank ^ mask;
173 /* find real rank of dest */
174 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
176 send_cnt = recv_cnt = 0;
177 if (newrank < newdst) {
178 send_idx = recv_idx + pof2 / (mask * 2);
179 for (i = send_idx; i < last_idx; i++)
181 for (i = recv_idx; i < send_idx; i++)
184 recv_idx = send_idx + pof2 / (mask * 2);
185 for (i = send_idx; i < recv_idx; i++)
187 for (i = recv_idx; i < last_idx; i++)
191 /* Send data from recvbuf. Recv into tmp_buf */
192 Request::sendrecv((char *) recvbuf +
193 disps[send_idx] * extent,
195 dst, COLL_TAG_ALLREDUCE,
197 disps[recv_idx] * extent,
198 recv_cnt, datatype, dst,
199 COLL_TAG_ALLREDUCE, comm,
202 /* tmp_buf contains data received in this step.
203 recvbuf contains data accumulated so far */
205 /* This algorithm is used only for predefined ops
206 and predefined ops are always commutative. */
208 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
209 (char *) recvbuf + disps[recv_idx] * extent,
210 &recv_cnt, datatype);
212 /* update send_idx for next iteration */
216 /* update last_idx, but not in last iteration
217 because the value is needed in the allgather
220 last_idx = recv_idx + pof2 / mask;
223 /* now do the allgather */
227 newdst = newrank ^ mask;
228 /* find real rank of dest */
229 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
231 send_cnt = recv_cnt = 0;
232 if (newrank < newdst) {
233 /* update last_idx except on first iteration */
234 if (mask != pof2 / 2) {
235 last_idx = last_idx + pof2 / (mask * 2);
238 recv_idx = send_idx + pof2 / (mask * 2);
239 for (i = send_idx; i < recv_idx; i++) {
242 for (i = recv_idx; i < last_idx; i++) {
246 recv_idx = send_idx - pof2 / (mask * 2);
247 for (i = send_idx; i < last_idx; i++) {
250 for (i = recv_idx; i < send_idx; i++) {
255 Request::sendrecv((char *) recvbuf +
256 disps[send_idx] * extent,
258 dst, COLL_TAG_ALLREDUCE,
260 disps[recv_idx] * extent,
261 recv_cnt, datatype, dst,
262 COLL_TAG_ALLREDUCE, comm,
264 if (newrank > newdst) {
275 /* In the non-power-of-two case, all odd-numbered
276 processes of rank < 2*rem send the result to
277 (rank-1), the ranks who didn't participate above. */
278 if (rank < 2 * rem) {
279 if (rank % 2) { /* odd */
280 Request::send(recvbuf, count,
282 COLL_TAG_ALLREDUCE, comm);
284 Request::recv(recvbuf, count,
286 COLL_TAG_ALLREDUCE, comm,
290 smpi_free_tmp_buffer(tmp_buf_free);