Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
7eb98156779088da9768dfcbd40e368ab1b24ef3
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 #include <tuple>
25 // From https://stackoverflow.com/questions/7110301/generic-hash-for-tuples-in-unordered-map-unordered-set
26 // This is all just to make std::unordered_map work with std::tuple. If we need this in other places,
27 // this could go into a header file.
28 namespace hash_tuple{
29     template <typename TT>
30     struct hash
31     {
32         size_t
33         operator()(TT const& tt) const
34         {
35             return std::hash<TT>()(tt);
36         }
37     };
38
39     template <class T>
40     inline void hash_combine(std::size_t& seed, T const& v)
41     {
42         seed ^= hash_tuple::hash<T>()(v) + 0x9e3779b9 + (seed<<6) + (seed>>2);
43     }
44
45     // Recursive template code derived from Matthieu M.
46     template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
47     struct HashValueImpl
48     {
49       static void apply(size_t& seed, Tuple const& tuple)
50       {
51         HashValueImpl<Tuple, Index-1>::apply(seed, tuple);
52         hash_combine(seed, std::get<Index>(tuple));
53       }
54     };
55
56     template <class Tuple>
57     struct HashValueImpl<Tuple,0>
58     {
59       static void apply(size_t& seed, Tuple const& tuple)
60       {
61         hash_combine(seed, std::get<0>(tuple));
62       }
63     };
64
65     template <typename ... TT>
66     struct hash<std::tuple<TT...>>
67     {
68         size_t
69         operator()(std::tuple<TT...> const& tt) const
70         {
71             size_t seed = 0;
72             HashValueImpl<std::tuple<TT...> >::apply(seed, tt);
73             return seed;
74         }
75     };
76 }
77
78 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
79
80 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
81
82 static MPI_Datatype MPI_DEFAULT_TYPE;
83
84 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
85   {                                                                                                                    \
86     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
87       std::stringstream ss;                                                                                            \
88       for (const auto& elem : action) {                                                                                \
89         ss << elem << " ";                                                                                             \
90       }                                                                                                                \
91       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
92                            "%zu items were given on the line. First two should be process_id and action.  "            \
93                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
94                            "The full line that was given is:\n   %s\n"                                                 \
95                            "Please contact the Simgrid team if support is needed",                                     \
96              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
97              ss.str().c_str());                                                                                        \
98     }                                                                                                                  \
99   }
100
101 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
102 {
103   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
104     std::string s = boost::algorithm::join(action, " ");
105     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
106   }
107 }
108
109 static std::vector<MPI_Request>* get_reqq_self()
110 {
111   return reqq.at(simgrid::s4u::this_actor::get_pid());
112 }
113
114 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
115 {
116   reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
117 }
118
119 /* Helper function */
120 static double parse_double(std::string string)
121 {
122   return xbt_str_parse_double(string.c_str(), "%s is not a double");
123 }
124
125 namespace simgrid {
126 namespace smpi {
127
128 namespace replay {
129 class ActionArgParser {
130 public:
131   virtual ~ActionArgParser() = default;
132   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
133 };
134
135 class SendRecvParser : public ActionArgParser {
136 public:
137   /* communication partner; if we send, this is the receiver and vice versa */
138   int partner;
139   double size;
140   int tag;
141   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
142
143   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
144   {
145     CHECK_ACTION_PARAMS(action, 3, 1)
146     partner = std::stoi(action[2]);
147     tag     = std::stoi(action[3]);
148     size    = parse_double(action[4]);
149     if (action.size() > 5)
150       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
151   }
152 };
153
154 class ComputeParser : public ActionArgParser {
155 public:
156   /* communication partner; if we send, this is the receiver and vice versa */
157   double flops;
158
159   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
160   {
161     CHECK_ACTION_PARAMS(action, 1, 0)
162     flops = parse_double(action[2]);
163   }
164 };
165
166 class CollCommParser : public ActionArgParser {
167 public:
168   double size;
169   double comm_size;
170   double comp_size;
171   int send_size;
172   int recv_size;
173   int root = 0;
174   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
175   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
176 };
177
178 class BcastArgParser : public CollCommParser {
179 public:
180   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
181   {
182     CHECK_ACTION_PARAMS(action, 1, 2)
183     size = parse_double(action[2]);
184     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
185     if (action.size() > 4)
186       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
187   }
188 };
189
190 class ReduceArgParser : public CollCommParser {
191 public:
192   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
193   {
194     CHECK_ACTION_PARAMS(action, 2, 2)
195     comm_size = parse_double(action[2]);
196     comp_size = parse_double(action[3]);
197     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
198     if (action.size() > 5)
199       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
200   }
201 };
202
203 class AllReduceArgParser : public CollCommParser {
204 public:
205   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
206   {
207     CHECK_ACTION_PARAMS(action, 2, 1)
208     comm_size = parse_double(action[2]);
209     comp_size = parse_double(action[3]);
210     if (action.size() > 4)
211       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
212   }
213 };
214
215 class AllToAllArgParser : public CollCommParser {
216 public:
217   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
218   {
219     CHECK_ACTION_PARAMS(action, 2, 1)
220     comm_size = MPI_COMM_WORLD->size();
221     send_size = parse_double(action[2]);
222     recv_size = parse_double(action[3]);
223
224     if (action.size() > 4)
225       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
226     if (action.size() > 5)
227       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
228   }
229 };
230
231 class GatherArgParser : public CollCommParser {
232 public:
233   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
234   {
235     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
236           0 gather 68 68 0 0 0
237         where:
238           1) 68 is the sendcounts
239           2) 68 is the recvcounts
240           3) 0 is the root node
241           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
242           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
243     */
244     CHECK_ACTION_PARAMS(action, 2, 3)
245     comm_size = MPI_COMM_WORLD->size();
246     send_size = parse_double(action[2]);
247     recv_size = parse_double(action[3]);
248
249     if (name == "gather") {
250       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
251       if (action.size() > 5)
252         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
253       if (action.size() > 6)
254         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
255     }
256     else {
257       if (action.size() > 4)
258         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
259       if (action.size() > 5)
260         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
261     }
262   }
263 };
264
265 class GatherVArgParser : public CollCommParser {
266 public:
267   int recv_size_sum;
268   std::shared_ptr<std::vector<int>> recvcounts;
269   std::vector<int> disps;
270   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
271   {
272     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
273          0 gather 68 68 10 10 10 0 0 0
274        where:
275          1) 68 is the sendcount
276          2) 68 10 10 10 is the recvcounts
277          3) 0 is the root node
278          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
279          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
280     */
281     comm_size = MPI_COMM_WORLD->size();
282     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
283     send_size = parse_double(action[2]);
284     disps     = std::vector<int>(comm_size, 0);
285     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
286
287     if (name == "gatherV") {
288       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
289       if (action.size() > 4 + comm_size)
290         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
291       if (action.size() > 5 + comm_size)
292         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
293     }
294     else {
295       int datatype_index = 0;
296       int disp_index     = 0;
297       /* The 3 comes from "0 gather <sendcount>", which must always be present.
298        * The + comm_size is the recvcounts array, which must also be present
299        */
300       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
301         datatype_index = 3 + comm_size;
302         disp_index     = datatype_index + 1;
303         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
304         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
305       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
306         disp_index     = 3 + comm_size;
307       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
308         datatype_index = 3 + comm_size;
309         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
310         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
311       }
312
313       if (disp_index != 0) {
314         for (unsigned int i = 0; i < comm_size; i++)
315           disps[i]          = std::stoi(action[disp_index + i]);
316       }
317     }
318
319     for (unsigned int i = 0; i < comm_size; i++) {
320       (*recvcounts)[i] = std::stoi(action[i + 3]);
321     }
322     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
323   }
324 };
325
326 class ScatterArgParser : public CollCommParser {
327 public:
328   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
329   {
330     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
331           0 gather 68 68 0 0 0
332         where:
333           1) 68 is the sendcounts
334           2) 68 is the recvcounts
335           3) 0 is the root node
336           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
337           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
338     */
339     CHECK_ACTION_PARAMS(action, 2, 3)
340     comm_size   = MPI_COMM_WORLD->size();
341     send_size   = parse_double(action[2]);
342     recv_size   = parse_double(action[3]);
343     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
344     if (action.size() > 5)
345       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
346     if (action.size() > 6)
347       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
348   }
349 };
350
351 class ScatterVArgParser : public CollCommParser {
352 public:
353   int recv_size_sum;
354   int send_size_sum;
355   std::shared_ptr<std::vector<int>> sendcounts;
356   std::vector<int> disps;
357   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
358   {
359     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
360        0 gather 68 10 10 10 68 0 0 0
361         where:
362         1) 68 10 10 10 is the sendcounts
363         2) 68 is the recvcount
364         3) 0 is the root node
365         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
366         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
367     */
368     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
369     recv_size  = parse_double(action[2 + comm_size]);
370     disps      = std::vector<int>(comm_size, 0);
371     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
372
373     if (action.size() > 5 + comm_size)
374       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
375     if (action.size() > 5 + comm_size)
376       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
377
378     for (unsigned int i = 0; i < comm_size; i++) {
379       (*sendcounts)[i] = std::stoi(action[i + 2]);
380     }
381     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
382     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
383   }
384 };
385
386 class ReduceScatterArgParser : public CollCommParser {
387 public:
388   int recv_size_sum;
389   std::shared_ptr<std::vector<int>> recvcounts;
390   std::vector<int> disps;
391   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
392   {
393     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
394          0 reduceScatter 275427 275427 275427 204020 11346849 0
395        where:
396          1) The first four values after the name of the action declare the recvcounts array
397          2) The value 11346849 is the amount of instructions
398          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
399     */
400     comm_size = MPI_COMM_WORLD->size();
401     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
402     comp_size = parse_double(action[2+comm_size]);
403     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
404     if (action.size() > 3 + comm_size)
405       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
406
407     for (unsigned int i = 0; i < comm_size; i++) {
408       recvcounts->push_back(std::stoi(action[i + 2]));
409     }
410     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
411   }
412 };
413
414 class AllToAllVArgParser : public CollCommParser {
415 public:
416   int recv_size_sum;
417   int send_size_sum;
418   std::shared_ptr<std::vector<int>> recvcounts;
419   std::shared_ptr<std::vector<int>> sendcounts;
420   std::vector<int> senddisps;
421   std::vector<int> recvdisps;
422   int send_buf_size;
423   int recv_buf_size;
424   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
425   {
426     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
427           0 allToAllV 100 1 7 10 12 100 1 70 10 5
428        where:
429         1) 100 is the size of the send buffer *sizeof(int),
430         2) 1 7 10 12 is the sendcounts array
431         3) 100*sizeof(int) is the size of the receiver buffer
432         4)  1 70 10 5 is the recvcounts array
433     */
434     comm_size = MPI_COMM_WORLD->size();
435     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
436     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
437     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
438     senddisps  = std::vector<int>(comm_size, 0);
439     recvdisps  = std::vector<int>(comm_size, 0);
440
441     if (action.size() > 5 + 2 * comm_size)
442       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
443     if (action.size() > 5 + 2 * comm_size)
444       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
445
446     send_buf_size=parse_double(action[2]);
447     recv_buf_size=parse_double(action[3+comm_size]);
448     for (unsigned int i = 0; i < comm_size; i++) {
449       (*sendcounts)[i] = std::stoi(action[3 + i]);
450       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
451     }
452     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
453     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
454   }
455 };
456
457 template <class T> class ReplayAction {
458 protected:
459   const std::string name;
460   const int my_proc_id;
461   T args;
462
463 public:
464   explicit ReplayAction(std::string name) : name(name), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
465   virtual ~ReplayAction() = default;
466
467   virtual void execute(simgrid::xbt::ReplayAction& action)
468   {
469     // Needs to be re-initialized for every action, hence here
470     double start_time = smpi_process()->simulated_elapsed();
471     args.parse(action, name);
472     kernel(action);
473     if (name != "Init")
474       log_timed_action(action, start_time);
475   }
476
477   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
478
479   void* send_buffer(int size)
480   {
481     return smpi_get_tmp_sendbuffer(size);
482   }
483
484   void* recv_buffer(int size)
485   {
486     return smpi_get_tmp_recvbuffer(size);
487   }
488 };
489
490 class WaitAction : public ReplayAction<ActionArgParser> {
491 public:
492   WaitAction() : ReplayAction("Wait") {}
493   void kernel(simgrid::xbt::ReplayAction& action) override
494   {
495     std::string s = boost::algorithm::join(action, " ");
496     xbt_assert(get_reqq_self()->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
497     MPI_Request request = get_reqq_self()->back();
498     get_reqq_self()->pop_back();
499
500     if (request == nullptr) {
501       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
502        * return.*/
503       return;
504     }
505
506     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
507
508     // Must be taken before Request::wait() since the request may be set to
509     // MPI_REQUEST_NULL by Request::wait!
510     int src                  = request->comm()->group()->rank(request->src());
511     int dst                  = request->comm()->group()->rank(request->dst());
512     int tag                  = request->tag();
513     bool is_wait_for_receive = (request->flags() & RECV);
514     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
515     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
516
517     MPI_Status status;
518     Request::wait(&request, &status);
519
520     TRACE_smpi_comm_out(rank);
521     if (is_wait_for_receive)
522       TRACE_smpi_recv(src, dst, tag);
523   }
524 };
525
526 class SendAction : public ReplayAction<SendRecvParser> {
527 public:
528   SendAction() = delete;
529   explicit SendAction(std::string name) : ReplayAction(name) {}
530   void kernel(simgrid::xbt::ReplayAction& action) override
531   {
532     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
533
534     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
535                                                                              args.tag, Datatype::encode(args.datatype1)));
536     if (not TRACE_smpi_view_internals())
537       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
538
539     if (name == "send") {
540       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
541     } else if (name == "Isend") {
542       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
543       get_reqq_self()->push_back(request);
544     } else {
545       xbt_die("Don't know this action, %s", name.c_str());
546     }
547
548     TRACE_smpi_comm_out(my_proc_id);
549   }
550 };
551
552 class RecvAction : public ReplayAction<SendRecvParser> {
553 public:
554   RecvAction() = delete;
555   explicit RecvAction(std::string name) : ReplayAction(name) {}
556   void kernel(simgrid::xbt::ReplayAction& action) override
557   {
558     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
559
560     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
561                                                                              args.tag, Datatype::encode(args.datatype1)));
562
563     MPI_Status status;
564     // unknown size from the receiver point of view
565     if (args.size <= 0.0) {
566       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
567       args.size = status.count;
568     }
569
570     if (name == "recv") {
571       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
572     } else if (name == "Irecv") {
573       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
574       get_reqq_self()->push_back(request);
575     }
576
577     TRACE_smpi_comm_out(my_proc_id);
578     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
579     if (name == "recv" && not TRACE_smpi_view_internals()) {
580       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
581     }
582   }
583 };
584
585 class ComputeAction : public ReplayAction<ComputeParser> {
586 public:
587   ComputeAction() : ReplayAction("compute") {}
588   void kernel(simgrid::xbt::ReplayAction& action) override
589   {
590     TRACE_smpi_computing_in(my_proc_id, args.flops);
591     smpi_execute_flops(args.flops);
592     TRACE_smpi_computing_out(my_proc_id);
593   }
594 };
595
596 class TestAction : public ReplayAction<ActionArgParser> {
597 public:
598   TestAction() : ReplayAction("Test") {}
599   void kernel(simgrid::xbt::ReplayAction& action) override
600   {
601     MPI_Request request = get_reqq_self()->back();
602     get_reqq_self()->pop_back();
603     // if request is null here, this may mean that a previous test has succeeded
604     // Different times in traced application and replayed version may lead to this
605     // In this case, ignore the extra calls.
606     if (request != nullptr) {
607       TRACE_smpi_testing_in(my_proc_id);
608
609       MPI_Status status;
610       int flag = Request::test(&request, &status);
611
612       XBT_DEBUG("MPI_Test result: %d", flag);
613       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
614        * nullptr.*/
615       get_reqq_self()->push_back(request);
616
617       TRACE_smpi_testing_out(my_proc_id);
618     }
619   }
620 };
621
622 class InitAction : public ReplayAction<ActionArgParser> {
623 public:
624   InitAction() : ReplayAction("Init") {}
625   void kernel(simgrid::xbt::ReplayAction& action) override
626   {
627     CHECK_ACTION_PARAMS(action, 0, 1)
628     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
629                                            : MPI_BYTE;  // default TAU datatype
630
631     /* start a simulated timer */
632     smpi_process()->simulated_start();
633     set_reqq_self(new std::vector<MPI_Request>);
634   }
635 };
636
637 class CommunicatorAction : public ReplayAction<ActionArgParser> {
638 public:
639   CommunicatorAction() : ReplayAction("Comm") {}
640   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
641 };
642
643 class WaitAllAction : public ReplayAction<ActionArgParser> {
644 public:
645   WaitAllAction() : ReplayAction("waitAll") {}
646   void kernel(simgrid::xbt::ReplayAction& action) override
647   {
648     const unsigned int count_requests = get_reqq_self()->size();
649
650     if (count_requests > 0) {
651       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
652       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
653       for (const auto& req : (*get_reqq_self())) {
654         if (req && (req->flags() & RECV)) {
655           sender_receiver.push_back({req->src(), req->dst()});
656         }
657       }
658       MPI_Status status[count_requests];
659       Request::waitall(count_requests, &(*get_reqq_self())[0], status);
660
661       for (auto& pair : sender_receiver) {
662         TRACE_smpi_recv(pair.first, pair.second, 0);
663       }
664       TRACE_smpi_comm_out(my_proc_id);
665     }
666   }
667 };
668
669 class BarrierAction : public ReplayAction<ActionArgParser> {
670 public:
671   BarrierAction() : ReplayAction("barrier") {}
672   void kernel(simgrid::xbt::ReplayAction& action) override
673   {
674     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
675     Colls::barrier(MPI_COMM_WORLD);
676     TRACE_smpi_comm_out(my_proc_id);
677   }
678 };
679
680 class BcastAction : public ReplayAction<BcastArgParser> {
681 public:
682   BcastAction() : ReplayAction("bcast") {}
683   void kernel(simgrid::xbt::ReplayAction& action) override
684   {
685     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
686                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
687                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
688
689     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
690
691     TRACE_smpi_comm_out(my_proc_id);
692   }
693 };
694
695 class ReduceAction : public ReplayAction<ReduceArgParser> {
696 public:
697   ReduceAction() : ReplayAction("reduce") {}
698   void kernel(simgrid::xbt::ReplayAction& action) override
699   {
700     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
701                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
702                                                       args.comp_size, args.comm_size, -1,
703                                                       Datatype::encode(args.datatype1), ""));
704
705     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
706         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
707     smpi_execute_flops(args.comp_size);
708
709     TRACE_smpi_comm_out(my_proc_id);
710   }
711 };
712
713 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
714 public:
715   AllReduceAction() : ReplayAction("allReduce") {}
716   void kernel(simgrid::xbt::ReplayAction& action) override
717   {
718     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
719                                                                                 Datatype::encode(args.datatype1), ""));
720
721     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
722         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
723     smpi_execute_flops(args.comp_size);
724
725     TRACE_smpi_comm_out(my_proc_id);
726   }
727 };
728
729 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
730 public:
731   AllToAllAction() : ReplayAction("allToAll") {}
732   void kernel(simgrid::xbt::ReplayAction& action) override
733   {
734     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
735                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
736                                                     Datatype::encode(args.datatype1),
737                                                     Datatype::encode(args.datatype2)));
738
739     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
740                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
741                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
742
743     TRACE_smpi_comm_out(my_proc_id);
744   }
745 };
746
747 class GatherAction : public ReplayAction<GatherArgParser> {
748 public:
749   explicit GatherAction(std::string name) : ReplayAction(name) {}
750   void kernel(simgrid::xbt::ReplayAction& action) override
751   {
752     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
753                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
754
755     if (name == "gather") {
756       int rank = MPI_COMM_WORLD->rank();
757       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
758                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
759     }
760     else
761       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
762                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
763
764     TRACE_smpi_comm_out(my_proc_id);
765   }
766 };
767
768 class GatherVAction : public ReplayAction<GatherVArgParser> {
769 public:
770   explicit GatherVAction(std::string name) : ReplayAction(name) {}
771   void kernel(simgrid::xbt::ReplayAction& action) override
772   {
773     int rank = MPI_COMM_WORLD->rank();
774
775     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
776                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
777                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
778
779     if (name == "gatherV") {
780       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
781                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
782                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
783     }
784     else {
785       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
786                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
787                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
788     }
789
790     TRACE_smpi_comm_out(my_proc_id);
791   }
792 };
793
794 class ScatterAction : public ReplayAction<ScatterArgParser> {
795 public:
796   ScatterAction() : ReplayAction("scatter") {}
797   void kernel(simgrid::xbt::ReplayAction& action) override
798   {
799     int rank = MPI_COMM_WORLD->rank();
800     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
801                                                                           Datatype::encode(args.datatype1),
802                                                                           Datatype::encode(args.datatype2)));
803
804     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
805                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
806
807     TRACE_smpi_comm_out(my_proc_id);
808   }
809 };
810
811
812 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
813 public:
814   ScatterVAction() : ReplayAction("scatterV") {}
815   void kernel(simgrid::xbt::ReplayAction& action) override
816   {
817     int rank = MPI_COMM_WORLD->rank();
818     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
819           nullptr, Datatype::encode(args.datatype1),
820           Datatype::encode(args.datatype2)));
821
822     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
823                     args.sendcounts->data(), args.disps.data(), args.datatype1,
824                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
825                     MPI_COMM_WORLD);
826
827     TRACE_smpi_comm_out(my_proc_id);
828   }
829 };
830
831 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
832 public:
833   ReduceScatterAction() : ReplayAction("reduceScatter") {}
834   void kernel(simgrid::xbt::ReplayAction& action) override
835   {
836     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
837                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
838                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
839                                                          Datatype::encode(args.datatype1)));
840
841     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
842                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
843                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
844
845     smpi_execute_flops(args.comp_size);
846     TRACE_smpi_comm_out(my_proc_id);
847   }
848 };
849
850 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
851 public:
852   AllToAllVAction() : ReplayAction("allToAllV") {}
853   void kernel(simgrid::xbt::ReplayAction& action) override
854   {
855     TRACE_smpi_comm_in(my_proc_id, __func__,
856                        new simgrid::instr::VarCollTIData(
857                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
858                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
859
860     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
861                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
862
863     TRACE_smpi_comm_out(my_proc_id);
864   }
865 };
866 } // Replay Namespace
867 }} // namespace simgrid::smpi
868
869 /** @brief Only initialize the replay, don't do it for real */
870 void smpi_replay_init(int* argc, char*** argv)
871 {
872   simgrid::smpi::Process::init(argc, argv);
873   smpi_process()->mark_as_initialized();
874   smpi_process()->set_replaying(true);
875
876   int my_proc_id = simgrid::s4u::this_actor::get_pid();
877   TRACE_smpi_init(my_proc_id);
878   TRACE_smpi_computing_init(my_proc_id);
879   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
880   TRACE_smpi_comm_out(my_proc_id);
881   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
882   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
883   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
884   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
885   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
886
887   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send").execute(action); });
888   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend").execute(action); });
889   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv").execute(action); });
890   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv").execute(action); });
891   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction().execute(action); });
892   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction().execute(action); });
893   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction().execute(action); });
894   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
895   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
896   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
897   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
898   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
899   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
900   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
901   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
902   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
903   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
904   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
905   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
906   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
907   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
908
909   //if we have a delayed start, sleep here.
910   if(*argc>2){
911     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
912     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
913     smpi_execute_flops(value);
914   } else {
915     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
916     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
917     smpi_execute_flops(0.0);
918   }
919 }
920
921 /** @brief actually run the replay after initialization */
922 void smpi_replay_main(int* argc, char*** argv)
923 {
924   static int active_processes = 0;
925   active_processes++;
926   simgrid::xbt::replay_runner(*argc, *argv);
927
928   /* and now, finalize everything */
929   /* One active process will stop. Decrease the counter*/
930   XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
931   if (not get_reqq_self()->empty()) {
932     unsigned int count_requests=get_reqq_self()->size();
933     MPI_Request requests[count_requests];
934     MPI_Status status[count_requests];
935     unsigned int i=0;
936
937     for (auto const& req : *get_reqq_self()) {
938       requests[i] = req;
939       i++;
940     }
941     simgrid::smpi::Request::waitall(count_requests, requests, status);
942   }
943   delete get_reqq_self();
944   active_processes--;
945
946   if(active_processes==0){
947     /* Last process alive speaking: end the simulated timer */
948     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
949     smpi_free_replay_tmp_buffers();
950   }
951
952   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
953                      new simgrid::instr::NoOpTIData("finalize"));
954
955   smpi_process()->finalize();
956
957   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
958   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
959 }
960
961 /** @brief chain a replay initialization and a replay start */
962 void smpi_replay_run(int* argc, char*** argv)
963 {
964   smpi_replay_init(argc, argv);
965   smpi_replay_main(argc, argv);
966 }