1 #include "colls_private.h"
8 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
9 int count, MPI_Datatype datatype,
10 MPI_Op op, int root, MPI_Comm comm)
13 int comm_size, rank, pof2, rem, newrank;
14 int mask, *cnts, *disps, i, j, send_idx = 0;
15 int recv_idx, last_idx = 0, newdst;
16 int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
17 int newroot_tree_root, new_count;
18 int tag = COLL_TAG_REDUCE;
19 void *send_ptr, *recv_ptr, *tmp_buf;
28 rank = smpi_comm_rank(comm);
29 comm_size = smpi_comm_size(comm);
31 extent = smpi_datatype_get_extent(datatype);
33 /* find nearest power-of-two less than or equal to comm_size */
35 while (pof2 <= comm_size)
39 if (count < comm_size) {
40 new_count = comm_size;
41 send_ptr = (void *) xbt_malloc(new_count * extent);
42 recv_ptr = (void *) xbt_malloc(new_count * extent);
43 tmp_buf = (void *) xbt_malloc(new_count * extent);
44 memcpy(send_ptr, sendbuf, extent * count);
47 smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
48 recv_ptr, new_count, datatype, rank, tag, comm, &status);
50 rem = comm_size - pof2;
54 smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
57 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
58 smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
61 } else /* rank >= 2*rem */
64 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
65 disps = (int *) xbt_malloc(pof2 * sizeof(int));
68 for (i = 0; i < (pof2 - 1); i++)
69 cnts[i] = new_count / pof2;
70 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
73 for (i = 1; i < pof2; i++)
74 disps[i] = disps[i - 1] + cnts[i - 1];
77 send_idx = recv_idx = 0;
80 newdst = newrank ^ mask;
81 /* find real rank of dest */
82 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
84 send_cnt = recv_cnt = 0;
85 if (newrank < newdst) {
86 send_idx = recv_idx + pof2 / (mask * 2);
87 for (i = send_idx; i < last_idx; i++)
89 for (i = recv_idx; i < send_idx; i++)
92 recv_idx = send_idx + pof2 / (mask * 2);
93 for (i = send_idx; i < recv_idx; i++)
95 for (i = recv_idx; i < last_idx; i++)
99 /* Send data from recvbuf. Recv into tmp_buf */
100 smpi_mpi_sendrecv((char *) recv_ptr +
101 disps[send_idx] * extent,
105 disps[recv_idx] * extent,
106 recv_cnt, datatype, dst, tag, comm, &status);
108 /* tmp_buf contains data received in this step.
109 recvbuf contains data accumulated so far */
111 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
112 (char *) recv_ptr + disps[recv_idx] * extent,
113 &recv_cnt, &datatype);
115 /* update send_idx for next iteration */
120 last_idx = recv_idx + pof2 / mask;
124 /* now do the gather to root */
126 if (root < 2 * rem) {
130 for (i = 0; i < (pof2 - 1); i++)
131 cnts[i] = new_count / pof2;
132 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
135 for (i = 1; i < pof2; i++)
136 disps[i] = disps[i - 1] + cnts[i - 1];
138 smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
143 } else if (newrank == 0) {
144 smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
151 newroot = root - rem;
156 while (mask < pof2) {
163 newdst = newrank ^ mask;
165 /* find real rank of dest */
166 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
168 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
170 newdst_tree_root = newdst >> j;
171 newdst_tree_root <<= j;
173 newroot_tree_root = newroot >> j;
174 newroot_tree_root <<= j;
176 send_cnt = recv_cnt = 0;
177 if (newrank < newdst) {
178 /* update last_idx except on first iteration */
179 if (mask != pof2 / 2)
180 last_idx = last_idx + pof2 / (mask * 2);
182 recv_idx = send_idx + pof2 / (mask * 2);
183 for (i = send_idx; i < recv_idx; i++)
185 for (i = recv_idx; i < last_idx; i++)
188 recv_idx = send_idx - pof2 / (mask * 2);
189 for (i = send_idx; i < last_idx; i++)
191 for (i = recv_idx; i < send_idx; i++)
195 if (newdst_tree_root == newroot_tree_root) {
196 smpi_mpi_send((char *) recv_ptr +
197 disps[send_idx] * extent,
198 send_cnt, datatype, dst, tag, comm);
201 smpi_mpi_recv((char *) recv_ptr +
202 disps[recv_idx] * extent,
203 recv_cnt, datatype, dst, tag, comm, &status);
206 if (newrank > newdst)
213 memcpy(recvbuf, recv_ptr, extent * count);
219 else /* (count >= comm_size) */ {
220 tmp_buf = (void *) xbt_malloc(count * extent);
222 //if ((rank != root))
223 smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
224 recvbuf, count, datatype, rank, tag, comm, &status);
226 rem = comm_size - pof2;
227 if (rank < 2 * rem) {
228 if (rank % 2 != 0) { /* odd */
229 smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
234 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
235 smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
238 } else /* rank >= 2*rem */
239 newrank = rank - rem;
241 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
242 disps = (int *) xbt_malloc(pof2 * sizeof(int));
245 for (i = 0; i < (pof2 - 1); i++)
246 cnts[i] = count / pof2;
247 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
250 for (i = 1; i < pof2; i++)
251 disps[i] = disps[i - 1] + cnts[i - 1];
254 send_idx = recv_idx = 0;
256 while (mask < pof2) {
257 newdst = newrank ^ mask;
258 /* find real rank of dest */
259 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
261 send_cnt = recv_cnt = 0;
262 if (newrank < newdst) {
263 send_idx = recv_idx + pof2 / (mask * 2);
264 for (i = send_idx; i < last_idx; i++)
266 for (i = recv_idx; i < send_idx; i++)
269 recv_idx = send_idx + pof2 / (mask * 2);
270 for (i = send_idx; i < recv_idx; i++)
272 for (i = recv_idx; i < last_idx; i++)
276 /* Send data from recvbuf. Recv into tmp_buf */
277 smpi_mpi_sendrecv((char *) recvbuf +
278 disps[send_idx] * extent,
282 disps[recv_idx] * extent,
283 recv_cnt, datatype, dst, tag, comm, &status);
285 /* tmp_buf contains data received in this step.
286 recvbuf contains data accumulated so far */
288 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
289 (char *) recvbuf + disps[recv_idx] * extent,
290 &recv_cnt, &datatype);
292 /* update send_idx for next iteration */
297 last_idx = recv_idx + pof2 / mask;
301 /* now do the gather to root */
303 if (root < 2 * rem) {
305 if (rank == root) { /* recv */
306 for (i = 0; i < (pof2 - 1); i++)
307 cnts[i] = count / pof2;
308 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
311 for (i = 1; i < pof2; i++)
312 disps[i] = disps[i - 1] + cnts[i - 1];
314 smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
319 } else if (newrank == 0) {
320 smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
327 newroot = root - rem;
332 while (mask < pof2) {
339 newdst = newrank ^ mask;
341 /* find real rank of dest */
342 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
344 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
346 newdst_tree_root = newdst >> j;
347 newdst_tree_root <<= j;
349 newroot_tree_root = newroot >> j;
350 newroot_tree_root <<= j;
352 send_cnt = recv_cnt = 0;
353 if (newrank < newdst) {
354 /* update last_idx except on first iteration */
355 if (mask != pof2 / 2)
356 last_idx = last_idx + pof2 / (mask * 2);
358 recv_idx = send_idx + pof2 / (mask * 2);
359 for (i = send_idx; i < recv_idx; i++)
361 for (i = recv_idx; i < last_idx; i++)
364 recv_idx = send_idx - pof2 / (mask * 2);
365 for (i = send_idx; i < last_idx; i++)
367 for (i = recv_idx; i < send_idx; i++)
371 if (newdst_tree_root == newroot_tree_root) {
372 smpi_mpi_send((char *) recvbuf +
373 disps[send_idx] * extent,
374 send_cnt, datatype, dst, tag, comm);
377 smpi_mpi_recv((char *) recvbuf +
378 disps[recv_idx] * extent,
379 recv_cnt, datatype, dst, tag, comm, &status);
382 if (newrank > newdst)