Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Use std::min/std::max instead of MIN/MAX in C++ files.
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-mpich.cpp
1 /* Copyright (c) 2013-2017. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8 #include <algorithm>
9
10 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
11 {
12     /* a mask for the high order bits that should be copied as-is */
13     int high_mask = ~((0x1 << bits) - 1);
14     int retval = x & high_mask;
15     int i;
16
17     for (i = 0; i < bits; ++i) {
18         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
19         retval |= bitval << ((bits - i) - 1);
20     }
21
22     return retval;
23 }
24 namespace simgrid{
25 namespace smpi{
26
27 int Coll_reduce_scatter_mpich_pair::reduce_scatter(void *sendbuf, void *recvbuf, int recvcounts[],
28                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
29 {
30     int   rank, comm_size, i;
31     MPI_Aint extent, true_extent, true_lb;
32     int  *disps;
33     void *tmp_recvbuf;
34     int mpi_errno = MPI_SUCCESS;
35     int total_count, dst, src;
36     int is_commutative;
37     comm_size = comm->size();
38     rank = comm->rank();
39
40     extent =datatype->get_extent();
41     datatype->extent(&true_lb, &true_extent);
42
43     if (op->is_commutative()) {
44         is_commutative = 1;
45     }
46
47     disps = (int*)xbt_malloc( comm_size * sizeof(int));
48
49     total_count = 0;
50     for (i=0; i<comm_size; i++) {
51         disps[i] = total_count;
52         total_count += recvcounts[i];
53     }
54
55     if (total_count == 0) {
56         xbt_free(disps);
57         return MPI_ERR_COUNT;
58     }
59
60         if (sendbuf != MPI_IN_PLACE) {
61             /* copy local data into recvbuf */
62             Datatype::copy(((char *)sendbuf+disps[rank]*extent),
63                                        recvcounts[rank], datatype, recvbuf,
64                                        recvcounts[rank], datatype);
65         }
66
67         /* allocate temporary buffer to store incoming data */
68         tmp_recvbuf = (void*)smpi_get_tmp_recvbuffer(recvcounts[rank] * std::max(true_extent, extent) + 1);
69         /* adjust for potential negative lower bound in datatype */
70         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
71
72         for (i=1; i<comm_size; i++) {
73             src = (rank - i + comm_size) % comm_size;
74             dst = (rank + i) % comm_size;
75
76             /* send the data that dst needs. recv data that this process
77                needs from src into tmp_recvbuf */
78             if (sendbuf != MPI_IN_PLACE)
79                 Request::sendrecv(((char *)sendbuf+disps[dst]*extent),
80                                              recvcounts[dst], datatype, dst,
81                                              COLL_TAG_SCATTER, tmp_recvbuf,
82                                              recvcounts[rank], datatype, src,
83                                              COLL_TAG_SCATTER, comm,
84                                              MPI_STATUS_IGNORE);
85             else
86                 Request::sendrecv(((char *)recvbuf+disps[dst]*extent),
87                                              recvcounts[dst], datatype, dst,
88                                              COLL_TAG_SCATTER, tmp_recvbuf,
89                                              recvcounts[rank], datatype, src,
90                                              COLL_TAG_SCATTER, comm,
91                                              MPI_STATUS_IGNORE);
92
93             if (is_commutative || (src < rank)) {
94                 if (sendbuf != MPI_IN_PLACE) {
95                   if (op != MPI_OP_NULL)
96                     op->apply(tmp_recvbuf, recvbuf, &recvcounts[rank], datatype);
97                 }
98                 else {
99                   if (op != MPI_OP_NULL)
100                     op->apply(tmp_recvbuf, ((char*)recvbuf + disps[rank] * extent), &recvcounts[rank], datatype);
101                   /* we can't store the result at the beginning of
102                      recvbuf right here because there is useful data
103                      there that other process/processes need. at the
104                      end, we will copy back the result to the
105                      beginning of recvbuf. */
106                 }
107             }
108             else {
109                 if (sendbuf != MPI_IN_PLACE) {
110                   if (op != MPI_OP_NULL)
111                     op->apply(recvbuf, tmp_recvbuf, &recvcounts[rank], datatype);
112                   /* copy result back into recvbuf */
113                   mpi_errno =
114                       Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype, recvbuf, recvcounts[rank], datatype);
115                   if (mpi_errno)
116                     return (mpi_errno);
117                 }
118                 else {
119                   if (op != MPI_OP_NULL)
120                     op->apply(((char*)recvbuf + disps[rank] * extent), tmp_recvbuf, &recvcounts[rank], datatype);
121                   /* copy result back into recvbuf */
122                   mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype,
123                                              ((char*)recvbuf + disps[rank] * extent), recvcounts[rank], datatype);
124                   if (mpi_errno)
125                     return (mpi_errno);
126                 }
127             }
128         }
129
130         /* if MPI_IN_PLACE, move output data to the beginning of
131            recvbuf. already done for rank 0. */
132         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
133             mpi_errno = Datatype::copy(((char *)recvbuf +
134                                         disps[rank]*extent),
135                                        recvcounts[rank], datatype,
136                                        recvbuf,
137                                        recvcounts[rank], datatype );
138             if (mpi_errno) return(mpi_errno);
139         }
140
141         xbt_free(disps);
142         smpi_free_tmp_buffer(tmp_recvbuf);
143
144         return MPI_SUCCESS;
145 }
146
147
148 int Coll_reduce_scatter_mpich_noncomm::reduce_scatter(void *sendbuf, void *recvbuf, int recvcounts[],
149                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
150 {
151     int mpi_errno = MPI_SUCCESS;
152     int comm_size = comm->size() ;
153     int rank = comm->rank();
154     int pof2;
155     int log2_comm_size;
156     int i, k;
157     int recv_offset, send_offset;
158     int block_size, total_count, size;
159     MPI_Aint true_extent, true_lb;
160     int buf0_was_inout;
161     void *tmp_buf0;
162     void *tmp_buf1;
163     void *result_ptr;
164
165     datatype->extent(&true_lb, &true_extent);
166
167     pof2 = 1;
168     log2_comm_size = 0;
169     while (pof2 < comm_size) {
170         pof2 <<= 1;
171         ++log2_comm_size;
172     }
173
174     /* begin error checking */
175     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
176
177     for (i = 0; i < (comm_size - 1); ++i) {
178         xbt_assert(recvcounts[i] == recvcounts[i+1]);
179     }
180     /* end error checking */
181
182     /* size of a block (count of datatype per block, NOT bytes per block) */
183     block_size = recvcounts[0];
184     total_count = block_size * comm_size;
185
186     tmp_buf0=( void *)smpi_get_tmp_sendbuffer( true_extent * total_count);
187     tmp_buf1=( void *)smpi_get_tmp_recvbuffer( true_extent * total_count);
188     void *tmp_buf0_save=tmp_buf0;
189     void *tmp_buf1_save=tmp_buf1;
190
191     /* adjust for potential negative lower bound in datatype */
192     tmp_buf0 = (void *)((char*)tmp_buf0 - true_lb);
193     tmp_buf1 = (void *)((char*)tmp_buf1 - true_lb);
194
195     /* Copy our send data to tmp_buf0.  We do this one block at a time and
196        permute the blocks as we go according to the mirror permutation. */
197     for (i = 0; i < comm_size; ++i) {
198         mpi_errno = Datatype::copy((char *)(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size, datatype,
199                                    (char *)tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size, datatype);
200         if (mpi_errno) return(mpi_errno);
201     }
202     buf0_was_inout = 1;
203
204     send_offset = 0;
205     recv_offset = 0;
206     size = total_count;
207     for (k = 0; k < log2_comm_size; ++k) {
208         /* use a double-buffering scheme to avoid local copies */
209         char *incoming_data = static_cast<char*>(buf0_was_inout ? tmp_buf1 : tmp_buf0);
210         char *outgoing_data = static_cast<char*>(buf0_was_inout ? tmp_buf0 : tmp_buf1);
211         int peer = rank ^ (0x1 << k);
212         size /= 2;
213
214         if (rank > peer) {
215             /* we have the higher rank: send top half, recv bottom half */
216             recv_offset += size;
217         }
218         else {
219             /* we have the lower rank: recv top half, send bottom half */
220             send_offset += size;
221         }
222
223         Request::sendrecv(outgoing_data + send_offset*true_extent,
224                                      size, datatype, peer, COLL_TAG_SCATTER,
225                                      incoming_data + recv_offset*true_extent,
226                                      size, datatype, peer, COLL_TAG_SCATTER,
227                                      comm, MPI_STATUS_IGNORE);
228         /* always perform the reduction at recv_offset, the data at send_offset
229            is now our peer's responsibility */
230         if (rank > peer) {
231             /* higher ranked value so need to call op(received_data, my_data) */
232             if(op!=MPI_OP_NULL) op->apply(
233                    incoming_data + recv_offset*true_extent,
234                      outgoing_data + recv_offset*true_extent,
235                      &size, datatype );
236             /* buf0_was_inout = buf0_was_inout; */
237         }
238         else {
239             /* lower ranked value so need to call op(my_data, received_data) */
240             if (op != MPI_OP_NULL)
241               op->apply(outgoing_data + recv_offset * true_extent, incoming_data + recv_offset * true_extent, &size,
242                         datatype);
243             buf0_was_inout = not buf0_was_inout;
244         }
245
246         /* the next round of send/recv needs to happen within the block (of size
247            "size") that we just received and reduced */
248         send_offset = recv_offset;
249     }
250
251     xbt_assert(size == recvcounts[rank]);
252
253     /* copy the reduced data to the recvbuf */
254     result_ptr = (char *)(buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
255     mpi_errno = Datatype::copy(result_ptr, size, datatype,
256                                recvbuf, size, datatype);
257     smpi_free_tmp_buffer(tmp_buf0_save);
258     smpi_free_tmp_buffer(tmp_buf1_save);
259     if (mpi_errno) return(mpi_errno);
260     return MPI_SUCCESS;
261 }
262
263
264
265 int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf, int recvcounts[],
266                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
267 {
268     int   rank, comm_size, i;
269     MPI_Aint extent, true_extent, true_lb;
270     int  *disps;
271     void *tmp_recvbuf, *tmp_results;
272     int mpi_errno = MPI_SUCCESS;
273     int dis[2], blklens[2], total_count, dst;
274     int mask, dst_tree_root, my_tree_root, j, k;
275     int received;
276     MPI_Datatype sendtype, recvtype;
277     int nprocs_completed, tmp_mask, tree_root, is_commutative=0;
278     comm_size = comm->size();
279     rank = comm->rank();
280
281     extent =datatype->get_extent();
282     datatype->extent(&true_lb, &true_extent);
283
284     if ((op==MPI_OP_NULL) || op->is_commutative()) {
285         is_commutative = 1;
286     }
287
288     disps = (int*)xbt_malloc( comm_size * sizeof(int));
289
290     total_count = 0;
291     for (i=0; i<comm_size; i++) {
292         disps[i] = total_count;
293         total_count += recvcounts[i];
294     }
295
296             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
297
298             /* need to allocate temporary buffer to receive incoming data*/
299             tmp_recvbuf= (void*)smpi_get_tmp_recvbuffer(total_count * std::max(true_extent, extent));
300             /* adjust for potential negative lower bound in datatype */
301             tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
302
303             /* need to allocate another temporary buffer to accumulate
304                results */
305             tmp_results = (void*)smpi_get_tmp_sendbuffer(total_count * std::max(true_extent, extent));
306             /* adjust for potential negative lower bound in datatype */
307             tmp_results = (void *)((char*)tmp_results - true_lb);
308
309             /* copy sendbuf into tmp_results */
310             if (sendbuf != MPI_IN_PLACE)
311                 mpi_errno = Datatype::copy(sendbuf, total_count, datatype,
312                                            tmp_results, total_count, datatype);
313             else
314                 mpi_errno = Datatype::copy(recvbuf, total_count, datatype,
315                                            tmp_results, total_count, datatype);
316
317             if (mpi_errno) return(mpi_errno);
318
319             mask = 0x1;
320             i = 0;
321             while (mask < comm_size) {
322                 dst = rank ^ mask;
323
324                 dst_tree_root = dst >> i;
325                 dst_tree_root <<= i;
326
327                 my_tree_root = rank >> i;
328                 my_tree_root <<= i;
329
330                 /* At step 1, processes exchange (n-n/p) amount of
331                    data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
332                    amount of data, and so forth. We use derived datatypes for this.
333
334                    At each step, a process does not need to send data
335                    indexed from my_tree_root to
336                    my_tree_root+mask-1. Similarly, a process won't receive
337                    data indexed from dst_tree_root to dst_tree_root+mask-1. */
338
339                 /* calculate sendtype */
340                 blklens[0] = blklens[1] = 0;
341                 for (j=0; j<my_tree_root; j++)
342                     blklens[0] += recvcounts[j];
343                 for (j=my_tree_root+mask; j<comm_size; j++)
344                     blklens[1] += recvcounts[j];
345
346                 dis[0] = 0;
347                 dis[1] = blklens[0];
348                 for (j=my_tree_root; (j<my_tree_root+mask) && (j<comm_size); j++)
349                     dis[1] += recvcounts[j];
350
351                 mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &sendtype);
352                 if (mpi_errno) return(mpi_errno);
353
354                 sendtype->commit();
355
356                 /* calculate recvtype */
357                 blklens[0] = blklens[1] = 0;
358                 for (j=0; j<dst_tree_root && j<comm_size; j++)
359                     blklens[0] += recvcounts[j];
360                 for (j=dst_tree_root+mask; j<comm_size; j++)
361                     blklens[1] += recvcounts[j];
362
363                 dis[0] = 0;
364                 dis[1] = blklens[0];
365                 for (j=dst_tree_root; (j<dst_tree_root+mask) && (j<comm_size); j++)
366                     dis[1] += recvcounts[j];
367
368                 mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &recvtype);
369                 if (mpi_errno) return(mpi_errno);
370
371                 recvtype->commit();
372
373                 received = 0;
374                 if (dst < comm_size) {
375                     /* tmp_results contains data to be sent in each step. Data is
376                        received in tmp_recvbuf and then accumulated into
377                        tmp_results. accumulation is done later below.   */
378
379                     Request::sendrecv(tmp_results, 1, sendtype, dst,
380                                                  COLL_TAG_SCATTER,
381                                                  tmp_recvbuf, 1, recvtype, dst,
382                                                  COLL_TAG_SCATTER, comm,
383                                                  MPI_STATUS_IGNORE);
384                     received = 1;
385                 }
386
387                 /* if some processes in this process's subtree in this step
388                    did not have any destination process to communicate with
389                    because of non-power-of-two, we need to send them the
390                    result. We use a logarithmic recursive-halfing algorithm
391                    for this. */
392
393                 if (dst_tree_root + mask > comm_size) {
394                     nprocs_completed = comm_size - my_tree_root - mask;
395                     /* nprocs_completed is the number of processes in this
396                        subtree that have all the data. Send data to others
397                        in a tree fashion. First find root of current tree
398                        that is being divided into two. k is the number of
399                        least-significant bits in this process's rank that
400                        must be zeroed out to find the rank of the root */
401                     j = mask;
402                     k = 0;
403                     while (j) {
404                         j >>= 1;
405                         k++;
406                     }
407                     k--;
408
409                     tmp_mask = mask >> 1;
410                     while (tmp_mask) {
411                         dst = rank ^ tmp_mask;
412
413                         tree_root = rank >> k;
414                         tree_root <<= k;
415
416                         /* send only if this proc has data and destination
417                            doesn't have data. at any step, multiple processes
418                            can send if they have the data */
419                         if ((dst > rank) &&
420                             (rank < tree_root + nprocs_completed)
421                             && (dst >= tree_root + nprocs_completed)) {
422                             /* send the current result */
423                             Request::send(tmp_recvbuf, 1, recvtype,
424                                                      dst, COLL_TAG_SCATTER,
425                                                      comm);
426                         }
427                         /* recv only if this proc. doesn't have data and sender
428                            has data */
429                         else if ((dst < rank) &&
430                                  (dst < tree_root + nprocs_completed) &&
431                                  (rank >= tree_root + nprocs_completed)) {
432                             Request::recv(tmp_recvbuf, 1, recvtype, dst,
433                                                      COLL_TAG_SCATTER,
434                                                      comm, MPI_STATUS_IGNORE);
435                             received = 1;
436                         }
437                         tmp_mask >>= 1;
438                         k--;
439                     }
440                 }
441
442                 /* The following reduction is done here instead of after
443                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
444                    because to do it above, in the noncommutative
445                    case, we would need an extra temp buffer so as not to
446                    overwrite temp_recvbuf, because temp_recvbuf may have
447                    to be communicated to other processes in the
448                    non-power-of-two case. To avoid that extra allocation,
449                    we do the reduce here. */
450                 if (received) {
451                     if (is_commutative || (dst_tree_root < my_tree_root)) {
452                         {
453                           if (op != MPI_OP_NULL)
454                             op->apply(tmp_recvbuf, tmp_results, &blklens[0], datatype);
455                           if (op != MPI_OP_NULL)
456                             op->apply(((char*)tmp_recvbuf + dis[1] * extent), ((char*)tmp_results + dis[1] * extent),
457                                       &blklens[1], datatype);
458                         }
459                     }
460                     else {
461                         {
462                           if (op != MPI_OP_NULL)
463                             op->apply(tmp_results, tmp_recvbuf, &blklens[0], datatype);
464                           if (op != MPI_OP_NULL)
465                             op->apply(((char*)tmp_results + dis[1] * extent), ((char*)tmp_recvbuf + dis[1] * extent),
466                                       &blklens[1], datatype);
467                         }
468                         /* copy result back into tmp_results */
469                         mpi_errno = Datatype::copy(tmp_recvbuf, 1, recvtype,
470                                                    tmp_results, 1, recvtype);
471                         if (mpi_errno) return(mpi_errno);
472                     }
473                 }
474
475                 Datatype::unref(sendtype);
476                 Datatype::unref(recvtype);
477
478                 mask <<= 1;
479                 i++;
480             }
481
482             /* now copy final results from tmp_results to recvbuf */
483             mpi_errno = Datatype::copy(((char *)tmp_results+disps[rank]*extent),
484                                        recvcounts[rank], datatype, recvbuf,
485                                        recvcounts[rank], datatype);
486             if (mpi_errno) return(mpi_errno);
487
488     xbt_free(disps);
489     smpi_free_tmp_buffer(tmp_recvbuf);
490     smpi_free_tmp_buffer(tmp_results);
491     return MPI_SUCCESS;
492         }
493 }
494 }
495