From ec1872016478e2a44665e1e40a8608bcb22f08a3 Mon Sep 17 00:00:00 2001 From: degomme Date: Tue, 2 Apr 2019 10:36:46 +0200 Subject: [PATCH] MPI_Ireduce_scatter, MPI_Ireduce_scatter_block --- src/smpi/bindings/smpi_mpi.cpp | 4 +- src/smpi/bindings/smpi_pmpi_coll.cpp | 38 ++++++++--- src/smpi/colls/smpi_nbc_impl.cpp | 37 ++++++++++ teshsuite/smpi/mpich3-test/coll/nonblocking.c | 16 ++--- .../smpi/mpich3-test/coll/nonblocking2.c | 68 +++++++++---------- 5 files changed, 111 insertions(+), 52 deletions(-) diff --git a/src/smpi/bindings/smpi_mpi.cpp b/src/smpi/bindings/smpi_mpi.cpp index e2fd7349f1..96e2fd3270 100644 --- a/src/smpi/bindings/smpi_mpi.cpp +++ b/src/smpi/bindings/smpi_mpi.cpp @@ -152,8 +152,8 @@ WRAPPED_PMPI_CALL(int,MPI_Ialltoallv,(void *sendbuf, int *sendcounts, int *sendd WRAPPED_PMPI_CALL(int,MPI_Ialltoallw,( void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype *sendtypes, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype *recvtypes, MPI_Comm comm, MPI_Request *request),( sendbuf, sendcnts, sdispls, sendtypes, recvbuf, recvcnts, rdispls, recvtypes, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Igather,(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Igatherv,(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, request)) -//WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter_block,(void *sendbuf, void *recvbuf, int recvcount, MPI_Datatype datatype, MPI_Op op,MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcount, datatype, op, comm, request)) -//WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter,(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcounts, datatype, op, comm, request)) +WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter_block,(void *sendbuf, void *recvbuf, int recvcount, MPI_Datatype datatype, MPI_Op op,MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcount, datatype, op, comm, request)) +WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter,(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcounts, datatype, op, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Ireduce,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, root, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Iexscan,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Iscan,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, comm, request)) diff --git a/src/smpi/bindings/smpi_pmpi_coll.cpp b/src/smpi/bindings/smpi_pmpi_coll.cpp index 825a544acd..66df9540ac 100644 --- a/src/smpi/bindings/smpi_pmpi_coll.cpp +++ b/src/smpi/bindings/smpi_pmpi_coll.cpp @@ -469,7 +469,7 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty } int rank = simgrid::s4u::this_actor::get_pid(); - TRACE_smpi_comm_in(rank, __func__, + TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED ?"PMPI_Allreduce":"PMPI_Iallreduce", new simgrid::instr::CollTIData(request==MPI_REQUEST_IGNORED ? "allreduce":"iallreduce", -1, 0, datatype->is_replayable() ? count : count * datatype->size(), -1, simgrid::smpi::Datatype::encode(datatype), "")); @@ -580,6 +580,11 @@ int PMPI_Iexscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, } int PMPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) +{ + return PMPI_Ireduce_scatter(sendbuf, recvbuf, recvcounts, datatype, op, comm, MPI_REQUEST_IGNORED); +} + +int PMPI_Ireduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request) { int retval = 0; smpi_bench_end(); @@ -592,6 +597,8 @@ int PMPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datat retval = MPI_ERR_OP; } else if (recvcounts == nullptr) { retval = MPI_ERR_ARG; + } else if (request == nullptr){ + retval = MPI_ERR_ARG; } else { int rank = simgrid::s4u::this_actor::get_pid(); std::vector* trace_recvcounts = new std::vector; @@ -609,11 +616,15 @@ int PMPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datat memcpy(sendtmpbuf, recvbuf, totalcount * datatype->size()); } - TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::VarCollTIData( - "reducescatter", -1, dt_send_size, nullptr, -1, trace_recvcounts, + TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED ? "PMPI_Reduce_scatter": "PMPI_Ireduce_scatter", new simgrid::instr::VarCollTIData( + request==MPI_REQUEST_IGNORED ? "reducescatter":"ireducescatter", -1, dt_send_size, nullptr, -1, trace_recvcounts, simgrid::smpi::Datatype::encode(datatype), "")); - simgrid::smpi::Colls::reduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm); + if(request == MPI_REQUEST_IGNORED) + simgrid::smpi::Colls::reduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm); + else + simgrid::smpi::Colls::ireduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm, request); + retval = MPI_SUCCESS; TRACE_smpi_comm_out(rank); @@ -627,6 +638,12 @@ int PMPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datat int PMPI_Reduce_scatter_block(void *sendbuf, void *recvbuf, int recvcount, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm) +{ + return PMPI_Ireduce_scatter_block(sendbuf, recvbuf, recvcount, datatype, op, comm, MPI_REQUEST_IGNORED); +} + +int PMPI_Ireduce_scatter_block(void *sendbuf, void *recvbuf, int recvcount, + MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request) { int retval; smpi_bench_end(); @@ -639,7 +656,9 @@ int PMPI_Reduce_scatter_block(void *sendbuf, void *recvbuf, int recvcount, retval = MPI_ERR_OP; } else if (recvcount < 0) { retval = MPI_ERR_ARG; - } else { + } else if (request == nullptr){ + retval = MPI_ERR_ARG; + } else { int count = comm->size(); int rank = simgrid::s4u::this_actor::get_pid(); @@ -652,14 +671,17 @@ int PMPI_Reduce_scatter_block(void *sendbuf, void *recvbuf, int recvcount, memcpy(sendtmpbuf, recvbuf, recvcount * count * datatype->size()); } - TRACE_smpi_comm_in(rank, __func__, - new simgrid::instr::VarCollTIData("reducescatter", -1, 0, nullptr, -1, trace_recvcounts, + TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED ? "PMPI_Reduce_scatter_block":"PMPI_Ireduce_scatter_block", + new simgrid::instr::VarCollTIData(request==MPI_REQUEST_IGNORED ? "reducescatter":"ireducescatter", -1, 0, nullptr, -1, trace_recvcounts, simgrid::smpi::Datatype::encode(datatype), "")); int* recvcounts = new int[count]; for (int i = 0; i < count; i++) recvcounts[i] = recvcount; - simgrid::smpi::Colls::reduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm); + if(request == MPI_REQUEST_IGNORED) + simgrid::smpi::Colls::reduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm); + else + simgrid::smpi::Colls::ireduce_scatter(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm, request); delete[] recvcounts; retval = MPI_SUCCESS; diff --git a/src/smpi/colls/smpi_nbc_impl.cpp b/src/smpi/colls/smpi_nbc_impl.cpp index 6d5f761547..6c234e69d0 100644 --- a/src/smpi/colls/smpi_nbc_impl.cpp +++ b/src/smpi/colls/smpi_nbc_impl.cpp @@ -602,5 +602,42 @@ int Colls::iexscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatyp return MPI_SUCCESS; } +int Colls::ireduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datatype datatype, MPI_Op op, + MPI_Comm comm, MPI_Request* request){ +//Version where each process performs the reduce for its own part. Alltoall pattern for comms. + const int system_tag = COLL_TAG_REDUCE_SCATTER; + MPI_Aint lb = 0; + MPI_Aint dataext = 0; + MPI_Request *requests; + + int rank = comm->rank(); + int size = comm->size(); + int count=recvcounts[rank]; + (*request) = new Request( recvbuf, count, datatype, + rank,rank, system_tag, comm, MPI_REQ_PERSISTENT, op); + datatype->extent(&lb, &dataext); + + // Send/Recv buffers to/from others; + requests = new MPI_Request[2 * (size - 1)]; + int index = 0; + int recvdisp=0; + for (int other = 0; other < size; other++) { + if(other != rank) { + requests[index] = Request::isend_init(static_cast(sendbuf) + recvdisp * dataext, recvcounts[other], datatype, other, system_tag,comm); + XBT_VERB("sending with recvdisp %d", recvdisp); + index++; + requests[index] = Request::irecv_init(smpi_get_tmp_sendbuffer(count * dataext), count, datatype, + other, system_tag, comm); + index++; + }else{ + Datatype::copy(static_cast(sendbuf) + recvdisp * dataext, count, datatype, recvbuf, count, datatype); + } + recvdisp+=recvcounts[other]; + } + Request::startall(2 * (size - 1), requests); + (*request)->set_nbc_requests(requests, 2 * (size - 1)); + return MPI_SUCCESS; +} + } } diff --git a/teshsuite/smpi/mpich3-test/coll/nonblocking.c b/teshsuite/smpi/mpich3-test/coll/nonblocking.c index 548b073257..9d8dffb3d6 100644 --- a/teshsuite/smpi/mpich3-test/coll/nonblocking.c +++ b/teshsuite/smpi/mpich3-test/coll/nonblocking.c @@ -168,17 +168,17 @@ int main(int argc, char **argv) MPI_Iallreduce(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req); MPI_Wait(&req, MPI_STATUS_IGNORE); -/* MPI_Ireduce_scatter(sbuf, rbuf, rcounts, MPI_INT, MPI_SUM, comm, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ + MPI_Ireduce_scatter(sbuf, rbuf, rcounts, MPI_INT, MPI_SUM, comm, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); -/* MPI_Ireduce_scatter(MPI_IN_PLACE, rbuf, rcounts, MPI_INT, MPI_SUM, comm, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ + MPI_Ireduce_scatter(MPI_IN_PLACE, rbuf, rcounts, MPI_INT, MPI_SUM, comm, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); -/* MPI_Ireduce_scatter_block(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ + MPI_Ireduce_scatter_block(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); -/* MPI_Ireduce_scatter_block(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ + MPI_Ireduce_scatter_block(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); MPI_Iscan(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req); MPI_Wait(&req, MPI_STATUS_IGNORE); diff --git a/teshsuite/smpi/mpich3-test/coll/nonblocking2.c b/teshsuite/smpi/mpich3-test/coll/nonblocking2.c index 6e216b5e30..187b2889ad 100644 --- a/teshsuite/smpi/mpich3-test/coll/nonblocking2.c +++ b/teshsuite/smpi/mpich3-test/coll/nonblocking2.c @@ -281,43 +281,43 @@ int main(int argc, char **argv) } /* MPI_Ireduce_scatter */ -/* for (i = 0; i < size; ++i) {*/ -/* recvcounts[i] = COUNT;*/ -/* for (j = 0; j < COUNT; ++j) {*/ -/* buf[i * COUNT + j] = rank + i;*/ -/* recvbuf[i * COUNT + j] = 0xdeadbeef;*/ -/* }*/ -/* }*/ -/* MPI_Ireduce_scatter(buf, recvbuf, recvcounts, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ -/* for (j = 0; j < COUNT; ++j) {*/ -/* my_assert(recvbuf[j] == (size * rank + ((size - 1) * size) / 2));*/ -/* }*/ -/* for (i = 1; i < size; ++i) {*/ -/* for (j = 0; j < COUNT; ++j) {*/ - /* check we didn't corrupt the rest of the recvbuf */ -/* my_assert(recvbuf[i * COUNT + j] == 0xdeadbeef);*/ -/* }*/ -/* }*/ + for (i = 0; i < size; ++i) { + recvcounts[i] = COUNT; + for (j = 0; j < COUNT; ++j) { + buf[i * COUNT + j] = rank + i; + recvbuf[i * COUNT + j] = 0xdeadbeef; + } + } + MPI_Ireduce_scatter(buf, recvbuf, recvcounts, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); + for (j = 0; j < COUNT; ++j) { + my_assert(recvbuf[j] == (size * rank + ((size - 1) * size) / 2)); + } + for (i = 1; i < size; ++i) { + for (j = 0; j < COUNT; ++j) { +/* check we didn't corrupt the rest of the recvbuf */ + my_assert(recvbuf[i * COUNT + j] == 0xdeadbeef); + } + } /* MPI_Ireduce_scatter_block */ -/* for (i = 0; i < size; ++i) {*/ -/* for (j = 0; j < COUNT; ++j) {*/ -/* buf[i * COUNT + j] = rank + i;*/ -/* recvbuf[i * COUNT + j] = 0xdeadbeef;*/ -/* }*/ -/* }*/ -/* MPI_Ireduce_scatter_block(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req);*/ -/* MPI_Wait(&req, MPI_STATUS_IGNORE);*/ -/* for (j = 0; j < COUNT; ++j) {*/ -/* my_assert(recvbuf[j] == (size * rank + ((size - 1) * size) / 2));*/ -/* }*/ -/* for (i = 1; i < size; ++i) {*/ -/* for (j = 0; j < COUNT; ++j) {*/ + for (i = 0; i < size; ++i) { + for (j = 0; j < COUNT; ++j) { + buf[i * COUNT + j] = rank + i; + recvbuf[i * COUNT + j] = 0xdeadbeef; + } + } + MPI_Ireduce_scatter_block(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, MPI_COMM_WORLD, &req); + MPI_Wait(&req, MPI_STATUS_IGNORE); + for (j = 0; j < COUNT; ++j) { + my_assert(recvbuf[j] == (size * rank + ((size - 1) * size) / 2)); + } + for (i = 1; i < size; ++i) { + for (j = 0; j < COUNT; ++j) { /* check we didn't corrupt the rest of the recvbuf */ -/* my_assert(recvbuf[i * COUNT + j] == 0xdeadbeef);*/ -/* }*/ -/* }*/ + my_assert(recvbuf[i * COUNT + j] == 0xdeadbeef); + } + } /* MPI_Igatherv */ for (i = 0; i < size * COUNT; ++i) { -- 2.20.1