Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of git+ssh://scm.gforge.inria.fr//gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / colls / reduce_scatter / reduce_scatter-mpich.cpp
index 8481ab7..995f078 100644 (file)
@@ -1,10 +1,10 @@
-/* Copyright (c) 2013-2014. The SimGrid Team.
+/* Copyright (c) 2013-2017. The SimGrid Team.
  * All rights reserved.                                                     */
 
 /* This program is free software; you can redistribute it and/or modify it
  * under the terms of the license (GNU LGPL) which comes with this package. */
 
-#include "../colls_private.h"
+#include "../colls_private.hpp"
 
 static inline int MPIU_Mirror_permutation(unsigned int x, int bits)
 {
@@ -27,7 +27,7 @@ int Coll_reduce_scatter_mpich_pair::reduce_scatter(void *sendbuf, void *recvbuf,
                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
 {
     int   rank, comm_size, i;
-    MPI_Aint extent, true_extent, true_lb; 
+    MPI_Aint extent, true_extent, true_lb;
     int  *disps;
     void *tmp_recvbuf;
     int mpi_errno = MPI_SUCCESS;
@@ -38,7 +38,7 @@ int Coll_reduce_scatter_mpich_pair::reduce_scatter(void *sendbuf, void *recvbuf,
 
     extent =datatype->get_extent();
     datatype->extent(&true_lb, &true_extent);
-    
+
     if (op->is_commutative()) {
         is_commutative = 1;
     }
@@ -50,7 +50,7 @@ int Coll_reduce_scatter_mpich_pair::reduce_scatter(void *sendbuf, void *recvbuf,
         disps[i] = total_count;
         total_count += recvcounts[i];
     }
-    
+
     if (total_count == 0) {
         xbt_free(disps);
         return MPI_ERR_COUNT;
@@ -62,92 +62,87 @@ int Coll_reduce_scatter_mpich_pair::reduce_scatter(void *sendbuf, void *recvbuf,
                                        recvcounts[rank], datatype, recvbuf,
                                        recvcounts[rank], datatype);
         }
-        
+
         /* allocate temporary buffer to store incoming data */
         tmp_recvbuf = (void*)smpi_get_tmp_recvbuffer(recvcounts[rank]*(MAX(true_extent,extent))+1);
         /* adjust for potential negative lower bound in datatype */
         tmp_recvbuf = (void *)((char*)tmp_recvbuf - true_lb);
-        
+
         for (i=1; i<comm_size; i++) {
             src = (rank - i + comm_size) % comm_size;
             dst = (rank + i) % comm_size;
-            
+
             /* send the data that dst needs. recv data that this process
                needs from src into tmp_recvbuf */
-            if (sendbuf != MPI_IN_PLACE) 
-                Request::sendrecv(((char *)sendbuf+disps[dst]*extent), 
+            if (sendbuf != MPI_IN_PLACE)
+                Request::sendrecv(((char *)sendbuf+disps[dst]*extent),
                                              recvcounts[dst], datatype, dst,
                                              COLL_TAG_SCATTER, tmp_recvbuf,
                                              recvcounts[rank], datatype, src,
                                              COLL_TAG_SCATTER, comm,
                                              MPI_STATUS_IGNORE);
             else
-                Request::sendrecv(((char *)recvbuf+disps[dst]*extent), 
+                Request::sendrecv(((char *)recvbuf+disps[dst]*extent),
                                              recvcounts[dst], datatype, dst,
                                              COLL_TAG_SCATTER, tmp_recvbuf,
                                              recvcounts[rank], datatype, src,
                                              COLL_TAG_SCATTER, comm,
                                              MPI_STATUS_IGNORE);
-            
+
             if (is_commutative || (src < rank)) {
                 if (sendbuf != MPI_IN_PLACE) {
-                    if(op!=MPI_OP_NULL) op->apply(
-                                                 tmp_recvbuf, recvbuf, &recvcounts[rank],
-                               datatype); 
+                  if (op != MPI_OP_NULL)
+                    op->apply(tmp_recvbuf, recvbuf, &recvcounts[rank], datatype);
                 }
                 else {
-                   if(op!=MPI_OP_NULL) op->apply( 
-                       tmp_recvbuf, ((char *)recvbuf+disps[rank]*extent), 
-                       &recvcounts[rank], datatype);
-                    /* we can't store the result at the beginning of
-                       recvbuf right here because there is useful data
-                       there that other process/processes need. at the
-                       end, we will copy back the result to the
-                       beginning of recvbuf. */
+                  if (op != MPI_OP_NULL)
+                    op->apply(tmp_recvbuf, ((char*)recvbuf + disps[rank] * extent), &recvcounts[rank], datatype);
+                  /* we can't store the result at the beginning of
+                     recvbuf right here because there is useful data
+                     there that other process/processes need. at the
+                     end, we will copy back the result to the
+                     beginning of recvbuf. */
                 }
             }
             else {
                 if (sendbuf != MPI_IN_PLACE) {
-                   if(op!=MPI_OP_NULL) op->apply( 
-                      recvbuf, tmp_recvbuf, &recvcounts[rank], datatype);
-                    /* copy result back into recvbuf */
-                    mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank],
-                                               datatype, recvbuf,
-                                               recvcounts[rank], datatype);
-                    if (mpi_errno) return(mpi_errno);
+                  if (op != MPI_OP_NULL)
+                    op->apply(recvbuf, tmp_recvbuf, &recvcounts[rank], datatype);
+                  /* copy result back into recvbuf */
+                  mpi_errno =
+                      Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype, recvbuf, recvcounts[rank], datatype);
+                  if (mpi_errno)
+                    return (mpi_errno);
                 }
                 else {
-                   if(op!=MPI_OP_NULL) op->apply( 
-                        ((char *)recvbuf+disps[rank]*extent),
-                       tmp_recvbuf, &recvcounts[rank], datatype);
-                    /* copy result back into recvbuf */
-                    mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank],
-                                               datatype, 
-                                               ((char *)recvbuf +
-                                                disps[rank]*extent), 
-                                               recvcounts[rank], datatype);
-                    if (mpi_errno) return(mpi_errno);
+                  if (op != MPI_OP_NULL)
+                    op->apply(((char*)recvbuf + disps[rank] * extent), tmp_recvbuf, &recvcounts[rank], datatype);
+                  /* copy result back into recvbuf */
+                  mpi_errno = Datatype::copy(tmp_recvbuf, recvcounts[rank], datatype,
+                                             ((char*)recvbuf + disps[rank] * extent), recvcounts[rank], datatype);
+                  if (mpi_errno)
+                    return (mpi_errno);
                 }
             }
         }
-        
+
         /* if MPI_IN_PLACE, move output data to the beginning of
            recvbuf. already done for rank 0. */
         if ((sendbuf == MPI_IN_PLACE) && (rank != 0)) {
             mpi_errno = Datatype::copy(((char *)recvbuf +
-                                        disps[rank]*extent),  
+                                        disps[rank]*extent),
                                        recvcounts[rank], datatype,
-                                       recvbuf, 
+                                       recvbuf,
                                        recvcounts[rank], datatype );
             if (mpi_errno) return(mpi_errno);
         }
-    
+
         xbt_free(disps);
         smpi_free_tmp_buffer(tmp_recvbuf);
 
         return MPI_SUCCESS;
 }
-    
+
 
 int Coll_reduce_scatter_mpich_noncomm::reduce_scatter(void *sendbuf, void *recvbuf, int recvcounts[],
                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
@@ -233,7 +228,7 @@ int Coll_reduce_scatter_mpich_noncomm::reduce_scatter(void *sendbuf, void *recvb
            is now our peer's responsibility */
         if (rank > peer) {
             /* higher ranked value so need to call op(received_data, my_data) */
-            if(op!=MPI_OP_NULL) op->apply( 
+            if(op!=MPI_OP_NULL) op->apply(
                    incoming_data + recv_offset*true_extent,
                      outgoing_data + recv_offset*true_extent,
                      &size, datatype );
@@ -241,11 +236,10 @@ int Coll_reduce_scatter_mpich_noncomm::reduce_scatter(void *sendbuf, void *recvb
         }
         else {
             /* lower ranked value so need to call op(my_data, received_data) */
-           if(op!=MPI_OP_NULL) op->apply(
-                    outgoing_data + recv_offset*true_extent,
-                     incoming_data + recv_offset*true_extent,
-                     &size, datatype);
-            buf0_was_inout = !buf0_was_inout;
+            if (op != MPI_OP_NULL)
+              op->apply(outgoing_data + recv_offset * true_extent, incoming_data + recv_offset * true_extent, &size,
+                        datatype);
+            buf0_was_inout = not buf0_was_inout;
         }
 
         /* the next round of send/recv needs to happen within the block (of size
@@ -271,7 +265,7 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
                               MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
 {
     int   rank, comm_size, i;
-    MPI_Aint extent, true_extent, true_lb; 
+    MPI_Aint extent, true_extent, true_lb;
     int  *disps;
     void *tmp_recvbuf, *tmp_results;
     int mpi_errno = MPI_SUCCESS;
@@ -285,7 +279,7 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
 
     extent =datatype->get_extent();
     datatype->extent(&true_lb, &true_extent);
-    
+
     if ((op==MPI_OP_NULL) || op->is_commutative()) {
         is_commutative = 1;
     }
@@ -297,7 +291,7 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
         disps[i] = total_count;
         total_count += recvcounts[i];
     }
-    
+
             /* noncommutative and (non-pof2 or block irregular), use recursive doubling. */
 
             /* need to allocate temporary buffer to receive incoming data*/
@@ -355,7 +349,7 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
 
                 mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &sendtype);
                 if (mpi_errno) return(mpi_errno);
-                
+
                 sendtype->commit();
 
                 /* calculate recvtype */
@@ -372,14 +366,14 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
 
                 mpi_errno = Datatype::create_indexed(2, blklens, dis, datatype, &recvtype);
                 if (mpi_errno) return(mpi_errno);
-                
+
                 recvtype->commit();
 
                 received = 0;
                 if (dst < comm_size) {
                     /* tmp_results contains data to be sent in each step. Data is
                        received in tmp_recvbuf and then accumulated into
-                       tmp_results. accumulation is done later below.   */ 
+                       tmp_results. accumulation is done later below.   */
 
                     Request::sendrecv(tmp_results, 1, sendtype, dst,
                                                  COLL_TAG_SCATTER,
@@ -402,7 +396,7 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
                        in a tree fashion. First find root of current tree
                        that is being divided into two. k is the number of
                        least-significant bits in this process's rank that
-                       must be zeroed out to find the rank of the root */ 
+                       must be zeroed out to find the rank of the root */
                     j = mask;
                     k = 0;
                     while (j) {
@@ -421,7 +415,7 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
                         /* send only if this proc has data and destination
                            doesn't have data. at any step, multiple processes
                            can send if they have the data */
-                        if ((dst > rank) && 
+                        if ((dst > rank) &&
                             (rank < tree_root + nprocs_completed)
                             && (dst >= tree_root + nprocs_completed)) {
                             /* send the current result */
@@ -431,12 +425,12 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
                         }
                         /* recv only if this proc. doesn't have data and sender
                            has data */
-                        else if ((dst < rank) && 
+                        else if ((dst < rank) &&
                                  (dst < tree_root + nprocs_completed) &&
                                  (rank >= tree_root + nprocs_completed)) {
                             Request::recv(tmp_recvbuf, 1, recvtype, dst,
                                                      COLL_TAG_SCATTER,
-                                                     comm, MPI_STATUS_IGNORE); 
+                                                     comm, MPI_STATUS_IGNORE);
                             received = 1;
                         }
                         tmp_mask >>= 1;
@@ -444,9 +438,9 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
                     }
                 }
 
-                /* The following reduction is done here instead of after 
+                /* The following reduction is done here instead of after
                    the MPIC_Sendrecv_ft or MPIC_Recv_ft above. This is
-                   because to do it above, in the noncommutative 
+                   because to do it above, in the noncommutative
                    case, we would need an extra temp buffer so as not to
                    overwrite temp_recvbuf, because temp_recvbuf may have
                    to be communicated to other processes in the
@@ -455,27 +449,23 @@ int Coll_reduce_scatter_mpich_rdb::reduce_scatter(void *sendbuf, void *recvbuf,
                 if (received) {
                     if (is_commutative || (dst_tree_root < my_tree_root)) {
                         {
-                                if(op!=MPI_OP_NULL) op->apply( 
-                               tmp_recvbuf, tmp_results, &blklens[0],
-                              datatype); 
-                               if(op!=MPI_OP_NULL) op->apply( 
-                               ((char *)tmp_recvbuf + dis[1]*extent),
-                              ((char *)tmp_results + dis[1]*extent),
-                              &blklens[1], datatype); 
+                          if (op != MPI_OP_NULL)
+                            op->apply(tmp_recvbuf, tmp_results, &blklens[0], datatype);
+                          if (op != MPI_OP_NULL)
+                            op->apply(((char*)tmp_recvbuf + dis[1] * extent), ((char*)tmp_results + dis[1] * extent),
+                                      &blklens[1], datatype);
                         }
                     }
                     else {
                         {
-                                if(op!=MPI_OP_NULL) op->apply(
-                                   tmp_results, tmp_recvbuf, &blklens[0],
-                                   datatype); 
-                                if(op!=MPI_OP_NULL) op->apply(
-                                   ((char *)tmp_results + dis[1]*extent),
-                                   ((char *)tmp_recvbuf + dis[1]*extent),
-                                   &blklens[1], datatype); 
+                          if (op != MPI_OP_NULL)
+                            op->apply(tmp_results, tmp_recvbuf, &blklens[0], datatype);
+                          if (op != MPI_OP_NULL)
+                            op->apply(((char*)tmp_results + dis[1] * extent), ((char*)tmp_recvbuf + dis[1] * extent),
+                                      &blklens[1], datatype);
                         }
                         /* copy result back into tmp_results */
-                        mpi_errno = Datatype::copy(tmp_recvbuf, 1, recvtype, 
+                        mpi_errno = Datatype::copy(tmp_recvbuf, 1, recvtype,
                                                    tmp_results, 1, recvtype);
                         if (mpi_errno) return(mpi_errno);
                     }