Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Change xbt_cfg_get_bool -> simgrid::config::get_config<bool>.
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 #include <tuple>
23 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
24 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
25 // this could go into a header file.
26 namespace hash_tuple {
27 template <typename TT> class hash {
28 public:
29   size_t operator()(TT const& tt) const { return std::hash<TT>()(tt); }
30 };
31
32 template <class T> inline void hash_combine(std::size_t& seed, T const& v)
33 {
34   seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
35 }
36
37 // Recursive template code derived from Matthieu M.
38 template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1> class HashValueImpl {
39 public:
40   static void apply(size_t& seed, Tuple const& tuple)
41   {
42     HashValueImpl<Tuple, Index - 1>::apply(seed, tuple);
43     hash_combine(seed, std::get<Index>(tuple));
44   }
45 };
46
47 template <class Tuple> class HashValueImpl<Tuple, 0> {
48 public:
49   static void apply(size_t& seed, Tuple const& tuple) { hash_combine(seed, std::get<0>(tuple)); }
50 };
51
52 template <typename... TT> class hash<std::tuple<TT...>> {
53 public:
54   size_t operator()(std::tuple<TT...> const& tt) const
55   {
56     size_t seed = 0;
57     HashValueImpl<std::tuple<TT...>>::apply(seed, tt);
58     return seed;
59   }
60 };
61 }
62
63 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
64
65 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
66 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
67
68 static MPI_Datatype MPI_DEFAULT_TYPE;
69
70 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
71   {                                                                                                                    \
72     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
73       std::stringstream ss;                                                                                            \
74       for (const auto& elem : action) {                                                                                \
75         ss << elem << " ";                                                                                             \
76       }                                                                                                                \
77       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
78                            "%zu items were given on the line. First two should be process_id and action.  "            \
79                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
80                            "The full line that was given is:\n   %s\n"                                                 \
81                            "Please contact the Simgrid team if support is needed",                                     \
82              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
83              ss.str().c_str());                                                                                        \
84     }                                                                                                                  \
85   }
86
87 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
88 {
89   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
90     std::string s = boost::algorithm::join(action, " ");
91     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
92   }
93 }
94
95 /* Helper function */
96 static double parse_double(std::string string)
97 {
98   return xbt_str_parse_double(string.c_str(), "%s is not a double");
99 }
100
101 namespace simgrid {
102 namespace smpi {
103
104 namespace replay {
105
106 class RequestStorage {
107 private:
108     req_storage_t store;
109
110 public:
111     RequestStorage() {}
112     int size()
113     {
114       return store.size();
115     }
116
117     req_storage_t& get_store()
118     {
119       return store;
120     }
121
122     void get_requests(std::vector<MPI_Request>& vec)
123     {
124       for (auto& pair : store) {
125         auto& req = pair.second;
126         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
127         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
128           vec.push_back(pair.second);
129           pair.second->print_request("MM");
130         }
131       }
132     }
133
134     MPI_Request find(int src, int dst, int tag)
135     {
136       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
137       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
138     }
139
140     void remove(MPI_Request req)
141     {
142       if (req == MPI_REQUEST_NULL) return;
143
144       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
145     }
146
147     void add(MPI_Request req)
148     {
149       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
150         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
151     }
152
153     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
154     void addNullRequest(int src, int dst, int tag)
155     {
156       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
157     }
158 };
159
160 /**
161  * Base class for all parsers.
162  */
163 class ActionArgParser {
164 public:
165   virtual ~ActionArgParser() = default;
166   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
167 };
168
169 class WaitTestParser : public ActionArgParser {
170 public:
171   int src;
172   int dst;
173   int tag;
174
175   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
176   {
177     CHECK_ACTION_PARAMS(action, 3, 0)
178     src = std::stoi(action[2]);
179     dst = std::stoi(action[3]);
180     tag = std::stoi(action[4]);
181   }
182 };
183
184 class SendRecvParser : public ActionArgParser {
185 public:
186   /* communication partner; if we send, this is the receiver and vice versa */
187   int partner;
188   double size;
189   int tag;
190   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
191
192   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
193   {
194     CHECK_ACTION_PARAMS(action, 3, 1)
195     partner = std::stoi(action[2]);
196     tag     = std::stoi(action[3]);
197     size    = parse_double(action[4]);
198     if (action.size() > 5)
199       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
200   }
201 };
202
203 class ComputeParser : public ActionArgParser {
204 public:
205   /* communication partner; if we send, this is the receiver and vice versa */
206   double flops;
207
208   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
209   {
210     CHECK_ACTION_PARAMS(action, 1, 0)
211     flops = parse_double(action[2]);
212   }
213 };
214
215 class CollCommParser : public ActionArgParser {
216 public:
217   double size;
218   double comm_size;
219   double comp_size;
220   int send_size;
221   int recv_size;
222   int root = 0;
223   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
224   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
225 };
226
227 class BcastArgParser : public CollCommParser {
228 public:
229   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
230   {
231     CHECK_ACTION_PARAMS(action, 1, 2)
232     size = parse_double(action[2]);
233     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
234     if (action.size() > 4)
235       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
236   }
237 };
238
239 class ReduceArgParser : public CollCommParser {
240 public:
241   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
242   {
243     CHECK_ACTION_PARAMS(action, 2, 2)
244     comm_size = parse_double(action[2]);
245     comp_size = parse_double(action[3]);
246     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
247     if (action.size() > 5)
248       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
249   }
250 };
251
252 class AllReduceArgParser : public CollCommParser {
253 public:
254   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
255   {
256     CHECK_ACTION_PARAMS(action, 2, 1)
257     comm_size = parse_double(action[2]);
258     comp_size = parse_double(action[3]);
259     if (action.size() > 4)
260       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
261   }
262 };
263
264 class AllToAllArgParser : public CollCommParser {
265 public:
266   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
267   {
268     CHECK_ACTION_PARAMS(action, 2, 1)
269     comm_size = MPI_COMM_WORLD->size();
270     send_size = parse_double(action[2]);
271     recv_size = parse_double(action[3]);
272
273     if (action.size() > 4)
274       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
275     if (action.size() > 5)
276       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
277   }
278 };
279
280 class GatherArgParser : public CollCommParser {
281 public:
282   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
283   {
284     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
285           0 gather 68 68 0 0 0
286         where:
287           1) 68 is the sendcounts
288           2) 68 is the recvcounts
289           3) 0 is the root node
290           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
291           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
292     */
293     CHECK_ACTION_PARAMS(action, 2, 3)
294     comm_size = MPI_COMM_WORLD->size();
295     send_size = parse_double(action[2]);
296     recv_size = parse_double(action[3]);
297
298     if (name == "gather") {
299       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
300       if (action.size() > 5)
301         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
302       if (action.size() > 6)
303         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
304     }
305     else {
306       if (action.size() > 4)
307         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
308       if (action.size() > 5)
309         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
310     }
311   }
312 };
313
314 class GatherVArgParser : public CollCommParser {
315 public:
316   int recv_size_sum;
317   std::shared_ptr<std::vector<int>> recvcounts;
318   std::vector<int> disps;
319   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
320   {
321     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
322          0 gather 68 68 10 10 10 0 0 0
323        where:
324          1) 68 is the sendcount
325          2) 68 10 10 10 is the recvcounts
326          3) 0 is the root node
327          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
328          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
329     */
330     comm_size = MPI_COMM_WORLD->size();
331     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
332     send_size = parse_double(action[2]);
333     disps     = std::vector<int>(comm_size, 0);
334     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
335
336     if (name == "gatherV") {
337       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
338       if (action.size() > 4 + comm_size)
339         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
340       if (action.size() > 5 + comm_size)
341         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
342     }
343     else {
344       int datatype_index = 0;
345       int disp_index     = 0;
346       /* The 3 comes from "0 gather <sendcount>", which must always be present.
347        * The + comm_size is the recvcounts array, which must also be present
348        */
349       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
350         datatype_index = 3 + comm_size;
351         disp_index     = datatype_index + 1;
352         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
353         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
354       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
355         disp_index     = 3 + comm_size;
356       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
357         datatype_index = 3 + comm_size;
358         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
359         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
360       }
361
362       if (disp_index != 0) {
363         for (unsigned int i = 0; i < comm_size; i++)
364           disps[i]          = std::stoi(action[disp_index + i]);
365       }
366     }
367
368     for (unsigned int i = 0; i < comm_size; i++) {
369       (*recvcounts)[i] = std::stoi(action[i + 3]);
370     }
371     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
372   }
373 };
374
375 class ScatterArgParser : public CollCommParser {
376 public:
377   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
378   {
379     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
380           0 gather 68 68 0 0 0
381         where:
382           1) 68 is the sendcounts
383           2) 68 is the recvcounts
384           3) 0 is the root node
385           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
386           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
387     */
388     CHECK_ACTION_PARAMS(action, 2, 3)
389     comm_size   = MPI_COMM_WORLD->size();
390     send_size   = parse_double(action[2]);
391     recv_size   = parse_double(action[3]);
392     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
393     if (action.size() > 5)
394       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
395     if (action.size() > 6)
396       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
397   }
398 };
399
400 class ScatterVArgParser : public CollCommParser {
401 public:
402   int recv_size_sum;
403   int send_size_sum;
404   std::shared_ptr<std::vector<int>> sendcounts;
405   std::vector<int> disps;
406   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
407   {
408     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
409        0 gather 68 10 10 10 68 0 0 0
410         where:
411         1) 68 10 10 10 is the sendcounts
412         2) 68 is the recvcount
413         3) 0 is the root node
414         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
415         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
416     */
417     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
418     recv_size  = parse_double(action[2 + comm_size]);
419     disps      = std::vector<int>(comm_size, 0);
420     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
421
422     if (action.size() > 5 + comm_size)
423       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
424     if (action.size() > 5 + comm_size)
425       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
426
427     for (unsigned int i = 0; i < comm_size; i++) {
428       (*sendcounts)[i] = std::stoi(action[i + 2]);
429     }
430     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
431     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
432   }
433 };
434
435 class ReduceScatterArgParser : public CollCommParser {
436 public:
437   int recv_size_sum;
438   std::shared_ptr<std::vector<int>> recvcounts;
439   std::vector<int> disps;
440   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
441   {
442     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
443          0 reduceScatter 275427 275427 275427 204020 11346849 0
444        where:
445          1) The first four values after the name of the action declare the recvcounts array
446          2) The value 11346849 is the amount of instructions
447          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
448     */
449     comm_size = MPI_COMM_WORLD->size();
450     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
451     comp_size = parse_double(action[2+comm_size]);
452     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
453     if (action.size() > 3 + comm_size)
454       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
455
456     for (unsigned int i = 0; i < comm_size; i++) {
457       recvcounts->push_back(std::stoi(action[i + 2]));
458     }
459     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
460   }
461 };
462
463 class AllToAllVArgParser : public CollCommParser {
464 public:
465   int recv_size_sum;
466   int send_size_sum;
467   std::shared_ptr<std::vector<int>> recvcounts;
468   std::shared_ptr<std::vector<int>> sendcounts;
469   std::vector<int> senddisps;
470   std::vector<int> recvdisps;
471   int send_buf_size;
472   int recv_buf_size;
473   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
474   {
475     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
476           0 allToAllV 100 1 7 10 12 100 1 70 10 5
477        where:
478         1) 100 is the size of the send buffer *sizeof(int),
479         2) 1 7 10 12 is the sendcounts array
480         3) 100*sizeof(int) is the size of the receiver buffer
481         4)  1 70 10 5 is the recvcounts array
482     */
483     comm_size = MPI_COMM_WORLD->size();
484     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
485     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
486     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
487     senddisps  = std::vector<int>(comm_size, 0);
488     recvdisps  = std::vector<int>(comm_size, 0);
489
490     if (action.size() > 5 + 2 * comm_size)
491       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
492     if (action.size() > 5 + 2 * comm_size)
493       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
494
495     send_buf_size=parse_double(action[2]);
496     recv_buf_size=parse_double(action[3+comm_size]);
497     for (unsigned int i = 0; i < comm_size; i++) {
498       (*sendcounts)[i] = std::stoi(action[3 + i]);
499       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
500     }
501     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
502     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
503   }
504 };
505
506 /**
507  * Base class for all ReplayActions.
508  * Note that this class actually implements the behavior of each action
509  * while the parsing of the replay arguments is done in the @ActionArgParser class.
510  * In other words: The logic goes here, the setup is done by the ActionArgParser.
511  */
512 template <class T> class ReplayAction {
513 protected:
514   const std::string name;
515   const int my_proc_id;
516   T args;
517
518 public:
519   explicit ReplayAction(std::string name) : name(name), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
520   virtual ~ReplayAction() = default;
521
522   virtual void execute(simgrid::xbt::ReplayAction& action)
523   {
524     // Needs to be re-initialized for every action, hence here
525     double start_time = smpi_process()->simulated_elapsed();
526     args.parse(action, name);
527     kernel(action);
528     if (name != "Init")
529       log_timed_action(action, start_time);
530   }
531
532   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
533
534   void* send_buffer(int size)
535   {
536     return smpi_get_tmp_sendbuffer(size);
537   }
538
539   void* recv_buffer(int size)
540   {
541     return smpi_get_tmp_recvbuffer(size);
542   }
543 };
544
545 class WaitAction : public ReplayAction<WaitTestParser> {
546 private:
547   RequestStorage& req_storage;
548
549 public:
550   explicit WaitAction(RequestStorage& storage) : ReplayAction("Wait"), req_storage(storage) {}
551   void kernel(simgrid::xbt::ReplayAction& action) override
552   {
553     std::string s = boost::algorithm::join(action, " ");
554     xbt_assert(req_storage.size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
555     MPI_Request request = req_storage.find(args.src, args.dst, args.tag);
556     req_storage.remove(request);
557
558     if (request == MPI_REQUEST_NULL) {
559       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
560        * return.*/
561       return;
562     }
563
564     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
565
566     // Must be taken before Request::wait() since the request may be set to
567     // MPI_REQUEST_NULL by Request::wait!
568     bool is_wait_for_receive = (request->flags() & RECV);
569     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
570     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
571
572     MPI_Status status;
573     Request::wait(&request, &status);
574
575     TRACE_smpi_comm_out(rank);
576     if (is_wait_for_receive)
577       TRACE_smpi_recv(args.src, args.dst, args.tag);
578   }
579 };
580
581 class SendAction : public ReplayAction<SendRecvParser> {
582 private:
583   RequestStorage& req_storage;
584
585 public:
586   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name), req_storage(storage) {}
587   void kernel(simgrid::xbt::ReplayAction& action) override
588   {
589     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
590
591     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
592                                                                              args.tag, Datatype::encode(args.datatype1)));
593     if (not TRACE_smpi_view_internals())
594       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
595
596     if (name == "send") {
597       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
598     } else if (name == "Isend") {
599       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
600       req_storage.add(request);
601     } else {
602       xbt_die("Don't know this action, %s", name.c_str());
603     }
604
605     TRACE_smpi_comm_out(my_proc_id);
606   }
607 };
608
609 class RecvAction : public ReplayAction<SendRecvParser> {
610 private:
611   RequestStorage& req_storage;
612
613 public:
614   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name), req_storage(storage) {}
615   void kernel(simgrid::xbt::ReplayAction& action) override
616   {
617     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
618
619     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
620                                                                              args.tag, Datatype::encode(args.datatype1)));
621
622     MPI_Status status;
623     // unknown size from the receiver point of view
624     if (args.size <= 0.0) {
625       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
626       args.size = status.count;
627     }
628
629     if (name == "recv") {
630       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
631     } else if (name == "Irecv") {
632       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
633       req_storage.add(request);
634     }
635
636     TRACE_smpi_comm_out(my_proc_id);
637     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
638     if (name == "recv" && not TRACE_smpi_view_internals()) {
639       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
640     }
641   }
642 };
643
644 class ComputeAction : public ReplayAction<ComputeParser> {
645 public:
646   ComputeAction() : ReplayAction("compute") {}
647   void kernel(simgrid::xbt::ReplayAction& action) override
648   {
649     TRACE_smpi_computing_in(my_proc_id, args.flops);
650     smpi_execute_flops(args.flops);
651     TRACE_smpi_computing_out(my_proc_id);
652   }
653 };
654
655 class TestAction : public ReplayAction<WaitTestParser> {
656 private:
657   RequestStorage& req_storage;
658
659 public:
660   explicit TestAction(RequestStorage& storage) : ReplayAction("Test"), req_storage(storage) {}
661   void kernel(simgrid::xbt::ReplayAction& action) override
662   {
663     MPI_Request request = req_storage.find(args.src, args.dst, args.tag);
664     req_storage.remove(request);
665     // if request is null here, this may mean that a previous test has succeeded
666     // Different times in traced application and replayed version may lead to this
667     // In this case, ignore the extra calls.
668     if (request != MPI_REQUEST_NULL) {
669       TRACE_smpi_testing_in(my_proc_id);
670
671       MPI_Status status;
672       int flag = Request::test(&request, &status);
673
674       XBT_DEBUG("MPI_Test result: %d", flag);
675       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
676        * nullptr.*/
677       if (request == MPI_REQUEST_NULL)
678         req_storage.addNullRequest(args.src, args.dst, args.tag);
679       else
680         req_storage.add(request);
681
682       TRACE_smpi_testing_out(my_proc_id);
683     }
684   }
685 };
686
687 class InitAction : public ReplayAction<ActionArgParser> {
688 public:
689   InitAction() : ReplayAction("Init") {}
690   void kernel(simgrid::xbt::ReplayAction& action) override
691   {
692     CHECK_ACTION_PARAMS(action, 0, 1)
693     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
694                                            : MPI_BYTE;  // default TAU datatype
695
696     /* start a simulated timer */
697     smpi_process()->simulated_start();
698   }
699 };
700
701 class CommunicatorAction : public ReplayAction<ActionArgParser> {
702 public:
703   CommunicatorAction() : ReplayAction("Comm") {}
704   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
705 };
706
707 class WaitAllAction : public ReplayAction<ActionArgParser> {
708 private:
709   RequestStorage& req_storage;
710
711 public:
712   explicit WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll"), req_storage(storage) {}
713   void kernel(simgrid::xbt::ReplayAction& action) override
714   {
715     const unsigned int count_requests = req_storage.size();
716
717     if (count_requests > 0) {
718       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
719       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
720       std::vector<MPI_Request> reqs;
721       req_storage.get_requests(reqs);
722       for (const auto& req : reqs) {
723         if (req && (req->flags() & RECV)) {
724           sender_receiver.push_back({req->src(), req->dst()});
725         }
726       }
727       MPI_Status status[count_requests];
728       Request::waitall(count_requests, &(reqs.data())[0], status);
729       req_storage.get_store().clear();
730
731       for (auto& pair : sender_receiver) {
732         TRACE_smpi_recv(pair.first, pair.second, 0);
733       }
734       TRACE_smpi_comm_out(my_proc_id);
735     }
736   }
737 };
738
739 class BarrierAction : public ReplayAction<ActionArgParser> {
740 public:
741   BarrierAction() : ReplayAction("barrier") {}
742   void kernel(simgrid::xbt::ReplayAction& action) override
743   {
744     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
745     Colls::barrier(MPI_COMM_WORLD);
746     TRACE_smpi_comm_out(my_proc_id);
747   }
748 };
749
750 class BcastAction : public ReplayAction<BcastArgParser> {
751 public:
752   BcastAction() : ReplayAction("bcast") {}
753   void kernel(simgrid::xbt::ReplayAction& action) override
754   {
755     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
756                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
757                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
758
759     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
760
761     TRACE_smpi_comm_out(my_proc_id);
762   }
763 };
764
765 class ReduceAction : public ReplayAction<ReduceArgParser> {
766 public:
767   ReduceAction() : ReplayAction("reduce") {}
768   void kernel(simgrid::xbt::ReplayAction& action) override
769   {
770     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
771                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
772                                                       args.comp_size, args.comm_size, -1,
773                                                       Datatype::encode(args.datatype1), ""));
774
775     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
776         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
777     smpi_execute_flops(args.comp_size);
778
779     TRACE_smpi_comm_out(my_proc_id);
780   }
781 };
782
783 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
784 public:
785   AllReduceAction() : ReplayAction("allReduce") {}
786   void kernel(simgrid::xbt::ReplayAction& action) override
787   {
788     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
789                                                                                 Datatype::encode(args.datatype1), ""));
790
791     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
792         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
793     smpi_execute_flops(args.comp_size);
794
795     TRACE_smpi_comm_out(my_proc_id);
796   }
797 };
798
799 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
800 public:
801   AllToAllAction() : ReplayAction("allToAll") {}
802   void kernel(simgrid::xbt::ReplayAction& action) override
803   {
804     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
805                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
806                                                     Datatype::encode(args.datatype1),
807                                                     Datatype::encode(args.datatype2)));
808
809     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
810                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
811                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
812
813     TRACE_smpi_comm_out(my_proc_id);
814   }
815 };
816
817 class GatherAction : public ReplayAction<GatherArgParser> {
818 public:
819   explicit GatherAction(std::string name) : ReplayAction(name) {}
820   void kernel(simgrid::xbt::ReplayAction& action) override
821   {
822     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
823                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
824
825     if (name == "gather") {
826       int rank = MPI_COMM_WORLD->rank();
827       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
828                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
829     }
830     else
831       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
832                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
833
834     TRACE_smpi_comm_out(my_proc_id);
835   }
836 };
837
838 class GatherVAction : public ReplayAction<GatherVArgParser> {
839 public:
840   explicit GatherVAction(std::string name) : ReplayAction(name) {}
841   void kernel(simgrid::xbt::ReplayAction& action) override
842   {
843     int rank = MPI_COMM_WORLD->rank();
844
845     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
846                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
847                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
848
849     if (name == "gatherV") {
850       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
851                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
852                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
853     }
854     else {
855       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
856                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
857                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
858     }
859
860     TRACE_smpi_comm_out(my_proc_id);
861   }
862 };
863
864 class ScatterAction : public ReplayAction<ScatterArgParser> {
865 public:
866   ScatterAction() : ReplayAction("scatter") {}
867   void kernel(simgrid::xbt::ReplayAction& action) override
868   {
869     int rank = MPI_COMM_WORLD->rank();
870     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
871                                                                           Datatype::encode(args.datatype1),
872                                                                           Datatype::encode(args.datatype2)));
873
874     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
875                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
876
877     TRACE_smpi_comm_out(my_proc_id);
878   }
879 };
880
881
882 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
883 public:
884   ScatterVAction() : ReplayAction("scatterV") {}
885   void kernel(simgrid::xbt::ReplayAction& action) override
886   {
887     int rank = MPI_COMM_WORLD->rank();
888     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
889           nullptr, Datatype::encode(args.datatype1),
890           Datatype::encode(args.datatype2)));
891
892     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
893                     args.sendcounts->data(), args.disps.data(), args.datatype1,
894                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
895                     MPI_COMM_WORLD);
896
897     TRACE_smpi_comm_out(my_proc_id);
898   }
899 };
900
901 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
902 public:
903   ReduceScatterAction() : ReplayAction("reduceScatter") {}
904   void kernel(simgrid::xbt::ReplayAction& action) override
905   {
906     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
907                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
908                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
909                                                          Datatype::encode(args.datatype1)));
910
911     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
912                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
913                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
914
915     smpi_execute_flops(args.comp_size);
916     TRACE_smpi_comm_out(my_proc_id);
917   }
918 };
919
920 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
921 public:
922   AllToAllVAction() : ReplayAction("allToAllV") {}
923   void kernel(simgrid::xbt::ReplayAction& action) override
924   {
925     TRACE_smpi_comm_in(my_proc_id, __func__,
926                        new simgrid::instr::VarCollTIData(
927                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
928                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
929
930     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
931                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
932
933     TRACE_smpi_comm_out(my_proc_id);
934   }
935 };
936 } // Replay Namespace
937 }} // namespace simgrid::smpi
938
939 std::vector<simgrid::smpi::replay::RequestStorage> storage;
940 /** @brief Only initialize the replay, don't do it for real */
941 void smpi_replay_init(int* argc, char*** argv)
942 {
943   simgrid::smpi::Process::init(argc, argv);
944   smpi_process()->mark_as_initialized();
945   smpi_process()->set_replaying(true);
946
947   int my_proc_id = simgrid::s4u::this_actor::get_pid();
948   storage.resize(smpi_process_count());
949
950   TRACE_smpi_init(my_proc_id);
951   TRACE_smpi_computing_init(my_proc_id);
952   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
953   TRACE_smpi_comm_out(my_proc_id);
954   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
955   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
956   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
957   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
958   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
959   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
960   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
961   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
962   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
963   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
964   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
965   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
966   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
967   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
968   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
969   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
970   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
971   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
972   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
973   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
974   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
975   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
976   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
977   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
978   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
979   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
980
981   //if we have a delayed start, sleep here.
982   if(*argc>2){
983     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
984     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
985     smpi_execute_flops(value);
986   } else {
987     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
988     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
989     smpi_execute_flops(0.0);
990   }
991 }
992
993 /** @brief actually run the replay after initialization */
994 void smpi_replay_main(int* argc, char*** argv)
995 {
996   static int active_processes = 0;
997   active_processes++;
998   simgrid::xbt::replay_runner(*argc, *argv);
999
1000   /* and now, finalize everything */
1001   /* One active process will stop. Decrease the counter*/
1002   unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
1003   XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
1004   if (count_requests > 0) {
1005     MPI_Request requests[count_requests];
1006     MPI_Status status[count_requests];
1007     unsigned int i=0;
1008
1009     for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
1010       requests[i] = pair.second;
1011       i++;
1012     }
1013     simgrid::smpi::Request::waitall(count_requests, requests, status);
1014   }
1015   active_processes--;
1016
1017   if(active_processes==0){
1018     /* Last process alive speaking: end the simulated timer */
1019     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1020     smpi_free_replay_tmp_buffers();
1021   }
1022
1023   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1024                      new simgrid::instr::NoOpTIData("finalize"));
1025
1026   smpi_process()->finalize();
1027
1028   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1029   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1030 }
1031
1032 /** @brief chain a replay initialization and a replay start */
1033 void smpi_replay_run(int* argc, char*** argv)
1034 {
1035   smpi_replay_init(argc, argv);
1036   smpi_replay_main(argc, argv);
1037 }