Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Remove another bunch of const_casts.
[simgrid.git] / src / smpi / colls / smpi_mpich_selector.cpp
index 89440a1..11de914 100644 (file)
@@ -1,6 +1,6 @@
 /* selector for collective algorithms based on mpich decision logic */
 
-/* Copyright (c) 2009-2018. The SimGrid Team.
+/* Copyright (c) 2009-2019. The SimGrid Team.
  * All rights reserved.                                                     */
 
 /* This program is free software; you can redistribute it and/or modify it
@@ -58,7 +58,7 @@
 */
 namespace simgrid{
 namespace smpi{
-int Coll_allreduce_mpich::allreduce(void *sbuf, void *rbuf, int count,
+int Coll_allreduce_mpich::allreduce(const void *sbuf, void *rbuf, int count,
                         MPI_Datatype dtype, MPI_Op op, MPI_Comm comm)
 {
     size_t dsize, block_dsize;
@@ -68,6 +68,14 @@ int Coll_allreduce_mpich::allreduce(void *sbuf, void *rbuf, int count,
     dsize = dtype->size();
     block_dsize = dsize * count;
 
+    /*MPICH uses SMP algorithms for all commutative ops now*/
+    if(!comm->is_smp_comm()){
+      if(comm->get_leaders_comm()==MPI_COMM_NULL){
+        comm->init_smp();
+      }
+      if(op->is_commutative())
+        return Coll_allreduce_mvapich2_two_level::allreduce (sbuf, rbuf,count, dtype, op, comm);
+    }
 
     /* find nearest power-of-two less than or equal to comm_size */
     int pof2 = 1;
@@ -76,14 +84,10 @@ int Coll_allreduce_mpich::allreduce(void *sbuf, void *rbuf, int count,
 
     if (block_dsize > large_message && count >= pof2 && (op==MPI_OP_NULL || op->is_commutative())) {
       //for long messages
-       return (Coll_allreduce_rab_rdb::allreduce (sbuf, rbuf,
-                                                                   count, dtype,
-                                                                   op, comm));
+       return Coll_allreduce_rab_rdb::allreduce (sbuf, rbuf, count, dtype, op, comm);
     }else {
       //for short ones and count < pof2
-      return (Coll_allreduce_rdb::allreduce (sbuf, rbuf,
-                                                                   count, dtype,
-                                                                   op, comm));
+      return Coll_allreduce_rdb::allreduce (sbuf, rbuf, count, dtype, op, comm);
     }
 }
 
@@ -134,7 +138,7 @@ int Coll_allreduce_mpich::allreduce(void *sbuf, void *rbuf, int count,
    End Algorithm: MPI_Alltoall
 */
 
-int Coll_alltoall_mpich::alltoall( void *sbuf, int scount,
+int Coll_alltoall_mpich::alltoall(const void *sbuf, int scount,
                                              MPI_Datatype sdtype,
                                              void* rbuf, int rcount,
                                              MPI_Datatype rdtype,
@@ -169,7 +173,7 @@ int Coll_alltoall_mpich::alltoall( void *sbuf, int scount,
                                                     comm);
 
     } else if (block_dsize < medium_size) {
-        return Coll_alltoall_basic_linear::alltoall(sbuf, scount, sdtype,
+        return Coll_alltoall_mvapich2_scatter_dest::alltoall(sbuf, scount, sdtype,
                                                            rbuf, rcount, rdtype,
                                                            comm);
     }else if (communicator_size%2){
@@ -183,9 +187,9 @@ int Coll_alltoall_mpich::alltoall( void *sbuf, int scount,
                                                     comm);
 }
 
-int Coll_alltoallv_mpich::alltoallv(void *sbuf, int *scounts, int *sdisps,
+int Coll_alltoallv_mpich::alltoallv(const void *sbuf, const int *scounts, const int *sdisps,
                                               MPI_Datatype sdtype,
-                                              void *rbuf, int *rcounts, int *rdisps,
+                                              void *rbuf, const int *rcounts, const int *rdisps,
                                               MPI_Datatype rdtype,
                                               MPI_Comm  comm
                                               )
@@ -259,6 +263,14 @@ int Coll_bcast_mpich::bcast(void *buff, int count,
     //int segsize = 0;
     size_t message_size, dsize;
 
+    if(!comm->is_smp_comm()){
+      if(comm->get_leaders_comm()==MPI_COMM_NULL){
+        comm->init_smp();
+      }
+      if(comm->is_uniform())
+        return Coll_bcast_SMP_binomial::bcast(buff, count, datatype, root, comm);
+    }
+
     communicator_size = comm->size();
 
     /* else we need data size for decision function */
@@ -341,15 +353,23 @@ int Coll_bcast_mpich::bcast(void *buff, int count,
 */
 
 
-int Coll_reduce_mpich::reduce( void *sendbuf, void *recvbuf,
+int Coll_reduce_mpich::reduce(const void *sendbuf, void *recvbuf,
                                             int count, MPI_Datatype  datatype,
                                             MPI_Op   op, int root,
                                             MPI_Comm   comm
                                             )
 {
     int communicator_size=0;
-    //int segsize = 0;
     size_t message_size, dsize;
+
+    if(!comm->is_smp_comm()){
+      if(comm->get_leaders_comm()==MPI_COMM_NULL){
+        comm->init_smp();
+      }
+      if (op->is_commutative() == 1)
+        return Coll_reduce_mvapich2_two_level::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
+    }
+
     communicator_size = comm->size();
 
     /* need data size for decision function */
@@ -363,8 +383,7 @@ int Coll_reduce_mpich::reduce( void *sendbuf, void *recvbuf,
     if ((count < pof2) || (message_size < 2048) || (op != MPI_OP_NULL && not op->is_commutative())) {
       return Coll_reduce_binomial::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
     }
-        return Coll_reduce_scatter_gather::reduce(sendbuf, recvbuf, count, datatype, op, root, comm/*, module,
-                                                     segsize, max_requests*/);
+        return Coll_reduce_scatter_gather::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
 }
 
 
@@ -417,8 +436,8 @@ int Coll_reduce_mpich::reduce( void *sendbuf, void *recvbuf,
 */
 
 
-int Coll_reduce_scatter_mpich::reduce_scatter( void *sbuf, void *rbuf,
-                                                    int *rcounts,
+int Coll_reduce_scatter_mpich::reduce_scatter(const void *sbuf, void *rbuf,
+                                                    const int *rcounts,
                                                     MPI_Datatype dtype,
                                                     MPI_Op  op,
                                                     MPI_Comm  comm
@@ -513,7 +532,7 @@ int Coll_reduce_scatter_mpich::reduce_scatter( void *sbuf, void *rbuf,
    End Algorithm: MPI_Allgather
 */
 
-int Coll_allgather_mpich::allgather(void *sbuf, int scount,
+int Coll_allgather_mpich::allgather(const void *sbuf, int scount,
                                               MPI_Datatype sdtype,
                                               void* rbuf, int rcount,
                                               MPI_Datatype rdtype,
@@ -591,10 +610,10 @@ int Coll_allgather_mpich::allgather(void *sbuf, int scount,
 
    End Algorithm: MPI_Allgatherv
 */
-int Coll_allgatherv_mpich::allgatherv(void *sbuf, int scount,
+int Coll_allgatherv_mpich::allgatherv(const void *sbuf, int scount,
                                                MPI_Datatype sdtype,
-                                               void* rbuf, int *rcounts,
-                                               int *rdispls,
+                                               void* rbuf, const int *rcounts,
+                                               const int *rdispls,
                                                MPI_Datatype rdtype,
                                                MPI_Comm  comm
                                                )
@@ -649,7 +668,7 @@ int Coll_allgatherv_mpich::allgatherv(void *sbuf, int scount,
    End Algorithm: MPI_Gather
 */
 
-int Coll_gather_mpich::gather(void *sbuf, int scount,
+int Coll_gather_mpich::gather(const void *sbuf, int scount,
                                            MPI_Datatype sdtype,
                                            void* rbuf, int rcount,
                                            MPI_Datatype rdtype,
@@ -684,25 +703,21 @@ int Coll_gather_mpich::gather(void *sbuf, int scount,
 */
 
 
-int Coll_scatter_mpich::scatter(void *sbuf, int scount,
+int Coll_scatter_mpich::scatter(const void *sbuf, int scount,
                                             MPI_Datatype sdtype,
                                             void* rbuf, int rcount,
                                             MPI_Datatype rdtype,
                                             int root, MPI_Comm  comm
                                             )
 {
+  std::unique_ptr<unsigned char[]> tmp_buf;
   if(comm->rank()!=root){
-      sbuf=xbt_malloc(rcount*rdtype->get_extent());
-      scount=rcount;
-      sdtype=rdtype;
-  }
-  int ret= Coll_scatter_ompi_binomial::scatter (sbuf, scount, sdtype,
-                                                       rbuf, rcount, rdtype,
-                                                       root, comm);
-  if(comm->rank()!=root){
-      xbt_free(sbuf);
+    tmp_buf.reset(new unsigned char[rcount * rdtype->get_extent()]);
+    sbuf   = tmp_buf.get();
+    scount = rcount;
+    sdtype = rdtype;
   }
-  return ret;
+  return Coll_scatter_ompi_binomial::scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm);
 }
 }
 }