Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Introduce and use WaitTestParser
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
81
82 static MPI_Datatype MPI_DEFAULT_TYPE;
83
84 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
85   {                                                                                                                    \
86     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
87       std::stringstream ss;                                                                                            \
88       for (const auto& elem : action) {                                                                                \
89         ss << elem << " ";                                                                                             \
90       }                                                                                                                \
91       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
92                            "%zu items were given on the line. First two should be process_id and action.  "            \
93                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
94                            "The full line that was given is:\n   %s\n"                                                 \
95                            "Please contact the Simgrid team if support is needed",                                     \
96              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
97              ss.str().c_str());                                                                                        \
98     }                                                                                                                  \
99   }
100
101 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
102 {
103   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
104     std::string s = boost::algorithm::join(action, " ");
105     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
106   }
107 }
108
109 static std::vector<MPI_Request>* get_reqq_self()
110 {
111   return reqq.at(simgrid::s4u::this_actor::get_pid());
112 }
113
114 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
115 {
116   reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
117 }
118
119 /* Helper function */
120 static double parse_double(std::string string)
121 {
122   return xbt_str_parse_double(string.c_str(), "%s is not a double");
123 }
124
125 namespace simgrid {
126 namespace smpi {
127
128 namespace replay {
129 class ActionArgParser {
130 public:
131   virtual ~ActionArgParser() = default;
132   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
133 };
134
135 class WaitTestParser : public ActionArgParser {
136 public:
137   int src;
138   int dst;
139   int tag;
140
141   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
142   {
143     CHECK_ACTION_PARAMS(action, 3, 0)
144     src = std::stoi(action[2]);
145     dst = std::stoi(action[3]);
146     tag = std::stoi(action[4]);
147   }
148 };
149
150 class SendRecvParser : public ActionArgParser {
151 public:
152   /* communication partner; if we send, this is the receiver and vice versa */
153   int partner;
154   double size;
155   int tag;
156   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
157
158   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
159   {
160     CHECK_ACTION_PARAMS(action, 3, 1)
161     partner = std::stoi(action[2]);
162     tag     = std::stoi(action[3]);
163     size    = parse_double(action[4]);
164     if (action.size() > 5)
165       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
166   }
167 };
168
169 class ComputeParser : public ActionArgParser {
170 public:
171   /* communication partner; if we send, this is the receiver and vice versa */
172   double flops;
173
174   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
175   {
176     CHECK_ACTION_PARAMS(action, 1, 0)
177     flops = parse_double(action[2]);
178   }
179 };
180
181 class CollCommParser : public ActionArgParser {
182 public:
183   double size;
184   double comm_size;
185   double comp_size;
186   int send_size;
187   int recv_size;
188   int root = 0;
189   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
190   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
191 };
192
193 class BcastArgParser : public CollCommParser {
194 public:
195   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
196   {
197     CHECK_ACTION_PARAMS(action, 1, 2)
198     size = parse_double(action[2]);
199     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
200     if (action.size() > 4)
201       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
202   }
203 };
204
205 class ReduceArgParser : public CollCommParser {
206 public:
207   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
208   {
209     CHECK_ACTION_PARAMS(action, 2, 2)
210     comm_size = parse_double(action[2]);
211     comp_size = parse_double(action[3]);
212     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
213     if (action.size() > 5)
214       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
215   }
216 };
217
218 class AllReduceArgParser : public CollCommParser {
219 public:
220   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
221   {
222     CHECK_ACTION_PARAMS(action, 2, 1)
223     comm_size = parse_double(action[2]);
224     comp_size = parse_double(action[3]);
225     if (action.size() > 4)
226       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
227   }
228 };
229
230 class AllToAllArgParser : public CollCommParser {
231 public:
232   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
233   {
234     CHECK_ACTION_PARAMS(action, 2, 1)
235     comm_size = MPI_COMM_WORLD->size();
236     send_size = parse_double(action[2]);
237     recv_size = parse_double(action[3]);
238
239     if (action.size() > 4)
240       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
241     if (action.size() > 5)
242       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
243   }
244 };
245
246 class GatherArgParser : public CollCommParser {
247 public:
248   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
249   {
250     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
251           0 gather 68 68 0 0 0
252         where:
253           1) 68 is the sendcounts
254           2) 68 is the recvcounts
255           3) 0 is the root node
256           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
257           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
258     */
259     CHECK_ACTION_PARAMS(action, 2, 3)
260     comm_size = MPI_COMM_WORLD->size();
261     send_size = parse_double(action[2]);
262     recv_size = parse_double(action[3]);
263
264     if (name == "gather") {
265       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
266       if (action.size() > 5)
267         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
268       if (action.size() > 6)
269         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
270     }
271     else {
272       if (action.size() > 4)
273         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
274       if (action.size() > 5)
275         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
276     }
277   }
278 };
279
280 class GatherVArgParser : public CollCommParser {
281 public:
282   int recv_size_sum;
283   std::shared_ptr<std::vector<int>> recvcounts;
284   std::vector<int> disps;
285   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
286   {
287     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
288          0 gather 68 68 10 10 10 0 0 0
289        where:
290          1) 68 is the sendcount
291          2) 68 10 10 10 is the recvcounts
292          3) 0 is the root node
293          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
294          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
295     */
296     comm_size = MPI_COMM_WORLD->size();
297     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
298     send_size = parse_double(action[2]);
299     disps     = std::vector<int>(comm_size, 0);
300     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
301
302     if (name == "gatherV") {
303       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
304       if (action.size() > 4 + comm_size)
305         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
306       if (action.size() > 5 + comm_size)
307         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
308     }
309     else {
310       int datatype_index = 0;
311       int disp_index     = 0;
312       /* The 3 comes from "0 gather <sendcount>", which must always be present.
313        * The + comm_size is the recvcounts array, which must also be present
314        */
315       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
316         datatype_index = 3 + comm_size;
317         disp_index     = datatype_index + 1;
318         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
319         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
320       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
321         disp_index     = 3 + comm_size;
322       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
323         datatype_index = 3 + comm_size;
324         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
325         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
326       }
327
328       if (disp_index != 0) {
329         for (unsigned int i = 0; i < comm_size; i++)
330           disps[i]          = std::stoi(action[disp_index + i]);
331       }
332     }
333
334     for (unsigned int i = 0; i < comm_size; i++) {
335       (*recvcounts)[i] = std::stoi(action[i + 3]);
336     }
337     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
338   }
339 };
340
341 class ScatterArgParser : public CollCommParser {
342 public:
343   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
344   {
345     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
346           0 gather 68 68 0 0 0
347         where:
348           1) 68 is the sendcounts
349           2) 68 is the recvcounts
350           3) 0 is the root node
351           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
352           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
353     */
354     CHECK_ACTION_PARAMS(action, 2, 3)
355     comm_size   = MPI_COMM_WORLD->size();
356     send_size   = parse_double(action[2]);
357     recv_size   = parse_double(action[3]);
358     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
359     if (action.size() > 5)
360       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
361     if (action.size() > 6)
362       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
363   }
364 };
365
366 class ScatterVArgParser : public CollCommParser {
367 public:
368   int recv_size_sum;
369   int send_size_sum;
370   std::shared_ptr<std::vector<int>> sendcounts;
371   std::vector<int> disps;
372   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
373   {
374     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
375        0 gather 68 10 10 10 68 0 0 0
376         where:
377         1) 68 10 10 10 is the sendcounts
378         2) 68 is the recvcount
379         3) 0 is the root node
380         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
381         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
382     */
383     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
384     recv_size  = parse_double(action[2 + comm_size]);
385     disps      = std::vector<int>(comm_size, 0);
386     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
387
388     if (action.size() > 5 + comm_size)
389       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
390     if (action.size() > 5 + comm_size)
391       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
392
393     for (unsigned int i = 0; i < comm_size; i++) {
394       (*sendcounts)[i] = std::stoi(action[i + 2]);
395     }
396     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
397     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
398   }
399 };
400
401 class ReduceScatterArgParser : public CollCommParser {
402 public:
403   int recv_size_sum;
404   std::shared_ptr<std::vector<int>> recvcounts;
405   std::vector<int> disps;
406   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
407   {
408     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
409          0 reduceScatter 275427 275427 275427 204020 11346849 0
410        where:
411          1) The first four values after the name of the action declare the recvcounts array
412          2) The value 11346849 is the amount of instructions
413          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
414     */
415     comm_size = MPI_COMM_WORLD->size();
416     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
417     comp_size = parse_double(action[2+comm_size]);
418     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
419     if (action.size() > 3 + comm_size)
420       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
421
422     for (unsigned int i = 0; i < comm_size; i++) {
423       recvcounts->push_back(std::stoi(action[i + 2]));
424     }
425     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
426   }
427 };
428
429 class AllToAllVArgParser : public CollCommParser {
430 public:
431   int recv_size_sum;
432   int send_size_sum;
433   std::shared_ptr<std::vector<int>> recvcounts;
434   std::shared_ptr<std::vector<int>> sendcounts;
435   std::vector<int> senddisps;
436   std::vector<int> recvdisps;
437   int send_buf_size;
438   int recv_buf_size;
439   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
440   {
441     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
442           0 allToAllV 100 1 7 10 12 100 1 70 10 5
443        where:
444         1) 100 is the size of the send buffer *sizeof(int),
445         2) 1 7 10 12 is the sendcounts array
446         3) 100*sizeof(int) is the size of the receiver buffer
447         4)  1 70 10 5 is the recvcounts array
448     */
449     comm_size = MPI_COMM_WORLD->size();
450     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
451     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
452     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
453     senddisps  = std::vector<int>(comm_size, 0);
454     recvdisps  = std::vector<int>(comm_size, 0);
455
456     if (action.size() > 5 + 2 * comm_size)
457       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
458     if (action.size() > 5 + 2 * comm_size)
459       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
460
461     send_buf_size=parse_double(action[2]);
462     recv_buf_size=parse_double(action[3+comm_size]);
463     for (unsigned int i = 0; i < comm_size; i++) {
464       (*sendcounts)[i] = std::stoi(action[3 + i]);
465       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
466     }
467     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
468     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
469   }
470 };
471
472 template <class T> class ReplayAction {
473 protected:
474   const std::string name;
475   const int my_proc_id;
476   T args;
477
478 public:
479   explicit ReplayAction(std::string name) : name(name), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
480   virtual ~ReplayAction() = default;
481
482   virtual void execute(simgrid::xbt::ReplayAction& action)
483   {
484     // Needs to be re-initialized for every action, hence here
485     double start_time = smpi_process()->simulated_elapsed();
486     args.parse(action, name);
487     kernel(action);
488     if (name != "Init")
489       log_timed_action(action, start_time);
490   }
491
492   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
493
494   void* send_buffer(int size)
495   {
496     return smpi_get_tmp_sendbuffer(size);
497   }
498
499   void* recv_buffer(int size)
500   {
501     return smpi_get_tmp_recvbuffer(size);
502   }
503 };
504
505 class WaitAction : public ReplayAction<WaitTestParser> {
506 public:
507   WaitAction() : ReplayAction("Wait") {}
508   void kernel(simgrid::xbt::ReplayAction& action) override
509   {
510     std::string s = boost::algorithm::join(action, " ");
511     xbt_assert(get_reqq_self()->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
512     MPI_Request request = get_reqq_self()->back();
513     get_reqq_self()->pop_back();
514
515     if (request == MPI_REQUEST_NULL) {
516       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
517        * return.*/
518       return;
519     }
520
521     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
522
523     // Must be taken before Request::wait() since the request may be set to
524     // MPI_REQUEST_NULL by Request::wait!
525     bool is_wait_for_receive = (request->flags() & RECV);
526     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
527     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
528
529     MPI_Status status;
530     Request::wait(&request, &status);
531
532     TRACE_smpi_comm_out(rank);
533     if (is_wait_for_receive)
534       TRACE_smpi_recv(args.src, args.dst, args.tag);
535   }
536 };
537
538 class SendAction : public ReplayAction<SendRecvParser> {
539 public:
540   SendAction() = delete;
541   explicit SendAction(std::string name) : ReplayAction(name) {}
542   void kernel(simgrid::xbt::ReplayAction& action) override
543   {
544     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
545
546     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
547                                                                              args.tag, Datatype::encode(args.datatype1)));
548     if (not TRACE_smpi_view_internals())
549       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
550
551     if (name == "send") {
552       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
553     } else if (name == "Isend") {
554       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
555       get_reqq_self()->push_back(request);
556     } else {
557       xbt_die("Don't know this action, %s", name.c_str());
558     }
559
560     TRACE_smpi_comm_out(my_proc_id);
561   }
562 };
563
564 class RecvAction : public ReplayAction<SendRecvParser> {
565 public:
566   RecvAction() = delete;
567   explicit RecvAction(std::string name) : ReplayAction(name) {}
568   void kernel(simgrid::xbt::ReplayAction& action) override
569   {
570     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
571
572     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
573                                                                              args.tag, Datatype::encode(args.datatype1)));
574
575     MPI_Status status;
576     // unknown size from the receiver point of view
577     if (args.size <= 0.0) {
578       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
579       args.size = status.count;
580     }
581
582     if (name == "recv") {
583       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
584     } else if (name == "Irecv") {
585       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
586       get_reqq_self()->push_back(request);
587     }
588
589     TRACE_smpi_comm_out(my_proc_id);
590     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
591     if (name == "recv" && not TRACE_smpi_view_internals()) {
592       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
593     }
594   }
595 };
596
597 class ComputeAction : public ReplayAction<ComputeParser> {
598 public:
599   ComputeAction() : ReplayAction("compute") {}
600   void kernel(simgrid::xbt::ReplayAction& action) override
601   {
602     TRACE_smpi_computing_in(my_proc_id, args.flops);
603     smpi_execute_flops(args.flops);
604     TRACE_smpi_computing_out(my_proc_id);
605   }
606 };
607
608 class TestAction : public ReplayAction<WaitTestParser> {
609 public:
610   TestAction() : ReplayAction("Test") {}
611   void kernel(simgrid::xbt::ReplayAction& action) override
612   {
613     MPI_Request request = get_reqq_self()->back();
614     get_reqq_self()->pop_back();
615     // if request is null here, this may mean that a previous test has succeeded
616     // Different times in traced application and replayed version may lead to this
617     // In this case, ignore the extra calls.
618     if (request != nullptr) {
619       TRACE_smpi_testing_in(my_proc_id);
620
621       MPI_Status status;
622       int flag = Request::test(&request, &status);
623
624       XBT_DEBUG("MPI_Test result: %d", flag);
625       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
626        * nullptr.*/
627       get_reqq_self()->push_back(request);
628
629       TRACE_smpi_testing_out(my_proc_id);
630     }
631   }
632 };
633
634 class InitAction : public ReplayAction<ActionArgParser> {
635 public:
636   InitAction() : ReplayAction("Init") {}
637   void kernel(simgrid::xbt::ReplayAction& action) override
638   {
639     CHECK_ACTION_PARAMS(action, 0, 1)
640     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
641                                            : MPI_BYTE;  // default TAU datatype
642
643     /* start a simulated timer */
644     smpi_process()->simulated_start();
645     set_reqq_self(new std::vector<MPI_Request>);
646   }
647 };
648
649 class CommunicatorAction : public ReplayAction<ActionArgParser> {
650 public:
651   CommunicatorAction() : ReplayAction("Comm") {}
652   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
653 };
654
655 class WaitAllAction : public ReplayAction<ActionArgParser> {
656 public:
657   WaitAllAction() : ReplayAction("waitAll") {}
658   void kernel(simgrid::xbt::ReplayAction& action) override
659   {
660     const unsigned int count_requests = get_reqq_self()->size();
661
662     if (count_requests > 0) {
663       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
664       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
665       for (const auto& req : (*get_reqq_self())) {
666         if (req && (req->flags() & RECV)) {
667           sender_receiver.push_back({req->src(), req->dst()});
668         }
669       }
670       MPI_Status status[count_requests];
671       Request::waitall(count_requests, &(*get_reqq_self())[0], status);
672
673       for (auto& pair : sender_receiver) {
674         TRACE_smpi_recv(pair.first, pair.second, 0);
675       }
676       TRACE_smpi_comm_out(my_proc_id);
677     }
678   }
679 };
680
681 class BarrierAction : public ReplayAction<ActionArgParser> {
682 public:
683   BarrierAction() : ReplayAction("barrier") {}
684   void kernel(simgrid::xbt::ReplayAction& action) override
685   {
686     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
687     Colls::barrier(MPI_COMM_WORLD);
688     TRACE_smpi_comm_out(my_proc_id);
689   }
690 };
691
692 class BcastAction : public ReplayAction<BcastArgParser> {
693 public:
694   BcastAction() : ReplayAction("bcast") {}
695   void kernel(simgrid::xbt::ReplayAction& action) override
696   {
697     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
698                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
699                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
700
701     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
702
703     TRACE_smpi_comm_out(my_proc_id);
704   }
705 };
706
707 class ReduceAction : public ReplayAction<ReduceArgParser> {
708 public:
709   ReduceAction() : ReplayAction("reduce") {}
710   void kernel(simgrid::xbt::ReplayAction& action) override
711   {
712     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
713                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
714                                                       args.comp_size, args.comm_size, -1,
715                                                       Datatype::encode(args.datatype1), ""));
716
717     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
718         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
719     smpi_execute_flops(args.comp_size);
720
721     TRACE_smpi_comm_out(my_proc_id);
722   }
723 };
724
725 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
726 public:
727   AllReduceAction() : ReplayAction("allReduce") {}
728   void kernel(simgrid::xbt::ReplayAction& action) override
729   {
730     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
731                                                                                 Datatype::encode(args.datatype1), ""));
732
733     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
734         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
735     smpi_execute_flops(args.comp_size);
736
737     TRACE_smpi_comm_out(my_proc_id);
738   }
739 };
740
741 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
742 public:
743   AllToAllAction() : ReplayAction("allToAll") {}
744   void kernel(simgrid::xbt::ReplayAction& action) override
745   {
746     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
747                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
748                                                     Datatype::encode(args.datatype1),
749                                                     Datatype::encode(args.datatype2)));
750
751     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
752                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
753                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
754
755     TRACE_smpi_comm_out(my_proc_id);
756   }
757 };
758
759 class GatherAction : public ReplayAction<GatherArgParser> {
760 public:
761   explicit GatherAction(std::string name) : ReplayAction(name) {}
762   void kernel(simgrid::xbt::ReplayAction& action) override
763   {
764     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
765                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
766
767     if (name == "gather") {
768       int rank = MPI_COMM_WORLD->rank();
769       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
770                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
771     }
772     else
773       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
774                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
775
776     TRACE_smpi_comm_out(my_proc_id);
777   }
778 };
779
780 class GatherVAction : public ReplayAction<GatherVArgParser> {
781 public:
782   explicit GatherVAction(std::string name) : ReplayAction(name) {}
783   void kernel(simgrid::xbt::ReplayAction& action) override
784   {
785     int rank = MPI_COMM_WORLD->rank();
786
787     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
788                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
789                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
790
791     if (name == "gatherV") {
792       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
793                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
794                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
795     }
796     else {
797       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
798                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
799                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
800     }
801
802     TRACE_smpi_comm_out(my_proc_id);
803   }
804 };
805
806 class ScatterAction : public ReplayAction<ScatterArgParser> {
807 public:
808   ScatterAction() : ReplayAction("scatter") {}
809   void kernel(simgrid::xbt::ReplayAction& action) override
810   {
811     int rank = MPI_COMM_WORLD->rank();
812     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
813                                                                           Datatype::encode(args.datatype1),
814                                                                           Datatype::encode(args.datatype2)));
815
816     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
817                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
818
819     TRACE_smpi_comm_out(my_proc_id);
820   }
821 };
822
823
824 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
825 public:
826   ScatterVAction() : ReplayAction("scatterV") {}
827   void kernel(simgrid::xbt::ReplayAction& action) override
828   {
829     int rank = MPI_COMM_WORLD->rank();
830     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
831           nullptr, Datatype::encode(args.datatype1),
832           Datatype::encode(args.datatype2)));
833
834     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
835                     args.sendcounts->data(), args.disps.data(), args.datatype1,
836                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
837                     MPI_COMM_WORLD);
838
839     TRACE_smpi_comm_out(my_proc_id);
840   }
841 };
842
843 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
844 public:
845   ReduceScatterAction() : ReplayAction("reduceScatter") {}
846   void kernel(simgrid::xbt::ReplayAction& action) override
847   {
848     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
849                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
850                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
851                                                          Datatype::encode(args.datatype1)));
852
853     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
854                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
855                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
856
857     smpi_execute_flops(args.comp_size);
858     TRACE_smpi_comm_out(my_proc_id);
859   }
860 };
861
862 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
863 public:
864   AllToAllVAction() : ReplayAction("allToAllV") {}
865   void kernel(simgrid::xbt::ReplayAction& action) override
866   {
867     TRACE_smpi_comm_in(my_proc_id, __func__,
868                        new simgrid::instr::VarCollTIData(
869                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
870                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
871
872     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
873                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
874
875     TRACE_smpi_comm_out(my_proc_id);
876   }
877 };
878 } // Replay Namespace
879 }} // namespace simgrid::smpi
880
881 /** @brief Only initialize the replay, don't do it for real */
882 void smpi_replay_init(int* argc, char*** argv)
883 {
884   simgrid::smpi::Process::init(argc, argv);
885   smpi_process()->mark_as_initialized();
886   smpi_process()->set_replaying(true);
887
888   int my_proc_id = simgrid::s4u::this_actor::get_pid();
889   TRACE_smpi_init(my_proc_id);
890   TRACE_smpi_computing_init(my_proc_id);
891   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
892   TRACE_smpi_comm_out(my_proc_id);
893   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
894   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
895   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
896   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
897   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
898
899   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send").execute(action); });
900   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend").execute(action); });
901   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv").execute(action); });
902   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv").execute(action); });
903   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction().execute(action); });
904   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction().execute(action); });
905   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction().execute(action); });
906   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
907   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
908   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
909   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
910   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
911   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
912   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
913   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
914   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
915   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
916   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
917   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
918   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
919   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
920
921   //if we have a delayed start, sleep here.
922   if(*argc>2){
923     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
924     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
925     smpi_execute_flops(value);
926   } else {
927     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
928     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
929     smpi_execute_flops(0.0);
930   }
931 }
932
933 /** @brief actually run the replay after initialization */
934 void smpi_replay_main(int* argc, char*** argv)
935 {
936   static int active_processes = 0;
937   active_processes++;
938   simgrid::xbt::replay_runner(*argc, *argv);
939
940   /* and now, finalize everything */
941   /* One active process will stop. Decrease the counter*/
942   XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
943   if (not get_reqq_self()->empty()) {
944     unsigned int count_requests=get_reqq_self()->size();
945     MPI_Request requests[count_requests];
946     MPI_Status status[count_requests];
947     unsigned int i=0;
948
949     for (auto const& req : *get_reqq_self()) {
950       requests[i] = req;
951       i++;
952     }
953     simgrid::smpi::Request::waitall(count_requests, requests, status);
954   }
955   delete get_reqq_self();
956   active_processes--;
957
958   if(active_processes==0){
959     /* Last process alive speaking: end the simulated timer */
960     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
961     smpi_free_replay_tmp_buffers();
962   }
963
964   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
965                      new simgrid::instr::NoOpTIData("finalize"));
966
967   smpi_process()->finalize();
968
969   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
970   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
971 }
972
973 /** @brief chain a replay initialization and a replay start */
974 void smpi_replay_run(int* argc, char*** argv)
975 {
976   smpi_replay_init(argc, argv);
977   smpi_replay_main(argc, argv);
978 }