From fe635536c45ee6f1475520ef1eeb3823e6859909 Mon Sep 17 00:00:00 2001 From: degomme Date: Wed, 5 Apr 2017 17:01:59 +0200 Subject: [PATCH 1/1] Add MPI_Compare_and_swap (normally atomic) call --- include/smpi/smpi.h | 3 +++ src/smpi/smpi_mpi.cpp | 2 ++ src/smpi/smpi_pmpi.cpp | 34 ++++++++++++++++++++++++++++++ src/smpi/smpi_win.cpp | 47 ++++++++++++++++++++++++++++++++++++++---- src/smpi/smpi_win.hpp | 4 ++++ 5 files changed, 86 insertions(+), 4 deletions(-) diff --git a/include/smpi/smpi.h b/include/smpi/smpi.h index ad6c4d5f36..500b64bb57 100644 --- a/include/smpi/smpi.h +++ b/include/smpi/smpi.h @@ -568,6 +568,9 @@ MPI_CALL(XBT_PUBLIC(int), MPI_Get_accumulate,( void *origin_addr, int origin_cou int target_count, MPI_Datatype target_datatype, MPI_Op op, MPI_Win win)); MPI_CALL(XBT_PUBLIC(int), MPI_Fetch_and_op,( void *origin_addr, void* result_addr, MPI_Datatype datatype, int target_rank, MPI_Aint target_disp, MPI_Op op, MPI_Win win)); +MPI_CALL(XBT_PUBLIC(int), MPI_Compare_and_swap, (void *origin_addr, void *compare_addr, + void *result_addr, MPI_Datatype datatype, int target_rank, MPI_Aint target_disp, MPI_Win win)); + MPI_CALL(XBT_PUBLIC(int), MPI_Alloc_mem, (MPI_Aint size, MPI_Info info, void *baseptr)); MPI_CALL(XBT_PUBLIC(int), MPI_Free_mem, (void *base)); diff --git a/src/smpi/smpi_mpi.cpp b/src/smpi/smpi_mpi.cpp index be3cecee17..53f8c64f33 100644 --- a/src/smpi/smpi_mpi.cpp +++ b/src/smpi/smpi_mpi.cpp @@ -71,6 +71,8 @@ WRAPPED_PMPI_CALL(int,MPI_Comm_set_attr ,(MPI_Comm comm, int comm_keyval, void * WRAPPED_PMPI_CALL(int,MPI_Comm_size,(MPI_Comm comm, int *size),(comm, size)) WRAPPED_PMPI_CALL(int,MPI_Comm_split,(MPI_Comm comm, int color, int key, MPI_Comm* comm_out),(comm, color, key, comm_out)) WRAPPED_PMPI_CALL(int,MPI_Comm_create_group,(MPI_Comm comm, MPI_Group group, int tag, MPI_Comm* comm_out),(comm, group, tag, comm_out)) +WRAPPED_PMPI_CALL(int,MPI_Compare_and_swap,(void *origin_addr, void *compare_addr, + void *result_addr, MPI_Datatype datatype, int target_rank, MPI_Aint target_disp, MPI_Win win), (origin_addr, compare_addr, result_addr, datatype, target_rank, target_disp, win)); WRAPPED_PMPI_CALL(int,MPI_Dims_create,(int nnodes, int ndims, int* dims) ,(nnodes, ndims, dims)) WRAPPED_PMPI_CALL(int,MPI_Error_class,(int errorcode, int* errorclass) ,(errorcode, errorclass)) WRAPPED_PMPI_CALL(int,MPI_Exscan,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm),(sendbuf, recvbuf, count, datatype, op, comm)) diff --git a/src/smpi/smpi_pmpi.cpp b/src/smpi/smpi_pmpi.cpp index 5c5915263b..65c0abd74b 100644 --- a/src/smpi/smpi_pmpi.cpp +++ b/src/smpi/smpi_pmpi.cpp @@ -2756,6 +2756,40 @@ int PMPI_Fetch_and_op(void *origin_addr, void *result_addr, MPI_Datatype dtype, return PMPI_Get_accumulate(origin_addr, origin_addr==nullptr?0:1, dtype, result_addr, 1, dtype, target_rank, target_disp, 1, dtype, op, win); } +int PMPI_Compare_and_swap(void *origin_addr, void *compare_addr, + void *result_addr, MPI_Datatype datatype, int target_rank, + MPI_Aint target_disp, MPI_Win win){ + int retval = 0; + smpi_bench_end(); + if (win == MPI_WIN_NULL) { + retval = MPI_ERR_WIN; + } else if (target_rank == MPI_PROC_NULL) { + retval = MPI_SUCCESS; + } else if (target_rank <0){ + retval = MPI_ERR_RANK; + } else if (win->dynamic()==0 && target_disp <0){ + //in case of dynamic window, target_disp can be mistakenly seen as negative, as it is an address + retval = MPI_ERR_ARG; + } else if (origin_addr==nullptr || result_addr==nullptr || compare_addr==nullptr){ + retval = MPI_ERR_COUNT; + } else if (!datatype->is_valid()) { + retval = MPI_ERR_TYPE; + } else { + int rank = smpi_process()->index(); + MPI_Group group; + win->get_group(&group); + int src_traced = group->index(target_rank); + TRACE_smpi_ptp_in(rank, src_traced, rank, __FUNCTION__, nullptr); + + retval = win->compare_and_swap( origin_addr, compare_addr, result_addr, datatype, + target_rank, target_disp); + + TRACE_smpi_ptp_out(rank, src_traced, rank, __FUNCTION__); + } + smpi_bench_begin(); + return retval; +} + int PMPI_Win_post(MPI_Group group, int assert, MPI_Win win){ int retval = 0; smpi_bench_end(); diff --git a/src/smpi/smpi_win.cpp b/src/smpi/smpi_win.cpp index fb16d16e03..48b992dbef 100644 --- a/src/smpi/smpi_win.cpp +++ b/src/smpi/smpi_win.cpp @@ -26,6 +26,7 @@ Win::Win(void *base, MPI_Aint size, int disp_unit, MPI_Info info, MPI_Comm comm, requests_ = new std::vector(); mut_=xbt_mutex_init(); lock_mut_=xbt_mutex_init(); + atomic_mut_=xbt_mutex_init(); connected_wins_ = new MPI_Win[comm_size]; connected_wins_[rank_] = this; count_ = 0; @@ -68,6 +69,7 @@ Win::~Win(){ MSG_barrier_destroy(bar_); xbt_mutex_destroy(mut_); xbt_mutex_destroy(lock_mut_); + xbt_mutex_destroy(atomic_mut_); if(allocated_ !=0) xbt_free(base_); @@ -352,17 +354,54 @@ int Win::get_accumulate( void *origin_addr, int origin_count, MPI_Datatype origi return MPI_ERR_ARG; XBT_DEBUG("Entering MPI_Get_accumulate from %d", target_rank); - + //need to be sure ops are correctly ordered, so finish request here ? slow. + MPI_Request req; + xbt_mutex_acquire(send_win->atomic_mut_); get(result_addr, result_count, result_datatype, target_rank, - target_disp, target_count, target_datatype); + target_disp, target_count, target_datatype, &req); + if (req != MPI_REQUEST_NULL) + Request::wait(&req, MPI_STATUS_IGNORE); if(op!=MPI_NO_OP) accumulate(origin_addr, origin_count, origin_datatype, target_rank, - target_disp, target_count, target_datatype, op); - + target_disp, target_count, target_datatype, op, &req); + if (req != MPI_REQUEST_NULL) + Request::wait(&req, MPI_STATUS_IGNORE); + xbt_mutex_release(send_win->atomic_mut_); return MPI_SUCCESS; } +int Win::compare_and_swap(void *origin_addr, void *compare_addr, + void *result_addr, MPI_Datatype datatype, int target_rank, + MPI_Aint target_disp){ + //get sender pointer + MPI_Win send_win = connected_wins_[target_rank]; + + if(opened_==0){//check that post/start has been done + // no fence or start .. lock ok ? + int locked=0; + for(auto it : send_win->lockers_) + if (it == comm_->rank()) + locked = 1; + if(locked != 1) + return MPI_ERR_WIN; + } + + XBT_DEBUG("Entering MPI_Compare_and_swap with %d", target_rank); + MPI_Request req; + xbt_mutex_acquire(send_win->atomic_mut_); + get(result_addr, 1, datatype, target_rank, + target_disp, 1, datatype, &req); + if (req != MPI_REQUEST_NULL) + Request::wait(&req, MPI_STATUS_IGNORE); + if(! memcmp (result_addr, compare_addr, datatype->get_extent() )){ + put(origin_addr, 1, datatype, target_rank, + target_disp, 1, datatype); + } + xbt_mutex_release(send_win->atomic_mut_); + return MPI_SUCCESS; +} + int Win::start(MPI_Group group, int assert){ /* From MPI forum advices The call to MPI_WIN_COMPLETE does not return until the put call has completed at the origin; and the target window diff --git a/src/smpi/smpi_win.hpp b/src/smpi/smpi_win.hpp index 1d36f36c73..a0f9aa77e3 100644 --- a/src/smpi/smpi_win.hpp +++ b/src/smpi/smpi_win.hpp @@ -31,6 +31,7 @@ class Win : public F2C, public Keyval { MPI_Group group_; int count_; //for ordering the accs xbt_mutex_t lock_mut_; + xbt_mutex_t atomic_mut_; std::list lockers_; int rank_; // to identify owner for barriers in MPI_COMM_WORLD int mode_; // exclusive or shared lock @@ -70,6 +71,9 @@ public: int get_accumulate( void *origin_addr, int origin_count, MPI_Datatype origin_datatype, void *result_addr, int result_count, MPI_Datatype result_datatype, int target_rank, MPI_Aint target_disp, int target_count, MPI_Datatype target_datatype, MPI_Op op); + int compare_and_swap(void *origin_addr, void *compare_addr, + void *result_addr, MPI_Datatype datatype, int target_rank, + MPI_Aint target_disp); static Win* f2c(int id); int lock(int lock_type, int rank, int assert); int unlock(int rank); -- 2.20.1