Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Unify input checking using shared macros. Avoid repeating code.
authorAugustin Degomme <adegomme@gmail.com>
Mon, 9 Dec 2019 17:43:42 +0000 (18:43 +0100)
committerAugustin Degomme <adegomme@gmail.com>
Mon, 9 Dec 2019 17:44:13 +0000 (18:44 +0100)
src/smpi/bindings/smpi_pmpi_coll.cpp
src/smpi/bindings/smpi_pmpi_file.cpp
src/smpi/bindings/smpi_pmpi_request.cpp
src/smpi/include/private.hpp

index 446501d..0a97fa0 100644 (file)
 
 XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi);
 
-#define CHECK_ARGS(test, errcode, ...)                                                                                 \
-  if (test) {                                                                                                          \
-    XBT_WARN(__VA_ARGS__);                                                                                             \
-    return (errcode);                                                                                                  \
-  }
-
-#define CHECK_COMM(num)\
-  CHECK_ARGS(comm == MPI_COMM_NULL, MPI_ERR_COMM,\
-             "%s: param %d communicator cannot be MPI_COMM_NULL", __func__, num);
-#define CHECK_REQUEST(num)\
-  CHECK_ARGS(request == nullptr, MPI_ERR_ARG,\
-             "%s: param %d request cannot be NULL",__func__, num);
-#define CHECK_BUFFER(num,buf,count)\
-  CHECK_ARGS(buf == nullptr && count > 0, MPI_ERR_BUFFER,\
-             "%s: param %d %s cannot be NULL if %s > 0",__func__, num, #buf, #count);
-#define CHECK_COUNT(num,count)\
-  CHECK_ARGS(count < 0, MPI_ERR_COUNT,\
-             "%s: param %d %s cannot be negative", __func__, num, #count);
-#define CHECK_TYPE(num, datatype)\
-  CHECK_ARGS((datatype == MPI_DATATYPE_NULL|| not datatype->is_valid()), MPI_ERR_TYPE,\
-             "%s: param %d %s cannot be MPI_DATATYPE_NULL or invalid", __func__, num, #datatype);
-#define CHECK_OP(num)\
-  CHECK_ARGS(op == MPI_OP_NULL, MPI_ERR_OP,\
-             "%s: param %d op cannot be MPI_OP_NULL or invalid", __func__, num);
-#define CHECK_ROOT(num)\
-  CHECK_ARGS((root < 0 || root >= comm->size()), MPI_ERR_ROOT,\
-             "%s: param %d root (=%d) cannot be negative or larger than communicator size (=%d)", __func__, num, root,\
-             comm->size());
-#define CHECK_NULL(num,err,buf)\
-  CHECK_ARGS(buf == nullptr, err,\
-             "%s: param %d %s cannot be NULL", __func__, num, #buf);
-
   static const void* smpi_get_in_place_buf(const void* inplacebuf, const void* otherbuf,std::unique_ptr<unsigned char[]>& tmp_sendbuf, int count, MPI_Datatype datatype){
   if (inplacebuf == MPI_IN_PLACE) {
       tmp_sendbuf.reset(new unsigned char[count * datatype->get_extent()]);
index d8a1ceb..32e1750 100644 (file)
@@ -8,6 +8,8 @@
 #include "smpi_file.hpp"
 #include "smpi_datatype.hpp"
 
+XBT_LOG_EXTERNAL_DEFAULT_CATEGORY(smpi_pmpi);
+
 extern MPI_Errhandler SMPI_default_File_Errhandler;
 
 int PMPI_File_open(MPI_Comm comm, const char *filename, int amode, MPI_Info info, MPI_File *fh){
@@ -38,21 +40,11 @@ int PMPI_File_close(MPI_File *fh){
   smpi_bench_begin();
   return ret;
 }
-#define CHECK_FILE(fh)                                                                                                 \
-  if ((fh) == MPI_FILE_NULL)                                                                                           \
-    return MPI_ERR_FILE;
-#define CHECK_BUFFER(buf, count)                                                                                       \
-  if ((buf) == nullptr && (count) > 0)                                                                                 \
-    return MPI_ERR_BUFFER;
-#define CHECK_COUNT(count)                                                                                             \
-  if ((count) < 0)                                                                                                     \
-    return MPI_ERR_COUNT;
+
+
 #define CHECK_OFFSET(offset)                                                                                           \
   if ((offset) < 0)                                                                                                    \
     return MPI_ERR_DISP;
-#define CHECK_DATATYPE(datatype, count)                                                                                \
-  if ((datatype) == MPI_DATATYPE_NULL && (count) > 0)                                                                  \
-    return MPI_ERR_TYPE;
 #define CHECK_STATUS(status)                                                                                           \
   if ((status) == nullptr)                                                                                             \
     return MPI_ERR_ARG;
@@ -70,7 +62,7 @@ int PMPI_File_close(MPI_File *fh){
   }
 
 int PMPI_File_seek(MPI_File fh, MPI_Offset offset, int whence){
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   smpi_bench_end();
   int ret = fh->seek(offset,whence);
   smpi_bench_begin();
@@ -78,7 +70,7 @@ int PMPI_File_seek(MPI_File fh, MPI_Offset offset, int whence){
 }
 
 int PMPI_File_seek_shared(MPI_File fh, MPI_Offset offset, int whence){
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   smpi_bench_end();
   int ret = fh->seek_shared(offset,whence);
   smpi_bench_begin();
@@ -95,7 +87,7 @@ int PMPI_File_get_position(MPI_File fh, MPI_Offset* offset){
 }
 
 int PMPI_File_get_position_shared(MPI_File fh, MPI_Offset* offset){
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   if (offset==nullptr)
     return MPI_ERR_DISP;
   smpi_bench_end();
@@ -105,10 +97,10 @@ int PMPI_File_get_position_shared(MPI_File fh, MPI_Offset* offset){
 }
 
 int PMPI_File_read(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   PASS_ZEROCOUNT(count)
@@ -122,10 +114,10 @@ int PMPI_File_read(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_
 }
 
 int PMPI_File_read_shared(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   PASS_ZEROCOUNT(count)
@@ -140,10 +132,10 @@ int PMPI_File_read_shared(MPI_File fh, void *buf, int count,MPI_Datatype datatyp
 }
 
 int PMPI_File_write(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   CHECK_RDONLY(fh)
@@ -158,10 +150,10 @@ int PMPI_File_write(MPI_File fh, const void *buf, int count,MPI_Datatype datatyp
 }
 
 int PMPI_File_write_shared(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   CHECK_RDONLY(fh)
@@ -177,10 +169,10 @@ int PMPI_File_write_shared(MPI_File fh, const void *buf, int count,MPI_Datatype
 }
 
 int PMPI_File_read_all(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   smpi_bench_end();
@@ -193,10 +185,10 @@ int PMPI_File_read_all(MPI_File fh, void *buf, int count,MPI_Datatype datatype,
 }
 
 int PMPI_File_read_ordered(MPI_File fh, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   smpi_bench_end();
@@ -210,10 +202,10 @@ int PMPI_File_read_ordered(MPI_File fh, void *buf, int count,MPI_Datatype dataty
 }
 
 int PMPI_File_write_all(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   CHECK_RDONLY(fh)
@@ -227,10 +219,10 @@ int PMPI_File_write_all(MPI_File fh, const void *buf, int count,MPI_Datatype dat
 }
 
 int PMPI_File_write_ordered(MPI_File fh, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   CHECK_RDONLY(fh)
@@ -245,11 +237,11 @@ int PMPI_File_write_ordered(MPI_File fh, const void *buf, int count,MPI_Datatype
 }
 
 int PMPI_File_read_at(MPI_File fh, MPI_Offset offset, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
   CHECK_OFFSET(offset)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   PASS_ZEROCOUNT(count);
@@ -266,11 +258,11 @@ int PMPI_File_read_at(MPI_File fh, MPI_Offset offset, void *buf, int count,MPI_D
 }
 
 int PMPI_File_read_at_all(MPI_File fh, MPI_Offset offset, void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
   CHECK_OFFSET(offset)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_COUNT(3, count)
+  CHECK_TYPE(4, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   smpi_bench_end();
@@ -287,11 +279,11 @@ int PMPI_File_read_at_all(MPI_File fh, MPI_Offset offset, void *buf, int count,M
 }
 
 int PMPI_File_write_at(MPI_File fh, MPI_Offset offset, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
   CHECK_OFFSET(offset)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_COUNT(4, count)
+  CHECK_TYPE(5, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   CHECK_RDONLY(fh)
@@ -309,11 +301,11 @@ int PMPI_File_write_at(MPI_File fh, MPI_Offset offset, const void *buf, int coun
 }
 
 int PMPI_File_write_at_all(MPI_File fh, MPI_Offset offset, const void *buf, int count,MPI_Datatype datatype, MPI_Status *status){
-  CHECK_FILE(fh)
-  CHECK_BUFFER(buf, count)
+  CHECK_FILE(1, fh)
+  CHECK_BUFFER(2, buf, count)
   CHECK_OFFSET(offset)
-  CHECK_COUNT(count)
-  CHECK_DATATYPE(datatype, count)
+  CHECK_COUNT(4, count)
+  CHECK_TYPE(5, datatype)
   CHECK_STATUS(status)
   CHECK_FLAGS(fh)
   CHECK_RDONLY(fh)
@@ -341,42 +333,42 @@ int PMPI_File_delete(const char *filename, MPI_Info info){
 
 int PMPI_File_get_info(MPI_File  fh, MPI_Info* info)
 {
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   *info = fh->info();
   return MPI_SUCCESS;
 }
 
 int PMPI_File_set_info(MPI_File  fh, MPI_Info info)
 {
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   fh->set_info(info);
   return MPI_SUCCESS;
 }
 
 int PMPI_File_get_size(MPI_File  fh, MPI_Offset* size)
 {
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   *size = fh->size();
   return MPI_SUCCESS;
 }
 
 int PMPI_File_get_amode(MPI_File  fh, int* amode)
 {
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   *amode = fh->flags();
   return MPI_SUCCESS;
 }
 
 int PMPI_File_get_group(MPI_File  fh, MPI_Group* group)
 {
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   *group = fh->comm()->group();
   return MPI_SUCCESS;
 }
 
 int PMPI_File_sync(MPI_File  fh)
 {
-  CHECK_FILE(fh)
+  CHECK_FILE(1, fh)
   fh->sync();
   return MPI_SUCCESS;
 }
index 62e4822..3caec95 100644 (file)
@@ -18,52 +18,49 @@ static int getPid(MPI_Comm comm, int id)
   return (actor == nullptr) ? MPI_UNDEFINED : actor->get_pid();
 }
 
+#define CHECK_SEND_INPUTS\
+  CHECK_BUFFER(1, buf, count)\
+  CHECK_COUNT(2, count)\
+  CHECK_TYPE(3, datatype)\
+  CHECK_PROC(4, dst)\
+  CHECK_TAG(5, tag)\
+  CHECK_COMM(6)\
+
+#define CHECK_ISEND_INPUTS\
+  CHECK_REQUEST(7)\
+  *request = MPI_REQUEST_NULL;\
+  CHECK_SEND_INPUTS
+  
+#define CHECK_IRECV_INPUTS\
+  CHECK_REQUEST(7)\
+  *request = MPI_REQUEST_NULL;\
+  CHECK_BUFFER(1, buf, count)\
+  CHECK_COUNT(2, count)\
+  CHECK_TYPE(3, datatype)\
+  CHECK_PROC(4, src)\
+  CHECK_TAG(5, tag)\
+  CHECK_COMM(6)
 /* PMPI User level calls */
 
 int PMPI_Send_init(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request)
 {
-  int retval = 0;
+  CHECK_ISEND_INPUTS
 
   smpi_bench_end();
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if (dst == MPI_PROC_NULL) {
-    retval = MPI_SUCCESS;
-  } else {
-    *request = simgrid::smpi::Request::send_init(buf, count, datatype, dst, tag, comm);
-    retval   = MPI_SUCCESS;
-  }
+  *request = simgrid::smpi::Request::send_init(buf, count, datatype, dst, tag, comm);
   smpi_bench_begin();
-  if (retval != MPI_SUCCESS && request != nullptr)
-    *request = MPI_REQUEST_NULL;
-  return retval;
+
+  return MPI_SUCCESS;
 }
 
 int PMPI_Recv_init(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI_Comm comm, MPI_Request * request)
 {
-  int retval = 0;
+  CHECK_IRECV_INPUTS
 
   smpi_bench_end();
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if (src == MPI_PROC_NULL) {
-    retval = MPI_SUCCESS;
-  } else {
-    *request = simgrid::smpi::Request::recv_init(buf, count, datatype, src, tag, comm);
-    retval = MPI_SUCCESS;
-  }
+  *request = simgrid::smpi::Request::recv_init(buf, count, datatype, src, tag, comm);
   smpi_bench_begin();
-  if (retval != MPI_SUCCESS && request != nullptr)
-    *request = MPI_REQUEST_NULL;
-  return retval;
+  return MPI_SUCCESS;
 }
 
 int PMPI_Rsend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm,
@@ -74,24 +71,14 @@ int PMPI_Rsend_init(const void* buf, int count, MPI_Datatype datatype, int dst,
 
 int PMPI_Ssend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
 {
-  int retval = 0;
+  CHECK_ISEND_INPUTS
 
+  int retval = 0;
   smpi_bench_end();
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if (dst == MPI_PROC_NULL) {
-    retval = MPI_SUCCESS;
-  } else {
-    *request = simgrid::smpi::Request::ssend_init(buf, count, datatype, dst, tag, comm);
-    retval = MPI_SUCCESS;
-  }
+  *request = simgrid::smpi::Request::ssend_init(buf, count, datatype, dst, tag, comm);
+  retval = MPI_SUCCESS;
+
   smpi_bench_begin();
-  if (retval != MPI_SUCCESS && request != nullptr)
-    *request = MPI_REQUEST_NULL;
   return retval;
 }
 
@@ -105,7 +92,8 @@ int PMPI_Start(MPI_Request * request)
   int retval = 0;
 
   smpi_bench_end();
-  if (request == nullptr || *request == MPI_REQUEST_NULL) {
+  CHECK_REQUEST(1)
+  if ( *request == MPI_REQUEST_NULL) {
     retval = MPI_ERR_REQUEST;
   } else {
     MPI_Request req = *request;
@@ -183,25 +171,12 @@ int PMPI_Request_free(MPI_Request * request)
 
 int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI_Comm comm, MPI_Request * request)
 {
-  int retval = 0;
+  CHECK_IRECV_INPUTS
 
   smpi_bench_end();
-
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (src == MPI_PROC_NULL) {
-    *request = MPI_REQUEST_NULL;
-    retval = MPI_SUCCESS;
-  } else if (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0)){
+  int retval = 0;
+  if (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0)){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0) || (buf==nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag<0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id = simgrid::s4u::this_actor::get_pid();
 
@@ -217,32 +192,18 @@ int PMPI_Irecv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MP
   }
 
   smpi_bench_begin();
-  if (retval != MPI_SUCCESS && request != nullptr)
-    *request = MPI_REQUEST_NULL;
   return retval;
 }
 
 
 int PMPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request * request)
 {
-  int retval = 0;
+  CHECK_ISEND_INPUTS
 
   smpi_bench_end();
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (dst == MPI_PROC_NULL) {
-    *request = MPI_REQUEST_NULL;
-    retval = MPI_SUCCESS;
-  } else if (dst >= comm->group()->size() || dst <0){
+  int retval = 0;
+  if (dst >= comm->group()->size() || dst <0){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0) || (buf==nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag<0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id = simgrid::s4u::this_actor::get_pid();
     int trace_dst = getPid(comm, dst);
@@ -260,8 +221,7 @@ int PMPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dst, int t
   }
 
   smpi_bench_begin();
-  if (retval != MPI_SUCCESS && request!=nullptr)
-    *request = MPI_REQUEST_NULL;
+
   return retval;
 }
 
@@ -273,24 +233,12 @@ int PMPI_Irsend(const void* buf, int count, MPI_Datatype datatype, int dst, int
 
 int PMPI_Issend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
 {
-  int retval = 0;
+  CHECK_ISEND_INPUTS
 
   smpi_bench_end();
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (dst == MPI_PROC_NULL) {
-    *request = MPI_REQUEST_NULL;
-    retval = MPI_SUCCESS;
-  } else if (dst >= comm->group()->size() || dst <0){
+  int retval = 0;
+  if (dst >= comm->group()->size() || dst <0){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0)|| (buf==nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag<0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id = simgrid::s4u::this_actor::get_pid();
     int trace_dst = getPid(comm, dst);
@@ -307,8 +255,6 @@ int PMPI_Issend(const void* buf, int count, MPI_Datatype datatype, int dst, int
   }
 
   smpi_bench_begin();
-  if (retval != MPI_SUCCESS && request!=nullptr)
-    *request = MPI_REQUEST_NULL;
   return retval;
 }
 
@@ -316,10 +262,14 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI
 {
   int retval = 0;
 
+  CHECK_BUFFER(1, buf, count)
+  CHECK_COUNT(2, count)
+  CHECK_TYPE(3, datatype)
+  CHECK_TAG(5, tag)
+  CHECK_COMM(6)
+
   smpi_bench_end();
-  if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (src == MPI_PROC_NULL) {
+  if (src == MPI_PROC_NULL) {
     if(status != MPI_STATUS_IGNORE){
       simgrid::smpi::Status::empty(status);
       status->MPI_SOURCE = MPI_PROC_NULL;
@@ -327,12 +277,6 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI
     retval = MPI_SUCCESS;
   } else if (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0)){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0) || (buf==nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag<0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id = simgrid::s4u::this_actor::get_pid();
     TRACE_smpi_comm_in(my_proc_id, __func__,
@@ -362,22 +306,12 @@ int PMPI_Recv(void *buf, int count, MPI_Datatype datatype, int src, int tag, MPI
 
 int PMPI_Send(const void *buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm)
 {
-  int retval = 0;
+  CHECK_SEND_INPUTS
 
   smpi_bench_end();
-
-  if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (dst == MPI_PROC_NULL) {
-    retval = MPI_SUCCESS;
-  } else if (dst >= comm->group()->size() || dst <0){
+  int retval = 0;
+  if (dst >= comm->group()->size() || dst <0){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0) || (buf == nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag < 0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id         = simgrid::s4u::this_actor::get_pid();
     int dst_traced         = getPid(comm, dst);
@@ -406,22 +340,12 @@ int PMPI_Rsend(const void* buf, int count, MPI_Datatype datatype, int dst, int t
 
 int PMPI_Bsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm)
 {
-  int retval = 0;
+  CHECK_SEND_INPUTS
 
   smpi_bench_end();
-
-  if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (dst == MPI_PROC_NULL) {
-    retval = MPI_SUCCESS;
-  } else if (dst >= comm->group()->size() || dst <0){
+  int retval = 0;
+  if (dst >= comm->group()->size() || dst <0){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0) || (buf == nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag < 0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id         = simgrid::s4u::this_actor::get_pid();
     int dst_traced         = getPid(comm, dst);
@@ -451,24 +375,12 @@ int PMPI_Bsend(const void* buf, int count, MPI_Datatype datatype, int dst, int t
 
 int PMPI_Ibsend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
 {
-  int retval = 0;
+  CHECK_ISEND_INPUTS
 
   smpi_bench_end();
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (dst == MPI_PROC_NULL) {
-    *request = MPI_REQUEST_NULL;
-    retval = MPI_SUCCESS;
-  } else if (dst >= comm->group()->size() || dst <0){
+  int retval = 0;
+  if (dst >= comm->group()->size() || dst <0){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0) || (buf==nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag<0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id = simgrid::s4u::this_actor::get_pid();
     int trace_dst = getPid(comm, dst);
@@ -499,51 +411,31 @@ int PMPI_Ibsend(const void* buf, int count, MPI_Datatype datatype, int dst, int
 
 int PMPI_Bsend_init(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm, MPI_Request* request)
 {
-  int retval = 0;
+  CHECK_ISEND_INPUTS
 
   smpi_bench_end();
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if (dst == MPI_PROC_NULL) {
-    retval = MPI_SUCCESS;
+  int retval = 0;
+  int bsend_buf_size = 0;
+  void* bsend_buf = nullptr;
+  smpi_process()->bsend_buffer(&bsend_buf, &bsend_buf_size);
+  if( bsend_buf==nullptr || bsend_buf_size < datatype->get_extent() * count + MPI_BSEND_OVERHEAD ) {
+    retval = MPI_ERR_BUFFER;
   } else {
-    int bsend_buf_size = 0;
-    void* bsend_buf = nullptr;
-    smpi_process()->bsend_buffer(&bsend_buf, &bsend_buf_size);
-    if( bsend_buf==nullptr || bsend_buf_size < datatype->get_extent() * count + MPI_BSEND_OVERHEAD ) {
-      retval = MPI_ERR_BUFFER;
-    } else {
-      *request = simgrid::smpi::Request::bsend_init(buf, count, datatype, dst, tag, comm);
-      retval   = MPI_SUCCESS;
-    }
+    *request = simgrid::smpi::Request::bsend_init(buf, count, datatype, dst, tag, comm);
+    retval   = MPI_SUCCESS;
   }
   smpi_bench_begin();
-  if (retval != MPI_SUCCESS && request != nullptr)
-    *request = MPI_REQUEST_NULL;
   return retval;
 }
 
-int PMPI_Ssend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm) {
-  int retval = 0;
+int PMPI_Ssend(const void* buf, int count, MPI_Datatype datatype, int dst, int tag, MPI_Comm comm)
+{
+  CHECK_SEND_INPUTS
 
   smpi_bench_end();
-
-  if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (dst == MPI_PROC_NULL) {
-    retval = MPI_SUCCESS;
-  } else if (dst >= comm->group()->size() || dst <0){
+  int retval = 0;
+  if (dst >= comm->group()->size() || dst <0){
     retval = MPI_ERR_RANK;
-  } else if ((count < 0) || (buf==nullptr && count > 0)) {
-    retval = MPI_ERR_COUNT;
-  } else if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if(tag<0 && tag !=  MPI_ANY_TAG){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id         = simgrid::s4u::this_actor::get_pid();
     int dst_traced         = getPid(comm, dst);
@@ -569,12 +461,16 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int
   int retval = 0;
 
   smpi_bench_end();
-
-  if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (not sendtype->is_valid() || not recvtype->is_valid()) {
-    retval = MPI_ERR_TYPE;
-  } else if (src == MPI_PROC_NULL) {
+  CHECK_BUFFER(1, sendbuf, sendcount)
+  CHECK_COUNT(2, sendcount)
+  CHECK_TYPE(3, sendtype)
+  CHECK_TAG(5, sendtag)
+  CHECK_BUFFER(6, recvbuf, recvcount)
+  CHECK_COUNT(7, recvcount)
+  CHECK_TYPE(8, recvtype)
+  CHECK_TAG(10, recvtag)
+  CHECK_COMM(11)
+  if (src == MPI_PROC_NULL) {
     if(status!=MPI_STATUS_IGNORE){
       simgrid::smpi::Status::empty(status);
       status->MPI_SOURCE = MPI_PROC_NULL;
@@ -582,17 +478,12 @@ int PMPI_Sendrecv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, int
     if(dst != MPI_PROC_NULL)
       simgrid::smpi::Request::send(sendbuf, sendcount, sendtype, dst, sendtag, comm);
     retval = MPI_SUCCESS;
-  }else if (dst == MPI_PROC_NULL){
+  } else if (dst == MPI_PROC_NULL){
     simgrid::smpi::Request::recv(recvbuf, recvcount, recvtype, src, recvtag, comm, status);
     retval = MPI_SUCCESS;
-  }else if (dst >= comm->group()->size() || dst <0 ||
+  } else if (dst >= comm->group()->size() || dst <0 ||
       (src!=MPI_ANY_SOURCE && (src >= comm->group()->size() || src <0))){
     retval = MPI_ERR_RANK;
-  } else if ((sendcount < 0 || recvcount<0) ||
-      (sendbuf==nullptr && sendcount > 0) || (recvbuf==nullptr && recvcount>0)) {
-    retval = MPI_ERR_COUNT;
-  } else if((sendtag<0 && sendtag !=  MPI_ANY_TAG)||(recvtag<0 && recvtag != MPI_ANY_TAG)){
-    retval = MPI_ERR_TAG;
   } else {
     int my_proc_id         = simgrid::s4u::this_actor::get_pid();
     int dst_traced         = getPid(comm, dst);
@@ -627,19 +518,17 @@ int PMPI_Sendrecv_replace(void* buf, int count, MPI_Datatype datatype, int dst,
                           MPI_Comm comm, MPI_Status* status)
 {
   int retval = 0;
-  if (datatype==MPI_DATATYPE_NULL || not datatype->is_valid()) {
-    return MPI_ERR_TYPE;
-  } else if (count < 0) {
-    return MPI_ERR_COUNT;
-  } else {
-    int size = datatype->get_extent() * count;
-    void* recvbuf = xbt_new0(char, size);
-    retval = MPI_Sendrecv(buf, count, datatype, dst, sendtag, recvbuf, count, datatype, src, recvtag, comm, status);
-    if(retval==MPI_SUCCESS){
-      simgrid::smpi::Datatype::copy(recvbuf, count, datatype, buf, count, datatype);
-    }
-    xbt_free(recvbuf);
+  CHECK_BUFFER(1, buf, count)
+  CHECK_COUNT(2, count)
+  CHECK_TYPE(3, datatype)
+
+  int size = datatype->get_extent() * count;
+  void* recvbuf = xbt_new0(char, size);
+  retval = MPI_Sendrecv(buf, count, datatype, dst, sendtag, recvbuf, count, datatype, src, recvtag, comm, status);
+  if(retval==MPI_SUCCESS){
+    simgrid::smpi::Datatype::copy(recvbuf, count, datatype, buf, count, datatype);
   }
+  xbt_free(recvbuf);
   return retval;
 }
 
@@ -659,7 +548,6 @@ int PMPI_Test(MPI_Request * request, int *flag, MPI_Status * status)
     int my_proc_id = ((*request)->comm() != MPI_COMM_NULL) ? simgrid::s4u::this_actor::get_pid() : -1;
 
     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("test"));
-    
     retval = simgrid::smpi::Request::test(request,status, flag);
 
     TRACE_smpi_comm_out(my_proc_id);
@@ -671,7 +559,7 @@ int PMPI_Test(MPI_Request * request, int *flag, MPI_Status * status)
 int PMPI_Testany(int count, MPI_Request requests[], int *index, int *flag, MPI_Status * status)
 {
   int retval = 0;
-
+  CHECK_COUNT(1, count)
   smpi_bench_end();
   if (index == nullptr || flag == nullptr) {
     retval = MPI_ERR_ARG;
@@ -688,7 +576,7 @@ int PMPI_Testany(int count, MPI_Request requests[], int *index, int *flag, MPI_S
 int PMPI_Testall(int count, MPI_Request* requests, int* flag, MPI_Status* statuses)
 {
   int retval = 0;
-
+  CHECK_COUNT(1, count)
   smpi_bench_end();
   if (flag == nullptr) {
     retval = MPI_ERR_ARG;
@@ -705,7 +593,7 @@ int PMPI_Testall(int count, MPI_Request* requests, int* flag, MPI_Status* status
 int PMPI_Testsome(int incount, MPI_Request requests[], int* outcount, int* indices, MPI_Status status[])
 {
   int retval = 0;
-
+  CHECK_COUNT(1, incount)
   smpi_bench_end();
   if (outcount == nullptr) {
     retval = MPI_ERR_ARG;
@@ -723,9 +611,9 @@ int PMPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status* status) {
   int retval = 0;
   smpi_bench_end();
 
-  if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
-  } else if (source == MPI_PROC_NULL) {
+  CHECK_COMM(6)
+  CHECK_TAG(2, tag)
+  if (source == MPI_PROC_NULL) {
     if (status != MPI_STATUS_IGNORE){
       simgrid::smpi::Status::empty(status);
       status->MPI_SOURCE = MPI_PROC_NULL;
@@ -742,11 +630,10 @@ int PMPI_Probe(int source, int tag, MPI_Comm comm, MPI_Status* status) {
 int PMPI_Iprobe(int source, int tag, MPI_Comm comm, int* flag, MPI_Status* status) {
   int retval = 0;
   smpi_bench_end();
-
+  CHECK_COMM(6)
+  CHECK_TAG(2, tag)
   if (flag == nullptr) {
     retval = MPI_ERR_ARG;
-  } else if (comm == MPI_COMM_NULL) {
-    retval = MPI_ERR_COMM;
   } else if (source == MPI_PROC_NULL) {
     *flag=true;
     if (status != MPI_STATUS_IGNORE){
@@ -787,9 +674,8 @@ int PMPI_Wait(MPI_Request * request, MPI_Status * status)
 
   simgrid::smpi::Status::empty(status);
 
-  if (request == nullptr) {
-    retval = MPI_ERR_ARG;
-  } else if (*request == MPI_REQUEST_NULL) {
+  CHECK_REQUEST(1) 
+  if (*request == MPI_REQUEST_NULL) {
     retval = MPI_SUCCESS;
   } else {
     // for tracing, save the handle which might get overridden before we can use the helper on it
@@ -858,7 +744,7 @@ int PMPI_Waitany(int count, MPI_Request requests[], int *index, MPI_Status * sta
 int PMPI_Waitall(int count, MPI_Request requests[], MPI_Status status[])
 {
   smpi_bench_end();
-
+  CHECK_COUNT(1, count)
   // for tracing, save the handles which might get overridden before we can use the helper on it
   std::vector<MPI_Request> savedreqs(requests, requests + count);
   for (MPI_Request& req : savedreqs) {
@@ -889,7 +775,7 @@ int PMPI_Waitall(int count, MPI_Request requests[], MPI_Status status[])
 int PMPI_Waitsome(int incount, MPI_Request requests[], int *outcount, int *indices, MPI_Status status[])
 {
   int retval = 0;
-
+  CHECK_COUNT(1, incount)
   smpi_bench_end();
   if (outcount == nullptr) {
     retval = MPI_ERR_ARG;
@@ -906,6 +792,7 @@ int PMPI_Cancel(MPI_Request* request)
   int retval = 0;
 
   smpi_bench_end();
+  CHECK_REQUEST(1)
   if (*request == MPI_REQUEST_NULL) {
     retval = MPI_ERR_REQUEST;
   } else {
index 5238bb5..f8cb15f 100644 (file)
@@ -498,4 +498,45 @@ XBT_PUBLIC smpi_trace_call_location_t* smpi_trace_get_call_location();
 
 XBT_PRIVATE void private_execute_flops(double flops);
 
+
+#define CHECK_ARGS(test, errcode, ...)                                                                                 \
+  if (test) {                                                                                                          \
+    XBT_WARN(__VA_ARGS__);                                                                                             \
+    return (errcode);                                                                                                  \
+  }
+
+#define CHECK_COMM(num)                                                                                                \
+  CHECK_ARGS(comm == MPI_COMM_NULL, MPI_ERR_COMM,                                                                      \
+             "%s: param %d communicator cannot be MPI_COMM_NULL", __func__, num);
+#define CHECK_REQUEST(num)                                                                                             \
+  CHECK_ARGS(request == nullptr, MPI_ERR_REQUEST,                                                                      \
+             "%s: param %d request cannot be NULL",__func__, num);
+#define CHECK_BUFFER(num,buf,count)                                                                                    \
+  CHECK_ARGS(buf == nullptr && count > 0, MPI_ERR_BUFFER,                                                              \
+             "%s: param %d %s cannot be NULL if %s > 0",__func__, num, #buf, #count);
+#define CHECK_COUNT(num,count)                                                                                         \
+  CHECK_ARGS(count < 0, MPI_ERR_COUNT,                                                                                 \
+             "%s: param %d %s cannot be negative", __func__, num, #count);
+#define CHECK_TYPE(num, datatype)                                                                                      \
+  CHECK_ARGS((datatype == MPI_DATATYPE_NULL|| not datatype->is_valid()), MPI_ERR_TYPE,                                 \
+             "%s: param %d %s cannot be MPI_DATATYPE_NULL or invalid", __func__, num, #datatype);
+#define CHECK_OP(num)                                                                                                  \
+  CHECK_ARGS(op == MPI_OP_NULL, MPI_ERR_OP,                                                                            \
+             "%s: param %d op cannot be MPI_OP_NULL or invalid", __func__, num);
+#define CHECK_ROOT(num)\
+  CHECK_ARGS((root < 0 || root >= comm->size()), MPI_ERR_ROOT,                                                         \
+             "%s: param %d root (=%d) cannot be negative or larger than communicator size (=%d)", __func__, num, root, \
+             comm->size());
+#define CHECK_NULL(num,err,buf)                                                                                        \
+  CHECK_ARGS(buf == nullptr, err,                                                                                      \
+             "%s: param %d %s cannot be NULL", __func__, num, #buf);
+#define CHECK_PROC(num,proc)                                                                                           \
+  CHECK_ARGS(proc == MPI_PROC_NULL, MPI_SUCCESS,                                                                       \
+             "%s: param %d %s cannot be MPI_PROC_NULL", __func__, num, #proc);
+#define CHECK_TAG(num,tag)                                                                                             \
+  CHECK_ARGS((tag<0 && tag !=  MPI_ANY_TAG), MPI_ERR_TAG,                                                              \
+             "%s: param %d %s cannot be negative", __func__, num, #tag);
+#define CHECK_FILE(num, fh)                                                                                                 \
+  CHECK_ARGS(fh == MPI_FILE_NULL, MPI_ERR_FILE,                                                                       \
+             "%s: param %d %s cannot be MPI_PROC_NULL", __func__, num, #fh);
 #endif