1 /* Copyright (c) 2013-2017. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
8 * (C) 2001 by Argonne National Laboratory.
9 * See COPYRIGHT in top-level directory.
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
15 * This file is part of the MVAPICH2 software package developed by the
16 * team members of The Ohio State University's Network-Based Computing
17 * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
19 * For detailed copyright and licensing information, please refer to the
20 * copyright file COPYRIGHT in the top level MVAPICH2 directory.
24 #include "../colls_private.hpp"
27 int Coll_allreduce_mvapich2_rs::allreduce(void *sendbuf,
30 MPI_Datatype datatype,
31 MPI_Op op, MPI_Comm comm)
33 int mpi_errno = MPI_SUCCESS;
35 int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
36 int dst, is_commutative, rem, newdst,
37 recv_cnt, *cnts, *disps;
38 MPI_Aint true_lb, true_extent, extent;
39 void *tmp_buf, *tmp_buf_free;
47 int comm_size = comm->size();
48 int rank = comm->rank();
50 is_commutative = (op==MPI_OP_NULL || op->is_commutative());
52 /* need to allocate temporary buffer to store incoming data */
53 datatype->extent(&true_lb, &true_extent);
54 extent = datatype->get_extent();
56 tmp_buf_free= smpi_get_tmp_recvbuffer(count * (MAX(extent, true_extent)));
58 /* adjust for potential negative lower bound in datatype */
59 tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
61 /* copy local data into recvbuf */
62 if (sendbuf != MPI_IN_PLACE) {
64 Datatype::copy(sendbuf, count, datatype, recvbuf, count,
68 /* find nearest power-of-two less than or equal to comm_size */
69 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
72 rem = comm_size - pof2;
74 /* In the non-power-of-two case, all even-numbered
75 processes of rank < 2*rem send their data to
76 (rank+1). These even-numbered processes no longer
77 participate in the algorithm until the very end. The
78 remaining processes form a nice power-of-two. */
83 Request::send(recvbuf, count, datatype, rank + 1,
84 COLL_TAG_ALLREDUCE, comm);
86 /* temporarily set the rank to -1 so that this
87 process does not pariticipate in recursive
92 Request::recv(tmp_buf, count, datatype, rank - 1,
93 COLL_TAG_ALLREDUCE, comm,
95 /* do the reduction on received data. since the
96 ordering is right, it doesn't matter whether
97 the operation is commutative or not. */
98 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
102 } else { /* rank >= 2*rem */
103 newrank = rank - rem;
106 /* If op is user-defined or count is less than pof2, use
107 recursive doubling algorithm. Otherwise do a reduce-scatter
108 followed by allgather. (If op is user-defined,
109 derived datatypes are allowed and the user could pass basic
110 datatypes on one process and derived on another as long as
111 the type maps are the same. Breaking up derived
112 datatypes to do the reduce-scatter is tricky, therefore
113 using recursive doubling in that case.) */
116 if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) { /* use recursive doubling */
118 while (mask < pof2) {
119 newdst = newrank ^ mask;
120 /* find real rank of dest */
121 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
123 /* Send the most current data, which is in recvbuf. Recv
125 Request::sendrecv(recvbuf, count, datatype,
126 dst, COLL_TAG_ALLREDUCE,
127 tmp_buf, count, datatype, dst,
128 COLL_TAG_ALLREDUCE, comm,
131 /* tmp_buf contains data received in this step.
132 recvbuf contains data accumulated so far */
134 if (is_commutative || (dst < rank)) {
135 /* op is commutative OR the order is already right */
136 if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
138 /* op is noncommutative and the order is not right */
139 if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
140 /* copy result back into recvbuf */
141 mpi_errno = Datatype::copy(tmp_buf, count, datatype,
142 recvbuf, count, datatype);
148 /* do a reduce-scatter followed by allgather */
150 /* for the reduce-scatter, calculate the count that
151 each process receives and the displacement within
153 cnts = (int *)xbt_malloc(pof2 * sizeof (int));
154 disps = (int *)xbt_malloc(pof2 * sizeof (int));
156 for (i = 0; i < (pof2 - 1); i++) {
157 cnts[i] = count / pof2;
159 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
162 for (i = 1; i < pof2; i++) {
163 disps[i] = disps[i - 1] + cnts[i - 1];
167 send_idx = recv_idx = 0;
169 while (mask < pof2) {
170 newdst = newrank ^ mask;
171 /* find real rank of dest */
172 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
174 send_cnt = recv_cnt = 0;
175 if (newrank < newdst) {
176 send_idx = recv_idx + pof2 / (mask * 2);
177 for (i = send_idx; i < last_idx; i++)
179 for (i = recv_idx; i < send_idx; i++)
182 recv_idx = send_idx + pof2 / (mask * 2);
183 for (i = send_idx; i < recv_idx; i++)
185 for (i = recv_idx; i < last_idx; i++)
189 /* Send data from recvbuf. Recv into tmp_buf */
190 Request::sendrecv((char *) recvbuf +
191 disps[send_idx] * extent,
193 dst, COLL_TAG_ALLREDUCE,
195 disps[recv_idx] * extent,
196 recv_cnt, datatype, dst,
197 COLL_TAG_ALLREDUCE, comm,
200 /* tmp_buf contains data received in this step.
201 recvbuf contains data accumulated so far */
203 /* This algorithm is used only for predefined ops
204 and predefined ops are always commutative. */
206 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
207 (char *) recvbuf + disps[recv_idx] * extent,
208 &recv_cnt, datatype);
210 /* update send_idx for next iteration */
214 /* update last_idx, but not in last iteration
215 because the value is needed in the allgather
218 last_idx = recv_idx + pof2 / mask;
221 /* now do the allgather */
225 newdst = newrank ^ mask;
226 /* find real rank of dest */
227 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
229 send_cnt = recv_cnt = 0;
230 if (newrank < newdst) {
231 /* update last_idx except on first iteration */
232 if (mask != pof2 / 2) {
233 last_idx = last_idx + pof2 / (mask * 2);
236 recv_idx = send_idx + pof2 / (mask * 2);
237 for (i = send_idx; i < recv_idx; i++) {
240 for (i = recv_idx; i < last_idx; i++) {
244 recv_idx = send_idx - pof2 / (mask * 2);
245 for (i = send_idx; i < last_idx; i++) {
248 for (i = recv_idx; i < send_idx; i++) {
253 Request::sendrecv((char *) recvbuf +
254 disps[send_idx] * extent,
256 dst, COLL_TAG_ALLREDUCE,
258 disps[recv_idx] * extent,
259 recv_cnt, datatype, dst,
260 COLL_TAG_ALLREDUCE, comm,
262 if (newrank > newdst) {
273 /* In the non-power-of-two case, all odd-numbered
274 processes of rank < 2*rem send the result to
275 (rank-1), the ranks who didn't participate above. */
276 if (rank < 2 * rem) {
277 if (rank % 2) { /* odd */
278 Request::send(recvbuf, count,
280 COLL_TAG_ALLREDUCE, comm);
282 Request::recv(recvbuf, count,
284 COLL_TAG_ALLREDUCE, comm,
288 smpi_free_tmp_buffer(tmp_buf_free);