Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of git+ssh://scm.gforge.inria.fr//gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-ompi.cpp
index 6eea9d6..1bc913c 100644 (file)
  * Additional copyrights may follow
  */
 
-#include "../colls_private.h"
-#include "../coll_tuned_topo.h"
+#include "../coll_tuned_topo.hpp"
+#include "../colls_private.hpp"
 
 /*
  * Recursive-halving function is (*mostly*) copied from the BASIC coll module.
- * I have removed the part which handles "large" message sizes 
+ * I have removed the part which handles "large" message sizes
  * (non-overlapping version of reduce_Scatter).
  */
 
@@ -35,7 +35,7 @@
 /*
  *  reduce_scatter_ompi_basic_recursivehalving
  *
- *  Function:   - reduce scatter implementation using recursive-halving 
+ *  Function:   - reduce scatter implementation using recursive-halving
  *                algorithm
  *  Accepts:    - same as MPI_Reduce_scatter()
  *  Returns:    - MPI_SUCCESS or error code
@@ -44,8 +44,8 @@
 namespace simgrid{
 namespace smpi{
 int
-Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf, 
-                                                            void *rbuf, 
+Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
+                                                            void *rbuf,
                                                             int *rcounts,
                                                             MPI_Datatype dtype,
                                                             MPI_Op op,
@@ -57,13 +57,13 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
     ptrdiff_t true_lb, true_extent, lb, extent, buf_size;
     char *recv_buf = NULL, *recv_buf_free = NULL;
     char *result_buf = NULL, *result_buf_free = NULL;
-   
+
     /* Initialize */
     rank = comm->rank();
     size = comm->size();
-   
+
     XBT_DEBUG("coll:tuned:reduce_scatter_ompi_basic_recursivehalving, rank %d", rank);
-    if( (op!=MPI_OP_NULL && !op->is_commutative()))
+    if ((op != MPI_OP_NULL && not op->is_commutative()))
       THROWF(arg_error,0, " reduce_scatter ompi_basic_recursivehalving can only be used for commutative operations! ");
 
     /* Find displacements and the like */
@@ -100,29 +100,29 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
         err = MPI_ERR_OTHER;
         goto cleanup;
     }
-   
+
     /* allocate temporary buffer for results */
     result_buf_free = (char*) smpi_get_tmp_sendbuffer(buf_size);
 
     result_buf = result_buf_free - lb;
-   
+
     /* copy local buffer into the temporary results */
     err =Datatype::copy(sbuf, count, dtype, result_buf, count, dtype);
     if (MPI_SUCCESS != err) goto cleanup;
-   
+
     /* figure out power of two mapping: grow until larger than
        comm size, then go back one, to get the largest power of
        two less than comm size */
     while (tmp_size <= size) tmp_size <<= 1;
     tmp_size >>= 1;
     remain = size - tmp_size;
-   
+
     /* If comm size is not a power of two, have the first "remain"
        procs with an even rank send to rank + 1, leaving a power of
        two procs to do the rest of the algorithm */
     if (rank < 2 * remain) {
         if ((rank & 1) == 0) {
-            Request::send(result_buf, count, dtype, rank + 1, 
+            Request::send(result_buf, count, dtype, rank + 1,
                                     COLL_TAG_REDUCE_SCATTER,
                                     comm);
             /* we don't participate from here on out */
@@ -131,10 +131,10 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
             Request::recv(recv_buf, count, dtype, rank - 1,
                                     COLL_TAG_REDUCE_SCATTER,
                                     comm, MPI_STATUS_IGNORE);
-         
+
             /* integrate their results into our temp results */
             if(op!=MPI_OP_NULL) op->apply( recv_buf, result_buf, &count, dtype);
-         
+
             /* adjust rank to be the bottom "remain" ranks */
             tmp_rank = rank / 2;
         }
@@ -143,13 +143,13 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
            remain" ranks dropped out */
         tmp_rank = rank - remain;
     }
-   
+
     /* For ranks not kicked out by the above code, perform the
        recursive halving */
     if (tmp_rank >= 0) {
         int *tmp_disps = NULL, *tmp_rcounts = NULL;
         int mask, send_index, recv_index, last_index;
-      
+
         /* recalculate disps and rcounts to account for the
            special "remainder" processes that are no longer doing
            anything */
@@ -224,18 +224,18 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
                     xbt_free(tmp_rcounts);
                     xbt_free(tmp_disps);
                     goto cleanup;
-                }                                             
+                }
             }
             if (recv_count > 0 && send_count != 0) {
                 Request::send(result_buf + (ptrdiff_t)tmp_disps[send_index] * extent,
-                                        send_count, dtype, peer, 
+                                        send_count, dtype, peer,
                                         COLL_TAG_REDUCE_SCATTER,
                                         comm);
                 if (MPI_SUCCESS != err) {
                     xbt_free(tmp_rcounts);
                     xbt_free(tmp_disps);
                     goto cleanup;
-                }                                             
+                }
             }
             if (send_count > 0 && recv_count != 0) {
                 Request::wait(&request, MPI_STATUS_IGNORE);
@@ -244,8 +244,8 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
             /* if we received something on this step, push it into
                the results buffer */
             if (recv_count > 0) {
-                if(op!=MPI_OP_NULL) op->apply( 
-                               recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent, 
+                if(op!=MPI_OP_NULL) op->apply(
+                               recv_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
                                result_buf + (ptrdiff_t)tmp_disps[recv_index] * extent,
                                &recv_count, dtype);
             }
@@ -259,13 +259,13 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
         /* copy local results from results buffer into real receive buffer */
         if (0 != rcounts[rank]) {
             err = Datatype::copy(result_buf + disps[rank] * extent,
-                                       rcounts[rank], dtype, 
+                                       rcounts[rank], dtype,
                                        rbuf, rcounts[rank], dtype);
             if (MPI_SUCCESS != err) {
                 xbt_free(tmp_rcounts);
                 xbt_free(tmp_disps);
                 goto cleanup;
-            }                                             
+            }
         }
 
         xbt_free(tmp_rcounts);
@@ -288,7 +288,7 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
                                         COLL_TAG_REDUCE_SCATTER,
                                         comm);
             }
-        }            
+        }
     }
 
  cleanup:
@@ -309,12 +309,12 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
  *   Accepts:        Same as MPI_Reduce_scatter()
  *   Returns:        MPI_SUCCESS or error code
  *
- *   Description:    Implements ring algorithm for reduce_scatter: 
- *                   the block sizes defined in rcounts are exchanged and 
+ *   Description:    Implements ring algorithm for reduce_scatter:
+ *                   the block sizes defined in rcounts are exchanged and
  8                    updated until they reach proper destination.
  *                   Algorithm requires 2 * max(rcounts) extra buffering
  *
- *   Limitations:    The algorithm DOES NOT preserve order of operations so it 
+ *   Limitations:    The algorithm DOES NOT preserve order of operations so it
  *                   can be used only for commutative operations.
  *         Example on 5 nodes:
  *         Initial state
@@ -326,7 +326,7 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
  *        [04]  ->       [14]          [24]           [34]           [44]
  *
  *        COMPUTATION PHASE
- *         Step 0: rank r sends block (r-1) to rank (r+1) and 
+ *         Step 0: rank r sends block (r-1) to rank (r+1) and
  *                 receives block (r+1) from rank (r-1) [with wraparound].
  *   #      0              1             2              3             4
  *        [00]           [10]        [10+20]   ->     [30]           [40]
@@ -334,12 +334,12 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
  *    ->  [02]           [12]          [22]           [32]         [32+42] -->..
  *      [43+03] ->       [13]          [23]           [33]           [43]
  *        [04]         [04+14]  ->     [24]           [34]           [44]
- *         
+ *
  *         Step 1:
  *   #      0              1             2              3             4
  *        [00]           [10]        [10+20]       [10+20+30] ->     [40]
  *    ->  [01]           [11]          [21]          [21+31]      [21+31+41] ->
- *     [32+42+02] ->     [12]          [22]           [32]         [32+42] 
+ *     [32+42+02] ->     [12]          [22]           [32]         [32+42]
  *        [03]        [43+03+13] ->    [23]           [33]           [43]
  *        [04]         [04+14]      [04+14+24]  ->    [34]           [44]
  *
@@ -347,7 +347,7 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
  *   #      0              1             2              3             4
  *     -> [00]           [10]        [10+20]       [10+20+30]   [10+20+30+40] ->
  *   [21+31+41+01]->     [11]          [21]          [21+31]      [21+31+41]
- *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42] 
+ *     [32+42+02]   [32+42+02+12]->    [22]           [32]         [32+42]
  *        [03]        [43+03+13]   [43+03+13+23]->    [33]           [43]
  *        [04]         [04+14]      [04+14+24]    [04+14+24+34] ->   [44]
  *
@@ -355,13 +355,13 @@ Coll_reduce_scatter_ompi_basic_recursivehalving::reduce_scatter(void *sbuf,
  *   #      0             1              2              3             4
  * [10+20+30+40+00]     [10]         [10+20]       [10+20+30]   [10+20+30+40]
  *  [21+31+41+01] [21+31+41+01+11]     [21]          [21+31]      [21+31+41]
- *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42] 
+ *    [32+42+02]   [32+42+02+12] [32+42+02+12+22]     [32]         [32+42]
  *       [03]        [43+03+13]    [43+03+13+23] [43+03+13+23+33]    [43]
  *       [04]         [04+14]       [04+14+24]    [04+14+24+34] [04+14+24+34+44]
  *    DONE :)
  *
  */
-int 
+int
 Coll_reduce_scatter_ompi_ring::reduce_scatter(void *sbuf, void *rbuf, int *rcounts,
                                           MPI_Datatype dtype,
                                           MPI_Op op,
@@ -378,10 +378,10 @@ Coll_reduce_scatter_ompi_ring::reduce_scatter(void *sbuf, void *rbuf, int *rcoun
     size = comm->size();
     rank = comm->rank();
 
-    XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d", 
+    XBT_DEBUG(  "coll:tuned:reduce_scatter_ompi_ring rank %d, size %d",
                  rank, size);
 
-    /* Determine the maximum number of elements per node, 
+    /* Determine the maximum number of elements per node,
        corresponding block size, and displacements array.
     */
     displs = (int*) xbt_malloc(size * sizeof(int));
@@ -389,12 +389,12 @@ Coll_reduce_scatter_ompi_ring::reduce_scatter(void *sbuf, void *rbuf, int *rcoun
     displs[0] = 0;
     total_count = rcounts[0];
     max_block_count = rcounts[0];
-    for (i = 1; i < size; i++) { 
+    for (i = 1; i < size; i++) {
         displs[i] = total_count;
         total_count += rcounts[i];
         if (max_block_count < rcounts[i]) max_block_count = rcounts[i];
     }
-      
+
     /* Special case for size == 1 */
     if (1 == size) {
         if (MPI_IN_PLACE != sbuf) {
@@ -438,7 +438,7 @@ Coll_reduce_scatter_ompi_ring::reduce_scatter(void *sbuf, void *rbuf, int *rcoun
 
     /* Computation loop */
 
-    /* 
+    /*
        For each of the remote nodes:
        - post irecv for block (r-2) from (r-1) with wrap around
        - send block (r-1) to (r+1)
@@ -468,23 +468,23 @@ Coll_reduce_scatter_ompi_ring::reduce_scatter(void *sbuf, void *rbuf, int *rcoun
 
     for (k = 2; k < size; k++) {
         const int prevblock = (rank + size - k) % size;
-      
+
         inbi = inbi ^ 0x1;
 
         /* Post irecv for the current block */
         reqs[inbi]=Request::irecv(inbuf[inbi], max_block_count, dtype, recv_from,
                                  COLL_TAG_REDUCE_SCATTER, comm
                                  );
-      
+
         /* Wait on previous block to arrive */
         Request::wait(&reqs[inbi ^ 0x1], MPI_STATUS_IGNORE);
-      
+
         /* Apply operation on previous block: result goes to rbuf
            rbuf[prevblock] = inbuf[inbi ^ 0x1] (op) rbuf[prevblock]
         */
         tmprecv = accumbuf + (ptrdiff_t)displs[prevblock] * extent;
         if(op!=MPI_OP_NULL) op->apply( inbuf[inbi ^ 0x1], tmprecv, &(rcounts[prevblock]), dtype);
-      
+
         /* send previous block to send_to */
         Request::send(tmprecv, rcounts[prevblock], dtype, send_to,
                                 COLL_TAG_REDUCE_SCATTER,
@@ -498,7 +498,7 @@ Coll_reduce_scatter_ompi_ring::reduce_scatter(void *sbuf, void *rbuf, int *rcoun
        rbuf[rank] = inbuf[inbi] (op) rbuf[rank] */
     tmprecv = accumbuf + (ptrdiff_t)displs[rank] * extent;
     if(op!=MPI_OP_NULL) op->apply( inbuf[inbi], tmprecv, &(rcounts[rank]), dtype);
-   
+
     /* Copy result from tmprecv to rbuf */
     ret = Datatype::copy(tmprecv, rcounts[rank], dtype, (char*)rbuf, rcounts[rank], dtype);
     if (ret < 0) { line = __LINE__; goto error_hndl; }