Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
Merge branch 'master' of ssh://scm.gforge.inria.fr/gitroot/simgrid/simgrid
[simgrid.git] / src / smpi / internals / smpi_replay.cpp
1 /* Copyright (c) 2009-2017. The SimGrid Team. All rights reserved.          */
2
3 /* This program is free software; you can redistribute it and/or modify it
4  * under the terms of the license (GNU LGPL) which comes with this package. */
5
6 #include "private.hpp"
7 #include "smpi_coll.hpp"
8 #include "smpi_comm.hpp"
9 #include "smpi_datatype.hpp"
10 #include "smpi_group.hpp"
11 #include "smpi_process.hpp"
12 #include "smpi_request.hpp"
13 #include "xbt/replay.hpp"
14
15 #include <boost/algorithm/string/join.hpp>
16 #include <memory>
17 #include <numeric>
18 #include <unordered_map>
19 #include <vector>
20
21 using simgrid::s4u::Actor;
22
23 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_replay,smpi,"Trace Replay with SMPI");
24
25 static int active_processes  = 0;
26 static std::unordered_map<int, std::vector<MPI_Request>*> reqq;
27
28 static MPI_Datatype MPI_DEFAULT_TYPE;
29
30 #define CHECK_ACTION_PARAMS(action, mandatory, optional)                                                               \
31   {                                                                                                                    \
32     if (action.size() < static_cast<unsigned long>(mandatory + 2))                                                     \
33       THROWF(arg_error, 0, "%s replay failed.\n"                                                                       \
34                            "%lu items were given on the line. First two should be process_id and action.  "            \
35                            "This action needs after them %lu mandatory arguments, and accepts %lu optional ones. \n"   \
36                            "Please contact the Simgrid team if support is needed",                                     \
37              __FUNCTION__, action.size(), static_cast<unsigned long>(mandatory),                                       \
38              static_cast<unsigned long>(optional));                                                                    \
39   }
40
41 static void log_timed_action(simgrid::xbt::ReplayAction& action, double clock)
42 {
43   if (XBT_LOG_ISENABLED(smpi_replay, xbt_log_priority_verbose)){
44     std::string s = boost::algorithm::join(action, " ");
45     XBT_VERB("%s %f", s.c_str(), smpi_process()->simulated_elapsed() - clock);
46   }
47 }
48
49 static std::vector<MPI_Request>* get_reqq_self()
50 {
51   return reqq.at(Actor::self()->getPid());
52 }
53
54 static void set_reqq_self(std::vector<MPI_Request> *mpi_request)
55 {
56    reqq.insert({Actor::self()->getPid(), mpi_request});
57 }
58
59 /* Helper function */
60 static double parse_double(std::string string)
61 {
62   return xbt_str_parse_double(string.c_str(), "%s is not a double");
63 }
64
65 namespace simgrid {
66 namespace smpi {
67
68 namespace Replay {
69 class ActionArgParser {
70 public:
71   virtual void parse(simgrid::xbt::ReplayAction& action, std::string name) { CHECK_ACTION_PARAMS(action, 0, 0) }
72 };
73
74 class SendRecvParser : public ActionArgParser {
75 public:
76   /* communication partner; if we send, this is the receiver and vice versa */
77   int partner;
78   double size;
79   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
80
81   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
82   {
83     CHECK_ACTION_PARAMS(action, 2, 1)
84     partner = std::stoi(action[2]);
85     size    = parse_double(action[3]);
86     if (action.size() > 4)
87       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
88   }
89 };
90
91 class ComputeParser : public ActionArgParser {
92 public:
93   /* communication partner; if we send, this is the receiver and vice versa */
94   double flops;
95
96   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
97   {
98     CHECK_ACTION_PARAMS(action, 1, 0)
99     flops = parse_double(action[2]);
100   }
101 };
102
103 class CollCommParser : public ActionArgParser {
104 public:
105   double size;
106   double comm_size;
107   double comp_size;
108   int send_size;
109   int recv_size;
110   int root = 0;
111   MPI_Datatype datatype1 = MPI_DEFAULT_TYPE;
112   MPI_Datatype datatype2 = MPI_DEFAULT_TYPE;
113 };
114
115 class BcastArgParser : public CollCommParser {
116 public:
117   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
118   {
119     CHECK_ACTION_PARAMS(action, 1, 2)
120     size = parse_double(action[2]);
121     root = (action.size() > 3) ? std::stoi(action[3]) : 0;
122     if (action.size() > 4)
123       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
124   }
125 };
126
127 class ReduceArgParser : public CollCommParser {
128 public:
129   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
130   {
131     CHECK_ACTION_PARAMS(action, 2, 2)
132     comm_size = parse_double(action[2]);
133     comp_size = parse_double(action[3]);
134     root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
135     if (action.size() > 5)
136       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
137   }
138 };
139
140 class AllReduceArgParser : public CollCommParser {
141 public:
142   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
143   {
144     CHECK_ACTION_PARAMS(action, 2, 1)
145     comm_size = parse_double(action[2]);
146     comp_size = parse_double(action[3]);
147     if (action.size() > 4)
148       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
149   }
150 };
151
152 class AllToAllArgParser : public CollCommParser {
153 public:
154   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
155   {
156     CHECK_ACTION_PARAMS(action, 2, 1)
157     comm_size = MPI_COMM_WORLD->size();
158     send_size = parse_double(action[2]);
159     recv_size = parse_double(action[3]);
160
161     if (action.size() > 4)
162       datatype1 = simgrid::smpi::Datatype::decode(action[4]);
163     if (action.size() > 5)
164       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
165   }
166 };
167
168 class GatherArgParser : public CollCommParser {
169 public:
170   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
171   {
172     /* The structure of the gather action for the rank 0 (total 4 processes) is the following:
173           0 gather 68 68 0 0 0
174         where:
175           1) 68 is the sendcounts
176           2) 68 is the recvcounts
177           3) 0 is the root node
178           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
179           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
180     */
181     CHECK_ACTION_PARAMS(action, 2, 3)
182     comm_size = MPI_COMM_WORLD->size();
183     send_size = parse_double(action[2]);
184     recv_size = parse_double(action[3]);
185
186     if (name == "gather") {
187       root      = (action.size() > 4) ? std::stoi(action[4]) : 0;
188       if (action.size() > 5)
189         datatype1 = simgrid::smpi::Datatype::decode(action[5]);
190       if (action.size() > 6)
191         datatype2 = simgrid::smpi::Datatype::decode(action[6]);
192     }
193     else {
194       if (action.size() > 4)
195         datatype1 = simgrid::smpi::Datatype::decode(action[4]);
196       if (action.size() > 5)
197         datatype2 = simgrid::smpi::Datatype::decode(action[5]);
198     }
199   }
200 };
201
202 class GatherVArgParser : public CollCommParser {
203 public:
204   int recv_size_sum;
205   std::shared_ptr<std::vector<int>> recvcounts;
206   std::vector<int> disps;
207   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
208   {
209     /* The structure of the gatherv action for the rank 0 (total 4 processes) is the following:
210          0 gather 68 68 10 10 10 0 0 0
211        where:
212          1) 68 is the sendcount
213          2) 68 10 10 10 is the recvcounts
214          3) 0 is the root node
215          4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
216          5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
217     */
218     comm_size = MPI_COMM_WORLD->size();
219     CHECK_ACTION_PARAMS(action, comm_size+1, 2)
220     send_size = parse_double(action[2]);
221     disps     = std::vector<int>(comm_size, 0);
222     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
223
224     if (name == "gatherV") {
225       root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
226       if (action.size() > 4 + comm_size)
227         datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
228       if (action.size() > 5 + comm_size)
229         datatype2 = simgrid::smpi::Datatype::decode(action[5 + comm_size]);
230     }
231     else {
232       int datatype_index = 0, disp_index = 0;
233       if (action.size() > 3 + 2 * comm_size) { /* datatype + disp are specified */
234         datatype_index = 3 + comm_size;
235         disp_index     = datatype_index + 1;
236       } else if (action.size() > 3 + 2 * comm_size) { /* disps specified; datatype is not specified; use the default one */
237         datatype_index = -1;
238         disp_index     = 3 + comm_size;
239       } else if (action.size() > 3 + comm_size) { /* only datatype, no disp specified */
240         datatype_index = 3 + comm_size;
241       }
242
243       if (disp_index != 0) {
244         for (unsigned int i = 0; i < comm_size; i++)
245           disps[i]          = std::stoi(action[disp_index + i]);
246       }
247
248       datatype1 = simgrid::smpi::Datatype::decode(action[datatype_index]);
249       datatype2 = simgrid::smpi::Datatype::decode(action[datatype_index]);
250     }
251
252     for (unsigned int i = 0; i < comm_size; i++) {
253       (*recvcounts)[i] = std::stoi(action[i + 3]);
254     }
255     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
256   }
257 };
258
259 class ScatterArgParser : public CollCommParser {
260 public:
261   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
262   {
263     /* The structure of the scatter action for the rank 0 (total 4 processes) is the following:
264           0 gather 68 68 0 0 0
265         where:
266           1) 68 is the sendcounts
267           2) 68 is the recvcounts
268           3) 0 is the root node
269           4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
270           5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
271     */
272     CHECK_ACTION_PARAMS(action, 2, 3)
273     comm_size   = MPI_COMM_WORLD->size();
274     send_size   = parse_double(action[2]);
275     recv_size   = parse_double(action[3]);
276     root   = (action.size() > 4) ? std::stoi(action[4]) : 0;
277     if (action.size() > 5)
278       datatype1 = simgrid::smpi::Datatype::decode(action[5]);
279     if (action.size() > 6)
280       datatype2 = simgrid::smpi::Datatype::decode(action[6]);
281   }
282 };
283
284 class ScatterVArgParser : public CollCommParser {
285 public:
286   int recv_size_sum;
287   int send_size_sum;
288   std::shared_ptr<std::vector<int>> sendcounts;
289   std::vector<int> disps;
290   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
291   {
292     /* The structure of the scatterv action for the rank 0 (total 4 processes) is the following:
293        0 gather 68 10 10 10 68 0 0 0
294         where:
295         1) 68 10 10 10 is the sendcounts
296         2) 68 is the recvcount
297         3) 0 is the root node
298         4) 0 is the send datatype id, see simgrid::smpi::Datatype::decode()
299         5) 0 is the recv datatype id, see simgrid::smpi::Datatype::decode()
300     */
301     CHECK_ACTION_PARAMS(action, comm_size + 1, 2)
302     recv_size  = parse_double(action[2 + comm_size]);
303     disps      = std::vector<int>(comm_size, 0);
304     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
305
306     if (action.size() > 5 + comm_size)
307       datatype1 = simgrid::smpi::Datatype::decode(action[4 + comm_size]);
308     if (action.size() > 5 + comm_size)
309       datatype2 = simgrid::smpi::Datatype::decode(action[5]);
310
311     for (unsigned int i = 0; i < comm_size; i++) {
312       (*sendcounts)[i] = std::stoi(action[i + 2]);
313     }
314     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
315     root = (action.size() > 3 + comm_size) ? std::stoi(action[3 + comm_size]) : 0;
316   }
317 };
318
319 class ReduceScatterArgParser : public CollCommParser {
320 public:
321   int recv_size_sum;
322   std::shared_ptr<std::vector<int>> recvcounts;
323   std::vector<int> disps;
324   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
325   {
326     /* The structure of the reducescatter action for the rank 0 (total 4 processes) is the following:
327          0 reduceScatter 275427 275427 275427 204020 11346849 0
328        where:
329          1) The first four values after the name of the action declare the recvcounts array
330          2) The value 11346849 is the amount of instructions
331          3) The last value corresponds to the datatype, see simgrid::smpi::Datatype::decode().
332     */
333     comm_size = MPI_COMM_WORLD->size();
334     CHECK_ACTION_PARAMS(action, comm_size+1, 1)
335     comp_size = parse_double(action[2+comm_size]);
336     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
337     if (action.size() > 3 + comm_size)
338       datatype1 = simgrid::smpi::Datatype::decode(action[3 + comm_size]);
339
340     for (unsigned int i = 0; i < comm_size; i++) {
341       recvcounts->push_back(std::stoi(action[i + 2]));
342     }
343     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
344   }
345 };
346
347 class AllToAllVArgParser : public CollCommParser {
348 public:
349   int recv_size_sum;
350   int send_size_sum;
351   std::shared_ptr<std::vector<int>> recvcounts;
352   std::shared_ptr<std::vector<int>> sendcounts;
353   std::vector<int> senddisps;
354   std::vector<int> recvdisps;
355   int send_buf_size;
356   int recv_buf_size;
357   void parse(simgrid::xbt::ReplayAction& action, std::string name) override
358   {
359     /* The structure of the allToAllV action for the rank 0 (total 4 processes) is the following:
360           0 allToAllV 100 1 7 10 12 100 1 70 10 5
361        where:
362         1) 100 is the size of the send buffer *sizeof(int),
363         2) 1 7 10 12 is the sendcounts array
364         3) 100*sizeof(int) is the size of the receiver buffer
365         4)  1 70 10 5 is the recvcounts array
366     */
367     comm_size = MPI_COMM_WORLD->size();
368     CHECK_ACTION_PARAMS(action, 2*comm_size+2, 2)
369     sendcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
370     recvcounts = std::shared_ptr<std::vector<int>>(new std::vector<int>(comm_size));
371     senddisps  = std::vector<int>(comm_size, 0);
372     recvdisps  = std::vector<int>(comm_size, 0);
373
374     if (action.size() > 5 + 2 * comm_size)
375       datatype1 = simgrid::smpi::Datatype::decode(action[4 + 2 * comm_size]);
376     if (action.size() > 5 + 2 * comm_size)
377       datatype2 = simgrid::smpi::Datatype::decode(action[5 + 2 * comm_size]);
378
379     send_buf_size=parse_double(action[2]);
380     recv_buf_size=parse_double(action[3+comm_size]);
381     for (unsigned int i = 0; i < comm_size; i++) {
382       (*sendcounts)[i] = std::stoi(action[3 + i]);
383       (*recvcounts)[i] = std::stoi(action[4 + comm_size + i]);
384     }
385     send_size_sum = std::accumulate(sendcounts->begin(), sendcounts->end(), 0);
386     recv_size_sum = std::accumulate(recvcounts->begin(), recvcounts->end(), 0);
387   }
388 };
389
390 template <class T> class ReplayAction {
391 protected:
392   const std::string name;
393   T args;
394
395   int my_proc_id;
396
397 public:
398   explicit ReplayAction(std::string name) : name(name), my_proc_id(simgrid::s4u::Actor::self()->getPid()) {}
399
400   virtual void execute(simgrid::xbt::ReplayAction& action)
401   {
402     // Needs to be re-initialized for every action, hence here
403     double start_time = smpi_process()->simulated_elapsed();
404     args.parse(action, name);
405     kernel(action);
406     if (name != "Init")
407       log_timed_action(action, start_time);
408   }
409
410   virtual void kernel(simgrid::xbt::ReplayAction& action) = 0;
411
412   void* send_buffer(int size)
413   {
414     return smpi_get_tmp_sendbuffer(size);
415   }
416
417   void* recv_buffer(int size)
418   {
419     return smpi_get_tmp_recvbuffer(size);
420   }
421 };
422
423 class WaitAction : public ReplayAction<ActionArgParser> {
424 public:
425   WaitAction() : ReplayAction("Wait") {}
426   void kernel(simgrid::xbt::ReplayAction& action) override
427   {
428     std::string s = boost::algorithm::join(action, " ");
429     xbt_assert(get_reqq_self()->size(), "action wait not preceded by any irecv or isend: %s", s.c_str());
430     MPI_Request request = get_reqq_self()->back();
431     get_reqq_self()->pop_back();
432
433     if (request == nullptr) {
434       /* Assume that the trace is well formed, meaning the comm might have been caught by a MPI_test. Then just
435        * return.*/
436       return;
437     }
438
439     int rank = request->comm() != MPI_COMM_NULL ? request->comm()->rank() : -1;
440
441     // Must be taken before Request::wait() since the request may be set to
442     // MPI_REQUEST_NULL by Request::wait!
443     int src                  = request->comm()->group()->rank(request->src());
444     int dst                  = request->comm()->group()->rank(request->dst());
445     bool is_wait_for_receive = (request->flags() & RECV);
446     // TODO: Here we take the rank while we normally take the process id (look for my_proc_id)
447     TRACE_smpi_comm_in(rank, __FUNCTION__, new simgrid::instr::NoOpTIData("wait"));
448
449     MPI_Status status;
450     Request::wait(&request, &status);
451
452     TRACE_smpi_comm_out(rank);
453     if (is_wait_for_receive)
454       TRACE_smpi_recv(src, dst, 0);
455   }
456 };
457
458 class SendAction : public ReplayAction<SendRecvParser> {
459 public:
460   SendAction() = delete;
461   SendAction(std::string name) : ReplayAction(name) {}
462   void kernel(simgrid::xbt::ReplayAction& action) override
463   {
464     int dst_traced = MPI_COMM_WORLD->group()->actor(args.partner)->getPid();
465
466     TRACE_smpi_comm_in(my_proc_id, __FUNCTION__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
467                                                                                  Datatype::encode(args.datatype1)));
468     if (not TRACE_smpi_view_internals())
469       TRACE_smpi_send(my_proc_id, my_proc_id, dst_traced, 0, args.size * args.datatype1->size());
470
471     if (name == "send") {
472       Request::send(nullptr, args.size, args.datatype1, args.partner, 0, MPI_COMM_WORLD);
473     } else if (name == "Isend") {
474       MPI_Request request = Request::isend(nullptr, args.size, args.datatype1, args.partner, 0, MPI_COMM_WORLD);
475       get_reqq_self()->push_back(request);
476     } else {
477       xbt_die("Don't know this action, %s", name.c_str());
478     }
479
480     TRACE_smpi_comm_out(my_proc_id);
481   }
482 };
483
484 class RecvAction : public ReplayAction<SendRecvParser> {
485 public:
486   RecvAction() = delete;
487   explicit RecvAction(std::string name) : ReplayAction(name) {}
488   void kernel(simgrid::xbt::ReplayAction& action) override
489   {
490     int src_traced = MPI_COMM_WORLD->group()->actor(args.partner)->getPid();
491
492     TRACE_smpi_comm_in(my_proc_id, __FUNCTION__, new simgrid::instr::Pt2PtTIData(name, args.partner, args.size,
493                                                                                  Datatype::encode(args.datatype1)));
494
495     MPI_Status status;
496     // unknown size from the receiver point of view
497     if (args.size <= 0.0) {
498       Request::probe(args.partner, 0, MPI_COMM_WORLD, &status);
499       args.size = status.count;
500     }
501
502     if (name == "recv") {
503       Request::recv(nullptr, args.size, args.datatype1, args.partner, 0, MPI_COMM_WORLD, &status);
504     } else if (name == "Irecv") {
505       MPI_Request request = Request::irecv(nullptr, args.size, args.datatype1, args.partner, 0, MPI_COMM_WORLD);
506       get_reqq_self()->push_back(request);
507     }
508
509     TRACE_smpi_comm_out(my_proc_id);
510     // TODO: Check why this was only activated in the "recv" case and not in the "Irecv" case
511     if (name == "recv" && not TRACE_smpi_view_internals()) {
512       TRACE_smpi_recv(src_traced, my_proc_id, 0);
513     }
514   }
515 };
516
517 class ComputeAction : public ReplayAction<ComputeParser> {
518 public:
519   ComputeAction() : ReplayAction("compute") {}
520   void kernel(simgrid::xbt::ReplayAction& action) override
521   {
522     TRACE_smpi_computing_in(my_proc_id, args.flops);
523     smpi_execute_flops(args.flops);
524     TRACE_smpi_computing_out(my_proc_id);
525   }
526 };
527
528 class TestAction : public ReplayAction<ActionArgParser> {
529 public:
530   TestAction() : ReplayAction("Test") {}
531   void kernel(simgrid::xbt::ReplayAction& action) override
532   {
533     MPI_Request request = get_reqq_self()->back();
534     get_reqq_self()->pop_back();
535     // if request is null here, this may mean that a previous test has succeeded
536     // Different times in traced application and replayed version may lead to this
537     // In this case, ignore the extra calls.
538     if (request != nullptr) {
539       TRACE_smpi_testing_in(my_proc_id);
540
541       MPI_Status status;
542       int flag = Request::test(&request, &status);
543
544       XBT_DEBUG("MPI_Test result: %d", flag);
545       /* push back request in vector to be caught by a subsequent wait. if the test did succeed, the request is now
546        * nullptr.*/
547       get_reqq_self()->push_back(request);
548
549       TRACE_smpi_testing_out(my_proc_id);
550     }
551   }
552 };
553
554 class InitAction : public ReplayAction<ActionArgParser> {
555 public:
556   InitAction() : ReplayAction("Init") {}
557   void kernel(simgrid::xbt::ReplayAction& action) override
558   {
559     CHECK_ACTION_PARAMS(action, 0, 1)
560     MPI_DEFAULT_TYPE = (action.size() > 2) ? MPI_DOUBLE // default MPE datatype
561                                            : MPI_BYTE;  // default TAU datatype
562
563     /* start a simulated timer */
564     smpi_process()->simulated_start();
565     /*initialize the number of active processes */
566     active_processes = smpi_process_count();
567
568     set_reqq_self(new std::vector<MPI_Request>);
569   }
570 };
571
572 class CommunicatorAction : public ReplayAction<ActionArgParser> {
573 public:
574   CommunicatorAction() : ReplayAction("Comm") {}
575   void kernel(simgrid::xbt::ReplayAction& action) override { /* nothing to do */}
576 };
577
578 class WaitAllAction : public ReplayAction<ActionArgParser> {
579 public:
580   WaitAllAction() : ReplayAction("waitAll") {}
581   void kernel(simgrid::xbt::ReplayAction& action) override
582   {
583     const unsigned int count_requests = get_reqq_self()->size();
584
585     if (count_requests > 0) {
586       TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
587                          new simgrid::instr::Pt2PtTIData("waitAll", -1, count_requests, ""));
588       std::vector<std::pair</*sender*/int,/*recv*/int>> sender_receiver;
589       for (const auto& req : (*get_reqq_self())) {
590         if (req && (req->flags() & RECV)) {
591           sender_receiver.push_back({req->src(), req->dst()});
592         }
593       }
594       MPI_Status status[count_requests];
595       Request::waitall(count_requests, &(*get_reqq_self())[0], status);
596
597       for (auto& pair : sender_receiver) {
598         TRACE_smpi_recv(pair.first, pair.second, 0);
599       }
600       TRACE_smpi_comm_out(my_proc_id);
601     }
602   }
603 };
604
605 class BarrierAction : public ReplayAction<ActionArgParser> {
606 public:
607   BarrierAction() : ReplayAction("barrier") {}
608   void kernel(simgrid::xbt::ReplayAction& action) override
609   {
610     TRACE_smpi_comm_in(my_proc_id, __FUNCTION__, new simgrid::instr::NoOpTIData("barrier"));
611     Colls::barrier(MPI_COMM_WORLD);
612     TRACE_smpi_comm_out(my_proc_id);
613   }
614 };
615
616 class BcastAction : public ReplayAction<BcastArgParser> {
617 public:
618   BcastAction() : ReplayAction("bcast") {}
619   void kernel(simgrid::xbt::ReplayAction& action) override
620   {
621     TRACE_smpi_comm_in(my_proc_id, "action_bcast",
622                        new simgrid::instr::CollTIData("bcast", MPI_COMM_WORLD->group()->actor(args.root)->getPid(),
623                                                       -1.0, args.size, -1, Datatype::encode(args.datatype1), ""));
624
625     Colls::bcast(send_buffer(args.size * args.datatype1->size()), args.size, args.datatype1, args.root, MPI_COMM_WORLD);
626
627     TRACE_smpi_comm_out(my_proc_id);
628   }
629 };
630
631 class ReduceAction : public ReplayAction<ReduceArgParser> {
632 public:
633   ReduceAction() : ReplayAction("reduce") {}
634   void kernel(simgrid::xbt::ReplayAction& action) override
635   {
636     TRACE_smpi_comm_in(my_proc_id, "action_reduce",
637                        new simgrid::instr::CollTIData("reduce", MPI_COMM_WORLD->group()->actor(args.root)->getPid(), args.comp_size,
638                                                       args.comm_size, -1, Datatype::encode(args.datatype1), ""));
639
640     Colls::reduce(send_buffer(args.comm_size * args.datatype1->size()),
641         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, args.root, MPI_COMM_WORLD);
642     smpi_execute_flops(args.comp_size);
643
644     TRACE_smpi_comm_out(my_proc_id);
645   }
646 };
647
648 class AllReduceAction : public ReplayAction<AllReduceArgParser> {
649 public:
650   AllReduceAction() : ReplayAction("allReduce") {}
651   void kernel(simgrid::xbt::ReplayAction& action) override
652   {
653     TRACE_smpi_comm_in(my_proc_id, "action_allReduce", new simgrid::instr::CollTIData("allReduce", -1, args.comp_size, args.comm_size, -1,
654                                                                                 Datatype::encode(args.datatype1), ""));
655
656     Colls::allreduce(send_buffer(args.comm_size * args.datatype1->size()),
657         recv_buffer(args.comm_size * args.datatype1->size()), args.comm_size, args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
658     smpi_execute_flops(args.comp_size);
659
660     TRACE_smpi_comm_out(my_proc_id);
661   }
662 };
663
664 class AllToAllAction : public ReplayAction<AllToAllArgParser> {
665 public:
666   AllToAllAction() : ReplayAction("allToAll") {}
667   void kernel(simgrid::xbt::ReplayAction& action) override
668   {
669     TRACE_smpi_comm_in(my_proc_id, "action_allToAll",
670                      new simgrid::instr::CollTIData("allToAll", -1, -1.0, args.send_size, args.recv_size,
671                                                     Datatype::encode(args.datatype1),
672                                                     Datatype::encode(args.datatype2)));
673
674     Colls::alltoall(send_buffer(args.send_size*args.comm_size* args.datatype1->size()), 
675       args.send_size, args.datatype1, recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()),
676       args.recv_size, args.datatype2, MPI_COMM_WORLD);
677
678     TRACE_smpi_comm_out(my_proc_id);
679   }
680 };
681
682 class GatherAction : public ReplayAction<GatherArgParser> {
683 public:
684   GatherAction(std::string name) : ReplayAction(name) {}
685   void kernel(simgrid::xbt::ReplayAction& action) override
686   {
687     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::CollTIData(name, (name == "gather") ? args.root : -1, -1.0, args.send_size, args.recv_size,
688                                                                           Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
689
690     if (name == "gather") {
691       int rank = MPI_COMM_WORLD->rank();
692       Colls::gather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
693                  (rank == args.root) ? recv_buffer(args.recv_size * args.comm_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
694     }
695     else
696       Colls::allgather(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
697                        recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, MPI_COMM_WORLD);
698
699     TRACE_smpi_comm_out(my_proc_id);
700   }
701 };
702
703 class GatherVAction : public ReplayAction<GatherVArgParser> {
704 public:
705   GatherVAction(std::string name) : ReplayAction(name) {}
706   void kernel(simgrid::xbt::ReplayAction& action) override
707   {
708     int rank = MPI_COMM_WORLD->rank();
709
710     TRACE_smpi_comm_in(my_proc_id, name.c_str(), new simgrid::instr::VarCollTIData(
711                                                name, (name == "gatherV") ? args.root : -1, args.send_size, nullptr, -1, args.recvcounts,
712                                                Datatype::encode(args.datatype1), Datatype::encode(args.datatype2)));
713
714     if (name == "gatherV") {
715       Colls::gatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1, 
716                      (rank == args.root) ? recv_buffer(args.recv_size_sum  * args.datatype2->size()) : nullptr, args.recvcounts->data(), args.disps.data(), args.datatype2, args.root,
717                      MPI_COMM_WORLD);
718     }
719     else {
720       Colls::allgatherv(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1, 
721                         recv_buffer(args.recv_size_sum * args.datatype2->size()), args.recvcounts->data(), args.disps.data(), args.datatype2,
722                     MPI_COMM_WORLD);
723     }
724
725     TRACE_smpi_comm_out(my_proc_id);
726   }
727 };
728
729 class ScatterAction : public ReplayAction<ScatterArgParser> {
730 public:
731   ScatterAction() : ReplayAction("scatter") {}
732   void kernel(simgrid::xbt::ReplayAction& action) override
733   {
734     int rank = MPI_COMM_WORLD->rank();
735     TRACE_smpi_comm_in(my_proc_id, "action_scatter", new simgrid::instr::CollTIData(name, args.root, -1.0, args.send_size, args.recv_size,
736                                                                           Datatype::encode(args.datatype1),
737                                                                           Datatype::encode(args.datatype2)));
738
739     Colls::scatter(send_buffer(args.send_size * args.datatype1->size()), args.send_size, args.datatype1,
740                   (rank == args.root) ? recv_buffer(args.recv_size * args.datatype2->size()) : nullptr, args.recv_size, args.datatype2, args.root, MPI_COMM_WORLD);
741
742     TRACE_smpi_comm_out(my_proc_id);
743   }
744 };
745
746
747 class ScatterVAction : public ReplayAction<ScatterVArgParser> {
748 public:
749   ScatterVAction() : ReplayAction("scatterV") {}
750   void kernel(simgrid::xbt::ReplayAction& action) override
751   {
752     int rank = MPI_COMM_WORLD->rank();
753     TRACE_smpi_comm_in(my_proc_id, "action_scatterv", new simgrid::instr::VarCollTIData(name, args.root, -1, args.sendcounts, args.recv_size,
754           nullptr, Datatype::encode(args.datatype1),
755           Datatype::encode(args.datatype2)));
756
757     Colls::scatterv((rank == args.root) ? send_buffer(args.send_size_sum * args.datatype1->size()) : nullptr, args.sendcounts->data(), args.disps.data(), 
758         args.datatype1, recv_buffer(args.recv_size * args.datatype2->size()), args.recv_size, args.datatype2, args.root,
759         MPI_COMM_WORLD);
760
761     TRACE_smpi_comm_out(my_proc_id);
762   }
763 };
764
765 class ReduceScatterAction : public ReplayAction<ReduceScatterArgParser> {
766 public:
767   ReduceScatterAction() : ReplayAction("reduceScatter") {}
768   void kernel(simgrid::xbt::ReplayAction& action) override
769   {
770     TRACE_smpi_comm_in(my_proc_id, "action_reducescatter",
771                        new simgrid::instr::VarCollTIData("reduceScatter", -1, 0, nullptr, -1, args.recvcounts,
772                                                          std::to_string(args.comp_size), /* ugly hack to print comp_size */
773                                                          Datatype::encode(args.datatype1)));
774
775     Colls::reduce_scatter(send_buffer(args.recv_size_sum * args.datatype1->size()), recv_buffer(args.recv_size_sum * args.datatype1->size()), 
776                           args.recvcounts->data(), args.datatype1, MPI_OP_NULL, MPI_COMM_WORLD);
777
778     smpi_execute_flops(args.comp_size);
779     TRACE_smpi_comm_out(my_proc_id);
780   }
781 };
782
783 class AllToAllVAction : public ReplayAction<AllToAllVArgParser> {
784 public:
785   AllToAllVAction() : ReplayAction("allToAllV") {}
786   void kernel(simgrid::xbt::ReplayAction& action) override
787   {
788     TRACE_smpi_comm_in(my_proc_id, __FUNCTION__,
789         new simgrid::instr::VarCollTIData("allToAllV", -1, args.send_size_sum, args.sendcounts, args.recv_size_sum, args.recvcounts,
790           Datatype::encode(args.datatype1),
791           Datatype::encode(args.datatype2)));
792
793     Colls::alltoallv(send_buffer(args.send_buf_size * args.datatype1->size()), args.sendcounts->data(), args.senddisps.data(), args.datatype1,
794                      recv_buffer(args.recv_buf_size * args.datatype2->size()), args.recvcounts->data(), args.recvdisps.data(), args.datatype2, MPI_COMM_WORLD);
795
796     TRACE_smpi_comm_out(my_proc_id);
797   }
798 };
799 } // Replay Namespace
800 }} // namespace simgrid::smpi
801
802 /** @brief Only initialize the replay, don't do it for real */
803 void smpi_replay_init(int* argc, char*** argv)
804 {
805   simgrid::smpi::Process::init(argc, argv);
806   smpi_process()->mark_as_initialized();
807   smpi_process()->set_replaying(true);
808
809   int my_proc_id = Actor::self()->getPid();
810   TRACE_smpi_init(my_proc_id);
811   TRACE_smpi_computing_init(my_proc_id);
812   TRACE_smpi_comm_in(my_proc_id, "smpi_replay_run_init", new simgrid::instr::NoOpTIData("init"));
813   TRACE_smpi_comm_out(my_proc_id);
814   xbt_replay_action_register("init", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::InitAction().execute(action); });
815   xbt_replay_action_register("finalize", [](simgrid::xbt::ReplayAction& action) { /* nothing to do */ });
816   xbt_replay_action_register("comm_size", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::CommunicatorAction().execute(action); });
817   xbt_replay_action_register("comm_split",[](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::CommunicatorAction().execute(action); });
818   xbt_replay_action_register("comm_dup",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::CommunicatorAction().execute(action); });
819
820   xbt_replay_action_register("send",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::SendAction("send").execute(action); });
821   xbt_replay_action_register("Isend", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::SendAction("Isend").execute(action); });
822   xbt_replay_action_register("recv",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::RecvAction("recv").execute(action); });
823   xbt_replay_action_register("Irecv", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::RecvAction("Irecv").execute(action); });
824   xbt_replay_action_register("test",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::TestAction().execute(action); });
825   xbt_replay_action_register("wait",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::WaitAction().execute(action); });
826   xbt_replay_action_register("waitAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::WaitAllAction().execute(action); });
827   xbt_replay_action_register("barrier", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::BarrierAction().execute(action); });
828   xbt_replay_action_register("bcast",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::BcastAction().execute(action); });
829   xbt_replay_action_register("reduce",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::ReduceAction().execute(action); });
830   xbt_replay_action_register("allReduce", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::AllReduceAction().execute(action); });
831   xbt_replay_action_register("allToAll", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::AllToAllAction().execute(action); });
832   xbt_replay_action_register("allToAllV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::AllToAllVAction().execute(action); });
833   xbt_replay_action_register("gather",   [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::GatherAction("gather").execute(action); });
834   xbt_replay_action_register("scatter",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::ScatterAction().execute(action); });
835   xbt_replay_action_register("gatherV",  [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::GatherVAction("gatherV").execute(action); });
836   xbt_replay_action_register("scatterV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::ScatterVAction().execute(action); });
837   xbt_replay_action_register("allGather", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::GatherAction("allGather").execute(action); });
838   xbt_replay_action_register("allGatherV", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::GatherVAction("allGatherV").execute(action); });
839   xbt_replay_action_register("reduceScatter", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::ReduceScatterAction().execute(action); });
840   xbt_replay_action_register("compute", [](simgrid::xbt::ReplayAction& action) { simgrid::smpi::Replay::ComputeAction().execute(action); });
841
842   //if we have a delayed start, sleep here.
843   if(*argc>2){
844     double value = xbt_str_parse_double((*argv)[2], "%s is not a double");
845     XBT_VERB("Delayed start for instance - Sleeping for %f flops ",value );
846     smpi_execute_flops(value);
847   } else {
848     //UGLY: force a context switch to be sure that all MSG_processes begin initialization
849     XBT_DEBUG("Force context switch by smpi_execute_flops  - Sleeping for 0.0 flops ");
850     smpi_execute_flops(0.0);
851   }
852 }
853
854 /** @brief actually run the replay after initialization */
855 void smpi_replay_main(int* argc, char*** argv)
856 {
857   simgrid::xbt::replay_runner(*argc, *argv);
858
859   /* and now, finalize everything */
860   /* One active process will stop. Decrease the counter*/
861   XBT_DEBUG("There are %zu elements in reqq[*]", get_reqq_self()->size());
862   if (not get_reqq_self()->empty()) {
863     unsigned int count_requests=get_reqq_self()->size();
864     MPI_Request requests[count_requests];
865     MPI_Status status[count_requests];
866     unsigned int i=0;
867
868     for (auto const& req : *get_reqq_self()) {
869       requests[i] = req;
870       i++;
871     }
872     simgrid::smpi::Request::waitall(count_requests, requests, status);
873   }
874   delete get_reqq_self();
875   active_processes--;
876
877   if(active_processes==0){
878     /* Last process alive speaking: end the simulated timer */
879     XBT_INFO("Simulation time %f", smpi_process()->simulated_elapsed());
880     smpi_free_replay_tmp_buffers();
881   }
882
883   TRACE_smpi_comm_in(Actor::self()->getPid(), "smpi_replay_run_finalize", new simgrid::instr::NoOpTIData("finalize"));
884
885   smpi_process()->finalize();
886
887   TRACE_smpi_comm_out(Actor::self()->getPid());
888   TRACE_smpi_finalize(Actor::self()->getPid());
889 }
890
891 /** @brief chain a replay initialization and a replay start */
892 void smpi_replay_run(int* argc, char*** argv)
893 {
894   smpi_replay_init(argc, argv);
895   smpi_replay_main(argc, argv);
896 }