Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Move collective algorithms to separate folders
[simgrid.git] / src / smpi / colls / allreduce / allreduce-mvapich-rs.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  *  (C) 2001 by Argonne National Laboratory.
9  *      See COPYRIGHT in top-level directory.
10  */
11
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
13  * reserved.
14  *
15  * This file is part of the MVAPICH2 software package developed by the
16  * team members of The Ohio State University's Network-Based Computing
17  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
18  *
19  * For detailed copyright and licensing information, please refer to the
20  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
21  *
22  */
23
24 #include "../colls_private.h"
25
26 int smpi_coll_tuned_allreduce_mvapich2_rs(void *sendbuf,
27                             void *recvbuf,
28                             int count,
29                             MPI_Datatype datatype,
30                             MPI_Op op, MPI_Comm comm)
31 {
32     int mpi_errno = MPI_SUCCESS;
33     int newrank = 0;
34     int mask, pof2, i, send_idx, recv_idx, last_idx, send_cnt;
35     int dst, is_commutative, rem, newdst,
36         recv_cnt, *cnts, *disps;
37     MPI_Aint true_lb, true_extent, extent;
38     void *tmp_buf, *tmp_buf_free;
39
40     if (count == 0) {
41         return MPI_SUCCESS;
42     }
43
44     /* homogeneous */
45
46     int comm_size =  comm->size();
47     int rank = comm->rank();
48
49     is_commutative = (op==MPI_OP_NULL || op->is_commutative());
50
51     /* need to allocate temporary buffer to store incoming data */
52     datatype->extent(&true_lb, &true_extent);
53     extent = datatype->get_extent();
54
55     tmp_buf_free= smpi_get_tmp_recvbuffer(count * (MAX(extent, true_extent)));
56
57     /* adjust for potential negative lower bound in datatype */
58     tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
59
60     /* copy local data into recvbuf */
61     if (sendbuf != MPI_IN_PLACE) {
62         mpi_errno =
63             Datatype::copy(sendbuf, count, datatype, recvbuf, count,
64                            datatype);
65     }
66
67     /* find nearest power-of-two less than or equal to comm_size */
68     for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
69     pof2 >>=1;
70
71     rem = comm_size - pof2;
72
73     /* In the non-power-of-two case, all even-numbered
74        processes of rank < 2*rem send their data to
75        (rank+1). These even-numbered processes no longer
76        participate in the algorithm until the very end. The
77        remaining processes form a nice power-of-two. */
78
79     if (rank < 2 * rem) {
80         if (rank % 2 == 0) {
81             /* even */
82             Request::send(recvbuf, count, datatype, rank + 1,
83                                      COLL_TAG_ALLREDUCE, comm);
84
85             /* temporarily set the rank to -1 so that this
86                process does not pariticipate in recursive
87                doubling */
88             newrank = -1;
89         } else {
90             /* odd */
91             Request::recv(tmp_buf, count, datatype, rank - 1,
92                                      COLL_TAG_ALLREDUCE, comm,
93                                      MPI_STATUS_IGNORE);
94             /* do the reduction on received data. since the
95                ordering is right, it doesn't matter whether
96                the operation is commutative or not. */
97                if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
98                 /* change the rank */
99                 newrank = rank / 2;
100         }
101     } else {                /* rank >= 2*rem */
102         newrank = rank - rem;
103     }
104
105     /* If op is user-defined or count is less than pof2, use
106        recursive doubling algorithm. Otherwise do a reduce-scatter
107        followed by allgather. (If op is user-defined,
108        derived datatypes are allowed and the user could pass basic
109        datatypes on one process and derived on another as long as
110        the type maps are the same. Breaking up derived
111        datatypes to do the reduce-scatter is tricky, therefore
112        using recursive doubling in that case.) */
113
114     if (newrank != -1) {
115         if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) {  /* use recursive doubling */
116             mask = 0x1;
117             while (mask < pof2) {
118                 newdst = newrank ^ mask;
119                 /* find real rank of dest */
120                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
121
122                 /* Send the most current data, which is in recvbuf. Recv
123                    into tmp_buf */
124                 Request::sendrecv(recvbuf, count, datatype,
125                                              dst, COLL_TAG_ALLREDUCE,
126                                              tmp_buf, count, datatype, dst,
127                                              COLL_TAG_ALLREDUCE, comm,
128                                              MPI_STATUS_IGNORE);
129
130                 /* tmp_buf contains data received in this step.
131                    recvbuf contains data accumulated so far */
132
133                 if (is_commutative || (dst < rank)) {
134                     /* op is commutative OR the order is already right */
135                      if(op!=MPI_OP_NULL) op->apply( tmp_buf, recvbuf, &count, datatype);
136                 } else {
137                     /* op is noncommutative and the order is not right */
138                     if(op!=MPI_OP_NULL) op->apply( recvbuf, tmp_buf, &count, datatype);
139                     /* copy result back into recvbuf */
140                     mpi_errno = Datatype::copy(tmp_buf, count, datatype,
141                                                recvbuf, count, datatype);
142                 }
143                 mask <<= 1;
144             }
145         } else {
146
147             /* do a reduce-scatter followed by allgather */
148
149             /* for the reduce-scatter, calculate the count that
150                each process receives and the displacement within
151                the buffer */
152             cnts = (int *)xbt_malloc(pof2 * sizeof (int));
153             disps = (int *)xbt_malloc(pof2 * sizeof (int));
154
155             for (i = 0; i < (pof2 - 1); i++) {
156                 cnts[i] = count / pof2;
157             }
158             cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
159
160             disps[0] = 0;
161             for (i = 1; i < pof2; i++) {
162                 disps[i] = disps[i - 1] + cnts[i - 1];
163             }
164
165             mask = 0x1;
166             send_idx = recv_idx = 0;
167             last_idx = pof2;
168             while (mask < pof2) {
169                 newdst = newrank ^ mask;
170                 /* find real rank of dest */
171                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
172
173                 send_cnt = recv_cnt = 0;
174                 if (newrank < newdst) {
175                     send_idx = recv_idx + pof2 / (mask * 2);
176                     for (i = send_idx; i < last_idx; i++)
177                         send_cnt += cnts[i];
178                     for (i = recv_idx; i < send_idx; i++)
179                         recv_cnt += cnts[i];
180                 } else {
181                     recv_idx = send_idx + pof2 / (mask * 2);
182                     for (i = send_idx; i < recv_idx; i++)
183                         send_cnt += cnts[i];
184                     for (i = recv_idx; i < last_idx; i++)
185                         recv_cnt += cnts[i];
186                 }
187
188                 /* Send data from recvbuf. Recv into tmp_buf */
189                 Request::sendrecv((char *) recvbuf +
190                                              disps[send_idx] * extent,
191                                              send_cnt, datatype,
192                                              dst, COLL_TAG_ALLREDUCE,
193                                              (char *) tmp_buf +
194                                              disps[recv_idx] * extent,
195                                              recv_cnt, datatype, dst,
196                                              COLL_TAG_ALLREDUCE, comm,
197                                              MPI_STATUS_IGNORE);
198
199                 /* tmp_buf contains data received in this step.
200                    recvbuf contains data accumulated so far */
201
202                 /* This algorithm is used only for predefined ops
203                    and predefined ops are always commutative. */
204
205                 if(op!=MPI_OP_NULL) op->apply( (char *) tmp_buf + disps[recv_idx] * extent,
206                         (char *) recvbuf + disps[recv_idx] * extent,
207                         &recv_cnt, datatype);
208
209                 /* update send_idx for next iteration */
210                 send_idx = recv_idx;
211                 mask <<= 1;
212
213                 /* update last_idx, but not in last iteration
214                    because the value is needed in the allgather
215                    step below. */
216                 if (mask < pof2)
217                     last_idx = recv_idx + pof2 / mask;
218             }
219
220             /* now do the allgather */
221
222             mask >>= 1;
223             while (mask > 0) {
224                 newdst = newrank ^ mask;
225                 /* find real rank of dest */
226                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
227
228                 send_cnt = recv_cnt = 0;
229                 if (newrank < newdst) {
230                     /* update last_idx except on first iteration */
231                     if (mask != pof2 / 2) {
232                         last_idx = last_idx + pof2 / (mask * 2);
233                     }
234
235                     recv_idx = send_idx + pof2 / (mask * 2);
236                     for (i = send_idx; i < recv_idx; i++) {
237                         send_cnt += cnts[i];
238                     }
239                     for (i = recv_idx; i < last_idx; i++) {
240                         recv_cnt += cnts[i];
241                     }
242                 } else {
243                     recv_idx = send_idx - pof2 / (mask * 2);
244                     for (i = send_idx; i < last_idx; i++) {
245                         send_cnt += cnts[i];
246                     }
247                     for (i = recv_idx; i < send_idx; i++) {
248                         recv_cnt += cnts[i];
249                     }
250                 }
251
252                Request::sendrecv((char *) recvbuf +
253                                              disps[send_idx] * extent,
254                                              send_cnt, datatype,
255                                              dst, COLL_TAG_ALLREDUCE,
256                                              (char *) recvbuf +
257                                              disps[recv_idx] * extent,
258                                              recv_cnt, datatype, dst,
259                                              COLL_TAG_ALLREDUCE, comm,
260                                              MPI_STATUS_IGNORE);
261                 if (newrank > newdst) {
262                     send_idx = recv_idx;
263                 }
264
265                 mask >>= 1;
266             }
267             xbt_free(disps);
268             xbt_free(cnts);
269         }
270     }
271
272     /* In the non-power-of-two case, all odd-numbered
273        processes of rank < 2*rem send the result to
274        (rank-1), the ranks who didn't participate above. */
275     if (rank < 2 * rem) {
276         if (rank % 2) {     /* odd */
277             Request::send(recvbuf, count,
278                                      datatype, rank - 1,
279                                      COLL_TAG_ALLREDUCE, comm);
280         } else {            /* even */
281             Request::recv(recvbuf, count,
282                                   datatype, rank + 1,
283                                   COLL_TAG_ALLREDUCE, comm,
284                                   MPI_STATUS_IGNORE);
285         }
286     }
287     smpi_free_tmp_buffer(tmp_buf_free);
288     return (mpi_errno);
289
290 }