X-Git-Url: http://info.iut-bm.univ-fcomte.fr/pub/gitweb/simgrid.git/blobdiff_plain/2d982ca14966f43a6678859d21fa7e796e8aa32f..198a6259727de39df5bb03861f67cb88f9a84373:/src/smpi/smpi_pmpi.c diff --git a/src/smpi/smpi_pmpi.c b/src/smpi/smpi_pmpi.c index c35b8210dd..1c12d353be 100644 --- a/src/smpi/smpi_pmpi.c +++ b/src/smpi/smpi_pmpi.c @@ -350,14 +350,17 @@ int PMPI_Group_translate_ranks(MPI_Group group1, int n, int *ranks1, MPI_Group group2, int *ranks2) { int retval, i, index; - smpi_bench_end(); if (group1 == MPI_GROUP_NULL || group2 == MPI_GROUP_NULL) { retval = MPI_ERR_GROUP; } else { for (i = 0; i < n; i++) { - index = smpi_group_index(group1, ranks1[i]); - ranks2[i] = smpi_group_rank(group2, index); + if(ranks1[i]==MPI_PROC_NULL){ + ranks2[i]=MPI_PROC_NULL; + }else{ + index = smpi_group_index(group1, ranks1[i]); + ranks2[i] = smpi_group_rank(group2, index); + } } retval = MPI_SUCCESS; } @@ -1656,12 +1659,26 @@ int PMPI_Gather(void *sendbuf, int sendcount, MPI_Datatype sendtype, #endif if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (sendtype == MPI_DATATYPE_NULL - || recvtype == MPI_DATATYPE_NULL) { + } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) || + ((smpi_comm_rank(comm) == root) && (recvtype == MPI_DATATYPE_NULL))){ retval = MPI_ERR_TYPE; + } else if ((( sendbuf != MPI_IN_PLACE) && (sendcount <0)) || + ((smpi_comm_rank(comm) == root) && (recvcount <0))){ + retval = MPI_ERR_COUNT; } else { - mpi_coll_gather_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount, + + char* sendtmpbuf = (char*) sendbuf; + int sendtmpcount = sendcount; + MPI_Datatype sendtmptype = sendtype; + if( (smpi_comm_rank(comm) == root) && (sendbuf == MPI_IN_PLACE )) { + sendtmpcount=0; + sendtmptype=recvtype; + } + + mpi_coll_gather_fun(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcount, recvtype, root, comm); + + retval = MPI_SUCCESS; } #ifdef HAVE_TRACING @@ -1687,13 +1704,24 @@ int PMPI_Gatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, #endif if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (sendtype == MPI_DATATYPE_NULL - || recvtype == MPI_DATATYPE_NULL) { + } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) || + ((smpi_comm_rank(comm) == root) && (recvtype == MPI_DATATYPE_NULL))){ retval = MPI_ERR_TYPE; + } else if (( sendbuf != MPI_IN_PLACE) && (sendcount <0)){ + retval = MPI_ERR_COUNT; } else if (recvcounts == NULL || displs == NULL) { retval = MPI_ERR_ARG; } else { - smpi_mpi_gatherv(sendbuf, sendcount, sendtype, recvbuf, recvcounts, + + char* sendtmpbuf = (char*) sendbuf; + int sendtmpcount = sendcount; + MPI_Datatype sendtmptype = sendtype; + if( (smpi_comm_rank(comm) == root) && (sendbuf == MPI_IN_PLACE )) { + sendtmpcount=0; + sendtmptype=recvtype; + } + + smpi_mpi_gatherv(sendtmpbuf, sendtmpcount, sendtmptype, recvbuf, recvcounts, displs, recvtype, root, comm); retval = MPI_SUCCESS; } @@ -1719,10 +1747,20 @@ int PMPI_Allgather(void *sendbuf, int sendcount, MPI_Datatype sendtype, #endif if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (sendtype == MPI_DATATYPE_NULL - || recvtype == MPI_DATATYPE_NULL) { + } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) || + (recvtype == MPI_DATATYPE_NULL)){ retval = MPI_ERR_TYPE; + } else if ((( sendbuf != MPI_IN_PLACE) && (sendcount <0)) || + (recvcount <0)){ + retval = MPI_ERR_COUNT; } else { + + if(sendbuf == MPI_IN_PLACE) { + sendbuf=((char*)recvbuf)+smpi_datatype_get_extent(recvtype)*recvcount*smpi_comm_rank(comm); + sendcount=recvcount; + sendtype=recvtype; + } + mpi_coll_allgather_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm); retval = MPI_SUCCESS; @@ -1748,12 +1786,21 @@ int PMPI_Allgatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, #endif if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (sendtype == MPI_DATATYPE_NULL - || recvtype == MPI_DATATYPE_NULL) { + } else if ((( sendbuf != MPI_IN_PLACE) && (sendtype == MPI_DATATYPE_NULL)) || + (recvtype == MPI_DATATYPE_NULL)){ retval = MPI_ERR_TYPE; + } else if (( sendbuf != MPI_IN_PLACE) && (sendcount <0)){ + retval = MPI_ERR_COUNT; } else if (recvcounts == NULL || displs == NULL) { retval = MPI_ERR_ARG; } else { + + if(sendbuf == MPI_IN_PLACE) { + sendbuf=((char*)recvbuf)+smpi_datatype_get_extent(recvtype)*displs[smpi_comm_rank(comm)]; + sendcount=recvcounts[smpi_comm_rank(comm)]; + sendtype=recvtype; + } + mpi_coll_allgatherv_fun(sendbuf, sendcount, sendtype, recvbuf, recvcounts, displs, recvtype, comm); retval = MPI_SUCCESS; @@ -1782,10 +1829,14 @@ int PMPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype sendtype, #endif if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (sendtype == MPI_DATATYPE_NULL - || recvtype == MPI_DATATYPE_NULL) { + } else if (((smpi_comm_rank(comm)==root) && (sendtype == MPI_DATATYPE_NULL)) + || ((recvbuf !=MPI_IN_PLACE) && (recvtype == MPI_DATATYPE_NULL))) { retval = MPI_ERR_TYPE; } else { + + if(recvbuf==MPI_IN_PLACE){ + recvcount=0; + } mpi_coll_scatter_fun(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, root, comm); retval = MPI_SUCCESS; @@ -1813,12 +1864,17 @@ int PMPI_Scatterv(void *sendbuf, int *sendcounts, int *displs, #endif if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (sendtype == MPI_DATATYPE_NULL - || recvtype == MPI_DATATYPE_NULL) { - retval = MPI_ERR_TYPE; } else if (sendcounts == NULL || displs == NULL) { retval = MPI_ERR_ARG; + } else if (((smpi_comm_rank(comm)==root) && (sendtype == MPI_DATATYPE_NULL)) + || ((recvbuf !=MPI_IN_PLACE) && (recvtype == MPI_DATATYPE_NULL))) { + retval = MPI_ERR_TYPE; } else { + + if(recvbuf==MPI_IN_PLACE){ + recvcount=0; + } + smpi_mpi_scatterv(sendbuf, sendcounts, displs, sendtype, recvbuf, recvcount, recvtype, root, comm); retval = MPI_SUCCESS; @@ -1848,7 +1904,19 @@ int PMPI_Reduce(void *sendbuf, void *recvbuf, int count, } else if (datatype == MPI_DATATYPE_NULL || op == MPI_OP_NULL) { retval = MPI_ERR_ARG; } else { - mpi_coll_reduce_fun(sendbuf, recvbuf, count, datatype, op, root, comm); + + char* sendtmpbuf = (char*) sendbuf; + if( sendbuf == MPI_IN_PLACE ) { + sendtmpbuf = (char *)xbt_malloc(count*smpi_datatype_get_extent(datatype)); + smpi_datatype_copy(recvbuf, count, datatype,sendtmpbuf, count, datatype); + } + + mpi_coll_reduce_fun(sendtmpbuf, recvbuf, count, datatype, op, root, comm); + + if( sendbuf == MPI_IN_PLACE ) { + xbt_free(sendtmpbuf); + } + retval = MPI_SUCCESS; } #ifdef HAVE_TRACING @@ -1892,8 +1960,21 @@ int PMPI_Allreduce(void *sendbuf, void *recvbuf, int count, } else if (op == MPI_OP_NULL) { retval = MPI_ERR_OP; } else { - mpi_coll_allreduce_fun(sendbuf, recvbuf, count, datatype, op, comm); + + char* sendtmpbuf = (char*) sendbuf; + if( sendbuf == MPI_IN_PLACE ) { + sendtmpbuf = (char *)xbt_malloc(count*smpi_datatype_get_extent(datatype)); + smpi_datatype_copy(recvbuf, count, datatype,sendtmpbuf, count, datatype); + } + + mpi_coll_allreduce_fun(sendtmpbuf, recvbuf, count, datatype, op, comm); + + if( sendbuf == MPI_IN_PLACE ) { + xbt_free(sendtmpbuf); + } + retval = MPI_SUCCESS; + } #ifdef HAVE_TRACING TRACE_smpi_collective_out(rank, -1, __FUNCTION__); @@ -1951,8 +2032,12 @@ int PMPI_Reduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, } else if (recvcounts == NULL) { retval = MPI_ERR_ARG; } else { + void* sendtmpbuf=sendbuf; + if(sendbuf==MPI_IN_PLACE){ + sendtmpbuf=recvbuf; + } - mpi_coll_reduce_scatter_fun(sendbuf, recvbuf, recvcounts, + mpi_coll_reduce_scatter_fun(sendtmpbuf, recvbuf, recvcounts, datatype, op, comm); retval = MPI_SUCCESS; }