Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[EXAMPLES] Added an example for the HostLoad plugin
[simgrid.git] / src / smpi / colls / bcast-scatter-rdb-allgather.cpp
1 #include "colls_private.h"
2
3 static int scatter_for_bcast(
4     int root,
5     MPI_Comm comm,
6     int nbytes,
7     void *tmp_buf)
8 {
9     MPI_Status status;
10     int        rank, comm_size, src, dst;
11     int        relative_rank, mask;
12     int mpi_errno = MPI_SUCCESS;
13     int scatter_size, curr_size, recv_size = 0, send_size;
14
15     comm_size = comm->size();
16     rank = comm->rank();
17     relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
18
19     /* use long message algorithm: binomial tree scatter followed by an allgather */
20
21     /* The scatter algorithm divides the buffer into nprocs pieces and
22        scatters them among the processes. Root gets the first piece,
23        root+1 gets the second piece, and so forth. Uses the same binomial
24        tree algorithm as above. Ceiling division
25        is used to compute the size of each piece. This means some
26        processes may not get any data. For example if bufsize = 97 and
27        nprocs = 16, ranks 15 and 16 will get 0 data. On each process, the
28        scattered data is stored at the same offset in the buffer as it is
29        on the root process. */ 
30
31     scatter_size = (nbytes + comm_size - 1)/comm_size; /* ceiling division */
32     curr_size = (rank == root) ? nbytes : 0; /* root starts with all the
33                                                 data */
34
35     mask = 0x1;
36     while (mask < comm_size)
37     {
38         if (relative_rank & mask)
39         {
40             src = rank - mask; 
41             if (src < 0) src += comm_size;
42             recv_size = nbytes - relative_rank*scatter_size;
43             /* recv_size is larger than what might actually be sent by the
44                sender. We don't need compute the exact value because MPI
45                allows you to post a larger recv.*/ 
46             if (recv_size <= 0)
47             {
48                 curr_size = 0; /* this process doesn't receive any data
49                                   because of uneven division */
50             }
51             else
52             {
53                 Request::recv(((char *)tmp_buf +
54                                           relative_rank*scatter_size),
55                                          recv_size, MPI_BYTE, src,
56                                          COLL_TAG_BCAST, comm, &status);
57                 /* query actual size of data received */
58                 curr_size=smpi_mpi_get_count(&status, MPI_BYTE);
59             }
60             break;
61         }
62         mask <<= 1;
63     }
64
65     /* This process is responsible for all processes that have bits
66        set from the LSB upto (but not including) mask.  Because of
67        the "not including", we start by shifting mask back down
68        one. */
69
70     mask >>= 1;
71     while (mask > 0)
72     {
73         if (relative_rank + mask < comm_size)
74         {
75             send_size = curr_size - scatter_size * mask; 
76             /* mask is also the size of this process's subtree */
77
78             if (send_size > 0)
79             {
80                 dst = rank + mask;
81                 if (dst >= comm_size) dst -= comm_size;
82                 Request::send(((char *)tmp_buf +
83                                           scatter_size*(relative_rank+mask)),
84                                          send_size, MPI_BYTE, dst,
85                                          COLL_TAG_BCAST, comm);
86                 curr_size -= send_size;
87             }
88         }
89         mask >>= 1;
90     }
91
92     return mpi_errno;
93 }
94
95 int
96 smpi_coll_tuned_bcast_scatter_rdb_allgather (
97     void *buffer, 
98     int count, 
99     MPI_Datatype datatype, 
100     int root, 
101     MPI_Comm comm)
102 {
103     MPI_Status status;
104     int rank, comm_size, dst;
105     int relative_rank, mask;
106     int mpi_errno = MPI_SUCCESS;
107     int scatter_size, curr_size, recv_size = 0;
108     int j, k, i, tmp_mask, is_contig, is_homogeneous;
109     MPI_Aint type_size = 0, nbytes = 0;
110     int relative_dst, dst_tree_root, my_tree_root, send_offset;
111     int recv_offset, tree_root, nprocs_completed, offset;
112     int position;
113     MPI_Aint true_extent, true_lb;
114     void *tmp_buf;
115
116     comm_size = comm->size();
117     rank = comm->rank();
118     relative_rank = (rank >= root) ? rank - root : rank - root + comm_size;
119
120     /* If there is only one process, return */
121     if (comm_size == 1) goto fn_exit;
122
123     //if (HANDLE_GET_KIND(datatype) == HANDLE_KIND_BUILTIN)
124     if(datatype->flags() & DT_FLAG_CONTIGUOUS)
125         is_contig = 1;
126     else {
127         is_contig = 0;
128     }
129
130     is_homogeneous = 1;
131
132     /* MPI_Type_size() might not give the accurate size of the packed
133      * datatype for heterogeneous systems (because of padding, encoding,
134      * etc). On the other hand, MPI_Pack_size() can become very
135      * expensive, depending on the implementation, especially for
136      * heterogeneous systems. We want to use MPI_Type_size() wherever
137      * possible, and MPI_Pack_size() in other places.
138      */
139     if (is_homogeneous)
140         type_size=datatype->size();
141
142     nbytes = type_size * count;
143     if (nbytes == 0)
144         goto fn_exit; /* nothing to do */
145
146     if (is_contig && is_homogeneous)
147     {
148         /* contiguous and homogeneous. no need to pack. */
149         datatype->extent(&true_lb, &true_extent);
150
151         tmp_buf = (char *) buffer + true_lb;
152     }
153     else
154     {
155         tmp_buf=(void*)xbt_malloc(nbytes);
156
157         /* TODO: Pipeline the packing and communication */
158         position = 0;
159         if (rank == root) {
160             mpi_errno = datatype->pack(buffer, count, tmp_buf, nbytes,
161                                        &position, comm);
162             if (mpi_errno) xbt_die("crash while packing %d", mpi_errno);
163         }
164     }
165
166
167     scatter_size = (nbytes + comm_size - 1)/comm_size; /* ceiling division */
168
169     mpi_errno = scatter_for_bcast(root, comm,
170                                   nbytes, tmp_buf);
171     if (mpi_errno) {
172       xbt_die("crash while scattering %d", mpi_errno);
173     }
174
175     /* curr_size is the amount of data that this process now has stored in
176      * buffer at byte offset (relative_rank*scatter_size) */
177     curr_size = scatter_size < (nbytes - (relative_rank * scatter_size)) ? scatter_size :  (nbytes - (relative_rank * scatter_size)) ;
178     if (curr_size < 0)
179         curr_size = 0;
180
181     /* medium size allgather and pof2 comm_size. use recurive doubling. */
182
183     mask = 0x1;
184     i = 0;
185     while (mask < comm_size)
186     {
187         relative_dst = relative_rank ^ mask;
188
189         dst = (relative_dst + root) % comm_size; 
190
191         /* find offset into send and recv buffers.
192            zero out the least significant "i" bits of relative_rank and
193            relative_dst to find root of src and dst
194            subtrees. Use ranks of roots as index to send from
195            and recv into  buffer */ 
196
197         dst_tree_root = relative_dst >> i;
198         dst_tree_root <<= i;
199
200         my_tree_root = relative_rank >> i;
201         my_tree_root <<= i;
202
203         send_offset = my_tree_root * scatter_size;
204         recv_offset = dst_tree_root * scatter_size;
205
206         if (relative_dst < comm_size)
207         {
208             Request::sendrecv(((char *)tmp_buf + send_offset),
209                                          curr_size, MPI_BYTE, dst, COLL_TAG_BCAST, 
210                                          ((char *)tmp_buf + recv_offset),
211                                          (nbytes-recv_offset < 0 ? 0 : nbytes-recv_offset), 
212                                          MPI_BYTE, dst, COLL_TAG_BCAST, comm, &status);
213             recv_size=smpi_mpi_get_count(&status, MPI_BYTE);
214             curr_size += recv_size;
215         }
216
217         /* if some processes in this process's subtree in this step
218            did not have any destination process to communicate with
219            because of non-power-of-two, we need to send them the
220            data that they would normally have received from those
221            processes. That is, the haves in this subtree must send to
222            the havenots. We use a logarithmic recursive-halfing algorithm
223            for this. */
224
225         /* This part of the code will not currently be
226            executed because we are not using recursive
227            doubling for non power of two. Mark it as experimental
228            so that it doesn't show up as red in the coverage tests. */  
229
230         /* --BEGIN EXPERIMENTAL-- */
231         if (dst_tree_root + mask > comm_size)
232         {
233             nprocs_completed = comm_size - my_tree_root - mask;
234             /* nprocs_completed is the number of processes in this
235                subtree that have all the data. Send data to others
236                in a tree fashion. First find root of current tree
237                that is being divided into two. k is the number of
238                least-significant bits in this process's rank that
239                must be zeroed out to find the rank of the root */ 
240             j = mask;
241             k = 0;
242             while (j)
243             {
244                 j >>= 1;
245                 k++;
246             }
247             k--;
248
249             offset = scatter_size * (my_tree_root + mask);
250             tmp_mask = mask >> 1;
251
252             while (tmp_mask)
253             {
254                 relative_dst = relative_rank ^ tmp_mask;
255                 dst = (relative_dst + root) % comm_size; 
256
257                 tree_root = relative_rank >> k;
258                 tree_root <<= k;
259
260                 /* send only if this proc has data and destination
261                    doesn't have data. */
262
263                 /* if (rank == 3) { 
264                    printf("rank %d, dst %d, root %d, nprocs_completed %d\n", relative_rank, relative_dst, tree_root, nprocs_completed);
265                    fflush(stdout);
266                    }*/
267
268                 if ((relative_dst > relative_rank) && 
269                     (relative_rank < tree_root + nprocs_completed)
270                     && (relative_dst >= tree_root + nprocs_completed))
271                 {
272
273                     /* printf("Rank %d, send to %d, offset %d, size %d\n", rank, dst, offset, recv_size);
274                        fflush(stdout); */
275                     Request::send(((char *)tmp_buf + offset),
276                                              recv_size, MPI_BYTE, dst,
277                                              COLL_TAG_BCAST, comm);
278                     /* recv_size was set in the previous
279                        receive. that's the amount of data to be
280                        sent now. */
281                 }
282                 /* recv only if this proc. doesn't have data and sender
283                    has data */
284                 else if ((relative_dst < relative_rank) && 
285                          (relative_dst < tree_root + nprocs_completed) &&
286                          (relative_rank >= tree_root + nprocs_completed))
287                 {
288                     /* printf("Rank %d waiting to recv from rank %d\n",
289                        relative_rank, dst); */
290                     Request::recv(((char *)tmp_buf + offset),
291                                              nbytes - offset, 
292                                              MPI_BYTE, dst, COLL_TAG_BCAST,
293                                              comm, &status);
294                     /* nprocs_completed is also equal to the no. of processes
295                        whose data we don't have */
296                     recv_size=smpi_mpi_get_count(&status, MPI_BYTE);
297                     curr_size += recv_size;
298                     /* printf("Rank %d, recv from %d, offset %d, size %d\n", rank, dst, offset, recv_size);
299                        fflush(stdout);*/
300                 }
301                 tmp_mask >>= 1;
302                 k--;
303             }
304         }
305         /* --END EXPERIMENTAL-- */
306
307         mask <<= 1;
308         i++;
309     }
310
311     /* check that we received as much as we expected */
312     /* recvd_size may not be accurate for packed heterogeneous data */
313     if (is_homogeneous && curr_size != nbytes) {
314       xbt_die("we didn't receive enough !");
315     }
316
317     if (!is_contig || !is_homogeneous)
318     {
319         if (rank != root)
320         {
321             position = 0;
322             mpi_errno = MPI_Unpack(tmp_buf, nbytes, &position, buffer,
323                                          count, datatype, comm);
324             if (mpi_errno) xbt_die("error when unpacking %d", mpi_errno);
325         }
326     }
327
328 fn_exit:
329 /*    xbt_free(tmp_buf);*/
330     return mpi_errno;
331 }