Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add MPI_Alltoallw and MPI_Ialltoallw
authordegomme <adegomme@users.noreply.github.com>
Fri, 29 Mar 2019 13:24:17 +0000 (14:24 +0100)
committerdegomme <adegomme@users.noreply.github.com>
Fri, 29 Mar 2019 13:24:17 +0000 (14:24 +0100)
include/smpi/smpi.h
src/smpi/bindings/smpi_mpi.cpp
src/smpi/bindings/smpi_pmpi_coll.cpp
src/smpi/colls/alltoallv/alltoallv-ompi-basic-linear.cpp
src/smpi/colls/smpi_coll.cpp
src/smpi/colls/smpi_default_selector.cpp
src/smpi/colls/smpi_nbc_impl.cpp
src/smpi/include/smpi_coll.hpp

index 4af1b13..ad51171 100644 (file)
@@ -600,7 +600,9 @@ MPI_CALL(XBT_PUBLIC int, MPI_Ialltoall, (void* sendbuf, int sendcount, MPI_Datat
 MPI_CALL(XBT_PUBLIC int, MPI_Ialltoallv,
          (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype sendtype, void* recvbuf, int* recvcounts,
           int* recvdisps, MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request));
-          
+MPI_CALL(XBT_PUBLIC int, MPI_Ialltoallw,
+         (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, int* recvcounts,
+          int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request));
 MPI_CALL(XBT_PUBLIC int, MPI_Gather, (void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf, int recvcount,
                                       MPI_Datatype recvtype, int root, MPI_Comm comm));
 MPI_CALL(XBT_PUBLIC int, MPI_Gatherv, (void* sendbuf, int sendcount, MPI_Datatype sendtype, void* recvbuf,
index efe7ed9..ceb1b89 100644 (file)
@@ -80,6 +80,7 @@ WRAPPED_PMPI_CALL(int,MPI_Alloc_mem,(MPI_Aint size, MPI_Info info, void *baseptr
 WRAPPED_PMPI_CALL(int,MPI_Allreduce,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm),(sendbuf, recvbuf, count, datatype, op, comm))
 WRAPPED_PMPI_CALL(int,MPI_Alltoall,(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbuf, int recvcount,MPI_Datatype recvtype, MPI_Comm comm),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm))
 WRAPPED_PMPI_CALL(int,MPI_Alltoallv,(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype recvtype, MPI_Comm comm),(sendbuf, sendcounts, senddisps, sendtype, recvbuf, recvcounts, recvdisps, recvtype, comm))
+WRAPPED_PMPI_CALL(int,MPI_Alltoallw,( void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype *sendtypes, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype *recvtypes, MPI_Comm comm),( sendbuf, sendcnts, sdispls, sendtypes, recvbuf, recvcnts, rdispls, recvtypes, comm))
 WRAPPED_PMPI_CALL(int,MPI_Attr_delete,(MPI_Comm comm, int keyval) ,(comm, keyval))
 WRAPPED_PMPI_CALL(int,MPI_Attr_get,(MPI_Comm comm, int keyval, void* attr_value, int* flag) ,(comm, keyval, attr_value, flag))
 WRAPPED_PMPI_CALL(int,MPI_Attr_put,(MPI_Comm comm, int keyval, void* attr_value) ,(comm, keyval, attr_value))
@@ -148,6 +149,7 @@ WRAPPED_PMPI_CALL(int,MPI_Iallgatherv,(void *sendbuf, int sendcount, MPI_Datatyp
 WRAPPED_PMPI_CALL(int,MPI_Iallreduce,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, comm, request))
 WRAPPED_PMPI_CALL(int,MPI_Ialltoall,(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbuf, int recvcount,MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm, request))
 WRAPPED_PMPI_CALL(int,MPI_Ialltoallv,(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcounts, senddisps, sendtype, recvbuf, recvcounts, recvdisps, recvtype, comm, request))
+WRAPPED_PMPI_CALL(int,MPI_Ialltoallw,( void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype *sendtypes, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype *recvtypes, MPI_Comm comm, MPI_Request *request),( sendbuf, sendcnts, sdispls, sendtypes, recvbuf, recvcnts, rdispls, recvtypes, comm, request))
 WRAPPED_PMPI_CALL(int,MPI_Igather,(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, request))
 WRAPPED_PMPI_CALL(int,MPI_Igatherv,(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, request))
 //WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter_block,(void *sendbuf, void *recvbuf, int recvcount, MPI_Datatype datatype, MPI_Op op,MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcount, datatype, op, comm, request))
@@ -310,7 +312,6 @@ WRAPPED_PMPI_CALL(int,MPI_Status_set_elements,( MPI_Status *status, MPI_Datatype
 UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Add_error_class,( int *errorclass),( errorclass))
 UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Add_error_code,(int errorclass, int *errorcode),(errorclass, errorcode))
 UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Add_error_string,( int errorcode, char *string),(errorcode, string))
-UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Alltoallw,( void *sendbuf, int *sendcnts, int *sdispls, MPI_Datatype *sendtypes, void *recvbuf, int *recvcnts, int *rdispls, MPI_Datatype *recvtypes, MPI_Comm comm),( sendbuf, sendcnts, sdispls, sendtypes, recvbuf, recvcnts, rdispls, recvtypes, comm))
 UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Bsend_init,(void* buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request* request),(buf, count, datatype, dest, tag, comm, request))
 UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Bsend,(void* buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm) ,(buf, count, datatype, dest, tag, comm))
 UNIMPLEMENTED_WRAPPED_PMPI_CALL(int,MPI_Buffer_attach,(void* buffer, int size) ,(buffer, size))
index 91bbfd8..9e3f6ea 100644 (file)
@@ -774,3 +774,88 @@ int PMPI_Ialltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
   smpi_bench_begin();
   return retval;
 }
+
+int PMPI_Alltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf,
+                   int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm)
+{
+  return PMPI_Ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, MPI_REQUEST_IGNORED);
+}
+
+int PMPI_Ialltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf,
+                   int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request)
+{
+  int retval = 0;
+
+  smpi_bench_end();
+
+  if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if ((sendbuf != MPI_IN_PLACE && sendtypes == nullptr)  || recvtypes == nullptr) {
+    retval = MPI_ERR_TYPE;
+  } else if ((sendbuf != MPI_IN_PLACE && (sendcounts == nullptr || senddisps == nullptr)) || recvcounts == nullptr ||
+             recvdisps == nullptr) {
+    retval = MPI_ERR_ARG;
+  } else if (request == nullptr){
+    retval = MPI_ERR_ARG;
+  }  else {
+    int rank                           = simgrid::s4u::this_actor::get_pid();
+    int size                           = comm->size();
+    int send_size                      = 0;
+    int recv_size                      = 0;
+    std::vector<int>* trace_sendcounts = new std::vector<int>;
+    std::vector<int>* trace_recvcounts = new std::vector<int>;
+
+    void* sendtmpbuf           = static_cast<char*>(sendbuf);
+    int* sendtmpcounts         = sendcounts;
+    int* sendtmpdisps          = senddisps;
+    MPI_Datatype* sendtmptypes = sendtypes;
+    unsigned long maxsize                = 0;
+    for (int i = 0; i < size; i++) { // copy data to avoid bad free
+      if(recvtypes[i]==MPI_DATATYPE_NULL)
+        return MPI_ERR_TYPE;
+      recv_size += recvcounts[i] * recvtypes[i]->size();
+      trace_recvcounts->push_back(recvcounts[i] * recvtypes[i]->size());
+      if ((recvdisps[i] + (recvcounts[i] * recvtypes[i]->size())) > maxsize)
+        maxsize = recvdisps[i] + (recvcounts[i] * recvtypes[i]->size());
+    }
+
+    if (sendbuf == MPI_IN_PLACE) {
+      sendtmpbuf = static_cast<void*>(xbt_malloc(maxsize));
+      memcpy(sendtmpbuf, recvbuf, maxsize);
+      sendtmpcounts = static_cast<int*>(xbt_malloc(size * sizeof(int)));
+      memcpy(sendtmpcounts, recvcounts, size * sizeof(int));
+      sendtmpdisps = static_cast<int*>(xbt_malloc(size * sizeof(int)));
+      memcpy(sendtmpdisps, recvdisps, size * sizeof(int));
+      sendtmptypes = static_cast<MPI_Datatype*>(xbt_malloc(size * sizeof(MPI_Datatype)));
+      memcpy(sendtmptypes, recvtypes, size * sizeof(MPI_Datatype));
+    }
+
+    for (int i = 0; i < size; i++) { // copy data to avoid bad free
+      send_size += sendtmpcounts[i] * sendtmptypes[i]->size();
+      trace_sendcounts->push_back(sendtmpcounts[i] * sendtmptypes[i]->size());
+    }
+
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Alltoallw":"PMPI_Ialltoallw",
+                       new simgrid::instr::VarCollTIData(request==MPI_REQUEST_IGNORED ? "alltoallv":"ialltoallv", -1, send_size, trace_sendcounts, recv_size,
+                                                         trace_recvcounts, simgrid::smpi::Datatype::encode(sendtmptypes[0]),
+                                                         simgrid::smpi::Datatype::encode(recvtypes[0])));
+
+    if(request == MPI_REQUEST_IGNORED)
+      retval = simgrid::smpi::Colls::alltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts,
+                                    recvdisps, recvtypes, comm);
+    else
+      retval = simgrid::smpi::Colls::ialltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts,
+                                    recvdisps, recvtypes, comm, request);
+    TRACE_smpi_comm_out(rank);
+
+    if (sendbuf == MPI_IN_PLACE) {
+      xbt_free(sendtmpbuf);
+      xbt_free(sendtmpcounts);
+      xbt_free(sendtmpdisps);
+      xbt_free(sendtmptypes);
+    }
+  }
+
+  smpi_bench_begin();
+  return retval;
+}
index a665472..501d15e 100644 (file)
@@ -56,7 +56,7 @@ Coll_alltoallv_ompi_basic_linear::alltoallv(void *sbuf, int *scounts, int *sdisp
 
     /* Post all receives first */
     for (i = 0; i < size; ++i) {
-        if (i == rank || 0 == rcounts[i]) {
+        if (i == rank) {
             continue;
         }
 
@@ -72,7 +72,7 @@ Coll_alltoallv_ompi_basic_linear::alltoallv(void *sbuf, int *scounts, int *sdisp
 
     /* Now post all sends */
     for (i = 0; i < size; ++i) {
-        if (i == rank || 0 == scounts[i]) {
+        if (i == rank) {
             continue;
         }
 
index 4662d01..36eca5f 100644 (file)
@@ -282,5 +282,23 @@ int Colls::exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype
   return MPI_SUCCESS;
 }
 
+int Colls::alltoallw(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype* sendtypes,
+                              void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm)
+{
+  MPI_Request request;
+  int err = Colls::ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, &request);
+  MPI_Request* requests = request->get_nbc_requests();
+  int count = request->get_nbc_requests_size();
+  XBT_DEBUG("<%d> wait for %d requests", comm->rank(), count);
+  Request::waitall(count, requests, MPI_STATUS_IGNORE);
+  for (int i = 0; i < count; i++) {
+    if(requests[i]!=MPI_REQUEST_NULL)
+      Request::unref(&requests[i]);
+  }
+  delete[] requests;
+  Request::unref(&request);
+  return err;
+}
+
 }
 }
index 0b5198d..f2751b9 100644 (file)
@@ -197,8 +197,6 @@ int Coll_alltoall_default::alltoall( void *sbuf, int scount, MPI_Datatype sdtype
   return Coll_alltoall_ompi::alltoall(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm);
 }
 
-
-
 int Coll_alltoallv_default::alltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype sendtype,
                               void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype recvtype, MPI_Comm comm)
 {
index f1db5ea..6da0fd9 100644 (file)
@@ -256,7 +256,7 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty
     int count = 0;
     /* Create all receives that will be posted first */
     for (int i = 0; i < size; ++i) {
-      if (i != rank && recvcounts[i] != 0) {
+      if (i != rank) {
         requests[count] = Request::irecv_init(static_cast<char *>(recvbuf) + recvdisps[i] * recvext,
                                           recvcounts[i], recvtype, i, system_tag, comm);
         count++;
@@ -266,7 +266,7 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty
     }
     /* Now create all sends  */
     for (int i = 0; i < size; ++i) {
-      if (i != rank && sendcounts[i] != 0) {
+      if (i != rank) {
       requests[count] = Request::isend_init(static_cast<char *>(sendbuf) + senddisps[i] * sendext,
                                         sendcounts[i], sendtype, i, system_tag, comm);
       count++;
@@ -281,6 +281,50 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty
   return err;
 }
 
+int Colls::ialltoallw(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype* sendtypes,
+                              void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request){
+  const int system_tag = COLL_TAG_ALLTOALLV;
+  MPI_Request *requests;
+
+  /* Initialize. */
+  int rank = comm->rank();
+  int size = comm->size();
+  (*request) = new Request( nullptr, 0, MPI_BYTE,
+                         rank,rank, COLL_TAG_ALLTOALLV, comm, MPI_REQ_PERSISTENT);
+  /* Local copy from self */
+  int err = (sendcounts[rank]>0 && recvcounts[rank]) ? Datatype::copy(static_cast<char *>(sendbuf) + senddisps[rank], sendcounts[rank], sendtypes[rank],
+                               static_cast<char *>(recvbuf) + recvdisps[rank], recvcounts[rank], recvtypes[rank]): MPI_SUCCESS;
+  if (err == MPI_SUCCESS && size > 1) {
+    /* Initiate all send/recv to/from others. */
+    requests = new MPI_Request[2 * (size - 1)];
+    int count = 0;
+    /* Create all receives that will be posted first */
+    for (int i = 0; i < size; ++i) {
+      if (i != rank) {
+        requests[count] = Request::irecv_init(static_cast<char *>(recvbuf) + recvdisps[i],
+                                          recvcounts[i], recvtypes[i], i, system_tag, comm);
+        count++;
+      }else{
+        XBT_DEBUG("<%d> skip request creation [src = %d, recvcounts[src] = %d]", rank, i, recvcounts[i]);
+      }
+    }
+    /* Now create all sends  */
+    for (int i = 0; i < size; ++i) {
+      if (i != rank) {
+      requests[count] = Request::isend_init(static_cast<char *>(sendbuf) + senddisps[i] ,
+                                        sendcounts[i], sendtypes[i], i, system_tag, comm);
+      count++;
+      }else{
+        XBT_DEBUG("<%d> skip request creation [dst = %d, sendcounts[dst] = %d]", rank, i, sendcounts[i]);
+      }
+    }
+    /* Wait for them all. */
+    Request::startall(count, requests);
+    (*request)->set_nbc_requests(requests, count);
+  }
+  return err;
+}
+
 int Colls::igather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
                      void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request)
 {
index 9597fc3..e4c5df3 100644 (file)
@@ -115,7 +115,10 @@ public:
                       MPI_Datatype recvtype, int root, MPI_Comm comm);
   static int scan(void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm);
   static int exscan(void* sendbuf, void* recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm);
-  
+  static int alltoallw
+         (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, int* recvcounts,
+          int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm);
+
   //async collectives
   static int ibarrier(MPI_Comm comm, MPI_Request* request);
   static int ibcast(void *buf, int count, MPI_Datatype datatype, 
@@ -149,6 +152,9 @@ public:
   static int ialltoallv
          (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype sendtype, void* recvbuf, int* recvcounts,
           int* recvdisps, MPI_Datatype recvtype, MPI_Comm comm, MPI_Request *request);
+  static int ialltoallw
+         (void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf, int* recvcounts,
+          int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request);
 
 
   static void (*smpi_coll_cleanup_callback)();