Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of https://framagit.org/simgrid/simgrid
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-ompi.cpp
1 /* Copyright (c) 2013-2019. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2012 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
19  * Copyright (c) 2009      University of Houston. All rights reserved.
20  *
21  * Additional copyrights may follow
22  */
23
24 #include "../coll_tuned_topo.hpp"
25 #include "../colls_private.hpp"
26
27 /*
28  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
29  * I have removed the part which handles "large" message sizes
30  * (non-overlapping version of reduce_Scatter).
31  */
32
33 /* copied function (with appropriate renaming) starts here */
34
35 /*
36  *  reduce_scatter_ompi_basic_recursivehalving
37  *
38  *  Function:   - reduce scatter implementation using recursive-halving
39  *                algorithm
40  *  Accepts:    - same as MPI_Reduce_scatter()
41  *  Returns:    - MPI_SUCCESS or error code
42  *  Limitation: - Works only for commutative operations.
43  */
44 namespace simgrid{
45 namespace smpi{
46 int
47 Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(const void *sbuf,
48                                                             void *rbuf,
49                                                             const int *rcounts,
50                                                             MPI_Datatype dtype,
51                                                             MPI_Op op,
52                                                             MPI_Comm comm
53                                                             )
54 {
55     int i, rank, size, count, err = MPI_SUCCESS;
56     int tmp_size = 1, remain = 0, tmp_rank;
57     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
58     unsigned char *result_buf = nullptr, *result_buf_free = nullptr;
59
60     /* Initialize */
61     rank = comm->rank();
62     size = comm->size();
63
64     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
65     if ((op != MPI_OP_NULL && not op->is_commutative()))
66       THROWF(arg_error,0, " reduce_scatter ompi_basic_recursivehalving can only be used for commutative operations! ");
67
68     /* Find displacements and the like */
69     int* disps = new int[size];
70
71     disps[0] = 0;
72     for (i = 0; i < (size - 1); ++i) {
73         disps[i + 1] = disps[i] + rcounts[i];
74     }
75     count = disps[size - 1] + rcounts[size - 1];
76
77     /* short cut the trivial case */
78     if (0 == count) {
79       delete[] disps;
80       return MPI_SUCCESS;
81     }
82
83     /* get datatype information */
84     dtype->extent(&lb, &extent);
85     dtype->extent(&true_lb, &true_extent);
86     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
87
88     /* Handle MPI_IN_PLACE */
89     if (MPI_IN_PLACE == sbuf) {
90         sbuf = rbuf;
91     }
92
93     /* Allocate temporary receive buffer. */
94     unsigned char* recv_buf_free = smpi_get_tmp_recvbuffer(buf_size);
95     unsigned char* recv_buf      = recv_buf_free - lb;
96     if (NULL == recv_buf_free) {
97         err = MPI_ERR_OTHER;
98         goto cleanup;
99     }
100
101     /* allocate temporary buffer for results */
102     result_buf_free = smpi_get_tmp_sendbuffer(buf_size);
103     result_buf = result_buf_free - lb;
104
105     /* copy local buffer into the temporary results */
106     err =Datatype::copy(sbuf, count, dtype, result_buf, count, dtype);
107     if (MPI_SUCCESS != err) goto cleanup;
108
109     /* figure out power of two mapping: grow until larger than
110        comm size, then go back one, to get the largest power of
111        two less than comm size */
112     while (tmp_size <= size) tmp_size <<= 1;
113     tmp_size >>= 1;
114     remain = size - tmp_size;
115
116     /* If comm size is not a power of two, have the first "remain"
117        procs with an even rank send to rank + 1, leaving a power of
118        two procs to do the rest of the algorithm */
119     if (rank < 2 * remain) {
120         if ((rank & 1) == 0) {
121             Request::send(result_buf, count, dtype, rank + 1,
122                                     COLL_TAG_REDUCE_SCATTER,
123                                     comm);
124             /* we don't participate from here on out */
125             tmp_rank = -1;
126         } else {
127             Request::recv(recv_buf, count, dtype, rank - 1,
128                                     COLL_TAG_REDUCE_SCATTER,
129                                     comm, MPI_STATUS_IGNORE);
130
131             /* integrate their results into our temp results */
132             if(op!=MPI_OP_NULL) op->apply( recv_buf, result_buf, &count, dtype);
133
134             /* adjust rank to be the bottom "remain" ranks */
135             tmp_rank = rank / 2;
136         }
137     } else {
138         /* just need to adjust rank to show that the bottom "even
139            remain" ranks dropped out */
140         tmp_rank = rank - remain;
141     }
142
143     /* For ranks not kicked out by the above code, perform the
144        recursive halving */
145     if (tmp_rank >= 0) {
146         int mask, send_index, recv_index, last_index;
147
148         /* recalculate disps and rcounts to account for the
149            special "remainder" processes that are no longer doing
150            anything */
151         int* tmp_rcounts = new int[tmp_size];
152         int* tmp_disps   = new int[tmp_size];
153
154         for (i = 0 ; i < tmp_size ; ++i) {
155             if (i < remain) {
156                 /* need to include old neighbor as well */
157                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
158             } else {
159                 tmp_rcounts[i] = rcounts[i + remain];
160             }
161         }
162
163         tmp_disps[0] = 0;
164         for (i = 0; i < tmp_size - 1; ++i) {
165             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
166         }
167
168         /* do the recursive halving communication.  Don't use the
169            dimension information on the communicator because I
170            think the information is invalidated by our "shrinking"
171            of the communicator */
172         mask = tmp_size >> 1;
173         send_index = recv_index = 0;
174         last_index = tmp_size;
175         while (mask > 0) {
176             int tmp_peer, peer, send_count, recv_count;
177             MPI_Request request;
178
179             tmp_peer = tmp_rank ^ mask;
180             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
181
182             /* figure out if we're sending, receiving, or both */
183             send_count = recv_count = 0;
184             if (tmp_rank < tmp_peer) {
185                 send_index = recv_index + mask;
186                 for (i = send_index ; i < last_index ; ++i) {
187                     send_count += tmp_rcounts[i];
188                 }
189                 for (i = recv_index ; i < send_index ; ++i) {
190                     recv_count += tmp_rcounts[i];
191                 }
192             } else {
193                 recv_index = send_index + mask;
194                 for (i = send_index ; i < recv_index ; ++i) {
195                     send_count += tmp_rcounts[i];
196                 }
197                 for (i = recv_index ; i < last_index ; ++i) {
198                     recv_count += tmp_rcounts[i];
199                 }
200             }
201
202             /* actual data transfer.  Send from result_buf,
203                receive into recv_buf */
204             if (send_count > 0 && recv_count != 0) {
205                 request=Request::irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
206                                          recv_count, dtype, peer,
207                                          COLL_TAG_REDUCE_SCATTER,
208                                          comm);
209                 if (MPI_SUCCESS != err) {
210                   delete[] tmp_rcounts;
211                   delete[] tmp_disps;
212                   goto cleanup;
213                 }
214             }
215             if (recv_count > 0 && send_count != 0) {
216                 Request::send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
217                                         send_count, dtype, peer,
218                                         COLL_TAG_REDUCE_SCATTER,
219                                         comm);
220                 if (MPI_SUCCESS != err) {
221                   delete[] tmp_rcounts;
222                   delete[] tmp_disps;
223                   goto cleanup;
224                 }
225             }
226             if (send_count > 0 && recv_count != 0) {
227                 Request::wait(&request, MPI_STATUS_IGNORE);
228             }
229
230             /* if we received something on this step, push it into
231                the results buffer */
232             if (recv_count > 0) {
233                 if(op!=MPI_OP_NULL) op->apply(
234                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
235                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
236                                &recv_count, dtype);
237             }
238
239             /* update for next iteration */
240             send_index = recv_index;
241             last_index = recv_index + mask;
242             mask >>= 1;
243         }
244
245         /* copy local results from results buffer into real receive buffer */
246         if (0 != rcounts[rank]) {
247             err = Datatype::copy(result_buf + disps[rank] * extent,
248                                        rcounts[rank], dtype,
249                                        rbuf, rcounts[rank], dtype);
250             if (MPI_SUCCESS != err) {
251               delete[] tmp_rcounts;
252               delete[] tmp_disps;
253               goto cleanup;
254             }
255         }
256
257         delete[] tmp_rcounts;
258         delete[] tmp_disps;
259     }
260
261     /* Now fix up the non-power of two case, by having the odd
262        procs send the even procs the proper results */
263     if (rank < (2 * remain)) {
264         if ((rank & 1) == 0) {
265             if (rcounts[rank]) {
266                 Request::recv(rbuf, rcounts[rank], dtype, rank + 1,
267                                         COLL_TAG_REDUCE_SCATTER,
268                                         comm, MPI_STATUS_IGNORE);
269             }
270         } else {
271             if (rcounts[rank - 1]) {
272                 Request::send(result_buf + disps[rank - 1] * extent,
273                                         rcounts[rank - 1], dtype, rank - 1,
274                                         COLL_TAG_REDUCE_SCATTER,
275                                         comm);
276             }
277         }
278     }
279
280  cleanup:
281     delete[] disps;
282     if (NULL != recv_buf_free) smpi_free_tmp_buffer(recv_buf_free);
283     if (NULL != result_buf_free) smpi_free_tmp_buffer(result_buf_free);
284
285     return err;
286 }
287
288 /* copied function (with appropriate renaming) ends here */
289
290
291 /*
292  *   Coll_reduce_scatter_ompi_ring::reduce_scatter
293  *
294  *   Function:       Ring algorithm for reduce_scatter operation
295  *   Accepts:        Same as MPI_Reduce_scatter()
296  *   Returns:        MPI_SUCCESS or error code
297  *
298  *   Description:    Implements ring algorithm for reduce_scatter:
299  *                   the block sizes defined in rcounts are exchanged and
300  8                    updated until they reach proper destination.
301  *                   Algorithm requires 2 * max(rcounts) extra buffering
302  *
303  *   Limitations:    The algorithm DOES NOT preserve order of operations so it
304  *                   can be used only for commutative operations.
305  *         Example on 5 nodes:
306  *         Initial state
307  *   #      0              1             2              3             4
308  *        [00]           [10]   ->     [20]           [30]           [40]
309  *        [01]           [11]          [21]  ->       [31]           [41]
310  *        [02]           [12]          [22]           [32]  ->       [42]
311  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
312  *        [04]  ->       [14]          [24]           [34]           [44]
313  *
314  *        COMPUTATION PHASE
315  *         Step 0: rank r sends block (r-1) to rank (r+1) and
316  *                 receives block (r+1) from rank (r-1) [with wraparound].
317  *   #      0              1             2              3             4
318  *        [00]           [10]        [10+20]   ->     [30]           [40]
319  *        [01]           [11]          [21]          [21+31]  ->     [41]
320  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
321  *      [43+03] ->       [13]          [23]           [33]           [43]
322  *        [04]         [04+14]  ->     [24]           [34]           [44]
323  *
324  *         Step 1:
325  *   #      0              1             2              3             4
326  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
327  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
328  *     [32+42+02] ->     [12]          [22]           [32]         [32+42]
329  *        [03]        [43+03+13] ->    [23]           [33]           [43]
330  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
331  *
332  *         Step 2:
333  *   #      0              1             2              3             4
334  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
335  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
336  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42]
337  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
338  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
339  *
340  *         Step 3:
341  *   #      0             1              2              3             4
342  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
343  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
344  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42]
345  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
346  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
347  *    DONE :)
348  *
349  */
350 int
351 Coll_reduce_scatter_ompi_ring::reduce_scatter(const void *sbuf, void *rbuf, const int *rcounts,
352                                           MPI_Datatype dtype,
353                                           MPI_Op op,
354                                           MPI_Comm comm
355                                           )
356 {
357     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
358     int inbi;
359     unsigned char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
360     unsigned char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
361     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
362     MPI_Request reqs[2] = {NULL, NULL};
363
364     size = comm->size();
365     rank = comm->rank();
366
367     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d",
368                  rank, size);
369
370     /* Determine the maximum number of elements per node,
371        corresponding block size, and displacements array.
372     */
373     int* displs = new int[size];
374
375     displs[0] = 0;
376     total_count = rcounts[0];
377     max_block_count = rcounts[0];
378     for (i = 1; i < size; i++) {
379         displs[i] = total_count;
380         total_count += rcounts[i];
381         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
382     }
383
384     /* Special case for size == 1 */
385     if (1 == size) {
386         if (MPI_IN_PLACE != sbuf) {
387             ret = Datatype::copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
388             if (ret < 0) { line = __LINE__; goto error_hndl; }
389         }
390         delete[] displs;
391         return MPI_SUCCESS;
392     }
393
394     /* Allocate and initialize temporary buffers, we need:
395        - a temporary buffer to perform reduction (size total_count) since
396        rbuf can be of rcounts[rank] size.
397        - up to two temporary buffers used for communication/computation overlap.
398     */
399     dtype->extent(&lb, &extent);
400     dtype->extent(&true_lb, &true_extent);
401
402     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
403
404     accumbuf_free = smpi_get_tmp_recvbuffer(true_extent + (ptrdiff_t)(total_count - 1) * extent);
405     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
406     accumbuf = accumbuf_free - lb;
407
408     inbuf_free[0] = smpi_get_tmp_sendbuffer(max_real_segsize);
409     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
410     inbuf[0] = inbuf_free[0] - lb;
411     if (size > 2) {
412       inbuf_free[1] = smpi_get_tmp_sendbuffer(max_real_segsize);
413       if (NULL == inbuf_free[1]) {
414         ret  = -1;
415         line = __LINE__;
416         goto error_hndl;
417       }
418       inbuf[1] = inbuf_free[1] - lb;
419     }
420
421     /* Handle MPI_IN_PLACE for size > 1 */
422     if (MPI_IN_PLACE == sbuf) {
423         sbuf = rbuf;
424     }
425
426     ret = Datatype::copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
427     if (ret < 0) { line = __LINE__; goto error_hndl; }
428
429     /* Computation loop */
430
431     /*
432        For each of the remote nodes:
433        - post irecv for block (r-2) from (r-1) with wrap around
434        - send block (r-1) to (r+1)
435        - in loop for every step k = 2 .. n
436        - post irecv for block (r - 1 + n - k) % n
437        - wait on block (r + n - k) % n to arrive
438        - compute on block (r + n - k ) % n
439        - send block (r + n - k) % n
440        - wait on block (r)
441        - compute on block (r)
442        - copy block (r) to rbuf
443        Note that we must be careful when computing the begining of buffers and
444        for send operations and computation we must compute the exact block size.
445     */
446     send_to = (rank + 1) % size;
447     recv_from = (rank + size - 1) % size;
448
449     inbi = 0;
450     /* Initialize first receive from the neighbor on the left */
451     reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
452                              COLL_TAG_REDUCE_SCATTER, comm
453                              );
454     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
455     Request::send(tmpsend, rcounts[recv_from], dtype, send_to,
456                             COLL_TAG_REDUCE_SCATTER,
457                              comm);
458
459     for (k = 2; k < size; k++) {
460         const int prevblock = (rank + size - k) % size;
461
462         inbi = inbi ^ 0x1;
463
464         /* Post irecv for the current block */
465         reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
466                                  COLL_TAG_REDUCE_SCATTER, comm
467                                  );
468
469         /* Wait on previous block to arrive */
470         Request::wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
471
472         /* Apply operation on previous block: result goes to rbuf
473            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
474         */
475         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
476         if (op != MPI_OP_NULL)
477           op->apply(inbuf[inbi ^ 0x1], tmprecv, &rcounts[prevblock], dtype);
478
479         /* send previous block to send_to */
480         Request::send(tmprecv, rcounts[prevblock], dtype, send_to,
481                                 COLL_TAG_REDUCE_SCATTER,
482                                  comm);
483     }
484
485     /* Wait on the last block to arrive */
486     Request::wait(&reqs[inbi], MPI_STATUS_IGNORE);
487
488     /* Apply operation on the last block (my block)
489        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
490     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
491     if (op != MPI_OP_NULL)
492       op->apply(inbuf[inbi], tmprecv, &rcounts[rank], dtype);
493
494     /* Copy result from tmprecv to rbuf */
495     ret = Datatype::copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
496     if (ret < 0) { line = __LINE__; goto error_hndl; }
497
498     delete[] displs;
499     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
500     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
501     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
502
503     return MPI_SUCCESS;
504
505  error_hndl:
506     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
507                  __FILE__, line, rank, ret);
508     delete[] displs;
509     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
510     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
511     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
512     return ret;
513 }
514 }
515 }
516