Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Move ActionWait to the RequestStore
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
81 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
82 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
83
84 static MPI_Datatype MPI_DEFAULT_TYPE;
85
86 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
87   {                                                                                                                    \
88     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
89       std::stringstream ss;                                                                                            \
90       for (const auto& elem : action) {                                                                                \
91         ss << elem << " ";                                                                                             \
92       }                                                                                                                \
93       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
94                            "%zu items were given on the line. First two should be process_id and action.  "            \
95                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
96                            "The full line that was given is:\n   %s\n"                                                 \
97                            "Please contact the Simgrid team if support is needed",                                     \
98              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
99              ss.str().c_str());                                                                                        \
100     }                                                                                                                  \
101   }
102
103 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
104 {
105   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
106     std::string s = boost::algorithm::join(action, " ");
107     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
108   }
109 }
110
111 static std::vector<MPI_Request>* get_reqq_self()
112 {
113   return reqq.at(simgrid::s4u::this_actor::get_pid());
114 }
115
116 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
117 {
118   reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
119 }
120
121 /* Helper function */
122 static double parse_double(std::string string)
123 {
124   return xbt_str_parse_double(string.c_str(), "%s is not a double");
125 }
126
127 namespace simgrid {
128 namespace smpi {
129
130 namespace replay {
131
132 class RequestStorage {
133 private:
134     req_storage_t store;
135
136 public:
137     RequestStorage() {}
138     int size()
139     {
140       return store.size();
141     }
142
143     req_storage_t& get_store()
144     {
145       return store;
146     }
147
148     void get_requests(std::vector<MPI_Request>& vec)
149     {
150       for (auto& pair : store) {
151         auto& req = pair.second;
152         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
153         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
154           vec.push_back(pair.second);
155           pair.second->print_request("MM");
156         }
157       }
158     }
159
160     MPI_Request find(int src, int dst, int tag)
161     {
162       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
163       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
164     }
165
166     void remove(MPI_Request req)
167     {
168       if (req == MPI_REQUEST_NULL) return;
169
170       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
171     }
172
173     void add(MPI_Request req)
174     {
175       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
176         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
177     }
178
179     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
180     void addNullRequest(int src, int dst, int tag)
181     {
182       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
183     }
184 };
185
186 class ActionArgParser {
187 public:
188   virtual ~ActionArgParser() = default;
189   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
190 };
191
192 class WaitTestParser : public ActionArgParser {
193 public:
194   int src;
195   int dst;
196   int tag;
197
198   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
199   {
200     CHECK_ACTION_PARAMS(action, 3, 0)
201     src = std::stoi(action[2]);
202     dst = std::stoi(action[3]);
203     tag = std::stoi(action[4]);
204   }
205 };
206
207 class SendRecvParser : public ActionArgParser {
208 public:
209   /* communication partner; if we send, this is the receiver and vice versa */
210   int partner;
211   double size;
212   int tag;
213   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
214
215   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
216   {
217     CHECK_ACTION_PARAMS(action, 3, 1)
218     partner = std::stoi(action[2]);
219     tag     = std::stoi(action[3]);
220     size    = parse_double(action[4]);
221     if (action.size() > 5)
222       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
223   }
224 };
225
226 class ComputeParser : public ActionArgParser {
227 public:
228   /* communication partner; if we send, this is the receiver and vice versa */
229   double flops;
230
231   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
232   {
233     CHECK_ACTION_PARAMS(action, 1, 0)
234     flops = parse_double(action[2]);
235   }
236 };
237
238 class CollCommParser : public ActionArgParser {
239 public:
240   double size;
241   double comm_size;
242   double comp_size;
243   int send_size;
244   int recv_size;
245   int root = 0;
246   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
247   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
248 };
249
250 class BcastArgParser : public CollCommParser {
251 public:
252   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
253   {
254     CHECK_ACTION_PARAMS(action, 1, 2)
255     size = parse_double(action[2]);
256     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
257     if (action.size() > 4)
258       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
259   }
260 };
261
262 class ReduceArgParser : public CollCommParser {
263 public:
264   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
265   {
266     CHECK_ACTION_PARAMS(action, 2, 2)
267     comm_size = parse_double(action[2]);
268     comp_size = parse_double(action[3]);
269     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
270     if (action.size() > 5)
271       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
272   }
273 };
274
275 class AllReduceArgParser : public CollCommParser {
276 public:
277   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
278   {
279     CHECK_ACTION_PARAMS(action, 2, 1)
280     comm_size = parse_double(action[2]);
281     comp_size = parse_double(action[3]);
282     if (action.size() > 4)
283       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
284   }
285 };
286
287 class AllToAllArgParser : public CollCommParser {
288 public:
289   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
290   {
291     CHECK_ACTION_PARAMS(action, 2, 1)
292     comm_size = MPI_COMM_WORLD->size();
293     send_size = parse_double(action[2]);
294     recv_size = parse_double(action[3]);
295
296     if (action.size() > 4)
297       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
298     if (action.size() > 5)
299       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
300   }
301 };
302
303 class GatherArgParser : public CollCommParser {
304 public:
305   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
306   {
307     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
308           0 gather 68 68 0 0 0
309         where:
310           1) 68 is the sendcounts
311           2) 68 is the recvcounts
312           3) 0 is the root node
313           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
314           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
315     */
316     CHECK_ACTION_PARAMS(action, 2, 3)
317     comm_size = MPI_COMM_WORLD->size();
318     send_size = parse_double(action[2]);
319     recv_size = parse_double(action[3]);
320
321     if (name == "gather") {
322       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
323       if (action.size() > 5)
324         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
325       if (action.size() > 6)
326         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
327     }
328     else {
329       if (action.size() > 4)
330         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
331       if (action.size() > 5)
332         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
333     }
334   }
335 };
336
337 class GatherVArgParser : public CollCommParser {
338 public:
339   int recv_size_sum;
340   std::shared_ptr<std::vector<int>> recvcounts;
341   std::vector<int> disps;
342   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
343   {
344     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
345          0 gather 68 68 10 10 10 0 0 0
346        where:
347          1) 68 is the sendcount
348          2) 68 10 10 10 is the recvcounts
349          3) 0 is the root node
350          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
351          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
352     */
353     comm_size = MPI_COMM_WORLD->size();
354     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
355     send_size = parse_double(action[2]);
356     disps     = std::vector<int>(comm_size, 0);
357     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
358
359     if (name == "gatherV") {
360       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
361       if (action.size() > 4 + comm_size)
362         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
363       if (action.size() > 5 + comm_size)
364         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
365     }
366     else {
367       int datatype_index = 0;
368       int disp_index     = 0;
369       /* The 3 comes from "0 gather <sendcount>", which must always be present.
370        * The + comm_size is the recvcounts array, which must also be present
371        */
372       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
373         datatype_index = 3 + comm_size;
374         disp_index     = datatype_index + 1;
375         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
376         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
377       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
378         disp_index     = 3 + comm_size;
379       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
380         datatype_index = 3 + comm_size;
381         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
382         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
383       }
384
385       if (disp_index != 0) {
386         for (unsigned int i = 0; i < comm_size; i++)
387           disps[i]          = std::stoi(action[disp_index + i]);
388       }
389     }
390
391     for (unsigned int i = 0; i < comm_size; i++) {
392       (*recvcounts)[i] = std::stoi(action[i + 3]);
393     }
394     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
395   }
396 };
397
398 class ScatterArgParser : public CollCommParser {
399 public:
400   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
401   {
402     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
403           0 gather 68 68 0 0 0
404         where:
405           1) 68 is the sendcounts
406           2) 68 is the recvcounts
407           3) 0 is the root node
408           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
409           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
410     */
411     CHECK_ACTION_PARAMS(action, 2, 3)
412     comm_size   = MPI_COMM_WORLD->size();
413     send_size   = parse_double(action[2]);
414     recv_size   = parse_double(action[3]);
415     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
416     if (action.size() > 5)
417       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
418     if (action.size() > 6)
419       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
420   }
421 };
422
423 class ScatterVArgParser : public CollCommParser {
424 public:
425   int recv_size_sum;
426   int send_size_sum;
427   std::shared_ptr<std::vector<int>> sendcounts;
428   std::vector<int> disps;
429   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
430   {
431     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
432        0 gather 68 10 10 10 68 0 0 0
433         where:
434         1) 68 10 10 10 is the sendcounts
435         2) 68 is the recvcount
436         3) 0 is the root node
437         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
438         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
439     */
440     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
441     recv_size  = parse_double(action[2 + comm_size]);
442     disps      = std::vector<int>(comm_size, 0);
443     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
444
445     if (action.size() > 5 + comm_size)
446       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
447     if (action.size() > 5 + comm_size)
448       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
449
450     for (unsigned int i = 0; i < comm_size; i++) {
451       (*sendcounts)[i] = std::stoi(action[i + 2]);
452     }
453     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
454     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
455   }
456 };
457
458 class ReduceScatterArgParser : public CollCommParser {
459 public:
460   int recv_size_sum;
461   std::shared_ptr<std::vector<int>> recvcounts;
462   std::vector<int> disps;
463   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
464   {
465     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
466          0 reduceScatter 275427 275427 275427 204020 11346849 0
467        where:
468          1) The first four values after the name of the action declare the recvcounts array
469          2) The value 11346849 is the amount of instructions
470          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
471     */
472     comm_size = MPI_COMM_WORLD->size();
473     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
474     comp_size = parse_double(action[2+comm_size]);
475     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
476     if (action.size() > 3 + comm_size)
477       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
478
479     for (unsigned int i = 0; i < comm_size; i++) {
480       recvcounts->push_back(std::stoi(action[i + 2]));
481     }
482     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
483   }
484 };
485
486 class AllToAllVArgParser : public CollCommParser {
487 public:
488   int recv_size_sum;
489   int send_size_sum;
490   std::shared_ptr<std::vector<int>> recvcounts;
491   std::shared_ptr<std::vector<int>> sendcounts;
492   std::vector<int> senddisps;
493   std::vector<int> recvdisps;
494   int send_buf_size;
495   int recv_buf_size;
496   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
497   {
498     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
499           0 allToAllV 100 1 7 10 12 100 1 70 10 5
500        where:
501         1) 100 is the size of the send buffer *sizeof(int),
502         2) 1 7 10 12 is the sendcounts array
503         3) 100*sizeof(int) is the size of the receiver buffer
504         4)  1 70 10 5 is the recvcounts array
505     */
506     comm_size = MPI_COMM_WORLD->size();
507     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
508     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
509     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
510     senddisps  = std::vector<int>(comm_size, 0);
511     recvdisps  = std::vector<int>(comm_size, 0);
512
513     if (action.size() > 5 + 2 * comm_size)
514       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
515     if (action.size() > 5 + 2 * comm_size)
516       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
517
518     send_buf_size=parse_double(action[2]);
519     recv_buf_size=parse_double(action[3+comm_size]);
520     for (unsigned int i = 0; i < comm_size; i++) {
521       (*sendcounts)[i] = std::stoi(action[3 + i]);
522       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
523     }
524     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
525     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
526   }
527 };
528
529 template <class T> class ReplayAction {
530 protected:
531   const std::string name;
532   RequestStorage* req_storage; // Points to the right storage for this process, nullptr except for Send/Recv/Wait/Test actions.
533   const int my_proc_id;
534   T args;
535
536 public:
537   explicit ReplayAction(std::string name, RequestStorage& storage) : name(name), req_storage(&storage), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
538   explicit ReplayAction(std::string name) : name(name), req_storage(nullptr), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
539   virtual ~ReplayAction() = default;
540
541   virtual void execute(simgrid::xbt::ReplayAction& action)
542   {
543     // Needs to be re-initialized for every action, hence here
544     double start_time = smpi_process()->simulated_elapsed();
545     args.parse(action, name);
546     kernel(action);
547     if (name != "Init")
548       log_timed_action(action, start_time);
549   }
550
551   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
552
553   void* send_buffer(int size)
554   {
555     return smpi_get_tmp_sendbuffer(size);
556   }
557
558   void* recv_buffer(int size)
559   {
560     return smpi_get_tmp_recvbuffer(size);
561   }
562 };
563
564 class WaitAction : public ReplayAction<WaitTestParser> {
565 public:
566   WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
567   void kernel(simgrid::xbt::ReplayAction& action) override
568   {
569     std::string s = boost::algorithm::join(action, " ");
570     xbt_assert(req_storage->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
571     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
572     req_storage->remove(request);
573
574     if (request == MPI_REQUEST_NULL) {
575       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
576        * return.*/
577       return;
578     }
579
580     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
581
582     // Must be taken before Request::wait() since the request may be set to
583     // MPI_REQUEST_NULL by Request::wait!
584     bool is_wait_for_receive = (request->flags() & RECV);
585     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
586     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
587
588     MPI_Status status;
589     Request::wait(&request, &status);
590
591     TRACE_smpi_comm_out(rank);
592     if (is_wait_for_receive)
593       TRACE_smpi_recv(args.src, args.dst, args.tag);
594   }
595 };
596
597 class SendAction : public ReplayAction<SendRecvParser> {
598 public:
599   SendAction() = delete;
600   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
601   void kernel(simgrid::xbt::ReplayAction& action) override
602   {
603     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
604
605     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
606                                                                              args.tag, Datatype::encode(args.datatype1)));
607     if (not TRACE_smpi_view_internals())
608       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
609
610     if (name == "send") {
611       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
612     } else if (name == "Isend") {
613       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
614       get_reqq_self()->push_back(request);
615     } else {
616       xbt_die("Don't know this action, %s", name.c_str());
617     }
618
619     TRACE_smpi_comm_out(my_proc_id);
620   }
621 };
622
623 class RecvAction : public ReplayAction<SendRecvParser> {
624 public:
625   RecvAction() = delete;
626   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
627   void kernel(simgrid::xbt::ReplayAction& action) override
628   {
629     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
630
631     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
632                                                                              args.tag, Datatype::encode(args.datatype1)));
633
634     MPI_Status status;
635     // unknown size from the receiver point of view
636     if (args.size <= 0.0) {
637       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
638       args.size = status.count;
639     }
640
641     if (name == "recv") {
642       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
643     } else if (name == "Irecv") {
644       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
645       get_reqq_self()->push_back(request);
646     }
647
648     TRACE_smpi_comm_out(my_proc_id);
649     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
650     if (name == "recv" && not TRACE_smpi_view_internals()) {
651       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
652     }
653   }
654 };
655
656 class ComputeAction : public ReplayAction<ComputeParser> {
657 public:
658   ComputeAction() : ReplayAction("compute") {}
659   void kernel(simgrid::xbt::ReplayAction& action) override
660   {
661     TRACE_smpi_computing_in(my_proc_id, args.flops);
662     smpi_execute_flops(args.flops);
663     TRACE_smpi_computing_out(my_proc_id);
664   }
665 };
666
667 class TestAction : public ReplayAction<WaitTestParser> {
668 public:
669   TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
670   void kernel(simgrid::xbt::ReplayAction& action) override
671   {
672     MPI_Request request = get_reqq_self()->back();
673     get_reqq_self()->pop_back();
674     // if request is null here, this may mean that a previous test has succeeded
675     // Different times in traced application and replayed version may lead to this
676     // In this case, ignore the extra calls.
677     if (request != MPI_REQUEST_NULL) {
678       TRACE_smpi_testing_in(my_proc_id);
679
680       MPI_Status status;
681       int flag = Request::test(&request, &status);
682
683       XBT_DEBUG("MPI_Test result: %d", flag);
684       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
685        * nullptr.*/
686       get_reqq_self()->push_back(request);
687
688       TRACE_smpi_testing_out(my_proc_id);
689     }
690   }
691 };
692
693 class InitAction : public ReplayAction<ActionArgParser> {
694 public:
695   InitAction() : ReplayAction("Init") {}
696   void kernel(simgrid::xbt::ReplayAction& action) override
697   {
698     CHECK_ACTION_PARAMS(action, 0, 1)
699     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
700                                            : MPI_BYTE;  // default TAU datatype
701
702     /* start a simulated timer */
703     smpi_process()->simulated_start();
704     set_reqq_self(new std::vector<MPI_Request>);
705   }
706 };
707
708 class CommunicatorAction : public ReplayAction<ActionArgParser> {
709 public:
710   CommunicatorAction() : ReplayAction("Comm") {}
711   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
712 };
713
714 class WaitAllAction : public ReplayAction<ActionArgParser> {
715 public:
716   WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
717   void kernel(simgrid::xbt::ReplayAction& action) override
718   {
719     const unsigned int count_requests = get_reqq_self()->size();
720
721     if (count_requests > 0) {
722       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
723       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
724       for (const auto& req : (*get_reqq_self())) {
725         if (req && (req->flags() & RECV)) {
726           sender_receiver.push_back({req->src(), req->dst()});
727         }
728       }
729       MPI_Status status[count_requests];
730       Request::waitall(count_requests, &(*get_reqq_self())[0], status);
731
732       for (auto& pair : sender_receiver) {
733         TRACE_smpi_recv(pair.first, pair.second, 0);
734       }
735       TRACE_smpi_comm_out(my_proc_id);
736     }
737   }
738 };
739
740 class BarrierAction : public ReplayAction<ActionArgParser> {
741 public:
742   BarrierAction() : ReplayAction("barrier") {}
743   void kernel(simgrid::xbt::ReplayAction& action) override
744   {
745     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
746     Colls::barrier(MPI_COMM_WORLD);
747     TRACE_smpi_comm_out(my_proc_id);
748   }
749 };
750
751 class BcastAction : public ReplayAction<BcastArgParser> {
752 public:
753   BcastAction() : ReplayAction("bcast") {}
754   void kernel(simgrid::xbt::ReplayAction& action) override
755   {
756     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
757                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
758                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
759
760     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
761
762     TRACE_smpi_comm_out(my_proc_id);
763   }
764 };
765
766 class ReduceAction : public ReplayAction<ReduceArgParser> {
767 public:
768   ReduceAction() : ReplayAction("reduce") {}
769   void kernel(simgrid::xbt::ReplayAction& action) override
770   {
771     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
772                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
773                                                       args.comp_size, args.comm_size, -1,
774                                                       Datatype::encode(args.datatype1), ""));
775
776     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
777         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
778     smpi_execute_flops(args.comp_size);
779
780     TRACE_smpi_comm_out(my_proc_id);
781   }
782 };
783
784 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
785 public:
786   AllReduceAction() : ReplayAction("allReduce") {}
787   void kernel(simgrid::xbt::ReplayAction& action) override
788   {
789     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
790                                                                                 Datatype::encode(args.datatype1), ""));
791
792     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
793         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
794     smpi_execute_flops(args.comp_size);
795
796     TRACE_smpi_comm_out(my_proc_id);
797   }
798 };
799
800 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
801 public:
802   AllToAllAction() : ReplayAction("allToAll") {}
803   void kernel(simgrid::xbt::ReplayAction& action) override
804   {
805     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
806                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
807                                                     Datatype::encode(args.datatype1),
808                                                     Datatype::encode(args.datatype2)));
809
810     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
811                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
812                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
813
814     TRACE_smpi_comm_out(my_proc_id);
815   }
816 };
817
818 class GatherAction : public ReplayAction<GatherArgParser> {
819 public:
820   explicit GatherAction(std::string name) : ReplayAction(name) {}
821   void kernel(simgrid::xbt::ReplayAction& action) override
822   {
823     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
824                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
825
826     if (name == "gather") {
827       int rank = MPI_COMM_WORLD->rank();
828       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
829                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
830     }
831     else
832       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
833                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
834
835     TRACE_smpi_comm_out(my_proc_id);
836   }
837 };
838
839 class GatherVAction : public ReplayAction<GatherVArgParser> {
840 public:
841   explicit GatherVAction(std::string name) : ReplayAction(name) {}
842   void kernel(simgrid::xbt::ReplayAction& action) override
843   {
844     int rank = MPI_COMM_WORLD->rank();
845
846     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
847                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
848                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
849
850     if (name == "gatherV") {
851       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
852                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
853                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
854     }
855     else {
856       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
857                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
858                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
859     }
860
861     TRACE_smpi_comm_out(my_proc_id);
862   }
863 };
864
865 class ScatterAction : public ReplayAction<ScatterArgParser> {
866 public:
867   ScatterAction() : ReplayAction("scatter") {}
868   void kernel(simgrid::xbt::ReplayAction& action) override
869   {
870     int rank = MPI_COMM_WORLD->rank();
871     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
872                                                                           Datatype::encode(args.datatype1),
873                                                                           Datatype::encode(args.datatype2)));
874
875     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
876                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
877
878     TRACE_smpi_comm_out(my_proc_id);
879   }
880 };
881
882
883 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
884 public:
885   ScatterVAction() : ReplayAction("scatterV") {}
886   void kernel(simgrid::xbt::ReplayAction& action) override
887   {
888     int rank = MPI_COMM_WORLD->rank();
889     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
890           nullptr, Datatype::encode(args.datatype1),
891           Datatype::encode(args.datatype2)));
892
893     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
894                     args.sendcounts->data(), args.disps.data(), args.datatype1,
895                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
896                     MPI_COMM_WORLD);
897
898     TRACE_smpi_comm_out(my_proc_id);
899   }
900 };
901
902 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
903 public:
904   ReduceScatterAction() : ReplayAction("reduceScatter") {}
905   void kernel(simgrid::xbt::ReplayAction& action) override
906   {
907     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
908                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
909                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
910                                                          Datatype::encode(args.datatype1)));
911
912     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
913                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
914                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
915
916     smpi_execute_flops(args.comp_size);
917     TRACE_smpi_comm_out(my_proc_id);
918   }
919 };
920
921 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
922 public:
923   AllToAllVAction() : ReplayAction("allToAllV") {}
924   void kernel(simgrid::xbt::ReplayAction& action) override
925   {
926     TRACE_smpi_comm_in(my_proc_id, __func__,
927                        new simgrid::instr::VarCollTIData(
928                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
929                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
930
931     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
932                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
933
934     TRACE_smpi_comm_out(my_proc_id);
935   }
936 };
937 } // Replay Namespace
938 }} // namespace simgrid::smpi
939
940 std::vector<simgrid::smpi::replay::RequestStorage> storage;
941 /** @brief Only initialize the replay, don't do it for real */
942 void smpi_replay_init(int* argc, char*** argv)
943 {
944   simgrid::smpi::Process::init(argc, argv);
945   smpi_process()->mark_as_initialized();
946   smpi_process()->set_replaying(true);
947
948   int my_proc_id = simgrid::s4u::this_actor::get_pid();
949   storage.resize(smpi_process_count());
950
951   TRACE_smpi_init(my_proc_id);
952   TRACE_smpi_computing_init(my_proc_id);
953   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
954   TRACE_smpi_comm_out(my_proc_id);
955   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
956   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
957   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
958   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
959   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
960
961   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
962   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
963   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
964   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
965   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
966   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
967   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
968   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
969   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
970   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
971   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
972   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
973   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
974   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
975   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
976   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
977   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
978   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
979   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
980   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
981   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
982
983   //if we have a delayed start, sleep here.
984   if(*argc>2){
985     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
986     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
987     smpi_execute_flops(value);
988   } else {
989     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
990     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
991     smpi_execute_flops(0.0);
992   }
993 }
994
995 /** @brief actually run the replay after initialization */
996 void smpi_replay_main(int* argc, char*** argv)
997 {
998   static int active_processes = 0;
999   active_processes++;
1000   simgrid::xbt::replay_runner(*argc, *argv);
1001
1002   /* and now, finalize everything */
1003   /* One active process will stop. Decrease the counter*/
1004   XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
1005   if (not get_reqq_self()->empty()) {
1006     unsigned int count_requests=get_reqq_self()->size();
1007     MPI_Request requests[count_requests];
1008     MPI_Status status[count_requests];
1009     unsigned int i=0;
1010
1011     for (auto const& req : *get_reqq_self()) {
1012       requests[i] = req;
1013       i++;
1014     }
1015     simgrid::smpi::Request::waitall(count_requests, requests, status);
1016   }
1017   delete get_reqq_self();
1018   active_processes--;
1019
1020   if(active_processes==0){
1021     /* Last process alive speaking: end the simulated timer */
1022     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1023     smpi_free_replay_tmp_buffers();
1024   }
1025
1026   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1027                      new simgrid::instr::NoOpTIData("finalize"));
1028
1029   smpi_process()->finalize();
1030
1031   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1032   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1033 }
1034
1035 /** @brief chain a replay initialization and a replay start */
1036 void smpi_replay_run(int* argc, char*** argv)
1037 {
1038   smpi_replay_init(argc, argv);
1039   smpi_replay_main(argc, argv);
1040 }