Logo AND Algorithmique Numérique Distribuée

Public GIT Repository
python: fix register_function when passed a class
authorMartin Quinson <martin.quinson@ens-rennes.fr>
Mon, 31 Dec 2018 01:26:07 +0000 (02:26 +0100)
committerMartin Quinson <martin.quinson@ens-rennes.fr>
Mon, 31 Dec 2018 01:26:07 +0000 (02:26 +0100)
src/bindings/python/simgrid_python.cpp

index a9f0135..28300bc 100644 (file)
@@ -80,10 +80,25 @@ PYBIND11_MODULE(simgrid, m)
       .def("load_deployment", &Engine::load_deployment,
           "Load a deployment file and launch the actors that it contains, see :cpp:func:`simgrid::s4u::Engine::load_deployment()`")
       .def("run", &Engine::run, "Run the simulation")
-      .def("register_function", [](Engine*, std::string name, std::function<void(std::vector<std::string>)> f) {
+      .def("register_function", [](Engine*, std::string name, py::object obj) {
         simgrid::simix::register_function(name,
-            [f](std::vector<std::string> args) -> simgrid::simix::ActorCode {
-          return [args, f]() { f(args); };
+            [obj](std::vector<std::string> args) -> simgrid::simix::ActorCode {
+          return [obj, args]() {
+            /* Convert the std::vector into a py::tuple */
+            py::tuple params(args.size()-1);
+            for (size_t i=1; i<args.size(); i++)
+              params[i-1] = py::cast(args[i]);
+
+            PyObject *result = PyObject_CallObject(obj.ptr(), params.ptr());
+            if (!result)
+                throw pybind11::error_already_set();
+
+            /* If I was passed a class, I just built an instance, so I need to call it now */
+            if (PyCallable_Check(result)) {
+              py::object obj2 = pybind11::reinterpret_steal<py::object>(pybind11::handle(static_cast<PyObject*>(result)));
+              obj2();
+            }
+          };
         });
       }, "Registers the main function of an actor, see :cpp:func:`simgrid::s4u::Engine::register_function()`")
       ;