Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of https://framagit.org/simgrid/simgrid
[simgrid.git] / src / smpi / colls / smpi_default_selector.cpp
index 4b00cd1..d3419b2 100644 (file)
@@ -1,6 +1,6 @@
-/* selector with default/naive Simgrid algorithms. These should not be trusted for performance evaluations */
+/* selector with default/naive SimGrid algorithms. These should not be trusted for performance evaluations */
 
-/* Copyright (c) 2009-2019. The SimGrid Team. All rights reserved.          */
+/* Copyright (c) 2009-2023. The SimGrid Team. All rights reserved.          */
 
 /* This program is free software; you can redistribute it and/or modify it
  * under the terms of the license (GNU LGPL) which comes with this package. */
@@ -8,8 +8,7 @@
 #include "colls_private.hpp"
 #include "src/smpi/include/smpi_actor.hpp"
 
-namespace simgrid{
-namespace smpi{
+namespace simgrid::smpi {
 
 int bcast__default(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm comm)
 {
@@ -26,7 +25,7 @@ int gather__default(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                      void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
 {
   MPI_Request request;
-  Colls::igather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
+  colls::igather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
   return Request::wait(&request, MPI_STATUS_IGNORE);
 }
 
@@ -39,15 +38,25 @@ int reduce_scatter__default(const void *sendbuf, void *recvbuf, const int *recvc
   int size = comm->size();
   int count = 0;
   int* displs = new int[size];
+  int regular=1;
   for (int i = 0; i < size; i++) {
+    if(recvcounts[i]!=recvcounts[0]){
+      regular=0;
+      break;
+    }
     displs[i] = count;
     count += recvcounts[i];
   }
+  if(not regular){
+    delete[] displs;
+    return reduce_scatter__mpich(sendbuf, recvbuf, recvcounts, datatype, op, comm);
+  }
+
   unsigned char* tmpbuf = smpi_get_tmp_sendbuffer(count * datatype->get_extent());
 
   int ret = reduce__default(sendbuf, tmpbuf, count, datatype, op, 0, comm);
   if(ret==MPI_SUCCESS)
-    ret = Colls::scatterv(tmpbuf, recvcounts, displs, datatype, recvbuf, recvcounts[rank], datatype, 0, comm);
+    ret = colls::scatterv(tmpbuf, recvcounts, displs, datatype, recvbuf, recvcounts[rank], datatype, 0, comm);
   delete[] displs;
   smpi_free_tmp_buffer(tmpbuf);
   return ret;
@@ -58,7 +67,7 @@ int allgather__default(const void *sendbuf, int sendcount, MPI_Datatype sendtype
                        void *recvbuf,int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
 {
   MPI_Request request;
-  Colls::iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, &request);
+  colls::iallgather(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, &request);
   return Request::wait(&request, MPI_STATUS_IGNORE);
 }
 
@@ -66,14 +75,11 @@ int allgatherv__default(const void *sendbuf, int sendcount, MPI_Datatype sendtyp
                         const int *recvcounts, const int *displs, MPI_Datatype recvtype, MPI_Comm comm)
 {
   MPI_Request request;
-  Colls::iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, &request, 0);
-  MPI_Request* requests = request->get_nbc_requests();
-  int count = request->get_nbc_requests_size();
-  Request::waitall(count, requests, MPI_STATUS_IGNORE);
-  for (int other = 0; other < count; other++) {
-    Request::unref(&requests[other]);
-  }
-  delete[] requests;
+  colls::iallgatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm, &request, 0);
+  auto requests = request->get_nbc_requests();
+  Request::waitall(requests.size(), requests.data(), MPI_STATUS_IGNORE);
+  for(auto& req: requests)
+    Request::unref(&req);
   Request::unref(&request);
   return MPI_SUCCESS;
 }
@@ -82,7 +88,7 @@ int scatter__default(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
                      void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)
 {
   MPI_Request request;
-  Colls::iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
+  colls::iscatter(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, &request, 0);
   return Request::wait(&request, MPI_STATUS_IGNORE);
 }
 
@@ -94,7 +100,7 @@ int reduce__default(const void *sendbuf, void *recvbuf, int count, MPI_Datatype
     return reduce__ompi_basic_linear(sendbuf, recvbuf, count, datatype, op, root, comm);
   }
   MPI_Request request;
-  Colls::ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, &request, 0);
+  colls::ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, &request, 0);
   return Request::wait(&request, MPI_STATUS_IGNORE);
 }
 
@@ -119,10 +125,9 @@ int alltoallv__default(const void *sendbuf, const int *sendcounts, const int *se
                        void *recvbuf, const int *recvcounts, const int *recvdisps, MPI_Datatype recvtype, MPI_Comm comm)
 {
   MPI_Request request;
-  Colls::ialltoallv(sendbuf, sendcounts, senddisps, sendtype, recvbuf, recvcounts, recvdisps, recvtype, comm, &request, 0);
+  colls::ialltoallv(sendbuf, sendcounts, senddisps, sendtype, recvbuf, recvcounts, recvdisps, recvtype, comm, &request,
+                    0);
   return Request::wait(&request, MPI_STATUS_IGNORE);
 }
 
-}
-}
-
+} // namespace simgrid::smpi