Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'actor-yield' of github.com:Takishipp/simgrid into actor-yield
[simgrid.git] / src / smpi / colls / allreduce / allreduce-mvapich-rs.cpp
1 /* Copyright (c) 2013-2017. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  *  (C) 2001 by Argonne National Laboratory.
9  *      See COPYRIGHT in top-level directory.
10  */
11
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
13  * reserved.
14  *
15  * This file is part of the MVAPICH2 software package developed by the
16  * team members of The Ohio State University's Network-Based Computing
17  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
18  *
19  * For detailed copyright and licensing information, please refer to the
20  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
21  *
22  */
23
24 #include "../colls_private.hpp"
25 #include <algorithm>
26
27 namespace simgrid{
28 namespace smpi{
29 int Coll_allreduce_mvapich2_rs::allreduce(void *sendbuf,
30                             void *recvbuf,
31                             int count,
32                             MPI_Datatype datatype,
33                             MPI_Op op, MPI_Comm comm)
34 {
35     int mpi_errno = MPI_SUCCESS;
36     int newrank = 0;
37     int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
38     int dst, is_commutative, rem, newdst,
39         recv_cnt, *cnts, *disps;
40     MPI_Aint true_lb, true_extent, extent;
41     void *tmp_buf, *tmp_buf_free;
42
43     if (count == 0) {
44         return MPI_SUCCESS;
45     }
46
47     /* homogeneous */
48
49     int comm_size =  comm->size();
50     int rank = comm->rank();
51
52     is_commutative = (op==MPI_OP_NULL || op->is_commutative());
53
54     /* need to allocate temporary buffer to store incoming data */
55     datatype->extent(&true_lb, &true_extent);
56     extent = datatype->get_extent();
57
58     tmp_buf_free = smpi_get_tmp_recvbuffer(count * std::max(extent, true_extent));
59
60     /* adjust for potential negative lower bound in datatype */
61     tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
62
63     /* copy local data into recvbuf */
64     if (sendbuf != MPI_IN_PLACE) {
65         mpi_errno =
66             Datatype::copy(sendbuf, count, datatype, recvbuf, count,
67                            datatype);
68     }
69
70     /* find nearest power-of-two less than or equal to comm_size */
71     for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
72     pof2 >>=1;
73
74     rem = comm_size - pof2;
75
76     /* In the non-power-of-two case, all even-numbered
77        processes of rank < 2*rem send their data to
78        (rank+1). These even-numbered processes no longer
79        participate in the algorithm until the very end. The
80        remaining processes form a nice power-of-two. */
81
82     if (rank < 2 * rem) {
83         if (rank % 2 == 0) {
84             /* even */
85             Request::send(recvbuf, count, datatype, rank + 1,
86                                      COLL_TAG_ALLREDUCE, comm);
87
88             /* temporarily set the rank to -1 so that this
89                process does not pariticipate in recursive
90                doubling */
91             newrank = -1;
92         } else {
93             /* odd */
94             Request::recv(tmp_buf, count, datatype, rank - 1,
95                                      COLL_TAG_ALLREDUCE, comm,
96                                      MPI_STATUS_IGNORE);
97             /* do the reduction on received data. since the
98                ordering is right, it doesn't matter whether
99                the operation is commutative or not. */
100                if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
101                 /* change the rank */
102                 newrank = rank / 2;
103         }
104     } else {                /* rank >= 2*rem */
105         newrank = rank - rem;
106     }
107
108     /* If op is user-defined or count is less than pof2, use
109        recursive doubling algorithm. Otherwise do a reduce-scatter
110        followed by allgather. (If op is user-defined,
111        derived datatypes are allowed and the user could pass basic
112        datatypes on one process and derived on another as long as
113        the type maps are the same. Breaking up derived
114        datatypes to do the reduce-scatter is tricky, therefore
115        using recursive doubling in that case.) */
116
117     if (newrank != -1) {
118         if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) {  /* use recursive doubling */
119             mask = 0x1;
120             while (mask < pof2) {
121                 newdst = newrank ^ mask;
122                 /* find real rank of dest */
123                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
124
125                 /* Send the most current data, which is in recvbuf. Recv
126                    into tmp_buf */
127                 Request::sendrecv(recvbuf, count, datatype,
128                                              dst, COLL_TAG_ALLREDUCE,
129                                              tmp_buf, count, datatype, dst,
130                                              COLL_TAG_ALLREDUCE, comm,
131                                              MPI_STATUS_IGNORE);
132
133                 /* tmp_buf contains data received in this step.
134                    recvbuf contains data accumulated so far */
135
136                 if (is_commutative || (dst < rank)) {
137                     /* op is commutative OR the order is already right */
138                      if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
139                 } else {
140                     /* op is noncommutative and the order is not right */
141                     if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
142                     /* copy result back into recvbuf */
143                     mpi_errno = Datatype::copy(tmp_buf, count, datatype,
144                                                recvbuf, count, datatype);
145                 }
146                 mask <<= 1;
147             }
148         } else {
149
150             /* do a reduce-scatter followed by allgather */
151
152             /* for the reduce-scatter, calculate the count that
153                each process receives and the displacement within
154                the buffer */
155             cnts = (int *)xbt_malloc(pof2 * sizeof (int));
156             disps = (int *)xbt_malloc(pof2 * sizeof (int));
157
158             for (i = 0; i < (pof2 - 1); i++) {
159                 cnts[i] = count / pof2;
160             }
161             cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
162
163             disps[0] = 0;
164             for (i = 1; i < pof2; i++) {
165                 disps[i] = disps[i - 1] + cnts[i - 1];
166             }
167
168             mask = 0x1;
169             send_idx = recv_idx = 0;
170             last_idx = pof2;
171             while (mask < pof2) {
172                 newdst = newrank ^ mask;
173                 /* find real rank of dest */
174                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
175
176                 send_cnt = recv_cnt = 0;
177                 if (newrank < newdst) {
178                     send_idx = recv_idx + pof2 / (mask * 2);
179                     for (i = send_idx; i < last_idx; i++)
180                         send_cnt += cnts[i];
181                     for (i = recv_idx; i < send_idx; i++)
182                         recv_cnt += cnts[i];
183                 } else {
184                     recv_idx = send_idx + pof2 / (mask * 2);
185                     for (i = send_idx; i < recv_idx; i++)
186                         send_cnt += cnts[i];
187                     for (i = recv_idx; i < last_idx; i++)
188                         recv_cnt += cnts[i];
189                 }
190
191                 /* Send data from recvbuf. Recv into tmp_buf */
192                 Request::sendrecv((char *) recvbuf +
193                                              disps[send_idx] * extent,
194                                              send_cnt, datatype,
195                                              dst, COLL_TAG_ALLREDUCE,
196                                              (char *) tmp_buf +
197                                              disps[recv_idx] * extent,
198                                              recv_cnt, datatype, dst,
199                                              COLL_TAG_ALLREDUCE, comm,
200                                              MPI_STATUS_IGNORE);
201
202                 /* tmp_buf contains data received in this step.
203                    recvbuf contains data accumulated so far */
204
205                 /* This algorithm is used only for predefined ops
206                    and predefined ops are always commutative. */
207
208                 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
209                         (char *) recvbuf + disps[recv_idx] * extent,
210                         &recv_cnt, datatype);
211
212                 /* update send_idx for next iteration */
213                 send_idx = recv_idx;
214                 mask <<= 1;
215
216                 /* update last_idx, but not in last iteration
217                    because the value is needed in the allgather
218                    step below. */
219                 if (mask < pof2)
220                     last_idx = recv_idx + pof2 / mask;
221             }
222
223             /* now do the allgather */
224
225             mask >>= 1;
226             while (mask > 0) {
227                 newdst = newrank ^ mask;
228                 /* find real rank of dest */
229                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
230
231                 send_cnt = recv_cnt = 0;
232                 if (newrank < newdst) {
233                     /* update last_idx except on first iteration */
234                     if (mask != pof2 / 2) {
235                         last_idx = last_idx + pof2 / (mask * 2);
236                     }
237
238                     recv_idx = send_idx + pof2 / (mask * 2);
239                     for (i = send_idx; i < recv_idx; i++) {
240                         send_cnt += cnts[i];
241                     }
242                     for (i = recv_idx; i < last_idx; i++) {
243                         recv_cnt += cnts[i];
244                     }
245                 } else {
246                     recv_idx = send_idx - pof2 / (mask * 2);
247                     for (i = send_idx; i < last_idx; i++) {
248                         send_cnt += cnts[i];
249                     }
250                     for (i = recv_idx; i < send_idx; i++) {
251                         recv_cnt += cnts[i];
252                     }
253                 }
254
255                Request::sendrecv((char *) recvbuf +
256                                              disps[send_idx] * extent,
257                                              send_cnt, datatype,
258                                              dst, COLL_TAG_ALLREDUCE,
259                                              (char *) recvbuf +
260                                              disps[recv_idx] * extent,
261                                              recv_cnt, datatype, dst,
262                                              COLL_TAG_ALLREDUCE, comm,
263                                              MPI_STATUS_IGNORE);
264                 if (newrank > newdst) {
265                     send_idx = recv_idx;
266                 }
267
268                 mask >>= 1;
269             }
270             xbt_free(disps);
271             xbt_free(cnts);
272         }
273     }
274
275     /* In the non-power-of-two case, all odd-numbered
276        processes of rank < 2*rem send the result to
277        (rank-1), the ranks who didn't participate above. */
278     if (rank < 2 * rem) {
279         if (rank % 2) {     /* odd */
280             Request::send(recvbuf, count,
281                                      datatype, rank - 1,
282                                      COLL_TAG_ALLREDUCE, comm);
283         } else {            /* even */
284             Request::recv(recvbuf, count,
285                                   datatype, rank + 1,
286                                   COLL_TAG_ALLREDUCE, comm,
287                                   MPI_STATUS_IGNORE);
288         }
289     }
290     smpi_free_tmp_buffer(tmp_buf_free);
291     return (mpi_errno);
292
293 }
294
295 }
296 }