Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
more leaks
[simgrid.git] / src / smpi / colls / reduce_scatter-mpich.c
1 #include "colls_private.h"
2
3 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
4 {
5     /* a mask for the high order bits that should be copied as-is */
6     int high_mask = ~((0x1 << bits) - 1);
7     int retval = x & high_mask;
8     int i;
9
10     for (i = 0; i < bits; ++i) {
11         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
12         retval |= bitval << ((bits - i) - 1);
13     }
14
15     return retval;
16 }
17
18
19 int smpi_coll_tuned_reduce_scatter_mpich_pair(void *sendbuf, void *recvbuf, int recvcounts[],
20                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
21 {
22     int   rank, comm_size, i;
23     MPI_Aint extent, true_extent, true_lb; 
24     int  *disps;
25     void *tmp_recvbuf;
26     int mpi_errno = MPI_SUCCESS;
27     int total_count, dst, src;
28     int is_commutative;
29     comm_size = smpi_comm_size(comm);
30     rank = smpi_comm_rank(comm);
31
32     extent =smpi_datatype_get_extent(datatype);
33     smpi_datatype_extent(datatype, &true_lb, &true_extent);
34     
35     if (smpi_op_is_commute(op)) {
36         is_commutative = 1;
37     }
38
39     disps = (int*)xbt_malloc( comm_size * sizeof(int));
40
41     total_count = 0;
42     for (i=0; i<comm_size; i++) {
43         disps[i] = total_count;
44         total_count += recvcounts[i];
45     }
46     
47     if (total_count == 0) {
48         xbt_free(disps);
49         return MPI_ERR_COUNT;
50     }
51
52         if (sendbuf != MPI_IN_PLACE) {
53             /* copy local data into recvbuf */
54             smpi_datatype_copy(((char *)sendbuf+disps[rank]*extent),
55                                        recvcounts[rank], datatype, recvbuf,
56                                        recvcounts[rank], datatype);
57         }
58         
59         /* allocate temporary buffer to store incoming data */
60         tmp_recvbuf = (void*)xbt_malloc(recvcounts[rank]*(max(true_extent,extent))+1);
61         /* adjust for potential negative lower bound in datatype */
62         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
63         
64         for (i=1; i<comm_size; i++) {
65             src = (rank - i + comm_size) % comm_size;
66             dst = (rank + i) % comm_size;
67             
68             /* send the data that dst needs. recv data that this process
69                needs from src into tmp_recvbuf */
70             if (sendbuf != MPI_IN_PLACE) 
71                 smpi_mpi_sendrecv(((char *)sendbuf+disps[dst]*extent), 
72                                              recvcounts[dst], datatype, dst,
73                                              COLL_TAG_SCATTER, tmp_recvbuf,
74                                              recvcounts[rank], datatype, src,
75                                              COLL_TAG_SCATTER, comm,
76                                              MPI_STATUS_IGNORE);
77             else
78                 smpi_mpi_sendrecv(((char *)recvbuf+disps[dst]*extent), 
79                                              recvcounts[dst], datatype, dst,
80                                              COLL_TAG_SCATTER, tmp_recvbuf,
81                                              recvcounts[rank], datatype, src,
82                                              COLL_TAG_SCATTER, comm,
83                                              MPI_STATUS_IGNORE);
84             
85             if (is_commutative || (src < rank)) {
86                 if (sendbuf != MPI_IN_PLACE) {
87                      smpi_op_apply( op,
88                                                   tmp_recvbuf, recvbuf, &recvcounts[rank],
89                                &datatype); 
90                 }
91                 else {
92                     smpi_op_apply(op, 
93                         tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
94                         &recvcounts[rank], &datatype);
95                     /* we can't store the result at the beginning of
96                        recvbuf right here because there is useful data
97                        there that other process/processes need. at the
98                        end, we will copy back the result to the
99                        beginning of recvbuf. */
100                 }
101             }
102             else {
103                 if (sendbuf != MPI_IN_PLACE) {
104                     smpi_op_apply(op, 
105                        recvbuf, tmp_recvbuf, &recvcounts[rank], &datatype);
106                     /* copy result back into recvbuf */
107                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
108                                                datatype, recvbuf,
109                                                recvcounts[rank], datatype);
110                     if (mpi_errno) return(mpi_errno);
111                 }
112                 else {
113                     smpi_op_apply(op, 
114                         ((char *)recvbuf+disps[rank]*extent),
115                         tmp_recvbuf, &recvcounts[rank], &datatype);
116                     /* copy result back into recvbuf */
117                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
118                                                datatype, 
119                                                ((char *)recvbuf +
120                                                 disps[rank]*extent), 
121                                                recvcounts[rank], datatype);
122                     if (mpi_errno) return(mpi_errno);
123                 }
124             }
125         }
126         
127         /* if MPI_IN_PLACE, move output data to the beginning of
128            recvbuf. already done for rank 0. */
129         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
130             mpi_errno = smpi_datatype_copy(((char *)recvbuf +
131                                         disps[rank]*extent),  
132                                        recvcounts[rank], datatype,
133                                        recvbuf, 
134                                        recvcounts[rank], datatype );
135             if (mpi_errno) return(mpi_errno);
136         }
137     
138         xbt_free(disps);
139         xbt_free(tmp_recvbuf);
140
141         return MPI_SUCCESS;
142 }
143     
144
145 int smpi_coll_tuned_reduce_scatter_mpich_noncomm(void *sendbuf, void *recvbuf, int recvcounts[],
146                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
147 {
148     int mpi_errno = MPI_SUCCESS;
149     int comm_size = smpi_comm_size(comm) ;
150     int rank = smpi_comm_rank(comm);
151     int pof2;
152     int log2_comm_size;
153     int i, k;
154     int recv_offset, send_offset;
155     int block_size, total_count, size;
156     MPI_Aint true_extent, true_lb;
157     int buf0_was_inout;
158     void *tmp_buf0;
159     void *tmp_buf1;
160     void *result_ptr;
161
162     smpi_datatype_extent(datatype, &true_lb, &true_extent);
163
164     pof2 = 1;
165     log2_comm_size = 0;
166     while (pof2 < comm_size) {
167         pof2 <<= 1;
168         ++log2_comm_size;
169     }
170
171     /* begin error checking */
172     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
173
174     for (i = 0; i < (comm_size - 1); ++i) {
175         xbt_assert(recvcounts[i] == recvcounts[i+1]);
176     }
177     /* end error checking */
178
179     /* size of a block (count of datatype per block, NOT bytes per block) */
180     block_size = recvcounts[0];
181     total_count = block_size * comm_size;
182
183     tmp_buf0=( void *)xbt_malloc( true_extent * total_count);
184     tmp_buf1=( void *)xbt_malloc( true_extent * total_count);
185     void *tmp_buf0_save=tmp_buf0;
186     void *tmp_buf1_save=tmp_buf1;
187
188     /* adjust for potential negative lower bound in datatype */
189     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
190     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
191
192     /* Copy our send data to tmp_buf0.  We do this one block at a time and
193        permute the blocks as we go according to the mirror permutation. */
194     for (i = 0; i < comm_size; ++i) {
195         mpi_errno = smpi_datatype_copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
196                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
197         if (mpi_errno) return(mpi_errno);
198     }
199     buf0_was_inout = 1;
200
201     send_offset = 0;
202     recv_offset = 0;
203     size = total_count;
204     for (k = 0; k < log2_comm_size; ++k) {
205         /* use a double-buffering scheme to avoid local copies */
206         char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
207         char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
208         int peer = rank ^ (0x1 << k);
209         size /= 2;
210
211         if (rank > peer) {
212             /* we have the higher rank: send top half, recv bottom half */
213             recv_offset += size;
214         }
215         else {
216             /* we have the lower rank: recv top half, send bottom half */
217             send_offset += size;
218         }
219
220         smpi_mpi_sendrecv(outgoing_data + send_offset*true_extent,
221                                      size, datatype, peer, COLL_TAG_SCATTER,
222                                      incoming_data + recv_offset*true_extent,
223                                      size, datatype, peer, COLL_TAG_SCATTER,
224                                      comm, MPI_STATUS_IGNORE);
225         /* always perform the reduction at recv_offset, the data at send_offset
226            is now our peer's responsibility */
227         if (rank > peer) {
228             /* higher ranked value so need to call op(received_data, my_data) */
229             smpi_op_apply(op, 
230                    incoming_data + recv_offset*true_extent,
231                      outgoing_data + recv_offset*true_extent,
232                      &size, &datatype );
233             /* buf0_was_inout = buf0_was_inout; */
234         }
235         else {
236             /* lower ranked value so need to call op(my_data, received_data) */
237             smpi_op_apply( op,
238                      outgoing_data + recv_offset*true_extent,
239                      incoming_data + recv_offset*true_extent,
240                      &size, &datatype);
241             buf0_was_inout = !buf0_was_inout;
242         }
243
244         /* the next round of send/recv needs to happen within the block (of size
245            "size") that we just received and reduced */
246         send_offset = recv_offset;
247     }
248
249     xbt_assert(size == recvcounts[rank]);
250
251     /* copy the reduced data to the recvbuf */
252     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
253     mpi_errno = smpi_datatype_copy(result_ptr, size, datatype,
254                                recvbuf, size, datatype);
255     xbt_free(tmp_buf0_save);
256     xbt_free(tmp_buf1_save);
257     if (mpi_errno) return(mpi_errno);
258     return MPI_SUCCESS;
259 }
260
261
262
263 int smpi_coll_tuned_reduce_scatter_mpich_rdb(void *sendbuf, void *recvbuf, int recvcounts[],
264                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
265 {
266     int   rank, comm_size, i;
267     MPI_Aint extent, true_extent, true_lb; 
268     int  *disps;
269     void *tmp_recvbuf, *tmp_results;
270     int mpi_errno = MPI_SUCCESS;
271     int dis[2], blklens[2], total_count, dst;
272     int mask, dst_tree_root, my_tree_root, j, k;
273     int received;
274     MPI_Datatype sendtype, recvtype;
275     int nprocs_completed, tmp_mask, tree_root, is_commutative=0;
276     comm_size = smpi_comm_size(comm);
277     rank = smpi_comm_rank(comm);
278
279     extent =smpi_datatype_get_extent(datatype);
280     smpi_datatype_extent(datatype, &true_lb, &true_extent);
281     
282     if (smpi_op_is_commute(op)) {
283         is_commutative = 1;
284     }
285
286     disps = (int*)xbt_malloc( comm_size * sizeof(int));
287
288     total_count = 0;
289     for (i=0; i<comm_size; i++) {
290         disps[i] = total_count;
291         total_count += recvcounts[i];
292     }
293     
294             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
295
296             /* need to allocate temporary buffer to receive incoming data*/
297             tmp_recvbuf= (void *) xbt_malloc( total_count*(max(true_extent,extent)));
298             /* adjust for potential negative lower bound in datatype */
299             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
300
301             /* need to allocate another temporary buffer to accumulate
302                results */
303             tmp_results = (void *)xbt_malloc( total_count*(max(true_extent,extent)));
304             /* adjust for potential negative lower bound in datatype */
305             tmp_results = (void *)((char*)tmp_results - true_lb);
306
307             /* copy sendbuf into tmp_results */
308             if (sendbuf != MPI_IN_PLACE)
309                 mpi_errno = smpi_datatype_copy(sendbuf, total_count, datatype,
310                                            tmp_results, total_count, datatype);
311             else
312                 mpi_errno = smpi_datatype_copy(recvbuf, total_count, datatype,
313                                            tmp_results, total_count, datatype);
314
315             if (mpi_errno) return(mpi_errno);
316
317             mask = 0x1;
318             i = 0;
319             while (mask < comm_size) {
320                 dst = rank ^ mask;
321
322                 dst_tree_root = dst >> i;
323                 dst_tree_root <<= i;
324
325                 my_tree_root = rank >> i;
326                 my_tree_root <<= i;
327
328                 /* At step 1, processes exchange (n-n/p) amount of
329                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
330                    amount of data, and so forth. We use derived datatypes for this.
331
332                    At each step, a process does not need to send data
333                    indexed from my_tree_root to
334                    my_tree_root+mask-1. Similarly, a process won't receive
335                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
336
337                 /* calculate sendtype */
338                 blklens[0] = blklens[1] = 0;
339                 for (j=0; j<my_tree_root; j++)
340                     blklens[0] += recvcounts[j];
341                 for (j=my_tree_root+mask; j<comm_size; j++)
342                     blklens[1] += recvcounts[j];
343
344                 dis[0] = 0;
345                 dis[1] = blklens[0];
346                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
347                     dis[1] += recvcounts[j];
348
349                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &sendtype);
350                 if (mpi_errno) return(mpi_errno);
351                 
352                 smpi_datatype_commit(&sendtype);
353
354                 /* calculate recvtype */
355                 blklens[0] = blklens[1] = 0;
356                 for (j=0; j<dst_tree_root && j<comm_size; j++)
357                     blklens[0] += recvcounts[j];
358                 for (j=dst_tree_root+mask; j<comm_size; j++)
359                     blklens[1] += recvcounts[j];
360
361                 dis[0] = 0;
362                 dis[1] = blklens[0];
363                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
364                     dis[1] += recvcounts[j];
365
366                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &recvtype);
367                 if (mpi_errno) return(mpi_errno);
368                 
369                 smpi_datatype_commit(&recvtype);
370
371                 received = 0;
372                 if (dst < comm_size) {
373                     /* tmp_results contains data to be sent in each step. Data is
374                        received in tmp_recvbuf and then accumulated into
375                        tmp_results. accumulation is done later below.   */ 
376
377                     smpi_mpi_sendrecv(tmp_results, 1, sendtype, dst,
378                                                  COLL_TAG_SCATTER,
379                                                  tmp_recvbuf, 1, recvtype, dst,
380                                                  COLL_TAG_SCATTER, comm,
381                                                  MPI_STATUS_IGNORE);
382                     received = 1;
383                 }
384
385                 /* if some processes in this process's subtree in this step
386                    did not have any destination process to communicate with
387                    because of non-power-of-two, we need to send them the
388                    result. We use a logarithmic recursive-halfing algorithm
389                    for this. */
390
391                 if (dst_tree_root + mask > comm_size) {
392                     nprocs_completed = comm_size - my_tree_root - mask;
393                     /* nprocs_completed is the number of processes in this
394                        subtree that have all the data. Send data to others
395                        in a tree fashion. First find root of current tree
396                        that is being divided into two. k is the number of
397                        least-significant bits in this process's rank that
398                        must be zeroed out to find the rank of the root */ 
399                     j = mask;
400                     k = 0;
401                     while (j) {
402                         j >>= 1;
403                         k++;
404                     }
405                     k--;
406
407                     tmp_mask = mask >> 1;
408                     while (tmp_mask) {
409                         dst = rank ^ tmp_mask;
410
411                         tree_root = rank >> k;
412                         tree_root <<= k;
413
414                         /* send only if this proc has data and destination
415                            doesn't have data. at any step, multiple processes
416                            can send if they have the data */
417                         if ((dst > rank) && 
418                             (rank < tree_root + nprocs_completed)
419                             && (dst >= tree_root + nprocs_completed)) {
420                             /* send the current result */
421                             smpi_mpi_send(tmp_recvbuf, 1, recvtype,
422                                                      dst, COLL_TAG_SCATTER,
423                                                      comm);
424                         }
425                         /* recv only if this proc. doesn't have data and sender
426                            has data */
427                         else if ((dst < rank) && 
428                                  (dst < tree_root + nprocs_completed) &&
429                                  (rank >= tree_root + nprocs_completed)) {
430                             smpi_mpi_recv(tmp_recvbuf, 1, recvtype, dst,
431                                                      COLL_TAG_SCATTER,
432                                                      comm, MPI_STATUS_IGNORE); 
433                             received = 1;
434                         }
435                         tmp_mask >>= 1;
436                         k--;
437                     }
438                 }
439
440                 /* The following reduction is done here instead of after 
441                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
442                    because to do it above, in the noncommutative 
443                    case, we would need an extra temp buffer so as not to
444                    overwrite temp_recvbuf, because temp_recvbuf may have
445                    to be communicated to other processes in the
446                    non-power-of-two case. To avoid that extra allocation,
447                    we do the reduce here. */
448                 if (received) {
449                     if (is_commutative || (dst_tree_root < my_tree_root)) {
450                         {
451                                  smpi_op_apply(op, 
452                                tmp_recvbuf, tmp_results, &blklens[0],
453                                &datatype); 
454                                 smpi_op_apply(op, 
455                                ((char *)tmp_recvbuf + dis[1]*extent),
456                                ((char *)tmp_results + dis[1]*extent),
457                                &blklens[1], &datatype); 
458                         }
459                     }
460                     else {
461                         {
462                                  smpi_op_apply(op,
463                                    tmp_results, tmp_recvbuf, &blklens[0],
464                                    &datatype); 
465                                  smpi_op_apply(op,
466                                    ((char *)tmp_results + dis[1]*extent),
467                                    ((char *)tmp_recvbuf + dis[1]*extent),
468                                    &blklens[1], &datatype); 
469                         }
470                         /* copy result back into tmp_results */
471                         mpi_errno = smpi_datatype_copy(tmp_recvbuf, 1, recvtype, 
472                                                    tmp_results, 1, recvtype);
473                         if (mpi_errno) return(mpi_errno);
474                     }
475                 }
476
477                 smpi_datatype_free(&sendtype);
478                 smpi_datatype_free(&recvtype);
479
480                 mask <<= 1;
481                 i++;
482             }
483
484             /* now copy final results from tmp_results to recvbuf */
485             mpi_errno = smpi_datatype_copy(((char *)tmp_results+disps[rank]*extent),
486                                        recvcounts[rank], datatype, recvbuf,
487                                        recvcounts[rank], datatype);
488             if (mpi_errno) return(mpi_errno);
489
490     xbt_free(disps);
491     xbt_free(tmp_recvbuf);
492     xbt_free(tmp_results);
493     return MPI_SUCCESS;
494         }
495
496