Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
implement mpi_isendrecv and mpi_isendrecv_replace
[simgrid.git] / src / smpi / bindings / smpi_pmpi_request.cpp
index 8cf9b41..73ea9b9 100644 (file)
@@ -429,6 +429,66 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int
   return retval;
 }
 
+int PMPI_Isendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf,
+                  int recvcount, MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Request* request)
+{
+  int retval = 0;
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
+  CHECK_COUNT(2, sendcount)
+  CHECK_TYPE(3, sendtype)
+  CHECK_TAG(5, sendtag)
+  CHECK_COUNT(7, recvcount)
+  CHECK_TYPE(8, recvtype)
+  CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
+  CHECK_BUFFER(6, recvbuf, recvcount, recvtype)
+  CHECK_ARGS(sendbuf == recvbuf && sendcount > 0 && recvcount > 0, MPI_ERR_BUFFER,
+             "%s: Invalid parameters 1 and 6: sendbuf and recvbuf must be disjoint", __func__);
+  CHECK_TAG(10, recvtag)
+  CHECK_COMM(11)
+  CHECK_REQUEST(12)
+  *request = MPI_REQUEST_NULL;
+  const SmpiBenchGuard suspend_bench;
+
+  if (src == MPI_PROC_NULL && dst != MPI_PROC_NULL){
+    *request=simgrid::smpi::Request::isend(sendbuf, sendcount, sendtype, dst, sendtag, comm);
+    retval = MPI_SUCCESS;
+  } else if (dst == MPI_PROC_NULL){
+    *request = simgrid::smpi::Request::irecv(recvbuf, recvcount, recvtype, src, recvtag, comm);
+    retval = MPI_SUCCESS;
+  } else if (dst >= comm->group()->size() || dst <0 ||
+      (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0))){
+    retval = MPI_ERR_RANK;
+  } else {
+    aid_t my_proc_id = simgrid::s4u::this_actor::get_pid();
+    aid_t dst_traced = MPI_COMM_WORLD->group()->rank(getPid(comm, dst));
+    aid_t src_traced = MPI_COMM_WORLD->group()->rank(getPid(comm, src));
+
+    // FIXME: Hack the way to trace this one
+    auto dst_hack = std::make_shared<std::vector<int>>();
+    auto src_hack = std::make_shared<std::vector<int>>();
+    dst_hack->push_back(dst_traced);
+    src_hack->push_back(src_traced);
+    TRACE_smpi_comm_in(my_proc_id, __func__,
+                       new simgrid::instr::VarCollTIData(
+                           "isendRecv", -1, sendcount,
+                           dst_hack, recvcount, src_hack,
+                           simgrid::smpi::Datatype::encode(sendtype), simgrid::smpi::Datatype::encode(recvtype)));
+
+    TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, sendtag, sendcount * sendtype->size());
+
+    simgrid::smpi::Request::isendrecv(sendbuf, sendcount, sendtype, dst, sendtag, recvbuf, recvcount, recvtype, src,
+                                     recvtag, comm, request);
+    retval = MPI_SUCCESS;
+
+    TRACE_smpi_recv(src_traced, my_proc_id, recvtag);
+    TRACE_smpi_comm_out(my_proc_id);
+  }
+
+  return retval;
+}
+
+
 int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst, int sendtag, int src, int recvtag,
                           MPI_Comm comm, MPI_Status* status)
 {
@@ -452,6 +512,29 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst,
   return retval;
 }
 
+int PMPI_Isendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst, int sendtag, int src, int recvtag,
+                          MPI_Comm comm, MPI_Request* request)
+{
+  int retval = 0;
+  SET_BUF1(buf)
+  CHECK_COUNT(2, count)
+  CHECK_TYPE(3, datatype)
+  CHECK_BUFFER(1, buf, count, datatype)
+  CHECK_REQUEST(10)
+  *request = MPI_REQUEST_NULL;
+
+  int size = datatype->get_extent() * count;
+  if (size == 0)
+    return MPI_SUCCESS;
+  else if (size <0)
+    return MPI_ERR_ARG;
+  std::vector<char> sendbuf(size);
+  simgrid::smpi::Datatype::copy(buf, count, datatype, sendbuf.data(), count, datatype);
+  retval =
+      MPI_Isendrecv(sendbuf.data(), count, datatype, dst, sendtag, buf, count, datatype, src, recvtag, comm, request);
+  return retval;
+}
+
 int PMPI_Test(MPI_Request * request, int *flag, MPI_Status * status)
 {
   int retval = 0;