Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
add support of MPI_PROC_NULL, correct behavior of waitall and waitsome for some speci...
[simgrid.git] / src / smpi / smpi_pmpi.c
index 64f0161..caf671e 100644 (file)
@@ -825,6 +825,8 @@ int PMPI_Send_init(void *buf, int count, MPI_Datatype datatype, int dst,
     retval = MPI_ERR_ARG;
   } else if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
+  } else if (dst == MPI_PROC_NULL) {
+    retval = MPI_SUCCESS;
   } else {
     *request = smpi_mpi_send_init(buf, count, datatype, dst, tag, comm);
     retval = MPI_SUCCESS;
@@ -843,6 +845,8 @@ int PMPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int src,
     retval = MPI_ERR_ARG;
   } else if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
+  } else if (src == MPI_PROC_NULL) {
+      retval = MPI_SUCCESS;
   } else {
     *request = smpi_mpi_recv_init(buf, count, datatype, src, tag, comm);
     retval = MPI_SUCCESS;
@@ -856,7 +860,7 @@ int PMPI_Start(MPI_Request * request)
   int retval;
 
   smpi_bench_end();
-  if (request == NULL) {
+  if (request == NULL || *request == MPI_REQUEST_NULL) {
     retval = MPI_ERR_ARG;
   } else {
     smpi_mpi_start(*request);
@@ -886,7 +890,7 @@ int PMPI_Request_free(MPI_Request * request)
   int retval;
 
   smpi_bench_end();
-  if (request == NULL) {
+  if (request == MPI_REQUEST_NULL) {
     retval = MPI_ERR_ARG;
   } else {
     smpi_mpi_request_free(request);
@@ -911,6 +915,9 @@ int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src,
     retval = MPI_ERR_ARG;
   } else if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
+  }else if (src == MPI_PROC_NULL) {
+    *request = MPI_REQUEST_NULL;
+    retval = MPI_SUCCESS;
   } else {
     *request = smpi_mpi_irecv(buf, count, datatype, src, tag, comm);
     retval = MPI_SUCCESS;
@@ -930,6 +937,15 @@ int PMPI_Isend(void *buf, int count, MPI_Datatype datatype, int dst,
   int retval;
 
   smpi_bench_end();
+  if (request == NULL) {
+    retval = MPI_ERR_ARG;
+  } else if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if (dst == MPI_PROC_NULL) {
+    *request = MPI_REQUEST_NULL;
+    retval = MPI_SUCCESS;
+  } else {
+
 #ifdef HAVE_TRACING
   int rank = comm != MPI_COMM_NULL ? smpi_process_index() : -1;
   TRACE_smpi_computing_out(rank);
@@ -937,19 +953,17 @@ int PMPI_Isend(void *buf, int count, MPI_Datatype datatype, int dst,
   TRACE_smpi_ptp_in(rank, rank, dst_traced, __FUNCTION__);
   TRACE_smpi_send(rank, rank, dst_traced);
 #endif
-  if (request == NULL) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else {
+
     *request = smpi_mpi_isend(buf, count, datatype, dst, tag, comm);
     retval = MPI_SUCCESS;
-  }
+
 #ifdef HAVE_TRACING
   TRACE_smpi_ptp_out(rank, rank, dst_traced, __FUNCTION__);
   (*request)->send = 1;
   TRACE_smpi_computing_in(rank);
 #endif
+  }
+
   smpi_bench_begin();
   return retval;
 }
@@ -962,6 +976,14 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag,
   int retval;
 
   smpi_bench_end();
+
+  if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if (src == MPI_PROC_NULL) {
+    smpi_empty_status(status);
+    status->MPI_SOURCE = MPI_PROC_NULL;
+    retval = MPI_SUCCESS;
+  } else {
 #ifdef HAVE_TRACING
   int rank = comm != MPI_COMM_NULL ? smpi_process_index() : -1;
   int src_traced = smpi_group_index(smpi_comm_group(comm), src);
@@ -969,12 +991,10 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag,
 
   TRACE_smpi_ptp_in(rank, src_traced, rank, __FUNCTION__);
 #endif
-  if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else {
+
     smpi_mpi_recv(buf, count, datatype, src, tag, comm, status);
     retval = MPI_SUCCESS;
-  }
+
 #ifdef HAVE_TRACING
   //the src may not have been known at the beginning of the recv (MPI_ANY_SOURCE)
   if(status!=MPI_STATUS_IGNORE)src_traced = smpi_group_index(smpi_comm_group(comm), status->MPI_SOURCE);
@@ -982,6 +1002,8 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag,
   TRACE_smpi_recv(rank, src_traced, rank);
   TRACE_smpi_computing_in(rank);
 #endif
+  }
+
   smpi_bench_begin();
   return retval;
 }
@@ -1001,6 +1023,8 @@ int PMPI_Send(void *buf, int count, MPI_Datatype datatype, int dst, int tag,
 #endif
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
+  } else if (dst == MPI_PROC_NULL) {
+    retval = MPI_SUCCESS;
   } else {
     smpi_mpi_send(buf, count, datatype, dst, tag, comm);
     retval = MPI_SUCCESS;
@@ -1035,6 +1059,10 @@ int PMPI_Sendrecv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
   } else if (sendtype == MPI_DATATYPE_NULL
              || recvtype == MPI_DATATYPE_NULL) {
     retval = MPI_ERR_TYPE;
+  } else if (src == MPI_PROC_NULL || dst == MPI_PROC_NULL) {
+      smpi_empty_status(status);
+      status->MPI_SOURCE = MPI_PROC_NULL;
+      retval = MPI_SUCCESS;
   } else {
     smpi_mpi_sendrecv(sendbuf, sendcount, sendtype, dst, sendtag, recvbuf,
                       recvcount, recvtype, src, recvtag, comm, status);
@@ -1077,7 +1105,7 @@ int PMPI_Test(MPI_Request * request, int *flag, MPI_Status * status)
   if (request == MPI_REQUEST_NULL || flag == NULL) {
     retval = MPI_ERR_ARG;
   } else if (*request == MPI_REQUEST_NULL) {
-    *flag=TRUE;
+    *flag= TRUE;
     retval = MPI_ERR_REQUEST;
   } else {
     *flag = smpi_mpi_test(request, status);
@@ -1123,12 +1151,16 @@ int PMPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status* status) {
   smpi_bench_end();
 
   if (status == NULL) {
-     retval = MPI_ERR_ARG;
-  }else if (comm == MPI_COMM_NULL) {
-       retval = MPI_ERR_COMM;
+    retval = MPI_ERR_ARG;
+  } else if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if (source == MPI_PROC_NULL) {
+    smpi_empty_status(status);
+    status->MPI_SOURCE = MPI_PROC_NULL;
+    retval = MPI_SUCCESS;
   } else {
-       smpi_mpi_probe(source, tag, comm, status);
-       retval = MPI_SUCCESS;
+    smpi_mpi_probe(source, tag, comm, status);
+    retval = MPI_SUCCESS;
   }
   smpi_bench_begin();
   return retval;
@@ -1140,14 +1172,18 @@ int PMPI_Iprobe(int source, int tag, MPI_Comm comm, int* flag, MPI_Status* statu
   smpi_bench_end();
 
   if (flag == NULL) {
-       retval = MPI_ERR_ARG;
-    }else if (status == NULL) {
-     retval = MPI_ERR_ARG;
-  }else if (comm == MPI_COMM_NULL) {
-       retval = MPI_ERR_COMM;
+    retval = MPI_ERR_ARG;
+  } else if (status == NULL) {
+    retval = MPI_ERR_ARG;
+  } else if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if (source == MPI_PROC_NULL) {
+    smpi_empty_status(status);
+    status->MPI_SOURCE = MPI_PROC_NULL;
+    retval = MPI_SUCCESS;
   } else {
-       smpi_mpi_iprobe(source, tag, comm, flag, status);
-       retval = MPI_SUCCESS;
+    smpi_mpi_iprobe(source, tag, comm, flag, status);
+    retval = MPI_SUCCESS;
   }
   smpi_bench_begin();
   return retval;
@@ -1158,6 +1194,13 @@ int PMPI_Wait(MPI_Request * request, MPI_Status * status)
   int retval;
 
   smpi_bench_end();
+
+  if (request == NULL) {
+    retval = MPI_ERR_ARG;
+  } else if (*request == MPI_REQUEST_NULL) {
+    retval = MPI_ERR_REQUEST;
+  } else {
+
 #ifdef HAVE_TRACING
   int rank = request && (*request)->comm != MPI_COMM_NULL
       ? smpi_process_index()
@@ -1170,22 +1213,20 @@ int PMPI_Wait(MPI_Request * request, MPI_Status * status)
   int is_wait_for_receive = (*request)->recv;
   TRACE_smpi_ptp_in(rank, src_traced, dst_traced, __FUNCTION__);
 #endif
-  if (request == NULL) {
-    retval = MPI_ERR_ARG;
-  } else if (*request == MPI_REQUEST_NULL) {
-    retval = MPI_ERR_REQUEST;
-  } else {
+
     smpi_mpi_wait(request, status);
     retval = MPI_SUCCESS;
-  }
+
 #ifdef HAVE_TRACING
   TRACE_smpi_ptp_out(rank, src_traced, dst_traced, __FUNCTION__);
   if (is_wait_for_receive) {
     TRACE_smpi_recv(rank, src_traced, dst_traced);
   }
   TRACE_smpi_computing_in(rank);
-
 #endif
+
+  }
+
   smpi_bench_begin();
   return retval;
 }
@@ -1267,19 +1308,27 @@ int PMPI_Waitall(int count, MPI_Request requests[], MPI_Status status[])
   xbt_dynar_t dsts = xbt_dynar_new(sizeof(int), NULL);
   xbt_dynar_t recvs = xbt_dynar_new(sizeof(int), NULL);
   for (i = 0; i < count; i++) {
-    MPI_Request req = requests[i];      //all req should be valid in Waitall
-    int *asrc = xbt_new(int, 1);
-    int *adst = xbt_new(int, 1);
-    int *arecv = xbt_new(int, 1);
-    *asrc = req->src;
-    *adst = req->dst;
-    *arecv = req->recv;
-    xbt_dynar_insert_at(srcs, i, asrc);
-    xbt_dynar_insert_at(dsts, i, adst);
-    xbt_dynar_insert_at(recvs, i, arecv);
-    xbt_free(asrc);
-    xbt_free(adst);
-    xbt_free(arecv);
+    MPI_Request req = requests[i];
+    if(req){
+      int *asrc = xbt_new(int, 1);
+      int *adst = xbt_new(int, 1);
+      int *arecv = xbt_new(int, 1);
+      *asrc = req->src;
+      *adst = req->dst;
+      *arecv = req->recv;
+      xbt_dynar_insert_at(srcs, i, asrc);
+      xbt_dynar_insert_at(dsts, i, adst);
+      xbt_dynar_insert_at(recvs, i, arecv);
+      xbt_free(asrc);
+      xbt_free(adst);
+      xbt_free(arecv);
+    }else {
+      int *t = xbt_new(int, 1);
+      xbt_dynar_insert_at(srcs, i, t);
+      xbt_dynar_insert_at(dsts, i, t);
+      xbt_dynar_insert_at(recvs, i, t);
+      xbt_free(t);
+    }
   }
   int rank_traced = smpi_process_index();
   TRACE_smpi_computing_out(rank_traced);