Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
cfd24494607d91cd3d3410f9f0539091be2e8354
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
81 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
82
83 static MPI_Datatype MPI_DEFAULT_TYPE;
84
85 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
86   {                                                                                                                    \
87     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
88       std::stringstream ss;                                                                                            \
89       for (const auto& elem : action) {                                                                                \
90         ss << elem << " ";                                                                                             \
91       }                                                                                                                \
92       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
93                            "%zu items were given on the line. First two should be process_id and action.  "            \
94                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
95                            "The full line that was given is:\n   %s\n"                                                 \
96                            "Please contact the Simgrid team if support is needed",                                     \
97              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
98              ss.str().c_str());                                                                                        \
99     }                                                                                                                  \
100   }
101
102 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
103 {
104   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
105     std::string s = boost::algorithm::join(action, " ");
106     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
107   }
108 }
109
110 /* Helper function */
111 static double parse_double(std::string string)
112 {
113   return xbt_str_parse_double(string.c_str(), "%s is not a double");
114 }
115
116 namespace simgrid {
117 namespace smpi {
118
119 namespace replay {
120
121 class RequestStorage {
122 private:
123     req_storage_t store;
124
125 public:
126     RequestStorage() {}
127     int size()
128     {
129       return store.size();
130     }
131
132     req_storage_t& get_store()
133     {
134       return store;
135     }
136
137     void get_requests(std::vector<MPI_Request>& vec)
138     {
139       for (auto& pair : store) {
140         auto& req = pair.second;
141         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
142         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
143           vec.push_back(pair.second);
144           pair.second->print_request("MM");
145         }
146       }
147     }
148
149     MPI_Request find(int src, int dst, int tag)
150     {
151       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
152       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
153     }
154
155     void remove(MPI_Request req)
156     {
157       if (req == MPI_REQUEST_NULL) return;
158
159       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
160     }
161
162     void add(MPI_Request req)
163     {
164       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
165         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
166     }
167
168     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
169     void addNullRequest(int src, int dst, int tag)
170     {
171       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
172     }
173 };
174
175 class ActionArgParser {
176 public:
177   virtual ~ActionArgParser() = default;
178   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
179 };
180
181 class WaitTestParser : public ActionArgParser {
182 public:
183   int src;
184   int dst;
185   int tag;
186
187   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
188   {
189     CHECK_ACTION_PARAMS(action, 3, 0)
190     src = std::stoi(action[2]);
191     dst = std::stoi(action[3]);
192     tag = std::stoi(action[4]);
193   }
194 };
195
196 class SendRecvParser : public ActionArgParser {
197 public:
198   /* communication partner; if we send, this is the receiver and vice versa */
199   int partner;
200   double size;
201   int tag;
202   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
203
204   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
205   {
206     CHECK_ACTION_PARAMS(action, 3, 1)
207     partner = std::stoi(action[2]);
208     tag     = std::stoi(action[3]);
209     size    = parse_double(action[4]);
210     if (action.size() > 5)
211       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
212   }
213 };
214
215 class ComputeParser : public ActionArgParser {
216 public:
217   /* communication partner; if we send, this is the receiver and vice versa */
218   double flops;
219
220   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
221   {
222     CHECK_ACTION_PARAMS(action, 1, 0)
223     flops = parse_double(action[2]);
224   }
225 };
226
227 class CollCommParser : public ActionArgParser {
228 public:
229   double size;
230   double comm_size;
231   double comp_size;
232   int send_size;
233   int recv_size;
234   int root = 0;
235   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
236   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
237 };
238
239 class BcastArgParser : public CollCommParser {
240 public:
241   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
242   {
243     CHECK_ACTION_PARAMS(action, 1, 2)
244     size = parse_double(action[2]);
245     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
246     if (action.size() > 4)
247       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
248   }
249 };
250
251 class ReduceArgParser : public CollCommParser {
252 public:
253   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
254   {
255     CHECK_ACTION_PARAMS(action, 2, 2)
256     comm_size = parse_double(action[2]);
257     comp_size = parse_double(action[3]);
258     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
259     if (action.size() > 5)
260       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
261   }
262 };
263
264 class AllReduceArgParser : public CollCommParser {
265 public:
266   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
267   {
268     CHECK_ACTION_PARAMS(action, 2, 1)
269     comm_size = parse_double(action[2]);
270     comp_size = parse_double(action[3]);
271     if (action.size() > 4)
272       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
273   }
274 };
275
276 class AllToAllArgParser : public CollCommParser {
277 public:
278   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
279   {
280     CHECK_ACTION_PARAMS(action, 2, 1)
281     comm_size = MPI_COMM_WORLD->size();
282     send_size = parse_double(action[2]);
283     recv_size = parse_double(action[3]);
284
285     if (action.size() > 4)
286       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
287     if (action.size() > 5)
288       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
289   }
290 };
291
292 class GatherArgParser : public CollCommParser {
293 public:
294   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
295   {
296     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
297           0 gather 68 68 0 0 0
298         where:
299           1) 68 is the sendcounts
300           2) 68 is the recvcounts
301           3) 0 is the root node
302           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
303           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
304     */
305     CHECK_ACTION_PARAMS(action, 2, 3)
306     comm_size = MPI_COMM_WORLD->size();
307     send_size = parse_double(action[2]);
308     recv_size = parse_double(action[3]);
309
310     if (name == "gather") {
311       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
312       if (action.size() > 5)
313         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
314       if (action.size() > 6)
315         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
316     }
317     else {
318       if (action.size() > 4)
319         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
320       if (action.size() > 5)
321         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
322     }
323   }
324 };
325
326 class GatherVArgParser : public CollCommParser {
327 public:
328   int recv_size_sum;
329   std::shared_ptr<std::vector<int>> recvcounts;
330   std::vector<int> disps;
331   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
332   {
333     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
334          0 gather 68 68 10 10 10 0 0 0
335        where:
336          1) 68 is the sendcount
337          2) 68 10 10 10 is the recvcounts
338          3) 0 is the root node
339          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
340          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
341     */
342     comm_size = MPI_COMM_WORLD->size();
343     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
344     send_size = parse_double(action[2]);
345     disps     = std::vector<int>(comm_size, 0);
346     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
347
348     if (name == "gatherV") {
349       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
350       if (action.size() > 4 + comm_size)
351         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
352       if (action.size() > 5 + comm_size)
353         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
354     }
355     else {
356       int datatype_index = 0;
357       int disp_index     = 0;
358       /* The 3 comes from "0 gather <sendcount>", which must always be present.
359        * The + comm_size is the recvcounts array, which must also be present
360        */
361       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
362         datatype_index = 3 + comm_size;
363         disp_index     = datatype_index + 1;
364         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
365         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
366       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
367         disp_index     = 3 + comm_size;
368       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
369         datatype_index = 3 + comm_size;
370         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
371         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
372       }
373
374       if (disp_index != 0) {
375         for (unsigned int i = 0; i < comm_size; i++)
376           disps[i]          = std::stoi(action[disp_index + i]);
377       }
378     }
379
380     for (unsigned int i = 0; i < comm_size; i++) {
381       (*recvcounts)[i] = std::stoi(action[i + 3]);
382     }
383     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
384   }
385 };
386
387 class ScatterArgParser : public CollCommParser {
388 public:
389   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
390   {
391     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
392           0 gather 68 68 0 0 0
393         where:
394           1) 68 is the sendcounts
395           2) 68 is the recvcounts
396           3) 0 is the root node
397           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
398           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
399     */
400     CHECK_ACTION_PARAMS(action, 2, 3)
401     comm_size   = MPI_COMM_WORLD->size();
402     send_size   = parse_double(action[2]);
403     recv_size   = parse_double(action[3]);
404     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
405     if (action.size() > 5)
406       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
407     if (action.size() > 6)
408       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
409   }
410 };
411
412 class ScatterVArgParser : public CollCommParser {
413 public:
414   int recv_size_sum;
415   int send_size_sum;
416   std::shared_ptr<std::vector<int>> sendcounts;
417   std::vector<int> disps;
418   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
419   {
420     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
421        0 gather 68 10 10 10 68 0 0 0
422         where:
423         1) 68 10 10 10 is the sendcounts
424         2) 68 is the recvcount
425         3) 0 is the root node
426         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
427         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
428     */
429     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
430     recv_size  = parse_double(action[2 + comm_size]);
431     disps      = std::vector<int>(comm_size, 0);
432     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
433
434     if (action.size() > 5 + comm_size)
435       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
436     if (action.size() > 5 + comm_size)
437       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
438
439     for (unsigned int i = 0; i < comm_size; i++) {
440       (*sendcounts)[i] = std::stoi(action[i + 2]);
441     }
442     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
443     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
444   }
445 };
446
447 class ReduceScatterArgParser : public CollCommParser {
448 public:
449   int recv_size_sum;
450   std::shared_ptr<std::vector<int>> recvcounts;
451   std::vector<int> disps;
452   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
453   {
454     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
455          0 reduceScatter 275427 275427 275427 204020 11346849 0
456        where:
457          1) The first four values after the name of the action declare the recvcounts array
458          2) The value 11346849 is the amount of instructions
459          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
460     */
461     comm_size = MPI_COMM_WORLD->size();
462     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
463     comp_size = parse_double(action[2+comm_size]);
464     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
465     if (action.size() > 3 + comm_size)
466       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
467
468     for (unsigned int i = 0; i < comm_size; i++) {
469       recvcounts->push_back(std::stoi(action[i + 2]));
470     }
471     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
472   }
473 };
474
475 class AllToAllVArgParser : public CollCommParser {
476 public:
477   int recv_size_sum;
478   int send_size_sum;
479   std::shared_ptr<std::vector<int>> recvcounts;
480   std::shared_ptr<std::vector<int>> sendcounts;
481   std::vector<int> senddisps;
482   std::vector<int> recvdisps;
483   int send_buf_size;
484   int recv_buf_size;
485   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
486   {
487     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
488           0 allToAllV 100 1 7 10 12 100 1 70 10 5
489        where:
490         1) 100 is the size of the send buffer *sizeof(int),
491         2) 1 7 10 12 is the sendcounts array
492         3) 100*sizeof(int) is the size of the receiver buffer
493         4)  1 70 10 5 is the recvcounts array
494     */
495     comm_size = MPI_COMM_WORLD->size();
496     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
497     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
498     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
499     senddisps  = std::vector<int>(comm_size, 0);
500     recvdisps  = std::vector<int>(comm_size, 0);
501
502     if (action.size() > 5 + 2 * comm_size)
503       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
504     if (action.size() > 5 + 2 * comm_size)
505       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
506
507     send_buf_size=parse_double(action[2]);
508     recv_buf_size=parse_double(action[3+comm_size]);
509     for (unsigned int i = 0; i < comm_size; i++) {
510       (*sendcounts)[i] = std::stoi(action[3 + i]);
511       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
512     }
513     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
514     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
515   }
516 };
517
518 template <class T> class ReplayAction {
519 protected:
520   const std::string name;
521   RequestStorage* req_storage; // Points to the right storage for this process, nullptr except for Send/Recv/Wait/Test actions.
522   const int my_proc_id;
523   T args;
524
525 public:
526   explicit ReplayAction(std::string name, RequestStorage& storage) : name(name), req_storage(&storage), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
527   explicit ReplayAction(std::string name) : name(name), req_storage(nullptr), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
528   virtual ~ReplayAction() = default;
529
530   virtual void execute(simgrid::xbt::ReplayAction& action)
531   {
532     // Needs to be re-initialized for every action, hence here
533     double start_time = smpi_process()->simulated_elapsed();
534     args.parse(action, name);
535     kernel(action);
536     if (name != "Init")
537       log_timed_action(action, start_time);
538   }
539
540   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
541
542   void* send_buffer(int size)
543   {
544     return smpi_get_tmp_sendbuffer(size);
545   }
546
547   void* recv_buffer(int size)
548   {
549     return smpi_get_tmp_recvbuffer(size);
550   }
551 };
552
553 class WaitAction : public ReplayAction<WaitTestParser> {
554 public:
555   explicit WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
556   void kernel(simgrid::xbt::ReplayAction& action) override
557   {
558     std::string s = boost::algorithm::join(action, " ");
559     xbt_assert(req_storage->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
560     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
561     req_storage->remove(request);
562
563     if (request == MPI_REQUEST_NULL) {
564       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
565        * return.*/
566       return;
567     }
568
569     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
570
571     // Must be taken before Request::wait() since the request may be set to
572     // MPI_REQUEST_NULL by Request::wait!
573     bool is_wait_for_receive = (request->flags() & RECV);
574     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
575     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
576
577     MPI_Status status;
578     Request::wait(&request, &status);
579
580     TRACE_smpi_comm_out(rank);
581     if (is_wait_for_receive)
582       TRACE_smpi_recv(args.src, args.dst, args.tag);
583   }
584 };
585
586 class SendAction : public ReplayAction<SendRecvParser> {
587 public:
588   SendAction() = delete;
589   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
590   void kernel(simgrid::xbt::ReplayAction& action) override
591   {
592     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
593
594     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
595                                                                              args.tag, Datatype::encode(args.datatype1)));
596     if (not TRACE_smpi_view_internals())
597       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
598
599     if (name == "send") {
600       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
601     } else if (name == "Isend") {
602       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
603       req_storage->add(request);
604     } else {
605       xbt_die("Don't know this action, %s", name.c_str());
606     }
607
608     TRACE_smpi_comm_out(my_proc_id);
609   }
610 };
611
612 class RecvAction : public ReplayAction<SendRecvParser> {
613 public:
614   RecvAction() = delete;
615   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
616   void kernel(simgrid::xbt::ReplayAction& action) override
617   {
618     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
619
620     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
621                                                                              args.tag, Datatype::encode(args.datatype1)));
622
623     MPI_Status status;
624     // unknown size from the receiver point of view
625     if (args.size <= 0.0) {
626       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
627       args.size = status.count;
628     }
629
630     if (name == "recv") {
631       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
632     } else if (name == "Irecv") {
633       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
634       req_storage->add(request);
635     }
636
637     TRACE_smpi_comm_out(my_proc_id);
638     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
639     if (name == "recv" && not TRACE_smpi_view_internals()) {
640       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
641     }
642   }
643 };
644
645 class ComputeAction : public ReplayAction<ComputeParser> {
646 public:
647   ComputeAction() : ReplayAction("compute") {}
648   void kernel(simgrid::xbt::ReplayAction& action) override
649   {
650     TRACE_smpi_computing_in(my_proc_id, args.flops);
651     smpi_execute_flops(args.flops);
652     TRACE_smpi_computing_out(my_proc_id);
653   }
654 };
655
656 class TestAction : public ReplayAction<WaitTestParser> {
657 public:
658   explicit TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
659   void kernel(simgrid::xbt::ReplayAction& action) override
660   {
661     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
662     req_storage->remove(request);
663     // if request is null here, this may mean that a previous test has succeeded
664     // Different times in traced application and replayed version may lead to this
665     // In this case, ignore the extra calls.
666     if (request != MPI_REQUEST_NULL) {
667       TRACE_smpi_testing_in(my_proc_id);
668
669       MPI_Status status;
670       int flag = Request::test(&request, &status);
671
672       XBT_DEBUG("MPI_Test result: %d", flag);
673       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
674        * nullptr.*/
675       if (request == MPI_REQUEST_NULL)
676         req_storage->addNullRequest(args.src, args.dst, args.tag);
677       else
678         req_storage->add(request);
679
680       TRACE_smpi_testing_out(my_proc_id);
681     }
682   }
683 };
684
685 class InitAction : public ReplayAction<ActionArgParser> {
686 public:
687   InitAction() : ReplayAction("Init") {}
688   void kernel(simgrid::xbt::ReplayAction& action) override
689   {
690     CHECK_ACTION_PARAMS(action, 0, 1)
691     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
692                                            : MPI_BYTE;  // default TAU datatype
693
694     /* start a simulated timer */
695     smpi_process()->simulated_start();
696   }
697 };
698
699 class CommunicatorAction : public ReplayAction<ActionArgParser> {
700 public:
701   CommunicatorAction() : ReplayAction("Comm") {}
702   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
703 };
704
705 class WaitAllAction : public ReplayAction<ActionArgParser> {
706 public:
707   explicit WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
708   void kernel(simgrid::xbt::ReplayAction& action) override
709   {
710     const unsigned int count_requests = req_storage->size();
711
712     if (count_requests > 0) {
713       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
714       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
715       std::vector<MPI_Request> reqs;
716       req_storage->get_requests(reqs);
717       for (const auto& req : reqs) {
718         if (req && (req->flags() & RECV)) {
719           sender_receiver.push_back({req->src(), req->dst()});
720         }
721       }
722       MPI_Status status[count_requests];
723       Request::waitall(count_requests, &(reqs.data())[0], status);
724       req_storage->get_store().clear();
725
726       for (auto& pair : sender_receiver) {
727         TRACE_smpi_recv(pair.first, pair.second, 0);
728       }
729       TRACE_smpi_comm_out(my_proc_id);
730     }
731   }
732 };
733
734 class BarrierAction : public ReplayAction<ActionArgParser> {
735 public:
736   BarrierAction() : ReplayAction("barrier") {}
737   void kernel(simgrid::xbt::ReplayAction& action) override
738   {
739     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
740     Colls::barrier(MPI_COMM_WORLD);
741     TRACE_smpi_comm_out(my_proc_id);
742   }
743 };
744
745 class BcastAction : public ReplayAction<BcastArgParser> {
746 public:
747   BcastAction() : ReplayAction("bcast") {}
748   void kernel(simgrid::xbt::ReplayAction& action) override
749   {
750     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
751                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
752                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
753
754     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
755
756     TRACE_smpi_comm_out(my_proc_id);
757   }
758 };
759
760 class ReduceAction : public ReplayAction<ReduceArgParser> {
761 public:
762   ReduceAction() : ReplayAction("reduce") {}
763   void kernel(simgrid::xbt::ReplayAction& action) override
764   {
765     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
766                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
767                                                       args.comp_size, args.comm_size, -1,
768                                                       Datatype::encode(args.datatype1), ""));
769
770     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
771         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
772     smpi_execute_flops(args.comp_size);
773
774     TRACE_smpi_comm_out(my_proc_id);
775   }
776 };
777
778 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
779 public:
780   AllReduceAction() : ReplayAction("allReduce") {}
781   void kernel(simgrid::xbt::ReplayAction& action) override
782   {
783     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
784                                                                                 Datatype::encode(args.datatype1), ""));
785
786     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
787         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
788     smpi_execute_flops(args.comp_size);
789
790     TRACE_smpi_comm_out(my_proc_id);
791   }
792 };
793
794 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
795 public:
796   AllToAllAction() : ReplayAction("allToAll") {}
797   void kernel(simgrid::xbt::ReplayAction& action) override
798   {
799     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
800                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
801                                                     Datatype::encode(args.datatype1),
802                                                     Datatype::encode(args.datatype2)));
803
804     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
805                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
806                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
807
808     TRACE_smpi_comm_out(my_proc_id);
809   }
810 };
811
812 class GatherAction : public ReplayAction<GatherArgParser> {
813 public:
814   explicit GatherAction(std::string name) : ReplayAction(name) {}
815   void kernel(simgrid::xbt::ReplayAction& action) override
816   {
817     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
818                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
819
820     if (name == "gather") {
821       int rank = MPI_COMM_WORLD->rank();
822       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
823                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
824     }
825     else
826       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
827                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
828
829     TRACE_smpi_comm_out(my_proc_id);
830   }
831 };
832
833 class GatherVAction : public ReplayAction<GatherVArgParser> {
834 public:
835   explicit GatherVAction(std::string name) : ReplayAction(name) {}
836   void kernel(simgrid::xbt::ReplayAction& action) override
837   {
838     int rank = MPI_COMM_WORLD->rank();
839
840     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
841                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
842                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
843
844     if (name == "gatherV") {
845       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
846                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
847                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
848     }
849     else {
850       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
851                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
852                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
853     }
854
855     TRACE_smpi_comm_out(my_proc_id);
856   }
857 };
858
859 class ScatterAction : public ReplayAction<ScatterArgParser> {
860 public:
861   ScatterAction() : ReplayAction("scatter") {}
862   void kernel(simgrid::xbt::ReplayAction& action) override
863   {
864     int rank = MPI_COMM_WORLD->rank();
865     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
866                                                                           Datatype::encode(args.datatype1),
867                                                                           Datatype::encode(args.datatype2)));
868
869     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
870                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
871
872     TRACE_smpi_comm_out(my_proc_id);
873   }
874 };
875
876
877 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
878 public:
879   ScatterVAction() : ReplayAction("scatterV") {}
880   void kernel(simgrid::xbt::ReplayAction& action) override
881   {
882     int rank = MPI_COMM_WORLD->rank();
883     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
884           nullptr, Datatype::encode(args.datatype1),
885           Datatype::encode(args.datatype2)));
886
887     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
888                     args.sendcounts->data(), args.disps.data(), args.datatype1,
889                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
890                     MPI_COMM_WORLD);
891
892     TRACE_smpi_comm_out(my_proc_id);
893   }
894 };
895
896 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
897 public:
898   ReduceScatterAction() : ReplayAction("reduceScatter") {}
899   void kernel(simgrid::xbt::ReplayAction& action) override
900   {
901     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
902                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
903                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
904                                                          Datatype::encode(args.datatype1)));
905
906     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
907                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
908                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
909
910     smpi_execute_flops(args.comp_size);
911     TRACE_smpi_comm_out(my_proc_id);
912   }
913 };
914
915 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
916 public:
917   AllToAllVAction() : ReplayAction("allToAllV") {}
918   void kernel(simgrid::xbt::ReplayAction& action) override
919   {
920     TRACE_smpi_comm_in(my_proc_id, __func__,
921                        new simgrid::instr::VarCollTIData(
922                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
923                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
924
925     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
926                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
927
928     TRACE_smpi_comm_out(my_proc_id);
929   }
930 };
931 } // Replay Namespace
932 }} // namespace simgrid::smpi
933
934 std::vector<simgrid::smpi::replay::RequestStorage> storage;
935 /** @brief Only initialize the replay, don't do it for real */
936 void smpi_replay_init(int* argc, char*** argv)
937 {
938   simgrid::smpi::Process::init(argc, argv);
939   smpi_process()->mark_as_initialized();
940   smpi_process()->set_replaying(true);
941
942   int my_proc_id = simgrid::s4u::this_actor::get_pid();
943   storage.resize(smpi_process_count());
944
945   TRACE_smpi_init(my_proc_id);
946   TRACE_smpi_computing_init(my_proc_id);
947   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
948   TRACE_smpi_comm_out(my_proc_id);
949   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
950   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
951   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
952   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
953   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
954
955   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
956   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
957   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
958   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
959   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
960   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
961   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
962   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
963   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
964   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
965   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
966   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
967   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
968   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
969   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
970   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
971   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
972   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
973   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
974   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
975   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
976
977   //if we have a delayed start, sleep here.
978   if(*argc>2){
979     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
980     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
981     smpi_execute_flops(value);
982   } else {
983     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
984     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
985     smpi_execute_flops(0.0);
986   }
987 }
988
989 /** @brief actually run the replay after initialization */
990 void smpi_replay_main(int* argc, char*** argv)
991 {
992   static int active_processes = 0;
993   active_processes++;
994   simgrid::xbt::replay_runner(*argc, *argv);
995
996   /* and now, finalize everything */
997   /* One active process will stop. Decrease the counter*/
998   unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
999   XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
1000   if (count_requests > 0) {
1001     MPI_Request requests[count_requests];
1002     MPI_Status status[count_requests];
1003     unsigned int i=0;
1004
1005     for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
1006       requests[i] = pair.second;
1007       i++;
1008     }
1009     simgrid::smpi::Request::waitall(count_requests, requests, status);
1010   }
1011   active_processes--;
1012
1013   if(active_processes==0){
1014     /* Last process alive speaking: end the simulated timer */
1015     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1016     smpi_free_replay_tmp_buffers();
1017   }
1018
1019   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1020                      new simgrid::instr::NoOpTIData("finalize"));
1021
1022   smpi_process()->finalize();
1023
1024   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1025   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1026 }
1027
1028 /** @brief chain a replay initialization and a replay start */
1029 void smpi_replay_run(int* argc, char*** argv)
1030 {
1031   smpi_replay_init(argc, argv);
1032   smpi_replay_main(argc, argv);
1033 }