Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
7afa4714e0422d54525d055008f80d60bbe02f4e
[simgrid.git] / src / smpi / colls / reduce_scatter-mpich.c
1 #include "colls_private.h"
2 #define MPIR_REDUCE_SCATTER_TAG 222
3
4 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
5 {
6     /* a mask for the high order bits that should be copied as-is */
7     int high_mask = ~((0x1 << bits) - 1);
8     int retval = x & high_mask;
9     int i;
10
11     for (i = 0; i < bits; ++i) {
12         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
13         retval |= bitval << ((bits - i) - 1);
14     }
15
16     return retval;
17 }
18
19
20 int smpi_coll_tuned_reduce_scatter_mpich_pair(void *sendbuf, void *recvbuf, int recvcounts[],
21                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
22 {
23     int   rank, comm_size, i;
24     MPI_Aint extent, true_extent, true_lb; 
25     int  *disps;
26     void *tmp_recvbuf;
27     int mpi_errno = MPI_SUCCESS;
28     int type_size, total_count, nbytes, dst, src;
29     int is_commutative;
30     comm_size = smpi_comm_size(comm);
31     rank = smpi_comm_rank(comm);
32
33     extent =smpi_datatype_get_extent(datatype);
34     smpi_datatype_extent(datatype, &true_lb, &true_extent);
35     
36     if (smpi_op_is_commute(op)) {
37         is_commutative = 1;
38     }
39
40     disps = (int*)xbt_malloc( comm_size * sizeof(int));
41
42     total_count = 0;
43     for (i=0; i<comm_size; i++) {
44         disps[i] = total_count;
45         total_count += recvcounts[i];
46     }
47     
48     if (total_count == 0) {
49         return MPI_ERR_COUNT;
50     }
51
52     type_size= smpi_datatype_size(datatype);
53     nbytes = total_count * type_size;
54     
55
56         if (sendbuf != MPI_IN_PLACE) {
57             /* copy local data into recvbuf */
58             smpi_datatype_copy(((char *)sendbuf+disps[rank]*extent),
59                                        recvcounts[rank], datatype, recvbuf,
60                                        recvcounts[rank], datatype);
61         }
62         
63         /* allocate temporary buffer to store incoming data */
64         tmp_recvbuf = (void*)xbt_malloc(recvcounts[rank]*(max(true_extent,extent))+1);
65         /* adjust for potential negative lower bound in datatype */
66         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
67         
68         for (i=1; i<comm_size; i++) {
69             src = (rank - i + comm_size) % comm_size;
70             dst = (rank + i) % comm_size;
71             
72             /* send the data that dst needs. recv data that this process
73                needs from src into tmp_recvbuf */
74             if (sendbuf != MPI_IN_PLACE) 
75                 smpi_mpi_sendrecv(((char *)sendbuf+disps[dst]*extent), 
76                                              recvcounts[dst], datatype, dst,
77                                              MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf,
78                                              recvcounts[rank], datatype, src,
79                                              MPIR_REDUCE_SCATTER_TAG, comm,
80                                              MPI_STATUS_IGNORE);
81             else
82                 smpi_mpi_sendrecv(((char *)recvbuf+disps[dst]*extent), 
83                                              recvcounts[dst], datatype, dst,
84                                              MPIR_REDUCE_SCATTER_TAG, tmp_recvbuf,
85                                              recvcounts[rank], datatype, src,
86                                              MPIR_REDUCE_SCATTER_TAG, comm,
87                                              MPI_STATUS_IGNORE);
88             
89             if (is_commutative || (src < rank)) {
90                 if (sendbuf != MPI_IN_PLACE) {
91                      smpi_op_apply( op,
92                                                   tmp_recvbuf, recvbuf, &recvcounts[rank],
93                                &datatype); 
94                 }
95                 else {
96                     smpi_op_apply(op, 
97                         tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
98                         &recvcounts[rank], &datatype);
99                     /* we can't store the result at the beginning of
100                        recvbuf right here because there is useful data
101                        there that other process/processes need. at the
102                        end, we will copy back the result to the
103                        beginning of recvbuf. */
104                 }
105             }
106             else {
107                 if (sendbuf != MPI_IN_PLACE) {
108                     smpi_op_apply(op, 
109                        recvbuf, tmp_recvbuf, &recvcounts[rank], &datatype);
110                     /* copy result back into recvbuf */
111                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
112                                                datatype, recvbuf,
113                                                recvcounts[rank], datatype);
114                     if (mpi_errno) return(mpi_errno);
115                 }
116                 else {
117                     smpi_op_apply(op, 
118                         ((char *)recvbuf+disps[rank]*extent),
119                         tmp_recvbuf, &recvcounts[rank], &datatype);
120                     /* copy result back into recvbuf */
121                     mpi_errno = smpi_datatype_copy(tmp_recvbuf, recvcounts[rank],
122                                                datatype, 
123                                                ((char *)recvbuf +
124                                                 disps[rank]*extent), 
125                                                recvcounts[rank], datatype);
126                     if (mpi_errno) return(mpi_errno);
127                 }
128             }
129         }
130         
131         /* if MPI_IN_PLACE, move output data to the beginning of
132            recvbuf. already done for rank 0. */
133         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
134             mpi_errno = smpi_datatype_copy(((char *)recvbuf +
135                                         disps[rank]*extent),  
136                                        recvcounts[rank], datatype,
137                                        recvbuf, 
138                                        recvcounts[rank], datatype );
139             if (mpi_errno) return(mpi_errno);
140         }
141     
142 return MPI_SUCCESS;
143 }
144     
145
146 int smpi_coll_tuned_reduce_scatter_mpich_noncomm(void *sendbuf, void *recvbuf, int recvcounts[],
147                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
148 {
149     int mpi_errno = MPI_SUCCESS;
150     int comm_size = smpi_comm_size(comm) ;
151     int rank = smpi_comm_rank(comm);
152     int pof2;
153     int log2_comm_size;
154     int i, k;
155     int recv_offset, send_offset;
156     int block_size, total_count, size;
157     MPI_Aint true_extent, true_lb;
158     int buf0_was_inout;
159     void *tmp_buf0;
160     void *tmp_buf1;
161     void *result_ptr;
162
163     smpi_datatype_extent(datatype, &true_lb, &true_extent);
164
165     pof2 = 1;
166     log2_comm_size = 0;
167     while (pof2 < comm_size) {
168         pof2 <<= 1;
169         ++log2_comm_size;
170     }
171
172     /* begin error checking */
173     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
174
175     for (i = 0; i < (comm_size - 1); ++i) {
176         xbt_assert(recvcounts[i] == recvcounts[i+1]);
177     }
178     /* end error checking */
179
180     /* size of a block (count of datatype per block, NOT bytes per block) */
181     block_size = recvcounts[0];
182     total_count = block_size * comm_size;
183
184     tmp_buf0=( void *)xbt_malloc( true_extent * total_count);
185     tmp_buf1=( void *)xbt_malloc( true_extent * total_count);
186     /* adjust for potential negative lower bound in datatype */
187     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
188     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
189
190     /* Copy our send data to tmp_buf0.  We do this one block at a time and
191        permute the blocks as we go according to the mirror permutation. */
192     for (i = 0; i < comm_size; ++i) {
193         mpi_errno = smpi_datatype_copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
194                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
195         if (mpi_errno) return(mpi_errno);
196     }
197     buf0_was_inout = 1;
198
199     send_offset = 0;
200     recv_offset = 0;
201     size = total_count;
202     for (k = 0; k < log2_comm_size; ++k) {
203         /* use a double-buffering scheme to avoid local copies */
204         char *incoming_data = (buf0_was_inout ? tmp_buf1 : tmp_buf0);
205         char *outgoing_data = (buf0_was_inout ? tmp_buf0 : tmp_buf1);
206         int peer = rank ^ (0x1 << k);
207         size /= 2;
208
209         if (rank > peer) {
210             /* we have the higher rank: send top half, recv bottom half */
211             recv_offset += size;
212         }
213         else {
214             /* we have the lower rank: recv top half, send bottom half */
215             send_offset += size;
216         }
217
218         smpi_mpi_sendrecv(outgoing_data + send_offset*true_extent,
219                                      size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
220                                      incoming_data + recv_offset*true_extent,
221                                      size, datatype, peer, MPIR_REDUCE_SCATTER_TAG,
222                                      comm, MPI_STATUS_IGNORE);
223         /* always perform the reduction at recv_offset, the data at send_offset
224            is now our peer's responsibility */
225         if (rank > peer) {
226             /* higher ranked value so need to call op(received_data, my_data) */
227             smpi_op_apply(op, 
228                    incoming_data + recv_offset*true_extent,
229                      outgoing_data + recv_offset*true_extent,
230                      &size, &datatype );
231             /* buf0_was_inout = buf0_was_inout; */
232         }
233         else {
234             /* lower ranked value so need to call op(my_data, received_data) */
235             smpi_op_apply( op,
236                      outgoing_data + recv_offset*true_extent,
237                      incoming_data + recv_offset*true_extent,
238                      &size, &datatype);
239             buf0_was_inout = !buf0_was_inout;
240         }
241
242         /* the next round of send/recv needs to happen within the block (of size
243            "size") that we just received and reduced */
244         send_offset = recv_offset;
245     }
246
247     xbt_assert(size == recvcounts[rank]);
248
249     /* copy the reduced data to the recvbuf */
250     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
251     mpi_errno = smpi_datatype_copy(result_ptr, size, datatype,
252                                recvbuf, size, datatype);
253     if (mpi_errno) return(mpi_errno);
254     return MPI_SUCCESS;
255 }
256
257
258
259 int smpi_coll_tuned_reduce_scatter_mpich_rdb(void *sendbuf, void *recvbuf, int recvcounts[],
260                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
261 {
262     int   rank, comm_size, i;
263     MPI_Aint extent, true_extent, true_lb; 
264     int  *disps;
265     void *tmp_recvbuf, *tmp_results;
266     int mpi_errno = MPI_SUCCESS;
267     int type_size, dis[2], blklens[2], total_count, nbytes, dst;
268     int mask, dst_tree_root, my_tree_root, j, k;
269     int received;
270     MPI_Datatype sendtype, recvtype;
271     int nprocs_completed, tmp_mask, tree_root, is_commutative;
272     comm_size = smpi_comm_size(comm);
273     rank = smpi_comm_rank(comm);
274
275     extent =smpi_datatype_get_extent(datatype);
276     smpi_datatype_extent(datatype, &true_lb, &true_extent);
277     
278     if (smpi_op_is_commute(op)) {
279         is_commutative = 1;
280     }
281
282     disps = (int*)xbt_malloc( comm_size * sizeof(int));
283
284     total_count = 0;
285     for (i=0; i<comm_size; i++) {
286         disps[i] = total_count;
287         total_count += recvcounts[i];
288     }
289     
290     type_size= smpi_datatype_size(datatype);
291     nbytes = total_count * type_size;
292     
293
294             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
295
296             /* need to allocate temporary buffer to receive incoming data*/
297             tmp_recvbuf= (void *) xbt_malloc( total_count*(max(true_extent,extent)));
298             /* adjust for potential negative lower bound in datatype */
299             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
300
301             /* need to allocate another temporary buffer to accumulate
302                results */
303             tmp_results = (void *)xbt_malloc( total_count*(max(true_extent,extent)));
304             /* adjust for potential negative lower bound in datatype */
305             tmp_results = (void *)((char*)tmp_results - true_lb);
306
307             /* copy sendbuf into tmp_results */
308             if (sendbuf != MPI_IN_PLACE)
309                 mpi_errno = smpi_datatype_copy(sendbuf, total_count, datatype,
310                                            tmp_results, total_count, datatype);
311             else
312                 mpi_errno = smpi_datatype_copy(recvbuf, total_count, datatype,
313                                            tmp_results, total_count, datatype);
314
315             if (mpi_errno) return(mpi_errno);
316
317             mask = 0x1;
318             i = 0;
319             while (mask < comm_size) {
320                 dst = rank ^ mask;
321
322                 dst_tree_root = dst >> i;
323                 dst_tree_root <<= i;
324
325                 my_tree_root = rank >> i;
326                 my_tree_root <<= i;
327
328                 /* At step 1, processes exchange (n-n/p) amount of
329                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
330                    amount of data, and so forth. We use derived datatypes for this.
331
332                    At each step, a process does not need to send data
333                    indexed from my_tree_root to
334                    my_tree_root+mask-1. Similarly, a process won't receive
335                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
336
337                 /* calculate sendtype */
338                 blklens[0] = blklens[1] = 0;
339                 for (j=0; j<my_tree_root; j++)
340                     blklens[0] += recvcounts[j];
341                 for (j=my_tree_root+mask; j<comm_size; j++)
342                     blklens[1] += recvcounts[j];
343
344                 dis[0] = 0;
345                 dis[1] = blklens[0];
346                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
347                     dis[1] += recvcounts[j];
348
349                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &sendtype);
350                 if (mpi_errno) return(mpi_errno);
351                 
352                 smpi_datatype_commit(&sendtype);
353
354                 /* calculate recvtype */
355                 blklens[0] = blklens[1] = 0;
356                 for (j=0; j<dst_tree_root && j<comm_size; j++)
357                     blklens[0] += recvcounts[j];
358                 for (j=dst_tree_root+mask; j<comm_size; j++)
359                     blklens[1] += recvcounts[j];
360
361                 dis[0] = 0;
362                 dis[1] = blklens[0];
363                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
364                     dis[1] += recvcounts[j];
365
366                 mpi_errno = smpi_datatype_indexed(2, blklens, dis, datatype, &recvtype);
367                 if (mpi_errno) return(mpi_errno);
368                 
369                 smpi_datatype_commit(&recvtype);
370
371                 received = 0;
372                 if (dst < comm_size) {
373                     /* tmp_results contains data to be sent in each step. Data is
374                        received in tmp_recvbuf and then accumulated into
375                        tmp_results. accumulation is done later below.   */ 
376
377                     smpi_mpi_sendrecv(tmp_results, 1, sendtype, dst,
378                                                  MPIR_REDUCE_SCATTER_TAG, 
379                                                  tmp_recvbuf, 1, recvtype, dst,
380                                                  MPIR_REDUCE_SCATTER_TAG, comm,
381                                                  MPI_STATUS_IGNORE);
382                     received = 1;
383                 }
384
385                 /* if some processes in this process's subtree in this step
386                    did not have any destination process to communicate with
387                    because of non-power-of-two, we need to send them the
388                    result. We use a logarithmic recursive-halfing algorithm
389                    for this. */
390
391                 if (dst_tree_root + mask > comm_size) {
392                     nprocs_completed = comm_size - my_tree_root - mask;
393                     /* nprocs_completed is the number of processes in this
394                        subtree that have all the data. Send data to others
395                        in a tree fashion. First find root of current tree
396                        that is being divided into two. k is the number of
397                        least-significant bits in this process's rank that
398                        must be zeroed out to find the rank of the root */ 
399                     j = mask;
400                     k = 0;
401                     while (j) {
402                         j >>= 1;
403                         k++;
404                     }
405                     k--;
406
407                     tmp_mask = mask >> 1;
408                     while (tmp_mask) {
409                         dst = rank ^ tmp_mask;
410
411                         tree_root = rank >> k;
412                         tree_root <<= k;
413
414                         /* send only if this proc has data and destination
415                            doesn't have data. at any step, multiple processes
416                            can send if they have the data */
417                         if ((dst > rank) && 
418                             (rank < tree_root + nprocs_completed)
419                             && (dst >= tree_root + nprocs_completed)) {
420                             /* send the current result */
421                             smpi_mpi_send(tmp_recvbuf, 1, recvtype,
422                                                      dst, MPIR_REDUCE_SCATTER_TAG,
423                                                      comm);
424                         }
425                         /* recv only if this proc. doesn't have data and sender
426                            has data */
427                         else if ((dst < rank) && 
428                                  (dst < tree_root + nprocs_completed) &&
429                                  (rank >= tree_root + nprocs_completed)) {
430                             smpi_mpi_recv(tmp_recvbuf, 1, recvtype, dst,
431                                                      MPIR_REDUCE_SCATTER_TAG,
432                                                      comm, MPI_STATUS_IGNORE); 
433                             received = 1;
434                         }
435                         tmp_mask >>= 1;
436                         k--;
437                     }
438                 }
439
440                 /* The following reduction is done here instead of after 
441                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
442                    because to do it above, in the noncommutative 
443                    case, we would need an extra temp buffer so as not to
444                    overwrite temp_recvbuf, because temp_recvbuf may have
445                    to be communicated to other processes in the
446                    non-power-of-two case. To avoid that extra allocation,
447                    we do the reduce here. */
448                 if (received) {
449                     if (is_commutative || (dst_tree_root < my_tree_root)) {
450                         {
451                                  smpi_op_apply(op, 
452                                tmp_recvbuf, tmp_results, &blklens[0],
453                                &datatype); 
454                                 smpi_op_apply(op, 
455                                ((char *)tmp_recvbuf + dis[1]*extent),
456                                ((char *)tmp_results + dis[1]*extent),
457                                &blklens[1], &datatype); 
458                         }
459                     }
460                     else {
461                         {
462                                  smpi_op_apply(op,
463                                    tmp_results, tmp_recvbuf, &blklens[0],
464                                    &datatype); 
465                                  smpi_op_apply(op,
466                                    ((char *)tmp_results + dis[1]*extent),
467                                    ((char *)tmp_recvbuf + dis[1]*extent),
468                                    &blklens[1], &datatype); 
469                         }
470                         /* copy result back into tmp_results */
471                         mpi_errno = smpi_datatype_copy(tmp_recvbuf, 1, recvtype, 
472                                                    tmp_results, 1, recvtype);
473                         if (mpi_errno) return(mpi_errno);
474                     }
475                 }
476
477                 //smpi_datatype_free(&sendtype);
478                 //smpi_datatype_free(&recvtype);
479
480                 mask <<= 1;
481                 i++;
482             }
483
484             /* now copy final results from tmp_results to recvbuf */
485             mpi_errno = smpi_datatype_copy(((char *)tmp_results+disps[rank]*extent),
486                                        recvcounts[rank], datatype, recvbuf,
487                                        recvcounts[rank], datatype);
488             if (mpi_errno) return(mpi_errno);
489
490     return MPI_SUCCESS;
491         }
492
493