Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Cleanups in smpi::Win (use std::vector, and simplify constructor).
[simgrid.git] / src / smpi / mpi / smpi_win.cpp
index 73bd32d..b90cf57 100644 (file)
@@ -14,6 +14,8 @@
 #include "smpi_request.hpp"
 #include "src/smpi/include/smpi_actor.hpp"
 
+#include <algorithm>
+
 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_rma, smpi, "Logging specific to SMPI (RMA operations)");
 
 
@@ -28,6 +30,7 @@ Win::Win(void* base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm,
     , disp_unit_(disp_unit)
     , info_(info)
     , comm_(comm)
+    , connected_wins_(comm->size())
     , rank_(comm->rank())
     , allocated_(allocated)
     , dynamic_(dynamic)
@@ -35,27 +38,16 @@ Win::Win(void* base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm,
   XBT_DEBUG("Creating window");
   if(info!=MPI_INFO_NULL)
     info->ref();
-  int comm_size          = comm->size();
-  opened_                = 0;
-  group_                 = MPI_GROUP_NULL;
-  requests_              = new std::vector<MPI_Request>();
-  mut_                   = s4u::Mutex::create();
-  lock_mut_              = s4u::Mutex::create();
-  atomic_mut_            = s4u::Mutex::create();
-  connected_wins_        = new MPI_Win[comm_size];
   connected_wins_[rank_] = this;
-  count_                 = 0;
   if(rank_==0){
-    bar_ = new s4u::Barrier(comm_size);
+    bar_ = new s4u::Barrier(comm->size());
   }
-  mode_=0;
-  errhandler_=MPI_ERRORS_ARE_FATAL;
   errhandler_->ref();
   comm->add_rma_win(this);
   comm->ref();
 
-  colls::allgather(&(connected_wins_[rank_]), sizeof(MPI_Win), MPI_BYTE, connected_wins_, sizeof(MPI_Win), MPI_BYTE,
-                   comm);
+  colls::allgather(&connected_wins_[rank_], sizeof(MPI_Win), MPI_BYTE, connected_wins_.data(), sizeof(MPI_Win),
+                   MPI_BYTE, comm);
 
   colls::bcast(&(bar_), sizeof(s4u::Barrier*), MPI_BYTE, 0, comm);
 
@@ -69,8 +61,6 @@ Win::~Win(){
   int finished = finish_comms();
   XBT_DEBUG("Win destructor - Finished %d RMA calls", finished);
 
-  delete requests_;
-  delete[] connected_wins_;
   if (info_ != MPI_INFO_NULL)
     simgrid::smpi::Info::unref(info_);
   if (errhandler_ != MPI_ERRHANDLER_NULL)
@@ -182,11 +172,11 @@ int Win::fence(int assert)
     // Without this, the vector could get redimensioned when another process pushes.
     // This would result in the array used by Request::waitall() to be invalidated.
     // Another solution would be to copy the data and cleanup the vector *before* Request::waitall
-    std::vector<MPI_Request> *reqs = requests_;
-    int size = static_cast<int>(reqs->size());
+
     // start all requests that have been prepared by another process
-    if (size > 0) {
-      MPI_Request* treqs = &(*reqs)[0];
+    if (not requests_.empty()) {
+      int size           = static_cast<int>(requests_.size());
+      MPI_Request* treqs = requests_.data();
       Request::waitall(size, treqs, MPI_STATUSES_IGNORE);
     }
     count_=0;
@@ -207,7 +197,7 @@ int Win::put(const void *origin_addr, int origin_count, MPI_Datatype origin_data
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Request* request)
 {
   //get receiver pointer
-  const Win* recv_win = connected_wins_[target_rank];
+  Win* recv_win = connected_wins_[target_rank];
 
   if(opened_==0){//check that post/start has been done
     // no fence or start .. lock ok ?
@@ -244,13 +234,13 @@ int Win::put(const void *origin_addr, int origin_count, MPI_Datatype origin_data
       *request=sreq;
     }else{
       mut_->lock();
-      requests_->push_back(sreq);
+      requests_.push_back(sreq);
       mut_->unlock();
     }
 
     //push request to receiver's win
     recv_win->mut_->lock();
-    recv_win->requests_->push_back(rreq);
+    recv_win->requests_.push_back(rreq);
     rreq->start();
     recv_win->mut_->unlock();
   } else {
@@ -267,7 +257,7 @@ int Win::get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Request* request)
 {
   //get sender pointer
-  const Win* send_win = connected_wins_[target_rank];
+  Win* send_win = connected_wins_[target_rank];
 
   if(opened_==0){//check that post/start has been done
     // no fence or start .. lock ok ?
@@ -300,7 +290,7 @@ int Win::get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
     sreq->start();
     //push request to receiver's win
     send_win->mut_->lock();
-    send_win->requests_->push_back(sreq);
+    send_win->requests_.push_back(sreq);
     send_win->mut_->unlock();
 
     //start recv
@@ -310,7 +300,7 @@ int Win::get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
       *request=rreq;
     }else{
       mut_->lock();
-      requests_->push_back(rreq);
+      requests_.push_back(rreq);
       mut_->unlock();
     }
   } else {
@@ -326,7 +316,7 @@ int Win::accumulate(const void *origin_addr, int origin_count, MPI_Datatype orig
 {
   XBT_DEBUG("Entering MPI_Win_Accumulate");
   //get receiver pointer
-  const Win* recv_win = connected_wins_[target_rank];
+  Win* recv_win = connected_wins_[target_rank];
 
   if(opened_==0){//check that post/start has been done
     // no fence or start .. lock ok ?
@@ -361,7 +351,7 @@ int Win::accumulate(const void *origin_addr, int origin_count, MPI_Datatype orig
   sreq->start();
   // push request to receiver's win
   recv_win->mut_->lock();
-  recv_win->requests_->push_back(rreq);
+  recv_win->requests_.push_back(rreq);
   rreq->start();
   recv_win->mut_->unlock();
 
@@ -369,7 +359,7 @@ int Win::accumulate(const void *origin_addr, int origin_count, MPI_Datatype orig
     *request = sreq;
   } else {
     mut_->lock();
-    requests_->push_back(sreq);
+    requests_.push_back(sreq);
     mut_->unlock();
   }
 
@@ -680,12 +670,11 @@ Win* Win::f2c(int id){
 int Win::finish_comms(){
   mut_->lock();
   //Finish own requests
-  std::vector<MPI_Request> *reqqs = requests_;
-  int size = static_cast<int>(reqqs->size());
+  int size = static_cast<int>(requests_.size());
   if (size > 0) {
-    MPI_Request* treqs = &(*reqqs)[0];
+    MPI_Request* treqs = requests_.data();
     Request::waitall(size, treqs, MPI_STATUSES_IGNORE);
-    reqqs->clear();
+    requests_.clear();
   }
   mut_->unlock();
   return size;
@@ -693,32 +682,22 @@ int Win::finish_comms(){
 
 int Win::finish_comms(int rank){
   mut_->lock();
-  //Finish own requests
-  std::vector<MPI_Request> *reqqs = requests_;
-  int size = static_cast<int>(reqqs->size());
+  // Finish own requests
+  // Let's see if we're either the destination or the sender of this request
+  // because we only wait for requests that we are responsible for.
+  // Also use the process id here since the request itself returns from src()
+  // and dst() the process id, NOT the rank (which only exists in the context of a communicator).
+  int proc_id = comm_->group()->actor(rank)->get_pid();
+  auto it     = std::stable_partition(begin(requests_), end(requests_), [proc_id](const MPI_Request& req) {
+    return (req == MPI_REQUEST_NULL || (req->src() != proc_id && req->dst() != proc_id));
+  });
+  std::vector<MPI_Request> myreqqs(it, end(requests_));
+  requests_.erase(it, end(requests_));
+  int size = static_cast<int>(myreqqs.size());
   if (size > 0) {
-    size = 0;
-    std::vector<MPI_Request> myreqqs;
-    auto iter                               = reqqs->begin();
-    int proc_id                             = comm_->group()->actor(rank)->get_pid();
-    while (iter != reqqs->end()){
-      // Let's see if we're either the destination or the sender of this request
-      // because we only wait for requests that we are responsible for.
-      // Also use the process id here since the request itself returns from src()
-      // and dst() the process id, NOT the rank (which only exists in the context of a communicator).
-      if (((*iter) != MPI_REQUEST_NULL) && (((*iter)->src() == proc_id) || ((*iter)->dst() == proc_id))) {
-        myreqqs.push_back(*iter);
-        iter = reqqs->erase(iter);
-        size++;
-      } else {
-        ++iter;
-      }
-    }
-    if(size >0){
-      MPI_Request* treqs = &myreqqs[0];
-      Request::waitall(size, treqs, MPI_STATUSES_IGNORE);
-      myreqqs.clear();
-    }
+    MPI_Request* treqs = myreqqs.data();
+    Request::waitall(size, treqs, MPI_STATUSES_IGNORE);
+    myreqqs.clear();
   }
   mut_->unlock();
   return size;