Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of scm.gforge.inria.fr:/gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 #include <tuple>
23 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
24 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
25 // this could go into a header file.
26 namespace hash_tuple {
27 template <typename TT> class hash {
28 public:
29   size_t operator()(TT const& tt) const { return std::hash<TT>()(tt); }
30 };
31
32 template <class T> inline void hash_combine(std::size_t& seed, T const& v)
33 {
34   seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
35 }
36
37 // Recursive template code derived from Matthieu M.
38 template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1> class HashValueImpl {
39 public:
40   static void apply(size_t& seed, Tuple const& tuple)
41   {
42     HashValueImpl<Tuple, Index - 1>::apply(seed, tuple);
43     hash_combine(seed, std::get<Index>(tuple));
44   }
45 };
46
47 template <class Tuple> class HashValueImpl<Tuple, 0> {
48 public:
49   static void apply(size_t& seed, Tuple const& tuple) { hash_combine(seed, std::get<0>(tuple)); }
50 };
51
52 template <typename... TT> class hash<std::tuple<TT...>> {
53 public:
54   size_t operator()(std::tuple<TT...> const& tt) const
55   {
56     size_t seed = 0;
57     HashValueImpl<std::tuple<TT...>>::apply(seed, tt);
58     return seed;
59   }
60 };
61 }
62
63 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
64
65 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
66 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
67
68 static MPI_Datatype MPI_DEFAULT_TYPE;
69
70 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
71   {                                                                                                                    \
72     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
73       std::stringstream ss;                                                                                            \
74       for (const auto& elem : action) {                                                                                \
75         ss << elem << " ";                                                                                             \
76       }                                                                                                                \
77       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
78                            "%zu items were given on the line. First two should be process_id and action.  "            \
79                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
80                            "The full line that was given is:\n   %s\n"                                                 \
81                            "Please contact the Simgrid team if support is needed",                                     \
82              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
83              ss.str().c_str());                                                                                        \
84     }                                                                                                                  \
85   }
86
87 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
88 {
89   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
90     std::string s = boost::algorithm::join(action, " ");
91     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
92   }
93 }
94
95 /* Helper function */
96 static double parse_double(std::string string)
97 {
98   return xbt_str_parse_double(string.c_str(), "%s is not a double");
99 }
100
101 namespace simgrid {
102 namespace smpi {
103
104 namespace replay {
105
106 class RequestStorage {
107 private:
108     req_storage_t store;
109
110 public:
111     RequestStorage() {}
112     int size()
113     {
114       return store.size();
115     }
116
117     req_storage_t& get_store()
118     {
119       return store;
120     }
121
122     void get_requests(std::vector<MPI_Request>& vec)
123     {
124       for (auto& pair : store) {
125         auto& req = pair.second;
126         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
127         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
128           vec.push_back(pair.second);
129           pair.second->print_request("MM");
130         }
131       }
132     }
133
134     MPI_Request find(int src, int dst, int tag)
135     {
136       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
137       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
138     }
139
140     void remove(MPI_Request req)
141     {
142       if (req == MPI_REQUEST_NULL) return;
143
144       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
145     }
146
147     void add(MPI_Request req)
148     {
149       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
150         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
151     }
152
153     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
154     void addNullRequest(int src, int dst, int tag)
155     {
156       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
157     }
158 };
159
160 class ActionArgParser {
161 public:
162   virtual ~ActionArgParser() = default;
163   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
164 };
165
166 class WaitTestParser : public ActionArgParser {
167 public:
168   int src;
169   int dst;
170   int tag;
171
172   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
173   {
174     CHECK_ACTION_PARAMS(action, 3, 0)
175     src = std::stoi(action[2]);
176     dst = std::stoi(action[3]);
177     tag = std::stoi(action[4]);
178   }
179 };
180
181 class SendRecvParser : public ActionArgParser {
182 public:
183   /* communication partner; if we send, this is the receiver and vice versa */
184   int partner;
185   double size;
186   int tag;
187   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
188
189   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
190   {
191     CHECK_ACTION_PARAMS(action, 3, 1)
192     partner = std::stoi(action[2]);
193     tag     = std::stoi(action[3]);
194     size    = parse_double(action[4]);
195     if (action.size() > 5)
196       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
197   }
198 };
199
200 class ComputeParser : public ActionArgParser {
201 public:
202   /* communication partner; if we send, this is the receiver and vice versa */
203   double flops;
204
205   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
206   {
207     CHECK_ACTION_PARAMS(action, 1, 0)
208     flops = parse_double(action[2]);
209   }
210 };
211
212 class CollCommParser : public ActionArgParser {
213 public:
214   double size;
215   double comm_size;
216   double comp_size;
217   int send_size;
218   int recv_size;
219   int root = 0;
220   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
221   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
222 };
223
224 class BcastArgParser : public CollCommParser {
225 public:
226   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
227   {
228     CHECK_ACTION_PARAMS(action, 1, 2)
229     size = parse_double(action[2]);
230     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
231     if (action.size() > 4)
232       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
233   }
234 };
235
236 class ReduceArgParser : public CollCommParser {
237 public:
238   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
239   {
240     CHECK_ACTION_PARAMS(action, 2, 2)
241     comm_size = parse_double(action[2]);
242     comp_size = parse_double(action[3]);
243     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
244     if (action.size() > 5)
245       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
246   }
247 };
248
249 class AllReduceArgParser : public CollCommParser {
250 public:
251   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
252   {
253     CHECK_ACTION_PARAMS(action, 2, 1)
254     comm_size = parse_double(action[2]);
255     comp_size = parse_double(action[3]);
256     if (action.size() > 4)
257       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
258   }
259 };
260
261 class AllToAllArgParser : public CollCommParser {
262 public:
263   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
264   {
265     CHECK_ACTION_PARAMS(action, 2, 1)
266     comm_size = MPI_COMM_WORLD->size();
267     send_size = parse_double(action[2]);
268     recv_size = parse_double(action[3]);
269
270     if (action.size() > 4)
271       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
272     if (action.size() > 5)
273       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
274   }
275 };
276
277 class GatherArgParser : public CollCommParser {
278 public:
279   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
280   {
281     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
282           0 gather 68 68 0 0 0
283         where:
284           1) 68 is the sendcounts
285           2) 68 is the recvcounts
286           3) 0 is the root node
287           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
288           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
289     */
290     CHECK_ACTION_PARAMS(action, 2, 3)
291     comm_size = MPI_COMM_WORLD->size();
292     send_size = parse_double(action[2]);
293     recv_size = parse_double(action[3]);
294
295     if (name == "gather") {
296       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
297       if (action.size() > 5)
298         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
299       if (action.size() > 6)
300         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
301     }
302     else {
303       if (action.size() > 4)
304         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
305       if (action.size() > 5)
306         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
307     }
308   }
309 };
310
311 class GatherVArgParser : public CollCommParser {
312 public:
313   int recv_size_sum;
314   std::shared_ptr<std::vector<int>> recvcounts;
315   std::vector<int> disps;
316   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
317   {
318     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
319          0 gather 68 68 10 10 10 0 0 0
320        where:
321          1) 68 is the sendcount
322          2) 68 10 10 10 is the recvcounts
323          3) 0 is the root node
324          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
325          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
326     */
327     comm_size = MPI_COMM_WORLD->size();
328     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
329     send_size = parse_double(action[2]);
330     disps     = std::vector<int>(comm_size, 0);
331     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
332
333     if (name == "gatherV") {
334       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
335       if (action.size() > 4 + comm_size)
336         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
337       if (action.size() > 5 + comm_size)
338         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
339     }
340     else {
341       int datatype_index = 0;
342       int disp_index     = 0;
343       /* The 3 comes from "0 gather <sendcount>", which must always be present.
344        * The + comm_size is the recvcounts array, which must also be present
345        */
346       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
347         datatype_index = 3 + comm_size;
348         disp_index     = datatype_index + 1;
349         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
350         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
351       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
352         disp_index     = 3 + comm_size;
353       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
354         datatype_index = 3 + comm_size;
355         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
356         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
357       }
358
359       if (disp_index != 0) {
360         for (unsigned int i = 0; i < comm_size; i++)
361           disps[i]          = std::stoi(action[disp_index + i]);
362       }
363     }
364
365     for (unsigned int i = 0; i < comm_size; i++) {
366       (*recvcounts)[i] = std::stoi(action[i + 3]);
367     }
368     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
369   }
370 };
371
372 class ScatterArgParser : public CollCommParser {
373 public:
374   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
375   {
376     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
377           0 gather 68 68 0 0 0
378         where:
379           1) 68 is the sendcounts
380           2) 68 is the recvcounts
381           3) 0 is the root node
382           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
383           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
384     */
385     CHECK_ACTION_PARAMS(action, 2, 3)
386     comm_size   = MPI_COMM_WORLD->size();
387     send_size   = parse_double(action[2]);
388     recv_size   = parse_double(action[3]);
389     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
390     if (action.size() > 5)
391       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
392     if (action.size() > 6)
393       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
394   }
395 };
396
397 class ScatterVArgParser : public CollCommParser {
398 public:
399   int recv_size_sum;
400   int send_size_sum;
401   std::shared_ptr<std::vector<int>> sendcounts;
402   std::vector<int> disps;
403   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
404   {
405     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
406        0 gather 68 10 10 10 68 0 0 0
407         where:
408         1) 68 10 10 10 is the sendcounts
409         2) 68 is the recvcount
410         3) 0 is the root node
411         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
412         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
413     */
414     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
415     recv_size  = parse_double(action[2 + comm_size]);
416     disps      = std::vector<int>(comm_size, 0);
417     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
418
419     if (action.size() > 5 + comm_size)
420       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
421     if (action.size() > 5 + comm_size)
422       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
423
424     for (unsigned int i = 0; i < comm_size; i++) {
425       (*sendcounts)[i] = std::stoi(action[i + 2]);
426     }
427     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
428     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
429   }
430 };
431
432 class ReduceScatterArgParser : public CollCommParser {
433 public:
434   int recv_size_sum;
435   std::shared_ptr<std::vector<int>> recvcounts;
436   std::vector<int> disps;
437   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
438   {
439     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
440          0 reduceScatter 275427 275427 275427 204020 11346849 0
441        where:
442          1) The first four values after the name of the action declare the recvcounts array
443          2) The value 11346849 is the amount of instructions
444          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
445     */
446     comm_size = MPI_COMM_WORLD->size();
447     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
448     comp_size = parse_double(action[2+comm_size]);
449     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
450     if (action.size() > 3 + comm_size)
451       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
452
453     for (unsigned int i = 0; i < comm_size; i++) {
454       recvcounts->push_back(std::stoi(action[i + 2]));
455     }
456     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
457   }
458 };
459
460 class AllToAllVArgParser : public CollCommParser {
461 public:
462   int recv_size_sum;
463   int send_size_sum;
464   std::shared_ptr<std::vector<int>> recvcounts;
465   std::shared_ptr<std::vector<int>> sendcounts;
466   std::vector<int> senddisps;
467   std::vector<int> recvdisps;
468   int send_buf_size;
469   int recv_buf_size;
470   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
471   {
472     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
473           0 allToAllV 100 1 7 10 12 100 1 70 10 5
474        where:
475         1) 100 is the size of the send buffer *sizeof(int),
476         2) 1 7 10 12 is the sendcounts array
477         3) 100*sizeof(int) is the size of the receiver buffer
478         4)  1 70 10 5 is the recvcounts array
479     */
480     comm_size = MPI_COMM_WORLD->size();
481     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
482     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
483     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
484     senddisps  = std::vector<int>(comm_size, 0);
485     recvdisps  = std::vector<int>(comm_size, 0);
486
487     if (action.size() > 5 + 2 * comm_size)
488       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
489     if (action.size() > 5 + 2 * comm_size)
490       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
491
492     send_buf_size=parse_double(action[2]);
493     recv_buf_size=parse_double(action[3+comm_size]);
494     for (unsigned int i = 0; i < comm_size; i++) {
495       (*sendcounts)[i] = std::stoi(action[3 + i]);
496       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
497     }
498     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
499     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
500   }
501 };
502
503 template <class T> class ReplayAction {
504 protected:
505   const std::string name;
506   const int my_proc_id;
507   T args;
508
509 public:
510   explicit ReplayAction(std::string name) : name(name), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
511   virtual ~ReplayAction() = default;
512
513   virtual void execute(simgrid::xbt::ReplayAction& action)
514   {
515     // Needs to be re-initialized for every action, hence here
516     double start_time = smpi_process()->simulated_elapsed();
517     args.parse(action, name);
518     kernel(action);
519     if (name != "Init")
520       log_timed_action(action, start_time);
521   }
522
523   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
524
525   void* send_buffer(int size)
526   {
527     return smpi_get_tmp_sendbuffer(size);
528   }
529
530   void* recv_buffer(int size)
531   {
532     return smpi_get_tmp_recvbuffer(size);
533   }
534 };
535
536 class WaitAction : public ReplayAction<WaitTestParser> {
537 private:
538   RequestStorage& req_storage;
539
540 public:
541   explicit WaitAction(RequestStorage& storage) : ReplayAction("Wait"), req_storage(storage) {}
542   void kernel(simgrid::xbt::ReplayAction& action) override
543   {
544     std::string s = boost::algorithm::join(action, " ");
545     xbt_assert(req_storage.size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
546     MPI_Request request = req_storage.find(args.src, args.dst, args.tag);
547     req_storage.remove(request);
548
549     if (request == MPI_REQUEST_NULL) {
550       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
551        * return.*/
552       return;
553     }
554
555     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
556
557     // Must be taken before Request::wait() since the request may be set to
558     // MPI_REQUEST_NULL by Request::wait!
559     bool is_wait_for_receive = (request->flags() & RECV);
560     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
561     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
562
563     MPI_Status status;
564     Request::wait(&request, &status);
565
566     TRACE_smpi_comm_out(rank);
567     if (is_wait_for_receive)
568       TRACE_smpi_recv(args.src, args.dst, args.tag);
569   }
570 };
571
572 class SendAction : public ReplayAction<SendRecvParser> {
573 private:
574   RequestStorage& req_storage;
575
576 public:
577   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name), req_storage(storage) {}
578   void kernel(simgrid::xbt::ReplayAction& action) override
579   {
580     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
581
582     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
583                                                                              args.tag, Datatype::encode(args.datatype1)));
584     if (not TRACE_smpi_view_internals())
585       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
586
587     if (name == "send") {
588       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
589     } else if (name == "Isend") {
590       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
591       req_storage.add(request);
592     } else {
593       xbt_die("Don't know this action, %s", name.c_str());
594     }
595
596     TRACE_smpi_comm_out(my_proc_id);
597   }
598 };
599
600 class RecvAction : public ReplayAction<SendRecvParser> {
601 private:
602   RequestStorage& req_storage;
603
604 public:
605   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name), req_storage(storage) {}
606   void kernel(simgrid::xbt::ReplayAction& action) override
607   {
608     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
609
610     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
611                                                                              args.tag, Datatype::encode(args.datatype1)));
612
613     MPI_Status status;
614     // unknown size from the receiver point of view
615     if (args.size <= 0.0) {
616       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
617       args.size = status.count;
618     }
619
620     if (name == "recv") {
621       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
622     } else if (name == "Irecv") {
623       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
624       req_storage.add(request);
625     }
626
627     TRACE_smpi_comm_out(my_proc_id);
628     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
629     if (name == "recv" && not TRACE_smpi_view_internals()) {
630       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
631     }
632   }
633 };
634
635 class ComputeAction : public ReplayAction<ComputeParser> {
636 public:
637   ComputeAction() : ReplayAction("compute") {}
638   void kernel(simgrid::xbt::ReplayAction& action) override
639   {
640     TRACE_smpi_computing_in(my_proc_id, args.flops);
641     smpi_execute_flops(args.flops);
642     TRACE_smpi_computing_out(my_proc_id);
643   }
644 };
645
646 class TestAction : public ReplayAction<WaitTestParser> {
647 private:
648   RequestStorage& req_storage;
649
650 public:
651   explicit TestAction(RequestStorage& storage) : ReplayAction("Test"), req_storage(storage) {}
652   void kernel(simgrid::xbt::ReplayAction& action) override
653   {
654     MPI_Request request = req_storage.find(args.src, args.dst, args.tag);
655     req_storage.remove(request);
656     // if request is null here, this may mean that a previous test has succeeded
657     // Different times in traced application and replayed version may lead to this
658     // In this case, ignore the extra calls.
659     if (request != MPI_REQUEST_NULL) {
660       TRACE_smpi_testing_in(my_proc_id);
661
662       MPI_Status status;
663       int flag = Request::test(&request, &status);
664
665       XBT_DEBUG("MPI_Test result: %d", flag);
666       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
667        * nullptr.*/
668       if (request == MPI_REQUEST_NULL)
669         req_storage.addNullRequest(args.src, args.dst, args.tag);
670       else
671         req_storage.add(request);
672
673       TRACE_smpi_testing_out(my_proc_id);
674     }
675   }
676 };
677
678 class InitAction : public ReplayAction<ActionArgParser> {
679 public:
680   InitAction() : ReplayAction("Init") {}
681   void kernel(simgrid::xbt::ReplayAction& action) override
682   {
683     CHECK_ACTION_PARAMS(action, 0, 1)
684     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
685                                            : MPI_BYTE;  // default TAU datatype
686
687     /* start a simulated timer */
688     smpi_process()->simulated_start();
689   }
690 };
691
692 class CommunicatorAction : public ReplayAction<ActionArgParser> {
693 public:
694   CommunicatorAction() : ReplayAction("Comm") {}
695   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
696 };
697
698 class WaitAllAction : public ReplayAction<ActionArgParser> {
699 private:
700   RequestStorage& req_storage;
701
702 public:
703   explicit WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll"), req_storage(storage) {}
704   void kernel(simgrid::xbt::ReplayAction& action) override
705   {
706     const unsigned int count_requests = req_storage.size();
707
708     if (count_requests > 0) {
709       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
710       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
711       std::vector<MPI_Request> reqs;
712       req_storage.get_requests(reqs);
713       for (const auto& req : reqs) {
714         if (req && (req->flags() & RECV)) {
715           sender_receiver.push_back({req->src(), req->dst()});
716         }
717       }
718       MPI_Status status[count_requests];
719       Request::waitall(count_requests, &(reqs.data())[0], status);
720       req_storage.get_store().clear();
721
722       for (auto& pair : sender_receiver) {
723         TRACE_smpi_recv(pair.first, pair.second, 0);
724       }
725       TRACE_smpi_comm_out(my_proc_id);
726     }
727   }
728 };
729
730 class BarrierAction : public ReplayAction<ActionArgParser> {
731 public:
732   BarrierAction() : ReplayAction("barrier") {}
733   void kernel(simgrid::xbt::ReplayAction& action) override
734   {
735     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
736     Colls::barrier(MPI_COMM_WORLD);
737     TRACE_smpi_comm_out(my_proc_id);
738   }
739 };
740
741 class BcastAction : public ReplayAction<BcastArgParser> {
742 public:
743   BcastAction() : ReplayAction("bcast") {}
744   void kernel(simgrid::xbt::ReplayAction& action) override
745   {
746     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
747                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
748                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
749
750     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
751
752     TRACE_smpi_comm_out(my_proc_id);
753   }
754 };
755
756 class ReduceAction : public ReplayAction<ReduceArgParser> {
757 public:
758   ReduceAction() : ReplayAction("reduce") {}
759   void kernel(simgrid::xbt::ReplayAction& action) override
760   {
761     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
762                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
763                                                       args.comp_size, args.comm_size, -1,
764                                                       Datatype::encode(args.datatype1), ""));
765
766     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
767         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
768     smpi_execute_flops(args.comp_size);
769
770     TRACE_smpi_comm_out(my_proc_id);
771   }
772 };
773
774 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
775 public:
776   AllReduceAction() : ReplayAction("allReduce") {}
777   void kernel(simgrid::xbt::ReplayAction& action) override
778   {
779     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
780                                                                                 Datatype::encode(args.datatype1), ""));
781
782     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
783         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
784     smpi_execute_flops(args.comp_size);
785
786     TRACE_smpi_comm_out(my_proc_id);
787   }
788 };
789
790 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
791 public:
792   AllToAllAction() : ReplayAction("allToAll") {}
793   void kernel(simgrid::xbt::ReplayAction& action) override
794   {
795     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
796                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
797                                                     Datatype::encode(args.datatype1),
798                                                     Datatype::encode(args.datatype2)));
799
800     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
801                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
802                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
803
804     TRACE_smpi_comm_out(my_proc_id);
805   }
806 };
807
808 class GatherAction : public ReplayAction<GatherArgParser> {
809 public:
810   explicit GatherAction(std::string name) : ReplayAction(name) {}
811   void kernel(simgrid::xbt::ReplayAction& action) override
812   {
813     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
814                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
815
816     if (name == "gather") {
817       int rank = MPI_COMM_WORLD->rank();
818       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
819                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
820     }
821     else
822       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
823                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
824
825     TRACE_smpi_comm_out(my_proc_id);
826   }
827 };
828
829 class GatherVAction : public ReplayAction<GatherVArgParser> {
830 public:
831   explicit GatherVAction(std::string name) : ReplayAction(name) {}
832   void kernel(simgrid::xbt::ReplayAction& action) override
833   {
834     int rank = MPI_COMM_WORLD->rank();
835
836     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
837                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
838                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
839
840     if (name == "gatherV") {
841       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
842                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
843                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
844     }
845     else {
846       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
847                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
848                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
849     }
850
851     TRACE_smpi_comm_out(my_proc_id);
852   }
853 };
854
855 class ScatterAction : public ReplayAction<ScatterArgParser> {
856 public:
857   ScatterAction() : ReplayAction("scatter") {}
858   void kernel(simgrid::xbt::ReplayAction& action) override
859   {
860     int rank = MPI_COMM_WORLD->rank();
861     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
862                                                                           Datatype::encode(args.datatype1),
863                                                                           Datatype::encode(args.datatype2)));
864
865     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
866                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
867
868     TRACE_smpi_comm_out(my_proc_id);
869   }
870 };
871
872
873 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
874 public:
875   ScatterVAction() : ReplayAction("scatterV") {}
876   void kernel(simgrid::xbt::ReplayAction& action) override
877   {
878     int rank = MPI_COMM_WORLD->rank();
879     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
880           nullptr, Datatype::encode(args.datatype1),
881           Datatype::encode(args.datatype2)));
882
883     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
884                     args.sendcounts->data(), args.disps.data(), args.datatype1,
885                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
886                     MPI_COMM_WORLD);
887
888     TRACE_smpi_comm_out(my_proc_id);
889   }
890 };
891
892 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
893 public:
894   ReduceScatterAction() : ReplayAction("reduceScatter") {}
895   void kernel(simgrid::xbt::ReplayAction& action) override
896   {
897     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
898                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
899                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
900                                                          Datatype::encode(args.datatype1)));
901
902     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
903                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
904                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
905
906     smpi_execute_flops(args.comp_size);
907     TRACE_smpi_comm_out(my_proc_id);
908   }
909 };
910
911 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
912 public:
913   AllToAllVAction() : ReplayAction("allToAllV") {}
914   void kernel(simgrid::xbt::ReplayAction& action) override
915   {
916     TRACE_smpi_comm_in(my_proc_id, __func__,
917                        new simgrid::instr::VarCollTIData(
918                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
919                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
920
921     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
922                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
923
924     TRACE_smpi_comm_out(my_proc_id);
925   }
926 };
927 } // Replay Namespace
928 }} // namespace simgrid::smpi
929
930 std::vector<simgrid::smpi::replay::RequestStorage> storage;
931 /** @brief Only initialize the replay, don't do it for real */
932 void smpi_replay_init(int* argc, char*** argv)
933 {
934   simgrid::smpi::Process::init(argc, argv);
935   smpi_process()->mark_as_initialized();
936   smpi_process()->set_replaying(true);
937
938   int my_proc_id = simgrid::s4u::this_actor::get_pid();
939   storage.resize(smpi_process_count());
940
941   TRACE_smpi_init(my_proc_id);
942   TRACE_smpi_computing_init(my_proc_id);
943   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
944   TRACE_smpi_comm_out(my_proc_id);
945   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
946   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
947   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
948   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
949   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
950
951   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
952   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
953   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
954   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
955   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
956   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
957   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
958   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
959   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
960   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
961   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
962   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
963   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
964   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
965   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
966   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
967   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
968   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
969   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
970   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
971   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
972
973   //if we have a delayed start, sleep here.
974   if(*argc>2){
975     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
976     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
977     smpi_execute_flops(value);
978   } else {
979     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
980     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
981     smpi_execute_flops(0.0);
982   }
983 }
984
985 /** @brief actually run the replay after initialization */
986 void smpi_replay_main(int* argc, char*** argv)
987 {
988   static int active_processes = 0;
989   active_processes++;
990   simgrid::xbt::replay_runner(*argc, *argv);
991
992   /* and now, finalize everything */
993   /* One active process will stop. Decrease the counter*/
994   unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
995   XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
996   if (count_requests > 0) {
997     MPI_Request requests[count_requests];
998     MPI_Status status[count_requests];
999     unsigned int i=0;
1000
1001     for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
1002       requests[i] = pair.second;
1003       i++;
1004     }
1005     simgrid::smpi::Request::waitall(count_requests, requests, status);
1006   }
1007   active_processes--;
1008
1009   if(active_processes==0){
1010     /* Last process alive speaking: end the simulated timer */
1011     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1012     smpi_free_replay_tmp_buffers();
1013   }
1014
1015   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1016                      new simgrid::instr::NoOpTIData("finalize"));
1017
1018   smpi_process()->finalize();
1019
1020   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1021   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1022 }
1023
1024 /** @brief chain a replay initialization and a replay start */
1025 void smpi_replay_run(int* argc, char*** argv)
1026 {
1027   smpi_replay_init(argc, argv);
1028   smpi_replay_main(argc, argv);
1029 }