Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
4744eda445af2ae99cd05f73e3055a378d85d960
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 #include <tuple>
23 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
24 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
25 // this could go into a header file.
26 namespace hash_tuple{
27     template <typename TT>
28     class hash
29     {
30     public:
31         size_t
32         operator()(TT const& tt) const
33         {
34             return std::hash<TT>()(tt);
35         }
36     };
37
38     template <class T>
39     inline void hash_combine(std::size_t& seed, T const& v)
40     {
41         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
42     }
43
44     // Recursive template code derived from Matthieu M.
45     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
46     class HashValueImpl
47     {
48     public:
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     class HashValueImpl<Tuple,0>
58     {
59     public:
60       static void apply(size_t& seed, Tuple const& tuple)
61       {
62         hash_combine(seed, std::get<0>(tuple));
63       }
64     };
65
66     template <typename ... TT>
67     class hash<std::tuple<TT...>>
68     {
69     public:
70         size_t
71         operator()(std::tuple<TT...> const& tt) const
72         {
73             size_t seed = 0;
74             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
75             return seed;
76         }
77     };
78 }
79
80 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
81
82 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
83 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
84
85 static MPI_Datatype MPI_DEFAULT_TYPE;
86
87 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
88   {                                                                                                                    \
89     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
90       std::stringstream ss;                                                                                            \
91       for (const auto& elem : action) {                                                                                \
92         ss << elem << " ";                                                                                             \
93       }                                                                                                                \
94       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
95                            "%zu items were given on the line. First two should be process_id and action.  "            \
96                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
97                            "The full line that was given is:\n   %s\n"                                                 \
98                            "Please contact the Simgrid team if support is needed",                                     \
99              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
100              ss.str().c_str());                                                                                        \
101     }                                                                                                                  \
102   }
103
104 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
105 {
106   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
107     std::string s = boost::algorithm::join(action, " ");
108     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
109   }
110 }
111
112 /* Helper function */
113 static double parse_double(std::string string)
114 {
115   return xbt_str_parse_double(string.c_str(), "%s is not a double");
116 }
117
118 namespace simgrid {
119 namespace smpi {
120
121 namespace replay {
122
123 class RequestStorage {
124 private:
125     req_storage_t store;
126
127 public:
128     RequestStorage() {}
129     int size()
130     {
131       return store.size();
132     }
133
134     req_storage_t& get_store()
135     {
136       return store;
137     }
138
139     void get_requests(std::vector<MPI_Request>& vec)
140     {
141       for (auto& pair : store) {
142         auto& req = pair.second;
143         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
144         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
145           vec.push_back(pair.second);
146           pair.second->print_request("MM");
147         }
148       }
149     }
150
151     MPI_Request find(int src, int dst, int tag)
152     {
153       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
154       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
155     }
156
157     void remove(MPI_Request req)
158     {
159       if (req == MPI_REQUEST_NULL) return;
160
161       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
162     }
163
164     void add(MPI_Request req)
165     {
166       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
167         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
168     }
169
170     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
171     void addNullRequest(int src, int dst, int tag)
172     {
173       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
174     }
175 };
176
177 class ActionArgParser {
178 public:
179   virtual ~ActionArgParser() = default;
180   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
181 };
182
183 class WaitTestParser : public ActionArgParser {
184 public:
185   int src;
186   int dst;
187   int tag;
188
189   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
190   {
191     CHECK_ACTION_PARAMS(action, 3, 0)
192     src = std::stoi(action[2]);
193     dst = std::stoi(action[3]);
194     tag = std::stoi(action[4]);
195   }
196 };
197
198 class SendRecvParser : public ActionArgParser {
199 public:
200   /* communication partner; if we send, this is the receiver and vice versa */
201   int partner;
202   double size;
203   int tag;
204   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
205
206   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
207   {
208     CHECK_ACTION_PARAMS(action, 3, 1)
209     partner = std::stoi(action[2]);
210     tag     = std::stoi(action[3]);
211     size    = parse_double(action[4]);
212     if (action.size() > 5)
213       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
214   }
215 };
216
217 class ComputeParser : public ActionArgParser {
218 public:
219   /* communication partner; if we send, this is the receiver and vice versa */
220   double flops;
221
222   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
223   {
224     CHECK_ACTION_PARAMS(action, 1, 0)
225     flops = parse_double(action[2]);
226   }
227 };
228
229 class CollCommParser : public ActionArgParser {
230 public:
231   double size;
232   double comm_size;
233   double comp_size;
234   int send_size;
235   int recv_size;
236   int root = 0;
237   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
238   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
239 };
240
241 class BcastArgParser : public CollCommParser {
242 public:
243   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
244   {
245     CHECK_ACTION_PARAMS(action, 1, 2)
246     size = parse_double(action[2]);
247     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
248     if (action.size() > 4)
249       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
250   }
251 };
252
253 class ReduceArgParser : public CollCommParser {
254 public:
255   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
256   {
257     CHECK_ACTION_PARAMS(action, 2, 2)
258     comm_size = parse_double(action[2]);
259     comp_size = parse_double(action[3]);
260     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
261     if (action.size() > 5)
262       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
263   }
264 };
265
266 class AllReduceArgParser : public CollCommParser {
267 public:
268   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
269   {
270     CHECK_ACTION_PARAMS(action, 2, 1)
271     comm_size = parse_double(action[2]);
272     comp_size = parse_double(action[3]);
273     if (action.size() > 4)
274       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
275   }
276 };
277
278 class AllToAllArgParser : public CollCommParser {
279 public:
280   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
281   {
282     CHECK_ACTION_PARAMS(action, 2, 1)
283     comm_size = MPI_COMM_WORLD->size();
284     send_size = parse_double(action[2]);
285     recv_size = parse_double(action[3]);
286
287     if (action.size() > 4)
288       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
289     if (action.size() > 5)
290       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
291   }
292 };
293
294 class GatherArgParser : public CollCommParser {
295 public:
296   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
297   {
298     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
299           0 gather 68 68 0 0 0
300         where:
301           1) 68 is the sendcounts
302           2) 68 is the recvcounts
303           3) 0 is the root node
304           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
305           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
306     */
307     CHECK_ACTION_PARAMS(action, 2, 3)
308     comm_size = MPI_COMM_WORLD->size();
309     send_size = parse_double(action[2]);
310     recv_size = parse_double(action[3]);
311
312     if (name == "gather") {
313       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
314       if (action.size() > 5)
315         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
316       if (action.size() > 6)
317         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
318     }
319     else {
320       if (action.size() > 4)
321         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
322       if (action.size() > 5)
323         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
324     }
325   }
326 };
327
328 class GatherVArgParser : public CollCommParser {
329 public:
330   int recv_size_sum;
331   std::shared_ptr<std::vector<int>> recvcounts;
332   std::vector<int> disps;
333   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
334   {
335     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
336          0 gather 68 68 10 10 10 0 0 0
337        where:
338          1) 68 is the sendcount
339          2) 68 10 10 10 is the recvcounts
340          3) 0 is the root node
341          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
342          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
343     */
344     comm_size = MPI_COMM_WORLD->size();
345     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
346     send_size = parse_double(action[2]);
347     disps     = std::vector<int>(comm_size, 0);
348     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
349
350     if (name == "gatherV") {
351       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
352       if (action.size() > 4 + comm_size)
353         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
354       if (action.size() > 5 + comm_size)
355         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
356     }
357     else {
358       int datatype_index = 0;
359       int disp_index     = 0;
360       /* The 3 comes from "0 gather <sendcount>", which must always be present.
361        * The + comm_size is the recvcounts array, which must also be present
362        */
363       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
364         datatype_index = 3 + comm_size;
365         disp_index     = datatype_index + 1;
366         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
367         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
368       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
369         disp_index     = 3 + comm_size;
370       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
371         datatype_index = 3 + comm_size;
372         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
373         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
374       }
375
376       if (disp_index != 0) {
377         for (unsigned int i = 0; i < comm_size; i++)
378           disps[i]          = std::stoi(action[disp_index + i]);
379       }
380     }
381
382     for (unsigned int i = 0; i < comm_size; i++) {
383       (*recvcounts)[i] = std::stoi(action[i + 3]);
384     }
385     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
386   }
387 };
388
389 class ScatterArgParser : public CollCommParser {
390 public:
391   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
392   {
393     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
394           0 gather 68 68 0 0 0
395         where:
396           1) 68 is the sendcounts
397           2) 68 is the recvcounts
398           3) 0 is the root node
399           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
400           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
401     */
402     CHECK_ACTION_PARAMS(action, 2, 3)
403     comm_size   = MPI_COMM_WORLD->size();
404     send_size   = parse_double(action[2]);
405     recv_size   = parse_double(action[3]);
406     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
407     if (action.size() > 5)
408       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
409     if (action.size() > 6)
410       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
411   }
412 };
413
414 class ScatterVArgParser : public CollCommParser {
415 public:
416   int recv_size_sum;
417   int send_size_sum;
418   std::shared_ptr<std::vector<int>> sendcounts;
419   std::vector<int> disps;
420   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
421   {
422     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
423        0 gather 68 10 10 10 68 0 0 0
424         where:
425         1) 68 10 10 10 is the sendcounts
426         2) 68 is the recvcount
427         3) 0 is the root node
428         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
429         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
430     */
431     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
432     recv_size  = parse_double(action[2 + comm_size]);
433     disps      = std::vector<int>(comm_size, 0);
434     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
435
436     if (action.size() > 5 + comm_size)
437       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
438     if (action.size() > 5 + comm_size)
439       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
440
441     for (unsigned int i = 0; i < comm_size; i++) {
442       (*sendcounts)[i] = std::stoi(action[i + 2]);
443     }
444     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
445     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
446   }
447 };
448
449 class ReduceScatterArgParser : public CollCommParser {
450 public:
451   int recv_size_sum;
452   std::shared_ptr<std::vector<int>> recvcounts;
453   std::vector<int> disps;
454   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
455   {
456     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
457          0 reduceScatter 275427 275427 275427 204020 11346849 0
458        where:
459          1) The first four values after the name of the action declare the recvcounts array
460          2) The value 11346849 is the amount of instructions
461          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
462     */
463     comm_size = MPI_COMM_WORLD->size();
464     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
465     comp_size = parse_double(action[2+comm_size]);
466     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
467     if (action.size() > 3 + comm_size)
468       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
469
470     for (unsigned int i = 0; i < comm_size; i++) {
471       recvcounts->push_back(std::stoi(action[i + 2]));
472     }
473     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
474   }
475 };
476
477 class AllToAllVArgParser : public CollCommParser {
478 public:
479   int recv_size_sum;
480   int send_size_sum;
481   std::shared_ptr<std::vector<int>> recvcounts;
482   std::shared_ptr<std::vector<int>> sendcounts;
483   std::vector<int> senddisps;
484   std::vector<int> recvdisps;
485   int send_buf_size;
486   int recv_buf_size;
487   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
488   {
489     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
490           0 allToAllV 100 1 7 10 12 100 1 70 10 5
491        where:
492         1) 100 is the size of the send buffer *sizeof(int),
493         2) 1 7 10 12 is the sendcounts array
494         3) 100*sizeof(int) is the size of the receiver buffer
495         4)  1 70 10 5 is the recvcounts array
496     */
497     comm_size = MPI_COMM_WORLD->size();
498     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
499     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
500     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
501     senddisps  = std::vector<int>(comm_size, 0);
502     recvdisps  = std::vector<int>(comm_size, 0);
503
504     if (action.size() > 5 + 2 * comm_size)
505       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
506     if (action.size() > 5 + 2 * comm_size)
507       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
508
509     send_buf_size=parse_double(action[2]);
510     recv_buf_size=parse_double(action[3+comm_size]);
511     for (unsigned int i = 0; i < comm_size; i++) {
512       (*sendcounts)[i] = std::stoi(action[3 + i]);
513       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
514     }
515     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
516     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
517   }
518 };
519
520 template <class T> class ReplayAction {
521 protected:
522   const std::string name;
523   RequestStorage* req_storage; // Points to the right storage for this process, nullptr except for Send/Recv/Wait/Test actions.
524   const int my_proc_id;
525   T args;
526
527 public:
528   explicit ReplayAction(std::string name, RequestStorage& storage) : name(name), req_storage(&storage), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
529   explicit ReplayAction(std::string name) : name(name), req_storage(nullptr), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
530   virtual ~ReplayAction() = default;
531
532   virtual void execute(simgrid::xbt::ReplayAction& action)
533   {
534     // Needs to be re-initialized for every action, hence here
535     double start_time = smpi_process()->simulated_elapsed();
536     args.parse(action, name);
537     kernel(action);
538     if (name != "Init")
539       log_timed_action(action, start_time);
540   }
541
542   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
543
544   void* send_buffer(int size)
545   {
546     return smpi_get_tmp_sendbuffer(size);
547   }
548
549   void* recv_buffer(int size)
550   {
551     return smpi_get_tmp_recvbuffer(size);
552   }
553 };
554
555 class WaitAction : public ReplayAction<WaitTestParser> {
556 public:
557   explicit WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
558   void kernel(simgrid::xbt::ReplayAction& action) override
559   {
560     std::string s = boost::algorithm::join(action, " ");
561     xbt_assert(req_storage->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
562     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
563     req_storage->remove(request);
564
565     if (request == MPI_REQUEST_NULL) {
566       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
567        * return.*/
568       return;
569     }
570
571     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
572
573     // Must be taken before Request::wait() since the request may be set to
574     // MPI_REQUEST_NULL by Request::wait!
575     bool is_wait_for_receive = (request->flags() & RECV);
576     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
577     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
578
579     MPI_Status status;
580     Request::wait(&request, &status);
581
582     TRACE_smpi_comm_out(rank);
583     if (is_wait_for_receive)
584       TRACE_smpi_recv(args.src, args.dst, args.tag);
585   }
586 };
587
588 class SendAction : public ReplayAction<SendRecvParser> {
589 public:
590   SendAction() = delete;
591   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
592   void kernel(simgrid::xbt::ReplayAction& action) override
593   {
594     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
595
596     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
597                                                                              args.tag, Datatype::encode(args.datatype1)));
598     if (not TRACE_smpi_view_internals())
599       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
600
601     if (name == "send") {
602       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
603     } else if (name == "Isend") {
604       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
605       req_storage->add(request);
606     } else {
607       xbt_die("Don't know this action, %s", name.c_str());
608     }
609
610     TRACE_smpi_comm_out(my_proc_id);
611   }
612 };
613
614 class RecvAction : public ReplayAction<SendRecvParser> {
615 public:
616   RecvAction() = delete;
617   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
618   void kernel(simgrid::xbt::ReplayAction& action) override
619   {
620     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
621
622     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
623                                                                              args.tag, Datatype::encode(args.datatype1)));
624
625     MPI_Status status;
626     // unknown size from the receiver point of view
627     if (args.size <= 0.0) {
628       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
629       args.size = status.count;
630     }
631
632     if (name == "recv") {
633       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
634     } else if (name == "Irecv") {
635       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
636       req_storage->add(request);
637     }
638
639     TRACE_smpi_comm_out(my_proc_id);
640     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
641     if (name == "recv" && not TRACE_smpi_view_internals()) {
642       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
643     }
644   }
645 };
646
647 class ComputeAction : public ReplayAction<ComputeParser> {
648 public:
649   ComputeAction() : ReplayAction("compute") {}
650   void kernel(simgrid::xbt::ReplayAction& action) override
651   {
652     TRACE_smpi_computing_in(my_proc_id, args.flops);
653     smpi_execute_flops(args.flops);
654     TRACE_smpi_computing_out(my_proc_id);
655   }
656 };
657
658 class TestAction : public ReplayAction<WaitTestParser> {
659 public:
660   explicit TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
661   void kernel(simgrid::xbt::ReplayAction& action) override
662   {
663     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
664     req_storage->remove(request);
665     // if request is null here, this may mean that a previous test has succeeded
666     // Different times in traced application and replayed version may lead to this
667     // In this case, ignore the extra calls.
668     if (request != MPI_REQUEST_NULL) {
669       TRACE_smpi_testing_in(my_proc_id);
670
671       MPI_Status status;
672       int flag = Request::test(&request, &status);
673
674       XBT_DEBUG("MPI_Test result: %d", flag);
675       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
676        * nullptr.*/
677       if (request == MPI_REQUEST_NULL)
678         req_storage->addNullRequest(args.src, args.dst, args.tag);
679       else
680         req_storage->add(request);
681
682       TRACE_smpi_testing_out(my_proc_id);
683     }
684   }
685 };
686
687 class InitAction : public ReplayAction<ActionArgParser> {
688 public:
689   InitAction() : ReplayAction("Init") {}
690   void kernel(simgrid::xbt::ReplayAction& action) override
691   {
692     CHECK_ACTION_PARAMS(action, 0, 1)
693     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
694                                            : MPI_BYTE;  // default TAU datatype
695
696     /* start a simulated timer */
697     smpi_process()->simulated_start();
698   }
699 };
700
701 class CommunicatorAction : public ReplayAction<ActionArgParser> {
702 public:
703   CommunicatorAction() : ReplayAction("Comm") {}
704   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
705 };
706
707 class WaitAllAction : public ReplayAction<ActionArgParser> {
708 public:
709   explicit WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
710   void kernel(simgrid::xbt::ReplayAction& action) override
711   {
712     const unsigned int count_requests = req_storage->size();
713
714     if (count_requests > 0) {
715       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
716       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
717       std::vector<MPI_Request> reqs;
718       req_storage->get_requests(reqs);
719       for (const auto& req : reqs) {
720         if (req && (req->flags() & RECV)) {
721           sender_receiver.push_back({req->src(), req->dst()});
722         }
723       }
724       MPI_Status status[count_requests];
725       Request::waitall(count_requests, &(reqs.data())[0], status);
726       req_storage->get_store().clear();
727
728       for (auto& pair : sender_receiver) {
729         TRACE_smpi_recv(pair.first, pair.second, 0);
730       }
731       TRACE_smpi_comm_out(my_proc_id);
732     }
733   }
734 };
735
736 class BarrierAction : public ReplayAction<ActionArgParser> {
737 public:
738   BarrierAction() : ReplayAction("barrier") {}
739   void kernel(simgrid::xbt::ReplayAction& action) override
740   {
741     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
742     Colls::barrier(MPI_COMM_WORLD);
743     TRACE_smpi_comm_out(my_proc_id);
744   }
745 };
746
747 class BcastAction : public ReplayAction<BcastArgParser> {
748 public:
749   BcastAction() : ReplayAction("bcast") {}
750   void kernel(simgrid::xbt::ReplayAction& action) override
751   {
752     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
753                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
754                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
755
756     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
757
758     TRACE_smpi_comm_out(my_proc_id);
759   }
760 };
761
762 class ReduceAction : public ReplayAction<ReduceArgParser> {
763 public:
764   ReduceAction() : ReplayAction("reduce") {}
765   void kernel(simgrid::xbt::ReplayAction& action) override
766   {
767     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
768                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
769                                                       args.comp_size, args.comm_size, -1,
770                                                       Datatype::encode(args.datatype1), ""));
771
772     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
773         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
774     smpi_execute_flops(args.comp_size);
775
776     TRACE_smpi_comm_out(my_proc_id);
777   }
778 };
779
780 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
781 public:
782   AllReduceAction() : ReplayAction("allReduce") {}
783   void kernel(simgrid::xbt::ReplayAction& action) override
784   {
785     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
786                                                                                 Datatype::encode(args.datatype1), ""));
787
788     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
789         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
790     smpi_execute_flops(args.comp_size);
791
792     TRACE_smpi_comm_out(my_proc_id);
793   }
794 };
795
796 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
797 public:
798   AllToAllAction() : ReplayAction("allToAll") {}
799   void kernel(simgrid::xbt::ReplayAction& action) override
800   {
801     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
802                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
803                                                     Datatype::encode(args.datatype1),
804                                                     Datatype::encode(args.datatype2)));
805
806     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
807                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
808                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
809
810     TRACE_smpi_comm_out(my_proc_id);
811   }
812 };
813
814 class GatherAction : public ReplayAction<GatherArgParser> {
815 public:
816   explicit GatherAction(std::string name) : ReplayAction(name) {}
817   void kernel(simgrid::xbt::ReplayAction& action) override
818   {
819     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
820                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
821
822     if (name == "gather") {
823       int rank = MPI_COMM_WORLD->rank();
824       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
825                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
826     }
827     else
828       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
829                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
830
831     TRACE_smpi_comm_out(my_proc_id);
832   }
833 };
834
835 class GatherVAction : public ReplayAction<GatherVArgParser> {
836 public:
837   explicit GatherVAction(std::string name) : ReplayAction(name) {}
838   void kernel(simgrid::xbt::ReplayAction& action) override
839   {
840     int rank = MPI_COMM_WORLD->rank();
841
842     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
843                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
844                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
845
846     if (name == "gatherV") {
847       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
848                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
849                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
850     }
851     else {
852       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
853                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
854                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
855     }
856
857     TRACE_smpi_comm_out(my_proc_id);
858   }
859 };
860
861 class ScatterAction : public ReplayAction<ScatterArgParser> {
862 public:
863   ScatterAction() : ReplayAction("scatter") {}
864   void kernel(simgrid::xbt::ReplayAction& action) override
865   {
866     int rank = MPI_COMM_WORLD->rank();
867     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
868                                                                           Datatype::encode(args.datatype1),
869                                                                           Datatype::encode(args.datatype2)));
870
871     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
872                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
873
874     TRACE_smpi_comm_out(my_proc_id);
875   }
876 };
877
878
879 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
880 public:
881   ScatterVAction() : ReplayAction("scatterV") {}
882   void kernel(simgrid::xbt::ReplayAction& action) override
883   {
884     int rank = MPI_COMM_WORLD->rank();
885     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
886           nullptr, Datatype::encode(args.datatype1),
887           Datatype::encode(args.datatype2)));
888
889     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
890                     args.sendcounts->data(), args.disps.data(), args.datatype1,
891                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
892                     MPI_COMM_WORLD);
893
894     TRACE_smpi_comm_out(my_proc_id);
895   }
896 };
897
898 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
899 public:
900   ReduceScatterAction() : ReplayAction("reduceScatter") {}
901   void kernel(simgrid::xbt::ReplayAction& action) override
902   {
903     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
904                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
905                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
906                                                          Datatype::encode(args.datatype1)));
907
908     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
909                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
910                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
911
912     smpi_execute_flops(args.comp_size);
913     TRACE_smpi_comm_out(my_proc_id);
914   }
915 };
916
917 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
918 public:
919   AllToAllVAction() : ReplayAction("allToAllV") {}
920   void kernel(simgrid::xbt::ReplayAction& action) override
921   {
922     TRACE_smpi_comm_in(my_proc_id, __func__,
923                        new simgrid::instr::VarCollTIData(
924                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
925                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
926
927     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
928                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
929
930     TRACE_smpi_comm_out(my_proc_id);
931   }
932 };
933 } // Replay Namespace
934 }} // namespace simgrid::smpi
935
936 std::vector<simgrid::smpi::replay::RequestStorage> storage;
937 /** @brief Only initialize the replay, don't do it for real */
938 void smpi_replay_init(int* argc, char*** argv)
939 {
940   simgrid::smpi::Process::init(argc, argv);
941   smpi_process()->mark_as_initialized();
942   smpi_process()->set_replaying(true);
943
944   int my_proc_id = simgrid::s4u::this_actor::get_pid();
945   storage.resize(smpi_process_count());
946
947   TRACE_smpi_init(my_proc_id);
948   TRACE_smpi_computing_init(my_proc_id);
949   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
950   TRACE_smpi_comm_out(my_proc_id);
951   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
952   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
953   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
954   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
955   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
956
957   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
958   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
959   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
960   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
961   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
962   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
963   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
964   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
965   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
966   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
967   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
968   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
969   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
970   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
971   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
972   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
973   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
974   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
975   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
976   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
977   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
978
979   //if we have a delayed start, sleep here.
980   if(*argc>2){
981     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
982     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
983     smpi_execute_flops(value);
984   } else {
985     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
986     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
987     smpi_execute_flops(0.0);
988   }
989 }
990
991 /** @brief actually run the replay after initialization */
992 void smpi_replay_main(int* argc, char*** argv)
993 {
994   static int active_processes = 0;
995   active_processes++;
996   simgrid::xbt::replay_runner(*argc, *argv);
997
998   /* and now, finalize everything */
999   /* One active process will stop. Decrease the counter*/
1000   unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
1001   XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
1002   if (count_requests > 0) {
1003     MPI_Request requests[count_requests];
1004     MPI_Status status[count_requests];
1005     unsigned int i=0;
1006
1007     for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
1008       requests[i] = pair.second;
1009       i++;
1010     }
1011     simgrid::smpi::Request::waitall(count_requests, requests, status);
1012   }
1013   active_processes--;
1014
1015   if(active_processes==0){
1016     /* Last process alive speaking: end the simulated timer */
1017     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1018     smpi_free_replay_tmp_buffers();
1019   }
1020
1021   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1022                      new simgrid::instr::NoOpTIData("finalize"));
1023
1024   smpi_process()->finalize();
1025
1026   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1027   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1028 }
1029
1030 /** @brief chain a replay initialization and a replay start */
1031 void smpi_replay_run(int* argc, char*** argv)
1032 {
1033   smpi_replay_init(argc, argv);
1034   smpi_replay_main(argc, argv);
1035 }