From 2a7d4780fb29f8e689819f284389fbc7f9ce8543 Mon Sep 17 00:00:00 2001 From: degomme Date: Thu, 4 Apr 2019 16:22:09 +0200 Subject: [PATCH 1/1] Add checks for comms and datatypes as well --- src/smpi/bindings/smpi_pmpi_request.cpp | 20 ++++----- src/smpi/bindings/smpi_pmpi_type.cpp | 55 ++++++++++++++++--------- src/smpi/mpi/smpi_datatype.cpp | 4 +- 3 files changed, 47 insertions(+), 32 deletions(-) diff --git a/src/smpi/bindings/smpi_pmpi_request.cpp b/src/smpi/bindings/smpi_pmpi_request.cpp index 790d371190..97e9200ecd 100644 --- a/src/smpi/bindings/smpi_pmpi_request.cpp +++ b/src/smpi/bindings/smpi_pmpi_request.cpp @@ -29,7 +29,7 @@ int PMPI_Send_init(void *buf, int count, MPI_Datatype datatype, int dst, int tag retval = MPI_ERR_ARG; } else if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if (dst == MPI_PROC_NULL) { retval = MPI_SUCCESS; @@ -52,7 +52,7 @@ int PMPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int src, int tag retval = MPI_ERR_ARG; } else if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if (src == MPI_PROC_NULL) { retval = MPI_SUCCESS; @@ -75,7 +75,7 @@ int PMPI_Ssend_init(void* buf, int count, MPI_Datatype datatype, int dst, int ta retval = MPI_ERR_ARG; } else if (comm == MPI_COMM_NULL) { retval = MPI_ERR_COMM; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if (dst == MPI_PROC_NULL) { retval = MPI_SUCCESS; @@ -194,7 +194,7 @@ int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MP retval = MPI_ERR_RANK; } else if ((count < 0) || (buf==nullptr && count > 0)) { retval = MPI_ERR_COUNT; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if(tag<0 && tag != MPI_ANY_TAG){ retval = MPI_ERR_TAG; @@ -236,7 +236,7 @@ int PMPI_Isend(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MP retval = MPI_ERR_RANK; } else if ((count < 0) || (buf==nullptr && count > 0)) { retval = MPI_ERR_COUNT; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if(tag<0 && tag != MPI_ANY_TAG){ retval = MPI_ERR_TAG; @@ -278,7 +278,7 @@ int PMPI_Issend(void* buf, int count, MPI_Datatype datatype, int dst, int tag, M retval = MPI_ERR_RANK; } else if ((count < 0)|| (buf==nullptr && count > 0)) { retval = MPI_ERR_COUNT; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if(tag<0 && tag != MPI_ANY_TAG){ retval = MPI_ERR_TAG; @@ -320,7 +320,7 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI retval = MPI_ERR_RANK; } else if ((count < 0) || (buf==nullptr && count > 0)) { retval = MPI_ERR_COUNT; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if(tag<0 && tag != MPI_ANY_TAG){ retval = MPI_ERR_TAG; @@ -365,7 +365,7 @@ int PMPI_Send(void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI retval = MPI_ERR_RANK; } else if ((count < 0) || (buf == nullptr && count > 0)) { retval = MPI_ERR_COUNT; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if(tag < 0 && tag != MPI_ANY_TAG){ retval = MPI_ERR_TAG; @@ -403,7 +403,7 @@ int PMPI_Ssend(void* buf, int count, MPI_Datatype datatype, int dst, int tag, MP retval = MPI_ERR_RANK; } else if ((count < 0) || (buf==nullptr && count > 0)) { retval = MPI_ERR_COUNT; - } else if (not datatype->is_valid()) { + } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { retval = MPI_ERR_TYPE; } else if(tag<0 && tag != MPI_ANY_TAG){ retval = MPI_ERR_TAG; @@ -490,7 +490,7 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst, MPI_Comm comm, MPI_Status* status) { int retval = 0; - if (not datatype->is_valid()) { + if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) { return MPI_ERR_TYPE; } else if (count < 0) { return MPI_ERR_COUNT; diff --git a/src/smpi/bindings/smpi_pmpi_type.cpp b/src/smpi/bindings/smpi_pmpi_type.cpp index 344b038da6..20a79ae9c5 100644 --- a/src/smpi/bindings/smpi_pmpi_type.cpp +++ b/src/smpi/bindings/smpi_pmpi_type.cpp @@ -13,8 +13,8 @@ XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi); int PMPI_Type_free(MPI_Datatype * datatype) { /* Free a predefined datatype is an error according to the standard, and should be checked for */ - if (*datatype == MPI_DATATYPE_NULL) { - return MPI_ERR_ARG; + if (*datatype == MPI_DATATYPE_NULL || (*datatype)->flags() & DT_FLAG_PREDEFINED) { + return MPI_ERR_TYPE; } else { simgrid::smpi::Datatype::unref(*datatype); return MPI_SUCCESS; @@ -134,8 +134,10 @@ int PMPI_Type_commit(MPI_Datatype* datatype) { int PMPI_Type_vector(int count, int blocklen, int stride, MPI_Datatype old_type, MPI_Datatype* new_type) { if (old_type == MPI_DATATYPE_NULL) { return MPI_ERR_TYPE; - } else if (count<0 || blocklen<0){ + } else if (count<0){ return MPI_ERR_COUNT; + } else if(blocklen<0){ + return MPI_ERR_ARG; } else { return simgrid::smpi::Datatype::create_vector(count, blocklen, stride, old_type, new_type); } @@ -144,8 +146,10 @@ int PMPI_Type_vector(int count, int blocklen, int stride, MPI_Datatype old_type, int PMPI_Type_hvector(int count, int blocklen, MPI_Aint stride, MPI_Datatype old_type, MPI_Datatype* new_type) { if (old_type == MPI_DATATYPE_NULL) { return MPI_ERR_TYPE; - } else if (count<0 || blocklen<0){ + } else if (count<0){ return MPI_ERR_COUNT; + } else if(blocklen<0){ + return MPI_ERR_ARG; } else { return simgrid::smpi::Datatype::create_hvector(count, blocklen, stride, old_type, new_type); } @@ -228,6 +232,9 @@ int PMPI_Type_struct(int count, int* blocklens, MPI_Aint* indices, MPI_Datatype* if (count<0){ return MPI_ERR_COUNT; } else { + for(int i=0; iis_valid()) + } else if (type == MPI_DATATYPE_NULL || not type->is_valid()){ return MPI_ERR_TYPE; - if(comm==MPI_COMM_NULL) + } else if(comm==MPI_COMM_NULL){ return MPI_ERR_COMM; - return type->unpack(inbuf, incount, position, outbuf,outcount, comm); + } else{ + return type->unpack(inbuf, incount, position, outbuf,outcount, comm); + } } int PMPI_Pack(void* inbuf, int incount, MPI_Datatype type, void* outbuf, int outcount, int* position, MPI_Comm comm) { - if(incount<0 || outcount < 0|| inbuf==nullptr || outbuf==nullptr) + if(incount<0){ + return MPI_ERR_COUNT; + } else if(inbuf==nullptr || outbuf==nullptr || outcount < 0){ return MPI_ERR_ARG; - if (not type->is_valid()) + } else if (type == MPI_DATATYPE_NULL || not type->is_valid()){ return MPI_ERR_TYPE; - if(comm==MPI_COMM_NULL) + } else if(comm==MPI_COMM_NULL){ return MPI_ERR_COMM; - return type->pack(inbuf == MPI_BOTTOM ? nullptr : inbuf, incount, outbuf, outcount, position, comm); + } else { + return type->pack(inbuf == MPI_BOTTOM ? nullptr : inbuf, incount, outbuf, outcount, position, comm); + } } int PMPI_Pack_size(int incount, MPI_Datatype datatype, MPI_Comm comm, int* size) { - if(incount<0) - return MPI_ERR_ARG; - if (not datatype->is_valid()) + if(incount<0){ + return MPI_ERR_COUNT; + } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){ return MPI_ERR_TYPE; - if(comm==MPI_COMM_NULL) + } else if(comm==MPI_COMM_NULL){ return MPI_ERR_COMM; - - *size=incount*datatype->size(); - - return MPI_SUCCESS; + } else { + *size=incount*datatype->size(); + return MPI_SUCCESS; + } } diff --git a/src/smpi/mpi/smpi_datatype.cpp b/src/smpi/mpi/smpi_datatype.cpp index 8d7a45785b..b8c9812647 100644 --- a/src/smpi/mpi/smpi_datatype.cpp +++ b/src/smpi/mpi/smpi_datatype.cpp @@ -250,7 +250,7 @@ void Datatype::set_name(char* name){ int Datatype::pack(void* inbuf, int incount, void* outbuf, int outcount, int* position,MPI_Comm comm){ if (outcount - *position < incount*static_cast(size_)) - return MPI_ERR_BUFFER; + return MPI_ERR_OTHER; Datatype::copy(inbuf, incount, this, static_cast(outbuf) + *position, outcount, MPI_CHAR); *position += incount * size_; return MPI_SUCCESS; @@ -258,7 +258,7 @@ int Datatype::pack(void* inbuf, int incount, void* outbuf, int outcount, int* po int Datatype::unpack(void* inbuf, int insize, int* position, void* outbuf, int outcount,MPI_Comm comm){ if (outcount*static_cast(size_)> insize) - return MPI_ERR_BUFFER; + return MPI_ERR_OTHER; Datatype::copy(static_cast(inbuf) + *position, insize, MPI_CHAR, outbuf, outcount, this); *position += outcount * size_; return MPI_SUCCESS; -- 2.20.1