Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
add reduce scatter collectives from openmpi, and fix existing one
[simgrid.git] / src / smpi / colls / reduce_scatter-ompi.c
1 /* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */
2 /*
3  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
4  *                         University Research and Technology
5  *                         Corporation.  All rights reserved.
6  * Copyright (c) 2004-2012 The University of Tennessee and The University
7  *                         of Tennessee Research Foundation.  All rights
8  *                         reserved.
9  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
10  *                         University of Stuttgart.  All rights reserved.
11  * Copyright (c) 2004-2005 The Regents of the University of California.
12  *                         All rights reserved.
13  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
14  * Copyright (c) 2009      University of Houston. All rights reserved.
15  * $COPYRIGHT$
16  *
17  * Additional copyrights may follow
18  *
19  * $HEADER$
20  */
21
22 #include "colls_private.h"
23 #include "coll_tuned_topo.h"
24
25 #define MCA_COLL_BASE_TAG_REDUCE_SCATTER 222
26 /*
27  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
28  * I have removed the part which handles "large" message sizes 
29  * (non-overlapping version of reduce_Scatter).
30  */
31
32 /* copied function (with appropriate renaming) starts here */
33
34 /*
35  *  reduce_scatter_ompi_basic_recursivehalving
36  *
37  *  Function:   - reduce scatter implementation using recursive-halving 
38  *                algorithm
39  *  Accepts:    - same as MPI_Reduce_scatter()
40  *  Returns:    - MPI_SUCCESS or error code
41  *  Limitation: - Works only for commutative operations.
42  */
43 int
44 smpi_coll_tuned_reduce_scatter_ompi_basic_recursivehalving(void *sbuf, 
45                                                             void *rbuf, 
46                                                             int *rcounts,
47                                                             MPI_Datatype dtype,
48                                                             MPI_Op op,
49                                                             MPI_Comm comm
50                                                             )
51 {
52     int i, rank, size, count, err = MPI_SUCCESS;
53     int tmp_size=1, remain = 0, tmp_rank, *disps = NULL;
54     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
55     char *recv_buf = NULL, *recv_buf_free = NULL;
56     char *result_buf = NULL, *result_buf_free = NULL;
57    
58     /* Initialize */
59     rank = smpi_comm_rank(comm);
60     size = smpi_comm_size(comm);
61    
62     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
63
64     /* Find displacements and the like */
65     disps = (int*) xbt_malloc(sizeof(int) * size);
66     if (NULL == disps) return MPI_ERR_OTHER;
67
68     disps[0] = 0;
69     for (i = 0; i < (size - 1); ++i) {
70         disps[i + 1] = disps[i] + rcounts[i];
71     }
72     count = disps[size - 1] + rcounts[size - 1];
73
74     /* short cut the trivial case */
75     if (0 == count) {
76         xbt_free(disps);
77         return MPI_SUCCESS;
78     }
79
80     /* get datatype information */
81     smpi_datatype_extent(dtype, &lb, &extent);
82     smpi_datatype_extent(dtype, &true_lb, &true_extent);
83     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
84
85     /* Handle MPI_IN_PLACE */
86     if (MPI_IN_PLACE == sbuf) {
87         sbuf = rbuf;
88     }
89
90     /* Allocate temporary receive buffer. */
91     recv_buf_free = (char*) xbt_malloc(buf_size);
92     recv_buf = recv_buf_free - lb;
93     if (NULL == recv_buf_free) {
94         err = MPI_ERR_OTHER;
95         goto cleanup;
96     }
97    
98     /* allocate temporary buffer for results */
99     result_buf_free = (char*) xbt_malloc(buf_size);
100     result_buf = result_buf_free - lb;
101    
102     /* copy local buffer into the temporary results */
103     err =smpi_datatype_copy(sbuf, count, dtype, result_buf, count, dtype);
104     if (MPI_SUCCESS != err) goto cleanup;
105    
106     /* figure out power of two mapping: grow until larger than
107        comm size, then go back one, to get the largest power of
108        two less than comm size */
109     while (tmp_size <= size) tmp_size <<= 1;
110     tmp_size >>= 1;
111     remain = size - tmp_size;
112    
113     /* If comm size is not a power of two, have the first "remain"
114        procs with an even rank send to rank + 1, leaving a power of
115        two procs to do the rest of the algorithm */
116     if (rank < 2 * remain) {
117         if ((rank & 1) == 0) {
118             smpi_mpi_send(result_buf, count, dtype, rank + 1, 
119                                     MCA_COLL_BASE_TAG_REDUCE_SCATTER,
120                                     comm);
121             /* we don't participate from here on out */
122             tmp_rank = -1;
123         } else {
124             smpi_mpi_recv(recv_buf, count, dtype, rank - 1,
125                                     MCA_COLL_BASE_TAG_REDUCE_SCATTER,
126                                     comm, MPI_STATUS_IGNORE);
127          
128             /* integrate their results into our temp results */
129             smpi_op_apply(op, recv_buf, result_buf, &count, &dtype);
130          
131             /* adjust rank to be the bottom "remain" ranks */
132             tmp_rank = rank / 2;
133         }
134     } else {
135         /* just need to adjust rank to show that the bottom "even
136            remain" ranks dropped out */
137         tmp_rank = rank - remain;
138     }
139    
140     /* For ranks not kicked out by the above code, perform the
141        recursive halving */
142     if (tmp_rank >= 0) {
143         int *tmp_disps = NULL, *tmp_rcounts = NULL;
144         int mask, send_index, recv_index, last_index;
145       
146         /* recalculate disps and rcounts to account for the
147            special "remainder" processes that are no longer doing
148            anything */
149         tmp_rcounts = (int*) xbt_malloc(tmp_size * sizeof(int));
150         if (NULL == tmp_rcounts) {
151             err = MPI_ERR_OTHER;
152             goto cleanup;
153         }
154         tmp_disps = (int*) xbt_malloc(tmp_size * sizeof(int));
155         if (NULL == tmp_disps) {
156             xbt_free(tmp_rcounts);
157             err = MPI_ERR_OTHER;
158             goto cleanup;
159         }
160
161         for (i = 0 ; i < tmp_size ; ++i) {
162             if (i < remain) {
163                 /* need to include old neighbor as well */
164                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
165             } else {
166                 tmp_rcounts[i] = rcounts[i + remain];
167             }
168         }
169
170         tmp_disps[0] = 0;
171         for (i = 0; i < tmp_size - 1; ++i) {
172             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
173         }
174
175         /* do the recursive halving communication.  Don't use the
176            dimension information on the communicator because I
177            think the information is invalidated by our "shrinking"
178            of the communicator */
179         mask = tmp_size >> 1;
180         send_index = recv_index = 0;
181         last_index = tmp_size;
182         while (mask > 0) {
183             int tmp_peer, peer, send_count, recv_count;
184             MPI_Request request;
185
186             tmp_peer = tmp_rank ^ mask;
187             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
188
189             /* figure out if we're sending, receiving, or both */
190             send_count = recv_count = 0;
191             if (tmp_rank < tmp_peer) {
192                 send_index = recv_index + mask;
193                 for (i = send_index ; i < last_index ; ++i) {
194                     send_count += tmp_rcounts[i];
195                 }
196                 for (i = recv_index ; i < send_index ; ++i) {
197                     recv_count += tmp_rcounts[i];
198                 }
199             } else {
200                 recv_index = send_index + mask;
201                 for (i = send_index ; i < recv_index ; ++i) {
202                     send_count += tmp_rcounts[i];
203                 }
204                 for (i = recv_index ; i < last_index ; ++i) {
205                     recv_count += tmp_rcounts[i];
206                 }
207             }
208
209             /* actual data transfer.  Send from result_buf,
210                receive into recv_buf */
211             if (send_count > 0 && recv_count != 0) {
212                 request=smpi_mpi_irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
213                                          recv_count, dtype, peer,
214                                          MCA_COLL_BASE_TAG_REDUCE_SCATTER,
215                                          comm);
216                 if (MPI_SUCCESS != err) {
217                     xbt_free(tmp_rcounts);
218                     xbt_free(tmp_disps);
219                     goto cleanup;
220                 }                                             
221             }
222             if (recv_count > 0 && send_count != 0) {
223                 smpi_mpi_send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
224                                         send_count, dtype, peer, 
225                                         MCA_COLL_BASE_TAG_REDUCE_SCATTER,
226                                         comm);
227                 if (MPI_SUCCESS != err) {
228                     xbt_free(tmp_rcounts);
229                     xbt_free(tmp_disps);
230                     goto cleanup;
231                 }                                             
232             }
233             if (send_count > 0 && recv_count != 0) {
234                 smpi_mpi_wait(&request, MPI_STATUS_IGNORE);
235             }
236
237             /* if we received something on this step, push it into
238                the results buffer */
239             if (recv_count > 0) {
240                 smpi_op_apply(op, 
241                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent, 
242                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
243                                &recv_count, &dtype);
244             }
245
246             /* update for next iteration */
247             send_index = recv_index;
248             last_index = recv_index + mask;
249             mask >>= 1;
250         }
251
252         /* copy local results from results buffer into real receive buffer */
253         if (0 != rcounts[rank]) {
254             err = smpi_datatype_copy(result_buf + disps[rank] * extent,
255                                        rcounts[rank], dtype, 
256                                        rbuf, rcounts[rank], dtype);
257             if (MPI_SUCCESS != err) {
258                 xbt_free(tmp_rcounts);
259                 xbt_free(tmp_disps);
260                 goto cleanup;
261             }                                             
262         }
263
264         xbt_free(tmp_rcounts);
265         xbt_free(tmp_disps);
266     }
267
268     /* Now fix up the non-power of two case, by having the odd
269        procs send the even procs the proper results */
270     if (rank < (2 * remain)) {
271         if ((rank & 1) == 0) {
272             if (rcounts[rank]) {
273                 smpi_mpi_recv(rbuf, rcounts[rank], dtype, rank + 1,
274                                         MCA_COLL_BASE_TAG_REDUCE_SCATTER,
275                                         comm, MPI_STATUS_IGNORE);
276             }
277         } else {
278             if (rcounts[rank - 1]) {
279                 smpi_mpi_send(result_buf + disps[rank - 1] * extent,
280                                         rcounts[rank - 1], dtype, rank - 1,
281                                         MCA_COLL_BASE_TAG_REDUCE_SCATTER,
282                                         comm);
283             }
284         }            
285     }
286
287  cleanup:
288     if (NULL != disps) xbt_free(disps);
289     if (NULL != recv_buf_free) xbt_free(recv_buf_free);
290     if (NULL != result_buf_free) xbt_free(result_buf_free);
291
292     return err;
293 }
294
295 /* copied function (with appropriate renaming) ends here */
296
297
298 /*
299  *   smpi_coll_tuned_reduce_scatter_ompi_ring
300  *
301  *   Function:       Ring algorithm for reduce_scatter operation
302  *   Accepts:        Same as MPI_Reduce_scatter()
303  *   Returns:        MPI_SUCCESS or error code
304  *
305  *   Description:    Implements ring algorithm for reduce_scatter: 
306  *                   the block sizes defined in rcounts are exchanged and 
307  8                    updated until they reach proper destination.
308  *                   Algorithm requires 2 * max(rcounts) extra buffering
309  *
310  *   Limitations:    The algorithm DOES NOT preserve order of operations so it 
311  *                   can be used only for commutative operations.
312  *         Example on 5 nodes:
313  *         Initial state
314  *   #      0              1             2              3             4
315  *        [00]           [10]   ->     [20]           [30]           [40]
316  *        [01]           [11]          [21]  ->       [31]           [41]
317  *        [02]           [12]          [22]           [32]  ->       [42]
318  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
319  *        [04]  ->       [14]          [24]           [34]           [44]
320  *
321  *        COMPUTATION PHASE
322  *         Step 0: rank r sends block (r-1) to rank (r+1) and 
323  *                 receives block (r+1) from rank (r-1) [with wraparound].
324  *   #      0              1             2              3             4
325  *        [00]           [10]        [10+20]   ->     [30]           [40]
326  *        [01]           [11]          [21]          [21+31]  ->     [41]
327  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
328  *      [43+03] ->       [13]          [23]           [33]           [43]
329  *        [04]         [04+14]  ->     [24]           [34]           [44]
330  *         
331  *         Step 1:
332  *   #      0              1             2              3             4
333  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
334  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
335  *     [32+42+02] ->     [12]          [22]           [32]         [32+42] 
336  *        [03]        [43+03+13] ->    [23]           [33]           [43]
337  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
338  *
339  *         Step 2:
340  *   #      0              1             2              3             4
341  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
342  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
343  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42] 
344  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
345  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
346  *
347  *         Step 3:
348  *   #      0             1              2              3             4
349  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
350  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
351  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42] 
352  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
353  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
354  *    DONE :)
355  *
356  */
357 int 
358 smpi_coll_tuned_reduce_scatter_ompi_ring(void *sbuf, void *rbuf, int *rcounts,
359                                           MPI_Datatype dtype,
360                                           MPI_Op op,
361                                           MPI_Comm comm
362                                           )
363 {
364     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
365     int inbi, *displs = NULL;
366     char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
367     char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
368     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
369     MPI_Request reqs[2] = {NULL, NULL};
370     size_t typelng;
371
372     size = smpi_comm_size(comm);
373     rank = smpi_comm_rank(comm);
374
375     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d", 
376                  rank, size);
377
378     /* Determine the maximum number of elements per node, 
379        corresponding block size, and displacements array.
380     */
381     displs = (int*) xbt_malloc(size * sizeof(int));
382     if (NULL == displs) { ret = -1; line = __LINE__; goto error_hndl; }
383     displs[0] = 0;
384     total_count = rcounts[0];
385     max_block_count = rcounts[0];
386     for (i = 1; i < size; i++) { 
387         displs[i] = total_count;
388         total_count += rcounts[i];
389         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
390     }
391       
392     /* Special case for size == 1 */
393     if (1 == size) {
394         if (MPI_IN_PLACE != sbuf) {
395             ret = smpi_datatype_copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
396             if (ret < 0) { line = __LINE__; goto error_hndl; }
397         }
398         xbt_free(displs);
399         return MPI_SUCCESS;
400     }
401
402     /* Allocate and initialize temporary buffers, we need:
403        - a temporary buffer to perform reduction (size total_count) since
404        rbuf can be of rcounts[rank] size.
405        - up to two temporary buffers used for communication/computation overlap.
406     */
407     smpi_datatype_extent(dtype, &lb, &extent);
408     smpi_datatype_extent(dtype, &true_lb, &true_extent);
409     typelng = smpi_datatype_size(dtype);
410
411     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
412
413     accumbuf_free = (char*)xbt_malloc(true_extent + (ptrdiff_t)(total_count - 1) * extent);
414     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
415     accumbuf = accumbuf_free - lb;
416
417     inbuf_free[0] = (char*)xbt_malloc(max_real_segsize);
418     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
419     inbuf[0] = inbuf_free[0] - lb;
420     if (size > 2) {
421         inbuf_free[1] = (char*)xbt_malloc(max_real_segsize);
422         if (NULL == inbuf_free[1]) { ret = -1; line = __LINE__; goto error_hndl; }
423         inbuf[1] = inbuf_free[1] - lb;
424     }
425
426     /* Handle MPI_IN_PLACE for size > 1 */
427     if (MPI_IN_PLACE == sbuf) {
428         sbuf = rbuf;
429     }
430
431     ret = smpi_datatype_copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
432     if (ret < 0) { line = __LINE__; goto error_hndl; }
433
434     /* Computation loop */
435
436     /* 
437        For each of the remote nodes:
438        - post irecv for block (r-2) from (r-1) with wrap around
439        - send block (r-1) to (r+1)
440        - in loop for every step k = 2 .. n
441        - post irecv for block (r - 1 + n - k) % n
442        - wait on block (r + n - k) % n to arrive
443        - compute on block (r + n - k ) % n
444        - send block (r + n - k) % n
445        - wait on block (r)
446        - compute on block (r)
447        - copy block (r) to rbuf
448        Note that we must be careful when computing the begining of buffers and
449        for send operations and computation we must compute the exact block size.
450     */
451     send_to = (rank + 1) % size;
452     recv_from = (rank + size - 1) % size;
453
454     inbi = 0;
455     /* Initialize first receive from the neighbor on the left */
456     reqs[inbi]=smpi_mpi_irecv(inbuf[inbi], max_block_count, dtype, recv_from,
457                              MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm
458                              );
459     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
460     smpi_mpi_send(tmpsend, rcounts[recv_from], dtype, send_to,
461                             MCA_COLL_BASE_TAG_REDUCE_SCATTER,
462                              comm);
463
464     for (k = 2; k < size; k++) {
465         const int prevblock = (rank + size - k) % size;
466       
467         inbi = inbi ^ 0x1;
468
469         /* Post irecv for the current block */
470         reqs[inbi]=smpi_mpi_irecv(inbuf[inbi], max_block_count, dtype, recv_from,
471                                  MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm 
472                                  );
473       
474         /* Wait on previous block to arrive */
475         smpi_mpi_wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
476       
477         /* Apply operation on previous block: result goes to rbuf
478            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
479         */
480         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
481         smpi_op_apply(op, inbuf[inbi ^ 0x1], tmprecv, &(rcounts[prevblock]), &dtype);
482       
483         /* send previous block to send_to */
484         smpi_mpi_send(tmprecv, rcounts[prevblock], dtype, send_to,
485                                 MCA_COLL_BASE_TAG_REDUCE_SCATTER,
486                                  comm);
487     }
488
489     /* Wait on the last block to arrive */
490     smpi_mpi_wait(&reqs[inbi], MPI_STATUS_IGNORE);
491
492     /* Apply operation on the last block (my block)
493        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
494     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
495     smpi_op_apply(op, inbuf[inbi], tmprecv, &(rcounts[rank]), &dtype);
496    
497     /* Copy result from tmprecv to rbuf */
498     ret = smpi_datatype_copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
499     if (ret < 0) { line = __LINE__; goto error_hndl; }
500
501     if (NULL != displs) xbt_free(displs);
502     if (NULL != accumbuf_free) xbt_free(accumbuf_free);
503     if (NULL != inbuf_free[0]) xbt_free(inbuf_free[0]);
504     if (NULL != inbuf_free[1]) xbt_free(inbuf_free[1]);
505
506     return MPI_SUCCESS;
507
508  error_hndl:
509     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
510                  __FILE__, line, rank, ret);
511     if (NULL != displs) xbt_free(displs);
512     if (NULL != accumbuf_free) xbt_free(accumbuf_free);
513     if (NULL != inbuf_free[0]) xbt_free(inbuf_free[0]);
514     if (NULL != inbuf_free[1]) xbt_free(inbuf_free[1]);
515     return ret;
516 }
517