Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add new entry in Release_Notes.
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-mpich.cpp
1 /* Copyright (c) 2013-2023. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 #include "../colls_private.hpp"
8 #include <algorithm>
9
10 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
11 {
12     /* a mask for the high order bits that should be copied as-is */
13     int high_mask = ~((0x1 << bits) - 1);
14     int retval = x & high_mask;
15     int i;
16
17     for (i = 0; i < bits; ++i) {
18         unsigned int bitval = (x & (0x1 << i)) >> i; /* 0x1 or 0x0 */
19         retval |= bitval << ((bits - i) - 1);
20     }
21
22     return retval;
23 }
24 namespace simgrid::smpi {
25
26 int reduce_scatter__mpich_pair(const void *sendbuf, void *recvbuf, const int recvcounts[],
27                                MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
28 {
29     int   rank, comm_size, i;
30     MPI_Aint extent, true_extent, true_lb;
31     unsigned char* tmp_recvbuf;
32     int mpi_errno = MPI_SUCCESS;
33     int total_count, dst, src;
34     comm_size = comm->size();
35     rank = comm->rank();
36
37     extent =datatype->get_extent();
38     datatype->extent(&true_lb, &true_extent);
39
40     bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
41
42     int* disps = new int[comm_size];
43
44     total_count = 0;
45     for (i=0; i<comm_size; i++) {
46         disps[i] = total_count;
47         total_count += recvcounts[i];
48     }
49
50     if (total_count == 0) {
51       delete[] disps;
52       return MPI_ERR_COUNT;
53     }
54
55         if (sendbuf != MPI_IN_PLACE) {
56             /* copy local data into recvbuf */
57             Datatype::copy(((char *)sendbuf+disps[rank]*extent),
58                                        recvcounts[rank], datatype, recvbuf,
59                                        recvcounts[rank], datatype);
60         }
61
62         /* allocate temporary buffer to store incoming data */
63         tmp_recvbuf = smpi_get_tmp_recvbuffer(recvcounts[rank] * std::max(true_extent, extent) + 1);
64         /* adjust for potential negative lower bound in datatype */
65         tmp_recvbuf = tmp_recvbuf - true_lb;
66
67         for (i=1; i<comm_size; i++) {
68             src = (rank - i + comm_size) % comm_size;
69             dst = (rank + i) % comm_size;
70
71             /* send the data that dst needs. recv data that this process
72                needs from src into tmp_recvbuf */
73             if (sendbuf != MPI_IN_PLACE)
74                 Request::sendrecv(((char *)sendbuf+disps[dst]*extent),
75                                              recvcounts[dst], datatype, dst,
76                                              COLL_TAG_REDUCE_SCATTER, tmp_recvbuf,
77                                              recvcounts[rank], datatype, src,
78                                              COLL_TAG_REDUCE_SCATTER, comm,
79                                              MPI_STATUS_IGNORE);
80             else
81                 Request::sendrecv(((char *)recvbuf+disps[dst]*extent),
82                                              recvcounts[dst], datatype, dst,
83                                              COLL_TAG_REDUCE_SCATTER, tmp_recvbuf,
84                                              recvcounts[rank], datatype, src,
85                                              COLL_TAG_REDUCE_SCATTER, comm,
86                                              MPI_STATUS_IGNORE);
87
88             if (is_commutative || (src < rank)) {
89                 if (sendbuf != MPI_IN_PLACE) {
90                   if (op != MPI_OP_NULL)
91                     op->apply(tmp_recvbuf, recvbuf, &recvcounts[rank], datatype);
92                 }
93                 else {
94                   if (op != MPI_OP_NULL)
95                     op->apply(tmp_recvbuf, ((char*)recvbuf + disps[rank] * extent), &recvcounts[rank], datatype);
96                   /* we can't store the result at the beginning of
97                      recvbuf right here because there is useful data
98                      there that other process/processes need. at the
99                      end, we will copy back the result to the
100                      beginning of recvbuf. */
101                 }
102             }
103             else {
104                 if (sendbuf != MPI_IN_PLACE) {
105                   if (op != MPI_OP_NULL)
106                     op->apply(recvbuf, tmp_recvbuf, &recvcounts[rank], datatype);
107                   /* copy result back into recvbuf */
108                   mpi_errno =
109                       Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype, recvbuf, recvcounts[rank], datatype);
110                   if (mpi_errno) {
111                     delete[] disps;
112                     smpi_free_tmp_buffer(tmp_recvbuf);
113                     return (mpi_errno);
114                   }
115                 }
116                 else {
117                   if (op != MPI_OP_NULL)
118                     op->apply(((char*)recvbuf + disps[rank] * extent), tmp_recvbuf, &recvcounts[rank], datatype);
119                   /* copy result back into recvbuf */
120                   mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype,
121                                              ((char*)recvbuf + disps[rank] * extent), recvcounts[rank], datatype);
122                   if (mpi_errno) {
123                     delete[] disps;
124                     smpi_free_tmp_buffer(tmp_recvbuf);
125                     return (mpi_errno);
126                   }
127                 }
128             }
129         }
130
131         /* if MPI_IN_PLACE, move output data to the beginning of
132            recvbuf. already done for rank 0. */
133         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
134             mpi_errno = Datatype::copy(((char *)recvbuf +
135                                         disps[rank]*extent),
136                                        recvcounts[rank], datatype,
137                                        recvbuf,
138                                        recvcounts[rank], datatype );
139             if (mpi_errno) {
140               delete[] disps;
141               smpi_free_tmp_buffer(tmp_recvbuf);
142               return (mpi_errno);
143             }
144         }
145
146         delete[] disps;
147         smpi_free_tmp_buffer(tmp_recvbuf);
148
149         return MPI_SUCCESS;
150 }
151
152
153 int reduce_scatter__mpich_noncomm(const void *sendbuf, void *recvbuf, const int recvcounts[],
154                                   MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
155 {
156     int mpi_errno = MPI_SUCCESS;
157     int comm_size = comm->size() ;
158     int rank = comm->rank();
159     int pof2;
160     int log2_comm_size;
161     int i, k;
162     int recv_offset, send_offset;
163     int block_size, total_count, size;
164     MPI_Aint true_extent, true_lb;
165     int buf0_was_inout;
166     unsigned char* tmp_buf0;
167     unsigned char* tmp_buf1;
168     unsigned char* result_ptr;
169
170     datatype->extent(&true_lb, &true_extent);
171
172     pof2 = 1;
173     log2_comm_size = 0;
174     while (pof2 < comm_size) {
175         pof2 <<= 1;
176         ++log2_comm_size;
177     }
178
179     /* begin error checking */
180     xbt_assert(pof2 == comm_size); /* FIXME this version only works for power of 2 procs */
181
182     for (i = 0; i < (comm_size - 1); ++i) {
183         xbt_assert(recvcounts[i] == recvcounts[i+1]);
184     }
185     /* end error checking */
186
187     /* size of a block (count of datatype per block, NOT bytes per block) */
188     block_size = recvcounts[0];
189     total_count = block_size * comm_size;
190
191     tmp_buf0                     = smpi_get_tmp_sendbuffer(true_extent * total_count);
192     tmp_buf1                     = smpi_get_tmp_recvbuffer(true_extent * total_count);
193     unsigned char* tmp_buf0_save = tmp_buf0;
194     unsigned char* tmp_buf1_save = tmp_buf1;
195
196     /* adjust for potential negative lower bound in datatype */
197     tmp_buf0 = tmp_buf0 - true_lb;
198     tmp_buf1 = tmp_buf1 - true_lb;
199
200     /* Copy our send data to tmp_buf0.  We do this one block at a time and
201        permute the blocks as we go according to the mirror permutation. */
202     for (i = 0; i < comm_size; ++i) {
203       mpi_errno = Datatype::copy(
204           static_cast<const char*>(sendbuf == MPI_IN_PLACE ? recvbuf : sendbuf) + (i * true_extent * block_size), block_size,
205           datatype, tmp_buf0 + (MPIU_Mirror_permutation(i, log2_comm_size) * true_extent * block_size), block_size,
206           datatype);
207       if (mpi_errno)
208         return mpi_errno;
209     }
210     buf0_was_inout = 1;
211
212     send_offset = 0;
213     recv_offset = 0;
214     size = total_count;
215     for (k = 0; k < log2_comm_size; ++k) {
216         /* use a double-buffering scheme to avoid local copies */
217         unsigned char* incoming_data = buf0_was_inout ? tmp_buf1 : tmp_buf0;
218         unsigned char* outgoing_data = buf0_was_inout ? tmp_buf0 : tmp_buf1;
219         int peer = rank ^ (0x1 << k);
220         size /= 2;
221
222         if (rank > peer) {
223             /* we have the higher rank: send top half, recv bottom half */
224             recv_offset += size;
225         }
226         else {
227             /* we have the lower rank: recv top half, send bottom half */
228             send_offset += size;
229         }
230
231         Request::sendrecv(outgoing_data + send_offset*true_extent,
232                                      size, datatype, peer, COLL_TAG_REDUCE_SCATTER,
233                                      incoming_data + recv_offset*true_extent,
234                                      size, datatype, peer, COLL_TAG_REDUCE_SCATTER,
235                                      comm, MPI_STATUS_IGNORE);
236         /* always perform the reduction at recv_offset, the data at send_offset
237            is now our peer's responsibility */
238         if (rank > peer) {
239             /* higher ranked value so need to call op(received_data, my_data) */
240             if(op!=MPI_OP_NULL) op->apply(
241                    incoming_data + recv_offset*true_extent,
242                      outgoing_data + recv_offset*true_extent,
243                      &size, datatype );
244             /* buf0_was_inout = buf0_was_inout; */
245         }
246         else {
247             /* lower ranked value so need to call op(my_data, received_data) */
248             if (op != MPI_OP_NULL)
249               op->apply(outgoing_data + recv_offset * true_extent, incoming_data + recv_offset * true_extent, &size,
250                         datatype);
251             buf0_was_inout = not buf0_was_inout;
252         }
253
254         /* the next round of send/recv needs to happen within the block (of size
255            "size") that we just received and reduced */
256         send_offset = recv_offset;
257     }
258
259     xbt_assert(size == recvcounts[rank]);
260
261     /* copy the reduced data to the recvbuf */
262     result_ptr = (buf0_was_inout ? tmp_buf0 : tmp_buf1) + recv_offset * true_extent;
263     mpi_errno = Datatype::copy(result_ptr, size, datatype,
264                                recvbuf, size, datatype);
265     smpi_free_tmp_buffer(tmp_buf0_save);
266     smpi_free_tmp_buffer(tmp_buf1_save);
267     if (mpi_errno) return(mpi_errno);
268     return MPI_SUCCESS;
269 }
270
271
272
273 int reduce_scatter__mpich_rdb(const void *sendbuf, void *recvbuf, const int recvcounts[],
274                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
275 {
276     int   rank, comm_size, i;
277     MPI_Aint extent, true_extent, true_lb;
278     int mpi_errno = MPI_SUCCESS;
279     int dis[2], blklens[2], total_count, dst;
280     int mask, dst_tree_root, my_tree_root, j, k;
281     int received;
282     MPI_Datatype sendtype, recvtype;
283     int nprocs_completed, tmp_mask, tree_root;
284     comm_size = comm->size();
285     rank = comm->rank();
286
287     extent =datatype->get_extent();
288     datatype->extent(&true_lb, &true_extent);
289
290     bool is_commutative = (op == MPI_OP_NULL || op->is_commutative());
291
292     int* disps = new int[comm_size];
293
294     total_count = 0;
295     for (i=0; i<comm_size; i++) {
296         disps[i] = total_count;
297         total_count += recvcounts[i];
298     }
299
300             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
301
302             /* need to allocate temporary buffer to receive incoming data*/
303     unsigned char* tmp_recvbuf = smpi_get_tmp_recvbuffer(total_count * std::max(true_extent, extent));
304     /* adjust for potential negative lower bound in datatype */
305     tmp_recvbuf = tmp_recvbuf - true_lb;
306
307     /* need to allocate another temporary buffer to accumulate
308        results */
309     unsigned char* tmp_results = smpi_get_tmp_sendbuffer(total_count * std::max(true_extent, extent));
310     /* adjust for potential negative lower bound in datatype */
311     tmp_results = tmp_results - true_lb;
312
313     /* copy sendbuf into tmp_results */
314     if (sendbuf != MPI_IN_PLACE)
315       mpi_errno = Datatype::copy(sendbuf, total_count, datatype, tmp_results, total_count, datatype);
316     else
317       mpi_errno = Datatype::copy(recvbuf, total_count, datatype, tmp_results, total_count, datatype);
318
319     if (mpi_errno) {
320       delete[] disps;
321       return (mpi_errno);
322     }
323
324     mask = 0x1;
325     i    = 0;
326     while (mask < comm_size) {
327       dst = rank ^ mask;
328
329       dst_tree_root = dst >> i;
330       dst_tree_root <<= i;
331
332       my_tree_root = rank >> i;
333       my_tree_root <<= i;
334
335       /* At step 1, processes exchange (n-n/p) amount of
336          data; at step 2, (n-2n/p) amount of data; at step 3, (n-4n/p)
337          amount of data, and so forth. We use derived datatypes for this.
338
339          At each step, a process does not need to send data
340          indexed from my_tree_root to
341          my_tree_root+mask-1. Similarly, a process won't receive
342          data indexed from dst_tree_root to dst_tree_root+mask-1. */
343
344       /* calculate sendtype */
345       blklens[0] = blklens[1] = 0;
346       for (j = 0; j < my_tree_root; j++)
347         blklens[0] += recvcounts[j];
348       for (j = my_tree_root + mask; j < comm_size; j++)
349         blklens[1] += recvcounts[j];
350
351       dis[0] = 0;
352       dis[1] = blklens[0];
353       for (j = my_tree_root; (j < my_tree_root + mask) && (j < comm_size); j++)
354         dis[1] += recvcounts[j];
355
356       mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &sendtype);
357       if (mpi_errno)
358         return (mpi_errno);
359
360       sendtype->commit();
361
362       /* calculate recvtype */
363       blklens[0] = blklens[1] = 0;
364       for (j = 0; j < dst_tree_root && j < comm_size; j++)
365         blklens[0] += recvcounts[j];
366       for (j = dst_tree_root + mask; j < comm_size; j++)
367         blklens[1] += recvcounts[j];
368
369       dis[0] = 0;
370       dis[1] = blklens[0];
371       for (j = dst_tree_root; (j < dst_tree_root + mask) && (j < comm_size); j++)
372         dis[1] += recvcounts[j];
373
374       mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &recvtype);
375       if (mpi_errno)
376         return (mpi_errno);
377
378       recvtype->commit();
379
380       received = 0;
381       if (dst < comm_size) {
382         /* tmp_results contains data to be sent in each step. Data is
383            received in tmp_recvbuf and then accumulated into
384            tmp_results. accumulation is done later below.   */
385
386         Request::sendrecv(tmp_results, 1, sendtype, dst, COLL_TAG_REDUCE_SCATTER, tmp_recvbuf, 1, recvtype, dst,
387                           COLL_TAG_REDUCE_SCATTER, comm, MPI_STATUS_IGNORE);
388         received = 1;
389       }
390
391       /* if some processes in this process's subtree in this step
392          did not have any destination process to communicate with
393          because of non-power-of-two, we need to send them the
394          result. We use a logarithmic recursive-halfing algorithm
395          for this. */
396
397       if (dst_tree_root + mask > comm_size) {
398         nprocs_completed = comm_size - my_tree_root - mask;
399         /* nprocs_completed is the number of processes in this
400            subtree that have all the data. Send data to others
401            in a tree fashion. First find root of current tree
402            that is being divided into two. k is the number of
403            least-significant bits in this process's rank that
404            must be zeroed out to find the rank of the root */
405         j = mask;
406         k = 0;
407         while (j) {
408           j >>= 1;
409           k++;
410         }
411         k--;
412
413         tmp_mask = mask >> 1;
414         while (tmp_mask) {
415           dst = rank ^ tmp_mask;
416
417           tree_root = rank >> k;
418           tree_root <<= k;
419
420           /* send only if this proc has data and destination
421              doesn't have data. at any step, multiple processes
422              can send if they have the data */
423           if ((dst > rank) && (rank < tree_root + nprocs_completed) && (dst >= tree_root + nprocs_completed)) {
424             /* send the current result */
425             Request::send(tmp_recvbuf, 1, recvtype, dst, COLL_TAG_REDUCE_SCATTER, comm);
426           }
427           /* recv only if this proc. doesn't have data and sender
428              has data */
429           else if ((dst < rank) && (dst < tree_root + nprocs_completed) && (rank >= tree_root + nprocs_completed)) {
430             Request::recv(tmp_recvbuf, 1, recvtype, dst, COLL_TAG_REDUCE_SCATTER, comm, MPI_STATUS_IGNORE);
431             received = 1;
432           }
433           tmp_mask >>= 1;
434           k--;
435         }
436       }
437
438       /* The following reduction is done here instead of after
439          the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
440          because to do it above, in the noncommutative
441          case, we would need an extra temp buffer so as not to
442          overwrite temp_recvbuf, because temp_recvbuf may have
443          to be communicated to other processes in the
444          non-power-of-two case. To avoid that extra allocation,
445          we do the reduce here. */
446       if (received) {
447         if (is_commutative || (dst_tree_root < my_tree_root)) {
448           {
449             if (op != MPI_OP_NULL)
450               op->apply(tmp_recvbuf, tmp_results, &blklens[0], datatype);
451             if (op != MPI_OP_NULL)
452               op->apply(tmp_recvbuf + dis[1] * extent, tmp_results + dis[1] * extent, &blklens[1], datatype);
453           }
454         } else {
455           {
456             if (op != MPI_OP_NULL)
457               op->apply(tmp_results, tmp_recvbuf, &blklens[0], datatype);
458             if (op != MPI_OP_NULL)
459               op->apply(tmp_results + dis[1] * extent, tmp_recvbuf + dis[1] * extent, &blklens[1], datatype);
460           }
461           /* copy result back into tmp_results */
462           mpi_errno = Datatype::copy(tmp_recvbuf, 1, recvtype, tmp_results, 1, recvtype);
463           if (mpi_errno)
464             return (mpi_errno);
465         }
466       }
467
468       Datatype::unref(sendtype);
469       Datatype::unref(recvtype);
470
471       mask <<= 1;
472       i++;
473             }
474
475             /* now copy final results from tmp_results to recvbuf */
476             mpi_errno = Datatype::copy(tmp_results + disps[rank] * extent, recvcounts[rank], datatype, recvbuf,
477                                        recvcounts[rank], datatype);
478             if (mpi_errno) return(mpi_errno);
479
480     delete[] disps;
481     smpi_free_tmp_buffer(tmp_recvbuf);
482     smpi_free_tmp_buffer(tmp_results);
483     return MPI_SUCCESS;
484 }
485 } // namespace simgrid::smpi