Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add MPI_Ireduce
authorAugustin Degomme <adegomme@users.noreply.github.com>
Mon, 1 Apr 2019 19:08:23 +0000 (21:08 +0200)
committerAugustin Degomme <adegomme@users.noreply.github.com>
Mon, 1 Apr 2019 22:46:45 +0000 (00:46 +0200)
src/smpi/bindings/smpi_mpi.cpp
src/smpi/bindings/smpi_pmpi_coll.cpp
src/smpi/colls/smpi_default_selector.cpp
src/smpi/colls/smpi_nbc_impl.cpp
src/smpi/mpi/smpi_request.cpp
teshsuite/smpi/mpich3-test/coll/nonblocking.c
teshsuite/smpi/mpich3-test/coll/nonblocking2.c

index ceb1b89..f3a1f7e 100644 (file)
@@ -154,7 +154,7 @@ WRAPPED_PMPI_CALL(int,MPI_Igather,(void *sendbuf, int sendcount, MPI_Datatype se
 WRAPPED_PMPI_CALL(int,MPI_Igatherv,(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int *recvcounts, int *displs,MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, root, comm, request))
 //WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter_block,(void *sendbuf, void *recvbuf, int recvcount, MPI_Datatype datatype, MPI_Op op,MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcount, datatype, op, comm, request))
 //WRAPPED_PMPI_CALL(int,MPI_Ireduce_scatter,(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, recvcounts, datatype, op, comm, request))
-//WRAPPED_PMPI_CALL(int,MPI_Ireduce,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, root, comm, request))
+WRAPPED_PMPI_CALL(int,MPI_Ireduce,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, root, comm, request))
 //WRAPPED_PMPI_CALL(int,MPI_Iexscan,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, comm, request))
 //WRAPPED_PMPI_CALL(int,MPI_Iscan,(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request),(sendbuf, recvbuf, count, datatype, op, comm, request))
 WRAPPED_PMPI_CALL(int,MPI_Iscatter,(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbuf, int recvcount, MPI_Datatype recvtype,int root, MPI_Comm comm, MPI_Request *request),(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm, request))
index 5f71a34..bd0ac32 100644 (file)
@@ -389,6 +389,11 @@ int PMPI_Iscatterv(void *sendbuf, int *sendcounts, int *displs,
 }
 
 int PMPI_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
+{
+  return PMPI_Ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, MPI_REQUEST_IGNORED);
+}
+
+int PMPI_Ireduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm, MPI_Request* request)
 {
   int retval = 0;
 
@@ -401,12 +406,15 @@ int PMPI_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
   } else {
     int rank = simgrid::s4u::this_actor::get_pid();
 
-    TRACE_smpi_comm_in(rank, __func__,
-                       new simgrid::instr::CollTIData("reduce", root, 0,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED ? "PMPI_Reduce":"PMPI_Ireduce",
+                       new simgrid::instr::CollTIData(request==MPI_REQUEST_IGNORED ? "reduce":"ireduce", root, 0,
                                                       datatype->is_replayable() ? count : count * datatype->size(), -1,
                                                       simgrid::smpi::Datatype::encode(datatype), ""));
+    if(request == MPI_REQUEST_IGNORED)
+      simgrid::smpi::Colls::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
+    else
+      simgrid::smpi::Colls::ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, request);
 
-    simgrid::smpi::Colls::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
 
     retval = MPI_SUCCESS;
     TRACE_smpi_comm_out(rank);
index dca9430..1aebcf6 100644 (file)
@@ -89,70 +89,13 @@ int Coll_scatter_default::scatter(void *sendbuf, int sendcount, MPI_Datatype sen
 int Coll_reduce_default::reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root,
                      MPI_Comm comm)
 {
-  const int system_tag = COLL_TAG_REDUCE;
-  MPI_Aint lb = 0;
-  MPI_Aint dataext = 0;
-
-  char* sendtmpbuf = static_cast<char *>(sendbuf);
-
-  int rank = comm->rank();
-  int size = comm->size();
-  if (size <= 0)
-    return MPI_ERR_COMM;
   //non commutative case, use a working algo from openmpi
   if (op != MPI_OP_NULL && not op->is_commutative()) {
-    return Coll_reduce_ompi_basic_linear::reduce(sendtmpbuf, recvbuf, count, datatype, op, root, comm);
+    return Coll_reduce_ompi_basic_linear::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
   }
-
-  if( sendbuf == MPI_IN_PLACE ) {
-    sendtmpbuf = static_cast<char *>(smpi_get_tmp_sendbuffer(count*datatype->get_extent()));
-    Datatype::copy(recvbuf, count, datatype,sendtmpbuf, count, datatype);
-  }
-
-  if(rank != root) {
-    // Send buffer to root
-    Request::send(sendtmpbuf, count, datatype, root, system_tag, comm);
-  } else {
-    datatype->extent(&lb, &dataext);
-    // Local copy from root
-    if (sendtmpbuf != nullptr && recvbuf != nullptr)
-      Datatype::copy(sendtmpbuf, count, datatype, recvbuf, count, datatype);
-    // Receive buffers from senders
-    MPI_Request *requests = xbt_new(MPI_Request, size - 1);
-    void **tmpbufs = xbt_new(void *, size - 1);
-    int index = 0;
-    for (int src = 0; src < size; src++) {
-      if (src != root) {
-        tmpbufs[index] = smpi_get_tmp_sendbuffer(count * dataext);
-        requests[index] =
-          Request::irecv_init(tmpbufs[index], count, datatype, src, system_tag, comm);
-        index++;
-      }
-    }
-    // Wait for completion of irecv's.
-    Request::startall(size - 1, requests);
-    for (int src = 0; src < size - 1; src++) {
-      index = Request::waitany(size - 1, requests, MPI_STATUS_IGNORE);
-      XBT_DEBUG("finished waiting any request with index %d", index);
-      if(index == MPI_UNDEFINED) {
-        break;
-      }else{
-        Request::unref(&requests[index]);
-      }
-      if(op) /* op can be MPI_OP_NULL that does nothing */
-        if(op!=MPI_OP_NULL) op->apply( tmpbufs[index], recvbuf, &count, datatype);
-    }
-      for(index = 0; index < size - 1; index++) {
-        smpi_free_tmp_buffer(tmpbufs[index]);
-      }
-    xbt_free(tmpbufs);
-    xbt_free(requests);
-
-  }
-  if( sendbuf == MPI_IN_PLACE ) {
-    smpi_free_tmp_buffer(sendtmpbuf);
-  }
-  return MPI_SUCCESS;
+  MPI_Request request;
+  Colls::ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, &request);
+  return Request::wait(&request, MPI_STATUS_IGNORE);
 }
 
 int Coll_allreduce_default::allreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
index 6da0fd9..0cb25f0 100644 (file)
@@ -442,5 +442,64 @@ int Colls::iscatterv(void *sendbuf, int *sendcounts, int *displs, MPI_Datatype s
   }
   return MPI_SUCCESS;
 }
+
+int Colls::ireduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root,
+                     MPI_Comm comm, MPI_Request* request)
+{
+  const int system_tag = COLL_TAG_REDUCE;
+  MPI_Aint lb = 0;
+  MPI_Aint dataext = 0;
+  MPI_Request* requests;
+
+  char* sendtmpbuf = static_cast<char *>(sendbuf);
+
+  int rank = comm->rank();
+  int size = comm->size();
+
+  if (size <= 0)
+    return MPI_ERR_COMM;
+
+  if( sendbuf == MPI_IN_PLACE ) {
+    sendtmpbuf = static_cast<char *>(smpi_get_tmp_sendbuffer(count*datatype->get_extent()));
+    Datatype::copy(recvbuf, count, datatype,sendtmpbuf, count, datatype);
+  }
+
+  if(rank == root){
+    (*request) =  new Request( recvbuf, count, datatype,
+                         rank,rank, COLL_TAG_REDUCE, comm, MPI_REQ_PERSISTENT, op);
+  }
+  else
+    (*request) = new Request( nullptr, count, datatype,
+                         rank,rank, COLL_TAG_REDUCE, comm, MPI_REQ_PERSISTENT);
+
+  if(rank != root) {
+    // Send buffer to root
+    requests = new MPI_Request[1];
+    requests[0]=Request::isend(sendtmpbuf, count, datatype, root, system_tag, comm);
+    (*request)->set_nbc_requests(requests, 1);
+  } else {
+    datatype->extent(&lb, &dataext);
+    // Local copy from root
+    if (sendtmpbuf != nullptr && recvbuf != nullptr)
+      Datatype::copy(sendtmpbuf, count, datatype, recvbuf, count, datatype);
+    // Receive buffers from senders
+    MPI_Request *requests = new MPI_Request[size - 1];
+    int index = 0;
+    for (int src = 0; src < size; src++) {
+      if (src != root) {
+        requests[index] =
+          Request::irecv_init(smpi_get_tmp_sendbuffer(count * dataext), count, datatype, src, system_tag, comm);
+        index++;
+      }
+    }
+    // Wait for completion of irecv's.
+    Request::startall(size - 1, requests);
+    (*request)->set_nbc_requests(requests, size - 1);
+  }    
+  if( sendbuf == MPI_IN_PLACE ) {
+    smpi_free_tmp_buffer(sendtmpbuf);
+  }
+  return MPI_SUCCESS;
+}
 }
 }
index 0762073..5cf3a49 100644 (file)
@@ -867,6 +867,16 @@ int Request::wait(MPI_Request * request, MPI_Status * status)
   if ((*request)->nbc_requests_size_>0){
     ret = waitall((*request)->nbc_requests_size_, (*request)->nbc_requests_, MPI_STATUSES_IGNORE);
     for (int i = 0; i < (*request)->nbc_requests_size_; i++) {
+      if((*request)->buf_!=nullptr && (*request)->nbc_requests_[i]!=MPI_REQUEST_NULL){//reduce case
+        void * buf=(*request)->nbc_requests_[i]->buf_;
+        if((*request)->old_type_->flags() & DT_FLAG_DERIVED)
+          buf=(*request)->nbc_requests_[i]->old_buf_;
+        if((*request)->op_!=MPI_OP_NULL){
+          int count=(*request)->size_/ (*request)->old_type_->size();
+          (*request)->op_->apply(buf, (*request)->buf_, &count, (*request)->old_type_);
+        }
+        smpi_free_tmp_buffer(buf);
+      }
       if((*request)->nbc_requests_[i]!=MPI_REQUEST_NULL)
         Request::unref(&((*request)->nbc_requests_[i]));
     }
index d3b6646..0038da1 100644 (file)
@@ -153,14 +153,14 @@ int main(int argc, char **argv)
     MPI_Ialltoallw(MPI_IN_PLACE, NULL, NULL, NULL, rbuf, rcounts, rdispls, types, comm, &req);
     MPI_Wait(&req, MPI_STATUS_IGNORE);
 
-/*    MPI_Ireduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req);*/
-/*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
+    MPI_Ireduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req);
+    MPI_Wait(&req, MPI_STATUS_IGNORE);
 
-/*    if (0 == rank)*/
-/*        MPI_Ireduce(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req);*/
-/*    else*/
-/*        MPI_Ireduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req);*/
-/*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
+    if (0 == rank)
+        MPI_Ireduce(MPI_IN_PLACE, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req);
+    else
+        MPI_Ireduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, 0, comm, &req);
+    MPI_Wait(&req, MPI_STATUS_IGNORE);
 
 /*    MPI_Iallreduce(sbuf, rbuf, NUM_INTS, MPI_INT, MPI_SUM, comm, &req);*/
 /*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
index c4daa76..72d774f 100644 (file)
@@ -108,42 +108,41 @@ int main(int argc, char **argv)
     MPI_Wait(&req, MPI_STATUS_IGNORE);
 
     /* MPI_Ireduce */
-/*    for (i = 0; i < COUNT; ++i) {*/
-/*        buf[i] = rank + i;*/
-/*        recvbuf[i] = 0xdeadbeef;*/
-/*    }*/
-/*    MPI_Ireduce(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD, &req);*/
-/*    MPI_Wait(&req, MPI_STATUS_IGNORE);*/
-/*    if (rank == 0) {*/
-/*        for (i = 0; i < COUNT; ++i) {*/
-/*            if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))*/
-/*                printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],*/
-/*                       ((size * (size - 1) / 2) + (i * size)));*/
-/*            my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));*/
-/*        }*/
-/*    }*/
+    for (i = 0; i < COUNT; ++i) {
+        buf[i] = rank + i;
+        recvbuf[i] = 0xdeadbeef;
+    }
+    MPI_Ireduce(buf, recvbuf, COUNT, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD, &req);
+    MPI_Wait(&req, MPI_STATUS_IGNORE);
+    if (rank == 0) {
+        for (i = 0; i < COUNT; ++i) {
+            if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))
+                printf("aa got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],
+                       ((size * (size - 1) / 2) + (i * size)));
+            my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));
+        }
+    }
 
     /* same again, use a user op and free it before the wait */
-/*    {*/
-/*        MPI_Op op = MPI_OP_NULL;*/
-/*        MPI_Op_create(sum_fn, 1, &op);*/
-
-/*        for (i = 0; i < COUNT; ++i) {*/
-/*            buf[i] = rank + i;*/
-/*            recvbuf[i] = 0xdeadbeef;*/
-/*        }*/
-/*        MPI_Ireduce(buf, recvbuf, COUNT, MPI_INT, op, 0, MPI_COMM_WORLD, &req);*/
-/*        MPI_Op_free(&op);*/
-/*        MPI_Wait(&req, MPI_STATUS_IGNORE);*/
-/*        if (rank == 0) {*/
-/*            for (i = 0; i < COUNT; ++i) {*/
-/*                if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))*/
-/*                    printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],*/
-/*                           ((size * (size - 1) / 2) + (i * size)));*/
-/*                my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));*/
-/*            }*/
-/*        }*/
-/*    }*/
+    {
+        MPI_Op op = MPI_OP_NULL;
+        MPI_Op_create(sum_fn, 1, &op);
+        for (i = 0; i < COUNT; ++i) {
+            buf[i] = rank + i;
+            recvbuf[i] = 0xdeadbeef;
+        }
+        MPI_Ireduce(buf, recvbuf, COUNT, MPI_INT, op, 0, MPI_COMM_WORLD, &req);
+        MPI_Op_free(&op);
+        MPI_Wait(&req, MPI_STATUS_IGNORE);
+        if (rank == 0) {
+            for (i = 0; i < COUNT; ++i) {
+                if (recvbuf[i] != ((size * (size - 1) / 2) + (i * size)))
+                    printf("got recvbuf[%d]=%d, expected %d\n", i, recvbuf[i],
+                           ((size * (size - 1) / 2) + (i * size)));
+                my_assert(recvbuf[i] == ((size * (size - 1) / 2) + (i * size)));
+            }
+        }
+    }
 
     /* MPI_Iallreduce */
 /*    for (i = 0; i < COUNT; ++i) {*/