Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add MPI_Alltoallw and MPI_Ialltoallw
[simgrid.git] / src / smpi / colls / smpi_nbc_impl.cpp
index f1db5ea..6da0fd9 100644 (file)
@@ -256,7 +256,7 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty
     int count = 0;
     /* Create all receives that will be posted first */
     for (int i = 0; i < size; ++i) {
-      if (i != rank && recvcounts[i] != 0) {
+      if (i != rank) {
         requests[count] = Request::irecv_init(static_cast<char *>(recvbuf) + recvdisps[i] * recvext,
                                           recvcounts[i], recvtype, i, system_tag, comm);
         count++;
@@ -266,7 +266,7 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty
     }
     /* Now create all sends  */
     for (int i = 0; i < size; ++i) {
-      if (i != rank && sendcounts[i] != 0) {
+      if (i != rank) {
       requests[count] = Request::isend_init(static_cast<char *>(sendbuf) + senddisps[i] * sendext,
                                         sendcounts[i], sendtype, i, system_tag, comm);
       count++;
@@ -281,6 +281,50 @@ int Colls::ialltoallv(void *sendbuf, int *sendcounts, int *senddisps, MPI_Dataty
   return err;
 }
 
+int Colls::ialltoallw(void *sendbuf, int *sendcounts, int *senddisps, MPI_Datatype* sendtypes,
+                              void *recvbuf, int *recvcounts, int *recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request){
+  const int system_tag = COLL_TAG_ALLTOALLV;
+  MPI_Request *requests;
+
+  /* Initialize. */
+  int rank = comm->rank();
+  int size = comm->size();
+  (*request) = new Request( nullptr, 0, MPI_BYTE,
+                         rank,rank, COLL_TAG_ALLTOALLV, comm, MPI_REQ_PERSISTENT);
+  /* Local copy from self */
+  int err = (sendcounts[rank]>0 && recvcounts[rank]) ? Datatype::copy(static_cast<char *>(sendbuf) + senddisps[rank], sendcounts[rank], sendtypes[rank],
+                               static_cast<char *>(recvbuf) + recvdisps[rank], recvcounts[rank], recvtypes[rank]): MPI_SUCCESS;
+  if (err == MPI_SUCCESS && size > 1) {
+    /* Initiate all send/recv to/from others. */
+    requests = new MPI_Request[2 * (size - 1)];
+    int count = 0;
+    /* Create all receives that will be posted first */
+    for (int i = 0; i < size; ++i) {
+      if (i != rank) {
+        requests[count] = Request::irecv_init(static_cast<char *>(recvbuf) + recvdisps[i],
+                                          recvcounts[i], recvtypes[i], i, system_tag, comm);
+        count++;
+      }else{
+        XBT_DEBUG("<%d> skip request creation [src = %d, recvcounts[src] = %d]", rank, i, recvcounts[i]);
+      }
+    }
+    /* Now create all sends  */
+    for (int i = 0; i < size; ++i) {
+      if (i != rank) {
+      requests[count] = Request::isend_init(static_cast<char *>(sendbuf) + senddisps[i] ,
+                                        sendcounts[i], sendtypes[i], i, system_tag, comm);
+      count++;
+      }else{
+        XBT_DEBUG("<%d> skip request creation [dst = %d, sendcounts[dst] = %d]", rank, i, sendcounts[i]);
+      }
+    }
+    /* Wait for them all. */
+    Request::startall(count, requests);
+    (*request)->set_nbc_requests(requests, count);
+  }
+  return err;
+}
+
 int Colls::igather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
                      void *recvbuf, int recvcount, MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request)
 {