Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Change malloc/free to new/delete.
[simgrid.git] / src / smpi / colls / allreduce / allreduce-mvapich-rs.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  *  (C) 2001 by Argonne National Laboratory.
9  *      See COPYRIGHT in top-level directory.
10  */
11
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
13  * reserved.
14  *
15  * This file is part of the MVAPICH2 software package developed by the
16  * team members of The Ohio State University's Network-Based Computing
17  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
18  *
19  * For detailed copyright and licensing information, please refer to the
20  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
21  *
22  */
23
24 #include "../colls_private.hpp"
25 #include <algorithm>
26
27 namespace simgrid{
28 namespace smpi{
29 int Coll_allreduce_mvapich2_rs::allreduce(const void *sendbuf,
30                             void *recvbuf,
31                             int count,
32                             MPI_Datatype datatype,
33                             MPI_Op op, MPI_Comm comm)
34 {
35     int mpi_errno = MPI_SUCCESS;
36     int newrank = 0;
37     int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
38     int dst, is_commutative, rem, newdst, recv_cnt;
39     MPI_Aint true_lb, true_extent, extent;
40     void *tmp_buf, *tmp_buf_free;
41
42     if (count == 0) {
43         return MPI_SUCCESS;
44     }
45
46     /* homogeneous */
47
48     int comm_size =  comm->size();
49     int rank = comm->rank();
50
51     is_commutative = (op==MPI_OP_NULL || op->is_commutative());
52
53     /* need to allocate temporary buffer to store incoming data */
54     datatype->extent(&true_lb, &true_extent);
55     extent = datatype->get_extent();
56
57     tmp_buf_free = smpi_get_tmp_recvbuffer(count * std::max(extent, true_extent));
58
59     /* adjust for potential negative lower bound in datatype */
60     tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
61
62     /* copy local data into recvbuf */
63     if (sendbuf != MPI_IN_PLACE) {
64         mpi_errno =
65             Datatype::copy(sendbuf, count, datatype, recvbuf, count,
66                            datatype);
67     }
68
69     /* find nearest power-of-two less than or equal to comm_size */
70     for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
71     pof2 >>=1;
72
73     rem = comm_size - pof2;
74
75     /* In the non-power-of-two case, all even-numbered
76        processes of rank < 2*rem send their data to
77        (rank+1). These even-numbered processes no longer
78        participate in the algorithm until the very end. The
79        remaining processes form a nice power-of-two. */
80
81     if (rank < 2 * rem) {
82         if (rank % 2 == 0) {
83             /* even */
84             Request::send(recvbuf, count, datatype, rank + 1,
85                                      COLL_TAG_ALLREDUCE, comm);
86
87             /* temporarily set the rank to -1 so that this
88                process does not pariticipate in recursive
89                doubling */
90             newrank = -1;
91         } else {
92             /* odd */
93             Request::recv(tmp_buf, count, datatype, rank - 1,
94                                      COLL_TAG_ALLREDUCE, comm,
95                                      MPI_STATUS_IGNORE);
96             /* do the reduction on received data. since the
97                ordering is right, it doesn't matter whether
98                the operation is commutative or not. */
99                if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
100                 /* change the rank */
101                 newrank = rank / 2;
102         }
103     } else {                /* rank >= 2*rem */
104         newrank = rank - rem;
105     }
106
107     /* If op is user-defined or count is less than pof2, use
108        recursive doubling algorithm. Otherwise do a reduce-scatter
109        followed by allgather. (If op is user-defined,
110        derived datatypes are allowed and the user could pass basic
111        datatypes on one process and derived on another as long as
112        the type maps are the same. Breaking up derived
113        datatypes to do the reduce-scatter is tricky, therefore
114        using recursive doubling in that case.) */
115
116     if (newrank != -1) {
117         if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) {  /* use recursive doubling */
118             mask = 0x1;
119             while (mask < pof2) {
120                 newdst = newrank ^ mask;
121                 /* find real rank of dest */
122                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
123
124                 /* Send the most current data, which is in recvbuf. Recv
125                    into tmp_buf */
126                 Request::sendrecv(recvbuf, count, datatype,
127                                              dst, COLL_TAG_ALLREDUCE,
128                                              tmp_buf, count, datatype, dst,
129                                              COLL_TAG_ALLREDUCE, comm,
130                                              MPI_STATUS_IGNORE);
131
132                 /* tmp_buf contains data received in this step.
133                    recvbuf contains data accumulated so far */
134
135                 if (is_commutative || (dst < rank)) {
136                     /* op is commutative OR the order is already right */
137                      if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
138                 } else {
139                     /* op is noncommutative and the order is not right */
140                     if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
141                     /* copy result back into recvbuf */
142                     mpi_errno = Datatype::copy(tmp_buf, count, datatype,
143                                                recvbuf, count, datatype);
144                 }
145                 mask <<= 1;
146             }
147         } else {
148
149             /* do a reduce-scatter followed by allgather */
150
151             /* for the reduce-scatter, calculate the count that
152                each process receives and the displacement within
153                the buffer */
154             int* cnts  = new int[pof2];
155             int* disps = new int[pof2];
156
157             for (i = 0; i < (pof2 - 1); i++) {
158                 cnts[i] = count / pof2;
159             }
160             cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
161
162             disps[0] = 0;
163             for (i = 1; i < pof2; i++) {
164                 disps[i] = disps[i - 1] + cnts[i - 1];
165             }
166
167             mask = 0x1;
168             send_idx = recv_idx = 0;
169             last_idx = pof2;
170             while (mask < pof2) {
171                 newdst = newrank ^ mask;
172                 /* find real rank of dest */
173                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
174
175                 send_cnt = recv_cnt = 0;
176                 if (newrank < newdst) {
177                     send_idx = recv_idx + pof2 / (mask * 2);
178                     for (i = send_idx; i < last_idx; i++)
179                         send_cnt += cnts[i];
180                     for (i = recv_idx; i < send_idx; i++)
181                         recv_cnt += cnts[i];
182                 } else {
183                     recv_idx = send_idx + pof2 / (mask * 2);
184                     for (i = send_idx; i < recv_idx; i++)
185                         send_cnt += cnts[i];
186                     for (i = recv_idx; i < last_idx; i++)
187                         recv_cnt += cnts[i];
188                 }
189
190                 /* Send data from recvbuf. Recv into tmp_buf */
191                 Request::sendrecv((char *) recvbuf +
192                                              disps[send_idx] * extent,
193                                              send_cnt, datatype,
194                                              dst, COLL_TAG_ALLREDUCE,
195                                              (char *) tmp_buf +
196                                              disps[recv_idx] * extent,
197                                              recv_cnt, datatype, dst,
198                                              COLL_TAG_ALLREDUCE, comm,
199                                              MPI_STATUS_IGNORE);
200
201                 /* tmp_buf contains data received in this step.
202                    recvbuf contains data accumulated so far */
203
204                 /* This algorithm is used only for predefined ops
205                    and predefined ops are always commutative. */
206
207                 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
208                         (char *) recvbuf + disps[recv_idx] * extent,
209                         &recv_cnt, datatype);
210
211                 /* update send_idx for next iteration */
212                 send_idx = recv_idx;
213                 mask <<= 1;
214
215                 /* update last_idx, but not in last iteration
216                    because the value is needed in the allgather
217                    step below. */
218                 if (mask < pof2)
219                     last_idx = recv_idx + pof2 / mask;
220             }
221
222             /* now do the allgather */
223
224             mask >>= 1;
225             while (mask > 0) {
226                 newdst = newrank ^ mask;
227                 /* find real rank of dest */
228                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
229
230                 send_cnt = recv_cnt = 0;
231                 if (newrank < newdst) {
232                     /* update last_idx except on first iteration */
233                     if (mask != pof2 / 2) {
234                         last_idx = last_idx + pof2 / (mask * 2);
235                     }
236
237                     recv_idx = send_idx + pof2 / (mask * 2);
238                     for (i = send_idx; i < recv_idx; i++) {
239                         send_cnt += cnts[i];
240                     }
241                     for (i = recv_idx; i < last_idx; i++) {
242                         recv_cnt += cnts[i];
243                     }
244                 } else {
245                     recv_idx = send_idx - pof2 / (mask * 2);
246                     for (i = send_idx; i < last_idx; i++) {
247                         send_cnt += cnts[i];
248                     }
249                     for (i = recv_idx; i < send_idx; i++) {
250                         recv_cnt += cnts[i];
251                     }
252                 }
253
254                Request::sendrecv((char *) recvbuf +
255                                              disps[send_idx] * extent,
256                                              send_cnt, datatype,
257                                              dst, COLL_TAG_ALLREDUCE,
258                                              (char *) recvbuf +
259                                              disps[recv_idx] * extent,
260                                              recv_cnt, datatype, dst,
261                                              COLL_TAG_ALLREDUCE, comm,
262                                              MPI_STATUS_IGNORE);
263                 if (newrank > newdst) {
264                     send_idx = recv_idx;
265                 }
266
267                 mask >>= 1;
268             }
269             delete[] disps;
270             delete[] cnts;
271         }
272     }
273
274     /* In the non-power-of-two case, all odd-numbered
275        processes of rank < 2*rem send the result to
276        (rank-1), the ranks who didn't participate above. */
277     if (rank < 2 * rem) {
278         if (rank % 2) {     /* odd */
279             Request::send(recvbuf, count,
280                                      datatype, rank - 1,
281                                      COLL_TAG_ALLREDUCE, comm);
282         } else {            /* even */
283             Request::recv(recvbuf, count,
284                                   datatype, rank + 1,
285                                   COLL_TAG_ALLREDUCE, comm,
286                                   MPI_STATUS_IGNORE);
287         }
288     }
289     smpi_free_tmp_buffer(tmp_buf_free);
290     return (mpi_errno);
291
292 }
293
294 }
295 }