1 /* Copyright (c) 2013-2014. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
8 * (C) 2001 by Argonne National Laboratory.
9 * See COPYRIGHT in top-level directory.
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
15 * This file is part of the MVAPICH2 software package developed by the
16 * team members of The Ohio State University's Network-Based Computing
17 * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
19 * For detailed copyright and licensing information, please refer to the
20 * copyright file COPYRIGHT in the top level MVAPICH2 directory.
24 #include "colls_private.h"
26 int smpi_coll_tuned_allreduce_mvapich2_rs(void *sendbuf,
29 MPI_Datatype datatype,
30 MPI_Op op, MPI_Comm comm)
32 int mpi_errno = MPI_SUCCESS;
34 unsigned int mask, pof2;
35 int dst, is_commutative, rem, newdst, i,
36 send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
37 MPI_Aint true_lb, true_extent, extent;
38 void *tmp_buf, *tmp_buf_free;
46 unsigned int comm_size = smpi_comm_size(comm);
47 unsigned int rank = smpi_comm_rank(comm);
49 is_commutative = smpi_op_is_commute(op);
51 /* need to allocate temporary buffer to store incoming data */
52 smpi_datatype_extent(datatype, &true_lb, &true_extent);
53 extent = smpi_datatype_get_extent(datatype);
55 tmp_buf_free= smpi_get_tmp_recvbuffer(count * (MAX(extent, true_extent)));
57 /* adjust for potential negative lower bound in datatype */
58 tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
60 /* copy local data into recvbuf */
61 if (sendbuf != MPI_IN_PLACE) {
63 smpi_datatype_copy(sendbuf, count, datatype, recvbuf, count,
67 /* find nearest power-of-two less than or equal to comm_size */
68 for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
71 rem = comm_size - pof2;
73 /* In the non-power-of-two case, all even-numbered
74 processes of rank < 2*rem send their data to
75 (rank+1). These even-numbered processes no longer
76 participate in the algorithm until the very end. The
77 remaining processes form a nice power-of-two. */
82 smpi_mpi_send(recvbuf, count, datatype, rank + 1,
83 COLL_TAG_ALLREDUCE, comm);
85 /* temporarily set the rank to -1 so that this
86 process does not pariticipate in recursive
91 smpi_mpi_recv(tmp_buf, count, datatype, rank - 1,
92 COLL_TAG_ALLREDUCE, comm,
94 /* do the reduction on received data. since the
95 ordering is right, it doesn't matter whether
96 the operation is commutative or not. */
97 smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
101 } else { /* rank >= 2*rem */
102 newrank = rank - rem;
105 /* If op is user-defined or count is less than pof2, use
106 recursive doubling algorithm. Otherwise do a reduce-scatter
107 followed by allgather. (If op is user-defined,
108 derived datatypes are allowed and the user could pass basic
109 datatypes on one process and derived on another as long as
110 the type maps are the same. Breaking up derived
111 datatypes to do the reduce-scatter is tricky, therefore
112 using recursive doubling in that case.) */
115 if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) { /* use recursive doubling */
117 while (mask < pof2) {
118 newdst = newrank ^ mask;
119 /* find real rank of dest */
120 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
122 /* Send the most current data, which is in recvbuf. Recv
124 smpi_mpi_sendrecv(recvbuf, count, datatype,
125 dst, COLL_TAG_ALLREDUCE,
126 tmp_buf, count, datatype, dst,
127 COLL_TAG_ALLREDUCE, comm,
130 /* tmp_buf contains data received in this step.
131 recvbuf contains data accumulated so far */
133 if (is_commutative || (dst < rank)) {
134 /* op is commutative OR the order is already right */
135 smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
137 /* op is noncommutative and the order is not right */
138 smpi_op_apply(op, recvbuf, tmp_buf, &count, &datatype);
139 /* copy result back into recvbuf */
140 mpi_errno = smpi_datatype_copy(tmp_buf, count, datatype,
141 recvbuf, count, datatype);
147 /* do a reduce-scatter followed by allgather */
149 /* for the reduce-scatter, calculate the count that
150 each process receives and the displacement within
152 cnts = (int *)xbt_malloc(pof2 * sizeof (int));
153 disps = (int *)xbt_malloc(pof2 * sizeof (int));
155 for (i = 0; i < (pof2 - 1); i++) {
156 cnts[i] = count / pof2;
158 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
161 for (i = 1; i < pof2; i++) {
162 disps[i] = disps[i - 1] + cnts[i - 1];
166 send_idx = recv_idx = 0;
168 while (mask < pof2) {
169 newdst = newrank ^ mask;
170 /* find real rank of dest */
171 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
173 send_cnt = recv_cnt = 0;
174 if (newrank < newdst) {
175 send_idx = recv_idx + pof2 / (mask * 2);
176 for (i = send_idx; i < last_idx; i++)
178 for (i = recv_idx; i < send_idx; i++)
181 recv_idx = send_idx + pof2 / (mask * 2);
182 for (i = send_idx; i < recv_idx; i++)
184 for (i = recv_idx; i < last_idx; i++)
188 /* Send data from recvbuf. Recv into tmp_buf */
189 smpi_mpi_sendrecv((char *) recvbuf +
190 disps[send_idx] * extent,
192 dst, COLL_TAG_ALLREDUCE,
194 disps[recv_idx] * extent,
195 recv_cnt, datatype, dst,
196 COLL_TAG_ALLREDUCE, comm,
199 /* tmp_buf contains data received in this step.
200 recvbuf contains data accumulated so far */
202 /* This algorithm is used only for predefined ops
203 and predefined ops are always commutative. */
205 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
206 (char *) recvbuf + disps[recv_idx] * extent,
207 &recv_cnt, &datatype);
209 /* update send_idx for next iteration */
213 /* update last_idx, but not in last iteration
214 because the value is needed in the allgather
217 last_idx = recv_idx + pof2 / mask;
220 /* now do the allgather */
224 newdst = newrank ^ mask;
225 /* find real rank of dest */
226 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
228 send_cnt = recv_cnt = 0;
229 if (newrank < newdst) {
230 /* update last_idx except on first iteration */
231 if (mask != pof2 / 2) {
232 last_idx = last_idx + pof2 / (mask * 2);
235 recv_idx = send_idx + pof2 / (mask * 2);
236 for (i = send_idx; i < recv_idx; i++) {
239 for (i = recv_idx; i < last_idx; i++) {
243 recv_idx = send_idx - pof2 / (mask * 2);
244 for (i = send_idx; i < last_idx; i++) {
247 for (i = recv_idx; i < send_idx; i++) {
252 smpi_mpi_sendrecv((char *) recvbuf +
253 disps[send_idx] * extent,
255 dst, COLL_TAG_ALLREDUCE,
257 disps[recv_idx] * extent,
258 recv_cnt, datatype, dst,
259 COLL_TAG_ALLREDUCE, comm,
261 if (newrank > newdst) {
272 /* In the non-power-of-two case, all odd-numbered
273 processes of rank < 2*rem send the result to
274 (rank-1), the ranks who didn't participate above. */
275 if (rank < 2 * rem) {
276 if (rank % 2) { /* odd */
277 smpi_mpi_send(recvbuf, count,
279 COLL_TAG_ALLREDUCE, comm);
281 smpi_mpi_recv(recvbuf, count,
283 COLL_TAG_ALLREDUCE, comm,
287 smpi_free_tmp_buffer(tmp_buf_free);