Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Update copyright lines.
[simgrid.git] / src / smpi / colls / gather / gather-mvapich.cpp
index 21c2142..ac80816 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright (c) 2013-2017. The SimGrid Team.
+/* Copyright (c) 2013-2021. The SimGrid Team.
  * All rights reserved.                                                     */
 
 /* This program is free software; you can redistribute it and/or modify it
  *      See COPYRIGHT in top-level directory.
  */
 
-#include "../colls_private.h"
+#include "../colls_private.hpp"
+#include <algorithm>
 
-
-
-
-
-#define MPIR_Gather_MV2_Direct Coll_gather_ompi_basic_linear::gather
-#define MPIR_Gather_MV2_two_level_Direct Coll_gather_ompi_basic_linear::gather
-#define MPIR_Gather_intra Coll_gather_mpich::gather
-typedef int (*MV2_Gather_function_ptr) (void *sendbuf,
+#define MPIR_Gather_MV2_Direct gather__ompi_basic_linear
+#define MPIR_Gather_MV2_two_level_Direct gather__ompi_basic_linear
+#define MPIR_Gather_intra gather__mpich
+typedef int (*MV2_Gather_function_ptr) (const void *sendbuf,
     int sendcnt,
     MPI_Datatype sendtype,
     void *recvbuf,
     int recvcnt,
     MPI_Datatype recvtype,
     int root, MPI_Comm comm);
-    
+
 extern MV2_Gather_function_ptr MV2_Gather_inter_leader_function;
 extern MV2_Gather_function_ptr MV2_Gather_intra_node_function;
 
-#define TEMP_BUF_HAS_NO_DATA (0)
-#define TEMP_BUF_HAS_DATA (1)
-
+#define TEMP_BUF_HAS_NO_DATA (false)
+#define TEMP_BUF_HAS_DATA (true)
 
 namespace simgrid{
 namespace smpi{
@@ -80,16 +76,12 @@ namespace smpi{
  *                     (shmem_comm or intra_sock_comm or
  *                     inter-sock_leader_comm)
  * intra_node_fn_ptr - (in) Function ptr to choose the
- *                      intra node gather function  
+ *                      intra node gather function
  * errflag           - (out) to record errors
  */
-static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sendtype,
-                            void *recvbuf, int recvcnt, MPI_Datatype recvtype,
-                            int root, int rank, 
-                            void *tmp_buf, int nbytes,
-                            int is_data_avail,
-                            MPI_Comm comm,  
-                            MV2_Gather_function_ptr intra_node_fn_ptr)
+static int MPIR_pt_pt_intra_gather(const void* sendbuf, int sendcnt, MPI_Datatype sendtype, void* recvbuf, int recvcnt,
+                                   MPI_Datatype recvtype, int root, int rank, void* tmp_buf, int nbytes,
+                                   bool is_data_avail, MPI_Comm comm, MV2_Gather_function_ptr intra_node_fn_ptr)
 {
     int mpi_errno = MPI_SUCCESS;
     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
@@ -105,10 +97,10 @@ static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sen
         recvtype->extent(&true_lb,
                                        &recvtype_true_extent);
     }
-    
+
     /* Special case, when tmp_buf itself has data */
     if (rank == root && sendbuf == MPI_IN_PLACE && is_data_avail) {
-         
+
          mpi_errno = intra_node_fn_ptr(MPI_IN_PLACE,
                                        sendcnt, sendtype, tmp_buf, nbytes,
                                        MPI_BYTE, 0, comm);
@@ -130,35 +122,34 @@ static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sen
 
 
 
-int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
-                                            int sendcnt,
-                                            MPI_Datatype sendtype,
-                                            void *recvbuf,
-                                            int recvcnt,
-                                            MPI_Datatype recvtype,
-                                            int root,
-                                            MPI_Comm comm)
+int gather__mvapich2_two_level(const void *sendbuf,
+                               int sendcnt,
+                               MPI_Datatype sendtype,
+                               void *recvbuf,
+                               int recvcnt,
+                               MPI_Datatype recvtype,
+                               int root,
+                               MPI_Comm comm)
 {
-    void *leader_gather_buf = NULL;
-    int comm_size, rank;
-    int local_rank, local_size;
-    int leader_comm_rank = -1, leader_comm_size = 0;
-    int mpi_errno = MPI_SUCCESS;
-    int recvtype_size = 0, sendtype_size = 0, nbytes=0;
-    int leader_root, leader_of_root;
-    MPI_Status status;
-    MPI_Aint sendtype_extent = 0, recvtype_extent = 0;  /* Datatype extent */
-    MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
-    MPI_Comm shmem_comm, leader_comm;
-    void* tmp_buf = NULL;
-    
-
-    //if not set (use of the algo directly, without mvapich2 selector)
-    if(MV2_Gather_intra_node_function==NULL)
-      MV2_Gather_intra_node_function= Coll_gather_mpich::gather;
-    
-    if(comm->get_leaders_comm()==MPI_COMM_NULL){
-      comm->init_smp();
+  unsigned char* leader_gather_buf = nullptr;
+  int comm_size, rank;
+  int local_rank, local_size;
+  int leader_comm_rank = -1, leader_comm_size = 0;
+  int mpi_errno     = MPI_SUCCESS;
+  int recvtype_size = 0, sendtype_size = 0, nbytes = 0;
+  int leader_root, leader_of_root;
+  MPI_Status status;
+  MPI_Aint sendtype_extent = 0, recvtype_extent = 0; /* Datatype extent */
+  MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
+  MPI_Comm shmem_comm, leader_comm;
+  unsigned char* tmp_buf = nullptr;
+
+  // if not set (use of the algo directly, without mvapich2 selector)
+  if (MV2_Gather_intra_node_function == nullptr)
+    MV2_Gather_intra_node_function = gather__mpich;
+
+  if (comm->get_leaders_comm() == MPI_COMM_NULL) {
+    comm->init_smp();
     }
     comm_size = comm->size();
     rank = comm->rank();
@@ -186,7 +177,7 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
     shmem_comm = comm->get_intra_comm();
     local_rank = shmem_comm->rank();
     local_size = shmem_comm->size();
-    
+
     if (local_rank == 0) {
         /* Node leader. Extract the rank, size information for the leader
          * communicator */
@@ -206,52 +197,49 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
     }
 
 #if defined(_SMP_LIMIC_)
-     if((g_use_limic2_coll) && (shmem_commptr->ch.use_intra_sock_comm == 1) 
+     if((g_use_limic2_coll) && (shmem_commptr->ch.use_intra_sock_comm == 1)
          && (use_limic_gather)
-         &&((num_scheme == USE_GATHER_PT_PT_BINOMIAL) 
+         &&((num_scheme == USE_GATHER_PT_PT_BINOMIAL)
             || (num_scheme == USE_GATHER_PT_PT_DIRECT)
-            ||(num_scheme == USE_GATHER_PT_LINEAR_BINOMIAL) 
+            ||(num_scheme == USE_GATHER_PT_LINEAR_BINOMIAL)
             || (num_scheme == USE_GATHER_PT_LINEAR_DIRECT)
             || (num_scheme == USE_GATHER_LINEAR_PT_BINOMIAL)
             || (num_scheme == USE_GATHER_LINEAR_PT_DIRECT)
             || (num_scheme == USE_GATHER_LINEAR_LINEAR)
             || (num_scheme == USE_GATHER_SINGLE_LEADER))) {
-            
+
             mpi_errno = MV2_Gather_intra_node_function(sendbuf, sendcnt, sendtype,
-                                                    recvbuf, recvcnt,recvtype, 
+                                                    recvbuf, recvcnt,recvtype,
                                                     root, comm);
      } else
 
-#endif/*#if defined(_SMP_LIMIC_)*/    
+#endif/*#if defined(_SMP_LIMIC_)*/
     {
         if (local_rank == 0) {
             /* Node leader, allocate tmp_buffer */
             if (rank == root) {
-                tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent,
-                            recvtype_true_extent) * local_size);
+              tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * std::max(recvtype_extent, recvtype_true_extent) * local_size);
             } else {
-                tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent,
-                            sendtype_true_extent) *
-                        local_size);
+              tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * std::max(sendtype_extent, sendtype_true_extent) * local_size);
             }
-            if (tmp_buf == NULL) {
-                mpi_errno = MPI_ERR_OTHER;
-                return mpi_errno;
+            if (tmp_buf == nullptr) {
+              mpi_errno = MPI_ERR_OTHER;
+              return mpi_errno;
             }
         }
          /*while testing mpich2 gather test, we see that
          * which basically splits the comm, and we come to
-         * a point, where use_intra_sock_comm == 0, but if the 
+         * a point, where use_intra_sock_comm == 0, but if the
          * intra node function is MPIR_Intra_node_LIMIC_Gather_MV2,
-         * it would use the intra sock comm. In such cases, we 
+         * it would use the intra sock comm. In such cases, we
          * fallback to binomial as a default case.*/
-#if defined(_SMP_LIMIC_)         
+#if defined(_SMP_LIMIC_)
         if(*MV2_Gather_intra_node_function == MPIR_Intra_node_LIMIC_Gather_MV2) {
 
             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
                                                  recvbuf, recvcnt, recvtype,
-                                                 root, rank, 
-                                                 tmp_buf, nbytes, 
+                                                 root, rank,
+                                                 tmp_buf, nbytes,
                                                  TEMP_BUF_HAS_NO_DATA,
                                                  shmem_commptr,
                                                  MPIR_Gather_intra);
@@ -263,8 +251,8 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
              * local data, we pass is_data_avail = TEMP_BUF_HAS_NO_DATA*/
             mpi_errno  = MPIR_pt_pt_intra_gather(sendbuf,sendcnt, sendtype,
                                                  recvbuf, recvcnt, recvtype,
-                                                 root, rank, 
-                                                 tmp_buf, nbytes, 
+                                                 root, rank,
+                                                 tmp_buf, nbytes,
                                                  TEMP_BUF_HAS_NO_DATA,
                                                  shmem_comm,
                                                  MV2_Gather_intra_node_function
@@ -275,104 +263,91 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
     int* leaders_map = comm->get_leaders_map();
     leader_of_root = comm->group()->rank(leaders_map[root]);
     leader_root = leader_comm->group()->rank(leaders_map[root]);
-    /* leader_root is the rank of the leader of the root in leader_comm. 
-     * leader_root is to be used as the root of the inter-leader gather ops 
+    /* leader_root is the rank of the leader of the root in leader_comm.
+     * leader_root is to be used as the root of the inter-leader gather ops
      */
-    if (!comm->is_uniform()) {
-        if (local_rank == 0) {
-            int *displs = NULL;
-            int *recvcnts = NULL;
-            int *node_sizes;
-            int i = 0;
-            /* Node leaders have all the data. But, different nodes can have
-             * different number of processes. Do a Gather first to get the 
-             * buffer lengths at each leader, followed by a Gatherv to move
-             * the actual data */
+    if (not comm->is_uniform()) {
+      if (local_rank == 0) {
+        int* displs   = nullptr;
+        int* recvcnts = nullptr;
+        int* node_sizes;
+        int i = 0;
+        /* Node leaders have all the data. But, different nodes can have
+         * different number of processes. Do a Gather first to get the
+         * buffer lengths at each leader, followed by a Gatherv to move
+         * the actual data */
+
+        if (leader_comm_rank == leader_root && root != leader_of_root) {
+          /* The root of the Gather operation is not a node-level
+           * leader and this process's rank in the leader_comm
+           * is the same as leader_root */
+          if (rank == root) {
+            leader_gather_buf =
+                smpi_get_tmp_recvbuffer(recvcnt * std::max(recvtype_extent, recvtype_true_extent) * comm_size);
+          } else {
+            leader_gather_buf =
+                smpi_get_tmp_sendbuffer(sendcnt * std::max(sendtype_extent, sendtype_true_extent) * comm_size);
+          }
+          if (leader_gather_buf == nullptr) {
+            mpi_errno = MPI_ERR_OTHER;
+            return mpi_errno;
+          }
+        }
 
-            if (leader_comm_rank == leader_root && root != leader_of_root) {
-                /* The root of the Gather operation is not a node-level 
-                 * leader and this process's rank in the leader_comm 
-                 * is the same as leader_root */
-                if(rank == root) { 
-                    leader_gather_buf = smpi_get_tmp_recvbuffer(recvcnt *
-                                                MAX(recvtype_extent,
-                                                recvtype_true_extent) *
-                                                comm_size);
-                } else { 
-                    leader_gather_buf = smpi_get_tmp_sendbuffer(sendcnt *
-                                                MAX(sendtype_extent,
-                                                sendtype_true_extent) *
-                                                comm_size);
-                } 
-                if (leader_gather_buf == NULL) {
-                    mpi_errno =  MPI_ERR_OTHER;
-                    return mpi_errno;
-                }
-            }
+        node_sizes = comm->get_non_uniform_map();
 
-            node_sizes = comm->get_non_uniform_map();
+        if (leader_comm_rank == leader_root) {
+          displs   = new int[leader_comm_size];
+          recvcnts = new int[leader_comm_size];
+        }
 
-            if (leader_comm_rank == leader_root) {
-                displs =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
-                recvcnts =  static_cast<int *>(xbt_malloc(sizeof (int) * leader_comm_size));
-                if (!displs || !recvcnts) {
-                    mpi_errno = MPI_ERR_OTHER;
-                    return mpi_errno;
-                }
-            }
+        if (root == leader_of_root) {
+          /* The root of the gather operation is also the node
+           * leader. Receive into recvbuf and we are done */
+          if (leader_comm_rank == leader_root) {
+            recvcnts[0] = node_sizes[0] * recvcnt;
+            displs[0]   = 0;
 
-            if (root == leader_of_root) {
-                /* The root of the gather operation is also the node 
-                 * leader. Receive into recvbuf and we are done */
-                if (leader_comm_rank == leader_root) {
-                    recvcnts[0] = node_sizes[0] * recvcnt;
-                    displs[0] = 0;
-
-                    for (i = 1; i < leader_comm_size; i++) {
-                        displs[i] = displs[i - 1] + node_sizes[i - 1] * recvcnt;
-                        recvcnts[i] = node_sizes[i] * recvcnt;
-                    }
-                } 
-                Colls::gatherv(tmp_buf,
-                                         local_size * nbytes,
-                                         MPI_BYTE, recvbuf, recvcnts,
-                                         displs, recvtype,
-                                         leader_root, leader_comm);
-            } else {
-                /* The root of the gather operation is not the node leader. 
-                 * Receive into leader_gather_buf and then send 
-                 * to the root */
-                if (leader_comm_rank == leader_root) {
-                    recvcnts[0] = node_sizes[0] * nbytes;
-                    displs[0] = 0;
-
-                    for (i = 1; i < leader_comm_size; i++) {
-                        displs[i] = displs[i - 1] + node_sizes[i - 1] * nbytes;
-                        recvcnts[i] = node_sizes[i] * nbytes;
-                    }
-                } 
-                Colls::gatherv(tmp_buf, local_size * nbytes,
-                                         MPI_BYTE, leader_gather_buf,
-                                         recvcnts, displs, MPI_BYTE,
-                                         leader_root, leader_comm);
+            for (i = 1; i < leader_comm_size; i++) {
+              displs[i]   = displs[i - 1] + node_sizes[i - 1] * recvcnt;
+              recvcnts[i] = node_sizes[i] * recvcnt;
             }
-            if (leader_comm_rank == leader_root) {
-                xbt_free(displs);
-                xbt_free(recvcnts);
+          }
+          colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, recvbuf, recvcnts, displs, recvtype, leader_root,
+                         leader_comm);
+        } else {
+          /* The root of the gather operation is not the node leader.
+           * Receive into leader_gather_buf and then send
+           * to the root */
+          if (leader_comm_rank == leader_root) {
+            recvcnts[0] = node_sizes[0] * nbytes;
+            displs[0]   = 0;
+
+            for (i = 1; i < leader_comm_size; i++) {
+              displs[i]   = displs[i - 1] + node_sizes[i - 1] * nbytes;
+              recvcnts[i] = node_sizes[i] * nbytes;
             }
+          }
+          colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, leader_gather_buf, recvcnts, displs, MPI_BYTE,
+                         leader_root, leader_comm);
+        }
+        if (leader_comm_rank == leader_root) {
+          delete[] displs;
+          delete[] recvcnts;
         }
+      }
     } else {
-        /* All nodes have the same number of processes. 
-         * Just do one Gather to get all 
+        /* All nodes have the same number of processes.
+         * Just do one Gather to get all
          * the data at the leader of the root process */
         if (local_rank == 0) {
             if (leader_comm_rank == leader_root && root != leader_of_root) {
                 /* The root of the Gather operation is not a node-level leader
                  */
                 leader_gather_buf = smpi_get_tmp_sendbuffer(nbytes * comm_size);
-                if (leader_gather_buf == NULL) {
-                    mpi_errno = MPI_ERR_OTHER;
-                    return mpi_errno;
+                if (leader_gather_buf == nullptr) {
+                  mpi_errno = MPI_ERR_OTHER;
+                  return mpi_errno;
                 }
             }
             if (root == leader_of_root) {
@@ -382,7 +357,7 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
                                                    recvcnt * local_size,
                                                    recvtype, leader_root,
                                                    leader_comm);
-                 
+
             } else {
                 mpi_errno = MPIR_Gather_MV2_Direct(tmp_buf, nbytes * local_size,
                                                    MPI_BYTE, leader_gather_buf,
@@ -409,12 +384,12 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
 
     /* check if multiple threads are calling this collective function */
     if (local_rank == 0 ) {
-        if (tmp_buf != NULL) {
-            smpi_free_tmp_buffer(tmp_buf);
-        }
-        if (leader_gather_buf != NULL) {
-            smpi_free_tmp_buffer(leader_gather_buf);
-        }
+      if (tmp_buf != nullptr) {
+        smpi_free_tmp_buffer(tmp_buf);
+      }
+      if (leader_gather_buf != nullptr) {
+        smpi_free_tmp_buffer(leader_gather_buf);
+      }
     }
 
     return (mpi_errno);