Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'dvfs'
[simgrid.git] / src / smpi / colls / reduce_scatter-ompi.c
1 /* -*- Mode: C; c-basic-offset:4 ; indent-tabs-mode:nil -*- */
2 /*
3  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
4  *                         University Research and Technology
5  *                         Corporation.  All rights reserved.
6  * Copyright (c) 2004-2012 The University of Tennessee and The University
7  *                         of Tennessee Research Foundation.  All rights
8  *                         reserved.
9  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
10  *                         University of Stuttgart.  All rights reserved.
11  * Copyright (c) 2004-2005 The Regents of the University of California.
12  *                         All rights reserved.
13  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
14  * Copyright (c) 2009      University of Houston. All rights reserved.
15  * $COPYRIGHT$
16  *
17  * Additional copyrights may follow
18  *
19  * $HEADER$
20  */
21
22 #include "colls_private.h"
23 #include "coll_tuned_topo.h"
24
25 /*
26  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
27  * I have removed the part which handles "large" message sizes 
28  * (non-overlapping version of reduce_Scatter).
29  */
30
31 /* copied function (with appropriate renaming) starts here */
32
33 /*
34  *  reduce_scatter_ompi_basic_recursivehalving
35  *
36  *  Function:   - reduce scatter implementation using recursive-halving 
37  *                algorithm
38  *  Accepts:    - same as MPI_Reduce_scatter()
39  *  Returns:    - MPI_SUCCESS or error code
40  *  Limitation: - Works only for commutative operations.
41  */
42 int
43 smpi_coll_tuned_reduce_scatter_ompi_basic_recursivehalving(void *sbuf, 
44                                                             void *rbuf, 
45                                                             int *rcounts,
46                                                             MPI_Datatype dtype,
47                                                             MPI_Op op,
48                                                             MPI_Comm comm
49                                                             )
50 {
51     int i, rank, size, count, err = MPI_SUCCESS;
52     int tmp_size=1, remain = 0, tmp_rank, *disps = NULL;
53     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
54     char *recv_buf = NULL, *recv_buf_free = NULL;
55     char *result_buf = NULL, *result_buf_free = NULL;
56    
57     /* Initialize */
58     rank = smpi_comm_rank(comm);
59     size = smpi_comm_size(comm);
60    
61     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
62
63     /* Find displacements and the like */
64     disps = (int*) xbt_malloc(sizeof(int) * size);
65     if (NULL == disps) return MPI_ERR_OTHER;
66
67     disps[0] = 0;
68     for (i = 0; i < (size - 1); ++i) {
69         disps[i + 1] = disps[i] + rcounts[i];
70     }
71     count = disps[size - 1] + rcounts[size - 1];
72
73     /* short cut the trivial case */
74     if (0 == count) {
75         xbt_free(disps);
76         return MPI_SUCCESS;
77     }
78
79     /* get datatype information */
80     smpi_datatype_extent(dtype, &lb, &extent);
81     smpi_datatype_extent(dtype, &true_lb, &true_extent);
82     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
83
84     /* Handle MPI_IN_PLACE */
85     if (MPI_IN_PLACE == sbuf) {
86         sbuf = rbuf;
87     }
88
89     /* Allocate temporary receive buffer. */
90     recv_buf_free = (char*) xbt_malloc(buf_size);
91     recv_buf = recv_buf_free - lb;
92     if (NULL == recv_buf_free) {
93         err = MPI_ERR_OTHER;
94         goto cleanup;
95     }
96    
97     /* allocate temporary buffer for results */
98     result_buf_free = (char*) xbt_malloc(buf_size);
99     result_buf = result_buf_free - lb;
100    
101     /* copy local buffer into the temporary results */
102     err =smpi_datatype_copy(sbuf, count, dtype, result_buf, count, dtype);
103     if (MPI_SUCCESS != err) goto cleanup;
104    
105     /* figure out power of two mapping: grow until larger than
106        comm size, then go back one, to get the largest power of
107        two less than comm size */
108     while (tmp_size <= size) tmp_size <<= 1;
109     tmp_size >>= 1;
110     remain = size - tmp_size;
111    
112     /* If comm size is not a power of two, have the first "remain"
113        procs with an even rank send to rank + 1, leaving a power of
114        two procs to do the rest of the algorithm */
115     if (rank < 2 * remain) {
116         if ((rank & 1) == 0) {
117             smpi_mpi_send(result_buf, count, dtype, rank + 1, 
118                                     COLL_TAG_REDUCE_SCATTER,
119                                     comm);
120             /* we don't participate from here on out */
121             tmp_rank = -1;
122         } else {
123             smpi_mpi_recv(recv_buf, count, dtype, rank - 1,
124                                     COLL_TAG_REDUCE_SCATTER,
125                                     comm, MPI_STATUS_IGNORE);
126          
127             /* integrate their results into our temp results */
128             smpi_op_apply(op, recv_buf, result_buf, &count, &dtype);
129          
130             /* adjust rank to be the bottom "remain" ranks */
131             tmp_rank = rank / 2;
132         }
133     } else {
134         /* just need to adjust rank to show that the bottom "even
135            remain" ranks dropped out */
136         tmp_rank = rank - remain;
137     }
138    
139     /* For ranks not kicked out by the above code, perform the
140        recursive halving */
141     if (tmp_rank >= 0) {
142         int *tmp_disps = NULL, *tmp_rcounts = NULL;
143         int mask, send_index, recv_index, last_index;
144       
145         /* recalculate disps and rcounts to account for the
146            special "remainder" processes that are no longer doing
147            anything */
148         tmp_rcounts = (int*) xbt_malloc(tmp_size * sizeof(int));
149         if (NULL == tmp_rcounts) {
150             err = MPI_ERR_OTHER;
151             goto cleanup;
152         }
153         tmp_disps = (int*) xbt_malloc(tmp_size * sizeof(int));
154         if (NULL == tmp_disps) {
155             xbt_free(tmp_rcounts);
156             err = MPI_ERR_OTHER;
157             goto cleanup;
158         }
159
160         for (i = 0 ; i < tmp_size ; ++i) {
161             if (i < remain) {
162                 /* need to include old neighbor as well */
163                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
164             } else {
165                 tmp_rcounts[i] = rcounts[i + remain];
166             }
167         }
168
169         tmp_disps[0] = 0;
170         for (i = 0; i < tmp_size - 1; ++i) {
171             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
172         }
173
174         /* do the recursive halving communication.  Don't use the
175            dimension information on the communicator because I
176            think the information is invalidated by our "shrinking"
177            of the communicator */
178         mask = tmp_size >> 1;
179         send_index = recv_index = 0;
180         last_index = tmp_size;
181         while (mask > 0) {
182             int tmp_peer, peer, send_count, recv_count;
183             MPI_Request request;
184
185             tmp_peer = tmp_rank ^ mask;
186             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
187
188             /* figure out if we're sending, receiving, or both */
189             send_count = recv_count = 0;
190             if (tmp_rank < tmp_peer) {
191                 send_index = recv_index + mask;
192                 for (i = send_index ; i < last_index ; ++i) {
193                     send_count += tmp_rcounts[i];
194                 }
195                 for (i = recv_index ; i < send_index ; ++i) {
196                     recv_count += tmp_rcounts[i];
197                 }
198             } else {
199                 recv_index = send_index + mask;
200                 for (i = send_index ; i < recv_index ; ++i) {
201                     send_count += tmp_rcounts[i];
202                 }
203                 for (i = recv_index ; i < last_index ; ++i) {
204                     recv_count += tmp_rcounts[i];
205                 }
206             }
207
208             /* actual data transfer.  Send from result_buf,
209                receive into recv_buf */
210             if (send_count > 0 && recv_count != 0) {
211                 request=smpi_mpi_irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
212                                          recv_count, dtype, peer,
213                                          COLL_TAG_REDUCE_SCATTER,
214                                          comm);
215                 if (MPI_SUCCESS != err) {
216                     xbt_free(tmp_rcounts);
217                     xbt_free(tmp_disps);
218                     goto cleanup;
219                 }                                             
220             }
221             if (recv_count > 0 && send_count != 0) {
222                 smpi_mpi_send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
223                                         send_count, dtype, peer, 
224                                         COLL_TAG_REDUCE_SCATTER,
225                                         comm);
226                 if (MPI_SUCCESS != err) {
227                     xbt_free(tmp_rcounts);
228                     xbt_free(tmp_disps);
229                     goto cleanup;
230                 }                                             
231             }
232             if (send_count > 0 && recv_count != 0) {
233                 smpi_mpi_wait(&request, MPI_STATUS_IGNORE);
234             }
235
236             /* if we received something on this step, push it into
237                the results buffer */
238             if (recv_count > 0) {
239                 smpi_op_apply(op, 
240                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent, 
241                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
242                                &recv_count, &dtype);
243             }
244
245             /* update for next iteration */
246             send_index = recv_index;
247             last_index = recv_index + mask;
248             mask >>= 1;
249         }
250
251         /* copy local results from results buffer into real receive buffer */
252         if (0 != rcounts[rank]) {
253             err = smpi_datatype_copy(result_buf + disps[rank] * extent,
254                                        rcounts[rank], dtype, 
255                                        rbuf, rcounts[rank], dtype);
256             if (MPI_SUCCESS != err) {
257                 xbt_free(tmp_rcounts);
258                 xbt_free(tmp_disps);
259                 goto cleanup;
260             }                                             
261         }
262
263         xbt_free(tmp_rcounts);
264         xbt_free(tmp_disps);
265     }
266
267     /* Now fix up the non-power of two case, by having the odd
268        procs send the even procs the proper results */
269     if (rank < (2 * remain)) {
270         if ((rank & 1) == 0) {
271             if (rcounts[rank]) {
272                 smpi_mpi_recv(rbuf, rcounts[rank], dtype, rank + 1,
273                                         COLL_TAG_REDUCE_SCATTER,
274                                         comm, MPI_STATUS_IGNORE);
275             }
276         } else {
277             if (rcounts[rank - 1]) {
278                 smpi_mpi_send(result_buf + disps[rank - 1] * extent,
279                                         rcounts[rank - 1], dtype, rank - 1,
280                                         COLL_TAG_REDUCE_SCATTER,
281                                         comm);
282             }
283         }            
284     }
285
286  cleanup:
287     if (NULL != disps) xbt_free(disps);
288     if (NULL != recv_buf_free) xbt_free(recv_buf_free);
289     if (NULL != result_buf_free) xbt_free(result_buf_free);
290
291     return err;
292 }
293
294 /* copied function (with appropriate renaming) ends here */
295
296
297 /*
298  *   smpi_coll_tuned_reduce_scatter_ompi_ring
299  *
300  *   Function:       Ring algorithm for reduce_scatter operation
301  *   Accepts:        Same as MPI_Reduce_scatter()
302  *   Returns:        MPI_SUCCESS or error code
303  *
304  *   Description:    Implements ring algorithm for reduce_scatter: 
305  *                   the block sizes defined in rcounts are exchanged and 
306  8                    updated until they reach proper destination.
307  *                   Algorithm requires 2 * max(rcounts) extra buffering
308  *
309  *   Limitations:    The algorithm DOES NOT preserve order of operations so it 
310  *                   can be used only for commutative operations.
311  *         Example on 5 nodes:
312  *         Initial state
313  *   #      0              1             2              3             4
314  *        [00]           [10]   ->     [20]           [30]           [40]
315  *        [01]           [11]          [21]  ->       [31]           [41]
316  *        [02]           [12]          [22]           [32]  ->       [42]
317  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
318  *        [04]  ->       [14]          [24]           [34]           [44]
319  *
320  *        COMPUTATION PHASE
321  *         Step 0: rank r sends block (r-1) to rank (r+1) and 
322  *                 receives block (r+1) from rank (r-1) [with wraparound].
323  *   #      0              1             2              3             4
324  *        [00]           [10]        [10+20]   ->     [30]           [40]
325  *        [01]           [11]          [21]          [21+31]  ->     [41]
326  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
327  *      [43+03] ->       [13]          [23]           [33]           [43]
328  *        [04]         [04+14]  ->     [24]           [34]           [44]
329  *         
330  *         Step 1:
331  *   #      0              1             2              3             4
332  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
333  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
334  *     [32+42+02] ->     [12]          [22]           [32]         [32+42] 
335  *        [03]        [43+03+13] ->    [23]           [33]           [43]
336  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
337  *
338  *         Step 2:
339  *   #      0              1             2              3             4
340  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
341  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
342  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42] 
343  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
344  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
345  *
346  *         Step 3:
347  *   #      0             1              2              3             4
348  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
349  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
350  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42] 
351  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
352  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
353  *    DONE :)
354  *
355  */
356 int 
357 smpi_coll_tuned_reduce_scatter_ompi_ring(void *sbuf, void *rbuf, int *rcounts,
358                                           MPI_Datatype dtype,
359                                           MPI_Op op,
360                                           MPI_Comm comm
361                                           )
362 {
363     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
364     int inbi, *displs = NULL;
365     char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
366     char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
367     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
368     MPI_Request reqs[2] = {NULL, NULL};
369
370     size = smpi_comm_size(comm);
371     rank = smpi_comm_rank(comm);
372
373     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d", 
374                  rank, size);
375
376     /* Determine the maximum number of elements per node, 
377        corresponding block size, and displacements array.
378     */
379     displs = (int*) xbt_malloc(size * sizeof(int));
380     if (NULL == displs) { ret = -1; line = __LINE__; goto error_hndl; }
381     displs[0] = 0;
382     total_count = rcounts[0];
383     max_block_count = rcounts[0];
384     for (i = 1; i < size; i++) { 
385         displs[i] = total_count;
386         total_count += rcounts[i];
387         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
388     }
389       
390     /* Special case for size == 1 */
391     if (1 == size) {
392         if (MPI_IN_PLACE != sbuf) {
393             ret = smpi_datatype_copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
394             if (ret < 0) { line = __LINE__; goto error_hndl; }
395         }
396         xbt_free(displs);
397         return MPI_SUCCESS;
398     }
399
400     /* Allocate and initialize temporary buffers, we need:
401        - a temporary buffer to perform reduction (size total_count) since
402        rbuf can be of rcounts[rank] size.
403        - up to two temporary buffers used for communication/computation overlap.
404     */
405     smpi_datatype_extent(dtype, &lb, &extent);
406     smpi_datatype_extent(dtype, &true_lb, &true_extent);
407
408     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
409
410     accumbuf_free = (char*)xbt_malloc(true_extent + (ptrdiff_t)(total_count - 1) * extent);
411     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
412     accumbuf = accumbuf_free - lb;
413
414     inbuf_free[0] = (char*)xbt_malloc(max_real_segsize);
415     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
416     inbuf[0] = inbuf_free[0] - lb;
417     if (size > 2) {
418         inbuf_free[1] = (char*)xbt_malloc(max_real_segsize);
419         if (NULL == inbuf_free[1]) { ret = -1; line = __LINE__; goto error_hndl; }
420         inbuf[1] = inbuf_free[1] - lb;
421     }
422
423     /* Handle MPI_IN_PLACE for size > 1 */
424     if (MPI_IN_PLACE == sbuf) {
425         sbuf = rbuf;
426     }
427
428     ret = smpi_datatype_copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
429     if (ret < 0) { line = __LINE__; goto error_hndl; }
430
431     /* Computation loop */
432
433     /* 
434        For each of the remote nodes:
435        - post irecv for block (r-2) from (r-1) with wrap around
436        - send block (r-1) to (r+1)
437        - in loop for every step k = 2 .. n
438        - post irecv for block (r - 1 + n - k) % n
439        - wait on block (r + n - k) % n to arrive
440        - compute on block (r + n - k ) % n
441        - send block (r + n - k) % n
442        - wait on block (r)
443        - compute on block (r)
444        - copy block (r) to rbuf
445        Note that we must be careful when computing the begining of buffers and
446        for send operations and computation we must compute the exact block size.
447     */
448     send_to = (rank + 1) % size;
449     recv_from = (rank + size - 1) % size;
450
451     inbi = 0;
452     /* Initialize first receive from the neighbor on the left */
453     reqs[inbi]=smpi_mpi_irecv(inbuf[inbi], max_block_count, dtype, recv_from,
454                              COLL_TAG_REDUCE_SCATTER, comm
455                              );
456     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
457     smpi_mpi_send(tmpsend, rcounts[recv_from], dtype, send_to,
458                             COLL_TAG_REDUCE_SCATTER,
459                              comm);
460
461     for (k = 2; k < size; k++) {
462         const int prevblock = (rank + size - k) % size;
463       
464         inbi = inbi ^ 0x1;
465
466         /* Post irecv for the current block */
467         reqs[inbi]=smpi_mpi_irecv(inbuf[inbi], max_block_count, dtype, recv_from,
468                                  COLL_TAG_REDUCE_SCATTER, comm
469                                  );
470       
471         /* Wait on previous block to arrive */
472         smpi_mpi_wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
473       
474         /* Apply operation on previous block: result goes to rbuf
475            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
476         */
477         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
478         smpi_op_apply(op, inbuf[inbi ^ 0x1], tmprecv, &(rcounts[prevblock]), &dtype);
479       
480         /* send previous block to send_to */
481         smpi_mpi_send(tmprecv, rcounts[prevblock], dtype, send_to,
482                                 COLL_TAG_REDUCE_SCATTER,
483                                  comm);
484     }
485
486     /* Wait on the last block to arrive */
487     smpi_mpi_wait(&reqs[inbi], MPI_STATUS_IGNORE);
488
489     /* Apply operation on the last block (my block)
490        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
491     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
492     smpi_op_apply(op, inbuf[inbi], tmprecv, &(rcounts[rank]), &dtype);
493    
494     /* Copy result from tmprecv to rbuf */
495     ret = smpi_datatype_copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
496     if (ret < 0) { line = __LINE__; goto error_hndl; }
497
498     if (NULL != displs) xbt_free(displs);
499     if (NULL != accumbuf_free) xbt_free(accumbuf_free);
500     if (NULL != inbuf_free[0]) xbt_free(inbuf_free[0]);
501     if (NULL != inbuf_free[1]) xbt_free(inbuf_free[1]);
502
503     return MPI_SUCCESS;
504
505  error_hndl:
506     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
507                  __FILE__, line, rank, ret);
508     if (NULL != displs) xbt_free(displs);
509     if (NULL != accumbuf_free) xbt_free(accumbuf_free);
510     if (NULL != inbuf_free[0]) xbt_free(inbuf_free[0]);
511     if (NULL != inbuf_free[1]) xbt_free(inbuf_free[1]);
512     return ret;
513 }
514