Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add Scatter SMP collective from MVAPICH2
[simgrid.git] / src / smpi / colls / smpi_mvapich2_selector.c
index 54b6782..c423a5b 100644 (file)
@@ -690,7 +690,7 @@ int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
     void *recvbuf,
     int recvcnt,
     MPI_Datatype recvtype,
-    int root, MPI_Comm comm_ptr)
+    int root, MPI_Comm comm)
 {
   int range = 0, range_threshold = 0, range_threshold_intra = 0;
   int mpi_errno = MPI_SUCCESS;
@@ -699,16 +699,20 @@ int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
   int recvtype_size, sendtype_size;
   int partial_sub_ok = 0;
   int conf_index = 0;
-  //  int local_size = -1;
-  //  int i;
-  //   MPI_Comm shmem_comm;
+    int local_size = -1;
+    int i;
+     MPI_Comm shmem_comm;
   //    MPID_Comm *shmem_commptr=NULL;
   if(mv2_scatter_thresholds_table==NULL)
     init_mv2_scatter_tables_stampede();
 
-  comm_size = smpi_comm_size(comm_ptr);
+  if(smpi_comm_get_leaders_comm(comm)==MPI_COMM_NULL){
+    smpi_comm_init_smp(comm);
+  }
+  
+  comm_size = smpi_comm_size(comm);
 
-  rank = smpi_comm_rank(comm_ptr);
+  rank = smpi_comm_rank(comm);
 
   if (rank == root) {
       sendtype_size=smpi_datatype_size(sendtype);
@@ -717,29 +721,28 @@ int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
       recvtype_size=smpi_datatype_size(recvtype);
       nbytes = recvcnt * recvtype_size;
   }
-  /*
+  
     // check if safe to use partial subscription mode 
-    if (comm_ptr->ch.shmem_coll_ok == 1 && comm_ptr->ch.is_uniform) {
+    if (smpi_comm_is_uniform(comm)) {
 
-        shmem_comm = comm_ptr->ch.shmem_comm;
-        MPID_Comm_get_ptr(shmem_comm, shmem_commptr);
-        local_size = shmem_commptr->local_size;
+        shmem_comm = smpi_comm_get_intra_comm(comm);
+        local_size = smpi_comm_size(shmem_comm);
         i = 0;
         if (mv2_scatter_table_ppn_conf[0] == -1) {
             // Indicating user defined tuning 
             conf_index = 0;
-            goto conf_check_end;
+        }else{
+            do {
+                if (local_size == mv2_scatter_table_ppn_conf[i]) {
+                    conf_index = i;
+                    partial_sub_ok = 1;
+                    break;
+                }
+                i++;
+            } while(i < mv2_scatter_num_ppn_conf);
         }
-        do {
-            if (local_size == mv2_scatter_table_ppn_conf[i]) {
-                conf_index = i;
-                partial_sub_ok = 1;
-                break;
-            }
-            i++;
-        } while(i < mv2_scatter_num_ppn_conf);
     }
-   */
+   
   if (partial_sub_ok != 1) {
       conf_index = 0;
   }
@@ -772,9 +775,9 @@ int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
 
   if(MV2_Scatter_function == &MPIR_Scatter_mcst_wrap_MV2) {
 #if defined(_MCST_SUPPORT_)
-      if(comm_ptr->ch.is_mcast_ok == 1
+      if(comm->ch.is_mcast_ok == 1
           && mv2_use_mcast_scatter == 1
-          && comm_ptr->ch.shmem_coll_ok == 1) {
+          && comm->ch.shmem_coll_ok == 1) {
           MV2_Scatter_function = &MPIR_Scatter_mcst_MV2;
       } else
 #endif /*#if defined(_MCST_SUPPORT_) */
@@ -792,25 +795,24 @@ int smpi_coll_tuned_scatter_mvapich2(void *sendbuf,
 
   if( (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Direct) ||
       (MV2_Scatter_function == &MPIR_Scatter_MV2_two_level_Binomial)) {
-      /* if( comm_ptr->ch.shmem_coll_ok == 1 &&
-             comm_ptr->ch.is_global_block == 1 ) {
+       if( smpi_comm_is_blocked(comm)) {
              MV2_Scatter_intra_function = mv2_scatter_thresholds_table[conf_index][range].intra_node[range_threshold_intra]
                                 .MV2_pt_Scatter_function;
 
              mpi_errno =
                    MV2_Scatter_function(sendbuf, sendcnt, sendtype,
                                         recvbuf, recvcnt, recvtype, root,
-                                        comm_ptr);
-         } else {*/
+                                        comm);
+         } else {
       mpi_errno = MPIR_Scatter_MV2_Binomial(sendbuf, sendcnt, sendtype,
           recvbuf, recvcnt, recvtype, root,
-          comm_ptr);
+          comm);
 
-      //}
+      }
   } else {
       mpi_errno = MV2_Scatter_function(sendbuf, sendcnt, sendtype,
           recvbuf, recvcnt, recvtype, root,
-          comm_ptr);
+          comm);
   }
   return (mpi_errno);
 }