Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Move TestAction to the RequestStore
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
81 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
82 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
83
84 static MPI_Datatype MPI_DEFAULT_TYPE;
85
86 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
87   {                                                                                                                    \
88     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
89       std::stringstream ss;                                                                                            \
90       for (const auto& elem : action) {                                                                                \
91         ss << elem << " ";                                                                                             \
92       }                                                                                                                \
93       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
94                            "%zu items were given on the line. First two should be process_id and action.  "            \
95                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
96                            "The full line that was given is:\n   %s\n"                                                 \
97                            "Please contact the Simgrid team if support is needed",                                     \
98              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
99              ss.str().c_str());                                                                                        \
100     }                                                                                                                  \
101   }
102
103 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
104 {
105   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
106     std::string s = boost::algorithm::join(action, " ");
107     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
108   }
109 }
110
111 static std::vector<MPI_Request>* get_reqq_self()
112 {
113   return reqq.at(simgrid::s4u::this_actor::get_pid());
114 }
115
116 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
117 {
118   reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
119 }
120
121 /* Helper function */
122 static double parse_double(std::string string)
123 {
124   return xbt_str_parse_double(string.c_str(), "%s is not a double");
125 }
126
127 namespace simgrid {
128 namespace smpi {
129
130 namespace replay {
131
132 class RequestStorage {
133 private:
134     req_storage_t store;
135
136 public:
137     RequestStorage() {}
138     int size()
139     {
140       return store.size();
141     }
142
143     req_storage_t& get_store()
144     {
145       return store;
146     }
147
148     void get_requests(std::vector<MPI_Request>& vec)
149     {
150       for (auto& pair : store) {
151         auto& req = pair.second;
152         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
153         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
154           vec.push_back(pair.second);
155           pair.second->print_request("MM");
156         }
157       }
158     }
159
160     MPI_Request find(int src, int dst, int tag)
161     {
162       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
163       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
164     }
165
166     void remove(MPI_Request req)
167     {
168       if (req == MPI_REQUEST_NULL) return;
169
170       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
171     }
172
173     void add(MPI_Request req)
174     {
175       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
176         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
177     }
178
179     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
180     void addNullRequest(int src, int dst, int tag)
181     {
182       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
183     }
184 };
185
186 class ActionArgParser {
187 public:
188   virtual ~ActionArgParser() = default;
189   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
190 };
191
192 class WaitTestParser : public ActionArgParser {
193 public:
194   int src;
195   int dst;
196   int tag;
197
198   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
199   {
200     CHECK_ACTION_PARAMS(action, 3, 0)
201     src = std::stoi(action[2]);
202     dst = std::stoi(action[3]);
203     tag = std::stoi(action[4]);
204   }
205 };
206
207 class SendRecvParser : public ActionArgParser {
208 public:
209   /* communication partner; if we send, this is the receiver and vice versa */
210   int partner;
211   double size;
212   int tag;
213   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
214
215   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
216   {
217     CHECK_ACTION_PARAMS(action, 3, 1)
218     partner = std::stoi(action[2]);
219     tag     = std::stoi(action[3]);
220     size    = parse_double(action[4]);
221     if (action.size() > 5)
222       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
223   }
224 };
225
226 class ComputeParser : public ActionArgParser {
227 public:
228   /* communication partner; if we send, this is the receiver and vice versa */
229   double flops;
230
231   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
232   {
233     CHECK_ACTION_PARAMS(action, 1, 0)
234     flops = parse_double(action[2]);
235   }
236 };
237
238 class CollCommParser : public ActionArgParser {
239 public:
240   double size;
241   double comm_size;
242   double comp_size;
243   int send_size;
244   int recv_size;
245   int root = 0;
246   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
247   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
248 };
249
250 class BcastArgParser : public CollCommParser {
251 public:
252   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
253   {
254     CHECK_ACTION_PARAMS(action, 1, 2)
255     size = parse_double(action[2]);
256     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
257     if (action.size() > 4)
258       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
259   }
260 };
261
262 class ReduceArgParser : public CollCommParser {
263 public:
264   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
265   {
266     CHECK_ACTION_PARAMS(action, 2, 2)
267     comm_size = parse_double(action[2]);
268     comp_size = parse_double(action[3]);
269     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
270     if (action.size() > 5)
271       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
272   }
273 };
274
275 class AllReduceArgParser : public CollCommParser {
276 public:
277   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
278   {
279     CHECK_ACTION_PARAMS(action, 2, 1)
280     comm_size = parse_double(action[2]);
281     comp_size = parse_double(action[3]);
282     if (action.size() > 4)
283       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
284   }
285 };
286
287 class AllToAllArgParser : public CollCommParser {
288 public:
289   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
290   {
291     CHECK_ACTION_PARAMS(action, 2, 1)
292     comm_size = MPI_COMM_WORLD->size();
293     send_size = parse_double(action[2]);
294     recv_size = parse_double(action[3]);
295
296     if (action.size() > 4)
297       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
298     if (action.size() > 5)
299       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
300   }
301 };
302
303 class GatherArgParser : public CollCommParser {
304 public:
305   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
306   {
307     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
308           0 gather 68 68 0 0 0
309         where:
310           1) 68 is the sendcounts
311           2) 68 is the recvcounts
312           3) 0 is the root node
313           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
314           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
315     */
316     CHECK_ACTION_PARAMS(action, 2, 3)
317     comm_size = MPI_COMM_WORLD->size();
318     send_size = parse_double(action[2]);
319     recv_size = parse_double(action[3]);
320
321     if (name == "gather") {
322       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
323       if (action.size() > 5)
324         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
325       if (action.size() > 6)
326         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
327     }
328     else {
329       if (action.size() > 4)
330         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
331       if (action.size() > 5)
332         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
333     }
334   }
335 };
336
337 class GatherVArgParser : public CollCommParser {
338 public:
339   int recv_size_sum;
340   std::shared_ptr<std::vector<int>> recvcounts;
341   std::vector<int> disps;
342   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
343   {
344     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
345          0 gather 68 68 10 10 10 0 0 0
346        where:
347          1) 68 is the sendcount
348          2) 68 10 10 10 is the recvcounts
349          3) 0 is the root node
350          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
351          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
352     */
353     comm_size = MPI_COMM_WORLD->size();
354     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
355     send_size = parse_double(action[2]);
356     disps     = std::vector<int>(comm_size, 0);
357     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
358
359     if (name == "gatherV") {
360       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
361       if (action.size() > 4 + comm_size)
362         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
363       if (action.size() > 5 + comm_size)
364         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
365     }
366     else {
367       int datatype_index = 0;
368       int disp_index     = 0;
369       /* The 3 comes from "0 gather <sendcount>", which must always be present.
370        * The + comm_size is the recvcounts array, which must also be present
371        */
372       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
373         datatype_index = 3 + comm_size;
374         disp_index     = datatype_index + 1;
375         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
376         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
377       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
378         disp_index     = 3 + comm_size;
379       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
380         datatype_index = 3 + comm_size;
381         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
382         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
383       }
384
385       if (disp_index != 0) {
386         for (unsigned int i = 0; i < comm_size; i++)
387           disps[i]          = std::stoi(action[disp_index + i]);
388       }
389     }
390
391     for (unsigned int i = 0; i < comm_size; i++) {
392       (*recvcounts)[i] = std::stoi(action[i + 3]);
393     }
394     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
395   }
396 };
397
398 class ScatterArgParser : public CollCommParser {
399 public:
400   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
401   {
402     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
403           0 gather 68 68 0 0 0
404         where:
405           1) 68 is the sendcounts
406           2) 68 is the recvcounts
407           3) 0 is the root node
408           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
409           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
410     */
411     CHECK_ACTION_PARAMS(action, 2, 3)
412     comm_size   = MPI_COMM_WORLD->size();
413     send_size   = parse_double(action[2]);
414     recv_size   = parse_double(action[3]);
415     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
416     if (action.size() > 5)
417       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
418     if (action.size() > 6)
419       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
420   }
421 };
422
423 class ScatterVArgParser : public CollCommParser {
424 public:
425   int recv_size_sum;
426   int send_size_sum;
427   std::shared_ptr<std::vector<int>> sendcounts;
428   std::vector<int> disps;
429   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
430   {
431     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
432        0 gather 68 10 10 10 68 0 0 0
433         where:
434         1) 68 10 10 10 is the sendcounts
435         2) 68 is the recvcount
436         3) 0 is the root node
437         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
438         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
439     */
440     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
441     recv_size  = parse_double(action[2 + comm_size]);
442     disps      = std::vector<int>(comm_size, 0);
443     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
444
445     if (action.size() > 5 + comm_size)
446       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
447     if (action.size() > 5 + comm_size)
448       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
449
450     for (unsigned int i = 0; i < comm_size; i++) {
451       (*sendcounts)[i] = std::stoi(action[i + 2]);
452     }
453     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
454     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
455   }
456 };
457
458 class ReduceScatterArgParser : public CollCommParser {
459 public:
460   int recv_size_sum;
461   std::shared_ptr<std::vector<int>> recvcounts;
462   std::vector<int> disps;
463   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
464   {
465     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
466          0 reduceScatter 275427 275427 275427 204020 11346849 0
467        where:
468          1) The first four values after the name of the action declare the recvcounts array
469          2) The value 11346849 is the amount of instructions
470          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
471     */
472     comm_size = MPI_COMM_WORLD->size();
473     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
474     comp_size = parse_double(action[2+comm_size]);
475     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
476     if (action.size() > 3 + comm_size)
477       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
478
479     for (unsigned int i = 0; i < comm_size; i++) {
480       recvcounts->push_back(std::stoi(action[i + 2]));
481     }
482     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
483   }
484 };
485
486 class AllToAllVArgParser : public CollCommParser {
487 public:
488   int recv_size_sum;
489   int send_size_sum;
490   std::shared_ptr<std::vector<int>> recvcounts;
491   std::shared_ptr<std::vector<int>> sendcounts;
492   std::vector<int> senddisps;
493   std::vector<int> recvdisps;
494   int send_buf_size;
495   int recv_buf_size;
496   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
497   {
498     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
499           0 allToAllV 100 1 7 10 12 100 1 70 10 5
500        where:
501         1) 100 is the size of the send buffer *sizeof(int),
502         2) 1 7 10 12 is the sendcounts array
503         3) 100*sizeof(int) is the size of the receiver buffer
504         4)  1 70 10 5 is the recvcounts array
505     */
506     comm_size = MPI_COMM_WORLD->size();
507     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
508     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
509     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
510     senddisps  = std::vector<int>(comm_size, 0);
511     recvdisps  = std::vector<int>(comm_size, 0);
512
513     if (action.size() > 5 + 2 * comm_size)
514       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
515     if (action.size() > 5 + 2 * comm_size)
516       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
517
518     send_buf_size=parse_double(action[2]);
519     recv_buf_size=parse_double(action[3+comm_size]);
520     for (unsigned int i = 0; i < comm_size; i++) {
521       (*sendcounts)[i] = std::stoi(action[3 + i]);
522       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
523     }
524     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
525     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
526   }
527 };
528
529 template <class T> class ReplayAction {
530 protected:
531   const std::string name;
532   RequestStorage* req_storage; // Points to the right storage for this process, nullptr except for Send/Recv/Wait/Test actions.
533   const int my_proc_id;
534   T args;
535
536 public:
537   explicit ReplayAction(std::string name, RequestStorage& storage) : name(name), req_storage(&storage), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
538   explicit ReplayAction(std::string name) : name(name), req_storage(nullptr), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
539   virtual ~ReplayAction() = default;
540
541   virtual void execute(simgrid::xbt::ReplayAction& action)
542   {
543     // Needs to be re-initialized for every action, hence here
544     double start_time = smpi_process()->simulated_elapsed();
545     args.parse(action, name);
546     kernel(action);
547     if (name != "Init")
548       log_timed_action(action, start_time);
549   }
550
551   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
552
553   void* send_buffer(int size)
554   {
555     return smpi_get_tmp_sendbuffer(size);
556   }
557
558   void* recv_buffer(int size)
559   {
560     return smpi_get_tmp_recvbuffer(size);
561   }
562 };
563
564 class WaitAction : public ReplayAction<WaitTestParser> {
565 public:
566   WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
567   void kernel(simgrid::xbt::ReplayAction& action) override
568   {
569     std::string s = boost::algorithm::join(action, " ");
570     xbt_assert(req_storage->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
571     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
572     req_storage->remove(request);
573
574     if (request == MPI_REQUEST_NULL) {
575       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
576        * return.*/
577       return;
578     }
579
580     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
581
582     // Must be taken before Request::wait() since the request may be set to
583     // MPI_REQUEST_NULL by Request::wait!
584     bool is_wait_for_receive = (request->flags() & RECV);
585     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
586     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
587
588     MPI_Status status;
589     Request::wait(&request, &status);
590
591     TRACE_smpi_comm_out(rank);
592     if (is_wait_for_receive)
593       TRACE_smpi_recv(args.src, args.dst, args.tag);
594   }
595 };
596
597 class SendAction : public ReplayAction<SendRecvParser> {
598 public:
599   SendAction() = delete;
600   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
601   void kernel(simgrid::xbt::ReplayAction& action) override
602   {
603     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
604
605     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
606                                                                              args.tag, Datatype::encode(args.datatype1)));
607     if (not TRACE_smpi_view_internals())
608       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
609
610     if (name == "send") {
611       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
612     } else if (name == "Isend") {
613       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
614       req_storage->add(request);
615     } else {
616       xbt_die("Don't know this action, %s", name.c_str());
617     }
618
619     TRACE_smpi_comm_out(my_proc_id);
620   }
621 };
622
623 class RecvAction : public ReplayAction<SendRecvParser> {
624 public:
625   RecvAction() = delete;
626   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
627   void kernel(simgrid::xbt::ReplayAction& action) override
628   {
629     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
630
631     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
632                                                                              args.tag, Datatype::encode(args.datatype1)));
633
634     MPI_Status status;
635     // unknown size from the receiver point of view
636     if (args.size <= 0.0) {
637       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
638       args.size = status.count;
639     }
640
641     if (name == "recv") {
642       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
643     } else if (name == "Irecv") {
644       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
645       req_storage->add(request);
646     }
647
648     TRACE_smpi_comm_out(my_proc_id);
649     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
650     if (name == "recv" && not TRACE_smpi_view_internals()) {
651       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
652     }
653   }
654 };
655
656 class ComputeAction : public ReplayAction<ComputeParser> {
657 public:
658   ComputeAction() : ReplayAction("compute") {}
659   void kernel(simgrid::xbt::ReplayAction& action) override
660   {
661     TRACE_smpi_computing_in(my_proc_id, args.flops);
662     smpi_execute_flops(args.flops);
663     TRACE_smpi_computing_out(my_proc_id);
664   }
665 };
666
667 class TestAction : public ReplayAction<WaitTestParser> {
668 public:
669   TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
670   void kernel(simgrid::xbt::ReplayAction& action) override
671   {
672     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
673     req_storage->remove(request);
674     // if request is null here, this may mean that a previous test has succeeded
675     // Different times in traced application and replayed version may lead to this
676     // In this case, ignore the extra calls.
677     if (request != MPI_REQUEST_NULL) {
678       TRACE_smpi_testing_in(my_proc_id);
679
680       MPI_Status status;
681       int flag = Request::test(&request, &status);
682
683       XBT_DEBUG("MPI_Test result: %d", flag);
684       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
685        * nullptr.*/
686       if (request == MPI_REQUEST_NULL)
687         req_storage->addNullRequest(args.src, args.dst, args.tag);
688       else
689         req_storage->add(request);
690
691       TRACE_smpi_testing_out(my_proc_id);
692     }
693   }
694 };
695
696 class InitAction : public ReplayAction<ActionArgParser> {
697 public:
698   InitAction() : ReplayAction("Init") {}
699   void kernel(simgrid::xbt::ReplayAction& action) override
700   {
701     CHECK_ACTION_PARAMS(action, 0, 1)
702     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
703                                            : MPI_BYTE;  // default TAU datatype
704
705     /* start a simulated timer */
706     smpi_process()->simulated_start();
707     set_reqq_self(new std::vector<MPI_Request>);
708   }
709 };
710
711 class CommunicatorAction : public ReplayAction<ActionArgParser> {
712 public:
713   CommunicatorAction() : ReplayAction("Comm") {}
714   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
715 };
716
717 class WaitAllAction : public ReplayAction<ActionArgParser> {
718 public:
719   WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
720   void kernel(simgrid::xbt::ReplayAction& action) override
721   {
722     const unsigned int count_requests = get_reqq_self()->size();
723
724     if (count_requests > 0) {
725       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
726       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
727       for (const auto& req : (*get_reqq_self())) {
728         if (req && (req->flags() & RECV)) {
729           sender_receiver.push_back({req->src(), req->dst()});
730         }
731       }
732       MPI_Status status[count_requests];
733       Request::waitall(count_requests, &(*get_reqq_self())[0], status);
734
735       for (auto& pair : sender_receiver) {
736         TRACE_smpi_recv(pair.first, pair.second, 0);
737       }
738       TRACE_smpi_comm_out(my_proc_id);
739     }
740   }
741 };
742
743 class BarrierAction : public ReplayAction<ActionArgParser> {
744 public:
745   BarrierAction() : ReplayAction("barrier") {}
746   void kernel(simgrid::xbt::ReplayAction& action) override
747   {
748     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
749     Colls::barrier(MPI_COMM_WORLD);
750     TRACE_smpi_comm_out(my_proc_id);
751   }
752 };
753
754 class BcastAction : public ReplayAction<BcastArgParser> {
755 public:
756   BcastAction() : ReplayAction("bcast") {}
757   void kernel(simgrid::xbt::ReplayAction& action) override
758   {
759     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
760                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
761                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
762
763     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
764
765     TRACE_smpi_comm_out(my_proc_id);
766   }
767 };
768
769 class ReduceAction : public ReplayAction<ReduceArgParser> {
770 public:
771   ReduceAction() : ReplayAction("reduce") {}
772   void kernel(simgrid::xbt::ReplayAction& action) override
773   {
774     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
775                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
776                                                       args.comp_size, args.comm_size, -1,
777                                                       Datatype::encode(args.datatype1), ""));
778
779     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
780         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
781     smpi_execute_flops(args.comp_size);
782
783     TRACE_smpi_comm_out(my_proc_id);
784   }
785 };
786
787 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
788 public:
789   AllReduceAction() : ReplayAction("allReduce") {}
790   void kernel(simgrid::xbt::ReplayAction& action) override
791   {
792     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
793                                                                                 Datatype::encode(args.datatype1), ""));
794
795     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
796         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
797     smpi_execute_flops(args.comp_size);
798
799     TRACE_smpi_comm_out(my_proc_id);
800   }
801 };
802
803 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
804 public:
805   AllToAllAction() : ReplayAction("allToAll") {}
806   void kernel(simgrid::xbt::ReplayAction& action) override
807   {
808     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
809                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
810                                                     Datatype::encode(args.datatype1),
811                                                     Datatype::encode(args.datatype2)));
812
813     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
814                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
815                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
816
817     TRACE_smpi_comm_out(my_proc_id);
818   }
819 };
820
821 class GatherAction : public ReplayAction<GatherArgParser> {
822 public:
823   explicit GatherAction(std::string name) : ReplayAction(name) {}
824   void kernel(simgrid::xbt::ReplayAction& action) override
825   {
826     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
827                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
828
829     if (name == "gather") {
830       int rank = MPI_COMM_WORLD->rank();
831       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
832                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
833     }
834     else
835       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
836                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
837
838     TRACE_smpi_comm_out(my_proc_id);
839   }
840 };
841
842 class GatherVAction : public ReplayAction<GatherVArgParser> {
843 public:
844   explicit GatherVAction(std::string name) : ReplayAction(name) {}
845   void kernel(simgrid::xbt::ReplayAction& action) override
846   {
847     int rank = MPI_COMM_WORLD->rank();
848
849     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
850                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
851                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
852
853     if (name == "gatherV") {
854       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
855                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
856                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
857     }
858     else {
859       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
860                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
861                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
862     }
863
864     TRACE_smpi_comm_out(my_proc_id);
865   }
866 };
867
868 class ScatterAction : public ReplayAction<ScatterArgParser> {
869 public:
870   ScatterAction() : ReplayAction("scatter") {}
871   void kernel(simgrid::xbt::ReplayAction& action) override
872   {
873     int rank = MPI_COMM_WORLD->rank();
874     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
875                                                                           Datatype::encode(args.datatype1),
876                                                                           Datatype::encode(args.datatype2)));
877
878     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
879                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
880
881     TRACE_smpi_comm_out(my_proc_id);
882   }
883 };
884
885
886 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
887 public:
888   ScatterVAction() : ReplayAction("scatterV") {}
889   void kernel(simgrid::xbt::ReplayAction& action) override
890   {
891     int rank = MPI_COMM_WORLD->rank();
892     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
893           nullptr, Datatype::encode(args.datatype1),
894           Datatype::encode(args.datatype2)));
895
896     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
897                     args.sendcounts->data(), args.disps.data(), args.datatype1,
898                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
899                     MPI_COMM_WORLD);
900
901     TRACE_smpi_comm_out(my_proc_id);
902   }
903 };
904
905 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
906 public:
907   ReduceScatterAction() : ReplayAction("reduceScatter") {}
908   void kernel(simgrid::xbt::ReplayAction& action) override
909   {
910     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
911                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
912                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
913                                                          Datatype::encode(args.datatype1)));
914
915     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
916                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
917                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
918
919     smpi_execute_flops(args.comp_size);
920     TRACE_smpi_comm_out(my_proc_id);
921   }
922 };
923
924 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
925 public:
926   AllToAllVAction() : ReplayAction("allToAllV") {}
927   void kernel(simgrid::xbt::ReplayAction& action) override
928   {
929     TRACE_smpi_comm_in(my_proc_id, __func__,
930                        new simgrid::instr::VarCollTIData(
931                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
932                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
933
934     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
935                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
936
937     TRACE_smpi_comm_out(my_proc_id);
938   }
939 };
940 } // Replay Namespace
941 }} // namespace simgrid::smpi
942
943 std::vector<simgrid::smpi::replay::RequestStorage> storage;
944 /** @brief Only initialize the replay, don't do it for real */
945 void smpi_replay_init(int* argc, char*** argv)
946 {
947   simgrid::smpi::Process::init(argc, argv);
948   smpi_process()->mark_as_initialized();
949   smpi_process()->set_replaying(true);
950
951   int my_proc_id = simgrid::s4u::this_actor::get_pid();
952   storage.resize(smpi_process_count());
953
954   TRACE_smpi_init(my_proc_id);
955   TRACE_smpi_computing_init(my_proc_id);
956   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
957   TRACE_smpi_comm_out(my_proc_id);
958   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
959   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
960   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
961   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
962   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
963
964   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
965   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
966   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
967   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
968   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
969   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
970   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
971   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
972   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
973   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
974   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
975   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
976   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
977   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
978   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
979   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
980   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
981   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
982   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
983   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
984   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
985
986   //if we have a delayed start, sleep here.
987   if(*argc>2){
988     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
989     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
990     smpi_execute_flops(value);
991   } else {
992     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
993     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
994     smpi_execute_flops(0.0);
995   }
996 }
997
998 /** @brief actually run the replay after initialization */
999 void smpi_replay_main(int* argc, char*** argv)
1000 {
1001   static int active_processes = 0;
1002   active_processes++;
1003   simgrid::xbt::replay_runner(*argc, *argv);
1004
1005   /* and now, finalize everything */
1006   /* One active process will stop. Decrease the counter*/
1007   XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
1008   if (not get_reqq_self()->empty()) {
1009     unsigned int count_requests=get_reqq_self()->size();
1010     MPI_Request requests[count_requests];
1011     MPI_Status status[count_requests];
1012     unsigned int i=0;
1013
1014     for (auto const& req : *get_reqq_self()) {
1015       requests[i] = req;
1016       i++;
1017     }
1018     simgrid::smpi::Request::waitall(count_requests, requests, status);
1019   }
1020   delete get_reqq_self();
1021   active_processes--;
1022
1023   if(active_processes==0){
1024     /* Last process alive speaking: end the simulated timer */
1025     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1026     smpi_free_replay_tmp_buffers();
1027   }
1028
1029   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1030                      new simgrid::instr::NoOpTIData("finalize"));
1031
1032   smpi_process()->finalize();
1033
1034   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1035   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1036 }
1037
1038 /** @brief chain a replay initialization and a replay start */
1039 void smpi_replay_run(int* argc, char*** argv)
1040 {
1041   smpi_replay_init(argc, argv);
1042   smpi_replay_main(argc, argv);
1043 }