Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Use MPI_REQUEST_NULL instead of nullptr
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
81
82 static MPI_Datatype MPI_DEFAULT_TYPE;
83
84 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
85   {                                                                                                                    \
86     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
87       std::stringstream ss;                                                                                            \
88       for (const auto& elem : action) {                                                                                \
89         ss << elem << " ";                                                                                             \
90       }                                                                                                                \
91       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
92                            "%zu items were given on the line. First two should be process_id and action.  "            \
93                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
94                            "The full line that was given is:\n   %s\n"                                                 \
95                            "Please contact the Simgrid team if support is needed",                                     \
96              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
97              ss.str().c_str());                                                                                        \
98     }                                                                                                                  \
99   }
100
101 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
102 {
103   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
104     std::string s = boost::algorithm::join(action, " ");
105     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
106   }
107 }
108
109 static std::vector<MPI_Request>* get_reqq_self()
110 {
111   return reqq.at(simgrid::s4u::this_actor::get_pid());
112 }
113
114 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
115 {
116   reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
117 }
118
119 /* Helper function */
120 static double parse_double(std::string string)
121 {
122   return xbt_str_parse_double(string.c_str(), "%s is not a double");
123 }
124
125 namespace simgrid {
126 namespace smpi {
127
128 namespace replay {
129
130 class RequestStorage {
131 private:
132     req_storage_t store;
133
134 public:
135     RequestStorage() {}
136     int size()
137     {
138       return store.size();
139     }
140
141     req_storage_t& get_store()
142     {
143       return store;
144     }
145
146     void get_requests(std::vector<MPI_Request>& vec)
147     {
148       for (auto& pair : store) {
149         auto& req = pair.second;
150         auto my_proc_id = simgrid::s4u::this_actor::getPid();
151         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
152           vec.push_back(pair.second);
153           pair.second->print_request("MM");
154         }
155       }
156     }
157
158     MPI_Request find(int src, int dst, int tag)
159     {
160       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
161       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
162     }
163
164     void remove(MPI_Request req)
165     {
166       if (req == MPI_REQUEST_NULL) return;
167
168       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
169     }
170
171     void add(MPI_Request req)
172     {
173       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
174         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
175     }
176
177     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
178     void addNullRequest(int src, int dst, int tag)
179     {
180       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
181     }
182 };
183
184 class ActionArgParser {
185 public:
186   virtual ~ActionArgParser() = default;
187   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
188 };
189
190 class WaitTestParser : public ActionArgParser {
191 public:
192   int src;
193   int dst;
194   int tag;
195
196   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
197   {
198     CHECK_ACTION_PARAMS(action, 3, 0)
199     src = std::stoi(action[2]);
200     dst = std::stoi(action[3]);
201     tag = std::stoi(action[4]);
202   }
203 };
204
205 class SendRecvParser : public ActionArgParser {
206 public:
207   /* communication partner; if we send, this is the receiver and vice versa */
208   int partner;
209   double size;
210   int tag;
211   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
212
213   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
214   {
215     CHECK_ACTION_PARAMS(action, 3, 1)
216     partner = std::stoi(action[2]);
217     tag     = std::stoi(action[3]);
218     size    = parse_double(action[4]);
219     if (action.size() > 5)
220       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
221   }
222 };
223
224 class ComputeParser : public ActionArgParser {
225 public:
226   /* communication partner; if we send, this is the receiver and vice versa */
227   double flops;
228
229   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
230   {
231     CHECK_ACTION_PARAMS(action, 1, 0)
232     flops = parse_double(action[2]);
233   }
234 };
235
236 class CollCommParser : public ActionArgParser {
237 public:
238   double size;
239   double comm_size;
240   double comp_size;
241   int send_size;
242   int recv_size;
243   int root = 0;
244   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
245   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
246 };
247
248 class BcastArgParser : public CollCommParser {
249 public:
250   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
251   {
252     CHECK_ACTION_PARAMS(action, 1, 2)
253     size = parse_double(action[2]);
254     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
255     if (action.size() > 4)
256       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
257   }
258 };
259
260 class ReduceArgParser : public CollCommParser {
261 public:
262   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
263   {
264     CHECK_ACTION_PARAMS(action, 2, 2)
265     comm_size = parse_double(action[2]);
266     comp_size = parse_double(action[3]);
267     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
268     if (action.size() > 5)
269       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
270   }
271 };
272
273 class AllReduceArgParser : public CollCommParser {
274 public:
275   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
276   {
277     CHECK_ACTION_PARAMS(action, 2, 1)
278     comm_size = parse_double(action[2]);
279     comp_size = parse_double(action[3]);
280     if (action.size() > 4)
281       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
282   }
283 };
284
285 class AllToAllArgParser : public CollCommParser {
286 public:
287   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
288   {
289     CHECK_ACTION_PARAMS(action, 2, 1)
290     comm_size = MPI_COMM_WORLD->size();
291     send_size = parse_double(action[2]);
292     recv_size = parse_double(action[3]);
293
294     if (action.size() > 4)
295       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
296     if (action.size() > 5)
297       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
298   }
299 };
300
301 class GatherArgParser : public CollCommParser {
302 public:
303   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
304   {
305     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
306           0 gather 68 68 0 0 0
307         where:
308           1) 68 is the sendcounts
309           2) 68 is the recvcounts
310           3) 0 is the root node
311           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
312           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
313     */
314     CHECK_ACTION_PARAMS(action, 2, 3)
315     comm_size = MPI_COMM_WORLD->size();
316     send_size = parse_double(action[2]);
317     recv_size = parse_double(action[3]);
318
319     if (name == "gather") {
320       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
321       if (action.size() > 5)
322         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
323       if (action.size() > 6)
324         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
325     }
326     else {
327       if (action.size() > 4)
328         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
329       if (action.size() > 5)
330         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
331     }
332   }
333 };
334
335 class GatherVArgParser : public CollCommParser {
336 public:
337   int recv_size_sum;
338   std::shared_ptr<std::vector<int>> recvcounts;
339   std::vector<int> disps;
340   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
341   {
342     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
343          0 gather 68 68 10 10 10 0 0 0
344        where:
345          1) 68 is the sendcount
346          2) 68 10 10 10 is the recvcounts
347          3) 0 is the root node
348          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
349          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
350     */
351     comm_size = MPI_COMM_WORLD->size();
352     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
353     send_size = parse_double(action[2]);
354     disps     = std::vector<int>(comm_size, 0);
355     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
356
357     if (name == "gatherV") {
358       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
359       if (action.size() > 4 + comm_size)
360         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
361       if (action.size() > 5 + comm_size)
362         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
363     }
364     else {
365       int datatype_index = 0;
366       int disp_index     = 0;
367       /* The 3 comes from "0 gather <sendcount>", which must always be present.
368        * The + comm_size is the recvcounts array, which must also be present
369        */
370       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
371         datatype_index = 3 + comm_size;
372         disp_index     = datatype_index + 1;
373         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
374         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
375       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
376         disp_index     = 3 + comm_size;
377       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
378         datatype_index = 3 + comm_size;
379         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
380         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
381       }
382
383       if (disp_index != 0) {
384         for (unsigned int i = 0; i < comm_size; i++)
385           disps[i]          = std::stoi(action[disp_index + i]);
386       }
387     }
388
389     for (unsigned int i = 0; i < comm_size; i++) {
390       (*recvcounts)[i] = std::stoi(action[i + 3]);
391     }
392     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
393   }
394 };
395
396 class ScatterArgParser : public CollCommParser {
397 public:
398   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
399   {
400     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
401           0 gather 68 68 0 0 0
402         where:
403           1) 68 is the sendcounts
404           2) 68 is the recvcounts
405           3) 0 is the root node
406           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
407           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
408     */
409     CHECK_ACTION_PARAMS(action, 2, 3)
410     comm_size   = MPI_COMM_WORLD->size();
411     send_size   = parse_double(action[2]);
412     recv_size   = parse_double(action[3]);
413     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
414     if (action.size() > 5)
415       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
416     if (action.size() > 6)
417       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
418   }
419 };
420
421 class ScatterVArgParser : public CollCommParser {
422 public:
423   int recv_size_sum;
424   int send_size_sum;
425   std::shared_ptr<std::vector<int>> sendcounts;
426   std::vector<int> disps;
427   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
428   {
429     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
430        0 gather 68 10 10 10 68 0 0 0
431         where:
432         1) 68 10 10 10 is the sendcounts
433         2) 68 is the recvcount
434         3) 0 is the root node
435         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
436         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
437     */
438     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
439     recv_size  = parse_double(action[2 + comm_size]);
440     disps      = std::vector<int>(comm_size, 0);
441     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
442
443     if (action.size() > 5 + comm_size)
444       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
445     if (action.size() > 5 + comm_size)
446       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
447
448     for (unsigned int i = 0; i < comm_size; i++) {
449       (*sendcounts)[i] = std::stoi(action[i + 2]);
450     }
451     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
452     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
453   }
454 };
455
456 class ReduceScatterArgParser : public CollCommParser {
457 public:
458   int recv_size_sum;
459   std::shared_ptr<std::vector<int>> recvcounts;
460   std::vector<int> disps;
461   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
462   {
463     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
464          0 reduceScatter 275427 275427 275427 204020 11346849 0
465        where:
466          1) The first four values after the name of the action declare the recvcounts array
467          2) The value 11346849 is the amount of instructions
468          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
469     */
470     comm_size = MPI_COMM_WORLD->size();
471     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
472     comp_size = parse_double(action[2+comm_size]);
473     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
474     if (action.size() > 3 + comm_size)
475       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
476
477     for (unsigned int i = 0; i < comm_size; i++) {
478       recvcounts->push_back(std::stoi(action[i + 2]));
479     }
480     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
481   }
482 };
483
484 class AllToAllVArgParser : public CollCommParser {
485 public:
486   int recv_size_sum;
487   int send_size_sum;
488   std::shared_ptr<std::vector<int>> recvcounts;
489   std::shared_ptr<std::vector<int>> sendcounts;
490   std::vector<int> senddisps;
491   std::vector<int> recvdisps;
492   int send_buf_size;
493   int recv_buf_size;
494   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
495   {
496     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
497           0 allToAllV 100 1 7 10 12 100 1 70 10 5
498        where:
499         1) 100 is the size of the send buffer *sizeof(int),
500         2) 1 7 10 12 is the sendcounts array
501         3) 100*sizeof(int) is the size of the receiver buffer
502         4)  1 70 10 5 is the recvcounts array
503     */
504     comm_size = MPI_COMM_WORLD->size();
505     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
506     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
507     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
508     senddisps  = std::vector<int>(comm_size, 0);
509     recvdisps  = std::vector<int>(comm_size, 0);
510
511     if (action.size() > 5 + 2 * comm_size)
512       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
513     if (action.size() > 5 + 2 * comm_size)
514       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
515
516     send_buf_size=parse_double(action[2]);
517     recv_buf_size=parse_double(action[3+comm_size]);
518     for (unsigned int i = 0; i < comm_size; i++) {
519       (*sendcounts)[i] = std::stoi(action[3 + i]);
520       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
521     }
522     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
523     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
524   }
525 };
526
527 template <class T> class ReplayAction {
528 protected:
529   const std::string name;
530   const int my_proc_id;
531   T args;
532
533 public:
534   explicit ReplayAction(std::string name) : name(name), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
535   virtual ~ReplayAction() = default;
536
537   virtual void execute(simgrid::xbt::ReplayAction& action)
538   {
539     // Needs to be re-initialized for every action, hence here
540     double start_time = smpi_process()->simulated_elapsed();
541     args.parse(action, name);
542     kernel(action);
543     if (name != "Init")
544       log_timed_action(action, start_time);
545   }
546
547   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
548
549   void* send_buffer(int size)
550   {
551     return smpi_get_tmp_sendbuffer(size);
552   }
553
554   void* recv_buffer(int size)
555   {
556     return smpi_get_tmp_recvbuffer(size);
557   }
558 };
559
560 class WaitAction : public ReplayAction<WaitTestParser> {
561 public:
562   WaitAction() : ReplayAction("Wait") {}
563   void kernel(simgrid::xbt::ReplayAction& action) override
564   {
565     std::string s = boost::algorithm::join(action, " ");
566     xbt_assert(get_reqq_self()->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
567     MPI_Request request = get_reqq_self()->back();
568     get_reqq_self()->pop_back();
569
570     if (request == MPI_REQUEST_NULL) {
571       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
572        * return.*/
573       return;
574     }
575
576     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
577
578     // Must be taken before Request::wait() since the request may be set to
579     // MPI_REQUEST_NULL by Request::wait!
580     bool is_wait_for_receive = (request->flags() & RECV);
581     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
582     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
583
584     MPI_Status status;
585     Request::wait(&request, &status);
586
587     TRACE_smpi_comm_out(rank);
588     if (is_wait_for_receive)
589       TRACE_smpi_recv(args.src, args.dst, args.tag);
590   }
591 };
592
593 class SendAction : public ReplayAction<SendRecvParser> {
594 public:
595   SendAction() = delete;
596   explicit SendAction(std::string name) : ReplayAction(name) {}
597   void kernel(simgrid::xbt::ReplayAction& action) override
598   {
599     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
600
601     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
602                                                                              args.tag, Datatype::encode(args.datatype1)));
603     if (not TRACE_smpi_view_internals())
604       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
605
606     if (name == "send") {
607       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
608     } else if (name == "Isend") {
609       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
610       get_reqq_self()->push_back(request);
611     } else {
612       xbt_die("Don't know this action, %s", name.c_str());
613     }
614
615     TRACE_smpi_comm_out(my_proc_id);
616   }
617 };
618
619 class RecvAction : public ReplayAction<SendRecvParser> {
620 public:
621   RecvAction() = delete;
622   explicit RecvAction(std::string name) : ReplayAction(name) {}
623   void kernel(simgrid::xbt::ReplayAction& action) override
624   {
625     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
626
627     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
628                                                                              args.tag, Datatype::encode(args.datatype1)));
629
630     MPI_Status status;
631     // unknown size from the receiver point of view
632     if (args.size <= 0.0) {
633       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
634       args.size = status.count;
635     }
636
637     if (name == "recv") {
638       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
639     } else if (name == "Irecv") {
640       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
641       get_reqq_self()->push_back(request);
642     }
643
644     TRACE_smpi_comm_out(my_proc_id);
645     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
646     if (name == "recv" && not TRACE_smpi_view_internals()) {
647       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
648     }
649   }
650 };
651
652 class ComputeAction : public ReplayAction<ComputeParser> {
653 public:
654   ComputeAction() : ReplayAction("compute") {}
655   void kernel(simgrid::xbt::ReplayAction& action) override
656   {
657     TRACE_smpi_computing_in(my_proc_id, args.flops);
658     smpi_execute_flops(args.flops);
659     TRACE_smpi_computing_out(my_proc_id);
660   }
661 };
662
663 class TestAction : public ReplayAction<WaitTestParser> {
664 public:
665   TestAction() : ReplayAction("Test") {}
666   void kernel(simgrid::xbt::ReplayAction& action) override
667   {
668     MPI_Request request = get_reqq_self()->back();
669     get_reqq_self()->pop_back();
670     // if request is null here, this may mean that a previous test has succeeded
671     // Different times in traced application and replayed version may lead to this
672     // In this case, ignore the extra calls.
673     if (request != MPI_REQUEST_NULL) {
674       TRACE_smpi_testing_in(my_proc_id);
675
676       MPI_Status status;
677       int flag = Request::test(&request, &status);
678
679       XBT_DEBUG("MPI_Test result: %d", flag);
680       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
681        * nullptr.*/
682       get_reqq_self()->push_back(request);
683
684       TRACE_smpi_testing_out(my_proc_id);
685     }
686   }
687 };
688
689 class InitAction : public ReplayAction<ActionArgParser> {
690 public:
691   InitAction() : ReplayAction("Init") {}
692   void kernel(simgrid::xbt::ReplayAction& action) override
693   {
694     CHECK_ACTION_PARAMS(action, 0, 1)
695     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
696                                            : MPI_BYTE;  // default TAU datatype
697
698     /* start a simulated timer */
699     smpi_process()->simulated_start();
700     set_reqq_self(new std::vector<MPI_Request>);
701   }
702 };
703
704 class CommunicatorAction : public ReplayAction<ActionArgParser> {
705 public:
706   CommunicatorAction() : ReplayAction("Comm") {}
707   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
708 };
709
710 class WaitAllAction : public ReplayAction<ActionArgParser> {
711 public:
712   WaitAllAction() : ReplayAction("waitAll") {}
713   void kernel(simgrid::xbt::ReplayAction& action) override
714   {
715     const unsigned int count_requests = get_reqq_self()->size();
716
717     if (count_requests > 0) {
718       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
719       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
720       for (const auto& req : (*get_reqq_self())) {
721         if (req && (req->flags() & RECV)) {
722           sender_receiver.push_back({req->src(), req->dst()});
723         }
724       }
725       MPI_Status status[count_requests];
726       Request::waitall(count_requests, &(*get_reqq_self())[0], status);
727
728       for (auto& pair : sender_receiver) {
729         TRACE_smpi_recv(pair.first, pair.second, 0);
730       }
731       TRACE_smpi_comm_out(my_proc_id);
732     }
733   }
734 };
735
736 class BarrierAction : public ReplayAction<ActionArgParser> {
737 public:
738   BarrierAction() : ReplayAction("barrier") {}
739   void kernel(simgrid::xbt::ReplayAction& action) override
740   {
741     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
742     Colls::barrier(MPI_COMM_WORLD);
743     TRACE_smpi_comm_out(my_proc_id);
744   }
745 };
746
747 class BcastAction : public ReplayAction<BcastArgParser> {
748 public:
749   BcastAction() : ReplayAction("bcast") {}
750   void kernel(simgrid::xbt::ReplayAction& action) override
751   {
752     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
753                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
754                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
755
756     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
757
758     TRACE_smpi_comm_out(my_proc_id);
759   }
760 };
761
762 class ReduceAction : public ReplayAction<ReduceArgParser> {
763 public:
764   ReduceAction() : ReplayAction("reduce") {}
765   void kernel(simgrid::xbt::ReplayAction& action) override
766   {
767     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
768                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
769                                                       args.comp_size, args.comm_size, -1,
770                                                       Datatype::encode(args.datatype1), ""));
771
772     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
773         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
774     smpi_execute_flops(args.comp_size);
775
776     TRACE_smpi_comm_out(my_proc_id);
777   }
778 };
779
780 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
781 public:
782   AllReduceAction() : ReplayAction("allReduce") {}
783   void kernel(simgrid::xbt::ReplayAction& action) override
784   {
785     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
786                                                                                 Datatype::encode(args.datatype1), ""));
787
788     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
789         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
790     smpi_execute_flops(args.comp_size);
791
792     TRACE_smpi_comm_out(my_proc_id);
793   }
794 };
795
796 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
797 public:
798   AllToAllAction() : ReplayAction("allToAll") {}
799   void kernel(simgrid::xbt::ReplayAction& action) override
800   {
801     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
802                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
803                                                     Datatype::encode(args.datatype1),
804                                                     Datatype::encode(args.datatype2)));
805
806     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
807                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
808                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
809
810     TRACE_smpi_comm_out(my_proc_id);
811   }
812 };
813
814 class GatherAction : public ReplayAction<GatherArgParser> {
815 public:
816   explicit GatherAction(std::string name) : ReplayAction(name) {}
817   void kernel(simgrid::xbt::ReplayAction& action) override
818   {
819     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
820                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
821
822     if (name == "gather") {
823       int rank = MPI_COMM_WORLD->rank();
824       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
825                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
826     }
827     else
828       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
829                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
830
831     TRACE_smpi_comm_out(my_proc_id);
832   }
833 };
834
835 class GatherVAction : public ReplayAction<GatherVArgParser> {
836 public:
837   explicit GatherVAction(std::string name) : ReplayAction(name) {}
838   void kernel(simgrid::xbt::ReplayAction& action) override
839   {
840     int rank = MPI_COMM_WORLD->rank();
841
842     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
843                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
844                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
845
846     if (name == "gatherV") {
847       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
848                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
849                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
850     }
851     else {
852       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
853                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
854                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
855     }
856
857     TRACE_smpi_comm_out(my_proc_id);
858   }
859 };
860
861 class ScatterAction : public ReplayAction<ScatterArgParser> {
862 public:
863   ScatterAction() : ReplayAction("scatter") {}
864   void kernel(simgrid::xbt::ReplayAction& action) override
865   {
866     int rank = MPI_COMM_WORLD->rank();
867     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
868                                                                           Datatype::encode(args.datatype1),
869                                                                           Datatype::encode(args.datatype2)));
870
871     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
872                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
873
874     TRACE_smpi_comm_out(my_proc_id);
875   }
876 };
877
878
879 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
880 public:
881   ScatterVAction() : ReplayAction("scatterV") {}
882   void kernel(simgrid::xbt::ReplayAction& action) override
883   {
884     int rank = MPI_COMM_WORLD->rank();
885     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
886           nullptr, Datatype::encode(args.datatype1),
887           Datatype::encode(args.datatype2)));
888
889     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
890                     args.sendcounts->data(), args.disps.data(), args.datatype1,
891                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
892                     MPI_COMM_WORLD);
893
894     TRACE_smpi_comm_out(my_proc_id);
895   }
896 };
897
898 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
899 public:
900   ReduceScatterAction() : ReplayAction("reduceScatter") {}
901   void kernel(simgrid::xbt::ReplayAction& action) override
902   {
903     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
904                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
905                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
906                                                          Datatype::encode(args.datatype1)));
907
908     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
909                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
910                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
911
912     smpi_execute_flops(args.comp_size);
913     TRACE_smpi_comm_out(my_proc_id);
914   }
915 };
916
917 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
918 public:
919   AllToAllVAction() : ReplayAction("allToAllV") {}
920   void kernel(simgrid::xbt::ReplayAction& action) override
921   {
922     TRACE_smpi_comm_in(my_proc_id, __func__,
923                        new simgrid::instr::VarCollTIData(
924                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
925                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
926
927     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
928                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
929
930     TRACE_smpi_comm_out(my_proc_id);
931   }
932 };
933 } // Replay Namespace
934 }} // namespace simgrid::smpi
935
936 std::vector<simgrid::smpi::replay::RequestStorage> storage;
937 /** @brief Only initialize the replay, don't do it for real */
938 void smpi_replay_init(int* argc, char*** argv)
939 {
940   simgrid::smpi::Process::init(argc, argv);
941   smpi_process()->mark_as_initialized();
942   smpi_process()->set_replaying(true);
943
944   int my_proc_id = simgrid::s4u::this_actor::get_pid();
945   TRACE_smpi_init(my_proc_id);
946   TRACE_smpi_computing_init(my_proc_id);
947   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
948   TRACE_smpi_comm_out(my_proc_id);
949   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
950   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
951   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
952   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
953   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
954
955   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send").execute(action); });
956   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend").execute(action); });
957   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv").execute(action); });
958   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv").execute(action); });
959   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction().execute(action); });
960   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction().execute(action); });
961   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction().execute(action); });
962   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
963   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
964   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
965   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
966   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
967   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
968   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
969   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
970   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
971   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
972   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
973   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
974   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
975   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
976
977   //if we have a delayed start, sleep here.
978   if(*argc>2){
979     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
980     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
981     smpi_execute_flops(value);
982   } else {
983     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
984     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
985     smpi_execute_flops(0.0);
986   }
987 }
988
989 /** @brief actually run the replay after initialization */
990 void smpi_replay_main(int* argc, char*** argv)
991 {
992   static int active_processes = 0;
993   active_processes++;
994   simgrid::xbt::replay_runner(*argc, *argv);
995
996   /* and now, finalize everything */
997   /* One active process will stop. Decrease the counter*/
998   XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
999   if (not get_reqq_self()->empty()) {
1000     unsigned int count_requests=get_reqq_self()->size();
1001     MPI_Request requests[count_requests];
1002     MPI_Status status[count_requests];
1003     unsigned int i=0;
1004
1005     for (auto const& req : *get_reqq_self()) {
1006       requests[i] = req;
1007       i++;
1008     }
1009     simgrid::smpi::Request::waitall(count_requests, requests, status);
1010   }
1011   delete get_reqq_self();
1012   active_processes--;
1013
1014   if(active_processes==0){
1015     /* Last process alive speaking: end the simulated timer */
1016     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1017     smpi_free_replay_tmp_buffers();
1018   }
1019
1020   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1021                      new simgrid::instr::NoOpTIData("finalize"));
1022
1023   smpi_process()->finalize();
1024
1025   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1026   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1027 }
1028
1029 /** @brief chain a replay initialization and a replay start */
1030 void smpi_replay_run(int* argc, char*** argv)
1031 {
1032   smpi_replay_init(argc, argv);
1033   smpi_replay_main(argc, argv);
1034 }