Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
implement mpi_isendrecv and mpi_isendrecv_replace
authorAugustin Degomme <adegomme@users.noreply.github.com>
Fri, 17 Nov 2023 21:31:48 +0000 (22:31 +0100)
committerAugustin Degomme <adegomme@users.noreply.github.com>
Fri, 17 Nov 2023 21:31:48 +0000 (22:31 +0100)
include/smpi/smpi.h
src/smpi/bindings/smpi_mpi.cpp
src/smpi/bindings/smpi_pmpi_request.cpp
src/smpi/include/smpi_request.hpp
src/smpi/mpi/smpi_request.cpp

index 84cd4ce..0812035 100644 (file)
@@ -705,8 +705,13 @@ MPI_CALL(XBT_PUBLIC int, MPI_Irsend,
 MPI_CALL(XBT_PUBLIC int, MPI_Sendrecv,
          (const void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf, int recvcount,
           MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Status* status));
+MPI_CALL(XBT_PUBLIC int, MPI_Isendrecv,
+         (const void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf, int recvcount,
+          MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Request* req));
 MPI_CALL(XBT_PUBLIC int, MPI_Sendrecv_replace, (void* buf, int count, MPI_Datatype datatype, int dst, int sendtag,
                                                 int src, int recvtag, MPI_Comm comm, MPI_Status* status));
+MPI_CALL(XBT_PUBLIC int, MPI_Isendrecv_replace, (void* buf, int count, MPI_Datatype datatype, int dst, int sendtag,
+                                                int src, int recvtag, MPI_Comm comm, MPI_Request* req));
 
 MPI_CALL(XBT_PUBLIC int, MPI_Test, (MPI_Request * request, int* flag, MPI_Status* status));
 MPI_CALL(XBT_PUBLIC int, MPI_Testany, (int count, MPI_Request requests[], int* index, int* flag, MPI_Status* status));
index 7e62ffa..4b9095c 100644 (file)
@@ -258,7 +258,9 @@ WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Scatter,(const void *sendbuf, int send
 WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Scatterv,(const void *sendbuf, const int *sendcounts, const int *displs, MPI_Datatype sendtype, void *recvbuf, int recvcount,MPI_Datatype recvtype, int root, MPI_Comm comm),(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm))
 WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Send_init,(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request),(buf, count, datatype, dst, tag, comm, request))
 WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Sendrecv_replace,(void *buf, int count, MPI_Datatype datatype, int dst, int sendtag, int src, int recvtag,MPI_Comm comm, MPI_Status * status),(buf, count, datatype, dst, sendtag, src, recvtag, comm, status))
+WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Isendrecv_replace,(void *buf, int count, MPI_Datatype datatype, int dst, int sendtag, int src, int recvtag,MPI_Comm comm, MPI_Request* req),(buf, count, datatype, dst, sendtag, src, recvtag, comm, req))
 WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Sendrecv,(const void *sendbuf, int sendcount, MPI_Datatype sendtype,int dst, int sendtag, void *recvbuf, int recvcount,MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Status * status),(sendbuf, sendcount, sendtype, dst, sendtag, recvbuf, recvcount, recvtype, src, recvtag,comm, status))
+WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Isendrecv,(const void *sendbuf, int sendcount, MPI_Datatype sendtype,int dst, int sendtag, void *recvbuf, int recvcount,MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Request* req),(sendbuf, sendcount, sendtype, dst, sendtag, recvbuf, recvcount, recvtype, src, recvtag,comm, req))
 WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Send,(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm),(buf, count, datatype, dst, tag, comm))
 WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Ssend_init,(const void* buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request* request),(buf, count, datatype, dest, tag, comm, request))
 WRAPPED_PMPI_CALL_ERRHANDLER_COMM(int,MPI_Ssend,(const void* buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm) ,(buf, count, datatype, dest, tag, comm))
index 8cf9b41..73ea9b9 100644 (file)
@@ -429,6 +429,66 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int
   return retval;
 }
 
+int PMPI_Isendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf,
+                  int recvcount, MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Request* request)
+{
+  int retval = 0;
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
+  CHECK_COUNT(2, sendcount)
+  CHECK_TYPE(3, sendtype)
+  CHECK_TAG(5, sendtag)
+  CHECK_COUNT(7, recvcount)
+  CHECK_TYPE(8, recvtype)
+  CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
+  CHECK_BUFFER(6, recvbuf, recvcount, recvtype)
+  CHECK_ARGS(sendbuf == recvbuf && sendcount > 0 && recvcount > 0, MPI_ERR_BUFFER,
+             "%s: Invalid parameters 1 and 6: sendbuf and recvbuf must be disjoint", __func__);
+  CHECK_TAG(10, recvtag)
+  CHECK_COMM(11)
+  CHECK_REQUEST(12)
+  *request = MPI_REQUEST_NULL;
+  const SmpiBenchGuard suspend_bench;
+
+  if (src == MPI_PROC_NULL && dst != MPI_PROC_NULL){
+    *request=simgrid::smpi::Request::isend(sendbuf, sendcount, sendtype, dst, sendtag, comm);
+    retval = MPI_SUCCESS;
+  } else if (dst == MPI_PROC_NULL){
+    *request = simgrid::smpi::Request::irecv(recvbuf, recvcount, recvtype, src, recvtag, comm);
+    retval = MPI_SUCCESS;
+  } else if (dst >= comm->group()->size() || dst <0 ||
+      (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0))){
+    retval = MPI_ERR_RANK;
+  } else {
+    aid_t my_proc_id = simgrid::s4u::this_actor::get_pid();
+    aid_t dst_traced = MPI_COMM_WORLD->group()->rank(getPid(comm, dst));
+    aid_t src_traced = MPI_COMM_WORLD->group()->rank(getPid(comm, src));
+
+    // FIXME: Hack the way to trace this one
+    auto dst_hack = std::make_shared<std::vector<int>>();
+    auto src_hack = std::make_shared<std::vector<int>>();
+    dst_hack->push_back(dst_traced);
+    src_hack->push_back(src_traced);
+    TRACE_smpi_comm_in(my_proc_id, __func__,
+                       new simgrid::instr::VarCollTIData(
+                           "isendRecv", -1, sendcount,
+                           dst_hack, recvcount, src_hack,
+                           simgrid::smpi::Datatype::encode(sendtype), simgrid::smpi::Datatype::encode(recvtype)));
+
+    TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, sendtag, sendcount * sendtype->size());
+
+    simgrid::smpi::Request::isendrecv(sendbuf, sendcount, sendtype, dst, sendtag, recvbuf, recvcount, recvtype, src,
+                                     recvtag, comm, request);
+    retval = MPI_SUCCESS;
+
+    TRACE_smpi_recv(src_traced, my_proc_id, recvtag);
+    TRACE_smpi_comm_out(my_proc_id);
+  }
+
+  return retval;
+}
+
+
 int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst, int sendtag, int src, int recvtag,
                           MPI_Comm comm, MPI_Status* status)
 {
@@ -452,6 +512,29 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst,
   return retval;
 }
 
+int PMPI_Isendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst, int sendtag, int src, int recvtag,
+                          MPI_Comm comm, MPI_Request* request)
+{
+  int retval = 0;
+  SET_BUF1(buf)
+  CHECK_COUNT(2, count)
+  CHECK_TYPE(3, datatype)
+  CHECK_BUFFER(1, buf, count, datatype)
+  CHECK_REQUEST(10)
+  *request = MPI_REQUEST_NULL;
+
+  int size = datatype->get_extent() * count;
+  if (size == 0)
+    return MPI_SUCCESS;
+  else if (size <0)
+    return MPI_ERR_ARG;
+  std::vector<char> sendbuf(size);
+  simgrid::smpi::Datatype::copy(buf, count, datatype, sendbuf.data(), count, datatype);
+  retval =
+      MPI_Isendrecv(sendbuf.data(), count, datatype, dst, sendtag, buf, count, datatype, src, recvtag, comm, request);
+  return retval;
+}
+
 int PMPI_Test(MPI_Request * request, int *flag, MPI_Status * status)
 {
   int retval = 0;
index 5b3962c..9b0a6cc 100644 (file)
@@ -102,6 +102,8 @@ public:
 
   static void sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf,
                        int recvcount, MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Status* status);
+  static void isendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf,
+                       int recvcount, MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Request* request);
 
   static void startall(int count, MPI_Request* requests);
 
index 86130de..64008cd 100644 (file)
@@ -439,6 +439,29 @@ void Request::sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype
   }
 }
 
+void Request::isendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,int dst, int sendtag,
+                       void *recvbuf, int recvcount, MPI_Datatype recvtype, int src, int recvtag,
+                       MPI_Comm comm, MPI_Request* request)
+{
+  aid_t source = MPI_PROC_NULL;
+  if (src == MPI_ANY_SOURCE)
+    source = MPI_ANY_SOURCE;
+  else if (src != MPI_PROC_NULL)
+    source = comm->group()->actor(src);
+  aid_t destination = dst != MPI_PROC_NULL ? comm->group()->actor(dst) : MPI_PROC_NULL;
+  
+  (*request) = new Request( nullptr, 0, MPI_BYTE,
+                         src,dst, sendtag, comm, MPI_REQ_PERSISTENT|MPI_REQ_NBC);
+  std::vector<MPI_Request> requests;
+  if (aid_t myid = simgrid::s4u::this_actor::get_pid(); (destination == myid) && (source == myid)) {
+    Datatype::copy(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype);
+    return;
+  }
+  requests.push_back(isend_init(sendbuf, sendcount, sendtype, dst, sendtag, comm));
+  requests.push_back(irecv_init(recvbuf, recvcount, recvtype, src, recvtag, comm));
+  (*request)->start_nbc_requests(requests);
+}
+
 void Request::start()
 {
   s4u::Mailbox* mailbox;