Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Remove old functions / datatype
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 typedef std::tuple</*sender*/ int, /* reciever */ int, /* tag */int> req_key_t;
81 typedef std::unordered_map<req_key_t, MPI_Request, hash_tuple::hash<std::tuple<int,int,int>>> req_storage_t;
82
83 static MPI_Datatype MPI_DEFAULT_TYPE;
84
85 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
86   {                                                                                                                    \
87     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
88       std::stringstream ss;                                                                                            \
89       for (const auto& elem : action) {                                                                                \
90         ss << elem << " ";                                                                                             \
91       }                                                                                                                \
92       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
93                            "%zu items were given on the line. First two should be process_id and action.  "            \
94                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
95                            "The full line that was given is:\n   %s\n"                                                 \
96                            "Please contact the Simgrid team if support is needed",                                     \
97              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
98              ss.str().c_str());                                                                                        \
99     }                                                                                                                  \
100   }
101
102 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
103 {
104   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
105     std::string s = boost::algorithm::join(action, " ");
106     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
107   }
108 }
109
110 /* Helper function */
111 static double parse_double(std::string string)
112 {
113   return xbt_str_parse_double(string.c_str(), "%s is not a double");
114 }
115
116 namespace simgrid {
117 namespace smpi {
118
119 namespace replay {
120
121 class RequestStorage {
122 private:
123     req_storage_t store;
124
125 public:
126     RequestStorage() {}
127     int size()
128     {
129       return store.size();
130     }
131
132     req_storage_t& get_store()
133     {
134       return store;
135     }
136
137     void get_requests(std::vector<MPI_Request>& vec)
138     {
139       for (auto& pair : store) {
140         auto& req = pair.second;
141         auto my_proc_id = simgrid::s4u::this_actor::get_pid();
142         if (req != MPI_REQUEST_NULL && (req->src() == my_proc_id || req->dst() == my_proc_id)) {
143           vec.push_back(pair.second);
144           pair.second->print_request("MM");
145         }
146       }
147     }
148
149     MPI_Request find(int src, int dst, int tag)
150     {
151       req_storage_t::iterator it = store.find(req_key_t(src, dst, tag));
152       return (it == store.end()) ? MPI_REQUEST_NULL : it->second;
153     }
154
155     void remove(MPI_Request req)
156     {
157       if (req == MPI_REQUEST_NULL) return;
158
159       store.erase(req_key_t(req->src()-1, req->dst()-1, req->tag()));
160     }
161
162     void add(MPI_Request req)
163     {
164       if (req != MPI_REQUEST_NULL) // Can and does happen in the case of TestAction
165         store.insert({req_key_t(req->src()-1, req->dst()-1, req->tag()), req});
166     }
167
168     /* Sometimes we need to re-insert MPI_REQUEST_NULL but we still need src,dst and tag */
169     void addNullRequest(int src, int dst, int tag)
170     {
171       store.insert({req_key_t(src, dst, tag), MPI_REQUEST_NULL});
172     }
173 };
174
175 class ActionArgParser {
176 public:
177   virtual ~ActionArgParser() = default;
178   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
179 };
180
181 class WaitTestParser : public ActionArgParser {
182 public:
183   int src;
184   int dst;
185   int tag;
186
187   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
188   {
189     CHECK_ACTION_PARAMS(action, 3, 0)
190     src = std::stoi(action[2]);
191     dst = std::stoi(action[3]);
192     tag = std::stoi(action[4]);
193   }
194 };
195
196 class SendRecvParser : public ActionArgParser {
197 public:
198   /* communication partner; if we send, this is the receiver and vice versa */
199   int partner;
200   double size;
201   int tag;
202   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
203
204   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
205   {
206     CHECK_ACTION_PARAMS(action, 3, 1)
207     partner = std::stoi(action[2]);
208     tag     = std::stoi(action[3]);
209     size    = parse_double(action[4]);
210     if (action.size() > 5)
211       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
212   }
213 };
214
215 class ComputeParser : public ActionArgParser {
216 public:
217   /* communication partner; if we send, this is the receiver and vice versa */
218   double flops;
219
220   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
221   {
222     CHECK_ACTION_PARAMS(action, 1, 0)
223     flops = parse_double(action[2]);
224   }
225 };
226
227 class CollCommParser : public ActionArgParser {
228 public:
229   double size;
230   double comm_size;
231   double comp_size;
232   int send_size;
233   int recv_size;
234   int root = 0;
235   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
236   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
237 };
238
239 class BcastArgParser : public CollCommParser {
240 public:
241   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
242   {
243     CHECK_ACTION_PARAMS(action, 1, 2)
244     size = parse_double(action[2]);
245     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
246     if (action.size() > 4)
247       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
248   }
249 };
250
251 class ReduceArgParser : public CollCommParser {
252 public:
253   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
254   {
255     CHECK_ACTION_PARAMS(action, 2, 2)
256     comm_size = parse_double(action[2]);
257     comp_size = parse_double(action[3]);
258     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
259     if (action.size() > 5)
260       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
261   }
262 };
263
264 class AllReduceArgParser : public CollCommParser {
265 public:
266   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
267   {
268     CHECK_ACTION_PARAMS(action, 2, 1)
269     comm_size = parse_double(action[2]);
270     comp_size = parse_double(action[3]);
271     if (action.size() > 4)
272       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
273   }
274 };
275
276 class AllToAllArgParser : public CollCommParser {
277 public:
278   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
279   {
280     CHECK_ACTION_PARAMS(action, 2, 1)
281     comm_size = MPI_COMM_WORLD->size();
282     send_size = parse_double(action[2]);
283     recv_size = parse_double(action[3]);
284
285     if (action.size() > 4)
286       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
287     if (action.size() > 5)
288       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
289   }
290 };
291
292 class GatherArgParser : public CollCommParser {
293 public:
294   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
295   {
296     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
297           0 gather 68 68 0 0 0
298         where:
299           1) 68 is the sendcounts
300           2) 68 is the recvcounts
301           3) 0 is the root node
302           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
303           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
304     */
305     CHECK_ACTION_PARAMS(action, 2, 3)
306     comm_size = MPI_COMM_WORLD->size();
307     send_size = parse_double(action[2]);
308     recv_size = parse_double(action[3]);
309
310     if (name == "gather") {
311       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
312       if (action.size() > 5)
313         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
314       if (action.size() > 6)
315         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
316     }
317     else {
318       if (action.size() > 4)
319         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
320       if (action.size() > 5)
321         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
322     }
323   }
324 };
325
326 class GatherVArgParser : public CollCommParser {
327 public:
328   int recv_size_sum;
329   std::shared_ptr<std::vector<int>> recvcounts;
330   std::vector<int> disps;
331   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
332   {
333     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
334          0 gather 68 68 10 10 10 0 0 0
335        where:
336          1) 68 is the sendcount
337          2) 68 10 10 10 is the recvcounts
338          3) 0 is the root node
339          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
340          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
341     */
342     comm_size = MPI_COMM_WORLD->size();
343     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
344     send_size = parse_double(action[2]);
345     disps     = std::vector<int>(comm_size, 0);
346     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
347
348     if (name == "gatherV") {
349       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
350       if (action.size() > 4 + comm_size)
351         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
352       if (action.size() > 5 + comm_size)
353         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
354     }
355     else {
356       int datatype_index = 0;
357       int disp_index     = 0;
358       /* The 3 comes from "0 gather <sendcount>", which must always be present.
359        * The + comm_size is the recvcounts array, which must also be present
360        */
361       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
362         datatype_index = 3 + comm_size;
363         disp_index     = datatype_index + 1;
364         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
365         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
366       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
367         disp_index     = 3 + comm_size;
368       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
369         datatype_index = 3 + comm_size;
370         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
371         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
372       }
373
374       if (disp_index != 0) {
375         for (unsigned int i = 0; i < comm_size; i++)
376           disps[i]          = std::stoi(action[disp_index + i]);
377       }
378     }
379
380     for (unsigned int i = 0; i < comm_size; i++) {
381       (*recvcounts)[i] = std::stoi(action[i + 3]);
382     }
383     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
384   }
385 };
386
387 class ScatterArgParser : public CollCommParser {
388 public:
389   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
390   {
391     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
392           0 gather 68 68 0 0 0
393         where:
394           1) 68 is the sendcounts
395           2) 68 is the recvcounts
396           3) 0 is the root node
397           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
398           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
399     */
400     CHECK_ACTION_PARAMS(action, 2, 3)
401     comm_size   = MPI_COMM_WORLD->size();
402     send_size   = parse_double(action[2]);
403     recv_size   = parse_double(action[3]);
404     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
405     if (action.size() > 5)
406       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
407     if (action.size() > 6)
408       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
409   }
410 };
411
412 class ScatterVArgParser : public CollCommParser {
413 public:
414   int recv_size_sum;
415   int send_size_sum;
416   std::shared_ptr<std::vector<int>> sendcounts;
417   std::vector<int> disps;
418   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
419   {
420     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
421        0 gather 68 10 10 10 68 0 0 0
422         where:
423         1) 68 10 10 10 is the sendcounts
424         2) 68 is the recvcount
425         3) 0 is the root node
426         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
427         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
428     */
429     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
430     recv_size  = parse_double(action[2 + comm_size]);
431     disps      = std::vector<int>(comm_size, 0);
432     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
433
434     if (action.size() > 5 + comm_size)
435       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
436     if (action.size() > 5 + comm_size)
437       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
438
439     for (unsigned int i = 0; i < comm_size; i++) {
440       (*sendcounts)[i] = std::stoi(action[i + 2]);
441     }
442     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
443     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
444   }
445 };
446
447 class ReduceScatterArgParser : public CollCommParser {
448 public:
449   int recv_size_sum;
450   std::shared_ptr<std::vector<int>> recvcounts;
451   std::vector<int> disps;
452   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
453   {
454     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
455          0 reduceScatter 275427 275427 275427 204020 11346849 0
456        where:
457          1) The first four values after the name of the action declare the recvcounts array
458          2) The value 11346849 is the amount of instructions
459          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
460     */
461     comm_size = MPI_COMM_WORLD->size();
462     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
463     comp_size = parse_double(action[2+comm_size]);
464     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
465     if (action.size() > 3 + comm_size)
466       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
467
468     for (unsigned int i = 0; i < comm_size; i++) {
469       recvcounts->push_back(std::stoi(action[i + 2]));
470     }
471     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
472   }
473 };
474
475 class AllToAllVArgParser : public CollCommParser {
476 public:
477   int recv_size_sum;
478   int send_size_sum;
479   std::shared_ptr<std::vector<int>> recvcounts;
480   std::shared_ptr<std::vector<int>> sendcounts;
481   std::vector<int> senddisps;
482   std::vector<int> recvdisps;
483   int send_buf_size;
484   int recv_buf_size;
485   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
486   {
487     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
488           0 allToAllV 100 1 7 10 12 100 1 70 10 5
489        where:
490         1) 100 is the size of the send buffer *sizeof(int),
491         2) 1 7 10 12 is the sendcounts array
492         3) 100*sizeof(int) is the size of the receiver buffer
493         4)  1 70 10 5 is the recvcounts array
494     */
495     comm_size = MPI_COMM_WORLD->size();
496     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
497     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
498     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
499     senddisps  = std::vector<int>(comm_size, 0);
500     recvdisps  = std::vector<int>(comm_size, 0);
501
502     if (action.size() > 5 + 2 * comm_size)
503       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
504     if (action.size() > 5 + 2 * comm_size)
505       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
506
507     send_buf_size=parse_double(action[2]);
508     recv_buf_size=parse_double(action[3+comm_size]);
509     for (unsigned int i = 0; i < comm_size; i++) {
510       (*sendcounts)[i] = std::stoi(action[3 + i]);
511       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
512     }
513     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
514     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
515   }
516 };
517
518 template <class T> class ReplayAction {
519 protected:
520   const std::string name;
521   RequestStorage* req_storage; // Points to the right storage for this process, nullptr except for Send/Recv/Wait/Test actions.
522   const int my_proc_id;
523   T args;
524
525 public:
526   explicit ReplayAction(std::string name, RequestStorage& storage) : name(name), req_storage(&storage), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
527   explicit ReplayAction(std::string name) : name(name), req_storage(nullptr), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
528   virtual ~ReplayAction() = default;
529
530   virtual void execute(simgrid::xbt::ReplayAction& action)
531   {
532     // Needs to be re-initialized for every action, hence here
533     double start_time = smpi_process()->simulated_elapsed();
534     args.parse(action, name);
535     kernel(action);
536     if (name != "Init")
537       log_timed_action(action, start_time);
538   }
539
540   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
541
542   void* send_buffer(int size)
543   {
544     return smpi_get_tmp_sendbuffer(size);
545   }
546
547   void* recv_buffer(int size)
548   {
549     return smpi_get_tmp_recvbuffer(size);
550   }
551 };
552
553 class WaitAction : public ReplayAction<WaitTestParser> {
554 public:
555   WaitAction(RequestStorage& storage) : ReplayAction("Wait", storage) {}
556   void kernel(simgrid::xbt::ReplayAction& action) override
557   {
558     std::string s = boost::algorithm::join(action, " ");
559     xbt_assert(req_storage->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
560     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
561     req_storage->remove(request);
562
563     if (request == MPI_REQUEST_NULL) {
564       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
565        * return.*/
566       return;
567     }
568
569     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
570
571     // Must be taken before Request::wait() since the request may be set to
572     // MPI_REQUEST_NULL by Request::wait!
573     bool is_wait_for_receive = (request->flags() & RECV);
574     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
575     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
576
577     MPI_Status status;
578     Request::wait(&request, &status);
579
580     TRACE_smpi_comm_out(rank);
581     if (is_wait_for_receive)
582       TRACE_smpi_recv(args.src, args.dst, args.tag);
583   }
584 };
585
586 class SendAction : public ReplayAction<SendRecvParser> {
587 public:
588   SendAction() = delete;
589   explicit SendAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
590   void kernel(simgrid::xbt::ReplayAction& action) override
591   {
592     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
593
594     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
595                                                                              args.tag, Datatype::encode(args.datatype1)));
596     if (not TRACE_smpi_view_internals())
597       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
598
599     if (name == "send") {
600       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
601     } else if (name == "Isend") {
602       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
603       req_storage->add(request);
604     } else {
605       xbt_die("Don't know this action, %s", name.c_str());
606     }
607
608     TRACE_smpi_comm_out(my_proc_id);
609   }
610 };
611
612 class RecvAction : public ReplayAction<SendRecvParser> {
613 public:
614   RecvAction() = delete;
615   explicit RecvAction(std::string name, RequestStorage& storage) : ReplayAction(name, storage) {}
616   void kernel(simgrid::xbt::ReplayAction& action) override
617   {
618     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
619
620     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
621                                                                              args.tag, Datatype::encode(args.datatype1)));
622
623     MPI_Status status;
624     // unknown size from the receiver point of view
625     if (args.size <= 0.0) {
626       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
627       args.size = status.count;
628     }
629
630     if (name == "recv") {
631       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
632     } else if (name == "Irecv") {
633       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
634       req_storage->add(request);
635     }
636
637     TRACE_smpi_comm_out(my_proc_id);
638     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
639     if (name == "recv" && not TRACE_smpi_view_internals()) {
640       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
641     }
642   }
643 };
644
645 class ComputeAction : public ReplayAction<ComputeParser> {
646 public:
647   ComputeAction() : ReplayAction("compute") {}
648   void kernel(simgrid::xbt::ReplayAction& action) override
649   {
650     TRACE_smpi_computing_in(my_proc_id, args.flops);
651     smpi_execute_flops(args.flops);
652     TRACE_smpi_computing_out(my_proc_id);
653   }
654 };
655
656 class TestAction : public ReplayAction<WaitTestParser> {
657 public:
658   TestAction(RequestStorage& storage) : ReplayAction("Test", storage) {}
659   void kernel(simgrid::xbt::ReplayAction& action) override
660   {
661     MPI_Request request = req_storage->find(args.src, args.dst, args.tag);
662     req_storage->remove(request);
663     // if request is null here, this may mean that a previous test has succeeded
664     // Different times in traced application and replayed version may lead to this
665     // In this case, ignore the extra calls.
666     if (request != MPI_REQUEST_NULL) {
667       TRACE_smpi_testing_in(my_proc_id);
668
669       MPI_Status status;
670       int flag = Request::test(&request, &status);
671
672       XBT_DEBUG("MPI_Test result: %d", flag);
673       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
674        * nullptr.*/
675       if (request == MPI_REQUEST_NULL)
676         req_storage->addNullRequest(args.src, args.dst, args.tag);
677       else
678         req_storage->add(request);
679
680       TRACE_smpi_testing_out(my_proc_id);
681     }
682   }
683 };
684
685 class InitAction : public ReplayAction<ActionArgParser> {
686 public:
687   InitAction() : ReplayAction("Init") {}
688   void kernel(simgrid::xbt::ReplayAction& action) override
689   {
690     CHECK_ACTION_PARAMS(action, 0, 1)
691     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
692                                            : MPI_BYTE;  // default TAU datatype
693
694     /* start a simulated timer */
695     smpi_process()->simulated_start();
696   }
697 };
698
699 class CommunicatorAction : public ReplayAction<ActionArgParser> {
700 public:
701   CommunicatorAction() : ReplayAction("Comm") {}
702   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
703 };
704
705 class WaitAllAction : public ReplayAction<ActionArgParser> {
706 public:
707   WaitAllAction(RequestStorage& storage) : ReplayAction("waitAll", storage) {}
708   void kernel(simgrid::xbt::ReplayAction& action) override
709   {
710     const unsigned int count_requests = req_storage->size();
711
712     if (count_requests > 0) {
713       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
714       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
715       std::vector<MPI_Request> reqs;
716       req_storage->get_requests(reqs);
717       for (const auto& req : reqs) {
718         if (req && (req->flags() & RECV)) {
719           sender_receiver.push_back({req->src(), req->dst()});
720         }
721       }
722       MPI_Status status[count_requests];
723       Request::waitall(count_requests, &(reqs.data())[0], status);
724
725       for (auto& pair : sender_receiver) {
726         TRACE_smpi_recv(pair.first, pair.second, 0);
727       }
728       TRACE_smpi_comm_out(my_proc_id);
729     }
730   }
731 };
732
733 class BarrierAction : public ReplayAction<ActionArgParser> {
734 public:
735   BarrierAction() : ReplayAction("barrier") {}
736   void kernel(simgrid::xbt::ReplayAction& action) override
737   {
738     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
739     Colls::barrier(MPI_COMM_WORLD);
740     TRACE_smpi_comm_out(my_proc_id);
741   }
742 };
743
744 class BcastAction : public ReplayAction<BcastArgParser> {
745 public:
746   BcastAction() : ReplayAction("bcast") {}
747   void kernel(simgrid::xbt::ReplayAction& action) override
748   {
749     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
750                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
751                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
752
753     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
754
755     TRACE_smpi_comm_out(my_proc_id);
756   }
757 };
758
759 class ReduceAction : public ReplayAction<ReduceArgParser> {
760 public:
761   ReduceAction() : ReplayAction("reduce") {}
762   void kernel(simgrid::xbt::ReplayAction& action) override
763   {
764     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
765                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
766                                                       args.comp_size, args.comm_size, -1,
767                                                       Datatype::encode(args.datatype1), ""));
768
769     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
770         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
771     smpi_execute_flops(args.comp_size);
772
773     TRACE_smpi_comm_out(my_proc_id);
774   }
775 };
776
777 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
778 public:
779   AllReduceAction() : ReplayAction("allReduce") {}
780   void kernel(simgrid::xbt::ReplayAction& action) override
781   {
782     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
783                                                                                 Datatype::encode(args.datatype1), ""));
784
785     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
786         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
787     smpi_execute_flops(args.comp_size);
788
789     TRACE_smpi_comm_out(my_proc_id);
790   }
791 };
792
793 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
794 public:
795   AllToAllAction() : ReplayAction("allToAll") {}
796   void kernel(simgrid::xbt::ReplayAction& action) override
797   {
798     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
799                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
800                                                     Datatype::encode(args.datatype1),
801                                                     Datatype::encode(args.datatype2)));
802
803     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
804                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
805                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
806
807     TRACE_smpi_comm_out(my_proc_id);
808   }
809 };
810
811 class GatherAction : public ReplayAction<GatherArgParser> {
812 public:
813   explicit GatherAction(std::string name) : ReplayAction(name) {}
814   void kernel(simgrid::xbt::ReplayAction& action) override
815   {
816     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
817                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
818
819     if (name == "gather") {
820       int rank = MPI_COMM_WORLD->rank();
821       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
822                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
823     }
824     else
825       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
826                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
827
828     TRACE_smpi_comm_out(my_proc_id);
829   }
830 };
831
832 class GatherVAction : public ReplayAction<GatherVArgParser> {
833 public:
834   explicit GatherVAction(std::string name) : ReplayAction(name) {}
835   void kernel(simgrid::xbt::ReplayAction& action) override
836   {
837     int rank = MPI_COMM_WORLD->rank();
838
839     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
840                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
841                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
842
843     if (name == "gatherV") {
844       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
845                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
846                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
847     }
848     else {
849       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
850                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
851                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
852     }
853
854     TRACE_smpi_comm_out(my_proc_id);
855   }
856 };
857
858 class ScatterAction : public ReplayAction<ScatterArgParser> {
859 public:
860   ScatterAction() : ReplayAction("scatter") {}
861   void kernel(simgrid::xbt::ReplayAction& action) override
862   {
863     int rank = MPI_COMM_WORLD->rank();
864     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
865                                                                           Datatype::encode(args.datatype1),
866                                                                           Datatype::encode(args.datatype2)));
867
868     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
869                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
870
871     TRACE_smpi_comm_out(my_proc_id);
872   }
873 };
874
875
876 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
877 public:
878   ScatterVAction() : ReplayAction("scatterV") {}
879   void kernel(simgrid::xbt::ReplayAction& action) override
880   {
881     int rank = MPI_COMM_WORLD->rank();
882     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
883           nullptr, Datatype::encode(args.datatype1),
884           Datatype::encode(args.datatype2)));
885
886     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
887                     args.sendcounts->data(), args.disps.data(), args.datatype1,
888                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
889                     MPI_COMM_WORLD);
890
891     TRACE_smpi_comm_out(my_proc_id);
892   }
893 };
894
895 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
896 public:
897   ReduceScatterAction() : ReplayAction("reduceScatter") {}
898   void kernel(simgrid::xbt::ReplayAction& action) override
899   {
900     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
901                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
902                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
903                                                          Datatype::encode(args.datatype1)));
904
905     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
906                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
907                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
908
909     smpi_execute_flops(args.comp_size);
910     TRACE_smpi_comm_out(my_proc_id);
911   }
912 };
913
914 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
915 public:
916   AllToAllVAction() : ReplayAction("allToAllV") {}
917   void kernel(simgrid::xbt::ReplayAction& action) override
918   {
919     TRACE_smpi_comm_in(my_proc_id, __func__,
920                        new simgrid::instr::VarCollTIData(
921                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
922                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
923
924     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
925                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
926
927     TRACE_smpi_comm_out(my_proc_id);
928   }
929 };
930 } // Replay Namespace
931 }} // namespace simgrid::smpi
932
933 std::vector<simgrid::smpi::replay::RequestStorage> storage;
934 /** @brief Only initialize the replay, don't do it for real */
935 void smpi_replay_init(int* argc, char*** argv)
936 {
937   simgrid::smpi::Process::init(argc, argv);
938   smpi_process()->mark_as_initialized();
939   smpi_process()->set_replaying(true);
940
941   int my_proc_id = simgrid::s4u::this_actor::get_pid();
942   storage.resize(smpi_process_count());
943
944   TRACE_smpi_init(my_proc_id);
945   TRACE_smpi_computing_init(my_proc_id);
946   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
947   TRACE_smpi_comm_out(my_proc_id);
948   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
949   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
950   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
951   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
952   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
953
954   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
955   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
956   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
957   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv", storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
958   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
959   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
960   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction(storage[simgrid::s4u::this_actor::get_pid()-1]).execute(action); });
961   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
962   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
963   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
964   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
965   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
966   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
967   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
968   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
969   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
970   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
971   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
972   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
973   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
974   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
975
976   //if we have a delayed start, sleep here.
977   if(*argc>2){
978     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
979     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
980     smpi_execute_flops(value);
981   } else {
982     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
983     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
984     smpi_execute_flops(0.0);
985   }
986 }
987
988 /** @brief actually run the replay after initialization */
989 void smpi_replay_main(int* argc, char*** argv)
990 {
991   static int active_processes = 0;
992   active_processes++;
993   simgrid::xbt::replay_runner(*argc, *argv);
994
995   /* and now, finalize everything */
996   /* One active process will stop. Decrease the counter*/
997   unsigned int count_requests = storage[simgrid::s4u::this_actor::get_pid() - 1].size();
998   XBT_DEBUG("There are %ud elements in reqq[*]", count_requests);
999   if (count_requests > 0) {
1000     MPI_Request requests[count_requests];
1001     MPI_Status status[count_requests];
1002     unsigned int i=0;
1003
1004     for (auto const& pair : storage[simgrid::s4u::this_actor::get_pid() - 1].get_store()) {
1005       requests[i] = pair.second;
1006       i++;
1007     }
1008     simgrid::smpi::Request::waitall(count_requests, requests, status);
1009   }
1010   active_processes--;
1011
1012   if(active_processes==0){
1013     /* Last process alive speaking: end the simulated timer */
1014     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
1015     smpi_free_replay_tmp_buffers();
1016   }
1017
1018   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
1019                      new simgrid::instr::NoOpTIData("finalize"));
1020
1021   smpi_process()->finalize();
1022
1023   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
1024   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
1025 }
1026
1027 /** @brief chain a replay initialization and a replay start */
1028 void smpi_replay_run(int* argc, char*** argv)
1029 {
1030   smpi_replay_init(argc, argv);
1031   smpi_replay_main(argc, argv);
1032 }