Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
use previous buffer check feature in MPI checks, to crash when a buffer overflow...
[simgrid.git] / src / smpi / bindings / smpi_pmpi_coll.cpp
index b80f397..30eef29 100644 (file)
@@ -62,10 +62,11 @@ int PMPI_Bcast(void *buf, int count, MPI_Datatype datatype, int root, MPI_Comm c
 int PMPI_Ibcast(void *buf, int count, MPI_Datatype datatype, 
                    int root, MPI_Comm comm, MPI_Request* request)
 {
+  SET_BUF1(buf)
   CHECK_COMM(5)
-  CHECK_BUFFER(1, buf, count)
   CHECK_COUNT(2, count)
   CHECK_TYPE(3, datatype)
+  CHECK_BUFFER(1, buf, count, datatype)
   CHECK_ROOT(4)
   CHECK_REQUEST(6)
 
@@ -99,17 +100,19 @@ int PMPI_Igather(const void* sendbuf, int sendcount, MPI_Datatype sendtype, void
                  MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(8)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
   if(sendbuf != MPI_IN_PLACE){
-    CHECK_BUFFER(1,sendbuf, sendcount)
     CHECK_COUNT(2, sendcount)
     CHECK_TYPE(3, sendtype)
+    CHECK_BUFFER(1,sendbuf, sendcount, sendtype)
   }
   if(rank == root){
     CHECK_NOT_IN_PLACE_ROOT(4, recvbuf)
     CHECK_TYPE(6, recvtype)
     CHECK_COUNT(5, recvcount)
-    CHECK_BUFFER(4, recvbuf, recvcount)
+    CHECK_BUFFER(4, recvbuf, recvcount, recvtype)
   } else {
     CHECK_NOT_IN_PLACE_ROOT(1, sendbuf)
   }
@@ -159,12 +162,14 @@ int PMPI_Igatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, voi
                   MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(9)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
-  CHECK_BUFFER(1, sendbuf, sendcount)
   if(sendbuf != MPI_IN_PLACE){
     CHECK_TYPE(3, sendtype)
     CHECK_COUNT(2, sendcount)
   }
+  CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
   if(rank == root){
     CHECK_NOT_IN_PLACE_ROOT(4, recvbuf)
     CHECK_TYPE(6, recvtype)
@@ -179,7 +184,7 @@ int PMPI_Igatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype, voi
   if (rank == root){
     for (int i = 0; i < comm->size(); i++) {
       CHECK_COUNT(5, recvcounts[i])
-      CHECK_BUFFER(4,recvbuf,recvcounts[i])
+      CHECK_BUFFER(4,recvbuf,recvcounts[i], recvtype)
     }
   }
 
@@ -228,9 +233,9 @@ int PMPI_Iallgather(const void* sendbuf, int sendcount, MPI_Datatype sendtype, v
                     MPI_Datatype recvtype, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(7)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
-  CHECK_BUFFER(1, sendbuf, sendcount)
-  CHECK_BUFFER(4, recvbuf, recvcount)
   CHECK_NOT_IN_PLACE(4, recvbuf)
   if(sendbuf != MPI_IN_PLACE){
     CHECK_COUNT(2, sendcount)
@@ -238,6 +243,8 @@ int PMPI_Iallgather(const void* sendbuf, int sendcount, MPI_Datatype sendtype, v
   }
   CHECK_TYPE(6, recvtype)
   CHECK_COUNT(5, recvcount)
+  CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
+  CHECK_BUFFER(4, recvbuf, recvcount, recvtype)
   CHECK_REQUEST(8)
 
   if (sendbuf == MPI_IN_PLACE) {
@@ -280,20 +287,23 @@ int PMPI_Iallgatherv(const void* sendbuf, int sendcount, MPI_Datatype sendtype,
                      MPI_Datatype recvtype, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(8)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
-  CHECK_BUFFER(1, sendbuf, sendcount)
   if(sendbuf != MPI_IN_PLACE)
     CHECK_TYPE(3, sendtype)
   CHECK_TYPE(6, recvtype)
   CHECK_NULL(5, MPI_ERR_COUNT, recvcounts)
   CHECK_NULL(6, MPI_ERR_ARG, displs)
-  if(sendbuf != MPI_IN_PLACE)
+  if(sendbuf != MPI_IN_PLACE){
     CHECK_COUNT(2, sendcount)
+    CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
+  }
   CHECK_REQUEST(9)
   CHECK_NOT_IN_PLACE(4, recvbuf)
   for (int i = 0; i < comm->size(); i++) {
     CHECK_COUNT(5, recvcounts[i])
-    CHECK_BUFFER(4, recvbuf, recvcounts[i])
+    CHECK_BUFFER(4, recvbuf, recvcounts[i], recvtype)
   }
 
   smpi_bench_end();
@@ -336,19 +346,21 @@ int PMPI_Iscatter(const void* sendbuf, int sendcount, MPI_Datatype sendtype, voi
                   MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(8)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
   if(rank == root){
     CHECK_NOT_IN_PLACE_ROOT(1, sendbuf)
-    CHECK_BUFFER(1, sendbuf, sendcount)
     CHECK_COUNT(2, sendcount)
     CHECK_TYPE(3, sendtype)
+    CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
   } else {
     CHECK_NOT_IN_PLACE_ROOT(4, recvbuf)
   }
   if(recvbuf != MPI_IN_PLACE){
-    CHECK_BUFFER(4, recvbuf, recvcount)
     CHECK_COUNT(5, recvcount)
     CHECK_TYPE(6, recvtype)
+    CHECK_BUFFER(4, recvbuf, recvcount, recvtype)
   }
   CHECK_ROOT(8)
   CHECK_REQUEST(9)
@@ -391,13 +403,15 @@ int PMPI_Scatterv(const void *sendbuf, const int *sendcounts, const int *displs,
 int PMPI_Iscatterv(const void* sendbuf, const int* sendcounts, const int* displs, MPI_Datatype sendtype, void* recvbuf, int recvcount,
                    MPI_Datatype recvtype, int root, MPI_Comm comm, MPI_Request* request)
 {
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   CHECK_COMM(9)
   int rank = comm->rank();
   if(recvbuf != MPI_IN_PLACE){
     CHECK_NOT_IN_PLACE_ROOT(1, sendbuf)
-    CHECK_BUFFER(4, recvbuf, recvcount)
     CHECK_COUNT(5, recvcount)
     CHECK_TYPE(7, recvtype)
+    CHECK_BUFFER(4, recvbuf, recvcount, recvtype)
   }
   CHECK_ROOT(9)
   CHECK_REQUEST(10)
@@ -406,8 +420,8 @@ int PMPI_Iscatterv(const void* sendbuf, const int* sendcounts, const int* displs
     CHECK_NULL(3, MPI_ERR_ARG, displs)
     CHECK_TYPE(4, sendtype)
     for (int i = 0; i < comm->size(); i++){
-      CHECK_BUFFER(1, sendbuf, sendcounts[i])
       CHECK_COUNT(2, sendcounts[i])
+      CHECK_BUFFER(1, sendbuf, sendcounts[i], sendtype)
     }
     if (recvbuf == MPI_IN_PLACE) {
       recvtype  = sendtype;
@@ -454,14 +468,16 @@ int PMPI_Reduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype data
 int PMPI_Ireduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, int root, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(7)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
-  CHECK_BUFFER(1, sendbuf, count)
+  CHECK_TYPE(4, datatype)
+  CHECK_COUNT(3, count)
+  CHECK_BUFFER(1, sendbuf, count, datatype)
   if(rank == root){
     CHECK_NOT_IN_PLACE(2, recvbuf)
-    CHECK_BUFFER(5, recvbuf, count)
+    CHECK_BUFFER(5, recvbuf, count, datatype)
   }
-  CHECK_TYPE(4, datatype)
-  CHECK_COUNT(3, count)
   CHECK_OP(5, op, datatype)
   CHECK_ROOT(7)
   CHECK_REQUEST(8)
@@ -485,10 +501,12 @@ int PMPI_Ireduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype dat
 
 int PMPI_Reduce_local(const void* inbuf, void* inoutbuf, int count, MPI_Datatype datatype, MPI_Op op)
 {
-  CHECK_BUFFER(1, inbuf, count)
-  CHECK_BUFFER(2, inoutbuf, count)
+  SET_BUF1(inbuf)
+  SET_BUF2(inoutbuf)
   CHECK_TYPE(4, datatype)
   CHECK_COUNT(3, count)
+  CHECK_BUFFER(1, inbuf, count, datatype)
+  CHECK_BUFFER(2, inoutbuf, count, datatype)
   CHECK_OP(5, op, datatype)
 
   smpi_bench_end();
@@ -505,14 +523,16 @@ int PMPI_Allreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype d
 int PMPI_Iallreduce(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request)
 {
   CHECK_COMM(6)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
-  CHECK_BUFFER(1, sendbuf, count)
-  CHECK_BUFFER(2, recvbuf, count)
   CHECK_NOT_IN_PLACE(2, recvbuf)
   CHECK_TYPE(4, datatype)
+  CHECK_OP(5, op, datatype)
   CHECK_COUNT(3, count)
+  CHECK_BUFFER(1, sendbuf, count, datatype)
+  CHECK_BUFFER(2, recvbuf, count, datatype)
   CHECK_REQUEST(7)
-  CHECK_OP(5, op, datatype)
 
   smpi_bench_end();
   std::vector<unsigned char> tmp_sendbuf;
@@ -543,10 +563,12 @@ int PMPI_Scan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype dataty
 int PMPI_Iscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(6)
-  CHECK_BUFFER(1,sendbuf,count)
-  CHECK_BUFFER(2,recvbuf,count)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   CHECK_TYPE(4, datatype)
   CHECK_COUNT(3, count)
+  CHECK_BUFFER(1,sendbuf,count, datatype)
+  CHECK_BUFFER(2,recvbuf,count, datatype)
   CHECK_REQUEST(7)
   CHECK_OP(5, op, datatype)
 
@@ -578,10 +600,12 @@ int PMPI_Exscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype data
 
 int PMPI_Iexscan(const void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request* request){
   CHECK_COMM(6)
-  CHECK_BUFFER(1, sendbuf, count)
-  CHECK_BUFFER(2, recvbuf, count)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   CHECK_TYPE(4, datatype)
   CHECK_COUNT(3, count)
+  CHECK_BUFFER(1, sendbuf, count, datatype)
+  CHECK_BUFFER(2, recvbuf, count, datatype)
   CHECK_REQUEST(7)
   CHECK_OP(5, op, datatype)
 
@@ -614,6 +638,8 @@ int PMPI_Reduce_scatter(const void *sendbuf, void *recvbuf, const int *recvcount
 int PMPI_Ireduce_scatter(const void *sendbuf, void *recvbuf, const int *recvcounts, MPI_Datatype datatype, MPI_Op op, MPI_Comm comm, MPI_Request *request)
 {
   CHECK_COMM(6)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   int rank = comm->rank();
   CHECK_NOT_IN_PLACE(2, recvbuf)
   CHECK_TYPE(4, datatype)
@@ -622,8 +648,8 @@ int PMPI_Ireduce_scatter(const void *sendbuf, void *recvbuf, const int *recvcoun
   CHECK_OP(5, op, datatype)
   for (int i = 0; i < comm->size(); i++) {
     CHECK_COUNT(3, recvcounts[i])
-    CHECK_BUFFER(1, sendbuf, recvcounts[i])
-    CHECK_BUFFER(2, recvbuf, recvcounts[i])
+    CHECK_BUFFER(1, sendbuf, recvcounts[i], datatype)
+    CHECK_BUFFER(2, recvbuf, recvcounts[i], datatype)
   }
 
   smpi_bench_end();
@@ -664,10 +690,12 @@ int PMPI_Ireduce_scatter_block(const void* sendbuf, void* recvbuf, int recvcount
                                MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(6)
-  CHECK_BUFFER(1, sendbuf, recvcount)
-  CHECK_BUFFER(2, recvbuf, recvcount)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   CHECK_TYPE(4, datatype)
   CHECK_COUNT(3, recvcount)
+  CHECK_BUFFER(1, sendbuf, recvcount, datatype)
+  CHECK_BUFFER(2, recvbuf, recvcount, datatype)
   CHECK_REQUEST(7)
   CHECK_OP(5, op, datatype)
 
@@ -707,15 +735,17 @@ int PMPI_Ialltoall(const void* sendbuf, int sendcount, MPI_Datatype sendtype, vo
                    MPI_Datatype recvtype, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(7)
-  CHECK_BUFFER(1, sendbuf, sendcount)
-  CHECK_BUFFER(4, recvbuf, recvcount)
-  if(sendbuf != MPI_IN_PLACE)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
+  if(sendbuf != MPI_IN_PLACE){
     CHECK_TYPE(3, sendtype)
+    CHECK_COUNT(2, sendcount)
+    CHECK_BUFFER(1, sendbuf, sendcount, sendtype)
+  }
   CHECK_TYPE(6, recvtype)
   CHECK_COUNT(5, recvcount)
-  if(sendbuf != MPI_IN_PLACE)
-    CHECK_COUNT(2, sendcount)
   CHECK_COUNT(5, recvcount)
+  CHECK_BUFFER(4, recvbuf, recvcount, recvtype)
   CHECK_REQUEST(8)
 
   int pid                 = simgrid::s4u::this_actor::get_pid();
@@ -766,6 +796,8 @@ int PMPI_Ialltoallv(const void* sendbuf, const int* sendcounts, const int* sendd
                     const int* recvcounts, const int* recvdispls, MPI_Datatype recvtype, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(9)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   if(sendbuf != MPI_IN_PLACE){
     CHECK_NULL(2, MPI_ERR_COUNT, sendcounts)
     CHECK_NULL(3, MPI_ERR_ARG, senddispls)
@@ -780,10 +812,10 @@ int PMPI_Ialltoallv(const void* sendbuf, const int* sendcounts, const int* sendd
   int size = comm->size();
   for (int i = 0; i < size; i++) {
     if(sendbuf != MPI_IN_PLACE){
-      CHECK_BUFFER(1, sendbuf, sendcounts[i])
+      CHECK_BUFFER(1, sendbuf, sendcounts[i], sendtype)
       CHECK_COUNT(2, sendcounts[i])
     }
-    CHECK_BUFFER(5, recvbuf, recvcounts[i])
+    CHECK_BUFFER(5, recvbuf, recvcounts[i], recvtype)
     CHECK_COUNT(6, recvcounts[i])
   }
 
@@ -859,6 +891,8 @@ int PMPI_Ialltoallw(const void* sendbuf, const int* sendcounts, const int* sendd
                     const int* recvcounts, const int* recvdispls, const MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request* request)
 {
   CHECK_COMM(9)
+  SET_BUF1(sendbuf)
+  SET_BUF2(recvbuf)
   if(sendbuf != MPI_IN_PLACE){
     CHECK_NULL(2, MPI_ERR_COUNT, sendcounts)
     CHECK_NULL(3, MPI_ERR_ARG, senddispls)
@@ -872,13 +906,13 @@ int PMPI_Ialltoallw(const void* sendbuf, const int* sendcounts, const int* sendd
   int size = comm->size();
   for (int i = 0; i < size; i++) {
     if(sendbuf != MPI_IN_PLACE){
-      CHECK_BUFFER(1, sendbuf, sendcounts[i])
       CHECK_COUNT(2, sendcounts[i])
       CHECK_TYPE(4, sendtypes[i])
+      CHECK_BUFFER(1, sendbuf, sendcounts[i], sendtypes[i])
     }
-    CHECK_BUFFER(5, recvbuf, recvcounts[i])
     CHECK_COUNT(6, recvcounts[i])
     CHECK_TYPE(8, recvtypes[i])
+    CHECK_BUFFER(5, recvbuf, recvcounts[i], recvtypes[i])
   }
 
   smpi_bench_end();