Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add Reduce SMP collective from MVAPICH2
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.c
index fa6c7b2..54b6782 100644 (file)
@@ -92,12 +92,12 @@ int smpi_coll_tuned_allgather_mvapich2(void *sendbuf, int sendcount, MPI_Datatyp
   int mpi_errno = MPI_SUCCESS;
   int nbytes = 0, comm_size, recvtype_size;
   int range = 0;
-  //int partial_sub_ok = 0;
+  int partial_sub_ok = 0;
   int conf_index = 0;
   int range_threshold = 0;
   int is_two_level = 0;
-  //int local_size = -1;
-  //MPI_Comm shmem_comm;
+  int local_size = -1;
+  MPI_Comm shmem_comm;
   //MPI_Comm *shmem_commptr=NULL;
   /* Get the size of the communicator */
   comm_size = smpi_comm_size(comm);
@@ -106,34 +106,35 @@ int smpi_coll_tuned_allgather_mvapich2(void *sendbuf, int sendcount, MPI_Datatyp
 
   if(mv2_allgather_table_ppn_conf==NULL)
     init_mv2_allgather_tables_stampede();
+    
+  if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
+    smpi_comm_init_smp(comm);
+  }
 
-  //int i;
-  /* check if safe to use partial subscription mode */
-  /*  if (comm->ch.shmem_coll_ok == 1 && comm->ch.is_uniform) {
-
-        shmem_comm = comm->ch.shmem_comm;
-        MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
-        local_size = shmem_commptr->local_size;
-        i = 0;
-        if (mv2_allgather_table_ppn_conf[0] == -1) {
-            // Indicating user defined tuning
-            conf_index = 0;
-            goto conf_check_end;
-        }
-        do {
-            if (local_size == mv2_allgather_table_ppn_conf[i]) {
-                conf_index = i;
-                partial_sub_ok = 1;
-                break;
-            }
-            i++;
-        } while(i < mv2_allgather_num_ppn_conf);
+  int i;
+  if (smpi_comm_is_uniform(comm)){
+    shmem_comm = smpi_comm_get_intra_comm(comm);
+    local_size = smpi_comm_size(shmem_comm);
+    i = 0;
+    if (mv2_allgather_table_ppn_conf[0] == -1) {
+      // Indicating user defined tuning
+      conf_index = 0;
+      goto conf_check_end;
     }
-
+    do {
+      if (local_size == mv2_allgather_table_ppn_conf[i]) {
+        conf_index = i;
+        partial_sub_ok = 1;
+        break;
+      }
+      i++;
+    } while(i < mv2_allgather_num_ppn_conf);
+  }
   conf_check_end:
-    if (partial_sub_ok != 1) {
-        conf_index = 0;
-    }*/
+  if (partial_sub_ok != 1) {
+    conf_index = 0;
+  }
+  
   /* Search for the corresponding system size inside the tuning table */
   while ((range < (mv2_size_allgather_tuning_table[conf_index] - 1)) &&
       (comm_size >
@@ -158,24 +159,21 @@ int smpi_coll_tuned_allgather_mvapich2(void *sendbuf, int sendcount, MPI_Datatyp
 
   /* intracommunicator */
   if(is_two_level ==1){
-
-      /*       if(comm->ch.shmem_coll_ok == 1){
-            MPIR_T_PVAR_COUNTER_INC(MV2, mv2_num_shmem_coll_calls, 1);
-          if (1 == comm->ch.is_blocked) {
-                mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
-                                                   recvbuf, recvcount, recvtype,
-                                                   comm, errflag);
-          }
-          else {
-              mpi_errno = MPIR_Allgather_intra(sendbuf, sendcount, sendtype,
-                                               recvbuf, recvcount, recvtype,
-                                               comm, errflag);
-          }
-        } else {*/
+    if(partial_sub_ok ==1){
+      if (smpi_comm_is_blocked(comm)){
+      mpi_errno = MPIR_2lvl_Allgather_MV2(sendbuf, sendcount, sendtype,
+                            recvbuf, recvcount, recvtype,
+                            comm);
+      }else{
+      mpi_errno = smpi_coll_tuned_allgather_mpich(sendbuf, sendcount, sendtype,
+                            recvbuf, recvcount, recvtype,
+                            comm);
+      }
+    } else {
       mpi_errno = MPIR_Allgather_RD_MV2(sendbuf, sendcount, sendtype,
           recvbuf, recvcount, recvtype,
           comm);
-      //     }
+    }
   } else if(MV2_Allgather_function == &MPIR_Allgather_Bruck_MV2
       || MV2_Allgather_function == &MPIR_Allgather_RD_MV2
       || MV2_Allgather_function == &MPIR_Allgather_Ring_MV2) {
@@ -242,9 +240,8 @@ int smpi_coll_tuned_gather_mvapich2(void *sendbuf,
       -1)) {
       range_intra_threshold++;
   }
-  /*
-    if (comm->ch.is_global_block == 1 && mv2_use_direct_gather == 1 &&
-            mv2_use_two_level_gather == 1 && comm->ch.shmem_coll_ok == 1) {
+  
+    if (smpi_comm_is_blocked(comm) ) {
         // Set intra-node function pt for gather_two_level 
         MV2_Gather_intra_node_function = 
                               mv2_gather_thresholds_table[range].intra_node[range_intra_threshold].
@@ -258,12 +255,12 @@ int smpi_coll_tuned_gather_mvapich2(void *sendbuf,
             MV2_Gather_inter_leader_function(sendbuf, sendcnt, sendtype, recvbuf, recvcnt,
                                              recvtype, root, comm);
 
-    } else {*/
-  // Indded, direct (non SMP-aware)gather is MPICH one
+    } else {
+  // Indeed, direct (non SMP-aware)gather is MPICH one
   mpi_errno = smpi_coll_tuned_gather_mpich(sendbuf, sendcnt, sendtype,
       recvbuf, recvcnt, recvtype,
       root, comm);
-  //}
+  }
 
   return mpi_errno;
 }
@@ -362,7 +359,7 @@ int smpi_coll_tuned_allreduce_mvapich2(void *sendbuf,
   int nbytes = 0;
   int range = 0, range_threshold = 0, range_threshold_intra = 0;
   int is_two_level = 0;
-  //int is_commutative = 0;
+  int is_commutative = 0;
   MPI_Aint true_lb, true_extent;
 
   sendtype_size=smpi_datatype_size(datatype);
@@ -430,16 +427,16 @@ int smpi_coll_tuned_allreduce_mvapich2(void *sendbuf,
 
     if(is_two_level == 1){
         // check if shm is ready, if not use other algorithm first
-        /*if ((comm->ch.shmem_coll_ok == 1)
-                    && (mv2_enable_shmem_allreduce)
-                    && (is_commutative)
-                    && (mv2_enable_shmem_collectives)) {
-                    mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
+        if (is_commutative) {
+          if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
+            smpi_comm_init_smp(comm);
+          }
+          mpi_errno = MPIR_Allreduce_two_level_MV2(sendbuf, recvbuf, count,
                                                      datatype, op, comm);
-                } else {*/
+                } else {
         mpi_errno = MPIR_Allreduce_pt2pt_rd_MV2(sendbuf, recvbuf, count,
             datatype, op, comm);
-        // }
+        }
     } else {
         mpi_errno = MV2_Allreduce_function(sendbuf, recvbuf, count,
             datatype, op, comm);
@@ -573,14 +570,16 @@ int smpi_coll_tuned_reduce_mvapich2( void *sendbuf,
   /* We call Reduce function */
   if(is_two_level == 1)
     {
-      /* if (comm->ch.shmem_coll_ok == 1
-            && is_commutative == 1) {
-            mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count, 
-                                           datatype, op, root, comm, errflag);
-        } else {*/
+       if (is_commutative == 1) {
+         if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
+           smpi_comm_init_smp(comm);
+         }
+         mpi_errno = MPIR_Reduce_two_level_helper_MV2(sendbuf, recvbuf, count, 
+                                           datatype, op, root, comm);
+        } else {
       mpi_errno = MPIR_Reduce_binomial_MV2(sendbuf, recvbuf, count,
           datatype, op, root, comm);
-      //}
+      }
     } else if(MV2_Reduce_function == &MPIR_Reduce_inter_knomial_wrapper_MV2 ){
         if(is_commutative ==1)
           {