Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
fix broken tests
[simgrid.git] / src / smpi / bindings / smpi_pmpi_coll.cpp
index 0e59749..e5a7089 100644 (file)
@@ -105,7 +105,7 @@ int PMPI_Gather(void *sendbuf, int sendcount, MPI_Datatype sendtype,void *recvbu
     extra->datatype1 = encode_datatype(sendtmptype);
     extra->send_size = sendtmptype->is_basic() ? sendtmpcount : sendtmpcount * sendtmptype->size();
     extra->datatype2 = encode_datatype(recvtype);
-    extra->recv_size = recvtype->is_basic() ? recvcount : recvcount * recvtype->size();
+    extra->recv_size = (comm->rank() != root || recvtype->is_basic()) ? recvcount : recvcount * recvtype->size();
 
     TRACE_smpi_collective_in(rank, __FUNCTION__, extra);
 
@@ -284,7 +284,7 @@ int PMPI_Scatter(void *sendbuf, int sendcount, MPI_Datatype sendtype,
     extra->root            = root_traced;
 
     extra->datatype1 = encode_datatype(sendtype);
-    extra->send_size = sendtype->is_basic() ? sendcount : sendcount * sendtype->size();
+    extra->send_size = (comm->rank() != root || sendtype->is_basic()) ? sendcount : sendcount * sendtype->size();
     extra->datatype2 = encode_datatype(recvtype);
     extra->recv_size = recvtype->is_basic() ? recvcount : recvcount * recvtype->size();
 
@@ -480,8 +480,8 @@ int PMPI_Exscan(void *sendbuf, void *recvbuf, int count, MPI_Datatype datatype,
     extra->send_size       = datatype->is_basic() ? count : count * datatype->size();
     void* sendtmpbuf = sendbuf;
     if (sendbuf == MPI_IN_PLACE) {
-      sendtmpbuf = static_cast<void*>(xbt_malloc(extra->send_size));
-      memcpy(sendtmpbuf, recvbuf, extra->send_size);
+      sendtmpbuf = static_cast<void*>(xbt_malloc(count * datatype->size()));
+      memcpy(sendtmpbuf, recvbuf, count * datatype->size());
     }
     TRACE_smpi_collective_in(rank, __FUNCTION__, extra);
 
@@ -666,7 +666,7 @@ int PMPI_Alltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
     extra->recvcounts      = new int[size];
     extra->sendcounts      = new int[size];
     extra->datatype2       = encode_datatype(recvtype);
-    int dt_size_recv       = recvtype->is_basic() ? 1 : recvtype->size();
+    int dt_size_recv       = recvtype->size();
 
     void* sendtmpbuf         = static_cast<char*>(sendbuf);
     int* sendtmpcounts       = sendcounts;
@@ -691,7 +691,7 @@ int PMPI_Alltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
     }
 
     extra->datatype1 = encode_datatype(sendtmptype);
-    int dt_size_send = sendtmptype->is_basic() ? 1 : sendtmptype->size();
+    int dt_size_send = sendtmptype->size();
 
     for (i = 0; i < size; i++) { // copy data to avoid bad free
       extra->send_size += sendtmpcounts[i] * dt_size_send;