Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
More doc for SMPI
[simgrid.git] / src / smpi / colls / smpi_coll.cpp
index 6a213c0..cf22e98 100644 (file)
@@ -1,6 +1,6 @@
 /* smpi_coll.c -- various optimized routing for collectives                 */
 
-/* Copyright (c) 2009-2020. The SimGrid Team. All rights reserved.          */
+/* Copyright (c) 2009-2023. The SimGrid Team. All rights reserved.          */
 
 /* This program is free software; you can redistribute it and/or modify it
  * under the terms of the license (GNU LGPL) which comes with this package. */
 #include "smpi_request.hpp"
 #include "xbt/config.hpp"
 
+#include <iomanip>
 #include <map>
+#include <numeric>
+#include <sstream>
 
 XBT_LOG_NEW_DEFAULT_SUBCATEGORY(smpi_coll, smpi, "Logging specific to SMPI collectives.");
 
-namespace simgrid {
-namespace smpi {
+namespace simgrid::smpi {
 
-std::map<std::string, std::vector<s_mpi_coll_description_t>> smpi_coll_descriptions(
-    {{std::string("gather"),
+std::map<std::string, std::vector<s_mpi_coll_description_t>, std::less<>> smpi_coll_descriptions(
+    {{"gather",
       {{"default", "gather default collective", (void*)gather__default},
        {"ompi", "gather ompi collective", (void*)gather__ompi},
        {"ompi_basic_linear", "gather ompi_basic_linear collective", (void*)gather__ompi_basic_linear},
@@ -103,6 +105,7 @@ std::map<std::string, std::vector<s_mpi_coll_description_t>> smpi_coll_descripti
        {"ompi_basic_recursivehalving", "reduce_scatter ompi_basic_recursivehalving collective",
         (void*)reduce_scatter__ompi_basic_recursivehalving},
        {"ompi_ring", "reduce_scatter ompi_ring collective", (void*)reduce_scatter__ompi_ring},
+       {"ompi_butterfly", "reduce_scatter ompi_butterfly collective", (void*)reduce_scatter__ompi_butterfly},
        {"mpich", "reduce_scatter mpich collective", (void*)reduce_scatter__mpich},
        {"mpich_pair", "reduce_scatter mpich_pair collective", (void*)reduce_scatter__mpich_pair},
        {"mpich_rdb", "reduce_scatter mpich_rdb collective", (void*)reduce_scatter__mpich_rdb},
@@ -115,6 +118,7 @@ std::map<std::string, std::vector<s_mpi_coll_description_t>> smpi_coll_descripti
       {{"default", "scatter default collective", (void*)scatter__default},
        {"ompi", "scatter ompi collective", (void*)scatter__ompi},
        {"ompi_basic_linear", "scatter ompi_basic_linear collective", (void*)scatter__ompi_basic_linear},
+       {"ompi_linear_nb", "scatter ompi_linear nonblocking collective", (void*)scatter__ompi_linear_nb},
        {"ompi_binomial", "scatter ompi_binomial collective", (void*)scatter__ompi_binomial},
        {"mpich", "scatter mpich collective", (void*)scatter__mpich},
        {"mvapich2", "scatter mvapich2 collective", (void*)scatter__mvapich2},
@@ -237,16 +241,37 @@ std::map<std::string, std::vector<s_mpi_coll_description_t>> smpi_coll_descripti
 std::vector<s_mpi_coll_description_t>* colls::get_smpi_coll_descriptions(const std::string& name)
 {
   auto iter = smpi_coll_descriptions.find(name);
-  if (iter == smpi_coll_descriptions.end())
-    xbt_die("No collective named %s. This is a bug.", name.c_str());
+  xbt_assert(iter != smpi_coll_descriptions.end(), "No collective named %s. This is a bug.", name.c_str());
   return &iter->second;
 }
 
+std::string colls::get_smpi_coll_help()
+{
+  size_t max_name_len =
+      std::accumulate(begin(smpi_coll_descriptions), end(smpi_coll_descriptions), 0, [](auto len, auto const& coll) {
+        return std::max(len, std::accumulate(begin(coll.second), end(coll.second), 0, [](auto len, auto const& descr) {
+                          return std::max<size_t>(len, descr.name.length());
+                        }));
+      });
+  std::ostringstream oss;
+  std::string title = "Available collective algorithms (select them with \"smpi/collective_name:algo_name\"):";
+  oss << title << '\n' << std::setfill('=') << std::setw(title.length() + 1) << '\n';
+  for (auto const& [coll, algos] : smpi_coll_descriptions) {
+    std::string line = "Collective: \"" + coll + "\"";
+    oss << line << '\n' << std::setfill('-') << std::setw(line.length() + 1) << '\n';
+    oss << std::setfill(' ') << std::left;
+    for (auto const& [name, descr, _] : algos)
+      oss << "  " << std::setw(max_name_len) << name << " " << descr << "\n";
+    oss << std::right << '\n';
+  }
+  oss << "Please see https://simgrid.org/doc/latest/app_smpi.html#available-algorithms for more information.\n";
+  return oss.str();
+}
+
 static s_mpi_coll_description_t* find_coll_description(const std::string& collective, const std::string& algo)
 {
   std::vector<s_mpi_coll_description_t>* table = colls::get_smpi_coll_descriptions(collective);
-  if (table->empty())
-    xbt_die("No registered algorithm for collective '%s'! This is a bug.", collective.c_str());
+  xbt_assert(not table->empty(), "No registered algorithm for collective '%s'! This is a bug.", collective.c_str());
 
   for (auto& desc : *table) {
     if (algo == desc.name) {
@@ -259,7 +284,8 @@ static s_mpi_coll_description_t* find_coll_description(const std::string& collec
   std::string name_list = table->at(0).name;
   for (unsigned long i = 1; i < table->size(); i++)
     name_list = name_list + ", " + table->at(i).name;
-  xbt_die("Collective '%s' has no algorithm '%s'! Valid algorithms: %s.", collective.c_str(), algo.c_str(), name_list.c_str());
+  xbt_die("Collective '%s' has no algorithm '%s'! Valid algorithms: %s. Please use --help-coll for details.",
+          collective.c_str(), algo.c_str(), name_list.c_str());
 }
 
 int (*colls::gather)(const void* send_buff, int send_count, MPI_Datatype send_type, void* recv_buff, int recv_count,
@@ -289,8 +315,7 @@ void (*colls::smpi_coll_cleanup_callback)();
   {                                                                                                                    \
     auto desc = find_coll_description(_XBT_STRINGIFY(cat), name);                                                      \
     cat       = reinterpret_cast<ret(*) args>(desc->coll);                                                             \
-    if (cat == nullptr)                                                                                                \
-      xbt_die("Collective " _XBT_STRINGIFY(cat) " set to nullptr!");                                                   \
+    xbt_assert(cat != nullptr, "Collective " _XBT_STRINGIFY(cat) " set to nullptr!");                                  \
   }
 COLL_APPLY(COLL_SETTER, COLL_GATHER_SIG, "")
 COLL_APPLY(COLL_SETTER,COLL_ALLGATHER_SIG,"")
@@ -382,7 +407,7 @@ int colls::scan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype data
         break;
       }
       if(index < rank) {
-        // #Request is below rank: it's a irecv
+        // #Request is below rank: it's an irecv
         op->apply( tmpbufs[index], recvbuf, &count, datatype);
       }
     }
@@ -444,7 +469,7 @@ int colls::exscan(const void* sendbuf, void* recvbuf, int count, MPI_Datatype da
           Datatype::copy(tmpbufs[index], count, datatype, recvbuf, count, datatype);
           recvbuf_is_empty=0;
         } else
-          // #Request is below rank: it's a irecv
+          // #Request is below rank: it's an irecv
           op->apply( tmpbufs[index], recvbuf, &count, datatype);
       }
     }
@@ -483,5 +508,4 @@ int colls::alltoallw(const void* sendbuf, const int* sendcounts, const int* send
   return Request::wait(&request, MPI_STATUS_IGNORE);
 }
 
-}
-}
+} // namespace simgrid::smpi