Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
More template based factorization
[simgrid.git] / src / smpi / smpi_win.cpp
index 0f8f5c2..67612c3 100644 (file)
@@ -11,13 +11,15 @@ XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_rma, smpi, "Logging specific to SMPI (RMA o
 
 namespace simgrid{
 namespace smpi{
+std::unordered_map<int, smpi_key_elem> Win::keyvals_;
+int Win::keyval_id_=0;
 
 Win::Win(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm): base_(base), size_(size), disp_unit_(disp_unit), assert_(0), info_(info), comm_(comm){
   int comm_size = comm->size();
   int rank      = comm->rank();
   XBT_DEBUG("Creating window");
   if(info!=MPI_INFO_NULL)
-    info->refcount++;
+    info->ref();
   name_ = nullptr;
   opened_ = 0;
   group_ = MPI_GROUP_NULL;
@@ -29,12 +31,12 @@ Win::Win(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm)
   if(rank==0){
     bar_ = MSG_barrier_init(comm_size);
   }
-  mpi_coll_allgather_fun(&(connected_wins_[rank]), sizeof(MPI_Win), MPI_BYTE, connected_wins_, sizeof(MPI_Win),
+  Colls::allgather(&(connected_wins_[rank]), sizeof(MPI_Win), MPI_BYTE, connected_wins_, sizeof(MPI_Win),
                          MPI_BYTE, comm);
 
-  mpi_coll_bcast_fun(&(bar_), sizeof(msg_bar_t), MPI_BYTE, 0, comm);
+  Colls::bcast(&(bar_), sizeof(msg_bar_t), MPI_BYTE, 0, comm);
 
-  mpi_coll_barrier_fun(comm);
+  Colls::barrier(comm);
 }
 
 Win::~Win(){
@@ -51,11 +53,13 @@ Win::~Win(){
     MPI_Info_free(&info_);
   }
 
-  mpi_coll_barrier_fun(comm_);
+  Colls::barrier(comm_);
   int rank=comm_->rank();
   if(rank == 0)
     MSG_barrier_destroy(bar_);
   xbt_mutex_destroy(mut_);
+
+  cleanup_attr<Win>();
 }
 
 void Win::get_name(char* name, int* length){
@@ -76,6 +80,19 @@ void Win::get_group(MPI_Group* group){
   }
 }
 
+MPI_Aint Win::size(){
+  return size_;
+}
+
+void* Win::base(){
+  return base_;
+}
+
+int Win::disp_unit(){
+  return disp_unit_;
+}
+
+
 void Win::set_name(char* name){
   name_ = xbt_strdup(name);
 }
@@ -125,6 +142,9 @@ int Win::put( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
   //get receiver pointer
   MPI_Win recv_win = connected_wins_[target_rank];
 
+  if(target_count*target_datatype->get_extent()>recv_win->size_)
+    return MPI_ERR_ARG;
+
   void* recv_addr = static_cast<void*> ( static_cast<char*>(recv_win->base_) + target_disp * recv_win->disp_unit_);
   XBT_DEBUG("Entering MPI_Put to %d", target_rank);
 
@@ -149,7 +169,7 @@ int Win::put( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
     requests_->push_back(sreq);
     xbt_mutex_release(mut_);
   }else{
-    smpi_datatype_copy(origin_addr, origin_count, origin_datatype, recv_addr, target_count, target_datatype);
+    Datatype::copy(origin_addr, origin_count, origin_datatype, recv_addr, target_count, target_datatype);
   }
 
   return MPI_SUCCESS;
@@ -163,6 +183,9 @@ int Win::get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
   //get sender pointer
   MPI_Win send_win = connected_wins_[target_rank];
 
+  if(target_count*target_datatype->get_extent()>send_win->size_)
+    return MPI_ERR_ARG;
+
   void* send_addr = static_cast<void*>(static_cast<char*>(send_win->base_) + target_disp * send_win->disp_unit_);
   XBT_DEBUG("Entering MPI_Get from %d", target_rank);
 
@@ -191,7 +214,7 @@ int Win::get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
     requests_->push_back(rreq);
     xbt_mutex_release(mut_);
   }else{
-    smpi_datatype_copy(send_addr, target_count, target_datatype, origin_addr, origin_count, origin_datatype);
+    Datatype::copy(send_addr, target_count, target_datatype, origin_addr, origin_count, origin_datatype);
   }
 
   return MPI_SUCCESS;
@@ -207,6 +230,9 @@ int Win::accumulate( void *origin_addr, int origin_count, MPI_Datatype origin_da
   //get receiver pointer
   MPI_Win recv_win = connected_wins_[target_rank];
 
+  if(target_count*target_datatype->get_extent()>recv_win->size_)
+    return MPI_ERR_ARG;
+
   void* recv_addr = static_cast<void*>(static_cast<char*>(recv_win->base_) + target_disp * recv_win->disp_unit_);
   XBT_DEBUG("Entering MPI_Accumulate to %d", target_rank);
     //As the tag will be used for ordering of the operations, add count to it
@@ -265,12 +291,12 @@ int Win::start(MPI_Group group, int assert){
   Request::startall(size, reqs);
   Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
   for(i=0;i<size;i++){
-    Request::unuse(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
   opened_++; //we're open for business !
   group_=group;
-  group->use();
+  group->ref();
   return MPI_SUCCESS;
 }
 
@@ -294,12 +320,12 @@ int Win::post(MPI_Group group, int assert){
   Request::startall(size, reqs);
   Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
   for(i=0;i<size;i++){
-    Request::unuse(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
   opened_++; //we're open for business !
   group_=group;
-  group->use();
+  group->ref();
   return MPI_SUCCESS;
 }
 
@@ -327,7 +353,7 @@ int Win::complete(){
   Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
 
   for(i=0;i<size;i++){
-    Request::unuse(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
 
@@ -350,7 +376,7 @@ int Win::complete(){
   }
   xbt_mutex_release(mut_);
 
-  group_->unuse();
+  Group::unref(group_);
   opened_--; //we're closed for business !
   return MPI_SUCCESS;
 }
@@ -375,7 +401,7 @@ int Win::wait(){
   Request::startall(size, reqs);
   Request::waitall(size, reqs, MPI_STATUSES_IGNORE);
   for(i=0;i<size;i++){
-    Request::unuse(&reqs[i]);
+    Request::unref(&reqs[i]);
   }
   xbt_free(reqs);
   xbt_mutex_acquire(mut_);
@@ -396,10 +422,14 @@ int Win::wait(){
   }
   xbt_mutex_release(mut_);
 
-  group_->unuse();
+  Group::unref(group_);
   opened_--; //we're opened for business !
   return MPI_SUCCESS;
 }
 
+Win* Win::f2c(int id){
+  return static_cast<Win*>(F2C::f2c(id));
+}
+
 }
 }