Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
4b3a479a280829f45f75caad13754454fdd432ce
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-ompi.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2012 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
19  * Copyright (c) 2009      University of Houston. All rights reserved.
20  *
21  * Additional copyrights may follow
22  */
23
24 #include "../coll_tuned_topo.hpp"
25 #include "../colls_private.hpp"
26
27 /*
28  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
29  * I have removed the part which handles "large" message sizes
30  * (non-overlapping version of reduce_Scatter).
31  */
32
33 /* copied function (with appropriate renaming) starts here */
34
35 /*
36  *  reduce_scatter_ompi_basic_recursivehalving
37  *
38  *  Function:   - reduce scatter implementation using recursive-halving
39  *                algorithm
40  *  Accepts:    - same as MPI_Reduce_scatter()
41  *  Returns:    - MPI_SUCCESS or error code
42  *  Limitation: - Works only for commutative operations.
43  */
44 namespace simgrid{
45 namespace smpi{
46 int
47 Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(const void *sbuf,
48                                                             void *rbuf,
49                                                             const int *rcounts,
50                                                             MPI_Datatype dtype,
51                                                             MPI_Op op,
52                                                             MPI_Comm comm
53                                                             )
54 {
55     int i, rank, size, count, err = MPI_SUCCESS;
56     int tmp_size = 1, remain = 0, tmp_rank;
57     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
58     char *recv_buf = NULL, *recv_buf_free = NULL;
59     char *result_buf = NULL, *result_buf_free = NULL;
60
61     /* Initialize */
62     rank = comm->rank();
63     size = comm->size();
64
65     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
66     if ((op != MPI_OP_NULL && not op->is_commutative()))
67       THROWF(arg_error,0, " reduce_scatter ompi_basic_recursivehalving can only be used for commutative operations! ");
68
69     /* Find displacements and the like */
70     int* disps = new int[size];
71
72     disps[0] = 0;
73     for (i = 0; i < (size - 1); ++i) {
74         disps[i + 1] = disps[i] + rcounts[i];
75     }
76     count = disps[size - 1] + rcounts[size - 1];
77
78     /* short cut the trivial case */
79     if (0 == count) {
80       delete[] disps;
81       return MPI_SUCCESS;
82     }
83
84     /* get datatype information */
85     dtype->extent(&lb, &extent);
86     dtype->extent(&true_lb, &true_extent);
87     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
88
89     /* Handle MPI_IN_PLACE */
90     if (MPI_IN_PLACE == sbuf) {
91         sbuf = rbuf;
92     }
93
94     /* Allocate temporary receive buffer. */
95     recv_buf_free = (char*) smpi_get_tmp_recvbuffer(buf_size);
96
97     recv_buf = recv_buf_free - lb;
98     if (NULL == recv_buf_free) {
99         err = MPI_ERR_OTHER;
100         goto cleanup;
101     }
102
103     /* allocate temporary buffer for results */
104     result_buf_free = (char*) smpi_get_tmp_sendbuffer(buf_size);
105
106     result_buf = result_buf_free - lb;
107
108     /* copy local buffer into the temporary results */
109     err =Datatype::copy(sbuf, count, dtype, result_buf, count, dtype);
110     if (MPI_SUCCESS != err) goto cleanup;
111
112     /* figure out power of two mapping: grow until larger than
113        comm size, then go back one, to get the largest power of
114        two less than comm size */
115     while (tmp_size <= size) tmp_size <<= 1;
116     tmp_size >>= 1;
117     remain = size - tmp_size;
118
119     /* If comm size is not a power of two, have the first "remain"
120        procs with an even rank send to rank + 1, leaving a power of
121        two procs to do the rest of the algorithm */
122     if (rank < 2 * remain) {
123         if ((rank & 1) == 0) {
124             Request::send(result_buf, count, dtype, rank + 1,
125                                     COLL_TAG_REDUCE_SCATTER,
126                                     comm);
127             /* we don't participate from here on out */
128             tmp_rank = -1;
129         } else {
130             Request::recv(recv_buf, count, dtype, rank - 1,
131                                     COLL_TAG_REDUCE_SCATTER,
132                                     comm, MPI_STATUS_IGNORE);
133
134             /* integrate their results into our temp results */
135             if(op!=MPI_OP_NULL) op->apply( recv_buf, result_buf, &count, dtype);
136
137             /* adjust rank to be the bottom "remain" ranks */
138             tmp_rank = rank / 2;
139         }
140     } else {
141         /* just need to adjust rank to show that the bottom "even
142            remain" ranks dropped out */
143         tmp_rank = rank - remain;
144     }
145
146     /* For ranks not kicked out by the above code, perform the
147        recursive halving */
148     if (tmp_rank >= 0) {
149         int mask, send_index, recv_index, last_index;
150
151         /* recalculate disps and rcounts to account for the
152            special "remainder" processes that are no longer doing
153            anything */
154         int* tmp_rcounts = new int[tmp_size];
155         int* tmp_disps   = new int[tmp_size];
156
157         for (i = 0 ; i < tmp_size ; ++i) {
158             if (i < remain) {
159                 /* need to include old neighbor as well */
160                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
161             } else {
162                 tmp_rcounts[i] = rcounts[i + remain];
163             }
164         }
165
166         tmp_disps[0] = 0;
167         for (i = 0; i < tmp_size - 1; ++i) {
168             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
169         }
170
171         /* do the recursive halving communication.  Don't use the
172            dimension information on the communicator because I
173            think the information is invalidated by our "shrinking"
174            of the communicator */
175         mask = tmp_size >> 1;
176         send_index = recv_index = 0;
177         last_index = tmp_size;
178         while (mask > 0) {
179             int tmp_peer, peer, send_count, recv_count;
180             MPI_Request request;
181
182             tmp_peer = tmp_rank ^ mask;
183             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
184
185             /* figure out if we're sending, receiving, or both */
186             send_count = recv_count = 0;
187             if (tmp_rank < tmp_peer) {
188                 send_index = recv_index + mask;
189                 for (i = send_index ; i < last_index ; ++i) {
190                     send_count += tmp_rcounts[i];
191                 }
192                 for (i = recv_index ; i < send_index ; ++i) {
193                     recv_count += tmp_rcounts[i];
194                 }
195             } else {
196                 recv_index = send_index + mask;
197                 for (i = send_index ; i < recv_index ; ++i) {
198                     send_count += tmp_rcounts[i];
199                 }
200                 for (i = recv_index ; i < last_index ; ++i) {
201                     recv_count += tmp_rcounts[i];
202                 }
203             }
204
205             /* actual data transfer.  Send from result_buf,
206                receive into recv_buf */
207             if (send_count > 0 && recv_count != 0) {
208                 request=Request::irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
209                                          recv_count, dtype, peer,
210                                          COLL_TAG_REDUCE_SCATTER,
211                                          comm);
212                 if (MPI_SUCCESS != err) {
213                   delete[] tmp_rcounts;
214                   delete[] tmp_disps;
215                   goto cleanup;
216                 }
217             }
218             if (recv_count > 0 && send_count != 0) {
219                 Request::send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
220                                         send_count, dtype, peer,
221                                         COLL_TAG_REDUCE_SCATTER,
222                                         comm);
223                 if (MPI_SUCCESS != err) {
224                   delete[] tmp_rcounts;
225                   delete[] tmp_disps;
226                   goto cleanup;
227                 }
228             }
229             if (send_count > 0 && recv_count != 0) {
230                 Request::wait(&request, MPI_STATUS_IGNORE);
231             }
232
233             /* if we received something on this step, push it into
234                the results buffer */
235             if (recv_count > 0) {
236                 if(op!=MPI_OP_NULL) op->apply(
237                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
238                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
239                                &recv_count, dtype);
240             }
241
242             /* update for next iteration */
243             send_index = recv_index;
244             last_index = recv_index + mask;
245             mask >>= 1;
246         }
247
248         /* copy local results from results buffer into real receive buffer */
249         if (0 != rcounts[rank]) {
250             err = Datatype::copy(result_buf + disps[rank] * extent,
251                                        rcounts[rank], dtype,
252                                        rbuf, rcounts[rank], dtype);
253             if (MPI_SUCCESS != err) {
254               delete[] tmp_rcounts;
255               delete[] tmp_disps;
256               goto cleanup;
257             }
258         }
259
260         delete[] tmp_rcounts;
261         delete[] tmp_disps;
262     }
263
264     /* Now fix up the non-power of two case, by having the odd
265        procs send the even procs the proper results */
266     if (rank < (2 * remain)) {
267         if ((rank & 1) == 0) {
268             if (rcounts[rank]) {
269                 Request::recv(rbuf, rcounts[rank], dtype, rank + 1,
270                                         COLL_TAG_REDUCE_SCATTER,
271                                         comm, MPI_STATUS_IGNORE);
272             }
273         } else {
274             if (rcounts[rank - 1]) {
275                 Request::send(result_buf + disps[rank - 1] * extent,
276                                         rcounts[rank - 1], dtype, rank - 1,
277                                         COLL_TAG_REDUCE_SCATTER,
278                                         comm);
279             }
280         }
281     }
282
283  cleanup:
284     delete[] disps;
285     if (NULL != recv_buf_free) smpi_free_tmp_buffer(recv_buf_free);
286     if (NULL != result_buf_free) smpi_free_tmp_buffer(result_buf_free);
287
288     return err;
289 }
290
291 /* copied function (with appropriate renaming) ends here */
292
293
294 /*
295  *   Coll_reduce_scatter_ompi_ring::reduce_scatter
296  *
297  *   Function:       Ring algorithm for reduce_scatter operation
298  *   Accepts:        Same as MPI_Reduce_scatter()
299  *   Returns:        MPI_SUCCESS or error code
300  *
301  *   Description:    Implements ring algorithm for reduce_scatter:
302  *                   the block sizes defined in rcounts are exchanged and
303  8                    updated until they reach proper destination.
304  *                   Algorithm requires 2 * max(rcounts) extra buffering
305  *
306  *   Limitations:    The algorithm DOES NOT preserve order of operations so it
307  *                   can be used only for commutative operations.
308  *         Example on 5 nodes:
309  *         Initial state
310  *   #      0              1             2              3             4
311  *        [00]           [10]   ->     [20]           [30]           [40]
312  *        [01]           [11]          [21]  ->       [31]           [41]
313  *        [02]           [12]          [22]           [32]  ->       [42]
314  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
315  *        [04]  ->       [14]          [24]           [34]           [44]
316  *
317  *        COMPUTATION PHASE
318  *         Step 0: rank r sends block (r-1) to rank (r+1) and
319  *                 receives block (r+1) from rank (r-1) [with wraparound].
320  *   #      0              1             2              3             4
321  *        [00]           [10]        [10+20]   ->     [30]           [40]
322  *        [01]           [11]          [21]          [21+31]  ->     [41]
323  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
324  *      [43+03] ->       [13]          [23]           [33]           [43]
325  *        [04]         [04+14]  ->     [24]           [34]           [44]
326  *
327  *         Step 1:
328  *   #      0              1             2              3             4
329  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
330  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
331  *     [32+42+02] ->     [12]          [22]           [32]         [32+42]
332  *        [03]        [43+03+13] ->    [23]           [33]           [43]
333  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
334  *
335  *         Step 2:
336  *   #      0              1             2              3             4
337  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
338  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
339  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42]
340  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
341  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
342  *
343  *         Step 3:
344  *   #      0             1              2              3             4
345  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
346  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
347  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42]
348  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
349  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
350  *    DONE :)
351  *
352  */
353 int
354 Coll_reduce_scatter_ompi_ring::reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
355                                           MPI_Datatype dtype,
356                                           MPI_Op op,
357                                           MPI_Comm comm
358                                           )
359 {
360     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
361     int inbi;
362     char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
363     char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
364     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
365     MPI_Request reqs[2] = {NULL, NULL};
366
367     size = comm->size();
368     rank = comm->rank();
369
370     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d",
371                  rank, size);
372
373     /* Determine the maximum number of elements per node,
374        corresponding block size, and displacements array.
375     */
376     int* displs = new int[size];
377
378     displs[0] = 0;
379     total_count = rcounts[0];
380     max_block_count = rcounts[0];
381     for (i = 1; i < size; i++) {
382         displs[i] = total_count;
383         total_count += rcounts[i];
384         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
385     }
386
387     /* Special case for size == 1 */
388     if (1 == size) {
389         if (MPI_IN_PLACE != sbuf) {
390             ret = Datatype::copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
391             if (ret < 0) { line = __LINE__; goto error_hndl; }
392         }
393         delete[] displs;
394         return MPI_SUCCESS;
395     }
396
397     /* Allocate and initialize temporary buffers, we need:
398        - a temporary buffer to perform reduction (size total_count) since
399        rbuf can be of rcounts[rank] size.
400        - up to two temporary buffers used for communication/computation overlap.
401     */
402     dtype->extent(&lb, &extent);
403     dtype->extent(&true_lb, &true_extent);
404
405     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
406
407     accumbuf_free = (char*)smpi_get_tmp_recvbuffer(true_extent + (ptrdiff_t)(total_count - 1) * extent);
408     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
409     accumbuf = accumbuf_free - lb;
410
411     inbuf_free[0] = (char*)smpi_get_tmp_sendbuffer(max_real_segsize);
412     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
413     inbuf[0] = inbuf_free[0] - lb;
414     if (size > 2) {
415         inbuf_free[1] = (char*)smpi_get_tmp_sendbuffer(max_real_segsize);
416         if (NULL == inbuf_free[1]) { ret = -1; line = __LINE__; goto error_hndl; }
417         inbuf[1] = inbuf_free[1] - lb;
418     }
419
420     /* Handle MPI_IN_PLACE for size > 1 */
421     if (MPI_IN_PLACE == sbuf) {
422         sbuf = rbuf;
423     }
424
425     ret = Datatype::copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
426     if (ret < 0) { line = __LINE__; goto error_hndl; }
427
428     /* Computation loop */
429
430     /*
431        For each of the remote nodes:
432        - post irecv for block (r-2) from (r-1) with wrap around
433        - send block (r-1) to (r+1)
434        - in loop for every step k = 2 .. n
435        - post irecv for block (r - 1 + n - k) % n
436        - wait on block (r + n - k) % n to arrive
437        - compute on block (r + n - k ) % n
438        - send block (r + n - k) % n
439        - wait on block (r)
440        - compute on block (r)
441        - copy block (r) to rbuf
442        Note that we must be careful when computing the begining of buffers and
443        for send operations and computation we must compute the exact block size.
444     */
445     send_to = (rank + 1) % size;
446     recv_from = (rank + size - 1) % size;
447
448     inbi = 0;
449     /* Initialize first receive from the neighbor on the left */
450     reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
451                              COLL_TAG_REDUCE_SCATTER, comm
452                              );
453     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
454     Request::send(tmpsend, rcounts[recv_from], dtype, send_to,
455                             COLL_TAG_REDUCE_SCATTER,
456                              comm);
457
458     for (k = 2; k < size; k++) {
459         const int prevblock = (rank + size - k) % size;
460
461         inbi = inbi ^ 0x1;
462
463         /* Post irecv for the current block */
464         reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
465                                  COLL_TAG_REDUCE_SCATTER, comm
466                                  );
467
468         /* Wait on previous block to arrive */
469         Request::wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
470
471         /* Apply operation on previous block: result goes to rbuf
472            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
473         */
474         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
475         if (op != MPI_OP_NULL)
476           op->apply(inbuf[inbi ^ 0x1], tmprecv, &rcounts[prevblock], dtype);
477
478         /* send previous block to send_to */
479         Request::send(tmprecv, rcounts[prevblock], dtype, send_to,
480                                 COLL_TAG_REDUCE_SCATTER,
481                                  comm);
482     }
483
484     /* Wait on the last block to arrive */
485     Request::wait(&reqs[inbi], MPI_STATUS_IGNORE);
486
487     /* Apply operation on the last block (my block)
488        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
489     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
490     if (op != MPI_OP_NULL)
491       op->apply(inbuf[inbi], tmprecv, &rcounts[rank], dtype);
492
493     /* Copy result from tmprecv to rbuf */
494     ret = Datatype::copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
495     if (ret < 0) { line = __LINE__; goto error_hndl; }
496
497     delete[] displs;
498     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
499     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
500     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
501
502     return MPI_SUCCESS;
503
504  error_hndl:
505     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
506                  __FILE__, line, rank, ret);
507     delete[] displs;
508     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
509     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
510     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
511     return ret;
512 }
513 }
514 }
515