1 /* Copyright (c) 2013-2014. The SimGrid Team.
2 * All rights reserved. */
4 /* This program is free software; you can redistribute it and/or modify it
5 * under the terms of the license (GNU LGPL) which comes with this package. */
7 #include "colls_private.h"
14 int smpi_coll_tuned_reduce_scatter_gather(void *sendbuf, void *recvbuf,
15 int count, MPI_Datatype datatype,
16 MPI_Op op, int root, MPI_Comm comm)
19 int comm_size, rank, pof2, rem, newrank;
20 int mask, *cnts, *disps, i, j, send_idx = 0;
21 int recv_idx, last_idx = 0, newdst;
22 int dst, send_cnt, recv_cnt, newroot, newdst_tree_root;
23 int newroot_tree_root, new_count;
24 int tag = COLL_TAG_REDUCE,temporary_buffer=0;
25 void *send_ptr, *recv_ptr, *tmp_buf;
34 rank = smpi_comm_rank(comm);
35 comm_size = smpi_comm_size(comm);
39 extent = smpi_datatype_get_extent(datatype);
40 /* If I'm not the root, then my recvbuf may not be valid, therefore
41 I have to allocate a temporary one */
42 if (rank != root && !recvbuf) {
44 recvbuf = (void *)smpi_get_tmp_recvbuffer(count * extent);
46 /* find nearest power-of-two less than or equal to comm_size */
48 while (pof2 <= comm_size)
52 if (count < comm_size) {
53 new_count = comm_size;
54 send_ptr = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
55 recv_ptr = (void *) smpi_get_tmp_recvbuffer(new_count * extent);
56 tmp_buf = (void *) smpi_get_tmp_sendbuffer(new_count * extent);
57 memcpy(send_ptr, sendbuf, extent * count);
60 smpi_mpi_sendrecv(send_ptr, new_count, datatype, rank, tag,
61 recv_ptr, new_count, datatype, rank, tag, comm, &status);
63 rem = comm_size - pof2;
67 smpi_mpi_send(recv_ptr, new_count, datatype, rank - 1, tag, comm);
70 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
71 smpi_op_apply(op, tmp_buf, recv_ptr, &new_count, &datatype);
74 } else /* rank >= 2*rem */
77 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
78 disps = (int *) xbt_malloc(pof2 * sizeof(int));
81 for (i = 0; i < (pof2 - 1); i++)
82 cnts[i] = new_count / pof2;
83 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
86 for (i = 1; i < pof2; i++)
87 disps[i] = disps[i - 1] + cnts[i - 1];
90 send_idx = recv_idx = 0;
93 newdst = newrank ^ mask;
94 /* find real rank of dest */
95 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
97 send_cnt = recv_cnt = 0;
98 if (newrank < newdst) {
99 send_idx = recv_idx + pof2 / (mask * 2);
100 for (i = send_idx; i < last_idx; i++)
102 for (i = recv_idx; i < send_idx; i++)
105 recv_idx = send_idx + pof2 / (mask * 2);
106 for (i = send_idx; i < recv_idx; i++)
108 for (i = recv_idx; i < last_idx; i++)
112 /* Send data from recvbuf. Recv into tmp_buf */
113 smpi_mpi_sendrecv((char *) recv_ptr +
114 disps[send_idx] * extent,
118 disps[recv_idx] * extent,
119 recv_cnt, datatype, dst, tag, comm, &status);
121 /* tmp_buf contains data received in this step.
122 recvbuf contains data accumulated so far */
124 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
125 (char *) recv_ptr + disps[recv_idx] * extent,
126 &recv_cnt, &datatype);
128 /* update send_idx for next iteration */
133 last_idx = recv_idx + pof2 / mask;
137 /* now do the gather to root */
139 if (root < 2 * rem) {
143 for (i = 0; i < (pof2 - 1); i++)
144 cnts[i] = new_count / pof2;
145 cnts[pof2 - 1] = new_count - (new_count / pof2) * (pof2 - 1);
148 for (i = 1; i < pof2; i++)
149 disps[i] = disps[i - 1] + cnts[i - 1];
151 smpi_mpi_recv(recv_ptr, cnts[0], datatype, 0, tag, comm, &status);
156 } else if (newrank == 0) {
157 smpi_mpi_send(recv_ptr, cnts[0], datatype, root, tag, comm);
164 newroot = root - rem;
169 while (mask < pof2) {
176 newdst = newrank ^ mask;
178 /* find real rank of dest */
179 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
181 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
183 newdst_tree_root = newdst >> j;
184 newdst_tree_root <<= j;
186 newroot_tree_root = newroot >> j;
187 newroot_tree_root <<= j;
189 send_cnt = recv_cnt = 0;
190 if (newrank < newdst) {
191 /* update last_idx except on first iteration */
192 if (mask != pof2 / 2)
193 last_idx = last_idx + pof2 / (mask * 2);
195 recv_idx = send_idx + pof2 / (mask * 2);
196 for (i = send_idx; i < recv_idx; i++)
198 for (i = recv_idx; i < last_idx; i++)
201 recv_idx = send_idx - pof2 / (mask * 2);
202 for (i = send_idx; i < last_idx; i++)
204 for (i = recv_idx; i < send_idx; i++)
208 if (newdst_tree_root == newroot_tree_root) {
209 smpi_mpi_send((char *) recv_ptr +
210 disps[send_idx] * extent,
211 send_cnt, datatype, dst, tag, comm);
214 smpi_mpi_recv((char *) recv_ptr +
215 disps[recv_idx] * extent,
216 recv_cnt, datatype, dst, tag, comm, &status);
219 if (newrank > newdst)
226 memcpy(recvbuf, recv_ptr, extent * count);
227 smpi_free_tmp_buffer(send_ptr);
228 smpi_free_tmp_buffer(recv_ptr);
232 else /* (count >= comm_size) */ {
233 tmp_buf = (void *) smpi_get_tmp_sendbuffer(count * extent);
235 //if ((rank != root))
236 smpi_mpi_sendrecv(sendbuf, count, datatype, rank, tag,
237 recvbuf, count, datatype, rank, tag, comm, &status);
239 rem = comm_size - pof2;
240 if (rank < 2 * rem) {
241 if (rank % 2 != 0) { /* odd */
242 smpi_mpi_send(recvbuf, count, datatype, rank - 1, tag, comm);
247 smpi_mpi_recv(tmp_buf, count, datatype, rank + 1, tag, comm, &status);
248 smpi_op_apply(op, tmp_buf, recvbuf, &count, &datatype);
251 } else /* rank >= 2*rem */
252 newrank = rank - rem;
254 cnts = (int *) xbt_malloc(pof2 * sizeof(int));
255 disps = (int *) xbt_malloc(pof2 * sizeof(int));
258 for (i = 0; i < (pof2 - 1); i++)
259 cnts[i] = count / pof2;
260 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
263 for (i = 1; i < pof2; i++)
264 disps[i] = disps[i - 1] + cnts[i - 1];
267 send_idx = recv_idx = 0;
269 while (mask < pof2) {
270 newdst = newrank ^ mask;
271 /* find real rank of dest */
272 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
274 send_cnt = recv_cnt = 0;
275 if (newrank < newdst) {
276 send_idx = recv_idx + pof2 / (mask * 2);
277 for (i = send_idx; i < last_idx; i++)
279 for (i = recv_idx; i < send_idx; i++)
282 recv_idx = send_idx + pof2 / (mask * 2);
283 for (i = send_idx; i < recv_idx; i++)
285 for (i = recv_idx; i < last_idx; i++)
289 /* Send data from recvbuf. Recv into tmp_buf */
290 smpi_mpi_sendrecv((char *) recvbuf +
291 disps[send_idx] * extent,
295 disps[recv_idx] * extent,
296 recv_cnt, datatype, dst, tag, comm, &status);
298 /* tmp_buf contains data received in this step.
299 recvbuf contains data accumulated so far */
301 smpi_op_apply(op, (char *) tmp_buf + disps[recv_idx] * extent,
302 (char *) recvbuf + disps[recv_idx] * extent,
303 &recv_cnt, &datatype);
305 /* update send_idx for next iteration */
310 last_idx = recv_idx + pof2 / mask;
314 /* now do the gather to root */
316 if (root < 2 * rem) {
318 if (rank == root) { /* recv */
319 for (i = 0; i < (pof2 - 1); i++)
320 cnts[i] = count / pof2;
321 cnts[pof2 - 1] = count - (count / pof2) * (pof2 - 1);
324 for (i = 1; i < pof2; i++)
325 disps[i] = disps[i - 1] + cnts[i - 1];
327 smpi_mpi_recv(recvbuf, cnts[0], datatype, 0, tag, comm, &status);
332 } else if (newrank == 0) {
333 smpi_mpi_send(recvbuf, cnts[0], datatype, root, tag, comm);
340 newroot = root - rem;
345 while (mask < pof2) {
352 newdst = newrank ^ mask;
354 /* find real rank of dest */
355 dst = (newdst < rem) ? newdst * 2 : newdst + rem;
357 if ((newdst == 0) && (root < 2 * rem) && (root % 2 != 0))
359 newdst_tree_root = newdst >> j;
360 newdst_tree_root <<= j;
362 newroot_tree_root = newroot >> j;
363 newroot_tree_root <<= j;
365 send_cnt = recv_cnt = 0;
366 if (newrank < newdst) {
367 /* update last_idx except on first iteration */
368 if (mask != pof2 / 2)
369 last_idx = last_idx + pof2 / (mask * 2);
371 recv_idx = send_idx + pof2 / (mask * 2);
372 for (i = send_idx; i < recv_idx; i++)
374 for (i = recv_idx; i < last_idx; i++)
377 recv_idx = send_idx - pof2 / (mask * 2);
378 for (i = send_idx; i < last_idx; i++)
380 for (i = recv_idx; i < send_idx; i++)
384 if (newdst_tree_root == newroot_tree_root) {
385 smpi_mpi_send((char *) recvbuf +
386 disps[send_idx] * extent,
387 send_cnt, datatype, dst, tag, comm);
390 smpi_mpi_recv((char *) recvbuf +
391 disps[recv_idx] * extent,
392 recv_cnt, datatype, dst, tag, comm, &status);
395 if (newrank > newdst)
404 smpi_free_tmp_buffer(tmp_buf);
405 if(temporary_buffer==1) smpi_free_tmp_buffer(recvbuf);