Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MPI_Iexscan, MPI_Iscan
[simgrid.git] / src / smpi / bindings / smpi_pmpi_coll.cpp
index 4d2e9e8..825a544 100644 (file)
@@ -35,7 +35,7 @@ int PMPI_Ibarrier(MPI_Comm comm, MPI_Request *request)
     retval = MPI_ERR_ARG;
   }else{
     int rank = simgrid::s4u::this_actor::get_pid();
-    TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData(request==MPI_REQUEST_IGNORED? "barrier" : "ibarrier"));
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED? "PMPI_Barrier" : "PMPI_Ibarrier", new simgrid::instr::NoOpTIData(request==MPI_REQUEST_IGNORED? "barrier" : "ibarrier"));
     if(request==MPI_REQUEST_IGNORED){
       simgrid::smpi::Colls::barrier(comm);
       //Barrier can be used to synchronize RMA calls. Finish all requests from comm before.
@@ -61,7 +61,7 @@ int PMPI_Ibcast(void *buf, int count, MPI_Datatype datatype,
     retval = MPI_ERR_ARG;
   } else {
     int rank = simgrid::s4u::this_actor::get_pid();
-    TRACE_smpi_comm_in(rank, __func__,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Bcast":"PMPI_Ibcast",
                        new simgrid::instr::CollTIData(request==MPI_REQUEST_IGNORED?"bcast":"ibcast", root, -1.0,
                                                       datatype->is_replayable() ? count : count * datatype->size(), -1,
                                                       simgrid::smpi::Datatype::encode(datatype), ""));
@@ -70,6 +70,9 @@ int PMPI_Ibcast(void *buf, int count, MPI_Datatype datatype,
         simgrid::smpi::Colls::bcast(buf, count, datatype, root, comm);
       else
         simgrid::smpi::Colls::ibcast(buf, count, datatype, root, comm, request);
+    } else {
+      if(request!=MPI_REQUEST_IGNORED)
+        *request = MPI_REQUEST_NULL;
     }
     retval = MPI_SUCCESS;
 
@@ -112,7 +115,7 @@ int PMPI_Igather(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvb
     int rank = simgrid::s4u::this_actor::get_pid();
 
     TRACE_smpi_comm_in(
-        rank, __func__,
+        rank, request==MPI_REQUEST_IGNORED?"PMPI_Gather":"PMPI_Igather",
         new simgrid::instr::CollTIData(
             request==MPI_REQUEST_IGNORED ? "gather":"igather", root, -1.0, sendtmptype->is_replayable() ? sendtmpcount : sendtmpcount * sendtmptype->size(),
             (comm->rank() != root || recvtype->is_replayable()) ? recvcount : recvcount * recvtype->size(),
@@ -171,7 +174,7 @@ int PMPI_Igatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *rec
         trace_recvcounts->push_back(recvcounts[i] * dt_size_recv);
     }
 
-    TRACE_smpi_comm_in(rank, __func__,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Gatherv":"PMPI_Igatherv",
                        new simgrid::instr::VarCollTIData(
                            request==MPI_REQUEST_IGNORED ? "gatherv":"igatherv", root,
                            sendtmptype->is_replayable() ? sendtmpcount : sendtmpcount * sendtmptype->size(), nullptr,
@@ -218,7 +221,7 @@ int PMPI_Iallgather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
     }
     int rank = simgrid::s4u::this_actor::get_pid();
 
-    TRACE_smpi_comm_in(rank, __func__,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Allgather":"PMPI_Iallggather",
                        new simgrid::instr::CollTIData(
                            request==MPI_REQUEST_IGNORED ? "allgather" : "iallgather", -1, -1.0, sendtype->is_replayable() ? sendcount : sendcount * sendtype->size(),
                            recvtype->is_replayable() ? recvcount : recvcount * recvtype->size(),
@@ -269,7 +272,7 @@ int PMPI_Iallgatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
     for (int i = 0; i < comm->size(); i++) // copy data to avoid bad free
       trace_recvcounts->push_back(recvcounts[i] * dt_size_recv);
 
-    TRACE_smpi_comm_in(rank, __func__,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Allgatherv":"PMPI_Iallgatherv",
                        new simgrid::instr::VarCollTIData(
                            request==MPI_REQUEST_IGNORED ? "allgatherv" : "iallgatherv", -1, sendtype->is_replayable() ? sendcount : sendcount * sendtype->size(),
                            nullptr, dt_size_recv, trace_recvcounts, simgrid::smpi::Datatype::encode(sendtype),
@@ -306,7 +309,9 @@ int PMPI_Iscatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
   } else if ((sendbuf == recvbuf) ||
       ((comm->rank()==root) && sendcount>0 && (sendbuf == nullptr))){
     retval = MPI_ERR_BUFFER;
-  }else {
+  }else if (request == nullptr){
+    retval = MPI_ERR_ARG;
+  } else {
 
     if (recvbuf == MPI_IN_PLACE) {
       recvtype  = sendtype;
@@ -315,7 +320,7 @@ int PMPI_Iscatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
     int rank = simgrid::s4u::this_actor::get_pid();
 
     TRACE_smpi_comm_in(
-        rank, __func__,
+        rank, request==MPI_REQUEST_IGNORED?"PMPI_Scatter":"PMPI_Iscatter",
         new simgrid::instr::CollTIData(
             request==MPI_REQUEST_IGNORED ? "scatter" : "iscatter", root, -1.0,
             (comm->rank() != root || sendtype->is_replayable()) ? sendcount : sendcount * sendtype->size(),
@@ -368,7 +373,7 @@ int PMPI_Iscatterv(void *sendbuf, int *sendcounts, int *displs,
         trace_sendcounts->push_back(sendcounts[i] * dt_size_send);
     }
 
-    TRACE_smpi_comm_in(rank, __func__,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Scatterv":"PMPI_Iscatterv",
                        new simgrid::instr::VarCollTIData(
                            request==MPI_REQUEST_IGNORED ? "scatterv":"iscatterv", root, dt_size_send, trace_sendcounts,
                            recvtype->is_replayable() ? recvcount : recvcount * recvtype->size(), nullptr,
@@ -386,6 +391,11 @@ int PMPI_Iscatterv(void *sendbuf, int *sendcounts, int *displs,
 }
 
 int PMPI_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm)
+{
+  return PMPI_Ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, MPI_REQUEST_IGNORED);
+}
+
+int PMPI_Ireduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm, MPI_Request* request)
 {
   int retval = 0;
 
@@ -395,15 +405,20 @@ int PMPI_Reduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
     retval = MPI_ERR_COMM;
   } else if (not datatype->is_valid() || op == MPI_OP_NULL) {
     retval = MPI_ERR_ARG;
+  } else if (request == nullptr){
+    retval = MPI_ERR_ARG;
   } else {
     int rank = simgrid::s4u::this_actor::get_pid();
 
-    TRACE_smpi_comm_in(rank, __func__,
-                       new simgrid::instr::CollTIData("reduce", root, 0,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED ? "PMPI_Reduce":"PMPI_Ireduce",
+                       new simgrid::instr::CollTIData(request==MPI_REQUEST_IGNORED ? "reduce":"ireduce", root, 0,
                                                       datatype->is_replayable() ? count : count * datatype->size(), -1,
                                                       simgrid::smpi::Datatype::encode(datatype), ""));
+    if(request == MPI_REQUEST_IGNORED)
+      simgrid::smpi::Colls::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
+    else
+      simgrid::smpi::Colls::ireduce(sendbuf, recvbuf, count, datatype, op, root, comm, request);
 
-    simgrid::smpi::Colls::reduce(sendbuf, recvbuf, count, datatype, op, root, comm);
 
     retval = MPI_SUCCESS;
     TRACE_smpi_comm_out(rank);
@@ -444,11 +459,9 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty
     retval = MPI_ERR_TYPE;
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
-  } else if (request != MPI_REQUEST_IGNORED) {
-    xbt_die("Iallreduce is not yet implemented. WIP");
+  } else if (request == nullptr){
     retval = MPI_ERR_ARG;
   } else {
-
     char* sendtmpbuf = static_cast<char*>(sendbuf);
     if( sendbuf == MPI_IN_PLACE ) {
       sendtmpbuf = static_cast<char*>(xbt_malloc(count*datatype->get_extent()));
@@ -461,10 +474,10 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty
                                                       datatype->is_replayable() ? count : count * datatype->size(), -1,
                                                       simgrid::smpi::Datatype::encode(datatype), ""));
 
-//    if(request == MPI_REQUEST_IGNORED)
+    if(request == MPI_REQUEST_IGNORED)
       simgrid::smpi::Colls::allreduce(sendtmpbuf, recvbuf, count, datatype, op, comm);
-//    else
-//      simgrid::smpi::Colls::iallreduce(sendtmpbuf, recvbuf, count, datatype, op, comm, request);
+    else
+      simgrid::smpi::Colls::iallreduce(sendtmpbuf, recvbuf, count, datatype, op, comm, request);
 
     if( sendbuf == MPI_IN_PLACE )
       xbt_free(sendtmpbuf);
@@ -478,6 +491,11 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty
 }
 
 int PMPI_Scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
+{
+  return PMPI_Iscan(sendbuf, recvbuf, count, datatype, op, comm, MPI_REQUEST_IGNORED);
+}
+
+int PMPI_Iscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request* request)
 {
   int retval = 0;
 
@@ -489,6 +507,8 @@ int PMPI_Scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MP
     retval = MPI_ERR_TYPE;
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
+  } else if (request == nullptr){
+    retval = MPI_ERR_ARG;
   } else {
     int rank = simgrid::s4u::this_actor::get_pid();
     void* sendtmpbuf = sendbuf;
@@ -496,11 +516,15 @@ int PMPI_Scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MP
       sendtmpbuf = static_cast<void*>(xbt_malloc(count * datatype->size()));
       memcpy(sendtmpbuf, recvbuf, count * datatype->size());
     }
-    TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::Pt2PtTIData(
-                                           "scan", -1, datatype->is_replayable() ? count : count * datatype->size(),
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED ? "PMPI_Scan" : "PMPI_Iscan", new simgrid::instr::Pt2PtTIData(
+                                           request==MPI_REQUEST_IGNORED ? "scan":"iscan", -1, 
+                                           datatype->is_replayable() ? count : count * datatype->size(),
                                            simgrid::smpi::Datatype::encode(datatype)));
 
-    retval = simgrid::smpi::Colls::scan(sendtmpbuf, recvbuf, count, datatype, op, comm);
+    if(request == MPI_REQUEST_IGNORED)
+      retval = simgrid::smpi::Colls::scan(sendtmpbuf, recvbuf, count, datatype, op, comm);
+    else
+      retval = simgrid::smpi::Colls::iscan(sendtmpbuf, recvbuf, count, datatype, op, comm, request);
 
     TRACE_smpi_comm_out(rank);
     if (sendbuf == MPI_IN_PLACE)
@@ -511,7 +535,12 @@ int PMPI_Scan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MP
   return retval;
 }
 
-int PMPI_Exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm){
+int PMPI_Exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm)
+{
+  return PMPI_Iexscan(sendbuf, recvbuf, count, datatype, op, comm, MPI_REQUEST_IGNORED);
+}
+
+int PMPI_Iexscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request* request){
   int retval = 0;
 
   smpi_bench_end();
@@ -522,6 +551,8 @@ int PMPI_Exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
     retval = MPI_ERR_TYPE;
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
+  } else if (request == nullptr){
+    retval = MPI_ERR_ARG;
   } else {
     int rank         = simgrid::s4u::this_actor::get_pid();
     void* sendtmpbuf = sendbuf;
@@ -530,11 +561,14 @@ int PMPI_Exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
       memcpy(sendtmpbuf, recvbuf, count * datatype->size());
     }
 
-    TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::Pt2PtTIData(
-                                           "exscan", -1, datatype->is_replayable() ? count : count * datatype->size(),
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED ? "PMPI_Exscan" : "PMPI_Iexscan", new simgrid::instr::Pt2PtTIData(
+                                           request==MPI_REQUEST_IGNORED ? "exscan":"iexscan", -1, datatype->is_replayable() ? count : count * datatype->size(),
                                            simgrid::smpi::Datatype::encode(datatype)));
 
-    retval = simgrid::smpi::Colls::exscan(sendtmpbuf, recvbuf, count, datatype, op, comm);
+    if(request == MPI_REQUEST_IGNORED)
+      retval = simgrid::smpi::Colls::exscan(sendtmpbuf, recvbuf, count, datatype, op, comm);
+    else
+      retval = simgrid::smpi::Colls::iexscan(sendtmpbuf, recvbuf, count, datatype, op, comm, request);
 
     TRACE_smpi_comm_out(rank);
     if (sendbuf == MPI_IN_PLACE)
@@ -667,7 +701,7 @@ int PMPI_Ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendtype, void* re
       sendtmptype  = recvtype;
     }
 
-    TRACE_smpi_comm_in(rank, __func__,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Alltoall":"PMPI_Ialltoall",
                        new simgrid::instr::CollTIData(
                            request==MPI_REQUEST_IGNORED ? "alltoall" : "ialltoall", -1, -1.0,
                            sendtmptype->is_replayable() ? sendtmpcount : sendtmpcount * sendtmptype->size(),
@@ -748,7 +782,7 @@ int PMPI_Ialltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
       trace_sendcounts->push_back(sendtmpcounts[i] * dt_size_send);
     }
 
-    TRACE_smpi_comm_in(rank, __func__,
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Alltoallv":"PMPI_Ialltoallv",
                        new simgrid::instr::VarCollTIData(request==MPI_REQUEST_IGNORED ? "alltoallv":"ialltoallv", -1, send_size, trace_sendcounts, recv_size,
                                                          trace_recvcounts, simgrid::smpi::Datatype::encode(sendtype),
                                                          simgrid::smpi::Datatype::encode(recvtype)));
@@ -771,3 +805,91 @@ int PMPI_Ialltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
   smpi_bench_begin();
   return retval;
 }
+
+int PMPI_Alltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf,
+                   int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm)
+{
+  return PMPI_Ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, MPI_REQUEST_IGNORED);
+}
+
+int PMPI_Ialltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf,
+                   int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request)
+{
+  int retval = 0;
+
+  smpi_bench_end();
+
+  if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if ((sendbuf != MPI_IN_PLACE && sendtypes == nullptr)  || recvtypes == nullptr) {
+    retval = MPI_ERR_TYPE;
+  } else if ((sendbuf != MPI_IN_PLACE && (sendcounts == nullptr || senddisps == nullptr)) || recvcounts == nullptr ||
+             recvdisps == nullptr) {
+    retval = MPI_ERR_ARG;
+  } else if (request == nullptr){
+    retval = MPI_ERR_ARG;
+  }  else {
+    int rank                           = simgrid::s4u::this_actor::get_pid();
+    int size                           = comm->size();
+    int send_size                      = 0;
+    int recv_size                      = 0;
+    std::vector<int>* trace_sendcounts = new std::vector<int>;
+    std::vector<int>* trace_recvcounts = new std::vector<int>;
+
+    void* sendtmpbuf           = static_cast<char*>(sendbuf);
+    int* sendtmpcounts         = sendcounts;
+    int* sendtmpdisps          = senddisps;
+    MPI_Datatype* sendtmptypes = sendtypes;
+    unsigned long maxsize                = 0;
+    for (int i = 0; i < size; i++) { // copy data to avoid bad free
+      if(recvtypes[i]==MPI_DATATYPE_NULL){
+        delete trace_recvcounts;
+        delete trace_sendcounts;
+        return MPI_ERR_TYPE;
+      }
+      recv_size += recvcounts[i] * recvtypes[i]->size();
+      trace_recvcounts->push_back(recvcounts[i] * recvtypes[i]->size());
+      if ((recvdisps[i] + (recvcounts[i] * recvtypes[i]->size())) > maxsize)
+        maxsize = recvdisps[i] + (recvcounts[i] * recvtypes[i]->size());
+    }
+
+    if (sendbuf == MPI_IN_PLACE) {
+      sendtmpbuf = static_cast<void*>(xbt_malloc(maxsize));
+      memcpy(sendtmpbuf, recvbuf, maxsize);
+      sendtmpcounts = static_cast<int*>(xbt_malloc(size * sizeof(int)));
+      memcpy(sendtmpcounts, recvcounts, size * sizeof(int));
+      sendtmpdisps = static_cast<int*>(xbt_malloc(size * sizeof(int)));
+      memcpy(sendtmpdisps, recvdisps, size * sizeof(int));
+      sendtmptypes = static_cast<MPI_Datatype*>(xbt_malloc(size * sizeof(MPI_Datatype)));
+      memcpy(sendtmptypes, recvtypes, size * sizeof(MPI_Datatype));
+    }
+
+    for (int i = 0; i < size; i++) { // copy data to avoid bad free
+      send_size += sendtmpcounts[i] * sendtmptypes[i]->size();
+      trace_sendcounts->push_back(sendtmpcounts[i] * sendtmptypes[i]->size());
+    }
+
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Alltoallw":"PMPI_Ialltoallw",
+                       new simgrid::instr::VarCollTIData(request==MPI_REQUEST_IGNORED ? "alltoallv":"ialltoallv", -1, send_size, trace_sendcounts, recv_size,
+                                                         trace_recvcounts, simgrid::smpi::Datatype::encode(sendtmptypes[0]),
+                                                         simgrid::smpi::Datatype::encode(recvtypes[0])));
+
+    if(request == MPI_REQUEST_IGNORED)
+      retval = simgrid::smpi::Colls::alltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts,
+                                    recvdisps, recvtypes, comm);
+    else
+      retval = simgrid::smpi::Colls::ialltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts,
+                                    recvdisps, recvtypes, comm, request);
+    TRACE_smpi_comm_out(rank);
+
+    if (sendbuf == MPI_IN_PLACE) {
+      xbt_free(sendtmpbuf);
+      xbt_free(sendtmpcounts);
+      xbt_free(sendtmpdisps);
+      xbt_free(sendtmptypes);
+    }
+  }
+
+  smpi_bench_begin();
+  return retval;
+}