Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Leak-- (seen in maestro-set).
[simgrid.git] / src / smpi / colls / allreduce / allreduce-mvapich-rs.cpp
1 /* Copyright (c) 2013-2017. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  *  (C) 2001 by Argonne National Laboratory.
9  *      See COPYRIGHT in top-level directory.
10  */
11
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
13  * reserved.
14  *
15  * This file is part of the MVAPICH2 software package developed by the
16  * team members of The Ohio State University's Network-Based Computing
17  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
18  *
19  * For detailed copyright and licensing information, please refer to the
20  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
21  *
22  */
23
24 #include "../colls_private.h"
25 namespace simgrid{
26 namespace smpi{
27 int Coll_allreduce_mvapich2_rs::allreduce(void *sendbuf,
28                             void *recvbuf,
29                             int count,
30                             MPI_Datatype datatype,
31                             MPI_Op op, MPI_Comm comm)
32 {
33     int mpi_errno = MPI_SUCCESS;
34     int newrank = 0;
35     int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
36     int dst, is_commutative, rem, newdst,
37         recv_cnt, *cnts, *disps;
38     MPI_Aint true_lb, true_extent, extent;
39     void *tmp_buf, *tmp_buf_free;
40
41     if (count == 0) {
42         return MPI_SUCCESS;
43     }
44
45     /* homogeneous */
46
47     int comm_size =  comm->size();
48     int rank = comm->rank();
49
50     is_commutative = (op==MPI_OP_NULL || op->is_commutative());
51
52     /* need to allocate temporary buffer to store incoming data */
53     datatype->extent(&true_lb, &true_extent);
54     extent = datatype->get_extent();
55
56     tmp_buf_free= smpi_get_tmp_recvbuffer(count * (MAX(extent, true_extent)));
57
58     /* adjust for potential negative lower bound in datatype */
59     tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
60
61     /* copy local data into recvbuf */
62     if (sendbuf != MPI_IN_PLACE) {
63         mpi_errno =
64             Datatype::copy(sendbuf, count, datatype, recvbuf, count,
65                            datatype);
66     }
67
68     /* find nearest power-of-two less than or equal to comm_size */
69     for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
70     pof2 >>=1;
71
72     rem = comm_size - pof2;
73
74     /* In the non-power-of-two case, all even-numbered
75        processes of rank < 2*rem send their data to
76        (rank+1). These even-numbered processes no longer
77        participate in the algorithm until the very end. The
78        remaining processes form a nice power-of-two. */
79
80     if (rank < 2 * rem) {
81         if (rank % 2 == 0) {
82             /* even */
83             Request::send(recvbuf, count, datatype, rank + 1,
84                                      COLL_TAG_ALLREDUCE, comm);
85
86             /* temporarily set the rank to -1 so that this
87                process does not pariticipate in recursive
88                doubling */
89             newrank = -1;
90         } else {
91             /* odd */
92             Request::recv(tmp_buf, count, datatype, rank - 1,
93                                      COLL_TAG_ALLREDUCE, comm,
94                                      MPI_STATUS_IGNORE);
95             /* do the reduction on received data. since the
96                ordering is right, it doesn't matter whether
97                the operation is commutative or not. */
98                if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
99                 /* change the rank */
100                 newrank = rank / 2;
101         }
102     } else {                /* rank >= 2*rem */
103         newrank = rank - rem;
104     }
105
106     /* If op is user-defined or count is less than pof2, use
107        recursive doubling algorithm. Otherwise do a reduce-scatter
108        followed by allgather. (If op is user-defined,
109        derived datatypes are allowed and the user could pass basic
110        datatypes on one process and derived on another as long as
111        the type maps are the same. Breaking up derived
112        datatypes to do the reduce-scatter is tricky, therefore
113        using recursive doubling in that case.) */
114
115     if (newrank != -1) {
116         if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) {  /* use recursive doubling */
117             mask = 0x1;
118             while (mask < pof2) {
119                 newdst = newrank ^ mask;
120                 /* find real rank of dest */
121                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
122
123                 /* Send the most current data, which is in recvbuf. Recv
124                    into tmp_buf */
125                 Request::sendrecv(recvbuf, count, datatype,
126                                              dst, COLL_TAG_ALLREDUCE,
127                                              tmp_buf, count, datatype, dst,
128                                              COLL_TAG_ALLREDUCE, comm,
129                                              MPI_STATUS_IGNORE);
130
131                 /* tmp_buf contains data received in this step.
132                    recvbuf contains data accumulated so far */
133
134                 if (is_commutative || (dst < rank)) {
135                     /* op is commutative OR the order is already right */
136                      if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
137                 } else {
138                     /* op is noncommutative and the order is not right */
139                     if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
140                     /* copy result back into recvbuf */
141                     mpi_errno = Datatype::copy(tmp_buf, count, datatype,
142                                                recvbuf, count, datatype);
143                 }
144                 mask <<= 1;
145             }
146         } else {
147
148             /* do a reduce-scatter followed by allgather */
149
150             /* for the reduce-scatter, calculate the count that
151                each process receives and the displacement within
152                the buffer */
153             cnts = (int *)xbt_malloc(pof2 * sizeof (int));
154             disps = (int *)xbt_malloc(pof2 * sizeof (int));
155
156             for (i = 0; i < (pof2 - 1); i++) {
157                 cnts[i] = count / pof2;
158             }
159             cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
160
161             disps[0] = 0;
162             for (i = 1; i < pof2; i++) {
163                 disps[i] = disps[i - 1] + cnts[i - 1];
164             }
165
166             mask = 0x1;
167             send_idx = recv_idx = 0;
168             last_idx = pof2;
169             while (mask < pof2) {
170                 newdst = newrank ^ mask;
171                 /* find real rank of dest */
172                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
173
174                 send_cnt = recv_cnt = 0;
175                 if (newrank < newdst) {
176                     send_idx = recv_idx + pof2 / (mask * 2);
177                     for (i = send_idx; i < last_idx; i++)
178                         send_cnt += cnts[i];
179                     for (i = recv_idx; i < send_idx; i++)
180                         recv_cnt += cnts[i];
181                 } else {
182                     recv_idx = send_idx + pof2 / (mask * 2);
183                     for (i = send_idx; i < recv_idx; i++)
184                         send_cnt += cnts[i];
185                     for (i = recv_idx; i < last_idx; i++)
186                         recv_cnt += cnts[i];
187                 }
188
189                 /* Send data from recvbuf. Recv into tmp_buf */
190                 Request::sendrecv((char *) recvbuf +
191                                              disps[send_idx] * extent,
192                                              send_cnt, datatype,
193                                              dst, COLL_TAG_ALLREDUCE,
194                                              (char *) tmp_buf +
195                                              disps[recv_idx] * extent,
196                                              recv_cnt, datatype, dst,
197                                              COLL_TAG_ALLREDUCE, comm,
198                                              MPI_STATUS_IGNORE);
199
200                 /* tmp_buf contains data received in this step.
201                    recvbuf contains data accumulated so far */
202
203                 /* This algorithm is used only for predefined ops
204                    and predefined ops are always commutative. */
205
206                 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
207                         (char *) recvbuf + disps[recv_idx] * extent,
208                         &recv_cnt, datatype);
209
210                 /* update send_idx for next iteration */
211                 send_idx = recv_idx;
212                 mask <<= 1;
213
214                 /* update last_idx, but not in last iteration
215                    because the value is needed in the allgather
216                    step below. */
217                 if (mask < pof2)
218                     last_idx = recv_idx + pof2 / mask;
219             }
220
221             /* now do the allgather */
222
223             mask >>= 1;
224             while (mask > 0) {
225                 newdst = newrank ^ mask;
226                 /* find real rank of dest */
227                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
228
229                 send_cnt = recv_cnt = 0;
230                 if (newrank < newdst) {
231                     /* update last_idx except on first iteration */
232                     if (mask != pof2 / 2) {
233                         last_idx = last_idx + pof2 / (mask * 2);
234                     }
235
236                     recv_idx = send_idx + pof2 / (mask * 2);
237                     for (i = send_idx; i < recv_idx; i++) {
238                         send_cnt += cnts[i];
239                     }
240                     for (i = recv_idx; i < last_idx; i++) {
241                         recv_cnt += cnts[i];
242                     }
243                 } else {
244                     recv_idx = send_idx - pof2 / (mask * 2);
245                     for (i = send_idx; i < last_idx; i++) {
246                         send_cnt += cnts[i];
247                     }
248                     for (i = recv_idx; i < send_idx; i++) {
249                         recv_cnt += cnts[i];
250                     }
251                 }
252
253                Request::sendrecv((char *) recvbuf +
254                                              disps[send_idx] * extent,
255                                              send_cnt, datatype,
256                                              dst, COLL_TAG_ALLREDUCE,
257                                              (char *) recvbuf +
258                                              disps[recv_idx] * extent,
259                                              recv_cnt, datatype, dst,
260                                              COLL_TAG_ALLREDUCE, comm,
261                                              MPI_STATUS_IGNORE);
262                 if (newrank > newdst) {
263                     send_idx = recv_idx;
264                 }
265
266                 mask >>= 1;
267             }
268             xbt_free(disps);
269             xbt_free(cnts);
270         }
271     }
272
273     /* In the non-power-of-two case, all odd-numbered
274        processes of rank < 2*rem send the result to
275        (rank-1), the ranks who didn't participate above. */
276     if (rank < 2 * rem) {
277         if (rank % 2) {     /* odd */
278             Request::send(recvbuf, count,
279                                      datatype, rank - 1,
280                                      COLL_TAG_ALLREDUCE, comm);
281         } else {            /* even */
282             Request::recv(recvbuf, count,
283                                   datatype, rank + 1,
284                                   COLL_TAG_ALLREDUCE, comm,
285                                   MPI_STATUS_IGNORE);
286         }
287     }
288     smpi_free_tmp_buffer(tmp_buf_free);
289     return (mpi_errno);
290
291 }
292
293 }
294 }