Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Please sonar: promote struct to class.
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
index e7c7a07..4744eda 100644 (file)
 #include <sstream>
 #include <vector>
 
-using simgrid::s4u::Actor;
-
 #include <tuple>
 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
 // this could go into a header file.
 namespace hash_tuple{
     template <typename TT>
-    struct hash
+    class hash
     {
+    public:
         size_t
         operator()(TT const& tt) const
         {
@@ -44,8 +43,9 @@ namespace hash_tuple{
 
     // Recursive template code derived from Matthieu M.
     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
-    struct HashValueImpl
+    class HashValueImpl
     {
+    public:
       static void apply(size_t& seed, Tuple const& tuple)
       {
         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
@@ -54,8 +54,9 @@ namespace hash_tuple{
     };
 
     template <class Tuple>
-    struct HashValueImpl<Tuple,0>
+    class HashValueImpl<Tuple,0>
     {
+    public:
       static void apply(size_t& seed, Tuple const& tuple)
       {
         hash_combine(seed, std::get<0>(tuple));
@@ -63,8 +64,9 @@ namespace hash_tuple{
     };
 
     template <typename ... TT>
-    struct hash<std::tuple<TT...>>
+    class hash<std::tuple<TT...>>
     {
+    public:
         size_t
         operator()(std::tuple<TT...> const& tt) const
         {
@@ -77,7 +79,6 @@ namespace hash_tuple{
 
 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
 
-static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
 
@@ -108,16 +109,6 @@ static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
   }
 }
 
-static std::vector<MPI_Request>* get_reqq_self()
-{
-  return reqq.at(simgrid::s4u::this_actor::get_pid());
-}
-
-static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
-{
-  reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
-}
-
 /* Helper function */
 static double parse_double(std::string string)
 {
@@ -563,7 +554,7 @@ public:
 
 class WaitAction : public ReplayAction<WaitTestParser> {
 public:
-  WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
+  explicit WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
   void kernel(simgrid::xbt::ReplayAction& action) override
   {
     std::string s = boost::algorithm::join(action, " ");
@@ -611,7 +602,7 @@ public:
       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
     } else if (name == "Isend") {
       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
-      get_reqq_self()->push_back(request);
+      req_storage->add(request);
     } else {
       xbt_die("Don't know this action, %s", name.c_str());
     }
@@ -642,7 +633,7 @@ public:
       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
     } else if (name == "Irecv") {
       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
-      get_reqq_self()->push_back(request);
+      req_storage->add(request);
     }
 
     TRACE_smpi_comm_out(my_proc_id);
@@ -666,11 +657,11 @@ public:
 
 class TestAction : public ReplayAction<WaitTestParser> {
 public:
-  TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
+  explicit TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
   void kernel(simgrid::xbt::ReplayAction& action) override
   {
-    MPI_Request request = get_reqq_self()->back();
-    get_reqq_self()->pop_back();
+    MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
+    req_storage->remove(request);
     // if request is null here, this may mean that a previous test has succeeded
     // Different times in traced application and replayed version may lead to this
     // In this case, ignore the extra calls.
@@ -683,7 +674,10 @@ public:
       XBT_DEBUG("MPI_Test result: %d", flag);
       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
        * nullptr.*/
-      get_reqq_self()->push_back(request);
+      if (request == MPI_REQUEST_NULL)
+        req_storage->addNullRequest(args.src, args.dst, args.tag);
+      else
+        req_storage->add(request);
 
       TRACE_smpi_testing_out(my_proc_id);
     }
@@ -701,7 +695,6 @@ public:
 
     /* start a simulated timer */
     smpi_process()->simulated_start();
-    set_reqq_self(new std::vector<MPI_Request>);
   }
 };
 
@@ -713,21 +706,24 @@ public:
 
 class WaitAllAction : public ReplayAction<ActionArgParser> {
 public:
-  WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
+  explicit WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
   void kernel(simgrid::xbt::ReplayAction& action) override
   {
-    const unsigned int count_requests = get_reqq_self()->size();
+    const unsigned int count_requests = req_storage->size();
 
     if (count_requests > 0) {
       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
-      for (const auto& req : (*get_reqq_self())) {
+      std::vector<MPI_Request> reqs;
+      req_storage->get_requests(reqs);
+      for (const auto& req : reqs) {
         if (req && (req->flags() & RECV)) {
           sender_receiver.push_back({req->src(), req->dst()});
         }
       }
       MPI_Status status[count_requests];
-      Request::waitall(count_requests, &(*get_reqq_self())[0], status);
+      Request::waitall(count_requests, &(reqs.data())[0], status);
+      req_storage->get_store().clear();
 
       for (auto& pair : sender_receiver) {
         TRACE_smpi_recv(pair.first, pair.second, 0);
@@ -1001,20 +997,19 @@ void smpi_replay_main(int* argc, char*** argv)
 
   /* and now, finalize everything */
   /* One active process will stop. Decrease the counter*/
-  XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
-  if (not get_reqq_self()->empty()) {
-    unsigned int count_requests=get_reqq_self()->size();
+  unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
+  XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
+  if (count_requests > 0) {
     MPI_Request requests[count_requests];
     MPI_Status status[count_requests];
     unsigned int i=0;
 
-    for (auto const& req : *get_reqq_self()) {
-      requests[i] = req;
+    for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
+      requests[i] = pair.second;
       i++;
     }
     simgrid::smpi::Request::waitall(count_requests, requests, status);
   }
-  delete get_reqq_self();
   active_processes--;
 
   if(active_processes==0){