Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines for 2022.
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-mpich.cpp
1 /* Copyright (c) 2013-2022. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8 #include <algorithm>
9
10 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
11 {
12     /* a mask for the high order bits that should be copied as-is */
13     int high_mask = ~((0x1 << bits) - 1);
14     int retval = x & high_mask;
15     int i;
16
17     for (i = 0; i < bits; ++i) {
18         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
19         retval |= bitval << ((bits - i) - 1);
20     }
21
22     return retval;
23 }
24 namespace simgrid{
25 namespace smpi{
26
27 int reduce_scatter__mpich_pair(const void *sendbuf, void *recvbuf, const int recvcounts[],
28                                MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
29 {
30     int   rank, comm_size, i;
31     MPI_Aint extent, true_extent, true_lb;
32     unsigned char* tmp_recvbuf;
33     int mpi_errno = MPI_SUCCESS;
34     int total_count, dst, src;
35     comm_size = comm->size();
36     rank = comm->rank();
37
38     extent =datatype->get_extent();
39     datatype->extent(&true_lb, &true_extent);
40
41     bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
42
43     int* disps = new int[comm_size];
44
45     total_count = 0;
46     for (i=0; i<comm_size; i++) {
47         disps[i] = total_count;
48         total_count += recvcounts[i];
49     }
50
51     if (total_count == 0) {
52       delete[] disps;
53       return MPI_ERR_COUNT;
54     }
55
56         if (sendbuf != MPI_IN_PLACE) {
57             /* copy local data into recvbuf */
58             Datatype::copy(((char *)sendbuf+disps[rank]*extent),
59                                        recvcounts[rank], datatype, recvbuf,
60                                        recvcounts[rank], datatype);
61         }
62
63         /* allocate temporary buffer to store incoming data */
64         tmp_recvbuf = smpi_get_tmp_recvbuffer(recvcounts[rank] * std::max(true_extent, extent) + 1);
65         /* adjust for potential negative lower bound in datatype */
66         tmp_recvbuf = tmp_recvbuf - true_lb;
67
68         for (i=1; i<comm_size; i++) {
69             src = (rank - i + comm_size) % comm_size;
70             dst = (rank + i) % comm_size;
71
72             /* send the data that dst needs. recv data that this process
73                needs from src into tmp_recvbuf */
74             if (sendbuf != MPI_IN_PLACE)
75                 Request::sendrecv(((char *)sendbuf+disps[dst]*extent),
76                                              recvcounts[dst], datatype, dst,
77                                              COLL_TAG_REDUCE_SCATTER, tmp_recvbuf,
78                                              recvcounts[rank], datatype, src,
79                                              COLL_TAG_REDUCE_SCATTER, comm,
80                                              MPI_STATUS_IGNORE);
81             else
82                 Request::sendrecv(((char *)recvbuf+disps[dst]*extent),
83                                              recvcounts[dst], datatype, dst,
84                                              COLL_TAG_REDUCE_SCATTER, tmp_recvbuf,
85                                              recvcounts[rank], datatype, src,
86                                              COLL_TAG_REDUCE_SCATTER, comm,
87                                              MPI_STATUS_IGNORE);
88
89             if (is_commutative || (src < rank)) {
90                 if (sendbuf != MPI_IN_PLACE) {
91                   if (op != MPI_OP_NULL)
92                     op->apply(tmp_recvbuf, recvbuf, &recvcounts[rank], datatype);
93                 }
94                 else {
95                   if (op != MPI_OP_NULL)
96                     op->apply(tmp_recvbuf, ((char*)recvbuf + disps[rank] * extent), &recvcounts[rank], datatype);
97                   /* we can't store the result at the beginning of
98                      recvbuf right here because there is useful data
99                      there that other process/processes need. at the
100                      end, we will copy back the result to the
101                      beginning of recvbuf. */
102                 }
103             }
104             else {
105                 if (sendbuf != MPI_IN_PLACE) {
106                   if (op != MPI_OP_NULL)
107                     op->apply(recvbuf, tmp_recvbuf, &recvcounts[rank], datatype);
108                   /* copy result back into recvbuf */
109                   mpi_errno =
110                       Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype, recvbuf, recvcounts[rank], datatype);
111                   if (mpi_errno) {
112                     delete[] disps;
113                     smpi_free_tmp_buffer(tmp_recvbuf);
114                     return (mpi_errno);
115                   }
116                 }
117                 else {
118                   if (op != MPI_OP_NULL)
119                     op->apply(((char*)recvbuf + disps[rank] * extent), tmp_recvbuf, &recvcounts[rank], datatype);
120                   /* copy result back into recvbuf */
121                   mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype,
122                                              ((char*)recvbuf + disps[rank] * extent), recvcounts[rank], datatype);
123                   if (mpi_errno) {
124                     delete[] disps;
125                     smpi_free_tmp_buffer(tmp_recvbuf);
126                     return (mpi_errno);
127                   }
128                 }
129             }
130         }
131
132         /* if MPI_IN_PLACE, move output data to the beginning of
133            recvbuf. already done for rank 0. */
134         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
135             mpi_errno = Datatype::copy(((char *)recvbuf +
136                                         disps[rank]*extent),
137                                        recvcounts[rank], datatype,
138                                        recvbuf,
139                                        recvcounts[rank], datatype );
140             if (mpi_errno) {
141               delete[] disps;
142               smpi_free_tmp_buffer(tmp_recvbuf);
143               return (mpi_errno);
144             }
145         }
146
147         delete[] disps;
148         smpi_free_tmp_buffer(tmp_recvbuf);
149
150         return MPI_SUCCESS;
151 }
152
153
154 int reduce_scatter__mpich_noncomm(const void *sendbuf, void *recvbuf, const int recvcounts[],
155                                   MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
156 {
157     int mpi_errno = MPI_SUCCESS;
158     int comm_size = comm->size() ;
159     int rank = comm->rank();
160     int pof2;
161     int log2_comm_size;
162     int i, k;
163     int recv_offset, send_offset;
164     int block_size, total_count, size;
165     MPI_Aint true_extent, true_lb;
166     int buf0_was_inout;
167     unsigned char* tmp_buf0;
168     unsigned char* tmp_buf1;
169     unsigned char* result_ptr;
170
171     datatype->extent(&true_lb, &true_extent);
172
173     pof2 = 1;
174     log2_comm_size = 0;
175     while (pof2 < comm_size) {
176         pof2 <<= 1;
177         ++log2_comm_size;
178     }
179
180     /* begin error checking */
181     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
182
183     for (i = 0; i < (comm_size - 1); ++i) {
184         xbt_assert(recvcounts[i] == recvcounts[i+1]);
185     }
186     /* end error checking */
187
188     /* size of a block (count of datatype per block, NOT bytes per block) */
189     block_size = recvcounts[0];
190     total_count = block_size * comm_size;
191
192     tmp_buf0                     = smpi_get_tmp_sendbuffer(true_extent * total_count);
193     tmp_buf1                     = smpi_get_tmp_recvbuffer(true_extent * total_count);
194     unsigned char* tmp_buf0_save = tmp_buf0;
195     unsigned char* tmp_buf1_save = tmp_buf1;
196
197     /* adjust for potential negative lower bound in datatype */
198     tmp_buf0 = tmp_buf0 - true_lb;
199     tmp_buf1 = tmp_buf1 - true_lb;
200
201     /* Copy our send data to tmp_buf0.  We do this one block at a time and
202        permute the blocks as we go according to the mirror permutation. */
203     for (i = 0; i < comm_size; ++i) {
204       mpi_errno = Datatype::copy(
205           static_cast<const char*>(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size,
206           datatype, tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size,
207           datatype);
208       if (mpi_errno)
209         return mpi_errno;
210     }
211     buf0_was_inout = 1;
212
213     send_offset = 0;
214     recv_offset = 0;
215     size = total_count;
216     for (k = 0; k < log2_comm_size; ++k) {
217         /* use a double-buffering scheme to avoid local copies */
218         unsigned char* incoming_data = buf0_was_inout ? tmp_buf1 : tmp_buf0;
219         unsigned char* outgoing_data = buf0_was_inout ? tmp_buf0 : tmp_buf1;
220         int peer = rank ^ (0x1 << k);
221         size /= 2;
222
223         if (rank > peer) {
224             /* we have the higher rank: send top half, recv bottom half */
225             recv_offset += size;
226         }
227         else {
228             /* we have the lower rank: recv top half, send bottom half */
229             send_offset += size;
230         }
231
232         Request::sendrecv(outgoing_data + send_offset*true_extent,
233                                      size, datatype, peer, COLL_TAG_REDUCE_SCATTER,
234                                      incoming_data + recv_offset*true_extent,
235                                      size, datatype, peer, COLL_TAG_REDUCE_SCATTER,
236                                      comm, MPI_STATUS_IGNORE);
237         /* always perform the reduction at recv_offset, the data at send_offset
238            is now our peer's responsibility */
239         if (rank > peer) {
240             /* higher ranked value so need to call op(received_data, my_data) */
241             if(op!=MPI_OP_NULL) op->apply(
242                    incoming_data + recv_offset*true_extent,
243                      outgoing_data + recv_offset*true_extent,
244                      &size, datatype );
245             /* buf0_was_inout = buf0_was_inout; */
246         }
247         else {
248             /* lower ranked value so need to call op(my_data, received_data) */
249             if (op != MPI_OP_NULL)
250               op->apply(outgoing_data + recv_offset * true_extent, incoming_data + recv_offset * true_extent, &size,
251                         datatype);
252             buf0_was_inout = not buf0_was_inout;
253         }
254
255         /* the next round of send/recv needs to happen within the block (of size
256            "size") that we just received and reduced */
257         send_offset = recv_offset;
258     }
259
260     xbt_assert(size == recvcounts[rank]);
261
262     /* copy the reduced data to the recvbuf */
263     result_ptr = (buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
264     mpi_errno = Datatype::copy(result_ptr, size, datatype,
265                                recvbuf, size, datatype);
266     smpi_free_tmp_buffer(tmp_buf0_save);
267     smpi_free_tmp_buffer(tmp_buf1_save);
268     if (mpi_errno) return(mpi_errno);
269     return MPI_SUCCESS;
270 }
271
272
273
274 int reduce_scatter__mpich_rdb(const void *sendbuf, void *recvbuf, const int recvcounts[],
275                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
276 {
277     int   rank, comm_size, i;
278     MPI_Aint extent, true_extent, true_lb;
279     int mpi_errno = MPI_SUCCESS;
280     int dis[2], blklens[2], total_count, dst;
281     int mask, dst_tree_root, my_tree_root, j, k;
282     int received;
283     MPI_Datatype sendtype, recvtype;
284     int nprocs_completed, tmp_mask, tree_root;
285     comm_size = comm->size();
286     rank = comm->rank();
287
288     extent =datatype->get_extent();
289     datatype->extent(&true_lb, &true_extent);
290
291     bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
292
293     int* disps = new int[comm_size];
294
295     total_count = 0;
296     for (i=0; i<comm_size; i++) {
297         disps[i] = total_count;
298         total_count += recvcounts[i];
299     }
300
301             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
302
303             /* need to allocate temporary buffer to receive incoming data*/
304     unsigned char* tmp_recvbuf = smpi_get_tmp_recvbuffer(total_count * std::max(true_extent, extent));
305     /* adjust for potential negative lower bound in datatype */
306     tmp_recvbuf = tmp_recvbuf - true_lb;
307
308     /* need to allocate another temporary buffer to accumulate
309        results */
310     unsigned char* tmp_results = smpi_get_tmp_sendbuffer(total_count * std::max(true_extent, extent));
311     /* adjust for potential negative lower bound in datatype */
312     tmp_results = tmp_results - true_lb;
313
314     /* copy sendbuf into tmp_results */
315     if (sendbuf != MPI_IN_PLACE)
316       mpi_errno = Datatype::copy(sendbuf, total_count, datatype, tmp_results, total_count, datatype);
317     else
318       mpi_errno = Datatype::copy(recvbuf, total_count, datatype, tmp_results, total_count, datatype);
319
320     if (mpi_errno) {
321       delete[] disps;
322       return (mpi_errno);
323     }
324
325     mask = 0x1;
326     i    = 0;
327     while (mask < comm_size) {
328       dst = rank ^ mask;
329
330       dst_tree_root = dst >> i;
331       dst_tree_root <<= i;
332
333       my_tree_root = rank >> i;
334       my_tree_root <<= i;
335
336       /* At step 1, processes exchange (n-n/p) amount of
337          data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
338          amount of data, and so forth. We use derived datatypes for this.
339
340          At each step, a process does not need to send data
341          indexed from my_tree_root to
342          my_tree_root+mask-1. Similarly, a process won't receive
343          data indexed from dst_tree_root to dst_tree_root+mask-1. */
344
345       /* calculate sendtype */
346       blklens[0] = blklens[1] = 0;
347       for (j = 0; j < my_tree_root; j++)
348         blklens[0] += recvcounts[j];
349       for (j = my_tree_root + mask; j < comm_size; j++)
350         blklens[1] += recvcounts[j];
351
352       dis[0] = 0;
353       dis[1] = blklens[0];
354       for (j = my_tree_root; (j < my_tree_root + mask) && (j < comm_size); j++)
355         dis[1] += recvcounts[j];
356
357       mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &sendtype);
358       if (mpi_errno)
359         return (mpi_errno);
360
361       sendtype->commit();
362
363       /* calculate recvtype */
364       blklens[0] = blklens[1] = 0;
365       for (j = 0; j < dst_tree_root && j < comm_size; j++)
366         blklens[0] += recvcounts[j];
367       for (j = dst_tree_root + mask; j < comm_size; j++)
368         blklens[1] += recvcounts[j];
369
370       dis[0] = 0;
371       dis[1] = blklens[0];
372       for (j = dst_tree_root; (j < dst_tree_root + mask) && (j < comm_size); j++)
373         dis[1] += recvcounts[j];
374
375       mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &recvtype);
376       if (mpi_errno)
377         return (mpi_errno);
378
379       recvtype->commit();
380
381       received = 0;
382       if (dst < comm_size) {
383         /* tmp_results contains data to be sent in each step. Data is
384            received in tmp_recvbuf and then accumulated into
385            tmp_results. accumulation is done later below.   */
386
387         Request::sendrecv(tmp_results, 1, sendtype, dst, COLL_TAG_REDUCE_SCATTER, tmp_recvbuf, 1, recvtype, dst,
388                           COLL_TAG_REDUCE_SCATTER, comm, MPI_STATUS_IGNORE);
389         received = 1;
390       }
391
392       /* if some processes in this process's subtree in this step
393          did not have any destination process to communicate with
394          because of non-power-of-two, we need to send them the
395          result. We use a logarithmic recursive-halfing algorithm
396          for this. */
397
398       if (dst_tree_root + mask > comm_size) {
399         nprocs_completed = comm_size - my_tree_root - mask;
400         /* nprocs_completed is the number of processes in this
401            subtree that have all the data. Send data to others
402            in a tree fashion. First find root of current tree
403            that is being divided into two. k is the number of
404            least-significant bits in this process's rank that
405            must be zeroed out to find the rank of the root */
406         j = mask;
407         k = 0;
408         while (j) {
409           j >>= 1;
410           k++;
411         }
412         k--;
413
414         tmp_mask = mask >> 1;
415         while (tmp_mask) {
416           dst = rank ^ tmp_mask;
417
418           tree_root = rank >> k;
419           tree_root <<= k;
420
421           /* send only if this proc has data and destination
422              doesn't have data. at any step, multiple processes
423              can send if they have the data */
424           if ((dst > rank) && (rank < tree_root + nprocs_completed) && (dst >= tree_root + nprocs_completed)) {
425             /* send the current result */
426             Request::send(tmp_recvbuf, 1, recvtype, dst, COLL_TAG_REDUCE_SCATTER, comm);
427           }
428           /* recv only if this proc. doesn't have data and sender
429              has data */
430           else if ((dst < rank) && (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) {
431             Request::recv(tmp_recvbuf, 1, recvtype, dst, COLL_TAG_REDUCE_SCATTER, comm, MPI_STATUS_IGNORE);
432             received = 1;
433           }
434           tmp_mask >>= 1;
435           k--;
436         }
437       }
438
439       /* The following reduction is done here instead of after
440          the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
441          because to do it above, in the noncommutative
442          case, we would need an extra temp buffer so as not to
443          overwrite temp_recvbuf, because temp_recvbuf may have
444          to be communicated to other processes in the
445          non-power-of-two case. To avoid that extra allocation,
446          we do the reduce here. */
447       if (received) {
448         if (is_commutative || (dst_tree_root < my_tree_root)) {
449           {
450             if (op != MPI_OP_NULL)
451               op->apply(tmp_recvbuf, tmp_results, &blklens[0], datatype);
452             if (op != MPI_OP_NULL)
453               op->apply(tmp_recvbuf + dis[1] * extent, tmp_results + dis[1] * extent, &blklens[1], datatype);
454           }
455         } else {
456           {
457             if (op != MPI_OP_NULL)
458               op->apply(tmp_results, tmp_recvbuf, &blklens[0], datatype);
459             if (op != MPI_OP_NULL)
460               op->apply(tmp_results + dis[1] * extent, tmp_recvbuf + dis[1] * extent, &blklens[1], datatype);
461           }
462           /* copy result back into tmp_results */
463           mpi_errno = Datatype::copy(tmp_recvbuf, 1, recvtype, tmp_results, 1, recvtype);
464           if (mpi_errno)
465             return (mpi_errno);
466         }
467       }
468
469       Datatype::unref(sendtype);
470       Datatype::unref(recvtype);
471
472       mask <<= 1;
473       i++;
474             }
475
476             /* now copy final results from tmp_results to recvbuf */
477             mpi_errno = Datatype::copy(tmp_results + disps[rank] * extent, recvcounts[rank], datatype, recvbuf,
478                                        recvcounts[rank], datatype);
479             if (mpi_errno) return(mpi_errno);
480
481     delete[] disps;
482     smpi_free_tmp_buffer(tmp_recvbuf);
483     smpi_free_tmp_buffer(tmp_results);
484     return MPI_SUCCESS;
485         }
486 }
487 }
488