Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
8481ab78e63ba3e6d97601c2080568f3370ef0d2
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-mpich.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.h"
8
9 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
10 {
11     /* a mask for the high order bits that should be copied as-is */
12     int high_mask = ~((0x1 << bits) - 1);
13     int retval = x & high_mask;
14     int i;
15
16     for (i = 0; i < bits; ++i) {
17         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
18         retval |= bitval << ((bits - i) - 1);
19     }
20
21     return retval;
22 }
23 namespace simgrid{
24 namespace smpi{
25
26 int Coll_reduce_scatter_mpich_pair::reduce_scatter(void *sendbuf, void *recvbuf, int recvcounts[],
27                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
28 {
29     int   rank, comm_size, i;
30     MPI_Aint extent, true_extent, true_lb; 
31     int  *disps;
32     void *tmp_recvbuf;
33     int mpi_errno = MPI_SUCCESS;
34     int total_count, dst, src;
35     int is_commutative;
36     comm_size = comm->size();
37     rank = comm->rank();
38
39     extent =datatype->get_extent();
40     datatype->extent(&true_lb, &true_extent);
41     
42     if (op->is_commutative()) {
43         is_commutative = 1;
44     }
45
46     disps = (int*)xbt_malloc( comm_size * sizeof(int));
47
48     total_count = 0;
49     for (i=0; i<comm_size; i++) {
50         disps[i] = total_count;
51         total_count += recvcounts[i];
52     }
53     
54     if (total_count == 0) {
55         xbt_free(disps);
56         return MPI_ERR_COUNT;
57     }
58
59         if (sendbuf != MPI_IN_PLACE) {
60             /* copy local data into recvbuf */
61             Datatype::copy(((char *)sendbuf+disps[rank]*extent),
62                                        recvcounts[rank], datatype, recvbuf,
63                                        recvcounts[rank], datatype);
64         }
65         
66         /* allocate temporary buffer to store incoming data */
67         tmp_recvbuf = (void*)smpi_get_tmp_recvbuffer(recvcounts[rank]*(MAX(true_extent,extent))+1);
68         /* adjust for potential negative lower bound in datatype */
69         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
70         
71         for (i=1; i<comm_size; i++) {
72             src = (rank - i + comm_size) % comm_size;
73             dst = (rank + i) % comm_size;
74             
75             /* send the data that dst needs. recv data that this process
76                needs from src into tmp_recvbuf */
77             if (sendbuf != MPI_IN_PLACE) 
78                 Request::sendrecv(((char *)sendbuf+disps[dst]*extent), 
79                                              recvcounts[dst], datatype, dst,
80                                              COLL_TAG_SCATTER, tmp_recvbuf,
81                                              recvcounts[rank], datatype, src,
82                                              COLL_TAG_SCATTER, comm,
83                                              MPI_STATUS_IGNORE);
84             else
85                 Request::sendrecv(((char *)recvbuf+disps[dst]*extent), 
86                                              recvcounts[dst], datatype, dst,
87                                              COLL_TAG_SCATTER, tmp_recvbuf,
88                                              recvcounts[rank], datatype, src,
89                                              COLL_TAG_SCATTER, comm,
90                                              MPI_STATUS_IGNORE);
91             
92             if (is_commutative || (src < rank)) {
93                 if (sendbuf != MPI_IN_PLACE) {
94                      if(op!=MPI_OP_NULL) op->apply(
95                                                   tmp_recvbuf, recvbuf, &recvcounts[rank],
96                                datatype); 
97                 }
98                 else {
99                     if(op!=MPI_OP_NULL) op->apply( 
100                         tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
101                         &recvcounts[rank], datatype);
102                     /* we can't store the result at the beginning of
103                        recvbuf right here because there is useful data
104                        there that other process/processes need. at the
105                        end, we will copy back the result to the
106                        beginning of recvbuf. */
107                 }
108             }
109             else {
110                 if (sendbuf != MPI_IN_PLACE) {
111                     if(op!=MPI_OP_NULL) op->apply( 
112                        recvbuf, tmp_recvbuf, &recvcounts[rank], datatype);
113                     /* copy result back into recvbuf */
114                     mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank],
115                                                datatype, recvbuf,
116                                                recvcounts[rank], datatype);
117                     if (mpi_errno) return(mpi_errno);
118                 }
119                 else {
120                     if(op!=MPI_OP_NULL) op->apply( 
121                         ((char *)recvbuf+disps[rank]*extent),
122                         tmp_recvbuf, &recvcounts[rank], datatype);
123                     /* copy result back into recvbuf */
124                     mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank],
125                                                datatype, 
126                                                ((char *)recvbuf +
127                                                 disps[rank]*extent), 
128                                                recvcounts[rank], datatype);
129                     if (mpi_errno) return(mpi_errno);
130                 }
131             }
132         }
133         
134         /* if MPI_IN_PLACE, move output data to the beginning of
135            recvbuf. already done for rank 0. */
136         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
137             mpi_errno = Datatype::copy(((char *)recvbuf +
138                                         disps[rank]*extent),  
139                                        recvcounts[rank], datatype,
140                                        recvbuf, 
141                                        recvcounts[rank], datatype );
142             if (mpi_errno) return(mpi_errno);
143         }
144     
145         xbt_free(disps);
146         smpi_free_tmp_buffer(tmp_recvbuf);
147
148         return MPI_SUCCESS;
149 }
150     
151
152 int Coll_reduce_scatter_mpich_noncomm::reduce_scatter(void *sendbuf, void *recvbuf, int recvcounts[],
153                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
154 {
155     int mpi_errno = MPI_SUCCESS;
156     int comm_size = comm->size() ;
157     int rank = comm->rank();
158     int pof2;
159     int log2_comm_size;
160     int i, k;
161     int recv_offset, send_offset;
162     int block_size, total_count, size;
163     MPI_Aint true_extent, true_lb;
164     int buf0_was_inout;
165     void *tmp_buf0;
166     void *tmp_buf1;
167     void *result_ptr;
168
169     datatype->extent(&true_lb, &true_extent);
170
171     pof2 = 1;
172     log2_comm_size = 0;
173     while (pof2 < comm_size) {
174         pof2 <<= 1;
175         ++log2_comm_size;
176     }
177
178     /* begin error checking */
179     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
180
181     for (i = 0; i < (comm_size - 1); ++i) {
182         xbt_assert(recvcounts[i] == recvcounts[i+1]);
183     }
184     /* end error checking */
185
186     /* size of a block (count of datatype per block, NOT bytes per block) */
187     block_size = recvcounts[0];
188     total_count = block_size * comm_size;
189
190     tmp_buf0=( void *)smpi_get_tmp_sendbuffer( true_extent * total_count);
191     tmp_buf1=( void *)smpi_get_tmp_recvbuffer( true_extent * total_count);
192     void *tmp_buf0_save=tmp_buf0;
193     void *tmp_buf1_save=tmp_buf1;
194
195     /* adjust for potential negative lower bound in datatype */
196     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
197     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
198
199     /* Copy our send data to tmp_buf0.  We do this one block at a time and
200        permute the blocks as we go according to the mirror permutation. */
201     for (i = 0; i < comm_size; ++i) {
202         mpi_errno = Datatype::copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
203                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
204         if (mpi_errno) return(mpi_errno);
205     }
206     buf0_was_inout = 1;
207
208     send_offset = 0;
209     recv_offset = 0;
210     size = total_count;
211     for (k = 0; k < log2_comm_size; ++k) {
212         /* use a double-buffering scheme to avoid local copies */
213         char *incoming_data = static_cast<char*>(buf0_was_inout ? tmp_buf1 : tmp_buf0);
214         char *outgoing_data = static_cast<char*>(buf0_was_inout ? tmp_buf0 : tmp_buf1);
215         int peer = rank ^ (0x1 << k);
216         size /= 2;
217
218         if (rank > peer) {
219             /* we have the higher rank: send top half, recv bottom half */
220             recv_offset += size;
221         }
222         else {
223             /* we have the lower rank: recv top half, send bottom half */
224             send_offset += size;
225         }
226
227         Request::sendrecv(outgoing_data + send_offset*true_extent,
228                                      size, datatype, peer, COLL_TAG_SCATTER,
229                                      incoming_data + recv_offset*true_extent,
230                                      size, datatype, peer, COLL_TAG_SCATTER,
231                                      comm, MPI_STATUS_IGNORE);
232         /* always perform the reduction at recv_offset, the data at send_offset
233            is now our peer's responsibility */
234         if (rank > peer) {
235             /* higher ranked value so need to call op(received_data, my_data) */
236             if(op!=MPI_OP_NULL) op->apply( 
237                    incoming_data + recv_offset*true_extent,
238                      outgoing_data + recv_offset*true_extent,
239                      &size, datatype );
240             /* buf0_was_inout = buf0_was_inout; */
241         }
242         else {
243             /* lower ranked value so need to call op(my_data, received_data) */
244             if(op!=MPI_OP_NULL) op->apply(
245                      outgoing_data + recv_offset*true_extent,
246                      incoming_data + recv_offset*true_extent,
247                      &size, datatype);
248             buf0_was_inout = !buf0_was_inout;
249         }
250
251         /* the next round of send/recv needs to happen within the block (of size
252            "size") that we just received and reduced */
253         send_offset = recv_offset;
254     }
255
256     xbt_assert(size == recvcounts[rank]);
257
258     /* copy the reduced data to the recvbuf */
259     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
260     mpi_errno = Datatype::copy(result_ptr, size, datatype,
261                                recvbuf, size, datatype);
262     smpi_free_tmp_buffer(tmp_buf0_save);
263     smpi_free_tmp_buffer(tmp_buf1_save);
264     if (mpi_errno) return(mpi_errno);
265     return MPI_SUCCESS;
266 }
267
268
269
270 int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf, int recvcounts[],
271                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
272 {
273     int   rank, comm_size, i;
274     MPI_Aint extent, true_extent, true_lb; 
275     int  *disps;
276     void *tmp_recvbuf, *tmp_results;
277     int mpi_errno = MPI_SUCCESS;
278     int dis[2], blklens[2], total_count, dst;
279     int mask, dst_tree_root, my_tree_root, j, k;
280     int received;
281     MPI_Datatype sendtype, recvtype;
282     int nprocs_completed, tmp_mask, tree_root, is_commutative=0;
283     comm_size = comm->size();
284     rank = comm->rank();
285
286     extent =datatype->get_extent();
287     datatype->extent(&true_lb, &true_extent);
288     
289     if ((op==MPI_OP_NULL) || op->is_commutative()) {
290         is_commutative = 1;
291     }
292
293     disps = (int*)xbt_malloc( comm_size * sizeof(int));
294
295     total_count = 0;
296     for (i=0; i<comm_size; i++) {
297         disps[i] = total_count;
298         total_count += recvcounts[i];
299     }
300     
301             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
302
303             /* need to allocate temporary buffer to receive incoming data*/
304             tmp_recvbuf= (void *) smpi_get_tmp_recvbuffer( total_count*(MAX(true_extent,extent)));
305             /* adjust for potential negative lower bound in datatype */
306             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
307
308             /* need to allocate another temporary buffer to accumulate
309                results */
310             tmp_results = (void *)smpi_get_tmp_sendbuffer( total_count*(MAX(true_extent,extent)));
311             /* adjust for potential negative lower bound in datatype */
312             tmp_results = (void *)((char*)tmp_results - true_lb);
313
314             /* copy sendbuf into tmp_results */
315             if (sendbuf != MPI_IN_PLACE)
316                 mpi_errno = Datatype::copy(sendbuf, total_count, datatype,
317                                            tmp_results, total_count, datatype);
318             else
319                 mpi_errno = Datatype::copy(recvbuf, total_count, datatype,
320                                            tmp_results, total_count, datatype);
321
322             if (mpi_errno) return(mpi_errno);
323
324             mask = 0x1;
325             i = 0;
326             while (mask < comm_size) {
327                 dst = rank ^ mask;
328
329                 dst_tree_root = dst >> i;
330                 dst_tree_root <<= i;
331
332                 my_tree_root = rank >> i;
333                 my_tree_root <<= i;
334
335                 /* At step 1, processes exchange (n-n/p) amount of
336                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
337                    amount of data, and so forth. We use derived datatypes for this.
338
339                    At each step, a process does not need to send data
340                    indexed from my_tree_root to
341                    my_tree_root+mask-1. Similarly, a process won't receive
342                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
343
344                 /* calculate sendtype */
345                 blklens[0] = blklens[1] = 0;
346                 for (j=0; j<my_tree_root; j++)
347                     blklens[0] += recvcounts[j];
348                 for (j=my_tree_root+mask; j<comm_size; j++)
349                     blklens[1] += recvcounts[j];
350
351                 dis[0] = 0;
352                 dis[1] = blklens[0];
353                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
354                     dis[1] += recvcounts[j];
355
356                 mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &sendtype);
357                 if (mpi_errno) return(mpi_errno);
358                 
359                 sendtype->commit();
360
361                 /* calculate recvtype */
362                 blklens[0] = blklens[1] = 0;
363                 for (j=0; j<dst_tree_root && j<comm_size; j++)
364                     blklens[0] += recvcounts[j];
365                 for (j=dst_tree_root+mask; j<comm_size; j++)
366                     blklens[1] += recvcounts[j];
367
368                 dis[0] = 0;
369                 dis[1] = blklens[0];
370                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
371                     dis[1] += recvcounts[j];
372
373                 mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &recvtype);
374                 if (mpi_errno) return(mpi_errno);
375                 
376                 recvtype->commit();
377
378                 received = 0;
379                 if (dst < comm_size) {
380                     /* tmp_results contains data to be sent in each step. Data is
381                        received in tmp_recvbuf and then accumulated into
382                        tmp_results. accumulation is done later below.   */ 
383
384                     Request::sendrecv(tmp_results, 1, sendtype, dst,
385                                                  COLL_TAG_SCATTER,
386                                                  tmp_recvbuf, 1, recvtype, dst,
387                                                  COLL_TAG_SCATTER, comm,
388                                                  MPI_STATUS_IGNORE);
389                     received = 1;
390                 }
391
392                 /* if some processes in this process's subtree in this step
393                    did not have any destination process to communicate with
394                    because of non-power-of-two, we need to send them the
395                    result. We use a logarithmic recursive-halfing algorithm
396                    for this. */
397
398                 if (dst_tree_root + mask > comm_size) {
399                     nprocs_completed = comm_size - my_tree_root - mask;
400                     /* nprocs_completed is the number of processes in this
401                        subtree that have all the data. Send data to others
402                        in a tree fashion. First find root of current tree
403                        that is being divided into two. k is the number of
404                        least-significant bits in this process's rank that
405                        must be zeroed out to find the rank of the root */ 
406                     j = mask;
407                     k = 0;
408                     while (j) {
409                         j >>= 1;
410                         k++;
411                     }
412                     k--;
413
414                     tmp_mask = mask >> 1;
415                     while (tmp_mask) {
416                         dst = rank ^ tmp_mask;
417
418                         tree_root = rank >> k;
419                         tree_root <<= k;
420
421                         /* send only if this proc has data and destination
422                            doesn't have data. at any step, multiple processes
423                            can send if they have the data */
424                         if ((dst > rank) && 
425                             (rank < tree_root + nprocs_completed)
426                             && (dst >= tree_root + nprocs_completed)) {
427                             /* send the current result */
428                             Request::send(tmp_recvbuf, 1, recvtype,
429                                                      dst, COLL_TAG_SCATTER,
430                                                      comm);
431                         }
432                         /* recv only if this proc. doesn't have data and sender
433                            has data */
434                         else if ((dst < rank) && 
435                                  (dst < tree_root + nprocs_completed) &&
436                                  (rank >= tree_root + nprocs_completed)) {
437                             Request::recv(tmp_recvbuf, 1, recvtype, dst,
438                                                      COLL_TAG_SCATTER,
439                                                      comm, MPI_STATUS_IGNORE); 
440                             received = 1;
441                         }
442                         tmp_mask >>= 1;
443                         k--;
444                     }
445                 }
446
447                 /* The following reduction is done here instead of after 
448                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
449                    because to do it above, in the noncommutative 
450                    case, we would need an extra temp buffer so as not to
451                    overwrite temp_recvbuf, because temp_recvbuf may have
452                    to be communicated to other processes in the
453                    non-power-of-two case. To avoid that extra allocation,
454                    we do the reduce here. */
455                 if (received) {
456                     if (is_commutative || (dst_tree_root < my_tree_root)) {
457                         {
458                                  if(op!=MPI_OP_NULL) op->apply( 
459                                tmp_recvbuf, tmp_results, &blklens[0],
460                                datatype); 
461                                 if(op!=MPI_OP_NULL) op->apply( 
462                                ((char *)tmp_recvbuf + dis[1]*extent),
463                                ((char *)tmp_results + dis[1]*extent),
464                                &blklens[1], datatype); 
465                         }
466                     }
467                     else {
468                         {
469                                  if(op!=MPI_OP_NULL) op->apply(
470                                    tmp_results, tmp_recvbuf, &blklens[0],
471                                    datatype); 
472                                  if(op!=MPI_OP_NULL) op->apply(
473                                    ((char *)tmp_results + dis[1]*extent),
474                                    ((char *)tmp_recvbuf + dis[1]*extent),
475                                    &blklens[1], datatype); 
476                         }
477                         /* copy result back into tmp_results */
478                         mpi_errno = Datatype::copy(tmp_recvbuf, 1, recvtype, 
479                                                    tmp_results, 1, recvtype);
480                         if (mpi_errno) return(mpi_errno);
481                     }
482                 }
483
484                 Datatype::unref(sendtype);
485                 Datatype::unref(recvtype);
486
487                 mask <<= 1;
488                 i++;
489             }
490
491             /* now copy final results from tmp_results to recvbuf */
492             mpi_errno = Datatype::copy(((char *)tmp_results+disps[rank]*extent),
493                                        recvcounts[rank], datatype, recvbuf,
494                                        recvcounts[rank], datatype);
495             if (mpi_errno) return(mpi_errno);
496
497     xbt_free(disps);
498     smpi_free_tmp_buffer(tmp_recvbuf);
499     smpi_free_tmp_buffer(tmp_results);
500     return MPI_SUCCESS;
501         }
502 }
503 }
504