Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Add plenty more checks to MPI collectives, to comply with the standard.
authordegomme <adegomme@users.noreply.github.com>
Thu, 4 Apr 2019 13:29:05 +0000 (15:29 +0200)
committerdegomme <adegomme@users.noreply.github.com>
Thu, 4 Apr 2019 14:52:58 +0000 (16:52 +0200)
Coverage was a bit too high, this should help reducing it.

src/smpi/bindings/smpi_pmpi_coll.cpp
src/smpi/include/smpi_op.hpp
src/smpi/mpi/smpi_op.cpp

index cade107..348b2da 100644 (file)
@@ -55,9 +55,13 @@ int PMPI_Ibcast(void *buf, int count, MPI_Datatype datatype,
   smpi_bench_end();
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid()) {
-    retval = MPI_ERR_ARG;
-  } else if(request == nullptr){
+  } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()) {
+    retval = MPI_ERR_TYPE;
+  } else if (count < 0){
+    retval = MPI_ERR_COUNT;
+  } else if (root < 0 || root >= comm->size()){
+    retval = MPI_ERR_ROOT;
+  }  else if (request == nullptr){
     retval = MPI_ERR_ARG;
   } else {
     int rank = simgrid::s4u::this_actor::get_pid();
@@ -101,6 +105,8 @@ int PMPI_Igather(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvb
     retval = MPI_ERR_TYPE;
   } else if ((( sendbuf != MPI_IN_PLACE) && (sendcount <0)) || ((comm->rank() == root) && (recvcount <0))){
     retval = MPI_ERR_COUNT;
+  } else if (root < 0 || root >= comm->size()){
+    retval = MPI_ERR_ROOT;
   } else if (request == nullptr){
     retval = MPI_ERR_ARG;
   }  else {
@@ -154,6 +160,8 @@ int PMPI_Igatherv(void *sendbuf, int sendcount, MPI_Datatype sendtype, void *rec
     retval = MPI_ERR_COUNT;
   } else if ((comm->rank() == root) && (recvcounts == nullptr || displs == nullptr)) {
     retval = MPI_ERR_ARG;
+  } else if (root < 0 || root >= comm->size()){
+    retval = MPI_ERR_ROOT;
   } else if (request == nullptr){
     retval = MPI_ERR_ARG;
   }  else {
@@ -303,13 +311,18 @@ int PMPI_Iscatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
 
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (((comm->rank() == root) && (not sendtype->is_valid())) ||
-             ((recvbuf != MPI_IN_PLACE) && (not recvtype->is_valid()))) {
+  } else if (((comm->rank() == root) && (sendtype == MPI_DATATYPE_NULL || not sendtype->is_valid())) ||
+             ((recvbuf != MPI_IN_PLACE) && (recvtype == MPI_DATATYPE_NULL || not recvtype->is_valid()))) {
     retval = MPI_ERR_TYPE;
+  } else if (((comm->rank() == root) && (sendcount < 0)) ||
+             ((recvbuf != MPI_IN_PLACE) && (recvcount < 0))) {
+    retval = MPI_ERR_COUNT;
   } else if ((sendbuf == recvbuf) ||
       ((comm->rank()==root) && sendcount>0 && (sendbuf == nullptr))){
     retval = MPI_ERR_BUFFER;
-  }else if (request == nullptr){
+  } else if (root < 0 || root >= comm->size()){
+    retval = MPI_ERR_ROOT;
+  } else if (request == nullptr){
     retval = MPI_ERR_ARG;
   } else {
 
@@ -359,9 +372,15 @@ int PMPI_Iscatterv(void *sendbuf, int *sendcounts, int *displs,
     retval = MPI_ERR_TYPE;
   } else if (request == nullptr){
     retval = MPI_ERR_ARG;
+  } else if (recvbuf != MPI_IN_PLACE && recvcount < 0){
+    retval = MPI_ERR_COUNT;
+  } else if (root < 0 || root >= comm->size()){
+    retval = MPI_ERR_ROOT;
   } else {
     if (recvbuf == MPI_IN_PLACE) {
       recvtype  = sendtype;
+      if(sendcounts[comm->rank()]<0)
+        return MPI_ERR_COUNT;
       recvcount = sendcounts[comm->rank()];
     }
     int rank               = simgrid::s4u::this_actor::get_pid();
@@ -369,8 +388,11 @@ int PMPI_Iscatterv(void *sendbuf, int *sendcounts, int *displs,
 
     std::vector<int>* trace_sendcounts = new std::vector<int>;
     if (comm->rank() == root) {
-      for (int i = 0; i < comm->size(); i++) // copy data to avoid bad free
+      for (int i = 0; i < comm->size(); i++){ // copy data to avoid bad free
         trace_sendcounts->push_back(sendcounts[i] * dt_size_send);
+        if(sendcounts[i]<0)
+          return MPI_ERR_COUNT;
+        }
     }
 
     TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Scatterv":"PMPI_Iscatterv",
@@ -403,10 +425,16 @@ int PMPI_Ireduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
 
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid() || op == MPI_OP_NULL) {
-    retval = MPI_ERR_ARG;
+  } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){
+    retval = MPI_ERR_TYPE;
+  } else if (op == MPI_OP_NULL) {
+    retval = MPI_ERR_OP;
   } else if (request == nullptr){
     retval = MPI_ERR_ARG;
+  } else if (root < 0 || root >= comm->size()){
+    retval = MPI_ERR_ROOT;
+  } else if (count < 0){
+    retval = MPI_ERR_COUNT;
   } else {
     int rank = simgrid::s4u::this_actor::get_pid();
 
@@ -432,9 +460,13 @@ int PMPI_Reduce_local(void *inbuf, void *inoutbuf, int count, MPI_Datatype datat
   int retval = 0;
 
   smpi_bench_end();
-  if (not datatype->is_valid() || op == MPI_OP_NULL) {
-    retval = MPI_ERR_ARG;
-  } else {
+  if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){
+    retval = MPI_ERR_TYPE;
+  } else if (op == MPI_OP_NULL) {
+    retval = MPI_ERR_OP;
+  } else if (count < 0){
+    retval = MPI_ERR_COUNT;
+  }  else {
     op->apply(inbuf, inoutbuf, &count, datatype);
     retval = MPI_SUCCESS;
   }
@@ -455,8 +487,10 @@ int PMPI_Iallreduce(void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty
 
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()) {
     retval = MPI_ERR_TYPE;
+  } else if (count < 0){
+    retval = MPI_ERR_COUNT;
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
   } else if (request == nullptr){
@@ -503,12 +537,14 @@ int PMPI_Iscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, M
 
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){
     retval = MPI_ERR_TYPE;
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
   } else if (request == nullptr){
     retval = MPI_ERR_ARG;
+  } else if (count < 0){
+    retval = MPI_ERR_COUNT;
   } else {
     int rank = simgrid::s4u::this_actor::get_pid();
     void* sendtmpbuf = sendbuf;
@@ -553,6 +589,8 @@ int PMPI_Iexscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
     retval = MPI_ERR_OP;
   } else if (request == nullptr){
     retval = MPI_ERR_ARG;
+  } else if (count < 0){
+    retval = MPI_ERR_COUNT;
   } else {
     int rank         = simgrid::s4u::this_actor::get_pid();
     void* sendtmpbuf = sendbuf;
@@ -591,7 +629,7 @@ int PMPI_Ireduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Data
 
   if (comm == MPI_COMM_NULL) {
     retval = MPI_ERR_COMM;
-  } else if (not datatype->is_valid()) {
+  } else if (datatype == MPI_DATATYPE_NULL || not datatype->is_valid()){
     retval = MPI_ERR_TYPE;
   } else if (op == MPI_OP_NULL) {
     retval = MPI_ERR_OP;
@@ -606,6 +644,8 @@ int PMPI_Ireduce_scatter(void *sendbuf, void *recvbuf, int *recvcounts, MPI_Data
     int totalcount    = 0;
 
     for (int i = 0; i < comm->size(); i++) { // copy data to avoid bad free
+      if(recvcounts[i]<0)
+        return MPI_ERR_COUNT;
       trace_recvcounts->push_back(recvcounts[i] * dt_send_size);
       totalcount += recvcounts[i];
     }
@@ -709,6 +749,8 @@ int PMPI_Ialltoall(void* sendbuf, int sendcount, MPI_Datatype sendtype, void* re
     retval = MPI_ERR_COMM;
   } else if ((sendbuf != MPI_IN_PLACE && sendtype == MPI_DATATYPE_NULL) || recvtype == MPI_DATATYPE_NULL) {
     retval = MPI_ERR_TYPE;
+  } else if ((sendbuf != MPI_IN_PLACE && sendcount < 0) || recvcount < 0){
+    retval = MPI_ERR_COUNT;
   } else if (request == nullptr){
     retval = MPI_ERR_ARG;
   } else {
@@ -781,6 +823,8 @@ int PMPI_Ialltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
     MPI_Datatype sendtmptype = sendtype;
     int maxsize              = 0;
     for (int i = 0; i < size; i++) { // copy data to avoid bad free
+      if (recvcounts[i] <0 || (sendbuf != MPI_IN_PLACE && sendcounts[i]<0))
+        return MPI_ERR_COUNT;
       recv_size += recvcounts[i] * dt_size_recv;
       trace_recvcounts->push_back(recvcounts[i] * dt_size_recv);
       if (((recvdisps[i] + recvcounts[i]) * dt_size_recv) > maxsize)
index c0eb212..dcc39c3 100644 (file)
@@ -17,9 +17,10 @@ class Op : public F2C{
   bool is_commutative_;
   bool is_fortran_op_ = false;
   int refcount_ = 1;
+  bool predefined_;
 
 public:
-  Op(MPI_User_function* function, bool commutative) : func_(function), is_commutative_(commutative) {}
+  Op(MPI_User_function* function, bool commutative, bool predefined=false) : func_(function), is_commutative_(commutative), predefined_(predefined) {}
   bool is_commutative() { return is_commutative_; }
   bool is_fortran_op() { return is_fortran_op_; }
   // tell that we were created from fortran, so we need to translate the type to fortran when called
index bb52d0e..0da76e4 100644 (file)
@@ -196,7 +196,7 @@ static void no_func(void*, void*, int*, MPI_Datatype*)
 }
 
 #define CREATE_MPI_OP(name, func)                             \
-  static SMPI_Op mpi_##name (&(func) /* func */, true ); \
+  static SMPI_Op mpi_##name (&(func) /* func */, true, true ); \
 MPI_Op name = &mpi_##name;
 
 CREATE_MPI_OP(MPI_MAX, max_func);
@@ -249,7 +249,7 @@ void Op::ref(){
 void Op::unref(MPI_Op* op){
   if((*op)!=MPI_OP_NULL){
     (*op)->refcount_--;
-    if((*op)->refcount_==0)
+    if((*op)->refcount_==0 && (*op)->predefined_==false)
       delete(*op);
   }
 }