Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Throw std::invalid_argument.
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-ompi.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2012 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
19  * Copyright (c) 2009      University of Houston. All rights reserved.
20  *
21  * Additional copyrights may follow
22  */
23
24 #include "../coll_tuned_topo.hpp"
25 #include "../colls_private.hpp"
26
27 /*
28  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
29  * I have removed the part which handles "large" message sizes
30  * (non-overlapping version of reduce_Scatter).
31  */
32
33 /* copied function (with appropriate renaming) starts here */
34
35 /*
36  *  reduce_scatter_ompi_basic_recursivehalving
37  *
38  *  Function:   - reduce scatter implementation using recursive-halving
39  *                algorithm
40  *  Accepts:    - same as MPI_Reduce_scatter()
41  *  Returns:    - MPI_SUCCESS or error code
42  *  Limitation: - Works only for commutative operations.
43  */
44 namespace simgrid{
45 namespace smpi{
46 int
47 Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(const void *sbuf,
48                                                             void *rbuf,
49                                                             const int *rcounts,
50                                                             MPI_Datatype dtype,
51                                                             MPI_Op op,
52                                                             MPI_Comm comm
53                                                             )
54 {
55     int i, rank, size, count, err = MPI_SUCCESS;
56     int tmp_size = 1, remain = 0, tmp_rank;
57     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
58     unsigned char *result_buf = nullptr, *result_buf_free = nullptr;
59
60     /* Initialize */
61     rank = comm->rank();
62     size = comm->size();
63
64     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
65     if ((op != MPI_OP_NULL && not op->is_commutative()))
66       throw std::invalid_argument(
67           "reduce_scatter ompi_basic_recursivehalving can only be used for commutative operations!");
68
69     /* Find displacements and the like */
70     int* disps = new int[size];
71
72     disps[0] = 0;
73     for (i = 0; i < (size - 1); ++i) {
74         disps[i + 1] = disps[i] + rcounts[i];
75     }
76     count = disps[size - 1] + rcounts[size - 1];
77
78     /* short cut the trivial case */
79     if (0 == count) {
80       delete[] disps;
81       return MPI_SUCCESS;
82     }
83
84     /* get datatype information */
85     dtype->extent(&lb, &extent);
86     dtype->extent(&true_lb, &true_extent);
87     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
88
89     /* Handle MPI_IN_PLACE */
90     if (MPI_IN_PLACE == sbuf) {
91         sbuf = rbuf;
92     }
93
94     /* Allocate temporary receive buffer. */
95     unsigned char* recv_buf_free = smpi_get_tmp_recvbuffer(buf_size);
96     unsigned char* recv_buf      = recv_buf_free - lb;
97     if (NULL == recv_buf_free) {
98         err = MPI_ERR_OTHER;
99         goto cleanup;
100     }
101
102     /* allocate temporary buffer for results */
103     result_buf_free = smpi_get_tmp_sendbuffer(buf_size);
104     result_buf = result_buf_free - lb;
105
106     /* copy local buffer into the temporary results */
107     err =Datatype::copy(sbuf, count, dtype, result_buf, count, dtype);
108     if (MPI_SUCCESS != err) goto cleanup;
109
110     /* figure out power of two mapping: grow until larger than
111        comm size, then go back one, to get the largest power of
112        two less than comm size */
113     while (tmp_size <= size) tmp_size <<= 1;
114     tmp_size >>= 1;
115     remain = size - tmp_size;
116
117     /* If comm size is not a power of two, have the first "remain"
118        procs with an even rank send to rank + 1, leaving a power of
119        two procs to do the rest of the algorithm */
120     if (rank < 2 * remain) {
121         if ((rank & 1) == 0) {
122             Request::send(result_buf, count, dtype, rank + 1,
123                                     COLL_TAG_REDUCE_SCATTER,
124                                     comm);
125             /* we don't participate from here on out */
126             tmp_rank = -1;
127         } else {
128             Request::recv(recv_buf, count, dtype, rank - 1,
129                                     COLL_TAG_REDUCE_SCATTER,
130                                     comm, MPI_STATUS_IGNORE);
131
132             /* integrate their results into our temp results */
133             if(op!=MPI_OP_NULL) op->apply( recv_buf, result_buf, &count, dtype);
134
135             /* adjust rank to be the bottom "remain" ranks */
136             tmp_rank = rank / 2;
137         }
138     } else {
139         /* just need to adjust rank to show that the bottom "even
140            remain" ranks dropped out */
141         tmp_rank = rank - remain;
142     }
143
144     /* For ranks not kicked out by the above code, perform the
145        recursive halving */
146     if (tmp_rank >= 0) {
147         int mask, send_index, recv_index, last_index;
148
149         /* recalculate disps and rcounts to account for the
150            special "remainder" processes that are no longer doing
151            anything */
152         int* tmp_rcounts = new int[tmp_size];
153         int* tmp_disps   = new int[tmp_size];
154
155         for (i = 0 ; i < tmp_size ; ++i) {
156             if (i < remain) {
157                 /* need to include old neighbor as well */
158                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
159             } else {
160                 tmp_rcounts[i] = rcounts[i + remain];
161             }
162         }
163
164         tmp_disps[0] = 0;
165         for (i = 0; i < tmp_size - 1; ++i) {
166             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
167         }
168
169         /* do the recursive halving communication.  Don't use the
170            dimension information on the communicator because I
171            think the information is invalidated by our "shrinking"
172            of the communicator */
173         mask = tmp_size >> 1;
174         send_index = recv_index = 0;
175         last_index = tmp_size;
176         while (mask > 0) {
177             int tmp_peer, peer, send_count, recv_count;
178             MPI_Request request;
179
180             tmp_peer = tmp_rank ^ mask;
181             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
182
183             /* figure out if we're sending, receiving, or both */
184             send_count = recv_count = 0;
185             if (tmp_rank < tmp_peer) {
186                 send_index = recv_index + mask;
187                 for (i = send_index ; i < last_index ; ++i) {
188                     send_count += tmp_rcounts[i];
189                 }
190                 for (i = recv_index ; i < send_index ; ++i) {
191                     recv_count += tmp_rcounts[i];
192                 }
193             } else {
194                 recv_index = send_index + mask;
195                 for (i = send_index ; i < recv_index ; ++i) {
196                     send_count += tmp_rcounts[i];
197                 }
198                 for (i = recv_index ; i < last_index ; ++i) {
199                     recv_count += tmp_rcounts[i];
200                 }
201             }
202
203             /* actual data transfer.  Send from result_buf,
204                receive into recv_buf */
205             if (send_count > 0 && recv_count != 0) {
206                 request=Request::irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
207                                          recv_count, dtype, peer,
208                                          COLL_TAG_REDUCE_SCATTER,
209                                          comm);
210                 if (MPI_SUCCESS != err) {
211                   delete[] tmp_rcounts;
212                   delete[] tmp_disps;
213                   goto cleanup;
214                 }
215             }
216             if (recv_count > 0 && send_count != 0) {
217                 Request::send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
218                                         send_count, dtype, peer,
219                                         COLL_TAG_REDUCE_SCATTER,
220                                         comm);
221                 if (MPI_SUCCESS != err) {
222                   delete[] tmp_rcounts;
223                   delete[] tmp_disps;
224                   goto cleanup;
225                 }
226             }
227             if (send_count > 0 && recv_count != 0) {
228                 Request::wait(&request, MPI_STATUS_IGNORE);
229             }
230
231             /* if we received something on this step, push it into
232                the results buffer */
233             if (recv_count > 0) {
234                 if(op!=MPI_OP_NULL) op->apply(
235                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
236                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
237                                &recv_count, dtype);
238             }
239
240             /* update for next iteration */
241             send_index = recv_index;
242             last_index = recv_index + mask;
243             mask >>= 1;
244         }
245
246         /* copy local results from results buffer into real receive buffer */
247         if (0 != rcounts[rank]) {
248             err = Datatype::copy(result_buf + disps[rank] * extent,
249                                        rcounts[rank], dtype,
250                                        rbuf, rcounts[rank], dtype);
251             if (MPI_SUCCESS != err) {
252               delete[] tmp_rcounts;
253               delete[] tmp_disps;
254               goto cleanup;
255             }
256         }
257
258         delete[] tmp_rcounts;
259         delete[] tmp_disps;
260     }
261
262     /* Now fix up the non-power of two case, by having the odd
263        procs send the even procs the proper results */
264     if (rank < (2 * remain)) {
265         if ((rank & 1) == 0) {
266             if (rcounts[rank]) {
267                 Request::recv(rbuf, rcounts[rank], dtype, rank + 1,
268                                         COLL_TAG_REDUCE_SCATTER,
269                                         comm, MPI_STATUS_IGNORE);
270             }
271         } else {
272             if (rcounts[rank - 1]) {
273                 Request::send(result_buf + disps[rank - 1] * extent,
274                                         rcounts[rank - 1], dtype, rank - 1,
275                                         COLL_TAG_REDUCE_SCATTER,
276                                         comm);
277             }
278         }
279     }
280
281  cleanup:
282     delete[] disps;
283     if (NULL != recv_buf_free) smpi_free_tmp_buffer(recv_buf_free);
284     if (NULL != result_buf_free) smpi_free_tmp_buffer(result_buf_free);
285
286     return err;
287 }
288
289 /* copied function (with appropriate renaming) ends here */
290
291
292 /*
293  *   Coll_reduce_scatter_ompi_ring::reduce_scatter
294  *
295  *   Function:       Ring algorithm for reduce_scatter operation
296  *   Accepts:        Same as MPI_Reduce_scatter()
297  *   Returns:        MPI_SUCCESS or error code
298  *
299  *   Description:    Implements ring algorithm for reduce_scatter:
300  *                   the block sizes defined in rcounts are exchanged and
301  8                    updated until they reach proper destination.
302  *                   Algorithm requires 2 * max(rcounts) extra buffering
303  *
304  *   Limitations:    The algorithm DOES NOT preserve order of operations so it
305  *                   can be used only for commutative operations.
306  *         Example on 5 nodes:
307  *         Initial state
308  *   #      0              1             2              3             4
309  *        [00]           [10]   ->     [20]           [30]           [40]
310  *        [01]           [11]          [21]  ->       [31]           [41]
311  *        [02]           [12]          [22]           [32]  ->       [42]
312  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
313  *        [04]  ->       [14]          [24]           [34]           [44]
314  *
315  *        COMPUTATION PHASE
316  *         Step 0: rank r sends block (r-1) to rank (r+1) and
317  *                 receives block (r+1) from rank (r-1) [with wraparound].
318  *   #      0              1             2              3             4
319  *        [00]           [10]        [10+20]   ->     [30]           [40]
320  *        [01]           [11]          [21]          [21+31]  ->     [41]
321  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
322  *      [43+03] ->       [13]          [23]           [33]           [43]
323  *        [04]         [04+14]  ->     [24]           [34]           [44]
324  *
325  *         Step 1:
326  *   #      0              1             2              3             4
327  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
328  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
329  *     [32+42+02] ->     [12]          [22]           [32]         [32+42]
330  *        [03]        [43+03+13] ->    [23]           [33]           [43]
331  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
332  *
333  *         Step 2:
334  *   #      0              1             2              3             4
335  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
336  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
337  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42]
338  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
339  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
340  *
341  *         Step 3:
342  *   #      0             1              2              3             4
343  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
344  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
345  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42]
346  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
347  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
348  *    DONE :)
349  *
350  */
351 int
352 Coll_reduce_scatter_ompi_ring::reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
353                                           MPI_Datatype dtype,
354                                           MPI_Op op,
355                                           MPI_Comm comm
356                                           )
357 {
358     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
359     int inbi;
360     unsigned char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
361     unsigned char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
362     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
363     MPI_Request reqs[2] = {NULL, NULL};
364
365     size = comm->size();
366     rank = comm->rank();
367
368     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d",
369                  rank, size);
370
371     /* Determine the maximum number of elements per node,
372        corresponding block size, and displacements array.
373     */
374     int* displs = new int[size];
375
376     displs[0] = 0;
377     total_count = rcounts[0];
378     max_block_count = rcounts[0];
379     for (i = 1; i < size; i++) {
380         displs[i] = total_count;
381         total_count += rcounts[i];
382         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
383     }
384
385     /* Special case for size == 1 */
386     if (1 == size) {
387         if (MPI_IN_PLACE != sbuf) {
388             ret = Datatype::copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
389             if (ret < 0) { line = __LINE__; goto error_hndl; }
390         }
391         delete[] displs;
392         return MPI_SUCCESS;
393     }
394
395     /* Allocate and initialize temporary buffers, we need:
396        - a temporary buffer to perform reduction (size total_count) since
397        rbuf can be of rcounts[rank] size.
398        - up to two temporary buffers used for communication/computation overlap.
399     */
400     dtype->extent(&lb, &extent);
401     dtype->extent(&true_lb, &true_extent);
402
403     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
404
405     accumbuf_free = smpi_get_tmp_recvbuffer(true_extent + (ptrdiff_t)(total_count - 1) * extent);
406     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
407     accumbuf = accumbuf_free - lb;
408
409     inbuf_free[0] = smpi_get_tmp_sendbuffer(max_real_segsize);
410     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
411     inbuf[0] = inbuf_free[0] - lb;
412     if (size > 2) {
413       inbuf_free[1] = smpi_get_tmp_sendbuffer(max_real_segsize);
414       if (NULL == inbuf_free[1]) {
415         ret  = -1;
416         line = __LINE__;
417         goto error_hndl;
418       }
419       inbuf[1] = inbuf_free[1] - lb;
420     }
421
422     /* Handle MPI_IN_PLACE for size > 1 */
423     if (MPI_IN_PLACE == sbuf) {
424         sbuf = rbuf;
425     }
426
427     ret = Datatype::copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
428     if (ret < 0) { line = __LINE__; goto error_hndl; }
429
430     /* Computation loop */
431
432     /*
433        For each of the remote nodes:
434        - post irecv for block (r-2) from (r-1) with wrap around
435        - send block (r-1) to (r+1)
436        - in loop for every step k = 2 .. n
437        - post irecv for block (r - 1 + n - k) % n
438        - wait on block (r + n - k) % n to arrive
439        - compute on block (r + n - k ) % n
440        - send block (r + n - k) % n
441        - wait on block (r)
442        - compute on block (r)
443        - copy block (r) to rbuf
444        Note that we must be careful when computing the begining of buffers and
445        for send operations and computation we must compute the exact block size.
446     */
447     send_to = (rank + 1) % size;
448     recv_from = (rank + size - 1) % size;
449
450     inbi = 0;
451     /* Initialize first receive from the neighbor on the left */
452     reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
453                              COLL_TAG_REDUCE_SCATTER, comm
454                              );
455     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
456     Request::send(tmpsend, rcounts[recv_from], dtype, send_to,
457                             COLL_TAG_REDUCE_SCATTER,
458                              comm);
459
460     for (k = 2; k < size; k++) {
461         const int prevblock = (rank + size - k) % size;
462
463         inbi = inbi ^ 0x1;
464
465         /* Post irecv for the current block */
466         reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
467                                  COLL_TAG_REDUCE_SCATTER, comm
468                                  );
469
470         /* Wait on previous block to arrive */
471         Request::wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
472
473         /* Apply operation on previous block: result goes to rbuf
474            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
475         */
476         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
477         if (op != MPI_OP_NULL)
478           op->apply(inbuf[inbi ^ 0x1], tmprecv, &rcounts[prevblock], dtype);
479
480         /* send previous block to send_to */
481         Request::send(tmprecv, rcounts[prevblock], dtype, send_to,
482                                 COLL_TAG_REDUCE_SCATTER,
483                                  comm);
484     }
485
486     /* Wait on the last block to arrive */
487     Request::wait(&reqs[inbi], MPI_STATUS_IGNORE);
488
489     /* Apply operation on the last block (my block)
490        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
491     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
492     if (op != MPI_OP_NULL)
493       op->apply(inbuf[inbi], tmprecv, &rcounts[rank], dtype);
494
495     /* Copy result from tmprecv to rbuf */
496     ret = Datatype::copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
497     if (ret < 0) { line = __LINE__; goto error_hndl; }
498
499     delete[] displs;
500     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
501     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
502     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
503
504     return MPI_SUCCESS;
505
506  error_hndl:
507     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
508                  __FILE__, line, rank, ret);
509     delete[] displs;
510     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
511     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
512     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
513     return ret;
514 }
515 }
516 }
517