Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines for 2022.
[simgrid.git] / src / smpi / colls / allreduce / allreduce-mvapich-rs.cpp
1 /* Copyright (c) 2013-2022. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  *  (C) 2001 by Argonne National Laboratory.
9  *      See COPYRIGHT in top-level directory.
10  */
11
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
13  * reserved.
14  *
15  * This file is part of the MVAPICH2 software package developed by the
16  * team members of The Ohio State University's Network-Based Computing
17  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
18  *
19  * For detailed copyright and licensing information, please refer to the
20  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
21  *
22  */
23
24 #include "../colls_private.hpp"
25 #include <algorithm>
26
27 namespace simgrid {
28 namespace smpi {
29 int allreduce__mvapich2_rs(const void *sendbuf,
30                            void *recvbuf,
31                            int count,
32                            MPI_Datatype datatype,
33                            MPI_Op op, MPI_Comm comm)
34 {
35     int mpi_errno = MPI_SUCCESS;
36     int newrank = 0;
37     int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
38     int dst, rem, newdst, recv_cnt;
39     MPI_Aint true_lb, true_extent, extent;
40
41     if (count == 0) {
42         return MPI_SUCCESS;
43     }
44
45     /* homogeneous */
46
47     int comm_size =  comm->size();
48     int rank = comm->rank();
49
50     bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
51
52     /* need to allocate temporary buffer to store incoming data */
53     datatype->extent(&true_lb, &true_extent);
54     extent = datatype->get_extent();
55
56     unsigned char* tmp_buf_free = smpi_get_tmp_recvbuffer(count * std::max(extent, true_extent));
57
58     /* adjust for potential negative lower bound in datatype */
59     unsigned char* tmp_buf = tmp_buf_free - true_lb;
60
61     /* copy local data into recvbuf */
62     if (sendbuf != MPI_IN_PLACE) {
63         mpi_errno =
64             Datatype::copy(sendbuf, count, datatype, recvbuf, count,
65                            datatype);
66     }
67
68     /* find nearest power-of-two less than or equal to comm_size */
69     for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
70     pof2 >>=1;
71
72     rem = comm_size - pof2;
73
74     /* In the non-power-of-two case, all even-numbered
75        processes of rank < 2*rem send their data to
76        (rank+1). These even-numbered processes no longer
77        participate in the algorithm until the very end. The
78        remaining processes form a nice power-of-two. */
79
80     if (rank < 2 * rem) {
81         if (rank % 2 == 0) {
82             /* even */
83             Request::send(recvbuf, count, datatype, rank + 1,
84                                      COLL_TAG_ALLREDUCE, comm);
85
86             /* temporarily set the rank to -1 so that this
87                process does not participate in recursive
88                doubling */
89             newrank = -1;
90         } else {
91             /* odd */
92             Request::recv(tmp_buf, count, datatype, rank - 1,
93                                      COLL_TAG_ALLREDUCE, comm,
94                                      MPI_STATUS_IGNORE);
95             /* do the reduction on received data. since the
96                ordering is right, it doesn't matter whether
97                the operation is commutative or not. */
98                if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
99                 /* change the rank */
100                 newrank = rank / 2;
101         }
102     } else {                /* rank >= 2*rem */
103         newrank = rank - rem;
104     }
105
106     /* If op is user-defined or count is less than pof2, use
107        recursive doubling algorithm. Otherwise do a reduce-scatter
108        followed by allgather. (If op is user-defined,
109        derived datatypes are allowed and the user could pass basic
110        datatypes on one process and derived on another as long as
111        the type maps are the same. Breaking up derived
112        datatypes to do the reduce-scatter is tricky, therefore
113        using recursive doubling in that case.) */
114
115     if (newrank != -1) {
116         if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) {  /* use recursive doubling */
117             mask = 0x1;
118             while (mask < pof2) {
119                 newdst = newrank ^ mask;
120                 /* find real rank of dest */
121                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
122
123                 /* Send the most current data, which is in recvbuf. Recv
124                    into tmp_buf */
125                 Request::sendrecv(recvbuf, count, datatype,
126                                              dst, COLL_TAG_ALLREDUCE,
127                                              tmp_buf, count, datatype, dst,
128                                              COLL_TAG_ALLREDUCE, comm,
129                                              MPI_STATUS_IGNORE);
130
131                 /* tmp_buf contains data received in this step.
132                    recvbuf contains data accumulated so far */
133
134                 if (is_commutative || (dst < rank)) {
135                     /* op is commutative OR the order is already right */
136                      if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
137                 } else {
138                     /* op is noncommutative and the order is not right */
139                     if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
140                     /* copy result back into recvbuf */
141                     mpi_errno = Datatype::copy(tmp_buf, count, datatype,
142                                                recvbuf, count, datatype);
143                 }
144                 mask <<= 1;
145             }
146         } else {
147
148             /* do a reduce-scatter followed by allgather */
149
150             /* for the reduce-scatter, calculate the count that
151                each process receives and the displacement within
152                the buffer */
153             int* cnts  = new int[pof2];
154             int* disps = new int[pof2];
155
156             for (i = 0; i < (pof2 - 1); i++) {
157                 cnts[i] = count / pof2;
158             }
159             cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
160
161             disps[0] = 0;
162             for (i = 1; i < pof2; i++) {
163                 disps[i] = disps[i - 1] + cnts[i - 1];
164             }
165
166             mask = 0x1;
167             send_idx = recv_idx = 0;
168             last_idx = pof2;
169             while (mask < pof2) {
170                 newdst = newrank ^ mask;
171                 /* find real rank of dest */
172                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
173
174                 send_cnt = recv_cnt = 0;
175                 if (newrank < newdst) {
176                     send_idx = recv_idx + pof2 / (mask * 2);
177                     for (i = send_idx; i < last_idx; i++)
178                         send_cnt += cnts[i];
179                     for (i = recv_idx; i < send_idx; i++)
180                         recv_cnt += cnts[i];
181                 } else {
182                     recv_idx = send_idx + pof2 / (mask * 2);
183                     for (i = send_idx; i < recv_idx; i++)
184                         send_cnt += cnts[i];
185                     for (i = recv_idx; i < last_idx; i++)
186                         recv_cnt += cnts[i];
187                 }
188
189                 /* Send data from recvbuf. Recv into tmp_buf */
190                 Request::sendrecv(static_cast<char*>(recvbuf) + disps[send_idx] * extent, send_cnt, datatype, dst,
191                                   COLL_TAG_ALLREDUCE, tmp_buf + disps[recv_idx] * extent, recv_cnt, datatype, dst,
192                                   COLL_TAG_ALLREDUCE, comm, MPI_STATUS_IGNORE);
193
194                 /* tmp_buf contains data received in this step.
195                    recvbuf contains data accumulated so far */
196
197                 /* This algorithm is used only for predefined ops
198                    and predefined ops are always commutative. */
199
200                 if (op != MPI_OP_NULL)
201                   op->apply(tmp_buf + disps[recv_idx] * extent, static_cast<char*>(recvbuf) + disps[recv_idx] * extent,
202                             &recv_cnt, datatype);
203
204                 /* update send_idx for next iteration */
205                 send_idx = recv_idx;
206                 mask <<= 1;
207
208                 /* update last_idx, but not in last iteration
209                    because the value is needed in the allgather
210                    step below. */
211                 if (mask < pof2)
212                     last_idx = recv_idx + pof2 / mask;
213             }
214
215             /* now do the allgather */
216
217             mask >>= 1;
218             while (mask > 0) {
219                 newdst = newrank ^ mask;
220                 /* find real rank of dest */
221                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
222
223                 send_cnt = recv_cnt = 0;
224                 if (newrank < newdst) {
225                     /* update last_idx except on first iteration */
226                     if (mask != pof2 / 2) {
227                         last_idx = last_idx + pof2 / (mask * 2);
228                     }
229
230                     recv_idx = send_idx + pof2 / (mask * 2);
231                     for (i = send_idx; i < recv_idx; i++) {
232                         send_cnt += cnts[i];
233                     }
234                     for (i = recv_idx; i < last_idx; i++) {
235                         recv_cnt += cnts[i];
236                     }
237                 } else {
238                     recv_idx = send_idx - pof2 / (mask * 2);
239                     for (i = send_idx; i < last_idx; i++) {
240                         send_cnt += cnts[i];
241                     }
242                     for (i = recv_idx; i < send_idx; i++) {
243                         recv_cnt += cnts[i];
244                     }
245                 }
246
247                Request::sendrecv((char *) recvbuf +
248                                              disps[send_idx] * extent,
249                                              send_cnt, datatype,
250                                              dst, COLL_TAG_ALLREDUCE,
251                                              (char *) recvbuf +
252                                              disps[recv_idx] * extent,
253                                              recv_cnt, datatype, dst,
254                                              COLL_TAG_ALLREDUCE, comm,
255                                              MPI_STATUS_IGNORE);
256                 if (newrank > newdst) {
257                     send_idx = recv_idx;
258                 }
259
260                 mask >>= 1;
261             }
262             delete[] disps;
263             delete[] cnts;
264         }
265     }
266
267     /* In the non-power-of-two case, all odd-numbered
268        processes of rank < 2*rem send the result to
269        (rank-1), the ranks who didn't participate above. */
270     if (rank < 2 * rem) {
271         if (rank % 2) {     /* odd */
272             Request::send(recvbuf, count,
273                                      datatype, rank - 1,
274                                      COLL_TAG_ALLREDUCE, comm);
275         } else {            /* even */
276             Request::recv(recvbuf, count,
277                                   datatype, rank + 1,
278                                   COLL_TAG_ALLREDUCE, comm,
279                                   MPI_STATUS_IGNORE);
280         }
281     }
282     smpi_free_tmp_buffer(tmp_buf_free);
283     return (mpi_errno);
284
285 }
286
287 }
288 }