From: degomme Date: Thu, 4 Apr 2019 13:29:05 +0000 (+0200) Subject: Add plenty more checks to MPI collectives, to comply with the standard. X-Git-Tag: v3.22.2~176 X-Git-Url: http://info.iut-bm.univ-fcomte.fr/pub/gitweb/simgrid.git/commitdiff_plain/ac307dad2932084d3b5b5af7bcd298d057649c1a?hp=1a1c52b967de67cce8c9c5eefab40b8ca7c106a4 Add plenty more checks to MPI collectives, to comply with the standard. Coverage was a bit too high, this should help reducing it. --- diff --git a/src/smpi/bindings/smpi_pmpi_coll.cpp b/src/smpi/bindings/smpi_pmpi_coll.cpp index cade107058..348b2dab2b 100644 --- a/src/smpi/bindings/smpi_pmpi_coll.cpp +++ b/src/smpi/bindings/smpi_pmpi_coll.cpp @@ -55,9 +55,13 @@ int PMPI_Ibcast(void *buf, int count, MPI_Datatype datatype, smpi_bench_end(); if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid()) { - retval = MPI_ERR_ARG; - } else if(request == nullptr){ + } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()) { + retval = MPI_ERR_TYPE; + } else if (count < 0){ + retval = MPI_ERR_COUNT; + } else if (root < 0 || root >= comm->size()){ + retval = MPI_ERR_ROOT; + } else if (request == nullptr){ retval = MPI_ERR_ARG; } else { int rank = simgrid::s4u::this_actor::get_pid(); @@ -101,6 +105,8 @@ int PMPI_Igather(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvb retval = MPI_ERR_TYPE; } else if ((( sendbuf != MPI_IN_PLACE) && (sendcount <0)) || ((comm->rank() == root) && (recvcount <0))){ retval = MPI_ERR_COUNT; + } else if (root < 0 || root >= comm->size()){ + retval = MPI_ERR_ROOT; } else if (request == nullptr){ retval = MPI_ERR_ARG; } else { @@ -154,6 +160,8 @@ int PMPI_Igatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *rec retval = MPI_ERR_COUNT; } else if ((comm->rank() == root) && (recvcounts == nullptr || displs == nullptr)) { retval = MPI_ERR_ARG; + } else if (root < 0 || root >= comm->size()){ + retval = MPI_ERR_ROOT; } else if (request == nullptr){ retval = MPI_ERR_ARG; } else { @@ -303,13 +311,18 @@ int PMPI_Iscatter(void *sendbuf, int sendcount, MPI_Datatype sendtype, if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (((comm->rank() == root) && (not sendtype->is_valid())) || - ((recvbuf != MPI_IN_PLACE) && (not recvtype->is_valid()))) { + } else if (((comm->rank() == root) && (sendtype == MPI_DATATYPE_NULL || not sendtype->is_valid())) || + ((recvbuf != MPI_IN_PLACE) && (recvtype == MPI_DATATYPE_NULL || not recvtype->is_valid()))) { retval = MPI_ERR_TYPE; + } else if (((comm->rank() == root) && (sendcount < 0)) || + ((recvbuf != MPI_IN_PLACE) && (recvcount < 0))) { + retval = MPI_ERR_COUNT; } else if ((sendbuf == recvbuf) || ((comm->rank()==root) && sendcount>0 && (sendbuf == nullptr))){ retval = MPI_ERR_BUFFER; - }else if (request == nullptr){ + } else if (root < 0 || root >= comm->size()){ + retval = MPI_ERR_ROOT; + } else if (request == nullptr){ retval = MPI_ERR_ARG; } else { @@ -359,9 +372,15 @@ int PMPI_Iscatterv(void *sendbuf, int *sendcounts, int *displs, retval = MPI_ERR_TYPE; } else if (request == nullptr){ retval = MPI_ERR_ARG; + } else if (recvbuf != MPI_IN_PLACE && recvcount < 0){ + retval = MPI_ERR_COUNT; + } else if (root < 0 || root >= comm->size()){ + retval = MPI_ERR_ROOT; } else { if (recvbuf == MPI_IN_PLACE) { recvtype = sendtype; + if(sendcounts[comm->rank()]<0) + return MPI_ERR_COUNT; recvcount = sendcounts[comm->rank()]; } int rank = simgrid::s4u::this_actor::get_pid(); @@ -369,8 +388,11 @@ int PMPI_Iscatterv(void *sendbuf, int *sendcounts, int *displs, std::vector* trace_sendcounts = new std::vector; if (comm->rank() == root) { - for (int i = 0; i < comm->size(); i++) // copy data to avoid bad free + for (int i = 0; i < comm->size(); i++){ // copy data to avoid bad free trace_sendcounts->push_back(sendcounts[i] * dt_size_send); + if(sendcounts[i]<0) + return MPI_ERR_COUNT; + } } TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Scatterv":"PMPI_Iscatterv", @@ -403,10 +425,16 @@ int PMPI_Ireduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid() || op == MPI_OP_NULL) { - retval = MPI_ERR_ARG; + } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){ + retval = MPI_ERR_TYPE; + } else if (op == MPI_OP_NULL) { + retval = MPI_ERR_OP; } else if (request == nullptr){ retval = MPI_ERR_ARG; + } else if (root < 0 || root >= comm->size()){ + retval = MPI_ERR_ROOT; + } else if (count < 0){ + retval = MPI_ERR_COUNT; } else { int rank = simgrid::s4u::this_actor::get_pid(); @@ -432,9 +460,13 @@ int PMPI_Reduce_local(void *inbuf, void *inoutbuf, int count, MPI_Datatype datat int retval = 0; smpi_bench_end(); - if (not datatype->is_valid() || op == MPI_OP_NULL) { - retval = MPI_ERR_ARG; - } else { + if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){ + retval = MPI_ERR_TYPE; + } else if (op == MPI_OP_NULL) { + retval = MPI_ERR_OP; + } else if (count < 0){ + retval = MPI_ERR_COUNT; + } else { op->apply(inbuf, inoutbuf, &count, datatype); retval = MPI_SUCCESS; } @@ -455,8 +487,10 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid()) { + } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; + } else if (count < 0){ + retval = MPI_ERR_COUNT; } else if (op == MPI_OP_NULL) { retval = MPI_ERR_OP; } else if (request == nullptr){ @@ -503,12 +537,14 @@ int PMPI_Iscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, M if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid()) { + } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){ retval = MPI_ERR_TYPE; } else if (op == MPI_OP_NULL) { retval = MPI_ERR_OP; } else if (request == nullptr){ retval = MPI_ERR_ARG; + } else if (count < 0){ + retval = MPI_ERR_COUNT; } else { int rank = simgrid::s4u::this_actor::get_pid(); void* sendtmpbuf = sendbuf; @@ -553,6 +589,8 @@ int PMPI_Iexscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, retval = MPI_ERR_OP; } else if (request == nullptr){ retval = MPI_ERR_ARG; + } else if (count < 0){ + retval = MPI_ERR_COUNT; } else { int rank = simgrid::s4u::this_actor::get_pid(); void* sendtmpbuf = sendbuf; @@ -591,7 +629,7 @@ int PMPI_Ireduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Data if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid()) { + } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){ retval = MPI_ERR_TYPE; } else if (op == MPI_OP_NULL) { retval = MPI_ERR_OP; @@ -606,6 +644,8 @@ int PMPI_Ireduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Data int totalcount = 0; for (int i = 0; i < comm->size(); i++) { // copy data to avoid bad free + if(recvcounts[i]<0) + return MPI_ERR_COUNT; trace_recvcounts->push_back(recvcounts[i] * dt_send_size); totalcount += recvcounts[i]; } @@ -709,6 +749,8 @@ int PMPI_Ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendtype, void* re retval = MPI_ERR_COMM; } else if ((sendbuf != MPI_IN_PLACE && sendtype == MPI_DATATYPE_NULL) || recvtype == MPI_DATATYPE_NULL) { retval = MPI_ERR_TYPE; + } else if ((sendbuf != MPI_IN_PLACE && sendcount < 0) || recvcount < 0){ + retval = MPI_ERR_COUNT; } else if (request == nullptr){ retval = MPI_ERR_ARG; } else { @@ -781,6 +823,8 @@ int PMPI_Ialltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype MPI_Datatype sendtmptype = sendtype; int maxsize = 0; for (int i = 0; i < size; i++) { // copy data to avoid bad free + if (recvcounts[i] <0 || (sendbuf != MPI_IN_PLACE && sendcounts[i]<0)) + return MPI_ERR_COUNT; recv_size += recvcounts[i] * dt_size_recv; trace_recvcounts->push_back(recvcounts[i] * dt_size_recv); if (((recvdisps[i] + recvcounts[i]) * dt_size_recv) > maxsize) diff --git a/src/smpi/include/smpi_op.hpp b/src/smpi/include/smpi_op.hpp index c0eb212982..dcc39c3994 100644 --- a/src/smpi/include/smpi_op.hpp +++ b/src/smpi/include/smpi_op.hpp @@ -17,9 +17,10 @@ class Op : public F2C{ bool is_commutative_; bool is_fortran_op_ = false; int refcount_ = 1; + bool predefined_; public: - Op(MPI_User_function* function, bool commutative) : func_(function), is_commutative_(commutative) {} + Op(MPI_User_function* function, bool commutative, bool predefined=false) : func_(function), is_commutative_(commutative), predefined_(predefined) {} bool is_commutative() { return is_commutative_; } bool is_fortran_op() { return is_fortran_op_; } // tell that we were created from fortran, so we need to translate the type to fortran when called diff --git a/src/smpi/mpi/smpi_op.cpp b/src/smpi/mpi/smpi_op.cpp index bb52d0e4ec..0da76e479f 100644 --- a/src/smpi/mpi/smpi_op.cpp +++ b/src/smpi/mpi/smpi_op.cpp @@ -196,7 +196,7 @@ static void no_func(void*, void*, int*, MPI_Datatype*) } #define CREATE_MPI_OP(name, func) \ - static SMPI_Op mpi_##name (&(func) /* func */, true ); \ + static SMPI_Op mpi_##name (&(func) /* func */, true, true ); \ MPI_Op name = &mpi_##name; CREATE_MPI_OP(MPI_MAX, max_func); @@ -249,7 +249,7 @@ void Op::ref(){ void Op::unref(MPI_Op* op){ if((*op)!=MPI_OP_NULL){ (*op)->refcount_--; - if((*op)->refcount_==0) + if((*op)->refcount_==0 && (*op)->predefined_==false) delete(*op); } }