From: degomme Date: Fri, 29 Mar 2019 13:24:17 +0000 (+0100) Subject: Add MPI_Alltoallw and MPI_Ialltoallw X-Git-Tag: v3_22~19 X-Git-Url: http://info.iut-bm.univ-fcomte.fr/pub/gitweb/simgrid.git/commitdiff_plain/68f707462521f974a8839675ab66e3527125ccbc?hp=5194f75c72a1e45fbe1f4ca7b7ec2e968168f89c Add MPI_Alltoallw and MPI_Ialltoallw --- diff --git a/include/smpi/smpi.h b/include/smpi/smpi.h index 4af1b13b7f..ad5117115a 100644 --- a/include/smpi/smpi.h +++ b/include/smpi/smpi.h @@ -600,7 +600,9 @@ MPI_CALL(XBT_PUBLIC int, MPI_Ialltoall, (void* sendbuf, int sendcount, MPI_Datat MPI_CALL(XBT_PUBLIC int, MPI_Ialltoallv, (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype sendtype, void* recvbuf, int* recvcounts, int* recvdisps, MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request)); - +MPI_CALL(XBT_PUBLIC int, MPI_Ialltoallw, + (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, int* recvcounts, + int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request)); MPI_CALL(XBT_PUBLIC int, MPI_Gather, (void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm)); MPI_CALL(XBT_PUBLIC int, MPI_Gatherv, (void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, diff --git a/src/smpi/bindings/smpi_mpi.cpp b/src/smpi/bindings/smpi_mpi.cpp index efe7ed932d..ceb1b89a61 100644 --- a/src/smpi/bindings/smpi_mpi.cpp +++ b/src/smpi/bindings/smpi_mpi.cpp @@ -80,6 +80,7 @@ WRAPPED_PMPI_CALL(int,MPI_Alloc_mem,(MPI_Aint size, MPI_Info info, void *baseptr WRAPPED_PMPI_CALL(int,MPI_Allreduce,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm),(sendbuf, recvbuf, count, datatype, op, comm)) WRAPPED_PMPI_CALL(int,MPI_Alltoall,(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbuf, int recvcount,MPI_Datatype recvtype, MPI_Comm comm),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm)) WRAPPED_PMPI_CALL(int,MPI_Alltoallv,(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype recvtype, MPI_Comm comm),(sendbuf, sendcounts, senddisps, sendtype, recvbuf, recvcounts, recvdisps, recvtype, comm)) +WRAPPED_PMPI_CALL(int,MPI_Alltoallw,( void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype *sendtypes, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype *recvtypes, MPI_Comm comm),( sendbuf, sendcnts, sdispls, sendtypes, recvbuf, recvcnts, rdispls, recvtypes, comm)) WRAPPED_PMPI_CALL(int,MPI_Attr_delete,(MPI_Comm comm, int keyval) ,(comm, keyval)) WRAPPED_PMPI_CALL(int,MPI_Attr_get,(MPI_Comm comm, int keyval, void* attr_value, int* flag) ,(comm, keyval, attr_value, flag)) WRAPPED_PMPI_CALL(int,MPI_Attr_put,(MPI_Comm comm, int keyval, void* attr_value) ,(comm, keyval, attr_value)) @@ -148,6 +149,7 @@ WRAPPED_PMPI_CALL(int,MPI_Iallgatherv,(void *sendbuf, int sendcount, MPI_Datatyp WRAPPED_PMPI_CALL(int,MPI_Iallreduce,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Ialltoall,(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbuf, int recvcount,MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Ialltoallv,(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcounts, senddisps, sendtype, recvbuf, recvcounts, recvdisps, recvtype, comm, request)) +WRAPPED_PMPI_CALL(int,MPI_Ialltoallw,( void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype *sendtypes, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype *recvtypes, MPI_Comm comm, MPI_Request *request),( sendbuf, sendcnts, sdispls, sendtypes, recvbuf, recvcnts, rdispls, recvtypes, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Igather,(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, request)) WRAPPED_PMPI_CALL(int,MPI_Igatherv,(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, request)) //WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter_block,(void *sendbuf, void *recvbuf, int recvcount, MPI_Datatype datatype, MPI_Op op,MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcount, datatype, op, comm, request)) @@ -310,7 +312,6 @@ WRAPPED_PMPI_CALL(int,MPI_Status_set_elements,( MPI_Status *status, MPI_Datatype UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Add_error_class,( int *errorclass),( errorclass)) UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Add_error_code,(int errorclass, int *errorcode),(errorclass, errorcode)) UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Add_error_string,( int errorcode, char *string),(errorcode, string)) -UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Alltoallw,( void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype *sendtypes, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype *recvtypes, MPI_Comm comm),( sendbuf, sendcnts, sdispls, sendtypes, recvbuf, recvcnts, rdispls, recvtypes, comm)) UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Bsend_init,(void* buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request* request),(buf, count, datatype, dest, tag, comm, request)) UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Bsend,(void* buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm) ,(buf, count, datatype, dest, tag, comm)) UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Buffer_attach,(void* buffer, int size) ,(buffer, size)) diff --git a/src/smpi/bindings/smpi_pmpi_coll.cpp b/src/smpi/bindings/smpi_pmpi_coll.cpp index 91bbfd8494..9e3f6ea3b3 100644 --- a/src/smpi/bindings/smpi_pmpi_coll.cpp +++ b/src/smpi/bindings/smpi_pmpi_coll.cpp @@ -774,3 +774,88 @@ int PMPI_Ialltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype smpi_bench_begin(); return retval; } + +int PMPI_Alltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, + int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm) +{ + return PMPI_Ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, MPI_REQUEST_IGNORED); +} + +int PMPI_Ialltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, + int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request) +{ + int retval = 0; + + smpi_bench_end(); + + if (comm == MPI_COMM_NULL) { + retval = MPI_ERR_COMM; + } else if ((sendbuf != MPI_IN_PLACE && sendtypes == nullptr) || recvtypes == nullptr) { + retval = MPI_ERR_TYPE; + } else if ((sendbuf != MPI_IN_PLACE && (sendcounts == nullptr || senddisps == nullptr)) || recvcounts == nullptr || + recvdisps == nullptr) { + retval = MPI_ERR_ARG; + } else if (request == nullptr){ + retval = MPI_ERR_ARG; + } else { + int rank = simgrid::s4u::this_actor::get_pid(); + int size = comm->size(); + int send_size = 0; + int recv_size = 0; + std::vector* trace_sendcounts = new std::vector; + std::vector* trace_recvcounts = new std::vector; + + void* sendtmpbuf = static_cast(sendbuf); + int* sendtmpcounts = sendcounts; + int* sendtmpdisps = senddisps; + MPI_Datatype* sendtmptypes = sendtypes; + unsigned long maxsize = 0; + for (int i = 0; i < size; i++) { // copy data to avoid bad free + if(recvtypes[i]==MPI_DATATYPE_NULL) + return MPI_ERR_TYPE; + recv_size += recvcounts[i] * recvtypes[i]->size(); + trace_recvcounts->push_back(recvcounts[i] * recvtypes[i]->size()); + if ((recvdisps[i] + (recvcounts[i] * recvtypes[i]->size())) > maxsize) + maxsize = recvdisps[i] + (recvcounts[i] * recvtypes[i]->size()); + } + + if (sendbuf == MPI_IN_PLACE) { + sendtmpbuf = static_cast(xbt_malloc(maxsize)); + memcpy(sendtmpbuf, recvbuf, maxsize); + sendtmpcounts = static_cast(xbt_malloc(size * sizeof(int))); + memcpy(sendtmpcounts, recvcounts, size * sizeof(int)); + sendtmpdisps = static_cast(xbt_malloc(size * sizeof(int))); + memcpy(sendtmpdisps, recvdisps, size * sizeof(int)); + sendtmptypes = static_cast(xbt_malloc(size * sizeof(MPI_Datatype))); + memcpy(sendtmptypes, recvtypes, size * sizeof(MPI_Datatype)); + } + + for (int i = 0; i < size; i++) { // copy data to avoid bad free + send_size += sendtmpcounts[i] * sendtmptypes[i]->size(); + trace_sendcounts->push_back(sendtmpcounts[i] * sendtmptypes[i]->size()); + } + + TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Alltoallw":"PMPI_Ialltoallw", + new simgrid::instr::VarCollTIData(request==MPI_REQUEST_IGNORED ? "alltoallv":"ialltoallv", -1, send_size, trace_sendcounts, recv_size, + trace_recvcounts, simgrid::smpi::Datatype::encode(sendtmptypes[0]), + simgrid::smpi::Datatype::encode(recvtypes[0]))); + + if(request == MPI_REQUEST_IGNORED) + retval = simgrid::smpi::Colls::alltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts, + recvdisps, recvtypes, comm); + else + retval = simgrid::smpi::Colls::ialltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts, + recvdisps, recvtypes, comm, request); + TRACE_smpi_comm_out(rank); + + if (sendbuf == MPI_IN_PLACE) { + xbt_free(sendtmpbuf); + xbt_free(sendtmpcounts); + xbt_free(sendtmpdisps); + xbt_free(sendtmptypes); + } + } + + smpi_bench_begin(); + return retval; +} diff --git a/src/smpi/colls/alltoallv/alltoallv-ompi-basic-linear.cpp b/src/smpi/colls/alltoallv/alltoallv-ompi-basic-linear.cpp index a665472e95..501d15e383 100644 --- a/src/smpi/colls/alltoallv/alltoallv-ompi-basic-linear.cpp +++ b/src/smpi/colls/alltoallv/alltoallv-ompi-basic-linear.cpp @@ -56,7 +56,7 @@ Coll_alltoallv_ompi_basic_linear::alltoallv(void *sbuf, int *scounts, int *sdisp /* Post all receives first */ for (i = 0; i < size; ++i) { - if (i == rank || 0 == rcounts[i]) { + if (i == rank) { continue; } @@ -72,7 +72,7 @@ Coll_alltoallv_ompi_basic_linear::alltoallv(void *sbuf, int *scounts, int *sdisp /* Now post all sends */ for (i = 0; i < size; ++i) { - if (i == rank || 0 == scounts[i]) { + if (i == rank) { continue; } diff --git a/src/smpi/colls/smpi_coll.cpp b/src/smpi/colls/smpi_coll.cpp index 4662d013c5..36eca5fc52 100644 --- a/src/smpi/colls/smpi_coll.cpp +++ b/src/smpi/colls/smpi_coll.cpp @@ -282,5 +282,23 @@ int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype return MPI_SUCCESS; } +int Colls::alltoallw(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype* sendtypes, + void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm) +{ + MPI_Request request; + int err = Colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, &request); + MPI_Request* requests = request->get_nbc_requests(); + int count = request->get_nbc_requests_size(); + XBT_DEBUG("<%d> wait for %d requests", comm->rank(), count); + Request::waitall(count, requests, MPI_STATUS_IGNORE); + for (int i = 0; i < count; i++) { + if(requests[i]!=MPI_REQUEST_NULL) + Request::unref(&requests[i]); + } + delete[] requests; + Request::unref(&request); + return err; +} + } } diff --git a/src/smpi/colls/smpi_default_selector.cpp b/src/smpi/colls/smpi_default_selector.cpp index 0b5198d673..f2751b96ef 100644 --- a/src/smpi/colls/smpi_default_selector.cpp +++ b/src/smpi/colls/smpi_default_selector.cpp @@ -197,8 +197,6 @@ int Coll_alltoall_default::alltoall( void *sbuf, int scount, MPI_Datatype sdtype return Coll_alltoall_ompi::alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm); } - - int Coll_alltoallv_default::alltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype recvtype, MPI_Comm comm) { diff --git a/src/smpi/colls/smpi_nbc_impl.cpp b/src/smpi/colls/smpi_nbc_impl.cpp index f1db5eac28..6da0fd9a9e 100644 --- a/src/smpi/colls/smpi_nbc_impl.cpp +++ b/src/smpi/colls/smpi_nbc_impl.cpp @@ -256,7 +256,7 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty int count = 0; /* Create all receives that will be posted first */ for (int i = 0; i < size; ++i) { - if (i != rank && recvcounts[i] != 0) { + if (i != rank) { requests[count] = Request::irecv_init(static_cast(recvbuf) + recvdisps[i] * recvext, recvcounts[i], recvtype, i, system_tag, comm); count++; @@ -266,7 +266,7 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty } /* Now create all sends */ for (int i = 0; i < size; ++i) { - if (i != rank && sendcounts[i] != 0) { + if (i != rank) { requests[count] = Request::isend_init(static_cast(sendbuf) + senddisps[i] * sendext, sendcounts[i], sendtype, i, system_tag, comm); count++; @@ -281,6 +281,50 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty return err; } +int Colls::ialltoallw(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype* sendtypes, + void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request){ + const int system_tag = COLL_TAG_ALLTOALLV; + MPI_Request *requests; + + /* Initialize. */ + int rank = comm->rank(); + int size = comm->size(); + (*request) = new Request( nullptr, 0, MPI_BYTE, + rank,rank, COLL_TAG_ALLTOALLV, comm, MPI_REQ_PERSISTENT); + /* Local copy from self */ + int err = (sendcounts[rank]>0 && recvcounts[rank]) ? Datatype::copy(static_cast(sendbuf) + senddisps[rank], sendcounts[rank], sendtypes[rank], + static_cast(recvbuf) + recvdisps[rank], recvcounts[rank], recvtypes[rank]): MPI_SUCCESS; + if (err == MPI_SUCCESS && size > 1) { + /* Initiate all send/recv to/from others. */ + requests = new MPI_Request[2 * (size - 1)]; + int count = 0; + /* Create all receives that will be posted first */ + for (int i = 0; i < size; ++i) { + if (i != rank) { + requests[count] = Request::irecv_init(static_cast(recvbuf) + recvdisps[i], + recvcounts[i], recvtypes[i], i, system_tag, comm); + count++; + }else{ + XBT_DEBUG("<%d> skip request creation [src = %d, recvcounts[src] = %d]", rank, i, recvcounts[i]); + } + } + /* Now create all sends */ + for (int i = 0; i < size; ++i) { + if (i != rank) { + requests[count] = Request::isend_init(static_cast(sendbuf) + senddisps[i] , + sendcounts[i], sendtypes[i], i, system_tag, comm); + count++; + }else{ + XBT_DEBUG("<%d> skip request creation [dst = %d, sendcounts[dst] = %d]", rank, i, sendcounts[i]); + } + } + /* Wait for them all. */ + Request::startall(count, requests); + (*request)->set_nbc_requests(requests, count); + } + return err; +} + int Colls::igather(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request) { diff --git a/src/smpi/include/smpi_coll.hpp b/src/smpi/include/smpi_coll.hpp index 9597fc3515..e4c5df3ac5 100644 --- a/src/smpi/include/smpi_coll.hpp +++ b/src/smpi/include/smpi_coll.hpp @@ -115,7 +115,10 @@ public: MPI_Datatype recvtype, int root, MPI_Comm comm); static int scan(void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm); static int exscan(void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm); - + static int alltoallw + (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, int* recvcounts, + int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm); + //async collectives static int ibarrier(MPI_Comm comm, MPI_Request* request); static int ibcast(void *buf, int count, MPI_Datatype datatype, @@ -149,6 +152,9 @@ public: static int ialltoallv (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype sendtype, void* recvbuf, int* recvcounts, int* recvdisps, MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request); + static int ialltoallw + (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, int* recvcounts, + int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request); static void (*smpi_coll_cleanup_callback)();