Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
use tuned barrier here if provided
[simgrid.git] / src / smpi / colls / reduce_scatter-mpich.c
1 #include "colls_private.h"
2
3 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
4 {
5     /* a mask for the high order bits that should be copied as-is */
6     int high_mask = ~((0x1 << bits) - 1);
7     int retval = x & high_mask;
8     int i;
9
10     for (i = 0; i < bits; ++i) {
11         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
12         retval |= bitval << ((bits - i) - 1);
13     }
14
15     return retval;
16 }
17
18
19 int smpi_coll_tuned_reduce_scatter_mpich_pair(void *sendbuf, void *recvbuf, int recvcounts[],
20                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
21 {
22     int   rank, comm_size, i;
23     MPI_Aint extent, true_extent, true_lb; 
24     int  *disps;
25     void *tmp_recvbuf;
26     int mpi_errno = MPI_SUCCESS;
27     int total_count, dst, src;
28     int is_commutative;
29     comm_size = smpi_comm_size(comm);
30     rank = smpi_comm_rank(comm);
31
32     extent =smpi_datatype_get_extent(datatype);
33     smpi_datatype_extent(datatype, &true_lb, &true_extent);
34     
35     if (smpi_op_is_commute(op)) {
36         is_commutative = 1;
37     }
38
39     disps = (int*)xbt_malloc( comm_size * sizeof(int));
40
41     total_count = 0;
42     for (i=0; i<comm_size; i++) {
43         disps[i] = total_count;
44         total_count += recvcounts[i];
45     }
46     
47     if (total_count == 0) {
48         xbt_free(disps);
49         return MPI_ERR_COUNT;
50     }
51
52         if (sendbuf != MPI_IN_PLACE) {
53             /* copy local data into recvbuf */
54             smpi_datatype_copy(((char *)sendbuf+disps[rank]*extent),
55                                        recvcounts[rank], datatype, recvbuf,
56                                        recvcounts[rank], datatype);
57         }
58         
59         /* allocate temporary buffer to store incoming data */
60         tmp_recvbuf = (void*)xbt_malloc(recvcounts[rank]*(max(true_extent,extent))+1);
61         /* adjust for potential negative lower bound in datatype */
62         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
63         
64         for (i=1; i<comm_size; i++) {
65             src = (rank - i + comm_size) % comm_size;
66             dst = (rank + i) % comm_size;
67             
68             /* send the data that dst needs. recv data that this process
69                needs from src into tmp_recvbuf */
70             if (sendbuf != MPI_IN_PLACE) 
71                 smpi_mpi_sendrecv(((char *)sendbuf+disps[dst]*extent), 
72                                              recvcounts[dst], datatype, dst,
73                                              COLL_TAG_SCATTER, tmp_recvbuf,
74                                              recvcounts[rank], datatype, src,
75                                              COLL_TAG_SCATTER, comm,
76                                              MPI_STATUS_IGNORE);
77             else
78                 smpi_mpi_sendrecv(((char *)recvbuf+disps[dst]*extent), 
79                                              recvcounts[dst], datatype, dst,
80                                              COLL_TAG_SCATTER, tmp_recvbuf,
81                                              recvcounts[rank], datatype, src,
82                                              COLL_TAG_SCATTER, comm,
83                                              MPI_STATUS_IGNORE);
84             
85             if (is_commutative || (src < rank)) {
86                 if (sendbuf != MPI_IN_PLACE) {
87                      smpi_op_apply( op,
88                                                   tmp_recvbuf, recvbuf, &recvcounts[rank],
89                                &datatype); 
90                 }
91                 else {
92                     smpi_op_apply(op, 
93                         tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
94                         &recvcounts[rank], &datatype);
95                     /* we can't store the result at the beginning of
96                        recvbuf right here because there is useful data
97                        there that other process/processes need. at the
98                        end, we will copy back the result to the
99                        beginning of recvbuf. */
100                 }
101             }
102             else {
103                 if (sendbuf != MPI_IN_PLACE) {
104                     smpi_op_apply(op, 
105                        recvbuf, tmp_recvbuf, &recvcounts[rank], &datatype);
106                     /* copy result back into recvbuf */
107                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
108                                                datatype, recvbuf,
109                                                recvcounts[rank], datatype);
110                     if (mpi_errno) return(mpi_errno);
111                 }
112                 else {
113                     smpi_op_apply(op, 
114                         ((char *)recvbuf+disps[rank]*extent),
115                         tmp_recvbuf, &recvcounts[rank], &datatype);
116                     /* copy result back into recvbuf */
117                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
118                                                datatype, 
119                                                ((char *)recvbuf +
120                                                 disps[rank]*extent), 
121                                                recvcounts[rank], datatype);
122                     if (mpi_errno) return(mpi_errno);
123                 }
124             }
125         }
126         
127         /* if MPI_IN_PLACE, move output data to the beginning of
128            recvbuf. already done for rank 0. */
129         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
130             mpi_errno = smpi_datatype_copy(((char *)recvbuf +
131                                         disps[rank]*extent),  
132                                        recvcounts[rank], datatype,
133                                        recvbuf, 
134                                        recvcounts[rank], datatype );
135             if (mpi_errno) return(mpi_errno);
136         }
137     
138         xbt_free(disps);
139         xbt_free(tmp_recvbuf);
140
141         return MPI_SUCCESS;
142 }
143     
144
145 int smpi_coll_tuned_reduce_scatter_mpich_noncomm(void *sendbuf, void *recvbuf, int recvcounts[],
146                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
147 {
148     int mpi_errno = MPI_SUCCESS;
149     int comm_size = smpi_comm_size(comm) ;
150     int rank = smpi_comm_rank(comm);
151     int pof2;
152     int log2_comm_size;
153     int i, k;
154     int recv_offset, send_offset;
155     int block_size, total_count, size;
156     MPI_Aint true_extent, true_lb;
157     int buf0_was_inout;
158     void *tmp_buf0;
159     void *tmp_buf1;
160     void *result_ptr;
161
162     smpi_datatype_extent(datatype, &true_lb, &true_extent);
163
164     pof2 = 1;
165     log2_comm_size = 0;
166     while (pof2 < comm_size) {
167         pof2 <<= 1;
168         ++log2_comm_size;
169     }
170
171     /* begin error checking */
172     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
173
174     for (i = 0; i < (comm_size - 1); ++i) {
175         xbt_assert(recvcounts[i] == recvcounts[i+1]);
176     }
177     /* end error checking */
178
179     /* size of a block (count of datatype per block, NOT bytes per block) */
180     block_size = recvcounts[0];
181     total_count = block_size * comm_size;
182
183     tmp_buf0=( void *)xbt_malloc( true_extent * total_count);
184     tmp_buf1=( void *)xbt_malloc( true_extent * total_count);
185     /* adjust for potential negative lower bound in datatype */
186     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
187     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
188
189     /* Copy our send data to tmp_buf0.  We do this one block at a time and
190        permute the blocks as we go according to the mirror permutation. */
191     for (i = 0; i < comm_size; ++i) {
192         mpi_errno = smpi_datatype_copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
193                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
194         if (mpi_errno) return(mpi_errno);
195     }
196     buf0_was_inout = 1;
197
198     send_offset = 0;
199     recv_offset = 0;
200     size = total_count;
201     for (k = 0; k < log2_comm_size; ++k) {
202         /* use a double-buffering scheme to avoid local copies */
203         char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
204         char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
205         int peer = rank ^ (0x1 << k);
206         size /= 2;
207
208         if (rank > peer) {
209             /* we have the higher rank: send top half, recv bottom half */
210             recv_offset += size;
211         }
212         else {
213             /* we have the lower rank: recv top half, send bottom half */
214             send_offset += size;
215         }
216
217         smpi_mpi_sendrecv(outgoing_data + send_offset*true_extent,
218                                      size, datatype, peer, COLL_TAG_SCATTER,
219                                      incoming_data + recv_offset*true_extent,
220                                      size, datatype, peer, COLL_TAG_SCATTER,
221                                      comm, MPI_STATUS_IGNORE);
222         /* always perform the reduction at recv_offset, the data at send_offset
223            is now our peer's responsibility */
224         if (rank > peer) {
225             /* higher ranked value so need to call op(received_data, my_data) */
226             smpi_op_apply(op, 
227                    incoming_data + recv_offset*true_extent,
228                      outgoing_data + recv_offset*true_extent,
229                      &size, &datatype );
230             /* buf0_was_inout = buf0_was_inout; */
231         }
232         else {
233             /* lower ranked value so need to call op(my_data, received_data) */
234             smpi_op_apply( op,
235                      outgoing_data + recv_offset*true_extent,
236                      incoming_data + recv_offset*true_extent,
237                      &size, &datatype);
238             buf0_was_inout = !buf0_was_inout;
239         }
240
241         /* the next round of send/recv needs to happen within the block (of size
242            "size") that we just received and reduced */
243         send_offset = recv_offset;
244     }
245
246     xbt_assert(size == recvcounts[rank]);
247
248     /* copy the reduced data to the recvbuf */
249     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
250     mpi_errno = smpi_datatype_copy(result_ptr, size, datatype,
251                                recvbuf, size, datatype);
252     if (mpi_errno) return(mpi_errno);
253     return MPI_SUCCESS;
254 }
255
256
257
258 int smpi_coll_tuned_reduce_scatter_mpich_rdb(void *sendbuf, void *recvbuf, int recvcounts[],
259                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
260 {
261     int   rank, comm_size, i;
262     MPI_Aint extent, true_extent, true_lb; 
263     int  *disps;
264     void *tmp_recvbuf, *tmp_results;
265     int mpi_errno = MPI_SUCCESS;
266     int dis[2], blklens[2], total_count, dst;
267     int mask, dst_tree_root, my_tree_root, j, k;
268     int received;
269     MPI_Datatype sendtype, recvtype;
270     int nprocs_completed, tmp_mask, tree_root, is_commutative=0;
271     comm_size = smpi_comm_size(comm);
272     rank = smpi_comm_rank(comm);
273
274     extent =smpi_datatype_get_extent(datatype);
275     smpi_datatype_extent(datatype, &true_lb, &true_extent);
276     
277     if (smpi_op_is_commute(op)) {
278         is_commutative = 1;
279     }
280
281     disps = (int*)xbt_malloc( comm_size * sizeof(int));
282
283     total_count = 0;
284     for (i=0; i<comm_size; i++) {
285         disps[i] = total_count;
286         total_count += recvcounts[i];
287     }
288     
289             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
290
291             /* need to allocate temporary buffer to receive incoming data*/
292             tmp_recvbuf= (void *) xbt_malloc( total_count*(max(true_extent,extent)));
293             /* adjust for potential negative lower bound in datatype */
294             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
295
296             /* need to allocate another temporary buffer to accumulate
297                results */
298             tmp_results = (void *)xbt_malloc( total_count*(max(true_extent,extent)));
299             /* adjust for potential negative lower bound in datatype */
300             tmp_results = (void *)((char*)tmp_results - true_lb);
301
302             /* copy sendbuf into tmp_results */
303             if (sendbuf != MPI_IN_PLACE)
304                 mpi_errno = smpi_datatype_copy(sendbuf, total_count, datatype,
305                                            tmp_results, total_count, datatype);
306             else
307                 mpi_errno = smpi_datatype_copy(recvbuf, total_count, datatype,
308                                            tmp_results, total_count, datatype);
309
310             if (mpi_errno) return(mpi_errno);
311
312             mask = 0x1;
313             i = 0;
314             while (mask < comm_size) {
315                 dst = rank ^ mask;
316
317                 dst_tree_root = dst >> i;
318                 dst_tree_root <<= i;
319
320                 my_tree_root = rank >> i;
321                 my_tree_root <<= i;
322
323                 /* At step 1, processes exchange (n-n/p) amount of
324                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
325                    amount of data, and so forth. We use derived datatypes for this.
326
327                    At each step, a process does not need to send data
328                    indexed from my_tree_root to
329                    my_tree_root+mask-1. Similarly, a process won't receive
330                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
331
332                 /* calculate sendtype */
333                 blklens[0] = blklens[1] = 0;
334                 for (j=0; j<my_tree_root; j++)
335                     blklens[0] += recvcounts[j];
336                 for (j=my_tree_root+mask; j<comm_size; j++)
337                     blklens[1] += recvcounts[j];
338
339                 dis[0] = 0;
340                 dis[1] = blklens[0];
341                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
342                     dis[1] += recvcounts[j];
343
344                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &sendtype);
345                 if (mpi_errno) return(mpi_errno);
346                 
347                 smpi_datatype_commit(&sendtype);
348
349                 /* calculate recvtype */
350                 blklens[0] = blklens[1] = 0;
351                 for (j=0; j<dst_tree_root && j<comm_size; j++)
352                     blklens[0] += recvcounts[j];
353                 for (j=dst_tree_root+mask; j<comm_size; j++)
354                     blklens[1] += recvcounts[j];
355
356                 dis[0] = 0;
357                 dis[1] = blklens[0];
358                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
359                     dis[1] += recvcounts[j];
360
361                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &recvtype);
362                 if (mpi_errno) return(mpi_errno);
363                 
364                 smpi_datatype_commit(&recvtype);
365
366                 received = 0;
367                 if (dst < comm_size) {
368                     /* tmp_results contains data to be sent in each step. Data is
369                        received in tmp_recvbuf and then accumulated into
370                        tmp_results. accumulation is done later below.   */ 
371
372                     smpi_mpi_sendrecv(tmp_results, 1, sendtype, dst,
373                                                  COLL_TAG_SCATTER,
374                                                  tmp_recvbuf, 1, recvtype, dst,
375                                                  COLL_TAG_SCATTER, comm,
376                                                  MPI_STATUS_IGNORE);
377                     received = 1;
378                 }
379
380                 /* if some processes in this process's subtree in this step
381                    did not have any destination process to communicate with
382                    because of non-power-of-two, we need to send them the
383                    result. We use a logarithmic recursive-halfing algorithm
384                    for this. */
385
386                 if (dst_tree_root + mask > comm_size) {
387                     nprocs_completed = comm_size - my_tree_root - mask;
388                     /* nprocs_completed is the number of processes in this
389                        subtree that have all the data. Send data to others
390                        in a tree fashion. First find root of current tree
391                        that is being divided into two. k is the number of
392                        least-significant bits in this process's rank that
393                        must be zeroed out to find the rank of the root */ 
394                     j = mask;
395                     k = 0;
396                     while (j) {
397                         j >>= 1;
398                         k++;
399                     }
400                     k--;
401
402                     tmp_mask = mask >> 1;
403                     while (tmp_mask) {
404                         dst = rank ^ tmp_mask;
405
406                         tree_root = rank >> k;
407                         tree_root <<= k;
408
409                         /* send only if this proc has data and destination
410                            doesn't have data. at any step, multiple processes
411                            can send if they have the data */
412                         if ((dst > rank) && 
413                             (rank < tree_root + nprocs_completed)
414                             && (dst >= tree_root + nprocs_completed)) {
415                             /* send the current result */
416                             smpi_mpi_send(tmp_recvbuf, 1, recvtype,
417                                                      dst, COLL_TAG_SCATTER,
418                                                      comm);
419                         }
420                         /* recv only if this proc. doesn't have data and sender
421                            has data */
422                         else if ((dst < rank) && 
423                                  (dst < tree_root + nprocs_completed) &&
424                                  (rank >= tree_root + nprocs_completed)) {
425                             smpi_mpi_recv(tmp_recvbuf, 1, recvtype, dst,
426                                                      COLL_TAG_SCATTER,
427                                                      comm, MPI_STATUS_IGNORE); 
428                             received = 1;
429                         }
430                         tmp_mask >>= 1;
431                         k--;
432                     }
433                 }
434
435                 /* The following reduction is done here instead of after 
436                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
437                    because to do it above, in the noncommutative 
438                    case, we would need an extra temp buffer so as not to
439                    overwrite temp_recvbuf, because temp_recvbuf may have
440                    to be communicated to other processes in the
441                    non-power-of-two case. To avoid that extra allocation,
442                    we do the reduce here. */
443                 if (received) {
444                     if (is_commutative || (dst_tree_root < my_tree_root)) {
445                         {
446                                  smpi_op_apply(op, 
447                                tmp_recvbuf, tmp_results, &blklens[0],
448                                &datatype); 
449                                 smpi_op_apply(op, 
450                                ((char *)tmp_recvbuf + dis[1]*extent),
451                                ((char *)tmp_results + dis[1]*extent),
452                                &blklens[1], &datatype); 
453                         }
454                     }
455                     else {
456                         {
457                                  smpi_op_apply(op,
458                                    tmp_results, tmp_recvbuf, &blklens[0],
459                                    &datatype); 
460                                  smpi_op_apply(op,
461                                    ((char *)tmp_results + dis[1]*extent),
462                                    ((char *)tmp_recvbuf + dis[1]*extent),
463                                    &blklens[1], &datatype); 
464                         }
465                         /* copy result back into tmp_results */
466                         mpi_errno = smpi_datatype_copy(tmp_recvbuf, 1, recvtype, 
467                                                    tmp_results, 1, recvtype);
468                         if (mpi_errno) return(mpi_errno);
469                     }
470                 }
471
472                 //smpi_datatype_free(&sendtype);
473                 //smpi_datatype_free(&recvtype);
474
475                 mask <<= 1;
476                 i++;
477             }
478
479             /* now copy final results from tmp_results to recvbuf */
480             mpi_errno = smpi_datatype_copy(((char *)tmp_results+disps[rank]*extent),
481                                        recvcounts[rank], datatype, recvbuf,
482                                        recvcounts[rank], datatype);
483             if (mpi_errno) return(mpi_errno);
484     xbt_free(disps);
485     xbt_free(tmp_recvbuf);
486     xbt_free(tmp_results);
487     return MPI_SUCCESS;
488         }
489
490