Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add MPI_Rput, Rget, Raccumulate and Rget_accumulate calls.
[simgrid.git] / src / smpi / smpi_pmpi.cpp
index d09c579..d68f54c 100644 (file)
@@ -2639,6 +2639,43 @@ int PMPI_Get( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
   return retval;
 }
 
+int PMPI_Rget( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
+              MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request* request){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else if (target_rank == MPI_PROC_NULL) {
+    *request = MPI_REQUEST_NULL;
+    retval = MPI_SUCCESS;
+  } else if (target_rank <0){
+    retval = MPI_ERR_RANK;
+  } else if (win->dynamic()==0 && target_disp <0){ 
+    //in case of dynamic window, target_disp can be mistakenly seen as negative, as it is an address
+    retval = MPI_ERR_ARG;
+  } else if ((origin_count < 0 || target_count < 0) ||
+             (origin_addr==nullptr && origin_count > 0)){
+    retval = MPI_ERR_COUNT;
+  } else if ((!origin_datatype->is_valid()) || (!target_datatype->is_valid())) {
+    retval = MPI_ERR_TYPE;
+  } else if(request == nullptr){
+    retval = MPI_ERR_REQUEST;
+  } else {
+    int rank = smpi_process()->index();
+    MPI_Group group;
+    win->get_group(&group);
+    int src_traced = group->index(target_rank);
+    TRACE_smpi_ptp_in(rank, src_traced, rank, __FUNCTION__, nullptr);
+
+    retval = win->get( origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count,
+                           target_datatype, request);
+
+    TRACE_smpi_ptp_out(rank, src_traced, rank, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
 int PMPI_Put( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win){
   int retval = 0;
@@ -2674,6 +2711,44 @@ int PMPI_Put( void *origin_addr, int origin_count, MPI_Datatype origin_datatype,
   return retval;
 }
 
+int PMPI_Rput( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
+              MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Win win, MPI_Request* request){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else if (target_rank == MPI_PROC_NULL) {
+    *request = MPI_REQUEST_NULL;
+    retval = MPI_SUCCESS;
+  } else if (target_rank <0){
+    retval = MPI_ERR_RANK;
+  } else if (win->dynamic()==0 && target_disp <0){ 
+    //in case of dynamic window, target_disp can be mistakenly seen as negative, as it is an address
+    retval = MPI_ERR_ARG;
+  } else if ((origin_count < 0 || target_count < 0) ||
+            (origin_addr==nullptr && origin_count > 0)){
+    retval = MPI_ERR_COUNT;
+  } else if ((!origin_datatype->is_valid()) || (!target_datatype->is_valid())) {
+    retval = MPI_ERR_TYPE;
+  } else if(request == nullptr){
+    retval = MPI_ERR_REQUEST;
+  } else {
+    int rank = smpi_process()->index();
+    MPI_Group group;
+    win->get_group(&group);
+    int dst_traced = group->index(target_rank);
+    TRACE_smpi_ptp_in(rank, rank, dst_traced, __FUNCTION__, nullptr);
+    TRACE_smpi_send(rank, rank, dst_traced, SMPI_RMA_TAG, origin_count*origin_datatype->size());
+
+    retval = win->put( origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count,
+                           target_datatype, request);
+
+    TRACE_smpi_ptp_out(rank, rank, dst_traced, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
 int PMPI_Accumulate( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
               MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Op op, MPI_Win win){
   int retval = 0;
@@ -2711,6 +2786,46 @@ int PMPI_Accumulate( void *origin_addr, int origin_count, MPI_Datatype origin_da
   return retval;
 }
 
+int PMPI_Raccumulate( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, int target_rank,
+              MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Op op, MPI_Win win, MPI_Request* request){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else if (target_rank == MPI_PROC_NULL) {
+    *request = MPI_REQUEST_NULL;
+    retval = MPI_SUCCESS;
+  } else if (target_rank <0){
+    retval = MPI_ERR_RANK;
+  } else if (win->dynamic()==0 && target_disp <0){ 
+    //in case of dynamic window, target_disp can be mistakenly seen as negative, as it is an address
+    retval = MPI_ERR_ARG;
+  } else if ((origin_count < 0 || target_count < 0) ||
+             (origin_addr==nullptr && origin_count > 0)){
+    retval = MPI_ERR_COUNT;
+  } else if ((!origin_datatype->is_valid()) ||
+            (!target_datatype->is_valid())) {
+    retval = MPI_ERR_TYPE;
+  } else if (op == MPI_OP_NULL) {
+    retval = MPI_ERR_OP;
+  } else if(request == nullptr){
+    retval = MPI_ERR_REQUEST;
+  } else {
+    int rank = smpi_process()->index();
+    MPI_Group group;
+    win->get_group(&group);
+    int src_traced = group->index(target_rank);
+    TRACE_smpi_ptp_in(rank, src_traced, rank, __FUNCTION__, nullptr);
+
+    retval = win->accumulate( origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count,
+                                  target_datatype, op, request);
+
+    TRACE_smpi_ptp_out(rank, src_traced, rank, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
 int PMPI_Get_accumulate(void *origin_addr, int origin_count, MPI_Datatype origin_datatype, void *result_addr, 
 int result_count, MPI_Datatype result_datatype, int target_rank, MPI_Aint target_disp, int target_count, 
 MPI_Datatype target_datatype, MPI_Op op, MPI_Win win){
@@ -2752,10 +2867,89 @@ MPI_Datatype target_datatype, MPI_Op op, MPI_Win win){
   return retval;
 }
 
+
+int PMPI_Rget_accumulate(void *origin_addr, int origin_count, MPI_Datatype origin_datatype, void *result_addr, 
+int result_count, MPI_Datatype result_datatype, int target_rank, MPI_Aint target_disp, int target_count, 
+MPI_Datatype target_datatype, MPI_Op op, MPI_Win win, MPI_Request* request){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else if (target_rank == MPI_PROC_NULL) {
+    *request = MPI_REQUEST_NULL;
+    retval = MPI_SUCCESS;
+  } else if (target_rank <0){
+    retval = MPI_ERR_RANK;
+  } else if (win->dynamic()==0 && target_disp <0){ 
+    //in case of dynamic window, target_disp can be mistakenly seen as negative, as it is an address
+    retval = MPI_ERR_ARG;
+  } else if ((origin_count < 0 || target_count < 0 || result_count <0) ||
+             (origin_addr==nullptr && origin_count > 0) ||
+             (result_addr==nullptr && result_count > 0)){
+    retval = MPI_ERR_COUNT;
+  } else if ((!origin_datatype->is_valid()) ||
+            (!target_datatype->is_valid())||
+            (!result_datatype->is_valid())) {
+    retval = MPI_ERR_TYPE;
+  } else if (op == MPI_OP_NULL) {
+    retval = MPI_ERR_OP;
+  } else if(request == nullptr){
+    retval = MPI_ERR_REQUEST;
+  } else {
+    int rank = smpi_process()->index();
+    MPI_Group group;
+    win->get_group(&group);
+    int src_traced = group->index(target_rank);
+    TRACE_smpi_ptp_in(rank, src_traced, rank, __FUNCTION__, nullptr);
+
+    retval = win->get_accumulate( origin_addr, origin_count, origin_datatype, result_addr, 
+                                  result_count, result_datatype, target_rank, target_disp, 
+                                  target_count, target_datatype, op, request);
+
+    TRACE_smpi_ptp_out(rank, src_traced, rank, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
 int PMPI_Fetch_and_op(void *origin_addr, void *result_addr, MPI_Datatype dtype, int target_rank, MPI_Aint target_disp, MPI_Op op, MPI_Win win){
   return PMPI_Get_accumulate(origin_addr, origin_addr==nullptr?0:1, dtype, result_addr, 1, dtype, target_rank, target_disp, 1, dtype, op, win);
 }
 
+int PMPI_Compare_and_swap(void *origin_addr, void *compare_addr,
+        void *result_addr, MPI_Datatype datatype, int target_rank,
+        MPI_Aint target_disp, MPI_Win win){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else if (target_rank == MPI_PROC_NULL) {
+    retval = MPI_SUCCESS;
+  } else if (target_rank <0){
+    retval = MPI_ERR_RANK;
+  } else if (win->dynamic()==0 && target_disp <0){ 
+    //in case of dynamic window, target_disp can be mistakenly seen as negative, as it is an address
+    retval = MPI_ERR_ARG;
+  } else if (origin_addr==nullptr || result_addr==nullptr || compare_addr==nullptr){
+    retval = MPI_ERR_COUNT;
+  } else if (!datatype->is_valid()) {
+    retval = MPI_ERR_TYPE;
+  } else {
+    int rank = smpi_process()->index();
+    MPI_Group group;
+    win->get_group(&group);
+    int src_traced = group->index(target_rank);
+    TRACE_smpi_ptp_in(rank, src_traced, rank, __FUNCTION__, nullptr);
+
+    retval = win->compare_and_swap( origin_addr, compare_addr, result_addr, datatype, 
+                                  target_rank, target_disp);
+
+    TRACE_smpi_ptp_out(rank, src_traced, rank, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
 int PMPI_Win_post(MPI_Group group, int assert, MPI_Win win){
   int retval = 0;
   smpi_bench_end();
@@ -2861,6 +3055,100 @@ int PMPI_Win_unlock(int rank, MPI_Win win){
   return retval;
 }
 
+int PMPI_Win_lock_all(int assert, MPI_Win win){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else {
+    int myrank = smpi_process()->index();
+    TRACE_smpi_collective_in(myrank, -1, __FUNCTION__, nullptr);
+    retval = win->lock_all(assert);
+    TRACE_smpi_collective_out(myrank, -1, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
+int PMPI_Win_unlock_all(MPI_Win win){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else {
+    int myrank = smpi_process()->index();
+    TRACE_smpi_collective_in(myrank, -1, __FUNCTION__, nullptr);
+    retval = win->unlock_all();
+    TRACE_smpi_collective_out(myrank, -1, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
+int PMPI_Win_flush(int rank, MPI_Win win){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else if (rank == MPI_PROC_NULL){ 
+    retval = MPI_SUCCESS;
+  } else {
+    int myrank = smpi_process()->index();
+    TRACE_smpi_collective_in(myrank, -1, __FUNCTION__, nullptr);
+    retval = win->flush(rank);
+    TRACE_smpi_collective_out(myrank, -1, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
+int PMPI_Win_flush_local(int rank, MPI_Win win){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else if (rank == MPI_PROC_NULL){ 
+    retval = MPI_SUCCESS;
+  } else {
+    int myrank = smpi_process()->index();
+    TRACE_smpi_collective_in(myrank, -1, __FUNCTION__, nullptr);
+    retval = win->flush_local(rank);
+    TRACE_smpi_collective_out(myrank, -1, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
+int PMPI_Win_flush_all(MPI_Win win){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else {
+    int myrank = smpi_process()->index();
+    TRACE_smpi_collective_in(myrank, -1, __FUNCTION__, nullptr);
+    retval = win->flush_all();
+    TRACE_smpi_collective_out(myrank, -1, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
+int PMPI_Win_flush_local_all(MPI_Win win){
+  int retval = 0;
+  smpi_bench_end();
+  if (win == MPI_WIN_NULL) {
+    retval = MPI_ERR_WIN;
+  } else {
+    int myrank = smpi_process()->index();
+    TRACE_smpi_collective_in(myrank, -1, __FUNCTION__, nullptr);
+    retval = win->flush_local_all();
+    TRACE_smpi_collective_out(myrank, -1, __FUNCTION__);
+  }
+  smpi_bench_begin();
+  return retval;
+}
+
 int PMPI_Alloc_mem(MPI_Aint size, MPI_Info info, void *baseptr){
   void *ptr = xbt_malloc(size);
   if(ptr==nullptr)