Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay/Allgatherv: Account for disps parameters in a replay trace
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
index 42c8904..0f844a4 100644 (file)
@@ -186,7 +186,7 @@ static void action_send(const char *const *action)
   int dst_traced = MPI_COMM_WORLD->group()->actor(to)->getPid();
 
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
-                     new simgrid::instr::Pt2PtTIData("send", to, size, encode_datatype(MPI_CURRENT_TYPE)));
+                     new simgrid::instr::Pt2PtTIData("send", to, size, MPI_CURRENT_TYPE->encode()));
   if (not TRACE_smpi_view_internals())
     TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, 0, size * MPI_CURRENT_TYPE->size());
 
@@ -209,7 +209,7 @@ static void action_Isend(const char *const *action)
   int my_proc_id = Actor::self()->getPid();
   int dst_traced = MPI_COMM_WORLD->group()->actor(to)->getPid();
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
-                     new simgrid::instr::Pt2PtTIData("Isend", to, size, encode_datatype(MPI_CURRENT_TYPE)));
+                     new simgrid::instr::Pt2PtTIData("Isend", to, size, MPI_CURRENT_TYPE->encode()));
   if (not TRACE_smpi_view_internals())
     TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, 0, size * MPI_CURRENT_TYPE->size());
 
@@ -235,7 +235,7 @@ static void action_recv(const char *const *action) {
   int src_traced = MPI_COMM_WORLD->group()->actor(from)->getPid();
 
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
-                     new simgrid::instr::Pt2PtTIData("recv", from, size, encode_datatype(MPI_CURRENT_TYPE)));
+                     new simgrid::instr::Pt2PtTIData("recv", from, size, MPI_CURRENT_TYPE->encode()));
 
   //unknown size from the receiver point of view
   if (size <= 0.0) {
@@ -264,7 +264,7 @@ static void action_Irecv(const char *const *action)
 
   int my_proc_id = Actor::self()->getPid();
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
-                     new simgrid::instr::Pt2PtTIData("Irecv", from, size, encode_datatype(MPI_CURRENT_TYPE)));
+                     new simgrid::instr::Pt2PtTIData("Irecv", from, size, MPI_CURRENT_TYPE->encode()));
   MPI_Status status;
   //unknow size from the receiver pov
   if (size <= 0.0) {
@@ -392,7 +392,7 @@ static void action_bcast(const char *const *action)
   int my_proc_id = Actor::self()->getPid();
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
                      new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(root)->getPid(), -1.0, size,
-                                                    -1, encode_datatype(MPI_CURRENT_TYPE), ""));
+                                                    -1, MPI_CURRENT_TYPE->encode(), ""));
 
   void *sendbuf = smpi_get_tmp_sendbuffer(size* MPI_CURRENT_TYPE->size());
 
@@ -415,7 +415,7 @@ static void action_reduce(const char *const *action)
   int my_proc_id = Actor::self()->getPid();
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
                      new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(root)->getPid(), comp_size,
-                                                    comm_size, -1, encode_datatype(MPI_CURRENT_TYPE), ""));
+                                                    comm_size, -1, MPI_CURRENT_TYPE->encode(), ""));
 
   void *recvbuf = smpi_get_tmp_sendbuffer(comm_size* MPI_CURRENT_TYPE->size());
   void *sendbuf = smpi_get_tmp_sendbuffer(comm_size* MPI_CURRENT_TYPE->size());
@@ -436,7 +436,7 @@ static void action_allReduce(const char *const *action) {
   double clock = smpi_process()->simulated_elapsed();
   int my_proc_id = Actor::self()->getPid();
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__, new simgrid::instr::CollTIData("allReduce", -1, comp_size, comm_size, -1,
-                                                                              encode_datatype(MPI_CURRENT_TYPE), ""));
+                                                                              MPI_CURRENT_TYPE->encode(), ""));
 
   void *recvbuf = smpi_get_tmp_sendbuffer(comm_size* MPI_CURRENT_TYPE->size());
   void *sendbuf = smpi_get_tmp_sendbuffer(comm_size* MPI_CURRENT_TYPE->size());
@@ -462,8 +462,7 @@ static void action_allToAll(const char *const *action) {
   int my_proc_id = Actor::self()->getPid();
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, send_size, recv_size,
-                                                    encode_datatype(MPI_CURRENT_TYPE),
-                                                    encode_datatype(MPI_CURRENT_TYPE2)));
+                                                    MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::alltoall(send, send_size, MPI_CURRENT_TYPE, recv, recv_size, MPI_CURRENT_TYPE2, MPI_COMM_WORLD);
 
@@ -497,9 +496,9 @@ static void action_gather(const char *const *action) {
   if(rank==root)
     recv = smpi_get_tmp_recvbuffer(recv_size*comm_size* MPI_CURRENT_TYPE2->size());
 
-  TRACE_smpi_comm_in(rank, __FUNCTION__, new simgrid::instr::CollTIData("gather", root, -1.0, send_size, recv_size,
-                                                                        encode_datatype(MPI_CURRENT_TYPE),
-                                                                        encode_datatype(MPI_CURRENT_TYPE2)));
+  TRACE_smpi_comm_in(rank, __FUNCTION__,
+                     new simgrid::instr::CollTIData("gather", root, -1.0, send_size, recv_size,
+                                                    MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::gather(send, send_size, MPI_CURRENT_TYPE, recv, recv_size, MPI_CURRENT_TYPE2, root, MPI_COMM_WORLD);
 
@@ -534,9 +533,9 @@ static void action_scatter(const char* const* action)
   if (rank == root)
     recv = smpi_get_tmp_recvbuffer(recv_size * comm_size * MPI_CURRENT_TYPE2->size());
 
-  TRACE_smpi_comm_in(rank, __FUNCTION__, new simgrid::instr::CollTIData("gather", root, -1.0, send_size, recv_size,
-                                                                        encode_datatype(MPI_CURRENT_TYPE),
-                                                                        encode_datatype(MPI_CURRENT_TYPE2)));
+  TRACE_smpi_comm_in(rank, __FUNCTION__,
+                     new simgrid::instr::CollTIData("gather", root, -1.0, send_size, recv_size,
+                                                    MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::scatter(send, send_size, MPI_CURRENT_TYPE, recv, recv_size, MPI_CURRENT_TYPE2, root, MPI_COMM_WORLD);
 
@@ -579,9 +578,9 @@ static void action_gatherv(const char *const *action) {
   if(rank==root)
     recv = smpi_get_tmp_recvbuffer(recv_sum* MPI_CURRENT_TYPE2->size());
 
-  TRACE_smpi_comm_in(rank, __FUNCTION__, new simgrid::instr::VarCollTIData(
-                                             "gatherV", root, send_size, nullptr, -1, recvcounts,
-                                             encode_datatype(MPI_CURRENT_TYPE), encode_datatype(MPI_CURRENT_TYPE2)));
+  TRACE_smpi_comm_in(rank, __FUNCTION__,
+                     new simgrid::instr::VarCollTIData("gatherV", root, send_size, nullptr, -1, recvcounts,
+                                                       MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::gatherv(send, send_size, MPI_CURRENT_TYPE, recv, recvcounts->data(), disps.data(), MPI_CURRENT_TYPE2, root,
                  MPI_COMM_WORLD);
@@ -626,9 +625,9 @@ static void action_scatterv(const char* const* action)
   if (rank == root)
     send = smpi_get_tmp_sendbuffer(send_sum * MPI_CURRENT_TYPE2->size());
 
-  TRACE_smpi_comm_in(rank, __FUNCTION__, new simgrid::instr::VarCollTIData("gatherV", root, -1, sendcounts, recv_size,
-                                                                           nullptr, encode_datatype(MPI_CURRENT_TYPE),
-                                                                           encode_datatype(MPI_CURRENT_TYPE2)));
+  TRACE_smpi_comm_in(rank, __FUNCTION__,
+                     new simgrid::instr::VarCollTIData("gatherV", root, -1, sendcounts, recv_size, nullptr,
+                                                       MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::scatterv(send, sendcounts->data(), disps.data(), MPI_CURRENT_TYPE, recv, recv_size, MPI_CURRENT_TYPE2, root,
                   MPI_COMM_WORLD);
@@ -661,7 +660,7 @@ static void action_reducescatter(const char *const *action) {
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
                      new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, recvcounts,
                                                        std::to_string(comp_size), /* ugly hack to print comp_size */
-                                                       encode_datatype(MPI_CURRENT_TYPE)));
+                                                       MPI_CURRENT_TYPE->encode()));
 
   void *sendbuf = smpi_get_tmp_sendbuffer(size* MPI_CURRENT_TYPE->size());
   void *recvbuf = smpi_get_tmp_recvbuffer(size* MPI_CURRENT_TYPE->size());
@@ -697,8 +696,7 @@ static void action_allgather(const char *const *action) {
 
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
                      new simgrid::instr::CollTIData("allGather", -1, -1.0, sendcount, recvcount,
-                                                    encode_datatype(MPI_CURRENT_TYPE),
-                                                    encode_datatype(MPI_CURRENT_TYPE2)));
+                                                    MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::allgather(sendbuf, sendcount, MPI_CURRENT_TYPE, recvbuf, recvcount, MPI_CURRENT_TYPE2, MPI_COMM_WORLD);
 
@@ -722,10 +720,23 @@ static void action_allgatherv(const char *const *action) {
   std::shared_ptr<std::vector<int>> recvcounts(new std::vector<int>(comm_size));
   std::vector<int> disps(comm_size, 0);
 
-  MPI_Datatype MPI_CURRENT_TYPE =
-      (action[3 + comm_size] && action[4 + comm_size]) ? decode_datatype(action[3 + comm_size]) : MPI_DEFAULT_TYPE;
-  MPI_Datatype MPI_CURRENT_TYPE2{
-      (action[3 + comm_size] && action[4 + comm_size]) ? decode_datatype(action[4 + comm_size]) : MPI_DEFAULT_TYPE};
+  int datatype_index = 0, disp_index = 0;
+  if (action[3 + 2 * comm_size]) { /* datatype + disp are specified */
+    datatype_index = 3 + comm_size;
+    disp_index     = datatype_index + 1;
+  } else if (action[3 + 2 * comm_size]) { /* disps specified; datatype is not specified; use the default one */
+    datatype_index = -1;
+    disp_index     = 3 + comm_size;
+  } else if (action[3 + comm_size]) { /* only datatype, no disp specified */
+    datatype_index = 3 + comm_size;
+  }
+
+  if (disp_index != 0) {
+    std::copy(action[disp_index], action[disp_index + comm_size], disps.begin());
+  }
+
+  MPI_Datatype MPI_CURRENT_TYPE{(datatype_index > 0) ? decode_datatype(action[datatype_index]) : MPI_DEFAULT_TYPE};
+  MPI_Datatype MPI_CURRENT_TYPE2{(datatype_index > 0) ? decode_datatype(action[datatype_index]) : MPI_DEFAULT_TYPE};
 
   void *sendbuf = smpi_get_tmp_sendbuffer(sendcount* MPI_CURRENT_TYPE->size());
 
@@ -739,8 +750,7 @@ static void action_allgatherv(const char *const *action) {
 
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
                      new simgrid::instr::VarCollTIData("allGatherV", -1, sendcount, nullptr, -1, recvcounts,
-                                                       encode_datatype(MPI_CURRENT_TYPE),
-                                                       encode_datatype(MPI_CURRENT_TYPE2)));
+                                                       MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::allgatherv(sendbuf, sendcount, MPI_CURRENT_TYPE, recvbuf, recvcounts->data(), disps.data(), MPI_CURRENT_TYPE2,
                     MPI_COMM_WORLD);
@@ -789,8 +799,7 @@ static void action_allToAllv(const char *const *action) {
 
   TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
                      new simgrid::instr::VarCollTIData("allToAllV", -1, send_size, sendcounts, recv_size, recvcounts,
-                                                       encode_datatype(MPI_CURRENT_TYPE),
-                                                       encode_datatype(MPI_CURRENT_TYPE2)));
+                                                       MPI_CURRENT_TYPE->encode(), MPI_CURRENT_TYPE2->encode()));
 
   Colls::alltoallv(sendbuf, sendcounts->data(), senddisps.data(), MPI_CURRENT_TYPE, recvbuf, recvcounts->data(),
                    recvdisps.data(), MPI_CURRENT_TYPE, MPI_COMM_WORLD);