From b512c05a92b5f7e29bd6cf767e1de5b40344aa61 Mon Sep 17 00:00:00 2001 From: Augustin Degomme Date: Mon, 1 Apr 2019 23:43:26 +0200 Subject: [PATCH] MPI_Iallreduce --- src/smpi/bindings/smpi_pmpi_coll.cpp | 10 ++---- src/smpi/colls/smpi_nbc_impl.cpp | 35 +++++++++++++++++++ src/smpi/mpi/smpi_request.cpp | 10 +++--- teshsuite/smpi/mpich3-test/coll/nonblocking.c | 8 ++--- .../smpi/mpich3-test/coll/nonblocking2.c | 26 +++++++------- 5 files changed, 61 insertions(+), 28 deletions(-) diff --git a/src/smpi/bindings/smpi_pmpi_coll.cpp b/src/smpi/bindings/smpi_pmpi_coll.cpp index bd0ac325a3..cc86aec60f 100644 --- a/src/smpi/bindings/smpi_pmpi_coll.cpp +++ b/src/smpi/bindings/smpi_pmpi_coll.cpp @@ -455,11 +455,7 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty retval = MPI_ERR_TYPE; } else if (op == MPI_OP_NULL) { retval = MPI_ERR_OP; - } else if (request != MPI_REQUEST_IGNORED) { - xbt_die("Iallreduce is not yet implemented. WIP"); - retval = MPI_ERR_ARG; } else { - char* sendtmpbuf = static_cast(sendbuf); if( sendbuf == MPI_IN_PLACE ) { sendtmpbuf = static_cast(xbt_malloc(count*datatype->get_extent())); @@ -472,10 +468,10 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty datatype->is_replayable() ? count : count * datatype->size(), -1, simgrid::smpi::Datatype::encode(datatype), "")); -// if(request == MPI_REQUEST_IGNORED) + if(request == MPI_REQUEST_IGNORED) simgrid::smpi::Colls::allreduce(sendtmpbuf, recvbuf, count, datatype, op, comm); -// else -// simgrid::smpi::Colls::iallreduce(sendtmpbuf, recvbuf, count, datatype, op, comm, request); + else + simgrid::smpi::Colls::iallreduce(sendtmpbuf, recvbuf, count, datatype, op, comm, request); if( sendbuf == MPI_IN_PLACE ) xbt_free(sendtmpbuf); diff --git a/src/smpi/colls/smpi_nbc_impl.cpp b/src/smpi/colls/smpi_nbc_impl.cpp index 0cb25f02ec..8e9b16d425 100644 --- a/src/smpi/colls/smpi_nbc_impl.cpp +++ b/src/smpi/colls/smpi_nbc_impl.cpp @@ -501,5 +501,40 @@ int Colls::ireduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatyp } return MPI_SUCCESS; } + +int Colls::iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, + MPI_Op op, MPI_Comm comm, MPI_Request* request) +{ + + const int system_tag = COLL_TAG_ALLREDUCE; + MPI_Aint lb = 0; + MPI_Aint dataext = 0; + MPI_Request *requests; + + int rank = comm->rank(); + int size = comm->size(); + (*request) = new Request( recvbuf, count, datatype, + rank,rank, COLL_TAG_ALLREDUCE, comm, MPI_REQ_PERSISTENT, op); + // FIXME: check for errors + datatype->extent(&lb, &dataext); + // Local copy from self + Datatype::copy(sendbuf, count, datatype, recvbuf, count, datatype); + // Send/Recv buffers to/from others; + requests = new MPI_Request[2 * (size - 1)]; + int index = 0; + for (int other = 0; other < size; other++) { + if(other != rank) { + requests[index] = Request::isend_init(sendbuf, count, datatype, other, system_tag,comm); + index++; + requests[index] = Request::irecv_init(smpi_get_tmp_sendbuffer(count * dataext), count, datatype, + other, system_tag, comm); + index++; + } + } + Request::startall(2 * (size - 1), requests); + (*request)->set_nbc_requests(requests, 2 * (size - 1)); + return MPI_SUCCESS; +} + } } diff --git a/src/smpi/mpi/smpi_request.cpp b/src/smpi/mpi/smpi_request.cpp index 5cf3a4990b..47b37fea93 100644 --- a/src/smpi/mpi/smpi_request.cpp +++ b/src/smpi/mpi/smpi_request.cpp @@ -871,11 +871,13 @@ int Request::wait(MPI_Request * request, MPI_Status * status) void * buf=(*request)->nbc_requests_[i]->buf_; if((*request)->old_type_->flags() & DT_FLAG_DERIVED) buf=(*request)->nbc_requests_[i]->old_buf_; - if((*request)->op_!=MPI_OP_NULL){ - int count=(*request)->size_/ (*request)->old_type_->size(); - (*request)->op_->apply(buf, (*request)->buf_, &count, (*request)->old_type_); + if((*request)->nbc_requests_[i]->flags_ & MPI_REQ_RECV ){ + if((*request)->op_!=MPI_OP_NULL){ + int count=(*request)->size_/ (*request)->old_type_->size(); + (*request)->op_->apply(buf, (*request)->buf_, &count, (*request)->old_type_); + } + smpi_free_tmp_buffer(buf); } - smpi_free_tmp_buffer(buf); } if((*request)->nbc_requests_[i]!=MPI_REQUEST_NULL) Request::unref(&((*request)->nbc_requests_[i])); diff --git a/teshsuite/smpi/mpich3-test/coll/nonblocking.c b/teshsuite/smpi/mpich3-test/coll/nonblocking.c index 0038da1399..94938e5dbc 100644 --- a/teshsuite/smpi/mpich3-test/coll/nonblocking.c +++ b/teshsuite/smpi/mpich3-test/coll/nonblocking.c @@ -162,11 +162,11 @@ int main(int argc, char **argv) MPI_Ireduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req); MPI_Wait(&req, MPI_STATUS_IGNORE); -/* MPI_Iallreduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ + MPI_Iallreduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); -/* MPI_Iallreduce(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ + MPI_Iallreduce(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); /* MPI_Ireduce_scatter(sbuf, rbuf, rcounts, MPI_INT, MPI_SUM, comm, &req);*/ /* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ diff --git a/teshsuite/smpi/mpich3-test/coll/nonblocking2.c b/teshsuite/smpi/mpich3-test/coll/nonblocking2.c index 72d774f8dd..f266f5466d 100644 --- a/teshsuite/smpi/mpich3-test/coll/nonblocking2.c +++ b/teshsuite/smpi/mpich3-test/coll/nonblocking2.c @@ -117,7 +117,7 @@ int main(int argc, char **argv) if (rank == 0) { for (i = 0; i < COUNT; ++i) { if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size))) - printf("aa got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i], + printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i], ((size * (size - 1) / 2) + (i * size))); my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size))); } @@ -145,18 +145,18 @@ int main(int argc, char **argv) } /* MPI_Iallreduce */ -/* for (i = 0; i < COUNT; ++i) {*/ -/* buf[i] = rank + i;*/ -/* recvbuf[i] = 0xdeadbeef;*/ -/* }*/ -/* MPI_Iallreduce(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ -/* for (i = 0; i < COUNT; ++i) {*/ -/* if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))*/ -/* printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],*/ -/* ((size * (size - 1) / 2) + (i * size)));*/ -/* my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));*/ -/* }*/ + for (i = 0; i < COUNT; ++i) { + buf[i] = rank + i; + recvbuf[i] = 0xdeadbeef; + } + MPI_Iallreduce(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); + for (i = 0; i < COUNT; ++i) { + if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size))) + printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i], + ((size * (size - 1) / 2) + (i * size))); + my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size))); + } /* MPI_Ialltoallv (a weak test, neither irregular nor sparse) */ for (i = 0; i < size; ++i) { -- 2.20.1