Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Move collective algorithms to separate folders
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-ompi.cpp
1 /* Copyright (c) 2013-2014. The SimGrid Team.
2  * All rights reserved.                                                     */
3
4 /* This program is free software; you can redistribute it and/or modify it
5  * under the terms of the license (GNU LGPL) which comes with this package. */
6
7 /*
8  * Copyright (c) 2004-2005 The Trustees of Indiana University and Indiana
9  *                         University Research and Technology
10  *                         Corporation.  All rights reserved.
11  * Copyright (c) 2004-2012 The University of Tennessee and The University
12  *                         of Tennessee Research Foundation.  All rights
13  *                         reserved.
14  * Copyright (c) 2004-2005 High Performance Computing Center Stuttgart,
15  *                         University of Stuttgart.  All rights reserved.
16  * Copyright (c) 2004-2005 The Regents of the University of California.
17  *                         All rights reserved.
18  * Copyright (c) 2008      Sun Microsystems, Inc.  All rights reserved.
19  * Copyright (c) 2009      University of Houston. All rights reserved.
20  *
21  * Additional copyrights may follow
22  */
23
24 #include "colls_private.h"
25 #include "coll_tuned_topo.h"
26
27 /*
28  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
29  * I have removed the part which handles "large" message sizes 
30  * (non-overlapping version of reduce_Scatter).
31  */
32
33 /* copied function (with appropriate renaming) starts here */
34
35 /*
36  *  reduce_scatter_ompi_basic_recursivehalving
37  *
38  *  Function:   - reduce scatter implementation using recursive-halving 
39  *                algorithm
40  *  Accepts:    - same as MPI_Reduce_scatter()
41  *  Returns:    - MPI_SUCCESS or error code
42  *  Limitation: - Works only for commutative operations.
43  */
44 int
45 smpi_coll_tuned_reduce_scatter_ompi_basic_recursivehalving(void *sbuf, 
46                                                             void *rbuf, 
47                                                             int *rcounts,
48                                                             MPI_Datatype dtype,
49                                                             MPI_Op op,
50                                                             MPI_Comm comm
51                                                             )
52 {
53     int i, rank, size, count, err = MPI_SUCCESS;
54     int tmp_size=1, remain = 0, tmp_rank, *disps = NULL;
55     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
56     char *recv_buf = NULL, *recv_buf_free = NULL;
57     char *result_buf = NULL, *result_buf_free = NULL;
58    
59     /* Initialize */
60     rank = comm->rank();
61     size = comm->size();
62    
63     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
64     if( (op!=MPI_OP_NULL && !op->is_commutative()))
65       THROWF(arg_error,0, " reduce_scatter ompi_basic_recursivehalving can only be used for commutative operations! ");
66
67     /* Find displacements and the like */
68     disps = (int*) xbt_malloc(sizeof(int) * size);
69     if (NULL == disps) return MPI_ERR_OTHER;
70
71     disps[0] = 0;
72     for (i = 0; i < (size - 1); ++i) {
73         disps[i + 1] = disps[i] + rcounts[i];
74     }
75     count = disps[size - 1] + rcounts[size - 1];
76
77     /* short cut the trivial case */
78     if (0 == count) {
79         xbt_free(disps);
80         return MPI_SUCCESS;
81     }
82
83     /* get datatype information */
84     dtype->extent(&lb, &extent);
85     dtype->extent(&true_lb, &true_extent);
86     buf_size = true_extent + (ptrdiff_t)(count - 1) * extent;
87
88     /* Handle MPI_IN_PLACE */
89     if (MPI_IN_PLACE == sbuf) {
90         sbuf = rbuf;
91     }
92
93     /* Allocate temporary receive buffer. */
94     recv_buf_free = (char*) smpi_get_tmp_recvbuffer(buf_size);
95
96     recv_buf = recv_buf_free - lb;
97     if (NULL == recv_buf_free) {
98         err = MPI_ERR_OTHER;
99         goto cleanup;
100     }
101    
102     /* allocate temporary buffer for results */
103     result_buf_free = (char*) smpi_get_tmp_sendbuffer(buf_size);
104
105     result_buf = result_buf_free - lb;
106    
107     /* copy local buffer into the temporary results */
108     err =Datatype::copy(sbuf, count, dtype, result_buf, count, dtype);
109     if (MPI_SUCCESS != err) goto cleanup;
110    
111     /* figure out power of two mapping: grow until larger than
112        comm size, then go back one, to get the largest power of
113        two less than comm size */
114     while (tmp_size <= size) tmp_size <<= 1;
115     tmp_size >>= 1;
116     remain = size - tmp_size;
117    
118     /* If comm size is not a power of two, have the first "remain"
119        procs with an even rank send to rank + 1, leaving a power of
120        two procs to do the rest of the algorithm */
121     if (rank < 2 * remain) {
122         if ((rank & 1) == 0) {
123             Request::send(result_buf, count, dtype, rank + 1, 
124                                     COLL_TAG_REDUCE_SCATTER,
125                                     comm);
126             /* we don't participate from here on out */
127             tmp_rank = -1;
128         } else {
129             Request::recv(recv_buf, count, dtype, rank - 1,
130                                     COLL_TAG_REDUCE_SCATTER,
131                                     comm, MPI_STATUS_IGNORE);
132          
133             /* integrate their results into our temp results */
134             if(op!=MPI_OP_NULL) op->apply( recv_buf, result_buf, &count, dtype);
135          
136             /* adjust rank to be the bottom "remain" ranks */
137             tmp_rank = rank / 2;
138         }
139     } else {
140         /* just need to adjust rank to show that the bottom "even
141            remain" ranks dropped out */
142         tmp_rank = rank - remain;
143     }
144    
145     /* For ranks not kicked out by the above code, perform the
146        recursive halving */
147     if (tmp_rank >= 0) {
148         int *tmp_disps = NULL, *tmp_rcounts = NULL;
149         int mask, send_index, recv_index, last_index;
150       
151         /* recalculate disps and rcounts to account for the
152            special "remainder" processes that are no longer doing
153            anything */
154         tmp_rcounts = (int*) xbt_malloc(tmp_size * sizeof(int));
155         if (NULL == tmp_rcounts) {
156             err = MPI_ERR_OTHER;
157             goto cleanup;
158         }
159         tmp_disps = (int*) xbt_malloc(tmp_size * sizeof(int));
160         if (NULL == tmp_disps) {
161             xbt_free(tmp_rcounts);
162             err = MPI_ERR_OTHER;
163             goto cleanup;
164         }
165
166         for (i = 0 ; i < tmp_size ; ++i) {
167             if (i < remain) {
168                 /* need to include old neighbor as well */
169                 tmp_rcounts[i] = rcounts[i * 2 + 1] + rcounts[i * 2];
170             } else {
171                 tmp_rcounts[i] = rcounts[i + remain];
172             }
173         }
174
175         tmp_disps[0] = 0;
176         for (i = 0; i < tmp_size - 1; ++i) {
177             tmp_disps[i + 1] = tmp_disps[i] + tmp_rcounts[i];
178         }
179
180         /* do the recursive halving communication.  Don't use the
181            dimension information on the communicator because I
182            think the information is invalidated by our "shrinking"
183            of the communicator */
184         mask = tmp_size >> 1;
185         send_index = recv_index = 0;
186         last_index = tmp_size;
187         while (mask > 0) {
188             int tmp_peer, peer, send_count, recv_count;
189             MPI_Request request;
190
191             tmp_peer = tmp_rank ^ mask;
192             peer = (tmp_peer < remain) ? tmp_peer * 2 + 1 : tmp_peer + remain;
193
194             /* figure out if we're sending, receiving, or both */
195             send_count = recv_count = 0;
196             if (tmp_rank < tmp_peer) {
197                 send_index = recv_index + mask;
198                 for (i = send_index ; i < last_index ; ++i) {
199                     send_count += tmp_rcounts[i];
200                 }
201                 for (i = recv_index ; i < send_index ; ++i) {
202                     recv_count += tmp_rcounts[i];
203                 }
204             } else {
205                 recv_index = send_index + mask;
206                 for (i = send_index ; i < recv_index ; ++i) {
207                     send_count += tmp_rcounts[i];
208                 }
209                 for (i = recv_index ; i < last_index ; ++i) {
210                     recv_count += tmp_rcounts[i];
211                 }
212             }
213
214             /* actual data transfer.  Send from result_buf,
215                receive into recv_buf */
216             if (send_count > 0 && recv_count != 0) {
217                 request=Request::irecv(recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
218                                          recv_count, dtype, peer,
219                                          COLL_TAG_REDUCE_SCATTER,
220                                          comm);
221                 if (MPI_SUCCESS != err) {
222                     xbt_free(tmp_rcounts);
223                     xbt_free(tmp_disps);
224                     goto cleanup;
225                 }                                             
226             }
227             if (recv_count > 0 && send_count != 0) {
228                 Request::send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
229                                         send_count, dtype, peer, 
230                                         COLL_TAG_REDUCE_SCATTER,
231                                         comm);
232                 if (MPI_SUCCESS != err) {
233                     xbt_free(tmp_rcounts);
234                     xbt_free(tmp_disps);
235                     goto cleanup;
236                 }                                             
237             }
238             if (send_count > 0 && recv_count != 0) {
239                 Request::wait(&request, MPI_STATUS_IGNORE);
240             }
241
242             /* if we received something on this step, push it into
243                the results buffer */
244             if (recv_count > 0) {
245                 if(op!=MPI_OP_NULL) op->apply( 
246                                recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent, 
247                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
248                                &recv_count, dtype);
249             }
250
251             /* update for next iteration */
252             send_index = recv_index;
253             last_index = recv_index + mask;
254             mask >>= 1;
255         }
256
257         /* copy local results from results buffer into real receive buffer */
258         if (0 != rcounts[rank]) {
259             err = Datatype::copy(result_buf + disps[rank] * extent,
260                                        rcounts[rank], dtype, 
261                                        rbuf, rcounts[rank], dtype);
262             if (MPI_SUCCESS != err) {
263                 xbt_free(tmp_rcounts);
264                 xbt_free(tmp_disps);
265                 goto cleanup;
266             }                                             
267         }
268
269         xbt_free(tmp_rcounts);
270         xbt_free(tmp_disps);
271     }
272
273     /* Now fix up the non-power of two case, by having the odd
274        procs send the even procs the proper results */
275     if (rank < (2 * remain)) {
276         if ((rank & 1) == 0) {
277             if (rcounts[rank]) {
278                 Request::recv(rbuf, rcounts[rank], dtype, rank + 1,
279                                         COLL_TAG_REDUCE_SCATTER,
280                                         comm, MPI_STATUS_IGNORE);
281             }
282         } else {
283             if (rcounts[rank - 1]) {
284                 Request::send(result_buf + disps[rank - 1] * extent,
285                                         rcounts[rank - 1], dtype, rank - 1,
286                                         COLL_TAG_REDUCE_SCATTER,
287                                         comm);
288             }
289         }            
290     }
291
292  cleanup:
293     if (NULL != disps) xbt_free(disps);
294     if (NULL != recv_buf_free) smpi_free_tmp_buffer(recv_buf_free);
295     if (NULL != result_buf_free) smpi_free_tmp_buffer(result_buf_free);
296
297     return err;
298 }
299
300 /* copied function (with appropriate renaming) ends here */
301
302
303 /*
304  *   smpi_coll_tuned_reduce_scatter_ompi_ring
305  *
306  *   Function:       Ring algorithm for reduce_scatter operation
307  *   Accepts:        Same as MPI_Reduce_scatter()
308  *   Returns:        MPI_SUCCESS or error code
309  *
310  *   Description:    Implements ring algorithm for reduce_scatter: 
311  *                   the block sizes defined in rcounts are exchanged and 
312  8                    updated until they reach proper destination.
313  *                   Algorithm requires 2 * max(rcounts) extra buffering
314  *
315  *   Limitations:    The algorithm DOES NOT preserve order of operations so it 
316  *                   can be used only for commutative operations.
317  *         Example on 5 nodes:
318  *         Initial state
319  *   #      0              1             2              3             4
320  *        [00]           [10]   ->     [20]           [30]           [40]
321  *        [01]           [11]          [21]  ->       [31]           [41]
322  *        [02]           [12]          [22]           [32]  ->       [42]
323  *    ->  [03]           [13]          [23]           [33]           [43] --> ..
324  *        [04]  ->       [14]          [24]           [34]           [44]
325  *
326  *        COMPUTATION PHASE
327  *         Step 0: rank r sends block (r-1) to rank (r+1) and 
328  *                 receives block (r+1) from rank (r-1) [with wraparound].
329  *   #      0              1             2              3             4
330  *        [00]           [10]        [10+20]   ->     [30]           [40]
331  *        [01]           [11]          [21]          [21+31]  ->     [41]
332  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
333  *      [43+03] ->       [13]          [23]           [33]           [43]
334  *        [04]         [04+14]  ->     [24]           [34]           [44]
335  *         
336  *         Step 1:
337  *   #      0              1             2              3             4
338  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
339  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
340  *     [32+42+02] ->     [12]          [22]           [32]         [32+42] 
341  *        [03]        [43+03+13] ->    [23]           [33]           [43]
342  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
343  *
344  *         Step 2:
345  *   #      0              1             2              3             4
346  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
347  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
348  *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42] 
349  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
350  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
351  *
352  *         Step 3:
353  *   #      0             1              2              3             4
354  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
355  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
356  *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42] 
357  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
358  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
359  *    DONE :)
360  *
361  */
362 int 
363 smpi_coll_tuned_reduce_scatter_ompi_ring(void *sbuf, void *rbuf, int *rcounts,
364                                           MPI_Datatype dtype,
365                                           MPI_Op op,
366                                           MPI_Comm comm
367                                           )
368 {
369     int ret, line, rank, size, i, k, recv_from, send_to, total_count, max_block_count;
370     int inbi, *displs = NULL;
371     char *tmpsend = NULL, *tmprecv = NULL, *accumbuf = NULL, *accumbuf_free = NULL;
372     char *inbuf_free[2] = {NULL, NULL}, *inbuf[2] = {NULL, NULL};
373     ptrdiff_t true_lb, true_extent, lb, extent, max_real_segsize;
374     MPI_Request reqs[2] = {NULL, NULL};
375
376     size = comm->size();
377     rank = comm->rank();
378
379     XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d", 
380                  rank, size);
381
382     /* Determine the maximum number of elements per node, 
383        corresponding block size, and displacements array.
384     */
385     displs = (int*) xbt_malloc(size * sizeof(int));
386     if (NULL == displs) { ret = -1; line = __LINE__; goto error_hndl; }
387     displs[0] = 0;
388     total_count = rcounts[0];
389     max_block_count = rcounts[0];
390     for (i = 1; i < size; i++) { 
391         displs[i] = total_count;
392         total_count += rcounts[i];
393         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
394     }
395       
396     /* Special case for size == 1 */
397     if (1 == size) {
398         if (MPI_IN_PLACE != sbuf) {
399             ret = Datatype::copy((char*)sbuf, total_count, dtype, (char*)rbuf, total_count, dtype);
400             if (ret < 0) { line = __LINE__; goto error_hndl; }
401         }
402         xbt_free(displs);
403         return MPI_SUCCESS;
404     }
405
406     /* Allocate and initialize temporary buffers, we need:
407        - a temporary buffer to perform reduction (size total_count) since
408        rbuf can be of rcounts[rank] size.
409        - up to two temporary buffers used for communication/computation overlap.
410     */
411     dtype->extent(&lb, &extent);
412     dtype->extent(&true_lb, &true_extent);
413
414     max_real_segsize = true_extent + (ptrdiff_t)(max_block_count - 1) * extent;
415
416     accumbuf_free = (char*)smpi_get_tmp_recvbuffer(true_extent + (ptrdiff_t)(total_count - 1) * extent);
417     if (NULL == accumbuf_free) { ret = -1; line = __LINE__; goto error_hndl; }
418     accumbuf = accumbuf_free - lb;
419
420     inbuf_free[0] = (char*)smpi_get_tmp_sendbuffer(max_real_segsize);
421     if (NULL == inbuf_free[0]) { ret = -1; line = __LINE__; goto error_hndl; }
422     inbuf[0] = inbuf_free[0] - lb;
423     if (size > 2) {
424         inbuf_free[1] = (char*)smpi_get_tmp_sendbuffer(max_real_segsize);
425         if (NULL == inbuf_free[1]) { ret = -1; line = __LINE__; goto error_hndl; }
426         inbuf[1] = inbuf_free[1] - lb;
427     }
428
429     /* Handle MPI_IN_PLACE for size > 1 */
430     if (MPI_IN_PLACE == sbuf) {
431         sbuf = rbuf;
432     }
433
434     ret = Datatype::copy((char*)sbuf, total_count, dtype, accumbuf, total_count, dtype);
435     if (ret < 0) { line = __LINE__; goto error_hndl; }
436
437     /* Computation loop */
438
439     /* 
440        For each of the remote nodes:
441        - post irecv for block (r-2) from (r-1) with wrap around
442        - send block (r-1) to (r+1)
443        - in loop for every step k = 2 .. n
444        - post irecv for block (r - 1 + n - k) % n
445        - wait on block (r + n - k) % n to arrive
446        - compute on block (r + n - k ) % n
447        - send block (r + n - k) % n
448        - wait on block (r)
449        - compute on block (r)
450        - copy block (r) to rbuf
451        Note that we must be careful when computing the begining of buffers and
452        for send operations and computation we must compute the exact block size.
453     */
454     send_to = (rank + 1) % size;
455     recv_from = (rank + size - 1) % size;
456
457     inbi = 0;
458     /* Initialize first receive from the neighbor on the left */
459     reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
460                              COLL_TAG_REDUCE_SCATTER, comm
461                              );
462     tmpsend = accumbuf + (ptrdiff_t)displs[recv_from] * extent;
463     Request::send(tmpsend, rcounts[recv_from], dtype, send_to,
464                             COLL_TAG_REDUCE_SCATTER,
465                              comm);
466
467     for (k = 2; k < size; k++) {
468         const int prevblock = (rank + size - k) % size;
469       
470         inbi = inbi ^ 0x1;
471
472         /* Post irecv for the current block */
473         reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
474                                  COLL_TAG_REDUCE_SCATTER, comm
475                                  );
476       
477         /* Wait on previous block to arrive */
478         Request::wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
479       
480         /* Apply operation on previous block: result goes to rbuf
481            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
482         */
483         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
484         if(op!=MPI_OP_NULL) op->apply( inbuf[inbi ^ 0x1], tmprecv, &(rcounts[prevblock]), dtype);
485       
486         /* send previous block to send_to */
487         Request::send(tmprecv, rcounts[prevblock], dtype, send_to,
488                                 COLL_TAG_REDUCE_SCATTER,
489                                  comm);
490     }
491
492     /* Wait on the last block to arrive */
493     Request::wait(&reqs[inbi], MPI_STATUS_IGNORE);
494
495     /* Apply operation on the last block (my block)
496        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
497     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
498     if(op!=MPI_OP_NULL) op->apply( inbuf[inbi], tmprecv, &(rcounts[rank]), dtype);
499    
500     /* Copy result from tmprecv to rbuf */
501     ret = Datatype::copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
502     if (ret < 0) { line = __LINE__; goto error_hndl; }
503
504     if (NULL != displs) xbt_free(displs);
505     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
506     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
507     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
508
509     return MPI_SUCCESS;
510
511  error_hndl:
512     XBT_DEBUG( "%s:%4d\tRank %d Error occurred %d\n",
513                  __FILE__, line, rank, ret);
514     if (NULL != displs) xbt_free(displs);
515     if (NULL != accumbuf_free) smpi_free_tmp_buffer(accumbuf_free);
516     if (NULL != inbuf_free[0]) smpi_free_tmp_buffer(inbuf_free[0]);
517     if (NULL != inbuf_free[1]) smpi_free_tmp_buffer(inbuf_free[1]);
518     return ret;
519 }
520