Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
another bunch of codefactor style issues
[simgrid.git] / src / smpi / bindings / smpi_pmpi_request.cpp
index ab2bf56..62e4822 100644 (file)
@@ -20,7 +20,7 @@ static int getPid(MPI_Comm comm, int id)
 
 /* PMPI User level calls */
 
-int PMPI_Send_init(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request)
+int PMPI_Send_init(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request)
 {
   int retval = 0;
 
@@ -29,7 +29,7 @@ int PMPI_Send_init(void *buf, int count, MPI_Datatype datatype, int dst, int tag
     retval = MPI_ERR_ARG;
   } else if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if (dst == MPI_PROC_NULL) {
     retval = MPI_SUCCESS;
@@ -52,7 +52,7 @@ int PMPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int src, int tag
     retval = MPI_ERR_ARG;
   } else if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if (src == MPI_PROC_NULL) {
     retval = MPI_SUCCESS;
@@ -66,7 +66,13 @@ int PMPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int src, int tag
   return retval;
 }
 
-int PMPI_Ssend_init(void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
+int PMPI_Rsend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm,
+                    MPI_Request* request)
+{
+  return PMPI_Send_init(buf, count, datatype, dst, tag, comm, request);
+}
+
+int PMPI_Ssend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
 {
   int retval = 0;
 
@@ -75,7 +81,7 @@ int PMPI_Ssend_init(void* buf, int count, MPI_Datatype datatype, int dst, int ta
     retval = MPI_ERR_ARG;
   } else if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if (dst == MPI_PROC_NULL) {
     retval = MPI_SUCCESS;
@@ -139,10 +145,9 @@ int PMPI_Startall(int count, MPI_Request * requests)
     if(retval != MPI_ERR_REQUEST) {
       int my_proc_id = simgrid::s4u::this_actor::get_pid();
       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("Startall"));
-      MPI_Request req = MPI_REQUEST_NULL;
       if (not TRACE_smpi_view_internals())
         for (int i = 0; i < count; i++) {
-          req = requests[i];
+          MPI_Request req = requests[i];
           if (req->flags() & MPI_REQ_SEND)
             TRACE_smpi_send(my_proc_id, my_proc_id, getPid(req->comm(), req->dst()), req->tag(), req->size());
         }
@@ -151,7 +156,7 @@ int PMPI_Startall(int count, MPI_Request * requests)
 
       if (not TRACE_smpi_view_internals())
         for (int i = 0; i < count; i++) {
-          req = requests[i];
+          MPI_Request req = requests[i];
           if (req->flags() & MPI_REQ_RECV)
             TRACE_smpi_recv(getPid(req->comm(), req->src()), my_proc_id, req->tag());
         }
@@ -167,10 +172,9 @@ int PMPI_Request_free(MPI_Request * request)
   int retval = 0;
 
   smpi_bench_end();
-  if (*request == MPI_REQUEST_NULL) {
-    retval = MPI_ERR_ARG;
-  } else {
+  if (*request != MPI_REQUEST_NULL) {
     simgrid::smpi::Request::unref(request);
+    *request = MPI_REQUEST_NULL;
     retval = MPI_SUCCESS;
   }
   smpi_bench_begin();
@@ -194,12 +198,11 @@ int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MP
     retval = MPI_ERR_RANK;
   } else if ((count < 0) || (buf==nullptr && count > 0)) {
     retval = MPI_ERR_COUNT;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if(tag<0 && tag !=  MPI_ANY_TAG){
     retval = MPI_ERR_TAG;
   } else {
-
     int my_proc_id = simgrid::s4u::this_actor::get_pid();
 
     TRACE_smpi_comm_in(my_proc_id, __func__,
@@ -220,7 +223,7 @@ int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MP
 }
 
 
-int PMPI_Isend(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request)
+int PMPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request)
 {
   int retval = 0;
 
@@ -236,7 +239,7 @@ int PMPI_Isend(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MP
     retval = MPI_ERR_RANK;
   } else if ((count < 0) || (buf==nullptr && count > 0)) {
     retval = MPI_ERR_COUNT;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if(tag<0 && tag !=  MPI_ANY_TAG){
     retval = MPI_ERR_TAG;
@@ -262,7 +265,13 @@ int PMPI_Isend(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MP
   return retval;
 }
 
-int PMPI_Issend(void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
+int PMPI_Irsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm,
+                MPI_Request* request)
+{
+  return PMPI_Isend(buf, count, datatype, dst, tag, comm, request);
+}
+
+int PMPI_Issend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
 {
   int retval = 0;
 
@@ -278,7 +287,7 @@ int PMPI_Issend(void* buf, int count, MPI_Datatype datatype, int dst, int tag, M
     retval = MPI_ERR_RANK;
   } else if ((count < 0)|| (buf==nullptr && count > 0)) {
     retval = MPI_ERR_COUNT;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if(tag<0 && tag !=  MPI_ANY_TAG){
     retval = MPI_ERR_TAG;
@@ -320,7 +329,7 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI
     retval = MPI_ERR_RANK;
   } else if ((count < 0) || (buf==nullptr && count > 0)) {
     retval = MPI_ERR_COUNT;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if(tag<0 && tag !=  MPI_ANY_TAG){
     retval = MPI_ERR_TAG;
@@ -351,7 +360,7 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI
   return retval;
 }
 
-int PMPI_Send(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm)
+int PMPI_Send(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm)
 {
   int retval = 0;
 
@@ -365,7 +374,7 @@ int PMPI_Send(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI
     retval = MPI_ERR_RANK;
   } else if ((count < 0) || (buf == nullptr && count > 0)) {
     retval = MPI_ERR_COUNT;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if(tag < 0 && tag !=  MPI_ANY_TAG){
     retval = MPI_ERR_TAG;
@@ -390,7 +399,135 @@ int PMPI_Send(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI
   return retval;
 }
 
-int PMPI_Ssend(void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm) {
+int PMPI_Rsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm)
+{
+  return PMPI_Send(buf, count, datatype, dst, tag, comm);
+}
+
+int PMPI_Bsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm)
+{
+  int retval = 0;
+
+  smpi_bench_end();
+
+  if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if (dst == MPI_PROC_NULL) {
+    retval = MPI_SUCCESS;
+  } else if (dst >= comm->group()->size() || dst <0){
+    retval = MPI_ERR_RANK;
+  } else if ((count < 0) || (buf == nullptr && count > 0)) {
+    retval = MPI_ERR_COUNT;
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
+    retval = MPI_ERR_TYPE;
+  } else if(tag < 0 && tag !=  MPI_ANY_TAG){
+    retval = MPI_ERR_TAG;
+  } else {
+    int my_proc_id         = simgrid::s4u::this_actor::get_pid();
+    int dst_traced         = getPid(comm, dst);
+    int bsend_buf_size = 0;
+    void* bsend_buf = nullptr;
+    smpi_process()->bsend_buffer(&bsend_buf, &bsend_buf_size);
+    int size = datatype->get_extent() * count;
+    if(bsend_buf==nullptr || bsend_buf_size < size + MPI_BSEND_OVERHEAD )
+      return MPI_ERR_BUFFER;
+    TRACE_smpi_comm_in(my_proc_id, __func__,
+                       new simgrid::instr::Pt2PtTIData("bsend", dst,
+                                                       datatype->is_replayable() ? count : count * datatype->size(),
+                                                       tag, simgrid::smpi::Datatype::encode(datatype)));
+    if (not TRACE_smpi_view_internals()) {
+      TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, tag, count * datatype->size());
+    }
+
+    simgrid::smpi::Request::bsend(buf, count, datatype, dst, tag, comm);
+    retval = MPI_SUCCESS;
+
+    TRACE_smpi_comm_out(my_proc_id);
+  }
+
+  smpi_bench_begin();
+  return retval;
+}
+
+int PMPI_Ibsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
+{
+  int retval = 0;
+
+  smpi_bench_end();
+  if (request == nullptr) {
+    retval = MPI_ERR_ARG;
+  } else if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if (dst == MPI_PROC_NULL) {
+    *request = MPI_REQUEST_NULL;
+    retval = MPI_SUCCESS;
+  } else if (dst >= comm->group()->size() || dst <0){
+    retval = MPI_ERR_RANK;
+  } else if ((count < 0) || (buf==nullptr && count > 0)) {
+    retval = MPI_ERR_COUNT;
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
+    retval = MPI_ERR_TYPE;
+  } else if(tag<0 && tag !=  MPI_ANY_TAG){
+    retval = MPI_ERR_TAG;
+  } else {
+    int my_proc_id = simgrid::s4u::this_actor::get_pid();
+    int trace_dst = getPid(comm, dst);
+    int bsend_buf_size = 0;
+    void* bsend_buf = nullptr;
+    smpi_process()->bsend_buffer(&bsend_buf, &bsend_buf_size);
+    int size = datatype->get_extent() * count;
+    if(bsend_buf==nullptr || bsend_buf_size < size + MPI_BSEND_OVERHEAD )
+      return MPI_ERR_BUFFER;
+    TRACE_smpi_comm_in(my_proc_id, __func__,
+                       new simgrid::instr::Pt2PtTIData("ibsend", dst,
+                                                       datatype->is_replayable() ? count : count * datatype->size(),
+                                                       tag, simgrid::smpi::Datatype::encode(datatype)));
+
+    TRACE_smpi_send(my_proc_id, my_proc_id, trace_dst, tag, count * datatype->size());
+
+    *request = simgrid::smpi::Request::ibsend(buf, count, datatype, dst, tag, comm);
+    retval = MPI_SUCCESS;
+
+    TRACE_smpi_comm_out(my_proc_id);
+  }
+
+  smpi_bench_begin();
+  if (retval != MPI_SUCCESS && request!=nullptr)
+    *request = MPI_REQUEST_NULL;
+  return retval;
+}
+
+int PMPI_Bsend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
+{
+  int retval = 0;
+
+  smpi_bench_end();
+  if (request == nullptr) {
+    retval = MPI_ERR_ARG;
+  } else if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
+    retval = MPI_ERR_TYPE;
+  } else if (dst == MPI_PROC_NULL) {
+    retval = MPI_SUCCESS;
+  } else {
+    int bsend_buf_size = 0;
+    void* bsend_buf = nullptr;
+    smpi_process()->bsend_buffer(&bsend_buf, &bsend_buf_size);
+    if( bsend_buf==nullptr || bsend_buf_size < datatype->get_extent() * count + MPI_BSEND_OVERHEAD ) {
+      retval = MPI_ERR_BUFFER;
+    } else {
+      *request = simgrid::smpi::Request::bsend_init(buf, count, datatype, dst, tag, comm);
+      retval   = MPI_SUCCESS;
+    }
+  }
+  smpi_bench_begin();
+  if (retval != MPI_SUCCESS && request != nullptr)
+    *request = MPI_REQUEST_NULL;
+  return retval;
+}
+
+int PMPI_Ssend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm) {
   int retval = 0;
 
   smpi_bench_end();
@@ -403,7 +540,7 @@ int PMPI_Ssend(void* buf, int count, MPI_Datatype datatype, int dst, int tag, MP
     retval = MPI_ERR_RANK;
   } else if ((count < 0) || (buf==nullptr && count > 0)) {
     retval = MPI_ERR_COUNT;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
   } else if(tag<0 && tag !=  MPI_ANY_TAG){
     retval = MPI_ERR_TAG;
@@ -426,7 +563,7 @@ int PMPI_Ssend(void* buf, int count, MPI_Datatype datatype, int dst, int tag, MP
   return retval;
 }
 
-int PMPI_Sendrecv(void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf,
+int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int dst, int sendtag, void* recvbuf,
                   int recvcount, MPI_Datatype recvtype, int src, int recvtag, MPI_Comm comm, MPI_Status* status)
 {
   int retval = 0;
@@ -490,7 +627,7 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst,
                           MPI_Comm comm, MPI_Status* status)
 {
   int retval = 0;
-  if (not datatype->is_valid()) {
+  if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
     return MPI_ERR_TYPE;
   } else if (count < 0) {
     return MPI_ERR_COUNT;
@@ -502,7 +639,6 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst,
       simgrid::smpi::Datatype::copy(recvbuf, count, datatype, buf, count, datatype);
     }
     xbt_free(recvbuf);
-
   }
   return retval;
 }
@@ -656,7 +792,7 @@ int PMPI_Wait(MPI_Request * request, MPI_Status * status)
   } else if (*request == MPI_REQUEST_NULL) {
     retval = MPI_SUCCESS;
   } else {
-    //for tracing, save the handle which might get overriden before we can use the helper on it
+    // for tracing, save the handle which might get overridden before we can use the helper on it
     MPI_Request savedreq = *request;
     if (savedreq != MPI_REQUEST_NULL && not(savedreq->flags() & MPI_REQ_FINISHED)
     && not(savedreq->flags() & MPI_REQ_GENERALIZED))
@@ -692,7 +828,7 @@ int PMPI_Waitany(int count, MPI_Request requests[], int *index, MPI_Status * sta
     return MPI_SUCCESS;
 
   smpi_bench_end();
-  //for tracing, save the handles which might get overriden before we can use the helper on it
+  // for tracing, save the handles which might get overridden before we can use the helper on it
   std::vector<MPI_Request> savedreqs(requests, requests + count);
   for (MPI_Request& req : savedreqs) {
     if (req != MPI_REQUEST_NULL && not(req->flags() & MPI_REQ_FINISHED))
@@ -702,7 +838,7 @@ int PMPI_Waitany(int count, MPI_Request requests[], int *index, MPI_Status * sta
   }
 
   int rank_traced = simgrid::s4u::this_actor::get_pid(); // FIXME: In PMPI_Wait, we check if the comm is null?
-  TRACE_smpi_comm_in(rank_traced, __func__, new simgrid::instr::CpuTIData("waitAny", static_cast<double>(count)));
+  TRACE_smpi_comm_in(rank_traced, __func__, new simgrid::instr::CpuTIData("waitAny", count));
 
   *index = simgrid::smpi::Request::waitany(count, requests, status);
 
@@ -723,7 +859,7 @@ int PMPI_Waitall(int count, MPI_Request requests[], MPI_Status status[])
 {
   smpi_bench_end();
 
-  //for tracing, save the handles which might get overriden before we can use the helper on it
+  // for tracing, save the handles which might get overridden before we can use the helper on it
   std::vector<MPI_Request> savedreqs(requests, requests + count);
   for (MPI_Request& req : savedreqs) {
     if (req != MPI_REQUEST_NULL && not(req->flags() & MPI_REQ_FINISHED))
@@ -733,7 +869,7 @@ int PMPI_Waitall(int count, MPI_Request requests[], MPI_Status status[])
   }
 
   int rank_traced = simgrid::s4u::this_actor::get_pid(); // FIXME: In PMPI_Wait, we check if the comm is null?
-  TRACE_smpi_comm_in(rank_traced, __func__, new simgrid::instr::CpuTIData("waitall", static_cast<double>(count)));
+  TRACE_smpi_comm_in(rank_traced, __func__, new simgrid::instr::CpuTIData("waitall", count));
 
   int retval = simgrid::smpi::Request::waitall(count, requests, status);
 
@@ -780,7 +916,7 @@ int PMPI_Cancel(MPI_Request* request)
   return retval;
 }
 
-int PMPI_Test_cancelled(MPI_Status* status, int* flag){
+int PMPI_Test_cancelled(const MPI_Status* status, int* flag){
   if(status==MPI_STATUS_IGNORE){
     *flag=0;
     return MPI_ERR_ARG;
@@ -789,6 +925,21 @@ int PMPI_Test_cancelled(MPI_Status* status, int* flag){
   return MPI_SUCCESS;  
 }
 
+int PMPI_Status_set_cancelled(MPI_Status* status, int flag){
+  if(status==MPI_STATUS_IGNORE){
+    return MPI_ERR_ARG;
+  }
+  simgrid::smpi::Status::set_cancelled(status,flag);
+  return MPI_SUCCESS;  
+}
+
+int PMPI_Status_set_elements(MPI_Status* status, MPI_Datatype datatype, int count){
+  if(status==MPI_STATUS_IGNORE){
+    return MPI_ERR_ARG;
+  }
+  simgrid::smpi::Status::set_elements(status,datatype, count);
+  return MPI_SUCCESS;  
+}
 
 int PMPI_Grequest_start( MPI_Grequest_query_function *query_fn, MPI_Grequest_free_function *free_fn, MPI_Grequest_cancel_function *cancel_fn, void *extra_state, MPI_Request *request){
   return simgrid::smpi::Request::grequest_start(query_fn, free_fn,cancel_fn, extra_state, request);
@@ -798,10 +949,25 @@ int PMPI_Grequest_complete( MPI_Request request){
   return simgrid::smpi::Request::grequest_complete(request);
 }
 
+int PMPI_Request_get_status( MPI_Request request, int *flag, MPI_Status *status){
+  if(request==MPI_REQUEST_NULL){
+    *flag=1;
+    simgrid::smpi::Status::empty(status);
+    return MPI_SUCCESS;
+  } else if (flag==NULL || status ==NULL){
+    return MPI_ERR_ARG;
+  }
+  return simgrid::smpi::Request::get_status(request,flag,status);
+}
+
 MPI_Request PMPI_Request_f2c(MPI_Fint request){
+  if(request==-1)
+    return MPI_REQUEST_NULL;
   return static_cast<MPI_Request>(simgrid::smpi::Request::f2c(request));
 }
 
 MPI_Fint PMPI_Request_c2f(MPI_Request request) {
+  if(request==MPI_REQUEST_NULL)
+    return -1;
   return request->c2f();
 }