Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
MPI_Translate_ranks should return MPI_PROC_NULL if we provide MPI_PROC_NULL as parameter
[simgrid.git] / src / smpi / smpi_pmpi.c
index c35b821..1c12d35 100644 (file)
@@ -350,14 +350,17 @@ int PMPI_Group_translate_ranks(MPI_Group group1, int n, int *ranks1,
                               MPI_Group group2, int *ranks2)
 {
   int retval, i, index;
-
   smpi_bench_end();
   if (group1 == MPI_GROUP_NULL || group2 == MPI_GROUP_NULL) {
     retval = MPI_ERR_GROUP;
   } else {
     for (i = 0; i < n; i++) {
-      index = smpi_group_index(group1, ranks1[i]);
-      ranks2[i] = smpi_group_rank(group2, index);
+      if(ranks1[i]==MPI_PROC_NULL){
+        ranks2[i]=MPI_PROC_NULL;
+      }else{
+        index = smpi_group_index(group1, ranks1[i]);
+        ranks2[i] = smpi_group_rank(group2, index);
+      }
     }
     retval = MPI_SUCCESS;
   }
@@ -1656,12 +1659,26 @@ int PMPI_Gather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 #endif
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (sendtype == MPI_DATATYPE_NULL
-             || recvtype == MPI_DATATYPE_NULL) {
+  } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) ||
+            ((smpi_comm_rank(comm) == root) && (recvtype == MPI_DATATYPE_NULL))){
     retval = MPI_ERR_TYPE;
+  } else if ((( sendbuf != MPI_IN_PLACE) && (sendcount <0)) ||
+            ((smpi_comm_rank(comm) == root) && (recvcount <0))){
+    retval = MPI_ERR_COUNT;
   } else {
-    mpi_coll_gather_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount,
+
+    char* sendtmpbuf = (char*) sendbuf;
+    int sendtmpcount = sendcount;
+    MPI_Datatype sendtmptype = sendtype;
+    if( (smpi_comm_rank(comm) == root) && (sendbuf == MPI_IN_PLACE )) {
+      sendtmpcount=0;
+      sendtmptype=recvtype;
+    }
+
+    mpi_coll_gather_fun(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcount,
                     recvtype, root, comm);
+
+
     retval = MPI_SUCCESS;
   }
 #ifdef HAVE_TRACING
@@ -1687,13 +1704,24 @@ int PMPI_Gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 #endif
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (sendtype == MPI_DATATYPE_NULL
-             || recvtype == MPI_DATATYPE_NULL) {
+  } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) ||
+            ((smpi_comm_rank(comm) == root) && (recvtype == MPI_DATATYPE_NULL))){
     retval = MPI_ERR_TYPE;
+  } else if (( sendbuf != MPI_IN_PLACE) && (sendcount <0)){
+    retval = MPI_ERR_COUNT;
   } else if (recvcounts == NULL || displs == NULL) {
     retval = MPI_ERR_ARG;
   } else {
-    smpi_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts,
+
+    char* sendtmpbuf = (char*) sendbuf;
+    int sendtmpcount = sendcount;
+    MPI_Datatype sendtmptype = sendtype;
+    if( (smpi_comm_rank(comm) == root) && (sendbuf == MPI_IN_PLACE )) {
+      sendtmpcount=0;
+      sendtmptype=recvtype;
+    }
+
+    smpi_mpi_gatherv(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcounts,
                      displs, recvtype, root, comm);
     retval = MPI_SUCCESS;
   }
@@ -1719,10 +1747,20 @@ int PMPI_Allgather(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 #endif
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (sendtype == MPI_DATATYPE_NULL
-             || recvtype == MPI_DATATYPE_NULL) {
+  } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) ||
+            (recvtype == MPI_DATATYPE_NULL)){
     retval = MPI_ERR_TYPE;
+  } else if ((( sendbuf != MPI_IN_PLACE) && (sendcount <0)) ||
+            (recvcount <0)){
+    retval = MPI_ERR_COUNT;
   } else {
+
+    if(sendbuf == MPI_IN_PLACE) {
+      sendbuf=((char*)recvbuf)+smpi_datatype_get_extent(recvtype)*recvcount*smpi_comm_rank(comm);
+      sendcount=recvcount;
+      sendtype=recvtype;
+    }
+
     mpi_coll_allgather_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount,
                            recvtype, comm);
     retval = MPI_SUCCESS;
@@ -1748,12 +1786,21 @@ int PMPI_Allgatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 #endif
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (sendtype == MPI_DATATYPE_NULL
-             || recvtype == MPI_DATATYPE_NULL) {
+  } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) ||
+            (recvtype == MPI_DATATYPE_NULL)){
     retval = MPI_ERR_TYPE;
+  } else if (( sendbuf != MPI_IN_PLACE) && (sendcount <0)){
+    retval = MPI_ERR_COUNT;
   } else if (recvcounts == NULL || displs == NULL) {
     retval = MPI_ERR_ARG;
   } else {
+
+    if(sendbuf == MPI_IN_PLACE) {
+      sendbuf=((char*)recvbuf)+smpi_datatype_get_extent(recvtype)*displs[smpi_comm_rank(comm)];
+      sendcount=recvcounts[smpi_comm_rank(comm)];
+      sendtype=recvtype;
+    }
+
     mpi_coll_allgatherv_fun(sendbuf, sendcount, sendtype, recvbuf, recvcounts,
                         displs, recvtype, comm);
     retval = MPI_SUCCESS;
@@ -1782,10 +1829,14 @@ int PMPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 #endif
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (sendtype == MPI_DATATYPE_NULL
-             || recvtype == MPI_DATATYPE_NULL) {
+  } else if (((smpi_comm_rank(comm)==root) && (sendtype == MPI_DATATYPE_NULL))
+             || ((recvbuf !=MPI_IN_PLACE) && (recvtype == MPI_DATATYPE_NULL))) {
     retval = MPI_ERR_TYPE;
   } else {
+
+    if(recvbuf==MPI_IN_PLACE){
+       recvcount=0;
+    }
     mpi_coll_scatter_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount,
                      recvtype, root, comm);
     retval = MPI_SUCCESS;
@@ -1813,12 +1864,17 @@ int PMPI_Scatterv(void *sendbuf, int *sendcounts, int *displs,
 #endif
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (sendtype == MPI_DATATYPE_NULL
-             || recvtype == MPI_DATATYPE_NULL) {
-    retval = MPI_ERR_TYPE;
   } else if (sendcounts == NULL || displs == NULL) {
     retval = MPI_ERR_ARG;
+  } else if (((smpi_comm_rank(comm)==root) && (sendtype == MPI_DATATYPE_NULL))
+             || ((recvbuf !=MPI_IN_PLACE) && (recvtype == MPI_DATATYPE_NULL))) {
+    retval = MPI_ERR_TYPE;
   } else {
+
+    if(recvbuf==MPI_IN_PLACE){
+       recvcount=0;
+    }
+
     smpi_mpi_scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf,
                       recvcount, recvtype, root, comm);
     retval = MPI_SUCCESS;
@@ -1848,7 +1904,19 @@ int PMPI_Reduce(void *sendbuf, void *recvbuf, int count,
   } else if (datatype == MPI_DATATYPE_NULL || op == MPI_OP_NULL) {
     retval = MPI_ERR_ARG;
   } else {
-    mpi_coll_reduce_fun(sendbuf, recvbuf, count, datatype, op, root, comm);
+
+    char* sendtmpbuf = (char*) sendbuf;
+    if( sendbuf == MPI_IN_PLACE ) {
+      sendtmpbuf = (char *)xbt_malloc(count*smpi_datatype_get_extent(datatype));
+      smpi_datatype_copy(recvbuf, count, datatype,sendtmpbuf, count, datatype);
+    }
+
+    mpi_coll_reduce_fun(sendtmpbuf, recvbuf, count, datatype, op, root, comm);
+
+    if( sendbuf == MPI_IN_PLACE ) {
+      xbt_free(sendtmpbuf);
+    }
+
     retval = MPI_SUCCESS;
   }
 #ifdef HAVE_TRACING
@@ -1892,8 +1960,21 @@ int PMPI_Allreduce(void *sendbuf, void *recvbuf, int count,
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
   } else {
-      mpi_coll_allreduce_fun(sendbuf, recvbuf, count, datatype, op, comm);
+
+    char* sendtmpbuf = (char*) sendbuf;
+    if( sendbuf == MPI_IN_PLACE ) {
+      sendtmpbuf = (char *)xbt_malloc(count*smpi_datatype_get_extent(datatype));
+      smpi_datatype_copy(recvbuf, count, datatype,sendtmpbuf, count, datatype);
+    }
+
+    mpi_coll_allreduce_fun(sendtmpbuf, recvbuf, count, datatype, op, comm);
+
+    if( sendbuf == MPI_IN_PLACE ) {
+      xbt_free(sendtmpbuf);
+    }
+
     retval = MPI_SUCCESS;
+
   }
 #ifdef HAVE_TRACING
   TRACE_smpi_collective_out(rank, -1, __FUNCTION__);
@@ -1951,8 +2032,12 @@ int PMPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts,
   } else if (recvcounts == NULL) {
     retval = MPI_ERR_ARG;
   } else {
+    void* sendtmpbuf=sendbuf;
+    if(sendbuf==MPI_IN_PLACE){
+      sendtmpbuf=recvbuf;
+    }
 
-    mpi_coll_reduce_scatter_fun(sendbuf, recvbuf, recvcounts,
+    mpi_coll_reduce_scatter_fun(sendtmpbuf, recvbuf, recvcounts,
                        datatype,  op, comm);
     retval = MPI_SUCCESS;
   }