Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
implement mpi_isendrecv and mpi_isendrecv_replace
[simgrid.git] / src / smpi / mpi / smpi_request.cpp
index 2d6fdda..64008cd 100644 (file)
@@ -69,7 +69,6 @@ Request::Request(const void* buf, int count, MPI_Datatype datatype, aid_t src, a
     refcount_ = 1;
   else
     refcount_ = 0;
-  message_id_ = 0;
   init_buffer(count);
   this->add_f();
 }
@@ -186,10 +185,11 @@ bool Request::match_recv(void* a, void* b, simgrid::kernel::activity::CommImpl*)
   bool match = match_common(req, req, ref);
   if (not match || ref->comm_ == MPI_COMM_UNINITIALIZED || ref->comm_->is_smp_comm())
     return match;
-
-  if (ref->comm_->get_received_messages_count(ref->comm_->group()->rank(req->src_),
-                                              ref->comm_->group()->rank(req->dst_), req->tag_) == req->message_id_) {
+  auto it = std::find(req->message_id_.begin(), req->message_id_.end(), ref->comm_->get_received_messages_count(ref->comm_->group()->rank(req->src_),
+                                              ref->comm_->group()->rank(req->dst_), req->tag_));
+  if (it != req->message_id_.end()) {
     if (((ref->flags_ & MPI_REQ_PROBE) == 0) && ((req->flags_ & MPI_REQ_PROBE) == 0)) {
+      req->message_id_.erase(it);
       XBT_DEBUG("increasing count in comm %p, which was %u from pid %ld, to pid %ld with tag %d", ref->comm_,
                 ref->comm_->get_received_messages_count(ref->comm_->group()->rank(req->src_),
                                                         ref->comm_->group()->rank(req->dst_), req->tag_),
@@ -204,12 +204,12 @@ bool Request::match_recv(void* a, void* b, simgrid::kernel::activity::CommImpl*)
     match = false;
     req->flags_ &= ~MPI_REQ_MATCHED;
     ref->detached_sender_ = nullptr;
-    XBT_DEBUG("Refusing to match message, as its ID is not the one I expect. in comm %p, %u != %u, "
+    XBT_DEBUG("Refusing to match message, as its ID is not the one I expect. in comm %p, %u, "
               "from pid %ld to pid %ld, with tag %d",
               ref->comm_,
               ref->comm_->get_received_messages_count(ref->comm_->group()->rank(req->src_),
                                                       ref->comm_->group()->rank(req->dst_), req->tag_),
-              req->message_id_, req->src_, req->dst_, req->tag_);
+              req->src_, req->dst_, req->tag_);
   }
   return match;
 }
@@ -439,6 +439,29 @@ void Request::sendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype
   }
 }
 
+void Request::isendrecv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,int dst, int sendtag,
+                       void *recvbuf, int recvcount, MPI_Datatype recvtype, int src, int recvtag,
+                       MPI_Comm comm, MPI_Request* request)
+{
+  aid_t source = MPI_PROC_NULL;
+  if (src == MPI_ANY_SOURCE)
+    source = MPI_ANY_SOURCE;
+  else if (src != MPI_PROC_NULL)
+    source = comm->group()->actor(src);
+  aid_t destination = dst != MPI_PROC_NULL ? comm->group()->actor(dst) : MPI_PROC_NULL;
+  
+  (*request) = new Request( nullptr, 0, MPI_BYTE,
+                         src,dst, sendtag, comm, MPI_REQ_PERSISTENT|MPI_REQ_NBC);
+  std::vector<MPI_Request> requests;
+  if (aid_t myid = simgrid::s4u::this_actor::get_pid(); (destination == myid) && (source == myid)) {
+    Datatype::copy(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype);
+    return;
+  }
+  requests.push_back(isend_init(sendbuf, sendcount, sendtype, dst, sendtag, comm));
+  requests.push_back(irecv_init(recvbuf, recvcount, recvtype, src, recvtag, comm));
+  (*request)->start_nbc_requests(requests);
+}
+
 void Request::start()
 {
   s4u::Mailbox* mailbox;
@@ -525,7 +548,7 @@ void Request::start()
       TRACE_smpi_send(src_, src_, dst_, tag_, size_);
     this->print_request("New send");
 
-    message_id_=comm_->get_sent_messages_count(comm_->group()->rank(src_), comm_->group()->rank(dst_), tag_);
+    message_id_.push_back(comm_->get_sent_messages_count(comm_->group()->rank(src_), comm_->group()->rank(dst_), tag_));
     comm_->increment_sent_messages_count(comm_->group()->rank(src_), comm_->group()->rank(dst_), tag_);
 
     void* buf = buf_;