1 #include "../colls_private.h"
5 static int scatter_for_bcast(
12 int rank, comm_size, src, dst;
13 int relative_rank, mask;
14 int mpi_errno = MPI_SUCCESS;
15 int scatter_size, curr_size, recv_size = 0, send_size;
17 comm_size = comm->size();
19 relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
21 /* use long message algorithm: binomial tree scatter followed by an allgather */
23 /* The scatter algorithm divides the buffer into nprocs pieces and
24 scatters them among the processes. Root gets the first piece,
25 root+1 gets the second piece, and so forth. Uses the same binomial
26 tree algorithm as above. Ceiling division
27 is used to compute the size of each piece. This means some
28 processes may not get any data. For example if bufsize = 97 and
29 nprocs = 16, ranks 15 and 16 will get 0 data. On each process, the
30 scattered data is stored at the same offset in the buffer as it is
31 on the root process. */
33 scatter_size = (nbytes + comm_size - 1)/comm_size; /* ceiling division */
34 curr_size = (rank == root) ? nbytes : 0; /* root starts with all the
38 while (mask < comm_size)
40 if (relative_rank & mask)
43 if (src < 0) src += comm_size;
44 recv_size = nbytes - relative_rank*scatter_size;
45 /* recv_size is larger than what might actually be sent by the
46 sender. We don't need compute the exact value because MPI
47 allows you to post a larger recv.*/
50 curr_size = 0; /* this process doesn't receive any data
51 because of uneven division */
55 Request::recv(((char *)tmp_buf +
56 relative_rank*scatter_size),
57 recv_size, MPI_BYTE, src,
58 COLL_TAG_BCAST, comm, &status);
59 /* query actual size of data received */
60 curr_size=Status::get_count(&status, MPI_BYTE);
67 /* This process is responsible for all processes that have bits
68 set from the LSB upto (but not including) mask. Because of
69 the "not including", we start by shifting mask back down
75 if (relative_rank + mask < comm_size)
77 send_size = curr_size - scatter_size * mask;
78 /* mask is also the size of this process's subtree */
83 if (dst >= comm_size) dst -= comm_size;
84 Request::send(((char *)tmp_buf +
85 scatter_size*(relative_rank+mask)),
86 send_size, MPI_BYTE, dst,
87 COLL_TAG_BCAST, comm);
88 curr_size -= send_size;
99 Coll_bcast_scatter_rdb_allgather::bcast (
102 MPI_Datatype datatype,
107 int rank, comm_size, dst;
108 int relative_rank, mask;
109 int mpi_errno = MPI_SUCCESS;
110 int scatter_size, curr_size, recv_size = 0;
111 int j, k, i, tmp_mask, is_contig, is_homogeneous;
112 MPI_Aint type_size = 0, nbytes = 0;
113 int relative_dst, dst_tree_root, my_tree_root, send_offset;
114 int recv_offset, tree_root, nprocs_completed, offset;
116 MPI_Aint true_extent, true_lb;
119 comm_size = comm->size();
121 relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
123 /* If there is only one process, return */
124 if (comm_size == 1) goto fn_exit;
126 //if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)
127 if(datatype->flags() & DT_FLAG_CONTIGUOUS)
135 /* MPI_Type_size() might not give the accurate size of the packed
136 * datatype for heterogeneous systems (because of padding, encoding,
137 * etc). On the other hand, MPI_Pack_size() can become very
138 * expensive, depending on the implementation, especially for
139 * heterogeneous systems. We want to use MPI_Type_size() wherever
140 * possible, and MPI_Pack_size() in other places.
143 type_size=datatype->size();
145 nbytes = type_size * count;
147 goto fn_exit; /* nothing to do */
149 if (is_contig && is_homogeneous)
151 /* contiguous and homogeneous. no need to pack. */
152 datatype->extent(&true_lb, &true_extent);
154 tmp_buf = (char *) buffer + true_lb;
158 tmp_buf=(void*)xbt_malloc(nbytes);
160 /* TODO: Pipeline the packing and communication */
163 mpi_errno = datatype->pack(buffer, count, tmp_buf, nbytes,
165 if (mpi_errno) xbt_die("crash while packing %d", mpi_errno);
170 scatter_size = (nbytes + comm_size - 1)/comm_size; /* ceiling division */
172 mpi_errno = scatter_for_bcast(root, comm,
175 xbt_die("crash while scattering %d", mpi_errno);
178 /* curr_size is the amount of data that this process now has stored in
179 * buffer at byte offset (relative_rank*scatter_size) */
180 curr_size = scatter_size < (nbytes - (relative_rank * scatter_size)) ? scatter_size : (nbytes - (relative_rank * scatter_size)) ;
184 /* medium size allgather and pof2 comm_size. use recurive doubling. */
188 while (mask < comm_size)
190 relative_dst = relative_rank ^ mask;
192 dst = (relative_dst + root) % comm_size;
194 /* find offset into send and recv buffers.
195 zero out the least significant "i" bits of relative_rank and
196 relative_dst to find root of src and dst
197 subtrees. Use ranks of roots as index to send from
198 and recv into buffer */
200 dst_tree_root = relative_dst >> i;
203 my_tree_root = relative_rank >> i;
206 send_offset = my_tree_root * scatter_size;
207 recv_offset = dst_tree_root * scatter_size;
209 if (relative_dst < comm_size)
211 Request::sendrecv(((char *)tmp_buf + send_offset),
212 curr_size, MPI_BYTE, dst, COLL_TAG_BCAST,
213 ((char *)tmp_buf + recv_offset),
214 (nbytes-recv_offset < 0 ? 0 : nbytes-recv_offset),
215 MPI_BYTE, dst, COLL_TAG_BCAST, comm, &status);
216 recv_size=Status::get_count(&status, MPI_BYTE);
217 curr_size += recv_size;
220 /* if some processes in this process's subtree in this step
221 did not have any destination process to communicate with
222 because of non-power-of-two, we need to send them the
223 data that they would normally have received from those
224 processes. That is, the haves in this subtree must send to
225 the havenots. We use a logarithmic recursive-halfing algorithm
228 /* This part of the code will not currently be
229 executed because we are not using recursive
230 doubling for non power of two. Mark it as experimental
231 so that it doesn't show up as red in the coverage tests. */
233 /* --BEGIN EXPERIMENTAL-- */
234 if (dst_tree_root + mask > comm_size)
236 nprocs_completed = comm_size - my_tree_root - mask;
237 /* nprocs_completed is the number of processes in this
238 subtree that have all the data. Send data to others
239 in a tree fashion. First find root of current tree
240 that is being divided into two. k is the number of
241 least-significant bits in this process's rank that
242 must be zeroed out to find the rank of the root */
252 offset = scatter_size * (my_tree_root + mask);
253 tmp_mask = mask >> 1;
257 relative_dst = relative_rank ^ tmp_mask;
258 dst = (relative_dst + root) % comm_size;
260 tree_root = relative_rank >> k;
263 /* send only if this proc has data and destination
264 doesn't have data. */
267 printf("rank %d, dst %d, root %d, nprocs_completed %d\n", relative_rank, relative_dst, tree_root, nprocs_completed);
271 if ((relative_dst > relative_rank) &&
272 (relative_rank < tree_root + nprocs_completed)
273 && (relative_dst >= tree_root + nprocs_completed))
276 /* printf("Rank %d, send to %d, offset %d, size %d\n", rank, dst, offset, recv_size);
278 Request::send(((char *)tmp_buf + offset),
279 recv_size, MPI_BYTE, dst,
280 COLL_TAG_BCAST, comm);
281 /* recv_size was set in the previous
282 receive. that's the amount of data to be
285 /* recv only if this proc. doesn't have data and sender
287 else if ((relative_dst < relative_rank) &&
288 (relative_dst < tree_root + nprocs_completed) &&
289 (relative_rank >= tree_root + nprocs_completed))
291 /* printf("Rank %d waiting to recv from rank %d\n",
292 relative_rank, dst); */
293 Request::recv(((char *)tmp_buf + offset),
295 MPI_BYTE, dst, COLL_TAG_BCAST,
297 /* nprocs_completed is also equal to the no. of processes
298 whose data we don't have */
299 recv_size=Status::get_count(&status, MPI_BYTE);
300 curr_size += recv_size;
301 /* printf("Rank %d, recv from %d, offset %d, size %d\n", rank, dst, offset, recv_size);
308 /* --END EXPERIMENTAL-- */
314 /* check that we received as much as we expected */
315 /* recvd_size may not be accurate for packed heterogeneous data */
316 if (is_homogeneous && curr_size != nbytes) {
317 xbt_die("we didn't receive enough !");
320 if (!is_contig || !is_homogeneous)
325 mpi_errno = MPI_Unpack(tmp_buf, nbytes, &position, buffer,
326 count, datatype, comm);
327 if (mpi_errno) xbt_die("error when unpacking %d", mpi_errno);
332 /* xbt_free(tmp_buf);*/