Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
[SMPI] Replay: Cosmetics. (Align '\' in a macro)
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2018. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <sstream>
20 #include <vector>
21
22 using simgrid::s4u::Actor;
23
24 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
25
26 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
27
28 static MPI_Datatype MPI_DEFAULT_TYPE;
29
30 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
31   {                                                                                                                    \
32     if (action.size() < static_cast<unsigned long>(mandatory + 2)) {                                                   \
33       std::stringstream ss;                                                                                            \
34       for (const auto& elem : action) {                                                                                \
35         ss << elem << " ";                                                                                             \
36       }                                                                                                                \
37       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
38                            "%zu items were given on the line. First two should be process_id and action.  "            \
39                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
40                            "The full line that was given is:\n   %s\n"                                                 \
41                            "Please contact the Simgrid team if support is needed",                                     \
42              __func__, action.size(), static_cast<unsigned long>(mandatory), static_cast<unsigned long>(optional),     \
43              ss.str().c_str());                                                                                        \
44     }                                                                                                                  \
45   }
46
47 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
48 {
49   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
50     std::string s = boost::algorithm::join(action, " ");
51     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
52   }
53 }
54
55 static std::vector<MPI_Request>* get_reqq_self()
56 {
57   return reqq.at(simgrid::s4u::this_actor::get_pid());
58 }
59
60 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
61 {
62   reqq.insert({simgrid::s4u::this_actor::get_pid(), mpi_request});
63 }
64
65 /* Helper function */
66 static double parse_double(std::string string)
67 {
68   return xbt_str_parse_double(string.c_str(), "%s is not a double");
69 }
70
71 namespace simgrid {
72 namespace smpi {
73
74 namespace replay {
75 class ActionArgParser {
76 public:
77   virtual ~ActionArgParser() = default;
78   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
79 };
80
81 class SendRecvParser : public ActionArgParser {
82 public:
83   /* communication partner; if we send, this is the receiver and vice versa */
84   int partner;
85   double size;
86   int tag;
87   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
88
89   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
90   {
91     CHECK_ACTION_PARAMS(action, 3, 1)
92     partner = std::stoi(action[2]);
93     tag     = std::stoi(action[3]);
94     size    = parse_double(action[4]);
95     if (action.size() > 5)
96       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
97   }
98 };
99
100 class ComputeParser : public ActionArgParser {
101 public:
102   /* communication partner; if we send, this is the receiver and vice versa */
103   double flops;
104
105   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
106   {
107     CHECK_ACTION_PARAMS(action, 1, 0)
108     flops = parse_double(action[2]);
109   }
110 };
111
112 class CollCommParser : public ActionArgParser {
113 public:
114   double size;
115   double comm_size;
116   double comp_size;
117   int send_size;
118   int recv_size;
119   int root = 0;
120   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
121   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
122 };
123
124 class BcastArgParser : public CollCommParser {
125 public:
126   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
127   {
128     CHECK_ACTION_PARAMS(action, 1, 2)
129     size = parse_double(action[2]);
130     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
131     if (action.size() > 4)
132       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
133   }
134 };
135
136 class ReduceArgParser : public CollCommParser {
137 public:
138   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
139   {
140     CHECK_ACTION_PARAMS(action, 2, 2)
141     comm_size = parse_double(action[2]);
142     comp_size = parse_double(action[3]);
143     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
144     if (action.size() > 5)
145       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
146   }
147 };
148
149 class AllReduceArgParser : public CollCommParser {
150 public:
151   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
152   {
153     CHECK_ACTION_PARAMS(action, 2, 1)
154     comm_size = parse_double(action[2]);
155     comp_size = parse_double(action[3]);
156     if (action.size() > 4)
157       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
158   }
159 };
160
161 class AllToAllArgParser : public CollCommParser {
162 public:
163   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
164   {
165     CHECK_ACTION_PARAMS(action, 2, 1)
166     comm_size = MPI_COMM_WORLD->size();
167     send_size = parse_double(action[2]);
168     recv_size = parse_double(action[3]);
169
170     if (action.size() > 4)
171       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
172     if (action.size() > 5)
173       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
174   }
175 };
176
177 class GatherArgParser : public CollCommParser {
178 public:
179   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
180   {
181     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
182           0 gather 68 68 0 0 0
183         where:
184           1) 68 is the sendcounts
185           2) 68 is the recvcounts
186           3) 0 is the root node
187           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
188           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
189     */
190     CHECK_ACTION_PARAMS(action, 2, 3)
191     comm_size = MPI_COMM_WORLD->size();
192     send_size = parse_double(action[2]);
193     recv_size = parse_double(action[3]);
194
195     if (name == "gather") {
196       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
197       if (action.size() > 5)
198         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
199       if (action.size() > 6)
200         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
201     }
202     else {
203       if (action.size() > 4)
204         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
205       if (action.size() > 5)
206         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
207     }
208   }
209 };
210
211 class GatherVArgParser : public CollCommParser {
212 public:
213   int recv_size_sum;
214   std::shared_ptr<std::vector<int>> recvcounts;
215   std::vector<int> disps;
216   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
217   {
218     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
219          0 gather 68 68 10 10 10 0 0 0
220        where:
221          1) 68 is the sendcount
222          2) 68 10 10 10 is the recvcounts
223          3) 0 is the root node
224          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
225          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
226     */
227     comm_size = MPI_COMM_WORLD->size();
228     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
229     send_size = parse_double(action[2]);
230     disps     = std::vector<int>(comm_size, 0);
231     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
232
233     if (name == "gatherV") {
234       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
235       if (action.size() > 4 + comm_size)
236         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
237       if (action.size() > 5 + comm_size)
238         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
239     }
240     else {
241       int datatype_index = 0;
242       int disp_index     = 0;
243       /* The 3 comes from "0 gather <sendcount>", which must always be present.
244        * The + comm_size is the recvcounts array, which must also be present
245        */
246       if (action.size() > 3 + comm_size + comm_size) { /* datatype + disp are specified */
247         datatype_index = 3 + comm_size;
248         disp_index     = datatype_index + 1;
249         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
250         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
251       } else if (action.size() > 3 + comm_size + 2) { /* disps specified; datatype is not specified; use the default one */
252         disp_index     = 3 + comm_size;
253       } else if (action.size() > 3 + comm_size)  { /* only datatype, no disp specified */
254         datatype_index = 3 + comm_size;
255         datatype1      = simgrid::smpi::Datatype::decode(action[datatype_index]);
256         datatype2      = simgrid::smpi::Datatype::decode(action[datatype_index]);
257       }
258
259       if (disp_index != 0) {
260         for (unsigned int i = 0; i < comm_size; i++)
261           disps[i]          = std::stoi(action[disp_index + i]);
262       }
263     }
264
265     for (unsigned int i = 0; i < comm_size; i++) {
266       (*recvcounts)[i] = std::stoi(action[i + 3]);
267     }
268     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
269   }
270 };
271
272 class ScatterArgParser : public CollCommParser {
273 public:
274   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
275   {
276     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
277           0 gather 68 68 0 0 0
278         where:
279           1) 68 is the sendcounts
280           2) 68 is the recvcounts
281           3) 0 is the root node
282           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
283           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
284     */
285     CHECK_ACTION_PARAMS(action, 2, 3)
286     comm_size   = MPI_COMM_WORLD->size();
287     send_size   = parse_double(action[2]);
288     recv_size   = parse_double(action[3]);
289     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
290     if (action.size() > 5)
291       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
292     if (action.size() > 6)
293       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
294   }
295 };
296
297 class ScatterVArgParser : public CollCommParser {
298 public:
299   int recv_size_sum;
300   int send_size_sum;
301   std::shared_ptr<std::vector<int>> sendcounts;
302   std::vector<int> disps;
303   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
304   {
305     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
306        0 gather 68 10 10 10 68 0 0 0
307         where:
308         1) 68 10 10 10 is the sendcounts
309         2) 68 is the recvcount
310         3) 0 is the root node
311         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
312         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
313     */
314     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
315     recv_size  = parse_double(action[2 + comm_size]);
316     disps      = std::vector<int>(comm_size, 0);
317     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
318
319     if (action.size() > 5 + comm_size)
320       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
321     if (action.size() > 5 + comm_size)
322       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
323
324     for (unsigned int i = 0; i < comm_size; i++) {
325       (*sendcounts)[i] = std::stoi(action[i + 2]);
326     }
327     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
328     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
329   }
330 };
331
332 class ReduceScatterArgParser : public CollCommParser {
333 public:
334   int recv_size_sum;
335   std::shared_ptr<std::vector<int>> recvcounts;
336   std::vector<int> disps;
337   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
338   {
339     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
340          0 reduceScatter 275427 275427 275427 204020 11346849 0
341        where:
342          1) The first four values after the name of the action declare the recvcounts array
343          2) The value 11346849 is the amount of instructions
344          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
345     */
346     comm_size = MPI_COMM_WORLD->size();
347     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
348     comp_size = parse_double(action[2+comm_size]);
349     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
350     if (action.size() > 3 + comm_size)
351       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
352
353     for (unsigned int i = 0; i < comm_size; i++) {
354       recvcounts->push_back(std::stoi(action[i + 2]));
355     }
356     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
357   }
358 };
359
360 class AllToAllVArgParser : public CollCommParser {
361 public:
362   int recv_size_sum;
363   int send_size_sum;
364   std::shared_ptr<std::vector<int>> recvcounts;
365   std::shared_ptr<std::vector<int>> sendcounts;
366   std::vector<int> senddisps;
367   std::vector<int> recvdisps;
368   int send_buf_size;
369   int recv_buf_size;
370   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
371   {
372     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
373           0 allToAllV 100 1 7 10 12 100 1 70 10 5
374        where:
375         1) 100 is the size of the send buffer *sizeof(int),
376         2) 1 7 10 12 is the sendcounts array
377         3) 100*sizeof(int) is the size of the receiver buffer
378         4)  1 70 10 5 is the recvcounts array
379     */
380     comm_size = MPI_COMM_WORLD->size();
381     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
382     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
383     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
384     senddisps  = std::vector<int>(comm_size, 0);
385     recvdisps  = std::vector<int>(comm_size, 0);
386
387     if (action.size() > 5 + 2 * comm_size)
388       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
389     if (action.size() > 5 + 2 * comm_size)
390       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
391
392     send_buf_size=parse_double(action[2]);
393     recv_buf_size=parse_double(action[3+comm_size]);
394     for (unsigned int i = 0; i < comm_size; i++) {
395       (*sendcounts)[i] = std::stoi(action[3 + i]);
396       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
397     }
398     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
399     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
400   }
401 };
402
403 template <class T> class ReplayAction {
404 protected:
405   const std::string name;
406   const int my_proc_id;
407   T args;
408
409 public:
410   explicit ReplayAction(std::string name) : name(name), my_proc_id(simgrid::s4u::this_actor::get_pid()) {}
411   virtual ~ReplayAction() = default;
412
413   virtual void execute(simgrid::xbt::ReplayAction& action)
414   {
415     // Needs to be re-initialized for every action, hence here
416     double start_time = smpi_process()->simulated_elapsed();
417     args.parse(action, name);
418     kernel(action);
419     if (name != "Init")
420       log_timed_action(action, start_time);
421   }
422
423   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
424
425   void* send_buffer(int size)
426   {
427     return smpi_get_tmp_sendbuffer(size);
428   }
429
430   void* recv_buffer(int size)
431   {
432     return smpi_get_tmp_recvbuffer(size);
433   }
434 };
435
436 class WaitAction : public ReplayAction<ActionArgParser> {
437 public:
438   WaitAction() : ReplayAction("Wait") {}
439   void kernel(simgrid::xbt::ReplayAction& action) override
440   {
441     std::string s = boost::algorithm::join(action, " ");
442     xbt_assert(get_reqq_self()->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
443     MPI_Request request = get_reqq_self()->back();
444     get_reqq_self()->pop_back();
445
446     if (request == nullptr) {
447       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
448        * return.*/
449       return;
450     }
451
452     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
453
454     // Must be taken before Request::wait() since the request may be set to
455     // MPI_REQUEST_NULL by Request::wait!
456     int src                  = request->comm()->group()->rank(request->src());
457     int dst                  = request->comm()->group()->rank(request->dst());
458     int tag                  = request->tag();
459     bool is_wait_for_receive = (request->flags() & RECV);
460     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
461     TRACE_smpi_comm_in(rank, __func__, new simgrid::instr::NoOpTIData("wait"));
462
463     MPI_Status status;
464     Request::wait(&request, &status);
465
466     TRACE_smpi_comm_out(rank);
467     if (is_wait_for_receive)
468       TRACE_smpi_recv(src, dst, tag);
469   }
470 };
471
472 class SendAction : public ReplayAction<SendRecvParser> {
473 public:
474   SendAction() = delete;
475   explicit SendAction(std::string name) : ReplayAction(name) {}
476   void kernel(simgrid::xbt::ReplayAction& action) override
477   {
478     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
479
480     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
481                                                                              args.tag, Datatype::encode(args.datatype1)));
482     if (not TRACE_smpi_view_internals())
483       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, args.tag, args.size * args.datatype1->size());
484
485     if (name == "send") {
486       Request::send(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
487     } else if (name == "Isend") {
488       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
489       get_reqq_self()->push_back(request);
490     } else {
491       xbt_die("Don't know this action, %s", name.c_str());
492     }
493
494     TRACE_smpi_comm_out(my_proc_id);
495   }
496 };
497
498 class RecvAction : public ReplayAction<SendRecvParser> {
499 public:
500   RecvAction() = delete;
501   explicit RecvAction(std::string name) : ReplayAction(name) {}
502   void kernel(simgrid::xbt::ReplayAction& action) override
503   {
504     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->get_pid();
505
506     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
507                                                                              args.tag, Datatype::encode(args.datatype1)));
508
509     MPI_Status status;
510     // unknown size from the receiver point of view
511     if (args.size <= 0.0) {
512       Request::probe(args.partner, args.tag, MPI_COMM_WORLD, &status);
513       args.size = status.count;
514     }
515
516     if (name == "recv") {
517       Request::recv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD, &status);
518     } else if (name == "Irecv") {
519       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, args.tag, MPI_COMM_WORLD);
520       get_reqq_self()->push_back(request);
521     }
522
523     TRACE_smpi_comm_out(my_proc_id);
524     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
525     if (name == "recv" && not TRACE_smpi_view_internals()) {
526       TRACE_smpi_recv(src_traced, my_proc_id, args.tag);
527     }
528   }
529 };
530
531 class ComputeAction : public ReplayAction<ComputeParser> {
532 public:
533   ComputeAction() : ReplayAction("compute") {}
534   void kernel(simgrid::xbt::ReplayAction& action) override
535   {
536     TRACE_smpi_computing_in(my_proc_id, args.flops);
537     smpi_execute_flops(args.flops);
538     TRACE_smpi_computing_out(my_proc_id);
539   }
540 };
541
542 class TestAction : public ReplayAction<ActionArgParser> {
543 public:
544   TestAction() : ReplayAction("Test") {}
545   void kernel(simgrid::xbt::ReplayAction& action) override
546   {
547     MPI_Request request = get_reqq_self()->back();
548     get_reqq_self()->pop_back();
549     // if request is null here, this may mean that a previous test has succeeded
550     // Different times in traced application and replayed version may lead to this
551     // In this case, ignore the extra calls.
552     if (request != nullptr) {
553       TRACE_smpi_testing_in(my_proc_id);
554
555       MPI_Status status;
556       int flag = Request::test(&request, &status);
557
558       XBT_DEBUG("MPI_Test result: %d", flag);
559       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
560        * nullptr.*/
561       get_reqq_self()->push_back(request);
562
563       TRACE_smpi_testing_out(my_proc_id);
564     }
565   }
566 };
567
568 class InitAction : public ReplayAction<ActionArgParser> {
569 public:
570   InitAction() : ReplayAction("Init") {}
571   void kernel(simgrid::xbt::ReplayAction& action) override
572   {
573     CHECK_ACTION_PARAMS(action, 0, 1)
574     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
575                                            : MPI_BYTE;  // default TAU datatype
576
577     /* start a simulated timer */
578     smpi_process()->simulated_start();
579     set_reqq_self(new std::vector<MPI_Request>);
580   }
581 };
582
583 class CommunicatorAction : public ReplayAction<ActionArgParser> {
584 public:
585   CommunicatorAction() : ReplayAction("Comm") {}
586   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
587 };
588
589 class WaitAllAction : public ReplayAction<ActionArgParser> {
590 public:
591   WaitAllAction() : ReplayAction("waitAll") {}
592   void kernel(simgrid::xbt::ReplayAction& action) override
593   {
594     const unsigned int count_requests = get_reqq_self()->size();
595
596     if (count_requests > 0) {
597       TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
598       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
599       for (const auto& req : (*get_reqq_self())) {
600         if (req && (req->flags() & RECV)) {
601           sender_receiver.push_back({req->src(), req->dst()});
602         }
603       }
604       MPI_Status status[count_requests];
605       Request::waitall(count_requests, &(*get_reqq_self())[0], status);
606
607       for (auto& pair : sender_receiver) {
608         TRACE_smpi_recv(pair.first, pair.second, 0);
609       }
610       TRACE_smpi_comm_out(my_proc_id);
611     }
612   }
613 };
614
615 class BarrierAction : public ReplayAction<ActionArgParser> {
616 public:
617   BarrierAction() : ReplayAction("barrier") {}
618   void kernel(simgrid::xbt::ReplayAction& action) override
619   {
620     TRACE_smpi_comm_in(my_proc_id, __func__, new simgrid::instr::NoOpTIData("barrier"));
621     Colls::barrier(MPI_COMM_WORLD);
622     TRACE_smpi_comm_out(my_proc_id);
623   }
624 };
625
626 class BcastAction : public ReplayAction<BcastArgParser> {
627 public:
628   BcastAction() : ReplayAction("bcast") {}
629   void kernel(simgrid::xbt::ReplayAction& action) override
630   {
631     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
632                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
633                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
634
635     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
636
637     TRACE_smpi_comm_out(my_proc_id);
638   }
639 };
640
641 class ReduceAction : public ReplayAction<ReduceArgParser> {
642 public:
643   ReduceAction() : ReplayAction("reduce") {}
644   void kernel(simgrid::xbt::ReplayAction& action) override
645   {
646     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
647                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->get_pid(),
648                                                       args.comp_size, args.comm_size, -1,
649                                                       Datatype::encode(args.datatype1), ""));
650
651     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
652         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
653     smpi_execute_flops(args.comp_size);
654
655     TRACE_smpi_comm_out(my_proc_id);
656   }
657 };
658
659 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
660 public:
661   AllReduceAction() : ReplayAction("allReduce") {}
662   void kernel(simgrid::xbt::ReplayAction& action) override
663   {
664     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
665                                                                                 Datatype::encode(args.datatype1), ""));
666
667     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
668         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
669     smpi_execute_flops(args.comp_size);
670
671     TRACE_smpi_comm_out(my_proc_id);
672   }
673 };
674
675 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
676 public:
677   AllToAllAction() : ReplayAction("allToAll") {}
678   void kernel(simgrid::xbt::ReplayAction& action) override
679   {
680     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
681                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
682                                                     Datatype::encode(args.datatype1),
683                                                     Datatype::encode(args.datatype2)));
684
685     Colls::alltoall(send_buffer(args.send_size * args.comm_size * args.datatype1->size()), args.send_size,
686                     args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
687                     args.recv_size, args.datatype2, MPI_COMM_WORLD);
688
689     TRACE_smpi_comm_out(my_proc_id);
690   }
691 };
692
693 class GatherAction : public ReplayAction<GatherArgParser> {
694 public:
695   explicit GatherAction(std::string name) : ReplayAction(name) {}
696   void kernel(simgrid::xbt::ReplayAction& action) override
697   {
698     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
699                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
700
701     if (name == "gather") {
702       int rank = MPI_COMM_WORLD->rank();
703       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
704                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
705     }
706     else
707       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
708                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
709
710     TRACE_smpi_comm_out(my_proc_id);
711   }
712 };
713
714 class GatherVAction : public ReplayAction<GatherVArgParser> {
715 public:
716   explicit GatherVAction(std::string name) : ReplayAction(name) {}
717   void kernel(simgrid::xbt::ReplayAction& action) override
718   {
719     int rank = MPI_COMM_WORLD->rank();
720
721     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
722                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
723                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
724
725     if (name == "gatherV") {
726       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
727                      (rank == args.root) ? recv_buffer(args.recv_size_sum * args.datatype2->size()) : nullptr,
728                      args.recvcounts->data(), args.disps.data(), args.datatype2, args.root, MPI_COMM_WORLD);
729     }
730     else {
731       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
732                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(),
733                         args.disps.data(), args.datatype2, MPI_COMM_WORLD);
734     }
735
736     TRACE_smpi_comm_out(my_proc_id);
737   }
738 };
739
740 class ScatterAction : public ReplayAction<ScatterArgParser> {
741 public:
742   ScatterAction() : ReplayAction("scatter") {}
743   void kernel(simgrid::xbt::ReplayAction& action) override
744   {
745     int rank = MPI_COMM_WORLD->rank();
746     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
747                                                                           Datatype::encode(args.datatype1),
748                                                                           Datatype::encode(args.datatype2)));
749
750     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
751                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
752
753     TRACE_smpi_comm_out(my_proc_id);
754   }
755 };
756
757
758 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
759 public:
760   ScatterVAction() : ReplayAction("scatterV") {}
761   void kernel(simgrid::xbt::ReplayAction& action) override
762   {
763     int rank = MPI_COMM_WORLD->rank();
764     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
765           nullptr, Datatype::encode(args.datatype1),
766           Datatype::encode(args.datatype2)));
767
768     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr,
769                     args.sendcounts->data(), args.disps.data(), args.datatype1,
770                     recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
771                     MPI_COMM_WORLD);
772
773     TRACE_smpi_comm_out(my_proc_id);
774   }
775 };
776
777 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
778 public:
779   ReduceScatterAction() : ReplayAction("reduceScatter") {}
780   void kernel(simgrid::xbt::ReplayAction& action) override
781   {
782     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
783                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
784                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
785                                                          Datatype::encode(args.datatype1)));
786
787     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()),
788                           recv_buffer(args.recv_size_sum * args.datatype1->size()), args.recvcounts->data(),
789                           args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
790
791     smpi_execute_flops(args.comp_size);
792     TRACE_smpi_comm_out(my_proc_id);
793   }
794 };
795
796 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
797 public:
798   AllToAllVAction() : ReplayAction("allToAllV") {}
799   void kernel(simgrid::xbt::ReplayAction& action) override
800   {
801     TRACE_smpi_comm_in(my_proc_id, __func__,
802                        new simgrid::instr::VarCollTIData(
803                            "allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
804                            Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
805
806     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
807                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
808
809     TRACE_smpi_comm_out(my_proc_id);
810   }
811 };
812 } // Replay Namespace
813 }} // namespace simgrid::smpi
814
815 /** @brief Only initialize the replay, don't do it for real */
816 void smpi_replay_init(int* argc, char*** argv)
817 {
818   simgrid::smpi::Process::init(argc, argv);
819   smpi_process()->mark_as_initialized();
820   smpi_process()->set_replaying(true);
821
822   int my_proc_id = simgrid::s4u::this_actor::get_pid();
823   TRACE_smpi_init(my_proc_id);
824   TRACE_smpi_computing_init(my_proc_id);
825   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
826   TRACE_smpi_comm_out(my_proc_id);
827   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::InitAction().execute(action); });
828   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
829   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
830   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
831   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::CommunicatorAction().execute(action); });
832
833   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("send").execute(action); });
834   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::SendAction("Isend").execute(action); });
835   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("recv").execute(action); });
836   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::RecvAction("Irecv").execute(action); });
837   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::TestAction().execute(action); });
838   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAction().execute(action); });
839   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::WaitAllAction().execute(action); });
840   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BarrierAction().execute(action); });
841   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::BcastAction().execute(action); });
842   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceAction().execute(action); });
843   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllReduceAction().execute(action); });
844   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllAction().execute(action); });
845   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::AllToAllVAction().execute(action); });
846   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("gather").execute(action); });
847   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterAction().execute(action); });
848   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("gatherV").execute(action); });
849   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ScatterVAction().execute(action); });
850   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherAction("allGather").execute(action); });
851   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::GatherVAction("allGatherV").execute(action); });
852   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ReduceScatterAction().execute(action); });
853   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::replay::ComputeAction().execute(action); });
854
855   //if we have a delayed start, sleep here.
856   if(*argc>2){
857     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
858     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
859     smpi_execute_flops(value);
860   } else {
861     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
862     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
863     smpi_execute_flops(0.0);
864   }
865 }
866
867 /** @brief actually run the replay after initialization */
868 void smpi_replay_main(int* argc, char*** argv)
869 {
870   static int active_processes = 0;
871   active_processes++;
872   simgrid::xbt::replay_runner(*argc, *argv);
873
874   /* and now, finalize everything */
875   /* One active process will stop. Decrease the counter*/
876   XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
877   if (not get_reqq_self()->empty()) {
878     unsigned int count_requests=get_reqq_self()->size();
879     MPI_Request requests[count_requests];
880     MPI_Status status[count_requests];
881     unsigned int i=0;
882
883     for (auto const& req : *get_reqq_self()) {
884       requests[i] = req;
885       i++;
886     }
887     simgrid::smpi::Request::waitall(count_requests, requests, status);
888   }
889   delete get_reqq_self();
890   active_processes--;
891
892   if(active_processes==0){
893     /* Last process alive speaking: end the simulated timer */
894     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
895     smpi_free_replay_tmp_buffers();
896   }
897
898   TRACE_smpi_comm_in(simgrid::s4u::this_actor::get_pid(), "smpi_replay_run_finalize",
899                      new simgrid::instr::NoOpTIData("finalize"));
900
901   smpi_process()->finalize();
902
903   TRACE_smpi_comm_out(simgrid::s4u::this_actor::get_pid());
904   TRACE_smpi_finalize(simgrid::s4u::this_actor::get_pid());
905 }
906
907 /** @brief chain a replay initialization and a replay start */
908 void smpi_replay_run(int* argc, char*** argv)
909 {
910   smpi_replay_init(argc, argv);
911   smpi_replay_main(argc, argv);
912 }