Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
bcd2dce83d9ea8bac983f91910b8b15ea0b62e97
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-ompi.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2012 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
19  * Copyright (c) 2009      University of Houston. All rights reserved.
20  *
21  * Additional copyrights may follow
22  */
23
24 #include "../coll_tuned_topo.hpp"
25 #include "../colls_private.hpp"
26
27 /*
28  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
29  * I have removed the part which handles "large" message sizes
30  * (non-overlapping version of reduce_Scatter).
31  */
32
33 /* copied function (with appropriate renaming) starts here */
34
35 /*
36  *  reduce_scatter_ompi_basic_recursivehalving
37  *
38  *  Function:   - reduce scatter implementation using recursive-halving
39  *                algorithm
40  *  Accepts:    - same as MPI_Reduce_scatter()
41  *  Returns:    - MPI_SUCCESS or error code
42  *  Limitation: - Works only for commutative operations.
43  */
44 namespace simgrid{
45 namespace smpi{
46 int
47 Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(const void *sbuf,
48                                                             void *rbuf,
49                                                             const int *rcounts,
50                                                             MPI_Datatype dtype,
51                                                             MPI_Op op,
52                                                             MPI_Comm comm
53                                                             )
54 {
55     int i, rank, size, count, err = MPI_SUCCESS;
56     int tmp_size=1, remain = 0, tmp_rank, *disps = NULL;
57     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
58     char *recv_buf = NULL, *recv_buf_free = NULL;
59     char *result_buf = NULL, *result_buf_free = NULL;
60
61     /* Initialize */
62     rank = comm->rank();
63     size = comm->size();
64
65     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
66     if ((op != MPI_OP_NULL && not op->is_commutative()))
67       THROWF(arg_error,0, " reduce_scatter ompi_basic_recursivehalving can only be used for commutative operations! ");
68
69     /* Find displacements and the like */
70     disps = (int*) xbt_malloc(sizeof(int) * size);
71     if (NULL == disps) return MPI_ERR_OTHER;
72
73     disps[0] = 0;
74     for (i = 0; i < (size - 1); ++i) {
75         disps[i + 1] = disps[i] + rcounts[i];
76     }
77     count = disps[size - 1] + rcounts[size - 1];
78
79     /* short cut the trivial case */
80     if (0 == count) {
81         xbt_free(disps);
82         return MPI_SUCCESS;
83     }
84
85     /* get datatype information */
86     dtype->extent(&lb, &extent);
87     dtype->extent(&true_lb, &true_extent);
88     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
89
90     /* Handle MPI_IN_PLACE */
91     if (MPI_IN_PLACE == sbuf) {
92         sbuf = rbuf;
93     }
94
95     /* Allocate temporary receive buffer. */
96     recv_buf_free = (char*) smpi_get_tmp_recvbuffer(buf_size);
97
98     recv_buf = recv_buf_free - lb;
99     if (NULL == recv_buf_free) {
100         err = MPI_ERR_OTHER;
101         goto cleanup;
102     }
103
104     /* allocate temporary buffer for results */
105     result_buf_free = (char*) smpi_get_tmp_sendbuffer(buf_size);
106
107     result_buf = result_buf_free - lb;
108
109     /* copy local buffer into the temporary results */
110     err =Datatype::copy(sbuf, count, dtype, result_buf, count, dtype);
111     if (MPI_SUCCESS != err) goto cleanup;
112
113     /* figure out power of two mapping: grow until larger than
114        comm size, then go back one, to get the largest power of
115        two less than comm size */
116     while (tmp_size <= size) tmp_size <<= 1;
117     tmp_size >>= 1;
118     remain = size - tmp_size;
119
120     /* If comm size is not a power of two, have the first "remain"
121        procs with an even rank send to rank + 1, leaving a power of
122        two procs to do the rest of the algorithm */
123     if (rank < 2 * remain) {
124         if ((rank & 1) == 0) {
125             Request::send(result_buf, count, dtype, rank + 1,
126                                     COLL_TAG_REDUCE_SCATTER,
127                                     comm);
128             /* we don't participate from here on out */
129             tmp_rank = -1;
130         } else {
131             Request::recv(recv_buf, count, dtype, rank - 1,
132                                     COLL_TAG_REDUCE_SCATTER,
133                                     comm, MPI_STATUS_IGNORE);
134
135             /* integrate their results into our temp results */
136             if(op!=MPI_OP_NULL) op->apply( recv_buf, result_buf, &count, dtype);
137
138             /* adjust rank to be the bottom "remain" ranks */
139             tmp_rank = rank / 2;
140         }
141     } else {
142         /* just need to adjust rank to show that the bottom "even
143            remain" ranks dropped out */
144         tmp_rank = rank - remain;
145     }
146
147     /* For ranks not kicked out by the above code, perform the
148        recursive halving */
149     if (tmp_rank >= 0) {
150         int *tmp_disps = NULL, *tmp_rcounts = NULL;
151         int mask, send_index, recv_index, last_index;
152
153         /* recalculate disps and rcounts to account for the
154            special "remainder" processes that are no longer doing
155            anything */
156         tmp_rcounts = (int*) xbt_malloc(tmp_size * sizeof(int));
157         if (NULL == tmp_rcounts) {
158             err = MPI_ERR_OTHER;
159             goto cleanup;
160         }
161         tmp_disps = (int*) xbt_malloc(tmp_size * sizeof(int));
162         if (NULL == tmp_disps) {
163             xbt_free(tmp_rcounts);
164             err = MPI_ERR_OTHER;
165             goto cleanup;
166         }
167
168         for (i = 0 ; i < tmp_size ; ++i) {
169             if (i < remain) {
170                 /* need to include old neighbor as well */
171                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
172             } else {
173                 tmp_rcounts[i] = rcounts[i + remain];
174             }
175         }
176
177         tmp_disps[0] = 0;
178         for (i = 0; i < tmp_size - 1; ++i) {
179             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
180         }
181
182         /* do the recursive halving communication.  Don't use the
183            dimension information on the communicator because I
184            think the information is invalidated by our "shrinking"
185            of the communicator */
186         mask = tmp_size >> 1;
187         send_index = recv_index = 0;
188         last_index = tmp_size;
189         while (mask > 0) {
190             int tmp_peer, peer, send_count, recv_count;
191             MPI_Request request;
192
193             tmp_peer = tmp_rank ^ mask;
194             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
195
196             /* figure out if we're sending, receiving, or both */
197             send_count = recv_count = 0;
198             if (tmp_rank < tmp_peer) {
199                 send_index = recv_index + mask;
200                 for (i = send_index ; i < last_index ; ++i) {
201                     send_count += tmp_rcounts[i];
202                 }
203                 for (i = recv_index ; i < send_index ; ++i) {
204                     recv_count += tmp_rcounts[i];
205                 }
206             } else {
207                 recv_index = send_index + mask;
208                 for (i = send_index ; i < recv_index ; ++i) {
209                     send_count += tmp_rcounts[i];
210                 }
211                 for (i = recv_index ; i < last_index ; ++i) {
212                     recv_count += tmp_rcounts[i];
213                 }
214             }
215
216             /* actual data transfer.  Send from result_buf,
217                receive into recv_buf */
218             if (send_count > 0 && recv_count != 0) {
219                 request=Request::irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
220                                          recv_count, dtype, peer,
221                                          COLL_TAG_REDUCE_SCATTER,
222                                          comm);
223                 if (MPI_SUCCESS != err) {
224                     xbt_free(tmp_rcounts);
225                     xbt_free(tmp_disps);
226                     goto cleanup;
227                 }
228             }
229             if (recv_count > 0 && send_count != 0) {
230                 Request::send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
231                                         send_count, dtype, peer,
232                                         COLL_TAG_REDUCE_SCATTER,
233                                         comm);
234                 if (MPI_SUCCESS != err) {
235                     xbt_free(tmp_rcounts);
236                     xbt_free(tmp_disps);
237                     goto cleanup;
238                 }
239             }
240             if (send_count > 0 && recv_count != 0) {
241                 Request::wait(&request, MPI_STATUS_IGNORE);
242             }
243
244             /* if we received something on this step, push it into
245                the results buffer */
246             if (recv_count > 0) {
247                 if(op!=MPI_OP_NULL) op->apply(
248                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
249                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
250                                &recv_count, dtype);
251             }
252
253             /* update for next iteration */
254             send_index = recv_index;
255             last_index = recv_index + mask;
256             mask >>= 1;
257         }
258
259         /* copy local results from results buffer into real receive buffer */
260         if (0 != rcounts[rank]) {
261             err = Datatype::copy(result_buf + disps[rank] * extent,
262                                        rcounts[rank], dtype,
263                                        rbuf, rcounts[rank], dtype);
264             if (MPI_SUCCESS != err) {
265                 xbt_free(tmp_rcounts);
266                 xbt_free(tmp_disps);
267                 goto cleanup;
268             }
269         }
270
271         xbt_free(tmp_rcounts);
272         xbt_free(tmp_disps);
273     }
274
275     /* Now fix up the non-power of two case, by having the odd
276        procs send the even procs the proper results */
277     if (rank < (2 * remain)) {
278         if ((rank & 1) == 0) {
279             if (rcounts[rank]) {
280                 Request::recv(rbuf, rcounts[rank], dtype, rank + 1,
281                                         COLL_TAG_REDUCE_SCATTER,
282                                         comm, MPI_STATUS_IGNORE);
283             }
284         } else {
285             if (rcounts[rank - 1]) {
286                 Request::send(result_buf + disps[rank - 1] * extent,
287                                         rcounts[rank - 1], dtype, rank - 1,
288                                         COLL_TAG_REDUCE_SCATTER,
289                                         comm);
290             }
291         }
292     }
293
294  cleanup:
295     if (NULL != disps) xbt_free(disps);
296     if (NULL != recv_buf_free) smpi_free_tmp_buffer(recv_buf_free);
297     if (NULL != result_buf_free) smpi_free_tmp_buffer(result_buf_free);
298
299     return err;
300 }
301
302 /* copied function (with appropriate renaming) ends here */
303
304
305 /*
306  *   Coll_reduce_scatter_ompi_ring::reduce_scatter
307  *
308  *   Function:       Ring algorithm for reduce_scatter operation
309  *   Accepts:        Same as MPI_Reduce_scatter()
310  *   Returns:        MPI_SUCCESS or error code
311  *
312  *   Description:    Implements ring algorithm for reduce_scatter:
313  *                   the block sizes defined in rcounts are exchanged and
314  8                    updated until they reach proper destination.
315  *                   Algorithm requires 2 * max(rcounts) extra buffering
316  *
317  *   Limitations:    The algorithm DOES NOT preserve order of operations so it
318  *                   can be used only for commutative operations.
319  *         Example on 5 nodes:
320  *         Initial state
321  *   #      0              1             2              3             4
322  *        [00]           [10]   ->     [20]           [30]           [40]
323  *        [01]           [11]          [21]  ->       [31]           [41]
324  *        [02]           [12]          [22]           [32]  ->       [42]
325  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
326  *        [04]  ->       [14]          [24]           [34]           [44]
327  *
328  *        COMPUTATION PHASE
329  *         Step 0: rank r sends block (r-1) to rank (r+1) and
330  *                 receives block (r+1) from rank (r-1) [with wraparound].
331  *   #      0              1             2              3             4
332  *        [00]           [10]        [10+20]   ->     [30]           [40]
333  *        [01]           [11]          [21]          [21+31]  ->     [41]
334  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
335  *      [43+03] ->       [13]          [23]           [33]           [43]
336  *        [04]         [04+14]  ->     [24]           [34]           [44]
337  *
338  *         Step 1:
339  *   #      0              1             2              3             4
340  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
341  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
342  *     [32+42+02] ->     [12]          [22]           [32]         [32+42]
343  *        [03]        [43+03+13] ->    [23]           [33]           [43]
344  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
345  *
346  *         Step 2:
347  *   #      0              1             2              3             4
348  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
349  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
350  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42]
351  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
352  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
353  *
354  *         Step 3:
355  *   #      0             1              2              3             4
356  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
357  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
358  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42]
359  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
360  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
361  *    DONE :)
362  *
363  */
364 int
365 Coll_reduce_scatter_ompi_ring::reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
366                                           MPI_Datatype dtype,
367                                           MPI_Op op,
368                                           MPI_Comm comm
369                                           )
370 {
371     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
372     int inbi, *displs = NULL;
373     char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
374     char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
375     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
376     MPI_Request reqs[2] = {NULL, NULL};
377
378     size = comm->size();
379     rank = comm->rank();
380
381     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d",
382                  rank, size);
383
384     /* Determine the maximum number of elements per node,
385        corresponding block size, and displacements array.
386     */
387     displs = (int*) xbt_malloc(size * sizeof(int));
388     if (NULL == displs) { ret = -1; line = __LINE__; goto error_hndl; }
389     displs[0] = 0;
390     total_count = rcounts[0];
391     max_block_count = rcounts[0];
392     for (i = 1; i < size; i++) {
393         displs[i] = total_count;
394         total_count += rcounts[i];
395         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
396     }
397
398     /* Special case for size == 1 */
399     if (1 == size) {
400         if (MPI_IN_PLACE != sbuf) {
401             ret = Datatype::copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
402             if (ret < 0) { line = __LINE__; goto error_hndl; }
403         }
404         xbt_free(displs);
405         return MPI_SUCCESS;
406     }
407
408     /* Allocate and initialize temporary buffers, we need:
409        - a temporary buffer to perform reduction (size total_count) since
410        rbuf can be of rcounts[rank] size.
411        - up to two temporary buffers used for communication/computation overlap.
412     */
413     dtype->extent(&lb, &extent);
414     dtype->extent(&true_lb, &true_extent);
415
416     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
417
418     accumbuf_free = (char*)smpi_get_tmp_recvbuffer(true_extent + (ptrdiff_t)(total_count - 1) * extent);
419     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
420     accumbuf = accumbuf_free - lb;
421
422     inbuf_free[0] = (char*)smpi_get_tmp_sendbuffer(max_real_segsize);
423     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
424     inbuf[0] = inbuf_free[0] - lb;
425     if (size > 2) {
426         inbuf_free[1] = (char*)smpi_get_tmp_sendbuffer(max_real_segsize);
427         if (NULL == inbuf_free[1]) { ret = -1; line = __LINE__; goto error_hndl; }
428         inbuf[1] = inbuf_free[1] - lb;
429     }
430
431     /* Handle MPI_IN_PLACE for size > 1 */
432     if (MPI_IN_PLACE == sbuf) {
433         sbuf = rbuf;
434     }
435
436     ret = Datatype::copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
437     if (ret < 0) { line = __LINE__; goto error_hndl; }
438
439     /* Computation loop */
440
441     /*
442        For each of the remote nodes:
443        - post irecv for block (r-2) from (r-1) with wrap around
444        - send block (r-1) to (r+1)
445        - in loop for every step k = 2 .. n
446        - post irecv for block (r - 1 + n - k) % n
447        - wait on block (r + n - k) % n to arrive
448        - compute on block (r + n - k ) % n
449        - send block (r + n - k) % n
450        - wait on block (r)
451        - compute on block (r)
452        - copy block (r) to rbuf
453        Note that we must be careful when computing the begining of buffers and
454        for send operations and computation we must compute the exact block size.
455     */
456     send_to = (rank + 1) % size;
457     recv_from = (rank + size - 1) % size;
458
459     inbi = 0;
460     /* Initialize first receive from the neighbor on the left */
461     reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
462                              COLL_TAG_REDUCE_SCATTER, comm
463                              );
464     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
465     Request::send(tmpsend, rcounts[recv_from], dtype, send_to,
466                             COLL_TAG_REDUCE_SCATTER,
467                              comm);
468
469     for (k = 2; k < size; k++) {
470         const int prevblock = (rank + size - k) % size;
471
472         inbi = inbi ^ 0x1;
473
474         /* Post irecv for the current block */
475         reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
476                                  COLL_TAG_REDUCE_SCATTER, comm
477                                  );
478
479         /* Wait on previous block to arrive */
480         Request::wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
481
482         /* Apply operation on previous block: result goes to rbuf
483            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
484         */
485         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
486         if (op != MPI_OP_NULL)
487           op->apply(inbuf[inbi ^ 0x1], tmprecv, &rcounts[prevblock], dtype);
488
489         /* send previous block to send_to */
490         Request::send(tmprecv, rcounts[prevblock], dtype, send_to,
491                                 COLL_TAG_REDUCE_SCATTER,
492                                  comm);
493     }
494
495     /* Wait on the last block to arrive */
496     Request::wait(&reqs[inbi], MPI_STATUS_IGNORE);
497
498     /* Apply operation on the last block (my block)
499        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
500     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
501     if (op != MPI_OP_NULL)
502       op->apply(inbuf[inbi], tmprecv, &rcounts[rank], dtype);
503
504     /* Copy result from tmprecv to rbuf */
505     ret = Datatype::copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
506     if (ret < 0) { line = __LINE__; goto error_hndl; }
507
508     if (NULL != displs) xbt_free(displs);
509     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
510     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
511     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
512
513     return MPI_SUCCESS;
514
515  error_hndl:
516     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
517                  __FILE__, line, rank, ret);
518     if (NULL != displs) xbt_free(displs);
519     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
520     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
521     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
522     return ret;
523 }
524 }
525 }
526