Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'mc++' into mc-merge
[simgrid.git] / src / smpi / colls / reduce_scatter-ompi.c
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2012 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
19  * Copyright (c) 2009      University of Houston. All rights reserved.
20  *
21  * Additional copyrights may follow
22  */
23
24 #include "colls_private.h"
25 #include "coll_tuned_topo.h"
26
27 /*
28  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
29  * I have removed the part which handles "large" message sizes 
30  * (non-overlapping version of reduce_Scatter).
31  */
32
33 /* copied function (with appropriate renaming) starts here */
34
35 /*
36  *  reduce_scatter_ompi_basic_recursivehalving
37  *
38  *  Function:   - reduce scatter implementation using recursive-halving 
39  *                algorithm
40  *  Accepts:    - same as MPI_Reduce_scatter()
41  *  Returns:    - MPI_SUCCESS or error code
42  *  Limitation: - Works only for commutative operations.
43  */
44 int
45 smpi_coll_tuned_reduce_scatter_ompi_basic_recursivehalving(void *sbuf, 
46                                                             void *rbuf, 
47                                                             int *rcounts,
48                                                             MPI_Datatype dtype,
49                                                             MPI_Op op,
50                                                             MPI_Comm comm
51                                                             )
52 {
53     int i, rank, size, count, err = MPI_SUCCESS;
54     int tmp_size=1, remain = 0, tmp_rank, *disps = NULL;
55     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
56     char *recv_buf = NULL, *recv_buf_free = NULL;
57     char *result_buf = NULL, *result_buf_free = NULL;
58    
59     /* Initialize */
60     rank = smpi_comm_rank(comm);
61     size = smpi_comm_size(comm);
62    
63     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
64
65     /* Find displacements and the like */
66     disps = (int*) xbt_malloc(sizeof(int) * size);
67     if (NULL == disps) return MPI_ERR_OTHER;
68
69     disps[0] = 0;
70     for (i = 0; i < (size - 1); ++i) {
71         disps[i + 1] = disps[i] + rcounts[i];
72     }
73     count = disps[size - 1] + rcounts[size - 1];
74
75     /* short cut the trivial case */
76     if (0 == count) {
77         xbt_free(disps);
78         return MPI_SUCCESS;
79     }
80
81     /* get datatype information */
82     smpi_datatype_extent(dtype, &lb, &extent);
83     smpi_datatype_extent(dtype, &true_lb, &true_extent);
84     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
85
86     /* Handle MPI_IN_PLACE */
87     if (MPI_IN_PLACE == sbuf) {
88         sbuf = rbuf;
89     }
90
91     /* Allocate temporary receive buffer. */
92     recv_buf_free = (char*) xbt_malloc(buf_size);
93     recv_buf = recv_buf_free - lb;
94     if (NULL == recv_buf_free) {
95         err = MPI_ERR_OTHER;
96         goto cleanup;
97     }
98    
99     /* allocate temporary buffer for results */
100     result_buf_free = (char*) xbt_malloc(buf_size);
101     result_buf = result_buf_free - lb;
102    
103     /* copy local buffer into the temporary results */
104     err =smpi_datatype_copy(sbuf, count, dtype, result_buf, count, dtype);
105     if (MPI_SUCCESS != err) goto cleanup;
106    
107     /* figure out power of two mapping: grow until larger than
108        comm size, then go back one, to get the largest power of
109        two less than comm size */
110     while (tmp_size <= size) tmp_size <<= 1;
111     tmp_size >>= 1;
112     remain = size - tmp_size;
113    
114     /* If comm size is not a power of two, have the first "remain"
115        procs with an even rank send to rank + 1, leaving a power of
116        two procs to do the rest of the algorithm */
117     if (rank < 2 * remain) {
118         if ((rank & 1) == 0) {
119             smpi_mpi_send(result_buf, count, dtype, rank + 1, 
120                                     COLL_TAG_REDUCE_SCATTER,
121                                     comm);
122             /* we don't participate from here on out */
123             tmp_rank = -1;
124         } else {
125             smpi_mpi_recv(recv_buf, count, dtype, rank - 1,
126                                     COLL_TAG_REDUCE_SCATTER,
127                                     comm, MPI_STATUS_IGNORE);
128          
129             /* integrate their results into our temp results */
130             smpi_op_apply(op, recv_buf, result_buf, &count, &dtype);
131          
132             /* adjust rank to be the bottom "remain" ranks */
133             tmp_rank = rank / 2;
134         }
135     } else {
136         /* just need to adjust rank to show that the bottom "even
137            remain" ranks dropped out */
138         tmp_rank = rank - remain;
139     }
140    
141     /* For ranks not kicked out by the above code, perform the
142        recursive halving */
143     if (tmp_rank >= 0) {
144         int *tmp_disps = NULL, *tmp_rcounts = NULL;
145         int mask, send_index, recv_index, last_index;
146       
147         /* recalculate disps and rcounts to account for the
148            special "remainder" processes that are no longer doing
149            anything */
150         tmp_rcounts = (int*) xbt_malloc(tmp_size * sizeof(int));
151         if (NULL == tmp_rcounts) {
152             err = MPI_ERR_OTHER;
153             goto cleanup;
154         }
155         tmp_disps = (int*) xbt_malloc(tmp_size * sizeof(int));
156         if (NULL == tmp_disps) {
157             xbt_free(tmp_rcounts);
158             err = MPI_ERR_OTHER;
159             goto cleanup;
160         }
161
162         for (i = 0 ; i < tmp_size ; ++i) {
163             if (i < remain) {
164                 /* need to include old neighbor as well */
165                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
166             } else {
167                 tmp_rcounts[i] = rcounts[i + remain];
168             }
169         }
170
171         tmp_disps[0] = 0;
172         for (i = 0; i < tmp_size - 1; ++i) {
173             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
174         }
175
176         /* do the recursive halving communication.  Don't use the
177            dimension information on the communicator because I
178            think the information is invalidated by our "shrinking"
179            of the communicator */
180         mask = tmp_size >> 1;
181         send_index = recv_index = 0;
182         last_index = tmp_size;
183         while (mask > 0) {
184             int tmp_peer, peer, send_count, recv_count;
185             MPI_Request request;
186
187             tmp_peer = tmp_rank ^ mask;
188             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
189
190             /* figure out if we're sending, receiving, or both */
191             send_count = recv_count = 0;
192             if (tmp_rank < tmp_peer) {
193                 send_index = recv_index + mask;
194                 for (i = send_index ; i < last_index ; ++i) {
195                     send_count += tmp_rcounts[i];
196                 }
197                 for (i = recv_index ; i < send_index ; ++i) {
198                     recv_count += tmp_rcounts[i];
199                 }
200             } else {
201                 recv_index = send_index + mask;
202                 for (i = send_index ; i < recv_index ; ++i) {
203                     send_count += tmp_rcounts[i];
204                 }
205                 for (i = recv_index ; i < last_index ; ++i) {
206                     recv_count += tmp_rcounts[i];
207                 }
208             }
209
210             /* actual data transfer.  Send from result_buf,
211                receive into recv_buf */
212             if (send_count > 0 && recv_count != 0) {
213                 request=smpi_mpi_irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
214                                          recv_count, dtype, peer,
215                                          COLL_TAG_REDUCE_SCATTER,
216                                          comm);
217                 if (MPI_SUCCESS != err) {
218                     xbt_free(tmp_rcounts);
219                     xbt_free(tmp_disps);
220                     goto cleanup;
221                 }                                             
222             }
223             if (recv_count > 0 && send_count != 0) {
224                 smpi_mpi_send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
225                                         send_count, dtype, peer, 
226                                         COLL_TAG_REDUCE_SCATTER,
227                                         comm);
228                 if (MPI_SUCCESS != err) {
229                     xbt_free(tmp_rcounts);
230                     xbt_free(tmp_disps);
231                     goto cleanup;
232                 }                                             
233             }
234             if (send_count > 0 && recv_count != 0) {
235                 smpi_mpi_wait(&request, MPI_STATUS_IGNORE);
236             }
237
238             /* if we received something on this step, push it into
239                the results buffer */
240             if (recv_count > 0) {
241                 smpi_op_apply(op, 
242                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent, 
243                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
244                                &recv_count, &dtype);
245             }
246
247             /* update for next iteration */
248             send_index = recv_index;
249             last_index = recv_index + mask;
250             mask >>= 1;
251         }
252
253         /* copy local results from results buffer into real receive buffer */
254         if (0 != rcounts[rank]) {
255             err = smpi_datatype_copy(result_buf + disps[rank] * extent,
256                                        rcounts[rank], dtype, 
257                                        rbuf, rcounts[rank], dtype);
258             if (MPI_SUCCESS != err) {
259                 xbt_free(tmp_rcounts);
260                 xbt_free(tmp_disps);
261                 goto cleanup;
262             }                                             
263         }
264
265         xbt_free(tmp_rcounts);
266         xbt_free(tmp_disps);
267     }
268
269     /* Now fix up the non-power of two case, by having the odd
270        procs send the even procs the proper results */
271     if (rank < (2 * remain)) {
272         if ((rank & 1) == 0) {
273             if (rcounts[rank]) {
274                 smpi_mpi_recv(rbuf, rcounts[rank], dtype, rank + 1,
275                                         COLL_TAG_REDUCE_SCATTER,
276                                         comm, MPI_STATUS_IGNORE);
277             }
278         } else {
279             if (rcounts[rank - 1]) {
280                 smpi_mpi_send(result_buf + disps[rank - 1] * extent,
281                                         rcounts[rank - 1], dtype, rank - 1,
282                                         COLL_TAG_REDUCE_SCATTER,
283                                         comm);
284             }
285         }            
286     }
287
288  cleanup:
289     if (NULL != disps) xbt_free(disps);
290     if (NULL != recv_buf_free) xbt_free(recv_buf_free);
291     if (NULL != result_buf_free) xbt_free(result_buf_free);
292
293     return err;
294 }
295
296 /* copied function (with appropriate renaming) ends here */
297
298
299 /*
300  *   smpi_coll_tuned_reduce_scatter_ompi_ring
301  *
302  *   Function:       Ring algorithm for reduce_scatter operation
303  *   Accepts:        Same as MPI_Reduce_scatter()
304  *   Returns:        MPI_SUCCESS or error code
305  *
306  *   Description:    Implements ring algorithm for reduce_scatter: 
307  *                   the block sizes defined in rcounts are exchanged and 
308  8                    updated until they reach proper destination.
309  *                   Algorithm requires 2 * max(rcounts) extra buffering
310  *
311  *   Limitations:    The algorithm DOES NOT preserve order of operations so it 
312  *                   can be used only for commutative operations.
313  *         Example on 5 nodes:
314  *         Initial state
315  *   #      0              1             2              3             4
316  *        [00]           [10]   ->     [20]           [30]           [40]
317  *        [01]           [11]          [21]  ->       [31]           [41]
318  *        [02]           [12]          [22]           [32]  ->       [42]
319  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
320  *        [04]  ->       [14]          [24]           [34]           [44]
321  *
322  *        COMPUTATION PHASE
323  *         Step 0: rank r sends block (r-1) to rank (r+1) and 
324  *                 receives block (r+1) from rank (r-1) [with wraparound].
325  *   #      0              1             2              3             4
326  *        [00]           [10]        [10+20]   ->     [30]           [40]
327  *        [01]           [11]          [21]          [21+31]  ->     [41]
328  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
329  *      [43+03] ->       [13]          [23]           [33]           [43]
330  *        [04]         [04+14]  ->     [24]           [34]           [44]
331  *         
332  *         Step 1:
333  *   #      0              1             2              3             4
334  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
335  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
336  *     [32+42+02] ->     [12]          [22]           [32]         [32+42] 
337  *        [03]        [43+03+13] ->    [23]           [33]           [43]
338  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
339  *
340  *         Step 2:
341  *   #      0              1             2              3             4
342  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
343  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
344  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42] 
345  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
346  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
347  *
348  *         Step 3:
349  *   #      0             1              2              3             4
350  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
351  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
352  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42] 
353  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
354  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
355  *    DONE :)
356  *
357  */
358 int 
359 smpi_coll_tuned_reduce_scatter_ompi_ring(void *sbuf, void *rbuf, int *rcounts,
360                                           MPI_Datatype dtype,
361                                           MPI_Op op,
362                                           MPI_Comm comm
363                                           )
364 {
365     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
366     int inbi, *displs = NULL;
367     char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
368     char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
369     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
370     MPI_Request reqs[2] = {NULL, NULL};
371
372     size = smpi_comm_size(comm);
373     rank = smpi_comm_rank(comm);
374
375     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d", 
376                  rank, size);
377
378     /* Determine the maximum number of elements per node, 
379        corresponding block size, and displacements array.
380     */
381     displs = (int*) xbt_malloc(size * sizeof(int));
382     if (NULL == displs) { ret = -1; line = __LINE__; goto error_hndl; }
383     displs[0] = 0;
384     total_count = rcounts[0];
385     max_block_count = rcounts[0];
386     for (i = 1; i < size; i++) { 
387         displs[i] = total_count;
388         total_count += rcounts[i];
389         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
390     }
391       
392     /* Special case for size == 1 */
393     if (1 == size) {
394         if (MPI_IN_PLACE != sbuf) {
395             ret = smpi_datatype_copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
396             if (ret < 0) { line = __LINE__; goto error_hndl; }
397         }
398         xbt_free(displs);
399         return MPI_SUCCESS;
400     }
401
402     /* Allocate and initialize temporary buffers, we need:
403        - a temporary buffer to perform reduction (size total_count) since
404        rbuf can be of rcounts[rank] size.
405        - up to two temporary buffers used for communication/computation overlap.
406     */
407     smpi_datatype_extent(dtype, &lb, &extent);
408     smpi_datatype_extent(dtype, &true_lb, &true_extent);
409
410     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
411
412     accumbuf_free = (char*)xbt_malloc(true_extent + (ptrdiff_t)(total_count - 1) * extent);
413     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
414     accumbuf = accumbuf_free - lb;
415
416     inbuf_free[0] = (char*)xbt_malloc(max_real_segsize);
417     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
418     inbuf[0] = inbuf_free[0] - lb;
419     if (size > 2) {
420         inbuf_free[1] = (char*)xbt_malloc(max_real_segsize);
421         if (NULL == inbuf_free[1]) { ret = -1; line = __LINE__; goto error_hndl; }
422         inbuf[1] = inbuf_free[1] - lb;
423     }
424
425     /* Handle MPI_IN_PLACE for size > 1 */
426     if (MPI_IN_PLACE == sbuf) {
427         sbuf = rbuf;
428     }
429
430     ret = smpi_datatype_copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
431     if (ret < 0) { line = __LINE__; goto error_hndl; }
432
433     /* Computation loop */
434
435     /* 
436        For each of the remote nodes:
437        - post irecv for block (r-2) from (r-1) with wrap around
438        - send block (r-1) to (r+1)
439        - in loop for every step k = 2 .. n
440        - post irecv for block (r - 1 + n - k) % n
441        - wait on block (r + n - k) % n to arrive
442        - compute on block (r + n - k ) % n
443        - send block (r + n - k) % n
444        - wait on block (r)
445        - compute on block (r)
446        - copy block (r) to rbuf
447        Note that we must be careful when computing the begining of buffers and
448        for send operations and computation we must compute the exact block size.
449     */
450     send_to = (rank + 1) % size;
451     recv_from = (rank + size - 1) % size;
452
453     inbi = 0;
454     /* Initialize first receive from the neighbor on the left */
455     reqs[inbi]=smpi_mpi_irecv(inbuf[inbi], max_block_count, dtype, recv_from,
456                              COLL_TAG_REDUCE_SCATTER, comm
457                              );
458     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
459     smpi_mpi_send(tmpsend, rcounts[recv_from], dtype, send_to,
460                             COLL_TAG_REDUCE_SCATTER,
461                              comm);
462
463     for (k = 2; k < size; k++) {
464         const int prevblock = (rank + size - k) % size;
465       
466         inbi = inbi ^ 0x1;
467
468         /* Post irecv for the current block */
469         reqs[inbi]=smpi_mpi_irecv(inbuf[inbi], max_block_count, dtype, recv_from,
470                                  COLL_TAG_REDUCE_SCATTER, comm
471                                  );
472       
473         /* Wait on previous block to arrive */
474         smpi_mpi_wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
475       
476         /* Apply operation on previous block: result goes to rbuf
477            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
478         */
479         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
480         smpi_op_apply(op, inbuf[inbi ^ 0x1], tmprecv, &(rcounts[prevblock]), &dtype);
481       
482         /* send previous block to send_to */
483         smpi_mpi_send(tmprecv, rcounts[prevblock], dtype, send_to,
484                                 COLL_TAG_REDUCE_SCATTER,
485                                  comm);
486     }
487
488     /* Wait on the last block to arrive */
489     smpi_mpi_wait(&reqs[inbi], MPI_STATUS_IGNORE);
490
491     /* Apply operation on the last block (my block)
492        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
493     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
494     smpi_op_apply(op, inbuf[inbi], tmprecv, &(rcounts[rank]), &dtype);
495    
496     /* Copy result from tmprecv to rbuf */
497     ret = smpi_datatype_copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
498     if (ret < 0) { line = __LINE__; goto error_hndl; }
499
500     if (NULL != displs) xbt_free(displs);
501     if (NULL != accumbuf_free) xbt_free(accumbuf_free);
502     if (NULL != inbuf_free[0]) xbt_free(inbuf_free[0]);
503     if (NULL != inbuf_free[1]) xbt_free(inbuf_free[1]);
504
505     return MPI_SUCCESS;
506
507  error_hndl:
508     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
509                  __FILE__, line, rank, ret);
510     if (NULL != displs) xbt_free(displs);
511     if (NULL != accumbuf_free) xbt_free(accumbuf_free);
512     if (NULL != inbuf_free[0]) xbt_free(inbuf_free[0]);
513     if (NULL != inbuf_free[1]) xbt_free(inbuf_free[1]);
514     return ret;
515 }
516