Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Move replay_init to the RequestStore
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
81 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
82 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
83
84 static MPI_Datatype MPI_DEFAULT_TYPE;
85
86 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
87   {                                                                                                                    \
88     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
89       std::stringstream ss;                                                                                            \
90       for (const auto& elem : action) {                                                                                \
91         ss << elem << " ";                                                                                             \
92       }                                                                                                                \
93       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
94                            "%zu items were given on the line. First two should be process_id and action.  "            \
95                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
96                            "The full line that was given is:\n   %s\n"                                                 \
97                            "Please contact the Simgrid team if support is needed",                                     \
98              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
99              ss.str().c_str());                                                                                        \
100     }                                                                                                                  \
101   }
102
103 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
104 {
105   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
106     std::string s = boost::algorithm::join(action, " ");
107     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
108   }
109 }
110
111 static std::vector<MPI_Request>* get_reqq_self()
112 {
113   return reqq.at(simgrid::s4u::this_actor::get_pid());
114 }
115
116 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
117 {
118   reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
119 }
120
121 /* Helper function */
122 static double parse_double(std::string string)
123 {
124   return xbt_str_parse_double(string.c_str(), "%s is not a double");
125 }
126
127 namespace simgrid {
128 namespace smpi {
129
130 namespace replay {
131
132 class RequestStorage {
133 private:
134     req_storage_t store;
135
136 public:
137     RequestStorage() {}
138     int size()
139     {
140       return store.size();
141     }
142
143     req_storage_t& get_store()
144     {
145       return store;
146     }
147
148     void get_requests(std::vector<MPI_Request>& vec)
149     {
150       for (auto& pair : store) {
151         auto& req = pair.second;
152         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
153         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
154           vec.push_back(pair.second);
155           pair.second->print_request("MM");
156         }
157       }
158     }
159
160     MPI_Request find(int src, int dst, int tag)
161     {
162       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
163       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
164     }
165
166     void remove(MPI_Request req)
167     {
168       if (req == MPI_REQUEST_NULL) return;
169
170       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
171     }
172
173     void add(MPI_Request req)
174     {
175       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
176         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
177     }
178
179     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
180     void addNullRequest(int src, int dst, int tag)
181     {
182       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
183     }
184 };
185
186 class ActionArgParser {
187 public:
188   virtual ~ActionArgParser() = default;
189   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
190 };
191
192 class WaitTestParser : public ActionArgParser {
193 public:
194   int src;
195   int dst;
196   int tag;
197
198   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
199   {
200     CHECK_ACTION_PARAMS(action, 3, 0)
201     src = std::stoi(action[2]);
202     dst = std::stoi(action[3]);
203     tag = std::stoi(action[4]);
204   }
205 };
206
207 class SendRecvParser : public ActionArgParser {
208 public:
209   /* communication partner; if we send, this is the receiver and vice versa */
210   int partner;
211   double size;
212   int tag;
213   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
214
215   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
216   {
217     CHECK_ACTION_PARAMS(action, 3, 1)
218     partner = std::stoi(action[2]);
219     tag     = std::stoi(action[3]);
220     size    = parse_double(action[4]);
221     if (action.size() > 5)
222       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
223   }
224 };
225
226 class ComputeParser : public ActionArgParser {
227 public:
228   /* communication partner; if we send, this is the receiver and vice versa */
229   double flops;
230
231   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
232   {
233     CHECK_ACTION_PARAMS(action, 1, 0)
234     flops = parse_double(action[2]);
235   }
236 };
237
238 class CollCommParser : public ActionArgParser {
239 public:
240   double size;
241   double comm_size;
242   double comp_size;
243   int send_size;
244   int recv_size;
245   int root = 0;
246   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
247   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
248 };
249
250 class BcastArgParser : public CollCommParser {
251 public:
252   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
253   {
254     CHECK_ACTION_PARAMS(action, 1, 2)
255     size = parse_double(action[2]);
256     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
257     if (action.size() > 4)
258       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
259   }
260 };
261
262 class ReduceArgParser : public CollCommParser {
263 public:
264   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
265   {
266     CHECK_ACTION_PARAMS(action, 2, 2)
267     comm_size = parse_double(action[2]);
268     comp_size = parse_double(action[3]);
269     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
270     if (action.size() > 5)
271       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
272   }
273 };
274
275 class AllReduceArgParser : public CollCommParser {
276 public:
277   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
278   {
279     CHECK_ACTION_PARAMS(action, 2, 1)
280     comm_size = parse_double(action[2]);
281     comp_size = parse_double(action[3]);
282     if (action.size() > 4)
283       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
284   }
285 };
286
287 class AllToAllArgParser : public CollCommParser {
288 public:
289   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
290   {
291     CHECK_ACTION_PARAMS(action, 2, 1)
292     comm_size = MPI_COMM_WORLD->size();
293     send_size = parse_double(action[2]);
294     recv_size = parse_double(action[3]);
295
296     if (action.size() > 4)
297       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
298     if (action.size() > 5)
299       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
300   }
301 };
302
303 class GatherArgParser : public CollCommParser {
304 public:
305   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
306   {
307     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
308           0 gather 68 68 0 0 0
309         where:
310           1) 68 is the sendcounts
311           2) 68 is the recvcounts
312           3) 0 is the root node
313           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
314           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
315     */
316     CHECK_ACTION_PARAMS(action, 2, 3)
317     comm_size = MPI_COMM_WORLD->size();
318     send_size = parse_double(action[2]);
319     recv_size = parse_double(action[3]);
320
321     if (name == "gather") {
322       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
323       if (action.size() > 5)
324         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
325       if (action.size() > 6)
326         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
327     }
328     else {
329       if (action.size() > 4)
330         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
331       if (action.size() > 5)
332         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
333     }
334   }
335 };
336
337 class GatherVArgParser : public CollCommParser {
338 public:
339   int recv_size_sum;
340   std::shared_ptr<std::vector<int>> recvcounts;
341   std::vector<int> disps;
342   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
343   {
344     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
345          0 gather 68 68 10 10 10 0 0 0
346        where:
347          1) 68 is the sendcount
348          2) 68 10 10 10 is the recvcounts
349          3) 0 is the root node
350          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
351          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
352     */
353     comm_size = MPI_COMM_WORLD->size();
354     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
355     send_size = parse_double(action[2]);
356     disps     = std::vector<int>(comm_size, 0);
357     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
358
359     if (name == "gatherV") {
360       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
361       if (action.size() > 4 + comm_size)
362         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
363       if (action.size() > 5 + comm_size)
364         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
365     }
366     else {
367       int datatype_index = 0;
368       int disp_index     = 0;
369       /* The 3 comes from "0 gather <sendcount>", which must always be present.
370        * The + comm_size is the recvcounts array, which must also be present
371        */
372       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
373         datatype_index = 3 + comm_size;
374         disp_index     = datatype_index + 1;
375         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
376         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
377       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
378         disp_index     = 3 + comm_size;
379       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
380         datatype_index = 3 + comm_size;
381         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
382         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
383       }
384
385       if (disp_index != 0) {
386         for (unsigned int i = 0; i < comm_size; i++)
387           disps[i]          = std::stoi(action[disp_index + i]);
388       }
389     }
390
391     for (unsigned int i = 0; i < comm_size; i++) {
392       (*recvcounts)[i] = std::stoi(action[i + 3]);
393     }
394     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
395   }
396 };
397
398 class ScatterArgParser : public CollCommParser {
399 public:
400   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
401   {
402     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
403           0 gather 68 68 0 0 0
404         where:
405           1) 68 is the sendcounts
406           2) 68 is the recvcounts
407           3) 0 is the root node
408           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
409           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
410     */
411     CHECK_ACTION_PARAMS(action, 2, 3)
412     comm_size   = MPI_COMM_WORLD->size();
413     send_size   = parse_double(action[2]);
414     recv_size   = parse_double(action[3]);
415     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
416     if (action.size() > 5)
417       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
418     if (action.size() > 6)
419       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
420   }
421 };
422
423 class ScatterVArgParser : public CollCommParser {
424 public:
425   int recv_size_sum;
426   int send_size_sum;
427   std::shared_ptr<std::vector<int>> sendcounts;
428   std::vector<int> disps;
429   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
430   {
431     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
432        0 gather 68 10 10 10 68 0 0 0
433         where:
434         1) 68 10 10 10 is the sendcounts
435         2) 68 is the recvcount
436         3) 0 is the root node
437         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
438         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
439     */
440     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
441     recv_size  = parse_double(action[2 + comm_size]);
442     disps      = std::vector<int>(comm_size, 0);
443     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
444
445     if (action.size() > 5 + comm_size)
446       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
447     if (action.size() > 5 + comm_size)
448       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
449
450     for (unsigned int i = 0; i < comm_size; i++) {
451       (*sendcounts)[i] = std::stoi(action[i + 2]);
452     }
453     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
454     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
455   }
456 };
457
458 class ReduceScatterArgParser : public CollCommParser {
459 public:
460   int recv_size_sum;
461   std::shared_ptr<std::vector<int>> recvcounts;
462   std::vector<int> disps;
463   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
464   {
465     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
466          0 reduceScatter 275427 275427 275427 204020 11346849 0
467        where:
468          1) The first four values after the name of the action declare the recvcounts array
469          2) The value 11346849 is the amount of instructions
470          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
471     */
472     comm_size = MPI_COMM_WORLD->size();
473     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
474     comp_size = parse_double(action[2+comm_size]);
475     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
476     if (action.size() > 3 + comm_size)
477       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
478
479     for (unsigned int i = 0; i < comm_size; i++) {
480       recvcounts->push_back(std::stoi(action[i + 2]));
481     }
482     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
483   }
484 };
485
486 class AllToAllVArgParser : public CollCommParser {
487 public:
488   int recv_size_sum;
489   int send_size_sum;
490   std::shared_ptr<std::vector<int>> recvcounts;
491   std::shared_ptr<std::vector<int>> sendcounts;
492   std::vector<int> senddisps;
493   std::vector<int> recvdisps;
494   int send_buf_size;
495   int recv_buf_size;
496   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
497   {
498     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
499           0 allToAllV 100 1 7 10 12 100 1 70 10 5
500        where:
501         1) 100 is the size of the send buffer *sizeof(int),
502         2) 1 7 10 12 is the sendcounts array
503         3) 100*sizeof(int) is the size of the receiver buffer
504         4)  1 70 10 5 is the recvcounts array
505     */
506     comm_size = MPI_COMM_WORLD->size();
507     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
508     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
509     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
510     senddisps  = std::vector<int>(comm_size, 0);
511     recvdisps  = std::vector<int>(comm_size, 0);
512
513     if (action.size() > 5 + 2 * comm_size)
514       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
515     if (action.size() > 5 + 2 * comm_size)
516       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
517
518     send_buf_size=parse_double(action[2]);
519     recv_buf_size=parse_double(action[3+comm_size]);
520     for (unsigned int i = 0; i < comm_size; i++) {
521       (*sendcounts)[i] = std::stoi(action[3 + i]);
522       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
523     }
524     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
525     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
526   }
527 };
528
529 template <class T> class ReplayAction {
530 protected:
531   const std::string name;
532   RequestStorage* req_storage; // Points to the right storage for this process, nullptr except for Send/Recv/Wait/Test actions.
533   const int my_proc_id;
534   T args;
535
536 public:
537   explicit ReplayAction(std::string name, RequestStorage& storage) : name(name), req_storage(&storage), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
538   explicit ReplayAction(std::string name) : name(name), req_storage(nullptr), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
539   virtual ~ReplayAction() = default;
540
541   virtual void execute(simgrid::xbt::ReplayAction& action)
542   {
543     // Needs to be re-initialized for every action, hence here
544     double start_time = smpi_process()->simulated_elapsed();
545     args.parse(action, name);
546     kernel(action);
547     if (name != "Init")
548       log_timed_action(action, start_time);
549   }
550
551   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
552
553   void* send_buffer(int size)
554   {
555     return smpi_get_tmp_sendbuffer(size);
556   }
557
558   void* recv_buffer(int size)
559   {
560     return smpi_get_tmp_recvbuffer(size);
561   }
562 };
563
564 class WaitAction : public ReplayAction<WaitTestParser> {
565 public:
566   WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
567   void kernel(simgrid::xbt::ReplayAction& action) override
568   {
569     std::string s = boost::algorithm::join(action, " ");
570     xbt_assert(req_storage->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
571     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
572     req_storage->remove(request);
573
574     if (request == MPI_REQUEST_NULL) {
575       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
576        * return.*/
577       return;
578     }
579
580     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
581
582     // Must be taken before Request::wait() since the request may be set to
583     // MPI_REQUEST_NULL by Request::wait!
584     bool is_wait_for_receive = (request->flags() & RECV);
585     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
586     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
587
588     MPI_Status status;
589     Request::wait(&request, &status);
590
591     TRACE_smpi_comm_out(rank);
592     if (is_wait_for_receive)
593       TRACE_smpi_recv(args.src, args.dst, args.tag);
594   }
595 };
596
597 class SendAction : public ReplayAction<SendRecvParser> {
598 public:
599   SendAction() = delete;
600   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
601   void kernel(simgrid::xbt::ReplayAction& action) override
602   {
603     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
604
605     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
606                                                                              args.tag, Datatype::encode(args.datatype1)));
607     if (not TRACE_smpi_view_internals())
608       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
609
610     if (name == "send") {
611       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
612     } else if (name == "Isend") {
613       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
614       req_storage->add(request);
615     } else {
616       xbt_die("Don't know this action, %s", name.c_str());
617     }
618
619     TRACE_smpi_comm_out(my_proc_id);
620   }
621 };
622
623 class RecvAction : public ReplayAction<SendRecvParser> {
624 public:
625   RecvAction() = delete;
626   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
627   void kernel(simgrid::xbt::ReplayAction& action) override
628   {
629     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
630
631     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
632                                                                              args.tag, Datatype::encode(args.datatype1)));
633
634     MPI_Status status;
635     // unknown size from the receiver point of view
636     if (args.size <= 0.0) {
637       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
638       args.size = status.count;
639     }
640
641     if (name == "recv") {
642       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
643     } else if (name == "Irecv") {
644       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
645       req_storage->add(request);
646     }
647
648     TRACE_smpi_comm_out(my_proc_id);
649     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
650     if (name == "recv" && not TRACE_smpi_view_internals()) {
651       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
652     }
653   }
654 };
655
656 class ComputeAction : public ReplayAction<ComputeParser> {
657 public:
658   ComputeAction() : ReplayAction("compute") {}
659   void kernel(simgrid::xbt::ReplayAction& action) override
660   {
661     TRACE_smpi_computing_in(my_proc_id, args.flops);
662     smpi_execute_flops(args.flops);
663     TRACE_smpi_computing_out(my_proc_id);
664   }
665 };
666
667 class TestAction : public ReplayAction<WaitTestParser> {
668 public:
669   TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
670   void kernel(simgrid::xbt::ReplayAction& action) override
671   {
672     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
673     req_storage->remove(request);
674     // if request is null here, this may mean that a previous test has succeeded
675     // Different times in traced application and replayed version may lead to this
676     // In this case, ignore the extra calls.
677     if (request != MPI_REQUEST_NULL) {
678       TRACE_smpi_testing_in(my_proc_id);
679
680       MPI_Status status;
681       int flag = Request::test(&request, &status);
682
683       XBT_DEBUG("MPI_Test result: %d", flag);
684       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
685        * nullptr.*/
686       if (request == MPI_REQUEST_NULL)
687         req_storage->addNullRequest(args.src, args.dst, args.tag);
688       else
689         req_storage->add(request);
690
691       TRACE_smpi_testing_out(my_proc_id);
692     }
693   }
694 };
695
696 class InitAction : public ReplayAction<ActionArgParser> {
697 public:
698   InitAction() : ReplayAction("Init") {}
699   void kernel(simgrid::xbt::ReplayAction& action) override
700   {
701     CHECK_ACTION_PARAMS(action, 0, 1)
702     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
703                                            : MPI_BYTE;  // default TAU datatype
704
705     /* start a simulated timer */
706     smpi_process()->simulated_start();
707   }
708 };
709
710 class CommunicatorAction : public ReplayAction<ActionArgParser> {
711 public:
712   CommunicatorAction() : ReplayAction("Comm") {}
713   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
714 };
715
716 class WaitAllAction : public ReplayAction<ActionArgParser> {
717 public:
718   WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
719   void kernel(simgrid::xbt::ReplayAction& action) override
720   {
721     const unsigned int count_requests = req_storage->size();
722
723     if (count_requests > 0) {
724       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
725       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
726       std::vector<MPI_Request> reqs;
727       req_storage->get_requests(reqs);
728       for (const auto& req : reqs) {
729         if (req && (req->flags() & RECV)) {
730           sender_receiver.push_back({req->src(), req->dst()});
731         }
732       }
733       MPI_Status status[count_requests];
734       Request::waitall(count_requests, &(reqs.data())[0], status);
735
736       for (auto& pair : sender_receiver) {
737         TRACE_smpi_recv(pair.first, pair.second, 0);
738       }
739       TRACE_smpi_comm_out(my_proc_id);
740     }
741   }
742 };
743
744 class BarrierAction : public ReplayAction<ActionArgParser> {
745 public:
746   BarrierAction() : ReplayAction("barrier") {}
747   void kernel(simgrid::xbt::ReplayAction& action) override
748   {
749     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
750     Colls::barrier(MPI_COMM_WORLD);
751     TRACE_smpi_comm_out(my_proc_id);
752   }
753 };
754
755 class BcastAction : public ReplayAction<BcastArgParser> {
756 public:
757   BcastAction() : ReplayAction("bcast") {}
758   void kernel(simgrid::xbt::ReplayAction& action) override
759   {
760     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
761                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
762                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
763
764     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
765
766     TRACE_smpi_comm_out(my_proc_id);
767   }
768 };
769
770 class ReduceAction : public ReplayAction<ReduceArgParser> {
771 public:
772   ReduceAction() : ReplayAction("reduce") {}
773   void kernel(simgrid::xbt::ReplayAction& action) override
774   {
775     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
776                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
777                                                       args.comp_size, args.comm_size, -1,
778                                                       Datatype::encode(args.datatype1), ""));
779
780     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
781         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
782     smpi_execute_flops(args.comp_size);
783
784     TRACE_smpi_comm_out(my_proc_id);
785   }
786 };
787
788 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
789 public:
790   AllReduceAction() : ReplayAction("allReduce") {}
791   void kernel(simgrid::xbt::ReplayAction& action) override
792   {
793     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
794                                                                                 Datatype::encode(args.datatype1), ""));
795
796     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
797         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
798     smpi_execute_flops(args.comp_size);
799
800     TRACE_smpi_comm_out(my_proc_id);
801   }
802 };
803
804 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
805 public:
806   AllToAllAction() : ReplayAction("allToAll") {}
807   void kernel(simgrid::xbt::ReplayAction& action) override
808   {
809     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
810                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
811                                                     Datatype::encode(args.datatype1),
812                                                     Datatype::encode(args.datatype2)));
813
814     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
815                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
816                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
817
818     TRACE_smpi_comm_out(my_proc_id);
819   }
820 };
821
822 class GatherAction : public ReplayAction<GatherArgParser> {
823 public:
824   explicit GatherAction(std::string name) : ReplayAction(name) {}
825   void kernel(simgrid::xbt::ReplayAction& action) override
826   {
827     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
828                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
829
830     if (name == "gather") {
831       int rank = MPI_COMM_WORLD->rank();
832       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
833                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
834     }
835     else
836       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
837                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
838
839     TRACE_smpi_comm_out(my_proc_id);
840   }
841 };
842
843 class GatherVAction : public ReplayAction<GatherVArgParser> {
844 public:
845   explicit GatherVAction(std::string name) : ReplayAction(name) {}
846   void kernel(simgrid::xbt::ReplayAction& action) override
847   {
848     int rank = MPI_COMM_WORLD->rank();
849
850     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
851                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
852                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
853
854     if (name == "gatherV") {
855       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
856                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
857                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
858     }
859     else {
860       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
861                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
862                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
863     }
864
865     TRACE_smpi_comm_out(my_proc_id);
866   }
867 };
868
869 class ScatterAction : public ReplayAction<ScatterArgParser> {
870 public:
871   ScatterAction() : ReplayAction("scatter") {}
872   void kernel(simgrid::xbt::ReplayAction& action) override
873   {
874     int rank = MPI_COMM_WORLD->rank();
875     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
876                                                                           Datatype::encode(args.datatype1),
877                                                                           Datatype::encode(args.datatype2)));
878
879     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
880                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
881
882     TRACE_smpi_comm_out(my_proc_id);
883   }
884 };
885
886
887 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
888 public:
889   ScatterVAction() : ReplayAction("scatterV") {}
890   void kernel(simgrid::xbt::ReplayAction& action) override
891   {
892     int rank = MPI_COMM_WORLD->rank();
893     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
894           nullptr, Datatype::encode(args.datatype1),
895           Datatype::encode(args.datatype2)));
896
897     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
898                     args.sendcounts->data(), args.disps.data(), args.datatype1,
899                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
900                     MPI_COMM_WORLD);
901
902     TRACE_smpi_comm_out(my_proc_id);
903   }
904 };
905
906 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
907 public:
908   ReduceScatterAction() : ReplayAction("reduceScatter") {}
909   void kernel(simgrid::xbt::ReplayAction& action) override
910   {
911     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
912                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
913                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
914                                                          Datatype::encode(args.datatype1)));
915
916     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
917                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
918                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
919
920     smpi_execute_flops(args.comp_size);
921     TRACE_smpi_comm_out(my_proc_id);
922   }
923 };
924
925 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
926 public:
927   AllToAllVAction() : ReplayAction("allToAllV") {}
928   void kernel(simgrid::xbt::ReplayAction& action) override
929   {
930     TRACE_smpi_comm_in(my_proc_id, __func__,
931                        new simgrid::instr::VarCollTIData(
932                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
933                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
934
935     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
936                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
937
938     TRACE_smpi_comm_out(my_proc_id);
939   }
940 };
941 } // Replay Namespace
942 }} // namespace simgrid::smpi
943
944 std::vector<simgrid::smpi::replay::RequestStorage> storage;
945 /** @brief Only initialize the replay, don't do it for real */
946 void smpi_replay_init(int* argc, char*** argv)
947 {
948   simgrid::smpi::Process::init(argc, argv);
949   smpi_process()->mark_as_initialized();
950   smpi_process()->set_replaying(true);
951
952   int my_proc_id = simgrid::s4u::this_actor::get_pid();
953   storage.resize(smpi_process_count());
954
955   TRACE_smpi_init(my_proc_id);
956   TRACE_smpi_computing_init(my_proc_id);
957   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
958   TRACE_smpi_comm_out(my_proc_id);
959   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
960   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
961   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
962   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
963   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
964
965   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
966   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
967   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
968   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
969   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
970   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
971   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
972   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
973   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
974   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
975   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
976   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
977   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
978   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
979   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
980   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
981   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
982   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
983   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
984   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
985   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
986
987   //if we have a delayed start, sleep here.
988   if(*argc>2){
989     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
990     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
991     smpi_execute_flops(value);
992   } else {
993     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
994     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
995     smpi_execute_flops(0.0);
996   }
997 }
998
999 /** @brief actually run the replay after initialization */
1000 void smpi_replay_main(int* argc, char*** argv)
1001 {
1002   static int active_processes = 0;
1003   active_processes++;
1004   simgrid::xbt::replay_runner(*argc, *argv);
1005
1006   /* and now, finalize everything */
1007   /* One active process will stop. Decrease the counter*/
1008   unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
1009   XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
1010   if (count_requests > 0) {
1011     MPI_Request requests[count_requests];
1012     MPI_Status status[count_requests];
1013     unsigned int i=0;
1014
1015     for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
1016       requests[i] = pair.second;
1017       i++;
1018     }
1019     simgrid::smpi::Request::waitall(count_requests, requests, status);
1020   }
1021   active_processes--;
1022
1023   if(active_processes==0){
1024     /* Last process alive speaking: end the simulated timer */
1025     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1026     smpi_free_replay_tmp_buffers();
1027   }
1028
1029   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1030                      new simgrid::instr::NoOpTIData("finalize"));
1031
1032   smpi_process()->finalize();
1033
1034   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1035   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1036 }
1037
1038 /** @brief chain a replay initialization and a replay start */
1039 void smpi_replay_run(int* argc, char*** argv)
1040 {
1041   smpi_replay_init(argc, argv);
1042   smpi_replay_main(argc, argv);
1043 }