Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
oops
[simgrid.git] / src / smpi / colls / allreduce-rab-rdb.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "colls_private.h"
8
9 int smpi_coll_tuned_allreduce_rab_rdb(void *sbuff, void *rbuff, int count,
10                                       MPI_Datatype dtype, MPI_Op op,
11                                       MPI_Comm comm)
12 {
13   int nprocs, rank, tag = COLL_TAG_ALLREDUCE;
14   int mask, dst, pof2, newrank, rem, newdst, i,
15       send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
16   MPI_Aint extent;
17   MPI_Status status;
18   void *tmp_buf = NULL;
19
20   nprocs = smpi_comm_size(comm);
21   rank = smpi_comm_rank(comm);
22
23   extent = smpi_datatype_get_extent(dtype);
24   tmp_buf = (void *) smpi_get_tmp_sendbuffer(count * extent);
25
26   smpi_datatype_copy(sbuff, count, dtype, rbuff, count, dtype);
27
28   // find nearest power-of-two less than or equal to comm_size
29   pof2 = 1;
30   while (pof2 <= nprocs)
31     pof2 <<= 1;
32   pof2 >>= 1;
33
34   rem = nprocs - pof2;
35
36   // In the non-power-of-two case, all even-numbered
37   // processes of rank < 2*rem send their data to
38   // (rank+1). These even-numbered processes no longer
39   // participate in the algorithm until the very end. The
40   // remaining processes form a nice power-of-two. 
41
42   if (rank < 2 * rem) {
43     // even       
44     if (rank % 2 == 0) {
45
46       smpi_mpi_send(rbuff, count, dtype, rank + 1, tag, comm);
47
48       // temporarily set the rank to -1 so that this
49       // process does not pariticipate in recursive
50       // doubling
51       newrank = -1;
52     } else                      // odd
53     {
54       smpi_mpi_recv(tmp_buf, count, dtype, rank - 1, tag, comm, &status);
55       // do the reduction on received data. since the
56       // ordering is right, it doesn't matter whether
57       // the operation is commutative or not.
58        smpi_op_apply(op, tmp_buf, rbuff, &count, &dtype);
59
60       // change the rank 
61       newrank = rank / 2;
62     }
63   }
64
65   else                          // rank >= 2 * rem 
66     newrank = rank - rem;
67
68   // If op is user-defined or count is less than pof2, use
69   // recursive doubling algorithm. Otherwise do a reduce-scatter
70   // followed by allgather. (If op is user-defined,
71   // derived datatypes are allowed and the user could pass basic
72   // datatypes on one process and derived on another as long as
73   // the type maps are the same. Breaking up derived
74   // datatypes to do the reduce-scatter is tricky, therefore
75   // using recursive doubling in that case.) 
76
77   if (newrank != -1) {
78     // do a reduce-scatter followed by allgather. for the
79     // reduce-scatter, calculate the count that each process receives
80     // and the displacement within the buffer 
81
82     cnts = (int *) xbt_malloc(pof2 * sizeof(int));
83     disps = (int *) xbt_malloc(pof2 * sizeof(int));
84
85     for (i = 0; i < (pof2 - 1); i++)
86       cnts[i] = count / pof2;
87     cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
88
89     disps[0] = 0;
90     for (i = 1; i < pof2; i++)
91       disps[i] = disps[i - 1] + cnts[i - 1];
92
93     mask = 0x1;
94     send_idx = recv_idx = 0;
95     last_idx = pof2;
96     while (mask < pof2) {
97       newdst = newrank ^ mask;
98       // find real rank of dest 
99       dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
100
101       send_cnt = recv_cnt = 0;
102       if (newrank < newdst) {
103         send_idx = recv_idx + pof2 / (mask * 2);
104         for (i = send_idx; i < last_idx; i++)
105           send_cnt += cnts[i];
106         for (i = recv_idx; i < send_idx; i++)
107           recv_cnt += cnts[i];
108       } else {
109         recv_idx = send_idx + pof2 / (mask * 2);
110         for (i = send_idx; i < recv_idx; i++)
111           send_cnt += cnts[i];
112         for (i = recv_idx; i < last_idx; i++)
113           recv_cnt += cnts[i];
114       }
115
116       // Send data from recvbuf. Recv into tmp_buf 
117       smpi_mpi_sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
118                    dtype, dst, tag,
119                    (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt,
120                    dtype, dst, tag, comm, &status);
121
122       // tmp_buf contains data received in this step.
123       // recvbuf contains data accumulated so far 
124
125       // This algorithm is used only for predefined ops
126       // and predefined ops are always commutative.
127       smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
128                         (char *) rbuff + disps[recv_idx] * extent, &recv_cnt, &dtype);
129
130       // update send_idx for next iteration 
131       send_idx = recv_idx;
132       mask <<= 1;
133
134       // update last_idx, but not in last iteration because the value
135       // is needed in the allgather step below. 
136       if (mask < pof2)
137         last_idx = recv_idx + pof2 / mask;
138     }
139
140     // now do the allgather 
141
142     mask >>= 1;
143     while (mask > 0) {
144       newdst = newrank ^ mask;
145       // find real rank of dest
146       dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
147
148       send_cnt = recv_cnt = 0;
149       if (newrank < newdst) {
150         // update last_idx except on first iteration 
151         if (mask != pof2 / 2)
152           last_idx = last_idx + pof2 / (mask * 2);
153
154         recv_idx = send_idx + pof2 / (mask * 2);
155         for (i = send_idx; i < recv_idx; i++)
156           send_cnt += cnts[i];
157         for (i = recv_idx; i < last_idx; i++)
158           recv_cnt += cnts[i];
159       } else {
160         recv_idx = send_idx - pof2 / (mask * 2);
161         for (i = send_idx; i < last_idx; i++)
162           send_cnt += cnts[i];
163         for (i = recv_idx; i < send_idx; i++)
164           recv_cnt += cnts[i];
165       }
166
167       smpi_mpi_sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
168                    dtype, dst, tag,
169                    (char *) rbuff + disps[recv_idx] * extent, recv_cnt,
170                    dtype, dst, tag, comm, &status);
171
172       if (newrank > newdst)
173         send_idx = recv_idx;
174
175       mask >>= 1;
176     }
177
178     free(cnts);
179     free(disps);
180
181   }
182   // In the non-power-of-two case, all odd-numbered processes of
183   // rank < 2 * rem send the result to (rank-1), the ranks who didn't
184   // participate above.
185
186   if (rank < 2 * rem) {
187     if (rank % 2)               // odd 
188       smpi_mpi_send(rbuff, count, dtype, rank - 1, tag, comm);
189     else                        // even 
190       smpi_mpi_recv(rbuff, count, dtype, rank + 1, tag, comm, &status);
191   }
192
193   smpi_free_tmp_buffer(tmp_buf);
194   return MPI_SUCCESS;
195 }