Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Use type bool for boolean variables in smpi/colls/.
[simgrid.git] / src / smpi / colls / gather / gather-mvapich.cpp
index 9cc9a36..5eb527f 100644 (file)
@@ -1,4 +1,4 @@
-/* Copyright (c) 2013-2017. The SimGrid Team.
+/* Copyright (c) 2013-2019. The SimGrid Team.
  * All rights reserved.                                                     */
 
 /* This program is free software; you can redistribute it and/or modify it
  *      See COPYRIGHT in top-level directory.
  */
 
-#include "../colls_private.h"
+#include "../colls_private.hpp"
+#include <algorithm>
 
-
-
-
-
-#define MPIR_Gather_MV2_Direct Coll_gather_ompi_basic_linear::gather
-#define MPIR_Gather_MV2_two_level_Direct Coll_gather_ompi_basic_linear::gather
-#define MPIR_Gather_intra Coll_gather_mpich::gather
-typedef int (*MV2_Gather_function_ptr) (void *sendbuf,
+#define MPIR_Gather_MV2_Direct gather__ompi_basic_linear
+#define MPIR_Gather_MV2_two_level_Direct gather__ompi_basic_linear
+#define MPIR_Gather_intra gather__mpich
+typedef int (*MV2_Gather_function_ptr) (const void *sendbuf,
     int sendcnt,
     MPI_Datatype sendtype,
     void *recvbuf,
@@ -55,9 +52,8 @@ typedef int (*MV2_Gather_function_ptr) (void *sendbuf,
 extern MV2_Gather_function_ptr MV2_Gather_inter_leader_function;
 extern MV2_Gather_function_ptr MV2_Gather_intra_node_function;
 
-#define TEMP_BUF_HAS_NO_DATA (0)
-#define TEMP_BUF_HAS_DATA (1)
-
+#define TEMP_BUF_HAS_NO_DATA (false)
+#define TEMP_BUF_HAS_DATA (true)
 
 namespace simgrid{
 namespace smpi{
@@ -83,13 +79,9 @@ namespace smpi{
  *                      intra node gather function
  * errflag           - (out) to record errors
  */
-static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sendtype,
-                            void *recvbuf, int recvcnt, MPI_Datatype recvtype,
-                            int root, int rank,
-                            void *tmp_buf, int nbytes,
-                            int is_data_avail,
-                            MPI_Comm comm,
-                            MV2_Gather_function_ptr intra_node_fn_ptr)
+static int MPIR_pt_pt_intra_gather(const void* sendbuf, int sendcnt, MPI_Datatype sendtype, void* recvbuf, int recvcnt,
+                                   MPI_Datatype recvtype, int root, int rank, void* tmp_buf, int nbytes,
+                                   bool is_data_avail, MPI_Comm comm, MV2_Gather_function_ptr intra_node_fn_ptr)
 {
     int mpi_errno = MPI_SUCCESS;
     MPI_Aint recvtype_extent = 0;  /* Datatype extent */
@@ -130,35 +122,34 @@ static int MPIR_pt_pt_intra_gather( void *sendbuf, int sendcnt, MPI_Datatype sen
 
 
 
-int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
-                                            int sendcnt,
-                                            MPI_Datatype sendtype,
-                                            void *recvbuf,
-                                            int recvcnt,
-                                            MPI_Datatype recvtype,
-                                            int root,
-                                            MPI_Comm comm)
+int gather__mvapich2_two_level(const void *sendbuf,
+                               int sendcnt,
+                               MPI_Datatype sendtype,
+                               void *recvbuf,
+                               int recvcnt,
+                               MPI_Datatype recvtype,
+                               int root,
+                               MPI_Comm comm)
 {
-    void *leader_gather_buf = NULL;
-    int comm_size, rank;
-    int local_rank, local_size;
-    int leader_comm_rank = -1, leader_comm_size = 0;
-    int mpi_errno = MPI_SUCCESS;
-    int recvtype_size = 0, sendtype_size = 0, nbytes=0;
-    int leader_root, leader_of_root;
-    MPI_Status status;
-    MPI_Aint sendtype_extent = 0, recvtype_extent = 0;  /* Datatype extent */
-    MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
-    MPI_Comm shmem_comm, leader_comm;
-    void* tmp_buf = NULL;
-
-
-    //if not set (use of the algo directly, without mvapich2 selector)
-    if(MV2_Gather_intra_node_function==NULL)
-      MV2_Gather_intra_node_function= Coll_gather_mpich::gather;
-
-    if(comm->get_leaders_comm()==MPI_COMM_NULL){
-      comm->init_smp();
+  unsigned char* leader_gather_buf = NULL;
+  int comm_size, rank;
+  int local_rank, local_size;
+  int leader_comm_rank = -1, leader_comm_size = 0;
+  int mpi_errno     = MPI_SUCCESS;
+  int recvtype_size = 0, sendtype_size = 0, nbytes = 0;
+  int leader_root, leader_of_root;
+  MPI_Status status;
+  MPI_Aint sendtype_extent = 0, recvtype_extent = 0; /* Datatype extent */
+  MPI_Aint true_lb = 0, sendtype_true_extent = 0, recvtype_true_extent = 0;
+  MPI_Comm shmem_comm, leader_comm;
+  unsigned char* tmp_buf = NULL;
+
+  // if not set (use of the algo directly, without mvapich2 selector)
+  if (MV2_Gather_intra_node_function == NULL)
+    MV2_Gather_intra_node_function = gather__mpich;
+
+  if (comm->get_leaders_comm() == MPI_COMM_NULL) {
+    comm->init_smp();
     }
     comm_size = comm->size();
     rank = comm->rank();
@@ -227,12 +218,9 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
         if (local_rank == 0) {
             /* Node leader, allocate tmp_buffer */
             if (rank == root) {
-                tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent,
-                            recvtype_true_extent) * local_size);
+              tmp_buf = smpi_get_tmp_recvbuffer(recvcnt * std::max(recvtype_extent, recvtype_true_extent) * local_size);
             } else {
-                tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent,
-                            sendtype_true_extent) *
-                        local_size);
+              tmp_buf = smpi_get_tmp_sendbuffer(sendcnt * std::max(sendtype_extent, sendtype_true_extent) * local_size);
             }
             if (tmp_buf == NULL) {
                 mpi_errno = MPI_ERR_OTHER;
@@ -295,10 +283,10 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
            * is the same as leader_root */
           if (rank == root) {
             leader_gather_buf =
-                smpi_get_tmp_recvbuffer(recvcnt * MAX(recvtype_extent, recvtype_true_extent) * comm_size);
+                smpi_get_tmp_recvbuffer(recvcnt * std::max(recvtype_extent, recvtype_true_extent) * comm_size);
           } else {
             leader_gather_buf =
-                smpi_get_tmp_sendbuffer(sendcnt * MAX(sendtype_extent, sendtype_true_extent) * comm_size);
+                smpi_get_tmp_sendbuffer(sendcnt * std::max(sendtype_extent, sendtype_true_extent) * comm_size);
           }
           if (leader_gather_buf == NULL) {
             mpi_errno = MPI_ERR_OTHER;
@@ -309,12 +297,8 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
         node_sizes = comm->get_non_uniform_map();
 
         if (leader_comm_rank == leader_root) {
-          displs   = static_cast<int*>(xbt_malloc(sizeof(int) * leader_comm_size));
-          recvcnts = static_cast<int*>(xbt_malloc(sizeof(int) * leader_comm_size));
-          if (not displs || not recvcnts) {
-            mpi_errno = MPI_ERR_OTHER;
-            return mpi_errno;
-          }
+          displs   = new int[leader_comm_size];
+          recvcnts = new int[leader_comm_size];
         }
 
         if (root == leader_of_root) {
@@ -329,7 +313,7 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
               recvcnts[i] = node_sizes[i] * recvcnt;
             }
           }
-          Colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, recvbuf, recvcnts, displs, recvtype, leader_root,
+          colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, recvbuf, recvcnts, displs, recvtype, leader_root,
                          leader_comm);
         } else {
           /* The root of the gather operation is not the node leader.
@@ -344,12 +328,12 @@ int Coll_gather_mvapich2_two_level::gather(void *sendbuf,
               recvcnts[i] = node_sizes[i] * nbytes;
             }
           }
-          Colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, leader_gather_buf, recvcnts, displs, MPI_BYTE,
+          colls::gatherv(tmp_buf, local_size * nbytes, MPI_BYTE, leader_gather_buf, recvcnts, displs, MPI_BYTE,
                          leader_root, leader_comm);
         }
         if (leader_comm_rank == leader_root) {
-          xbt_free(displs);
-          xbt_free(recvcnts);
+          delete[] displs;
+          delete[] recvcnts;
         }
       }
     } else {