Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Rename (recv|send)_sum -> (recv|send)_size_sum
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
index fca1f30..1cc4d1f 100644 (file)
@@ -201,7 +201,7 @@ public:
 
 class GatherVArgParser : public CollCommParser {
 public:
-  int recv_sum;
+  int recv_size_sum;
   std::shared_ptr<std::vector<int>> recvcounts;
   std::vector<int> disps;
   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
@@ -252,7 +252,7 @@ public:
     for (unsigned int i = 0; i < comm_size; i++) {
       (*recvcounts)[i] = std::stoi(action[i + 3]);
     }
-    recv_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
+    recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
   }
 };
 
@@ -283,8 +283,8 @@ public:
 
 class ScatterVArgParser : public CollCommParser {
 public:
-  int recv_sum;
-  int send_sum;
+  int recv_size_sum;
+  int send_size_sum;
   std::shared_ptr<std::vector<int>> sendcounts;
   std::vector<int> disps;
   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
@@ -311,14 +311,14 @@ public:
     for (unsigned int i = 0; i < comm_size; i++) {
       (*sendcounts)[i] = std::stoi(action[i + 2]);
     }
-    send_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
+    send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
   }
 };
 
 class ReduceScatterArgParser : public CollCommParser {
 public:
-  int recv_sum;
+  int recv_size_sum;
   std::shared_ptr<std::vector<int>> recvcounts;
   std::vector<int> disps;
   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
@@ -340,7 +340,50 @@ public:
     for (unsigned int i = 0; i < comm_size; i++) {
       recvcounts->push_back(std::stoi(action[i + 2]));
     }
-    recv_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
+    recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
+  }
+};
+
+class AllToAllVArgParser : public CollCommParser {
+public:
+  int recv_size_sum;
+  int send_size_sum;
+  std::shared_ptr<std::vector<int>> recvcounts;
+  std::shared_ptr<std::vector<int>> sendcounts;
+  std::vector<int> senddisps;
+  std::vector<int> recvdisps;
+  int send_buf_size;
+  int recv_buf_size;
+  void parse(simgrid::xbt::ReplayAction& action, std::string name) override
+  {
+    /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
+          0 allToAllV 100 1 7 10 12 100 1 70 10 5
+       where:
+        1) 100 is the size of the send buffer *sizeof(int),
+        2) 1 7 10 12 is the sendcounts array
+        3) 100*sizeof(int) is the size of the receiver buffer
+        4)  1 70 10 5 is the recvcounts array
+    */
+    comm_size = MPI_COMM_WORLD->size();
+    CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
+    sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
+    recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
+    senddisps  = std::vector<int>(comm_size, 0);
+    recvdisps  = std::vector<int>(comm_size, 0);
+
+    if (action.size() > 5 + 2 * comm_size)
+      datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
+    if (action.size() > 5 + 2 * comm_size)
+      datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
+
+    send_buf_size=parse_double(action[2]);
+    recv_buf_size=parse_double(action[3+comm_size]);
+    for (unsigned int i = 0; i < comm_size; i++) {
+      (*sendcounts)[i] = std::stoi(action[3 + i]);
+      (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
+    }
+    send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
+    recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
   }
 };
 
@@ -670,12 +713,12 @@ public:
 
     if (name == "gatherV") {
       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1, 
-                     (rank == args.root) ? recv_buffer(args.recv_sum  * args.datatype2->size()) : nullptr, args.recvcounts->data(), args.disps.data(), args.datatype2, args.root,
+                     (rank == args.root) ? recv_buffer(args.recv_size_sum  * args.datatype2->size()) : nullptr, args.recvcounts->data(), args.disps.data(), args.datatype2, args.root,
                      MPI_COMM_WORLD);
     }
     else {
       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1, 
-                        recv_buffer(args.recv_sum * args.datatype2->size()), args.recvcounts->data(), args.disps.data(), args.datatype2,
+                        recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(), args.disps.data(), args.datatype2,
                     MPI_COMM_WORLD);
     }
 
@@ -711,7 +754,7 @@ public:
           nullptr, Datatype::encode(args.datatype1),
           Datatype::encode(args.datatype2)));
 
-    Colls::scatterv((rank == args.root) ? send_buffer(args.send_sum * args.datatype1->size()) : nullptr, args.sendcounts->data(), args.disps.data(), 
+    Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr, args.sendcounts->data(), args.disps.data(), 
         args.datatype1, recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
         MPI_COMM_WORLD);
 
@@ -729,66 +772,31 @@ public:
                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
                                                          Datatype::encode(args.datatype1)));
 
-    Colls::reduce_scatter(send_buffer(args.recv_sum * args.datatype1->size()), recv_buffer(args.recv_sum * args.datatype1->size()), 
+    Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()), recv_buffer(args.recv_size_sum * args.datatype1->size()), 
                           args.recvcounts->data(), args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
 
     smpi_execute_flops(args.comp_size);
     TRACE_smpi_comm_out(my_proc_id);
   }
 };
-} // Replay Namespace
-
-static void action_allToAllv(simgrid::xbt::ReplayAction& action)
-{
-  /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
-        0 allToAllV 100 1 7 10 12 100 1 70 10 5
-     where:
-        1) 100 is the size of the send buffer *sizeof(int),
-        2) 1 7 10 12 is the sendcounts array
-        3) 100*sizeof(int) is the size of the receiver buffer
-        4)  1 70 10 5 is the recvcounts array
-  */
-  double clock = smpi_process()->simulated_elapsed();
-
-  unsigned long comm_size = MPI_COMM_WORLD->size();
-  CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
-  std::shared_ptr<std::vector<int>> sendcounts(new std::vector<int>(comm_size));
-  std::shared_ptr<std::vector<int>> recvcounts(new std::vector<int>(comm_size));
-  std::vector<int> senddisps(comm_size, 0);
-  std::vector<int> recvdisps(comm_size, 0);
-
-  MPI_Datatype MPI_CURRENT_TYPE = (action.size() > 5 + 2 * comm_size)
-                                      ? simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size])
-                                      : MPI_DEFAULT_TYPE;
-  MPI_Datatype MPI_CURRENT_TYPE2{(action.size() > 5 + 2 * comm_size)
-                                     ? simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size])
-                                     : MPI_DEFAULT_TYPE};
-
-  int send_buf_size=parse_double(action[2]);
-  int recv_buf_size=parse_double(action[3+comm_size]);
-  int my_proc_id = Actor::self()->getPid();
-  void *sendbuf = smpi_get_tmp_sendbuffer(send_buf_size* MPI_CURRENT_TYPE->size());
-  void *recvbuf  = smpi_get_tmp_recvbuffer(recv_buf_size* MPI_CURRENT_TYPE2->size());
-
-  for (unsigned int i = 0; i < comm_size; i++) {
-    (*sendcounts)[i] = std::stoi(action[3 + i]);
-    (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
-  }
-  int send_size = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
-  int recv_size = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
-
-  TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
-                     new simgrid::instr::VarCollTIData("allToAllV", -1, send_size, sendcounts, recv_size, recvcounts,
-                                                       Datatype::encode(MPI_CURRENT_TYPE),
-                                                       Datatype::encode(MPI_CURRENT_TYPE2)));
 
-  Colls::alltoallv(sendbuf, sendcounts->data(), senddisps.data(), MPI_CURRENT_TYPE, recvbuf, recvcounts->data(),
-                   recvdisps.data(), MPI_CURRENT_TYPE, MPI_COMM_WORLD);
+class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
+public:
+  AllToAllVAction() : ReplayAction("allToAllV") {}
+  void kernel(simgrid::xbt::ReplayAction& action) override
+  {
+    TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
+        new simgrid::instr::VarCollTIData("allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
+          Datatype::encode(args.datatype1),
+          Datatype::encode(args.datatype2)));
 
-  TRACE_smpi_comm_out(my_proc_id);
-  log_timed_action (action, clock);
-}
+    Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
+                     recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
 
+    TRACE_smpi_comm_out(my_proc_id);
+  }
+};
+} // Replay Namespace
 }} // namespace simgrid::smpi
 
 /** @brief Only initialize the replay, don't do it for real */
@@ -821,7 +829,7 @@ void smpi_replay_init(int* argc, char*** argv)
   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::ReduceAction().execute(action); });
   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::AllReduceAction().execute(action); });
   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::AllToAllAction().execute(action); });
-  xbt_replay_action_register("allToAllV",  simgrid::smpi::action_allToAllv);
+  xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::AllToAllVAction().execute(action); });
   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::GatherAction("gather").execute(action); });
   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::ScatterAction().execute(action); });
   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::GatherVAction("gatherV").execute(action); });