Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
hide this from users
[simgrid.git] / src / smpi / smpi_win.cpp
index 803776f..c12c458 100644 (file)
@@ -17,7 +17,7 @@ Win::Win(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm)
   int rank      = comm->rank();
   XBT_DEBUG("Creating window");
   if(info!=MPI_INFO_NULL)
-    info->refcount++;
+    info->ref();
   name_ = nullptr;
   opened_ = 0;
   group_ = MPI_GROUP_NULL;
@@ -29,12 +29,12 @@ Win::Win(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm)
   if(rank==0){
     bar_ = MSG_barrier_init(comm_size);
   }
-  mpi_coll_allgather_fun(&(connected_wins_[rank]), sizeof(MPI_Win), MPI_BYTE, connected_wins_, sizeof(MPI_Win),
+  Colls::allgather(&(connected_wins_[rank]), sizeof(MPI_Win), MPI_BYTE, connected_wins_, sizeof(MPI_Win),
                          MPI_BYTE, comm);
 
-  mpi_coll_bcast_fun(&(bar_), sizeof(msg_bar_t), MPI_BYTE, 0, comm);
+  Colls::bcast(&(bar_), sizeof(msg_bar_t), MPI_BYTE, 0, comm);
 
-  mpi_coll_barrier_fun(comm);
+  Colls::barrier(comm);
 }
 
 Win::~Win(){
@@ -51,7 +51,7 @@ Win::~Win(){
     MPI_Info_free(&info_);
   }
 
-  mpi_coll_barrier_fun(comm_);
+  Colls::barrier(comm_);
   int rank=comm_->rank();
   if(rank == 0)
     MSG_barrier_destroy(bar_);
@@ -91,20 +91,20 @@ int Win::fence(int assert)
     xbt_mutex_acquire(mut_);
     // This (simulated) mutex ensures that no process pushes to the vector of requests during the waitall.
     // Without this, the vector could get redimensionned when another process pushes.
-    // This would result in the array used by smpi_mpi_waitall() to be invalidated.
-    // Another solution would be to copy the data and cleanup the vector *before* smpi_mpi_waitall
+    // This would result in the array used by Request::waitall() to be invalidated.
+    // Another solution would be to copy the data and cleanup the vector *before* Request::waitall
     std::vector<MPI_Request> *reqs = requests_;
     int size = static_cast<int>(reqs->size());
     // start all requests that have been prepared by another process
     if (size > 0) {
       for (const auto& req : *reqs) {
-        if (req && (req->flags & PREPARED))
-          smpi_mpi_start(req);
+        if (req && (req->flags() & PREPARED))
+          req->start();
       }
 
       MPI_Request* treqs = &(*reqs)[0];
 
-      smpi_mpi_waitall(size, treqs, MPI_STATUSES_IGNORE);
+      Request::waitall(size, treqs, MPI_STATUSES_IGNORE);
     }
     count_=0;
     xbt_mutex_release(mut_);
@@ -125,16 +125,19 @@ int Win::put( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
   //get receiver pointer
   MPI_Win recv_win = connected_wins_[target_rank];
 
+  if(target_count*target_datatype->get_extent()>recv_win->size_)
+    return MPI_ERR_ARG;
+
   void* recv_addr = static_cast<void*> ( static_cast<char*>(recv_win->base_) + target_disp * recv_win->disp_unit_);
   XBT_DEBUG("Entering MPI_Put to %d", target_rank);
 
   if(target_rank != comm_->rank()){
     //prepare send_request
-    MPI_Request sreq = smpi_rma_send_init(origin_addr, origin_count, origin_datatype, smpi_process_index(),
+    MPI_Request sreq = Request::rma_send_init(origin_addr, origin_count, origin_datatype, smpi_process_index(),
         comm_->group()->index(target_rank), SMPI_RMA_TAG+1, comm_, MPI_OP_NULL);
 
     //prepare receiver request
-    MPI_Request rreq = smpi_rma_recv_init(recv_addr, target_count, target_datatype, smpi_process_index(),
+    MPI_Request rreq = Request::rma_recv_init(recv_addr, target_count, target_datatype, smpi_process_index(),
         comm_->group()->index(target_rank), SMPI_RMA_TAG+1, recv_win->comm_, MPI_OP_NULL);
 
     //push request to receiver's win
@@ -142,14 +145,14 @@ int Win::put( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
     recv_win->requests_->push_back(rreq);
     xbt_mutex_release(recv_win->mut_);
     //start send
-    smpi_mpi_start(sreq);
+    sreq->start();
 
     //push request to sender's win
     xbt_mutex_acquire(mut_);
     requests_->push_back(sreq);
     xbt_mutex_release(mut_);
   }else{
-    smpi_datatype_copy(origin_addr, origin_count, origin_datatype, recv_addr, target_count, target_datatype);
+    Datatype::copy(origin_addr, origin_count, origin_datatype, recv_addr, target_count, target_datatype);
   }
 
   return MPI_SUCCESS;
@@ -163,35 +166,38 @@ int Win::get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
   //get sender pointer
   MPI_Win send_win = connected_wins_[target_rank];
 
+  if(target_count*target_datatype->get_extent()>send_win->size_)
+    return MPI_ERR_ARG;
+
   void* send_addr = static_cast<void*>(static_cast<char*>(send_win->base_) + target_disp * send_win->disp_unit_);
   XBT_DEBUG("Entering MPI_Get from %d", target_rank);
 
   if(target_rank != comm_->rank()){
     //prepare send_request
-    MPI_Request sreq = smpi_rma_send_init(send_addr, target_count, target_datatype,
+    MPI_Request sreq = Request::rma_send_init(send_addr, target_count, target_datatype,
         comm_->group()->index(target_rank), smpi_process_index(), SMPI_RMA_TAG+2, send_win->comm_,
         MPI_OP_NULL);
 
     //prepare receiver request
-    MPI_Request rreq = smpi_rma_recv_init(origin_addr, origin_count, origin_datatype,
+    MPI_Request rreq = Request::rma_recv_init(origin_addr, origin_count, origin_datatype,
         comm_->group()->index(target_rank), smpi_process_index(), SMPI_RMA_TAG+2, comm_,
         MPI_OP_NULL);
 
     //start the send, with another process than us as sender. 
-    smpi_mpi_start(sreq);
+    sreq->start();
     //push request to receiver's win
     xbt_mutex_acquire(send_win->mut_);
     send_win->requests_->push_back(sreq);
     xbt_mutex_release(send_win->mut_);
 
     //start recv
-    smpi_mpi_start(rreq);
+    rreq->start();
     //push request to sender's win
     xbt_mutex_acquire(mut_);
     requests_->push_back(rreq);
     xbt_mutex_release(mut_);
   }else{
-    smpi_datatype_copy(send_addr, target_count, target_datatype, origin_addr, origin_count, origin_datatype);
+    Datatype::copy(send_addr, target_count, target_datatype, origin_addr, origin_count, origin_datatype);
   }
 
   return MPI_SUCCESS;
@@ -207,15 +213,18 @@ int Win::accumulate( void *origin_addr, int origin_count, MPI_Datatype origin_da
   //get receiver pointer
   MPI_Win recv_win = connected_wins_[target_rank];
 
+  if(target_count*target_datatype->get_extent()>recv_win->size_)
+    return MPI_ERR_ARG;
+
   void* recv_addr = static_cast<void*>(static_cast<char*>(recv_win->base_) + target_disp * recv_win->disp_unit_);
   XBT_DEBUG("Entering MPI_Accumulate to %d", target_rank);
     //As the tag will be used for ordering of the operations, add count to it
     //prepare send_request
-    MPI_Request sreq = smpi_rma_send_init(origin_addr, origin_count, origin_datatype,
+    MPI_Request sreq = Request::rma_send_init(origin_addr, origin_count, origin_datatype,
         smpi_process_index(), comm_->group()->index(target_rank), SMPI_RMA_TAG+3+count_, comm_, op);
 
     //prepare receiver request
-    MPI_Request rreq = smpi_rma_recv_init(recv_addr, target_count, target_datatype,
+    MPI_Request rreq = Request::rma_recv_init(recv_addr, target_count, target_datatype,
         smpi_process_index(), comm_->group()->index(target_rank), SMPI_RMA_TAG+3+count_, recv_win->comm_, op);
 
     count_++;
@@ -224,7 +233,7 @@ int Win::accumulate( void *origin_addr, int origin_count, MPI_Datatype origin_da
     recv_win->requests_->push_back(rreq);
     xbt_mutex_release(recv_win->mut_);
     //start send
-    smpi_mpi_start(sreq);
+    sreq->start();
 
     //push request to sender's win
     xbt_mutex_acquire(mut_);
@@ -256,21 +265,21 @@ int Win::start(MPI_Group group, int assert){
     while (j != size) {
       int src = group->index(j);
       if (src != smpi_process_index() && src != MPI_UNDEFINED) {
-        reqs[i] = smpi_irecv_init(nullptr, 0, MPI_CHAR, src, SMPI_RMA_TAG + 4, MPI_COMM_WORLD);
+        reqs[i] = Request::irecv_init(nullptr, 0, MPI_CHAR, src, SMPI_RMA_TAG + 4, MPI_COMM_WORLD);
         i++;
       }
       j++;
   }
   size=i;
-  smpi_mpi_startall(size, reqs);
-  smpi_mpi_waitall(size, reqs, MPI_STATUSES_IGNORE);
+  Request::startall(size, reqs);
+  Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
   for(i=0;i<size;i++){
-    smpi_mpi_request_free(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
   opened_++; //we're open for business !
   group_=group;
-  group->use();
+  group->ref();
   return MPI_SUCCESS;
 }
 
@@ -284,22 +293,22 @@ int Win::post(MPI_Group group, int assert){
   while(j!=size){
     int dst=group->index(j);
     if(dst!=smpi_process_index() && dst!=MPI_UNDEFINED){
-      reqs[i]=smpi_mpi_send_init(nullptr, 0, MPI_CHAR, dst, SMPI_RMA_TAG+4, MPI_COMM_WORLD);
+      reqs[i]=Request::send_init(nullptr, 0, MPI_CHAR, dst, SMPI_RMA_TAG+4, MPI_COMM_WORLD);
       i++;
     }
     j++;
   }
   size=i;
 
-  smpi_mpi_startall(size, reqs);
-  smpi_mpi_waitall(size, reqs, MPI_STATUSES_IGNORE);
+  Request::startall(size, reqs);
+  Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
   for(i=0;i<size;i++){
-    smpi_mpi_request_free(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
   opened_++; //we're open for business !
   group_=group;
-  group->use();
+  group->ref();
   return MPI_SUCCESS;
 }
 
@@ -316,18 +325,18 @@ int Win::complete(){
   while(j!=size){
     int dst=group_->index(j);
     if(dst!=smpi_process_index() && dst!=MPI_UNDEFINED){
-      reqs[i]=smpi_mpi_send_init(nullptr, 0, MPI_CHAR, dst, SMPI_RMA_TAG+5, MPI_COMM_WORLD);
+      reqs[i]=Request::send_init(nullptr, 0, MPI_CHAR, dst, SMPI_RMA_TAG+5, MPI_COMM_WORLD);
       i++;
     }
     j++;
   }
   size=i;
   XBT_DEBUG("Win_complete - Sending sync messages to %d processes", size);
-  smpi_mpi_startall(size, reqs);
-  smpi_mpi_waitall(size, reqs, MPI_STATUSES_IGNORE);
+  Request::startall(size, reqs);
+  Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
 
   for(i=0;i<size;i++){
-    smpi_mpi_request_free(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
 
@@ -340,17 +349,17 @@ int Win::complete(){
   if (size > 0) {
     // start all requests that have been prepared by another process
     for (const auto& req : *reqqs) {
-      if (req && (req->flags & PREPARED))
-        smpi_mpi_start(req);
+      if (req && (req->flags() & PREPARED))
+        req->start();
     }
 
     MPI_Request* treqs = &(*reqqs)[0];
-    smpi_mpi_waitall(size, treqs, MPI_STATUSES_IGNORE);
+    Request::waitall(size, treqs, MPI_STATUSES_IGNORE);
     reqqs->clear();
   }
   xbt_mutex_release(mut_);
 
-  group_->unuse();
+  Group::unref(group_);
   opened_--; //we're closed for business !
   return MPI_SUCCESS;
 }
@@ -365,17 +374,17 @@ int Win::wait(){
   while(j!=size){
     int src=group_->index(j);
     if(src!=smpi_process_index() && src!=MPI_UNDEFINED){
-      reqs[i]=smpi_irecv_init(nullptr, 0, MPI_CHAR, src,SMPI_RMA_TAG+5, MPI_COMM_WORLD);
+      reqs[i]=Request::irecv_init(nullptr, 0, MPI_CHAR, src,SMPI_RMA_TAG+5, MPI_COMM_WORLD);
       i++;
     }
     j++;
   }
   size=i;
   XBT_DEBUG("Win_wait - Receiving sync messages from %d processes", size);
-  smpi_mpi_startall(size, reqs);
-  smpi_mpi_waitall(size, reqs, MPI_STATUSES_IGNORE);
+  Request::startall(size, reqs);
+  Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
   for(i=0;i<size;i++){
-    smpi_mpi_request_free(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
   xbt_mutex_acquire(mut_);
@@ -386,20 +395,24 @@ int Win::wait(){
   if (size > 0) {
     // start all requests that have been prepared by another process
     for (const auto& req : *reqqs) {
-      if (req && (req->flags & PREPARED))
-        smpi_mpi_start(req);
+      if (req && (req->flags() & PREPARED))
+        req->start();
     }
 
     MPI_Request* treqs = &(*reqqs)[0];
-    smpi_mpi_waitall(size, treqs, MPI_STATUSES_IGNORE);
+    Request::waitall(size, treqs, MPI_STATUSES_IGNORE);
     reqqs->clear();
   }
   xbt_mutex_release(mut_);
 
-  group_->unuse();
+  Group::unref(group_);
   opened_--; //we're opened for business !
   return MPI_SUCCESS;
 }
 
+Win* Win::f2c(int id){
+  return static_cast<Win*>(F2C::f2c(id));
+}
+
 }
 }