Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Classify the actions, start with Wait
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
index 95b3e80..8067fd4 100644 (file)
@@ -65,6 +65,76 @@ static double parse_double(std::string string)
 namespace simgrid {
 namespace smpi {
 
+namespace Replay {
+class ActionArgParser {
+public:
+  virtual void parse(simgrid::xbt::ReplayAction& action){};
+};
+
+template<class T> class ReplayAction {
+  protected:
+  const std::string name;
+  T args;
+
+  /*
+   * Used to compute the duration of this action.
+   */
+  double start_time;
+
+  int my_proc_id;
+
+public:
+  explicit ReplayAction(std::string name)
+      : name(name), start_time(smpi_process()->simulated_elapsed()), my_proc_id(simgrid::s4u::Actor::self()->getPid())
+  {
+  }
+
+  virtual void execute(simgrid::xbt::ReplayAction& action)
+  {
+    args.parse(action);
+    kernel(action);
+    log_timed_action(action, start_time);
+  }
+
+  virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
+};
+
+class WaitAction : public ReplayAction<ActionArgParser> {
+public:
+  WaitAction() : ReplayAction("Wait") {}
+  void kernel(simgrid::xbt::ReplayAction& action) override
+  {
+    CHECK_ACTION_PARAMS(action, 0, 0)
+    MPI_Status status;
+
+    std::string s = boost::algorithm::join(action, " ");
+    xbt_assert(get_reqq_self()->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
+    MPI_Request request = get_reqq_self()->back();
+    get_reqq_self()->pop_back();
+
+    if (request == nullptr) {
+      /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
+       * return.*/
+      return;
+    }
+
+    int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
+
+    MPI_Group group          = request->comm()->group();
+    int src_traced           = group->rank(request->src());
+    int dst_traced           = group->rank(request->dst());
+    bool is_wait_for_receive = (request->flags() & RECV);
+    // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
+    TRACE_smpi_comm_in(rank, __FUNCTION__, new simgrid::instr::NoOpTIData("wait"));
+
+    Request::wait(&request, &status);
+
+    TRACE_smpi_comm_out(rank);
+    if (is_wait_for_receive)
+      TRACE_smpi_recv(src_traced, dst_traced, 0);
+  }
+};
+
 static void action_init(simgrid::xbt::ReplayAction& action)
 {
   XBT_DEBUG("Initialize the counters");
@@ -253,34 +323,7 @@ static void action_test(simgrid::xbt::ReplayAction& action)
 
 static void action_wait(simgrid::xbt::ReplayAction& action)
 {
-  CHECK_ACTION_PARAMS(action, 0, 0)
-  double clock = smpi_process()->simulated_elapsed();
-  MPI_Status status;
-
-  std::string s = boost::algorithm::join(action, " ");
-  xbt_assert(get_reqq_self()->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
-  MPI_Request request = get_reqq_self()->back();
-  get_reqq_self()->pop_back();
-
-  if (request==nullptr){
-    /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just return.*/
-    return;
-  }
-
-  int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
-
-  MPI_Group group = request->comm()->group();
-  int src_traced = group->rank(request->src());
-  int dst_traced = group->rank(request->dst());
-  int is_wait_for_receive = (request->flags() & RECV);
-  TRACE_smpi_comm_in(rank, __FUNCTION__, new simgrid::instr::NoOpTIData("wait"));
-
-  Request::wait(&request, &status);
-
-  TRACE_smpi_comm_out(rank);
-  if (is_wait_for_receive)
-    TRACE_smpi_recv(src_traced, dst_traced, 0);
-  log_timed_action (action, clock);
+  Replay::WaitAction().execute(action);
 }
 
 static void action_waitall(simgrid::xbt::ReplayAction& action)