Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
oops
[simgrid.git] / src / smpi / colls / allreduce-mvapich-rs.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  *  (C) 2001 by Argonne National Laboratory.
9  *      See COPYRIGHT in top-level directory.
10  */
11
12 /* Copyright (c) 2001-2014, The Ohio State University. All rights
13  * reserved.
14  *
15  * This file is part of the MVAPICH2 software package developed by the
16  * team members of The Ohio State University's Network-Based Computing
17  * Laboratory (NBCL), headed by Professor Dhabaleswar K. (DK) Panda.
18  *
19  * For detailed copyright and licensing information, please refer to the
20  * copyright file COPYRIGHT in the top level MVAPICH2 directory.
21  *
22  */
23  
24  #include "colls_private.h"
25  
26  int smpi_coll_tuned_allreduce_mvapich2_rs(void *sendbuf,
27                              void *recvbuf,
28                              int count,
29                              MPI_Datatype datatype,
30                              MPI_Op op, MPI_Comm comm)
31 {
32     int comm_size, rank;
33     int mpi_errno = MPI_SUCCESS;
34     int mask, dst, is_commutative, pof2, newrank = 0, rem, newdst, i,
35         send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
36     MPI_Aint true_lb, true_extent, extent;
37     void *tmp_buf, *tmp_buf_free;
38
39     if (count == 0) {
40         return MPI_SUCCESS;
41     }
42
43     /* homogeneous */
44
45     comm_size =  smpi_comm_size(comm);
46     rank = smpi_comm_rank(comm);
47
48     is_commutative = smpi_op_is_commute(op);
49
50     /* need to allocate temporary buffer to store incoming data */
51     smpi_datatype_extent(datatype, &true_lb, &true_extent);
52     extent = smpi_datatype_get_extent(datatype);
53
54     tmp_buf_free= smpi_get_tmp_recvbuffer(count * (MAX(extent, true_extent)));
55
56     /* adjust for potential negative lower bound in datatype */
57     tmp_buf = (void *) ((char *) tmp_buf_free - true_lb);
58
59     /* copy local data into recvbuf */
60     if (sendbuf != MPI_IN_PLACE) {
61         mpi_errno =
62             smpi_datatype_copy(sendbuf, count, datatype, recvbuf, count,
63                            datatype);
64     }
65
66     /* find nearest power-of-two less than or equal to comm_size */
67     for( pof2 = 1; pof2 <= comm_size; pof2 <<= 1 );
68     pof2 >>=1;
69
70     rem = comm_size - pof2;
71
72     /* In the non-power-of-two case, all even-numbered
73        processes of rank < 2*rem send their data to
74        (rank+1). These even-numbered processes no longer
75        participate in the algorithm until the very end. The
76        remaining processes form a nice power-of-two. */
77
78     if (rank < 2 * rem) {
79         if (rank % 2 == 0) {
80             /* even */
81             smpi_mpi_send(recvbuf, count, datatype, rank + 1,
82                                      COLL_TAG_ALLREDUCE, comm);
83
84             /* temporarily set the rank to -1 so that this
85                process does not pariticipate in recursive
86                doubling */
87             newrank = -1;
88         } else {
89             /* odd */
90             smpi_mpi_recv(tmp_buf, count, datatype, rank - 1,
91                                      COLL_TAG_ALLREDUCE, comm,
92                                      MPI_STATUS_IGNORE);
93             /* do the reduction on received data. since the
94                ordering is right, it doesn't matter whether
95                the operation is commutative or not. */
96                smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
97                 /* change the rank */
98                 newrank = rank / 2;
99         }
100     } else {                /* rank >= 2*rem */
101         newrank = rank - rem;
102     }
103
104     /* If op is user-defined or count is less than pof2, use
105        recursive doubling algorithm. Otherwise do a reduce-scatter
106        followed by allgather. (If op is user-defined,
107        derived datatypes are allowed and the user could pass basic
108        datatypes on one process and derived on another as long as
109        the type maps are the same. Breaking up derived
110        datatypes to do the reduce-scatter is tricky, therefore
111        using recursive doubling in that case.) */
112
113     if (newrank != -1) {
114         if (/*(HANDLE_GET_KIND(op) != HANDLE_KIND_BUILTIN) ||*/ (count < pof2)) {  /* use recursive doubling */
115             mask = 0x1;
116             while (mask < pof2) {
117                 newdst = newrank ^ mask;
118                 /* find real rank of dest */
119                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
120
121                 /* Send the most current data, which is in recvbuf. Recv
122                    into tmp_buf */
123                 smpi_mpi_sendrecv(recvbuf, count, datatype,
124                                              dst, COLL_TAG_ALLREDUCE,
125                                              tmp_buf, count, datatype, dst,
126                                              COLL_TAG_ALLREDUCE, comm,
127                                              MPI_STATUS_IGNORE);
128
129                 /* tmp_buf contains data received in this step.
130                    recvbuf contains data accumulated so far */
131
132                 if (is_commutative || (dst < rank)) {
133                     /* op is commutative OR the order is already right */
134                      smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
135                 } else {
136                     /* op is noncommutative and the order is not right */
137                     smpi_op_apply(op, recvbuf, tmp_buf, &count, &datatype);
138                     /* copy result back into recvbuf */
139                     mpi_errno = smpi_datatype_copy(tmp_buf, count, datatype,
140                                                recvbuf, count, datatype);
141                 }
142                 mask <<= 1;
143             }
144         } else {
145
146             /* do a reduce-scatter followed by allgather */
147
148             /* for the reduce-scatter, calculate the count that
149                each process receives and the displacement within
150                the buffer */
151             cnts = (int *)xbt_malloc(pof2 * sizeof (int));
152             disps = (int *)xbt_malloc(pof2 * sizeof (int));
153
154             for (i = 0; i < (pof2 - 1); i++) {
155                 cnts[i] = count / pof2;
156             }
157             cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
158
159             disps[0] = 0;
160             for (i = 1; i < pof2; i++) {
161                 disps[i] = disps[i - 1] + cnts[i - 1];
162             }
163
164             mask = 0x1;
165             send_idx = recv_idx = 0;
166             last_idx = pof2;
167             while (mask < pof2) {
168                 newdst = newrank ^ mask;
169                 /* find real rank of dest */
170                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
171
172                 send_cnt = recv_cnt = 0;
173                 if (newrank < newdst) {
174                     send_idx = recv_idx + pof2 / (mask * 2);
175                     for (i = send_idx; i < last_idx; i++)
176                         send_cnt += cnts[i];
177                     for (i = recv_idx; i < send_idx; i++)
178                         recv_cnt += cnts[i];
179                 } else {
180                     recv_idx = send_idx + pof2 / (mask * 2);
181                     for (i = send_idx; i < recv_idx; i++)
182                         send_cnt += cnts[i];
183                     for (i = recv_idx; i < last_idx; i++)
184                         recv_cnt += cnts[i];
185                 }
186
187                 /* Send data from recvbuf. Recv into tmp_buf */
188                 smpi_mpi_sendrecv((char *) recvbuf +
189                                              disps[send_idx] * extent,
190                                              send_cnt, datatype,
191                                              dst, COLL_TAG_ALLREDUCE,
192                                              (char *) tmp_buf +
193                                              disps[recv_idx] * extent,
194                                              recv_cnt, datatype, dst,
195                                              COLL_TAG_ALLREDUCE, comm,
196                                              MPI_STATUS_IGNORE);
197
198                 /* tmp_buf contains data received in this step.
199                    recvbuf contains data accumulated so far */
200
201                 /* This algorithm is used only for predefined ops
202                    and predefined ops are always commutative. */
203
204                 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
205                         (char *) recvbuf + disps[recv_idx] * extent,
206                         &recv_cnt, &datatype);
207
208                 /* update send_idx for next iteration */
209                 send_idx = recv_idx;
210                 mask <<= 1;
211
212                 /* update last_idx, but not in last iteration
213                    because the value is needed in the allgather
214                    step below. */
215                 if (mask < pof2)
216                     last_idx = recv_idx + pof2 / mask;
217             }
218
219             /* now do the allgather */
220
221             mask >>= 1;
222             while (mask > 0) {
223                 newdst = newrank ^ mask;
224                 /* find real rank of dest */
225                 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
226
227                 send_cnt = recv_cnt = 0;
228                 if (newrank < newdst) {
229                     /* update last_idx except on first iteration */
230                     if (mask != pof2 / 2) {
231                         last_idx = last_idx + pof2 / (mask * 2);
232                     }
233
234                     recv_idx = send_idx + pof2 / (mask * 2);
235                     for (i = send_idx; i < recv_idx; i++) {
236                         send_cnt += cnts[i];
237                     }
238                     for (i = recv_idx; i < last_idx; i++) {
239                         recv_cnt += cnts[i];
240                     }
241                 } else {
242                     recv_idx = send_idx - pof2 / (mask * 2);
243                     for (i = send_idx; i < last_idx; i++) {
244                         send_cnt += cnts[i];
245                     }
246                     for (i = recv_idx; i < send_idx; i++) {
247                         recv_cnt += cnts[i];
248                     }
249                 }
250
251                smpi_mpi_sendrecv((char *) recvbuf +
252                                              disps[send_idx] * extent,
253                                              send_cnt, datatype,
254                                              dst, COLL_TAG_ALLREDUCE,
255                                              (char *) recvbuf +
256                                              disps[recv_idx] * extent,
257                                              recv_cnt, datatype, dst,
258                                              COLL_TAG_ALLREDUCE, comm,
259                                              MPI_STATUS_IGNORE);
260                 if (newrank > newdst) {
261                     send_idx = recv_idx;
262                 }
263
264                 mask >>= 1;
265             }
266             xbt_free(disps);
267             xbt_free(cnts);
268         }
269     }
270
271     /* In the non-power-of-two case, all odd-numbered
272        processes of rank < 2*rem send the result to
273        (rank-1), the ranks who didn't participate above. */
274     if (rank < 2 * rem) {
275         if (rank % 2) {     /* odd */
276             smpi_mpi_send(recvbuf, count,
277                                      datatype, rank - 1,
278                                      COLL_TAG_ALLREDUCE, comm);
279         } else {            /* even */
280             smpi_mpi_recv(recvbuf, count,
281                                   datatype, rank + 1,
282                                   COLL_TAG_ALLREDUCE, comm,
283                                   MPI_STATUS_IGNORE);
284         }
285     }
286     smpi_free_tmp_buffer(tmp_buf_free);
287     return (mpi_errno);
288
289 }