Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Unused.
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 #include <tuple>
23 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
24 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
25 // this could go into a header file.
26 namespace hash_tuple{
27     template <typename TT>
28     struct hash
29     {
30         size_t
31         operator()(TT const& tt) const
32         {
33             return std::hash<TT>()(tt);
34         }
35     };
36
37     template <class T>
38     inline void hash_combine(std::size_t& seed, T const& v)
39     {
40         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
41     }
42
43     // Recursive template code derived from Matthieu M.
44     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
45     struct HashValueImpl
46     {
47       static void apply(size_t& seed, Tuple const& tuple)
48       {
49         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
50         hash_combine(seed, std::get<Index>(tuple));
51       }
52     };
53
54     template <class Tuple>
55     struct HashValueImpl<Tuple,0>
56     {
57       static void apply(size_t& seed, Tuple const& tuple)
58       {
59         hash_combine(seed, std::get<0>(tuple));
60       }
61     };
62
63     template <typename ... TT>
64     struct hash<std::tuple<TT...>>
65     {
66         size_t
67         operator()(std::tuple<TT...> const& tt) const
68         {
69             size_t seed = 0;
70             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
71             return seed;
72         }
73     };
74 }
75
76 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
77
78 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
79 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
80
81 static MPI_Datatype MPI_DEFAULT_TYPE;
82
83 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
84   {                                                                                                                    \
85     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
86       std::stringstream ss;                                                                                            \
87       for (const auto& elem : action) {                                                                                \
88         ss << elem << " ";                                                                                             \
89       }                                                                                                                \
90       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
91                            "%zu items were given on the line. First two should be process_id and action.  "            \
92                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
93                            "The full line that was given is:\n   %s\n"                                                 \
94                            "Please contact the Simgrid team if support is needed",                                     \
95              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
96              ss.str().c_str());                                                                                        \
97     }                                                                                                                  \
98   }
99
100 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
101 {
102   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
103     std::string s = boost::algorithm::join(action, " ");
104     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
105   }
106 }
107
108 /* Helper function */
109 static double parse_double(std::string string)
110 {
111   return xbt_str_parse_double(string.c_str(), "%s is not a double");
112 }
113
114 namespace simgrid {
115 namespace smpi {
116
117 namespace replay {
118
119 class RequestStorage {
120 private:
121     req_storage_t store;
122
123 public:
124     RequestStorage() {}
125     int size()
126     {
127       return store.size();
128     }
129
130     req_storage_t& get_store()
131     {
132       return store;
133     }
134
135     void get_requests(std::vector<MPI_Request>& vec)
136     {
137       for (auto& pair : store) {
138         auto& req = pair.second;
139         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
140         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
141           vec.push_back(pair.second);
142           pair.second->print_request("MM");
143         }
144       }
145     }
146
147     MPI_Request find(int src, int dst, int tag)
148     {
149       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
150       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
151     }
152
153     void remove(MPI_Request req)
154     {
155       if (req == MPI_REQUEST_NULL) return;
156
157       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
158     }
159
160     void add(MPI_Request req)
161     {
162       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
163         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
164     }
165
166     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
167     void addNullRequest(int src, int dst, int tag)
168     {
169       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
170     }
171 };
172
173 class ActionArgParser {
174 public:
175   virtual ~ActionArgParser() = default;
176   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
177 };
178
179 class WaitTestParser : public ActionArgParser {
180 public:
181   int src;
182   int dst;
183   int tag;
184
185   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
186   {
187     CHECK_ACTION_PARAMS(action, 3, 0)
188     src = std::stoi(action[2]);
189     dst = std::stoi(action[3]);
190     tag = std::stoi(action[4]);
191   }
192 };
193
194 class SendRecvParser : public ActionArgParser {
195 public:
196   /* communication partner; if we send, this is the receiver and vice versa */
197   int partner;
198   double size;
199   int tag;
200   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
201
202   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
203   {
204     CHECK_ACTION_PARAMS(action, 3, 1)
205     partner = std::stoi(action[2]);
206     tag     = std::stoi(action[3]);
207     size    = parse_double(action[4]);
208     if (action.size() > 5)
209       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
210   }
211 };
212
213 class ComputeParser : public ActionArgParser {
214 public:
215   /* communication partner; if we send, this is the receiver and vice versa */
216   double flops;
217
218   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
219   {
220     CHECK_ACTION_PARAMS(action, 1, 0)
221     flops = parse_double(action[2]);
222   }
223 };
224
225 class CollCommParser : public ActionArgParser {
226 public:
227   double size;
228   double comm_size;
229   double comp_size;
230   int send_size;
231   int recv_size;
232   int root = 0;
233   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
234   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
235 };
236
237 class BcastArgParser : public CollCommParser {
238 public:
239   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
240   {
241     CHECK_ACTION_PARAMS(action, 1, 2)
242     size = parse_double(action[2]);
243     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
244     if (action.size() > 4)
245       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
246   }
247 };
248
249 class ReduceArgParser : public CollCommParser {
250 public:
251   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
252   {
253     CHECK_ACTION_PARAMS(action, 2, 2)
254     comm_size = parse_double(action[2]);
255     comp_size = parse_double(action[3]);
256     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
257     if (action.size() > 5)
258       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
259   }
260 };
261
262 class AllReduceArgParser : public CollCommParser {
263 public:
264   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
265   {
266     CHECK_ACTION_PARAMS(action, 2, 1)
267     comm_size = parse_double(action[2]);
268     comp_size = parse_double(action[3]);
269     if (action.size() > 4)
270       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
271   }
272 };
273
274 class AllToAllArgParser : public CollCommParser {
275 public:
276   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
277   {
278     CHECK_ACTION_PARAMS(action, 2, 1)
279     comm_size = MPI_COMM_WORLD->size();
280     send_size = parse_double(action[2]);
281     recv_size = parse_double(action[3]);
282
283     if (action.size() > 4)
284       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
285     if (action.size() > 5)
286       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
287   }
288 };
289
290 class GatherArgParser : public CollCommParser {
291 public:
292   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
293   {
294     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
295           0 gather 68 68 0 0 0
296         where:
297           1) 68 is the sendcounts
298           2) 68 is the recvcounts
299           3) 0 is the root node
300           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
301           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
302     */
303     CHECK_ACTION_PARAMS(action, 2, 3)
304     comm_size = MPI_COMM_WORLD->size();
305     send_size = parse_double(action[2]);
306     recv_size = parse_double(action[3]);
307
308     if (name == "gather") {
309       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
310       if (action.size() > 5)
311         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
312       if (action.size() > 6)
313         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
314     }
315     else {
316       if (action.size() > 4)
317         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
318       if (action.size() > 5)
319         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
320     }
321   }
322 };
323
324 class GatherVArgParser : public CollCommParser {
325 public:
326   int recv_size_sum;
327   std::shared_ptr<std::vector<int>> recvcounts;
328   std::vector<int> disps;
329   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
330   {
331     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
332          0 gather 68 68 10 10 10 0 0 0
333        where:
334          1) 68 is the sendcount
335          2) 68 10 10 10 is the recvcounts
336          3) 0 is the root node
337          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
338          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
339     */
340     comm_size = MPI_COMM_WORLD->size();
341     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
342     send_size = parse_double(action[2]);
343     disps     = std::vector<int>(comm_size, 0);
344     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
345
346     if (name == "gatherV") {
347       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
348       if (action.size() > 4 + comm_size)
349         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
350       if (action.size() > 5 + comm_size)
351         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
352     }
353     else {
354       int datatype_index = 0;
355       int disp_index     = 0;
356       /* The 3 comes from "0 gather <sendcount>", which must always be present.
357        * The + comm_size is the recvcounts array, which must also be present
358        */
359       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
360         datatype_index = 3 + comm_size;
361         disp_index     = datatype_index + 1;
362         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
363         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
364       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
365         disp_index     = 3 + comm_size;
366       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
367         datatype_index = 3 + comm_size;
368         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
369         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
370       }
371
372       if (disp_index != 0) {
373         for (unsigned int i = 0; i < comm_size; i++)
374           disps[i]          = std::stoi(action[disp_index + i]);
375       }
376     }
377
378     for (unsigned int i = 0; i < comm_size; i++) {
379       (*recvcounts)[i] = std::stoi(action[i + 3]);
380     }
381     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
382   }
383 };
384
385 class ScatterArgParser : public CollCommParser {
386 public:
387   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
388   {
389     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
390           0 gather 68 68 0 0 0
391         where:
392           1) 68 is the sendcounts
393           2) 68 is the recvcounts
394           3) 0 is the root node
395           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
396           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
397     */
398     CHECK_ACTION_PARAMS(action, 2, 3)
399     comm_size   = MPI_COMM_WORLD->size();
400     send_size   = parse_double(action[2]);
401     recv_size   = parse_double(action[3]);
402     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
403     if (action.size() > 5)
404       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
405     if (action.size() > 6)
406       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
407   }
408 };
409
410 class ScatterVArgParser : public CollCommParser {
411 public:
412   int recv_size_sum;
413   int send_size_sum;
414   std::shared_ptr<std::vector<int>> sendcounts;
415   std::vector<int> disps;
416   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
417   {
418     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
419        0 gather 68 10 10 10 68 0 0 0
420         where:
421         1) 68 10 10 10 is the sendcounts
422         2) 68 is the recvcount
423         3) 0 is the root node
424         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
425         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
426     */
427     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
428     recv_size  = parse_double(action[2 + comm_size]);
429     disps      = std::vector<int>(comm_size, 0);
430     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
431
432     if (action.size() > 5 + comm_size)
433       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
434     if (action.size() > 5 + comm_size)
435       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
436
437     for (unsigned int i = 0; i < comm_size; i++) {
438       (*sendcounts)[i] = std::stoi(action[i + 2]);
439     }
440     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
441     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
442   }
443 };
444
445 class ReduceScatterArgParser : public CollCommParser {
446 public:
447   int recv_size_sum;
448   std::shared_ptr<std::vector<int>> recvcounts;
449   std::vector<int> disps;
450   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
451   {
452     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
453          0 reduceScatter 275427 275427 275427 204020 11346849 0
454        where:
455          1) The first four values after the name of the action declare the recvcounts array
456          2) The value 11346849 is the amount of instructions
457          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
458     */
459     comm_size = MPI_COMM_WORLD->size();
460     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
461     comp_size = parse_double(action[2+comm_size]);
462     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
463     if (action.size() > 3 + comm_size)
464       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
465
466     for (unsigned int i = 0; i < comm_size; i++) {
467       recvcounts->push_back(std::stoi(action[i + 2]));
468     }
469     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
470   }
471 };
472
473 class AllToAllVArgParser : public CollCommParser {
474 public:
475   int recv_size_sum;
476   int send_size_sum;
477   std::shared_ptr<std::vector<int>> recvcounts;
478   std::shared_ptr<std::vector<int>> sendcounts;
479   std::vector<int> senddisps;
480   std::vector<int> recvdisps;
481   int send_buf_size;
482   int recv_buf_size;
483   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
484   {
485     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
486           0 allToAllV 100 1 7 10 12 100 1 70 10 5
487        where:
488         1) 100 is the size of the send buffer *sizeof(int),
489         2) 1 7 10 12 is the sendcounts array
490         3) 100*sizeof(int) is the size of the receiver buffer
491         4)  1 70 10 5 is the recvcounts array
492     */
493     comm_size = MPI_COMM_WORLD->size();
494     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
495     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
496     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
497     senddisps  = std::vector<int>(comm_size, 0);
498     recvdisps  = std::vector<int>(comm_size, 0);
499
500     if (action.size() > 5 + 2 * comm_size)
501       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
502     if (action.size() > 5 + 2 * comm_size)
503       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
504
505     send_buf_size=parse_double(action[2]);
506     recv_buf_size=parse_double(action[3+comm_size]);
507     for (unsigned int i = 0; i < comm_size; i++) {
508       (*sendcounts)[i] = std::stoi(action[3 + i]);
509       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
510     }
511     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
512     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
513   }
514 };
515
516 template <class T> class ReplayAction {
517 protected:
518   const std::string name;
519   RequestStorage* req_storage; // Points to the right storage for this process, nullptr except for Send/Recv/Wait/Test actions.
520   const int my_proc_id;
521   T args;
522
523 public:
524   explicit ReplayAction(std::string name, RequestStorage& storage) : name(name), req_storage(&storage), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
525   explicit ReplayAction(std::string name) : name(name), req_storage(nullptr), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
526   virtual ~ReplayAction() = default;
527
528   virtual void execute(simgrid::xbt::ReplayAction& action)
529   {
530     // Needs to be re-initialized for every action, hence here
531     double start_time = smpi_process()->simulated_elapsed();
532     args.parse(action, name);
533     kernel(action);
534     if (name != "Init")
535       log_timed_action(action, start_time);
536   }
537
538   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
539
540   void* send_buffer(int size)
541   {
542     return smpi_get_tmp_sendbuffer(size);
543   }
544
545   void* recv_buffer(int size)
546   {
547     return smpi_get_tmp_recvbuffer(size);
548   }
549 };
550
551 class WaitAction : public ReplayAction<WaitTestParser> {
552 public:
553   explicit WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
554   void kernel(simgrid::xbt::ReplayAction& action) override
555   {
556     std::string s = boost::algorithm::join(action, " ");
557     xbt_assert(req_storage->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
558     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
559     req_storage->remove(request);
560
561     if (request == MPI_REQUEST_NULL) {
562       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
563        * return.*/
564       return;
565     }
566
567     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
568
569     // Must be taken before Request::wait() since the request may be set to
570     // MPI_REQUEST_NULL by Request::wait!
571     bool is_wait_for_receive = (request->flags() & RECV);
572     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
573     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
574
575     MPI_Status status;
576     Request::wait(&request, &status);
577
578     TRACE_smpi_comm_out(rank);
579     if (is_wait_for_receive)
580       TRACE_smpi_recv(args.src, args.dst, args.tag);
581   }
582 };
583
584 class SendAction : public ReplayAction<SendRecvParser> {
585 public:
586   SendAction() = delete;
587   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
588   void kernel(simgrid::xbt::ReplayAction& action) override
589   {
590     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
591
592     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
593                                                                              args.tag, Datatype::encode(args.datatype1)));
594     if (not TRACE_smpi_view_internals())
595       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
596
597     if (name == "send") {
598       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
599     } else if (name == "Isend") {
600       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
601       req_storage->add(request);
602     } else {
603       xbt_die("Don't know this action, %s", name.c_str());
604     }
605
606     TRACE_smpi_comm_out(my_proc_id);
607   }
608 };
609
610 class RecvAction : public ReplayAction<SendRecvParser> {
611 public:
612   RecvAction() = delete;
613   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
614   void kernel(simgrid::xbt::ReplayAction& action) override
615   {
616     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
617
618     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
619                                                                              args.tag, Datatype::encode(args.datatype1)));
620
621     MPI_Status status;
622     // unknown size from the receiver point of view
623     if (args.size <= 0.0) {
624       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
625       args.size = status.count;
626     }
627
628     if (name == "recv") {
629       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
630     } else if (name == "Irecv") {
631       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
632       req_storage->add(request);
633     }
634
635     TRACE_smpi_comm_out(my_proc_id);
636     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
637     if (name == "recv" && not TRACE_smpi_view_internals()) {
638       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
639     }
640   }
641 };
642
643 class ComputeAction : public ReplayAction<ComputeParser> {
644 public:
645   ComputeAction() : ReplayAction("compute") {}
646   void kernel(simgrid::xbt::ReplayAction& action) override
647   {
648     TRACE_smpi_computing_in(my_proc_id, args.flops);
649     smpi_execute_flops(args.flops);
650     TRACE_smpi_computing_out(my_proc_id);
651   }
652 };
653
654 class TestAction : public ReplayAction<WaitTestParser> {
655 public:
656   explicit TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
657   void kernel(simgrid::xbt::ReplayAction& action) override
658   {
659     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
660     req_storage->remove(request);
661     // if request is null here, this may mean that a previous test has succeeded
662     // Different times in traced application and replayed version may lead to this
663     // In this case, ignore the extra calls.
664     if (request != MPI_REQUEST_NULL) {
665       TRACE_smpi_testing_in(my_proc_id);
666
667       MPI_Status status;
668       int flag = Request::test(&request, &status);
669
670       XBT_DEBUG("MPI_Test result: %d", flag);
671       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
672        * nullptr.*/
673       if (request == MPI_REQUEST_NULL)
674         req_storage->addNullRequest(args.src, args.dst, args.tag);
675       else
676         req_storage->add(request);
677
678       TRACE_smpi_testing_out(my_proc_id);
679     }
680   }
681 };
682
683 class InitAction : public ReplayAction<ActionArgParser> {
684 public:
685   InitAction() : ReplayAction("Init") {}
686   void kernel(simgrid::xbt::ReplayAction& action) override
687   {
688     CHECK_ACTION_PARAMS(action, 0, 1)
689     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
690                                            : MPI_BYTE;  // default TAU datatype
691
692     /* start a simulated timer */
693     smpi_process()->simulated_start();
694   }
695 };
696
697 class CommunicatorAction : public ReplayAction<ActionArgParser> {
698 public:
699   CommunicatorAction() : ReplayAction("Comm") {}
700   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
701 };
702
703 class WaitAllAction : public ReplayAction<ActionArgParser> {
704 public:
705   explicit WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
706   void kernel(simgrid::xbt::ReplayAction& action) override
707   {
708     const unsigned int count_requests = req_storage->size();
709
710     if (count_requests > 0) {
711       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
712       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
713       std::vector<MPI_Request> reqs;
714       req_storage->get_requests(reqs);
715       for (const auto& req : reqs) {
716         if (req && (req->flags() & RECV)) {
717           sender_receiver.push_back({req->src(), req->dst()});
718         }
719       }
720       MPI_Status status[count_requests];
721       Request::waitall(count_requests, &(reqs.data())[0], status);
722       req_storage->get_store().clear();
723
724       for (auto& pair : sender_receiver) {
725         TRACE_smpi_recv(pair.first, pair.second, 0);
726       }
727       TRACE_smpi_comm_out(my_proc_id);
728     }
729   }
730 };
731
732 class BarrierAction : public ReplayAction<ActionArgParser> {
733 public:
734   BarrierAction() : ReplayAction("barrier") {}
735   void kernel(simgrid::xbt::ReplayAction& action) override
736   {
737     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
738     Colls::barrier(MPI_COMM_WORLD);
739     TRACE_smpi_comm_out(my_proc_id);
740   }
741 };
742
743 class BcastAction : public ReplayAction<BcastArgParser> {
744 public:
745   BcastAction() : ReplayAction("bcast") {}
746   void kernel(simgrid::xbt::ReplayAction& action) override
747   {
748     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
749                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
750                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
751
752     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
753
754     TRACE_smpi_comm_out(my_proc_id);
755   }
756 };
757
758 class ReduceAction : public ReplayAction<ReduceArgParser> {
759 public:
760   ReduceAction() : ReplayAction("reduce") {}
761   void kernel(simgrid::xbt::ReplayAction& action) override
762   {
763     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
764                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
765                                                       args.comp_size, args.comm_size, -1,
766                                                       Datatype::encode(args.datatype1), ""));
767
768     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
769         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
770     smpi_execute_flops(args.comp_size);
771
772     TRACE_smpi_comm_out(my_proc_id);
773   }
774 };
775
776 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
777 public:
778   AllReduceAction() : ReplayAction("allReduce") {}
779   void kernel(simgrid::xbt::ReplayAction& action) override
780   {
781     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
782                                                                                 Datatype::encode(args.datatype1), ""));
783
784     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
785         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
786     smpi_execute_flops(args.comp_size);
787
788     TRACE_smpi_comm_out(my_proc_id);
789   }
790 };
791
792 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
793 public:
794   AllToAllAction() : ReplayAction("allToAll") {}
795   void kernel(simgrid::xbt::ReplayAction& action) override
796   {
797     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
798                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
799                                                     Datatype::encode(args.datatype1),
800                                                     Datatype::encode(args.datatype2)));
801
802     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
803                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
804                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
805
806     TRACE_smpi_comm_out(my_proc_id);
807   }
808 };
809
810 class GatherAction : public ReplayAction<GatherArgParser> {
811 public:
812   explicit GatherAction(std::string name) : ReplayAction(name) {}
813   void kernel(simgrid::xbt::ReplayAction& action) override
814   {
815     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
816                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
817
818     if (name == "gather") {
819       int rank = MPI_COMM_WORLD->rank();
820       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
821                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
822     }
823     else
824       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
825                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
826
827     TRACE_smpi_comm_out(my_proc_id);
828   }
829 };
830
831 class GatherVAction : public ReplayAction<GatherVArgParser> {
832 public:
833   explicit GatherVAction(std::string name) : ReplayAction(name) {}
834   void kernel(simgrid::xbt::ReplayAction& action) override
835   {
836     int rank = MPI_COMM_WORLD->rank();
837
838     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
839                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
840                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
841
842     if (name == "gatherV") {
843       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
844                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
845                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
846     }
847     else {
848       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
849                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
850                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
851     }
852
853     TRACE_smpi_comm_out(my_proc_id);
854   }
855 };
856
857 class ScatterAction : public ReplayAction<ScatterArgParser> {
858 public:
859   ScatterAction() : ReplayAction("scatter") {}
860   void kernel(simgrid::xbt::ReplayAction& action) override
861   {
862     int rank = MPI_COMM_WORLD->rank();
863     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
864                                                                           Datatype::encode(args.datatype1),
865                                                                           Datatype::encode(args.datatype2)));
866
867     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
868                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
869
870     TRACE_smpi_comm_out(my_proc_id);
871   }
872 };
873
874
875 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
876 public:
877   ScatterVAction() : ReplayAction("scatterV") {}
878   void kernel(simgrid::xbt::ReplayAction& action) override
879   {
880     int rank = MPI_COMM_WORLD->rank();
881     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
882           nullptr, Datatype::encode(args.datatype1),
883           Datatype::encode(args.datatype2)));
884
885     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
886                     args.sendcounts->data(), args.disps.data(), args.datatype1,
887                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
888                     MPI_COMM_WORLD);
889
890     TRACE_smpi_comm_out(my_proc_id);
891   }
892 };
893
894 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
895 public:
896   ReduceScatterAction() : ReplayAction("reduceScatter") {}
897   void kernel(simgrid::xbt::ReplayAction& action) override
898   {
899     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
900                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
901                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
902                                                          Datatype::encode(args.datatype1)));
903
904     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
905                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
906                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
907
908     smpi_execute_flops(args.comp_size);
909     TRACE_smpi_comm_out(my_proc_id);
910   }
911 };
912
913 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
914 public:
915   AllToAllVAction() : ReplayAction("allToAllV") {}
916   void kernel(simgrid::xbt::ReplayAction& action) override
917   {
918     TRACE_smpi_comm_in(my_proc_id, __func__,
919                        new simgrid::instr::VarCollTIData(
920                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
921                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
922
923     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
924                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
925
926     TRACE_smpi_comm_out(my_proc_id);
927   }
928 };
929 } // Replay Namespace
930 }} // namespace simgrid::smpi
931
932 std::vector<simgrid::smpi::replay::RequestStorage> storage;
933 /** @brief Only initialize the replay, don't do it for real */
934 void smpi_replay_init(int* argc, char*** argv)
935 {
936   simgrid::smpi::Process::init(argc, argv);
937   smpi_process()->mark_as_initialized();
938   smpi_process()->set_replaying(true);
939
940   int my_proc_id = simgrid::s4u::this_actor::get_pid();
941   storage.resize(smpi_process_count());
942
943   TRACE_smpi_init(my_proc_id);
944   TRACE_smpi_computing_init(my_proc_id);
945   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
946   TRACE_smpi_comm_out(my_proc_id);
947   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
948   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
949   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
950   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
951   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
952
953   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
954   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
955   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
956   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
957   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
958   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
959   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
960   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
961   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
962   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
963   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
964   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
965   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
966   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
967   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
968   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
969   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
970   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
971   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
972   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
973   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
974
975   //if we have a delayed start, sleep here.
976   if(*argc>2){
977     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
978     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
979     smpi_execute_flops(value);
980   } else {
981     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
982     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
983     smpi_execute_flops(0.0);
984   }
985 }
986
987 /** @brief actually run the replay after initialization */
988 void smpi_replay_main(int* argc, char*** argv)
989 {
990   static int active_processes = 0;
991   active_processes++;
992   simgrid::xbt::replay_runner(*argc, *argv);
993
994   /* and now, finalize everything */
995   /* One active process will stop. Decrease the counter*/
996   unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
997   XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
998   if (count_requests > 0) {
999     MPI_Request requests[count_requests];
1000     MPI_Status status[count_requests];
1001     unsigned int i=0;
1002
1003     for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
1004       requests[i] = pair.second;
1005       i++;
1006     }
1007     simgrid::smpi::Request::waitall(count_requests, requests, status);
1008   }
1009   active_processes--;
1010
1011   if(active_processes==0){
1012     /* Last process alive speaking: end the simulated timer */
1013     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1014     smpi_free_replay_tmp_buffers();
1015   }
1016
1017   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1018                      new simgrid::instr::NoOpTIData("finalize"));
1019
1020   smpi_process()->finalize();
1021
1022   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1023   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1024 }
1025
1026 /** @brief chain a replay initialization and a replay start */
1027 void smpi_replay_run(int* argc, char*** argv)
1028 {
1029   smpi_replay_init(argc, argv);
1030   smpi_replay_main(argc, argv);
1031 }