1 #include "colls_private.h"
3 int smpi_coll_tuned_allreduce_rab_rdb(void *sbuff, void *rbuff, int count,
4 MPI_Datatype dtype, MPI_Op op,
7 int nprocs, rank, type_size, tag = 543;
8 int mask, dst, pof2, newrank, rem, newdst, i,
9 send_idx, recv_idx, last_idx, send_cnt, recv_cnt, *cnts, *disps;
14 #ifdef MPICH2_REDUCTION
15 MPI_User_function *uop = MPIR_Op_table[op % 16 - 1];
17 MPI_User_function *uop;
18 struct MPIR_OP *op_ptr;
19 op_ptr = (MPI_User_function *) MPIR_ToPointer(op);
23 nprocs = smpi_comm_size(comm);
24 rank = smpi_comm_rank(comm);
26 extent = smpi_datatype_get_extent(dtype);
27 tmp_buf = (void *) xbt_malloc(count * extent);
29 MPIR_Localcopy(sbuff, count, dtype, rbuff, count, dtype);
31 type_size = smpi_datatype_size(dtype);
33 // find nearest power-of-two less than or equal to comm_size
35 while (pof2 <= nprocs)
41 // In the non-power-of-two case, all even-numbered
42 // processes of rank < 2*rem send their data to
43 // (rank+1). These even-numbered processes no longer
44 // participate in the algorithm until the very end. The
45 // remaining processes form a nice power-of-two.
51 smpi_mpi_send(rbuff, count, dtype, rank + 1, tag, comm);
53 // temporarily set the rank to -1 so that this
54 // process does not pariticipate in recursive
59 smpi_mpi_recv(tmp_buf, count, dtype, rank - 1, tag, comm, &status);
60 // do the reduction on received data. since the
61 // ordering is right, it doesn't matter whether
62 // the operation is commutative or not.
63 (*uop) (tmp_buf, rbuff, &count, &dtype);
70 else // rank >= 2 * rem
73 // If op is user-defined or count is less than pof2, use
74 // recursive doubling algorithm. Otherwise do a reduce-scatter
75 // followed by allgather. (If op is user-defined,
76 // derived datatypes are allowed and the user could pass basic
77 // datatypes on one process and derived on another as long as
78 // the type maps are the same. Breaking up derived
79 // datatypes to do the reduce-scatter is tricky, therefore
80 // using recursive doubling in that case.)
83 // do a reduce-scatter followed by allgather. for the
84 // reduce-scatter, calculate the count that each process receives
85 // and the displacement within the buffer
87 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
88 disps = (int *) xbt_malloc(pof2 * sizeof(int));
90 for (i = 0; i < (pof2 - 1); i++)
91 cnts[i] = count / pof2;
92 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
95 for (i = 1; i < pof2; i++)
96 disps[i] = disps[i - 1] + cnts[i - 1];
99 send_idx = recv_idx = 0;
101 while (mask < pof2) {
102 newdst = newrank ^ mask;
103 // find real rank of dest
104 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
106 send_cnt = recv_cnt = 0;
107 if (newrank < newdst) {
108 send_idx = recv_idx + pof2 / (mask * 2);
109 for (i = send_idx; i < last_idx; i++)
111 for (i = recv_idx; i < send_idx; i++)
114 recv_idx = send_idx + pof2 / (mask * 2);
115 for (i = send_idx; i < recv_idx; i++)
117 for (i = recv_idx; i < last_idx; i++)
121 // Send data from recvbuf. Recv into tmp_buf
122 smpi_mpi_sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
124 (char *) tmp_buf + disps[recv_idx] * extent, recv_cnt,
125 dtype, dst, tag, comm, &status);
127 // tmp_buf contains data received in this step.
128 // recvbuf contains data accumulated so far
130 // This algorithm is used only for predefined ops
131 // and predefined ops are always commutative.
132 (*uop) ((char *) tmp_buf + disps[recv_idx] * extent,
133 (char *) rbuff + disps[recv_idx] * extent, &recv_cnt, &dtype);
135 // update send_idx for next iteration
139 // update last_idx, but not in last iteration because the value
140 // is needed in the allgather step below.
142 last_idx = recv_idx + pof2 / mask;
145 // now do the allgather
149 newdst = newrank ^ mask;
150 // find real rank of dest
151 dst = (newdst < rem) ? newdst * 2 + 1 : newdst + rem;
153 send_cnt = recv_cnt = 0;
154 if (newrank < newdst) {
155 // update last_idx except on first iteration
156 if (mask != pof2 / 2)
157 last_idx = last_idx + pof2 / (mask * 2);
159 recv_idx = send_idx + pof2 / (mask * 2);
160 for (i = send_idx; i < recv_idx; i++)
162 for (i = recv_idx; i < last_idx; i++)
165 recv_idx = send_idx - pof2 / (mask * 2);
166 for (i = send_idx; i < last_idx; i++)
168 for (i = recv_idx; i < send_idx; i++)
172 smpi_mpi_sendrecv((char *) rbuff + disps[send_idx] * extent, send_cnt,
174 (char *) rbuff + disps[recv_idx] * extent, recv_cnt,
175 dtype, dst, tag, comm, &status);
177 if (newrank > newdst)
187 // In the non-power-of-two case, all odd-numbered processes of
188 // rank < 2 * rem send the result to (rank-1), the ranks who didn't
189 // participate above.
191 if (rank < 2 * rem) {
193 smpi_mpi_send(rbuff, count, dtype, rank - 1, tag, comm);
195 smpi_mpi_recv(rbuff, count, dtype, rank + 1, tag, comm, &status);