Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
avoid potential leak
[simgrid.git] / src / smpi / bindings / smpi_pmpi_coll.cpp
index 3a82044..5f71a34 100644 (file)
@@ -70,6 +70,9 @@ int PMPI_Ibcast(void *buf, int count, MPI_Datatype datatype,
         simgrid::smpi::Colls::bcast(buf, count, datatype, root, comm);
       else
         simgrid::smpi::Colls::ibcast(buf, count, datatype, root, comm, request);
+    } else {
+      if(request!=MPI_REQUEST_IGNORED)
+        *request = MPI_REQUEST_NULL;
     }
     retval = MPI_SUCCESS;
 
@@ -771,3 +774,91 @@ int PMPI_Ialltoallv(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype
   smpi_bench_begin();
   return retval;
 }
+
+int PMPI_Alltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf,
+                   int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm)
+{
+  return PMPI_Ialltoallw(sendbuf, sendcounts, senddisps, sendtypes, recvbuf, recvcounts, recvdisps, recvtypes, comm, MPI_REQUEST_IGNORED);
+}
+
+int PMPI_Ialltoallw(void* sendbuf, int* sendcounts, int* senddisps, MPI_Datatype* sendtypes, void* recvbuf,
+                   int* recvcounts, int* recvdisps, MPI_Datatype* recvtypes, MPI_Comm comm, MPI_Request *request)
+{
+  int retval = 0;
+
+  smpi_bench_end();
+
+  if (comm == MPI_COMM_NULL) {
+    retval = MPI_ERR_COMM;
+  } else if ((sendbuf != MPI_IN_PLACE && sendtypes == nullptr)  || recvtypes == nullptr) {
+    retval = MPI_ERR_TYPE;
+  } else if ((sendbuf != MPI_IN_PLACE && (sendcounts == nullptr || senddisps == nullptr)) || recvcounts == nullptr ||
+             recvdisps == nullptr) {
+    retval = MPI_ERR_ARG;
+  } else if (request == nullptr){
+    retval = MPI_ERR_ARG;
+  }  else {
+    int rank                           = simgrid::s4u::this_actor::get_pid();
+    int size                           = comm->size();
+    int send_size                      = 0;
+    int recv_size                      = 0;
+    std::vector<int>* trace_sendcounts = new std::vector<int>;
+    std::vector<int>* trace_recvcounts = new std::vector<int>;
+
+    void* sendtmpbuf           = static_cast<char*>(sendbuf);
+    int* sendtmpcounts         = sendcounts;
+    int* sendtmpdisps          = senddisps;
+    MPI_Datatype* sendtmptypes = sendtypes;
+    unsigned long maxsize                = 0;
+    for (int i = 0; i < size; i++) { // copy data to avoid bad free
+      if(recvtypes[i]==MPI_DATATYPE_NULL){
+        delete trace_recvcounts;
+        delete trace_sendcounts;
+        return MPI_ERR_TYPE;
+      }
+      recv_size += recvcounts[i] * recvtypes[i]->size();
+      trace_recvcounts->push_back(recvcounts[i] * recvtypes[i]->size());
+      if ((recvdisps[i] + (recvcounts[i] * recvtypes[i]->size())) > maxsize)
+        maxsize = recvdisps[i] + (recvcounts[i] * recvtypes[i]->size());
+    }
+
+    if (sendbuf == MPI_IN_PLACE) {
+      sendtmpbuf = static_cast<void*>(xbt_malloc(maxsize));
+      memcpy(sendtmpbuf, recvbuf, maxsize);
+      sendtmpcounts = static_cast<int*>(xbt_malloc(size * sizeof(int)));
+      memcpy(sendtmpcounts, recvcounts, size * sizeof(int));
+      sendtmpdisps = static_cast<int*>(xbt_malloc(size * sizeof(int)));
+      memcpy(sendtmpdisps, recvdisps, size * sizeof(int));
+      sendtmptypes = static_cast<MPI_Datatype*>(xbt_malloc(size * sizeof(MPI_Datatype)));
+      memcpy(sendtmptypes, recvtypes, size * sizeof(MPI_Datatype));
+    }
+
+    for (int i = 0; i < size; i++) { // copy data to avoid bad free
+      send_size += sendtmpcounts[i] * sendtmptypes[i]->size();
+      trace_sendcounts->push_back(sendtmpcounts[i] * sendtmptypes[i]->size());
+    }
+
+    TRACE_smpi_comm_in(rank, request==MPI_REQUEST_IGNORED?"PMPI_Alltoallw":"PMPI_Ialltoallw",
+                       new simgrid::instr::VarCollTIData(request==MPI_REQUEST_IGNORED ? "alltoallv":"ialltoallv", -1, send_size, trace_sendcounts, recv_size,
+                                                         trace_recvcounts, simgrid::smpi::Datatype::encode(sendtmptypes[0]),
+                                                         simgrid::smpi::Datatype::encode(recvtypes[0])));
+
+    if(request == MPI_REQUEST_IGNORED)
+      retval = simgrid::smpi::Colls::alltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts,
+                                    recvdisps, recvtypes, comm);
+    else
+      retval = simgrid::smpi::Colls::ialltoallw(sendtmpbuf, sendtmpcounts, sendtmpdisps, sendtmptypes, recvbuf, recvcounts,
+                                    recvdisps, recvtypes, comm, request);
+    TRACE_smpi_comm_out(rank);
+
+    if (sendbuf == MPI_IN_PLACE) {
+      xbt_free(sendtmpbuf);
+      xbt_free(sendtmpcounts);
+      xbt_free(sendtmpdisps);
+      xbt_free(sendtmptypes);
+    }
+  }
+
+  smpi_bench_begin();
+  return retval;
+}