From 726328c4fb63d3a99363c038f4b1e9e380db03db Mon Sep 17 00:00:00 2001
From: Ben Cumming <bcumming@cscs.ch>
Date: Mon, 24 Jun 2019 09:44:02 +0200
Subject: [PATCH] Python spikes (#788)

Support for recording spikes generated by a simulation in the Python wrapper
* Implement a `spike_recorder` that holds a shared pointer to a `std::vector` of spikes, and a callback for the `arb::simulation` spike recording API.
* Add `python/example/ring.py` that creates a ring network, then records and prints spikes.
* Some fixes to get the full `recipe` -> `domain_decomposition` -> `simulation` -> `spikes` workflow to work
  * always use default `global_parameters`: user customization of global parameters for cable cells can wait until the ion species interface is finished.
  * change the Python recipe interface for `recipe::connections_on` to use `pybind11::objects` because of shim.
* Some small improvements to error and help messages.

Fixes #764
---
 .ycm_extra_conf.py     |  2 +
 arbor/arbexcept.cpp    |  2 +-
 python/CMakeLists.txt  |  3 +-
 python/cells.cpp       |  3 +-
 python/example/ring.py | 65 +++++++++++++++++++++++++++++
 python/identifiers.cpp |  2 +-
 python/pyarb.cpp       |  2 +
 python/recipe.cpp      | 69 +++++++++----------------------
 python/recipe.hpp      | 50 ++++++++++-------------
 python/simulation.cpp  | 13 +++---
 python/spikes.cpp      | 92 ++++++++++++++++++++++++++++++++++++++++++
 11 files changed, 215 insertions(+), 88 deletions(-)
 create mode 100644 python/example/ring.py
 create mode 100644 python/spikes.cpp

diff --git a/.ycm_extra_conf.py b/.ycm_extra_conf.py
index 27093e79..8fb825ba 100644
--- a/.ycm_extra_conf.py
+++ b/.ycm_extra_conf.py
@@ -66,6 +66,8 @@ flags = [
     '/usr/include/python3.6m', # TODO: run a command to find this on "any" system
     '-I',
     'sup/include',
+    '-I',
+    '/usr/include/python3.7m',
 ]
 
 # Set this to the absolute path to the folder (NOT the file!) containing the
diff --git a/arbor/arbexcept.cpp b/arbor/arbexcept.cpp
index 45a8fb27..0a3a382c 100644
--- a/arbor/arbexcept.cpp
+++ b/arbor/arbexcept.cpp
@@ -11,7 +11,7 @@ namespace arb {
 using arb::util::pprintf;
 
 bad_cell_description::bad_cell_description(cell_kind kind, cell_gid_type gid):
-    arbor_exception(pprintf("bad description for cell kind {} on gid {}", kind, gid)),
+    arbor_exception(pprintf("recipe::get_cell_kind(gid={}) -> {} does not match the cell type provided by recipe::get_cell_description(gid={})", gid, kind, gid)),
     gid(gid),
     kind(kind)
 {}
diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt
index da987714..e03878ec 100644
--- a/python/CMakeLists.txt
+++ b/python/CMakeLists.txt
@@ -25,8 +25,9 @@ add_library(pyarb MODULE
     mpi.cpp
     pyarb.cpp
     recipe.cpp
-    simulation.cpp
     schedule.cpp
+    simulation.cpp
+    spikes.cpp
 )
 
 target_link_libraries(pyarb PRIVATE arbor pybind11::module)
diff --git a/python/cells.cpp b/python/cells.cpp
index 190b45b0..d8b43418 100644
--- a/python/cells.cpp
+++ b/python/cells.cpp
@@ -24,6 +24,7 @@ arb::util::unique_any convert_cell(pybind11::object o) {
     using pybind11::isinstance;
     using pybind11::cast;
 
+    pybind11::gil_scoped_acquire guard;
     if (isinstance<arb::spike_source_cell>(o)) {
         return arb::util::unique_any(cast<arb::spike_source_cell>(o));
     }
@@ -132,7 +133,7 @@ arb::cable_cell branch_cell(arb::cell_gid_type gid, const cell_parameters& param
     // Add a synapse to the mid point of the first dendrite.
     cell.add_synapse({1, 0.5}, "expsyn");
 
-    // Add additional synapses that will not be connected to anything.
+    // Add additional synapses.
     for (unsigned i=1u; i<params.synapses; ++i) {
         cell.add_synapse({1, 0.5}, "expsyn");
     }
diff --git a/python/example/ring.py b/python/example/ring.py
new file mode 100644
index 00000000..ef80aa63
--- /dev/null
+++ b/python/example/ring.py
@@ -0,0 +1,65 @@
+import sys
+import arbor
+
+class ring_recipe (arbor.recipe):
+
+    def __init__(self, n=4):
+        # The base C++ class constructor must be called first, to ensure that
+        # all memory in the C++ class is initialized correctly.
+        arbor.recipe.__init__(self)
+        self.ncells = n
+        self.params = arbor.cell_parameters()
+
+    # The num_cells method that returns the total number of cells in the model
+    # must be implemented.
+    def num_cells(self):
+        return self.ncells
+
+    # The cell_description method returns a cell
+    def cell_description(self, gid):
+        return arbor.branch_cell(gid, self.params)
+
+    def num_targets(self, gid):
+        return 1
+
+    def num_sources(self, gid):
+        return 1
+
+    # The kind method returns the type of cell with gid.
+    # Note: this must agree with the type returned by cell_description.
+    def cell_kind(self, gid):
+        return arbor.cell_kind.cable
+
+    # Make a ring network
+    def connections_on(self, gid):
+        src = (gid-1)%self.ncells
+        w = 0.01
+        d = 10
+        return [arbor.connection(arbor.cell_member(src,0), arbor.cell_member(gid,0), w, d)]
+
+    # Attach a generator to the first cell in the ring.
+    def event_generators(self, gid):
+        if gid==0:
+            sched = arbor.explicit_schedule([1])
+            return [arbor.event_generator(arbor.cell_member(0,0), 0.1, sched)]
+        return []
+
+
+context = arbor.context(threads=4, gpu_id=None)
+print(context)
+
+recipe = ring_recipe(100)
+print(recipe)
+
+decomp = arbor.partition_load_balance(recipe, context)
+print(decomp)
+
+sim = arbor.simulation(recipe, decomp, context)
+print(sim)
+
+recorder = arbor.attach_spike_recorder(sim)
+
+sim.run(1000)
+
+for s in recorder.spikes:
+    print(s)
diff --git a/python/identifiers.cpp b/python/identifiers.cpp
index 11895074..43b6b9c4 100644
--- a/python/identifiers.cpp
+++ b/python/identifiers.cpp
@@ -48,7 +48,7 @@ void register_identifiers(pybind11::module& m) {
             "Proxy cell that generates spikes from a spike sequence provided by the user.");
 
     pybind11::enum_<arb::backend_kind>(m, "backend",
-        "Enumeration used to indicate which hardware backend to use for running a cell_group.")
+        "Enumeration used to indicate which hardware backend to execute a cell group on.")
         .value("gpu", arb::backend_kind::gpu,
             "Use GPU backend.")
         .value("multicore", arb::backend_kind::multicore,
diff --git a/python/pyarb.cpp b/python/pyarb.cpp
index 143317d2..c369c2ab 100644
--- a/python/pyarb.cpp
+++ b/python/pyarb.cpp
@@ -15,6 +15,7 @@ void register_identifiers(pybind11::module& m);
 void register_recipe(pybind11::module& m);
 void register_schedules(pybind11::module& m);
 void register_simulation(pybind11::module& m);
+void register_spike_handling(pybind11::module& m);
 
 #ifdef ARB_MPI_ENABLED
 void register_mpi(pybind11::module& m);
@@ -37,4 +38,5 @@ PYBIND11_MODULE(arbor, m) {
     pyarb::register_recipe(m);
     pyarb::register_schedules(m);
     pyarb::register_simulation(m);
+    pyarb::register_spike_handling(m);
 }
diff --git a/python/recipe.cpp b/python/recipe.cpp
index fc1f4569..49b03ec3 100644
--- a/python/recipe.cpp
+++ b/python/recipe.cpp
@@ -22,43 +22,17 @@ namespace pyarb {
 // The py::recipe::cell_decription returns a pybind11::object, that is
 // unwrapped and copied into a arb::util::unique_any.
 arb::util::unique_any py_recipe_shim::get_cell_description(arb::cell_gid_type gid) const {
-    // Aquire the GIL because it must be held when calling isinstance and cast.
-    auto guard = pybind11::gil_scoped_acquire();
-
-    // Get the python object pyarb::cell_description from the python front end
+    pybind11::gil_scoped_acquire guard;
     return convert_cell(impl_->cell_description(gid));
 }
 
-// The py::recipe::global_properties returns a pybind11::object, that is
-// unwrapped and copied into a arb::util::any.
-arb::util::any py_recipe_shim::get_global_properties(arb::cell_kind kind) const {
-    using pybind11::cast;
-
-    // Aquire the GIL because it must be held when calling cast.
-    auto guard = pybind11::gil_scoped_acquire();
-
-    // Get the python object pyarb::global_properties from the python front end
-    pybind11::object o = impl_->global_properties(kind);
-
-    if (kind == arb::cell_kind::cable) {
-        return arb::util::any(cast<arb::cable_cell_global_properties>(o));
-    }
-
-    else return arb::util::any{};
-
-    throw pyarb_error( "recipe.global_properties returned \""
-                       + std::string(pybind11::str(o))
-                       + "\" which does not describe a known Arbor global property description");
-
-}
-
 std::vector<arb::event_generator> py_recipe_shim::event_generators(arb::cell_gid_type gid) const {
     using namespace std::string_literals;
     using pybind11::isinstance;
     using pybind11::cast;
 
     // Aquire the GIL because it must be held when calling isinstance and cast.
-    auto guard = pybind11::gil_scoped_acquire();
+    pybind11::gil_scoped_acquire guard;
 
     // Get the python list of pyarb::event_generator_shim from the python front end.
     auto pygens = impl_->event_generators(gid);
@@ -69,10 +43,9 @@ std::vector<arb::event_generator> py_recipe_shim::event_generators(arb::cell_gid
     for (auto& g: pygens) {
         // check that a valid Python event_generator was passed.
         if (!isinstance<pyarb::event_generator_shim>(g)) {
-            std::stringstream s;
-            s << "recipe supplied an invalid event generator for gid "
-            << gid << ": " << pybind11::str(g);
-            throw pyarb_error(s.str());
+            throw pyarb_error(
+                util::pprintf(
+                    "recipe supplied an invalid event generator for gid {}: {}", gid, pybind11::str(g)));
         }
         // get a reference to the python event_generator
         auto& p = cast<const pyarb::event_generator_shim&>(g);
@@ -128,24 +101,24 @@ void register_recipe(pybind11::module& m) {
     using namespace pybind11::literals;
 
     // Connections
-    pybind11::class_<cell_connection_shim> cell_connection(m, "connection",
+    pybind11::class_<arb::cell_connection> cell_connection(m, "connection",
         "Describes a connection between two cells:\n"
         "  Defined by source and destination end points (that is pre-synaptic and post-synaptic respectively), a connection weight and a delay time.");
     cell_connection
         .def(pybind11::init<arb::cell_member_type, arb::cell_member_type, float, arb::time_type>(),
-            "source"_a = arb::cell_member_type{0,0}, "dest"_a = arb::cell_member_type{0,0}, "weight"_a = 0.f, "delay"_a,
+            "source"_a, "dest"_a, "weight"_a, "delay"_a,
             "Construct a connection with arguments:\n"
-            "  source:      The source end point of the connection (default (0,0)).\n"
-            "  dest:        The destination end point of the connection (default (0,0)).\n"
-            "  weight:      The weight delivered to the target synapse (dimensionless with interpretation specific to synapse type of target, default 0.).\n"
+            "  source:      The source end point of the connection.\n"
+            "  dest:        The destination end point of the connection.\n"
+            "  weight:      The weight delivered to the target synapse (unit: defined by the type of synapse target).\n"
             "  delay:       The delay of the connection (unit: ms).")
-        .def_readwrite("source", &cell_connection_shim::source,
+        .def_readwrite("source", &arb::cell_connection::source,
             "The source of the connection.")
-        .def_readwrite("dest", &cell_connection_shim::destination,
+        .def_readwrite("dest", &arb::cell_connection::dest,
             "The destination of the connection.")
-        .def_readwrite("weight", &cell_connection_shim::weight,
+        .def_readwrite("weight", &arb::cell_connection::weight,
             "The weight of the connection.")
-        .def_property("delay", &cell_connection_shim::get_delay, &cell_connection_shim::set_delay,
+        .def_readwrite("delay", &arb::cell_connection::delay,
             "The delay time of the connection (unit: ms).")
         .def("__str__",  &con_to_string)
         .def("__repr__", &con_to_string);
@@ -155,11 +128,11 @@ void register_recipe(pybind11::module& m) {
         "Describes a gap junction between two gap junction sites.");
     gap_junction_connection
         .def(pybind11::init<arb::cell_member_type, arb::cell_member_type, double>(),
-            "local"_a = arb::cell_member_type{0,0}, "peer"_a = arb::cell_member_type{0,0}, "ggap"_a = 0.f,
+            "local"_a, "peer"_a, "ggap"_a,
             "Construct a gap junction connection with arguments:\n"
-            "  local: One half of the gap junction connection (default (0,0)).\n"
-            "  peer:  Other half of the gap junction connection (default (0,0)).\n"
-            "  ggap:  Gap junction conductance (unit: μS, default 0.).")
+            "  local: One half of the gap junction connection.\n"
+            "  peer:  Other half of the gap junction connection.\n"
+            "  ggap:  Gap junction conductance (unit: μS).")
         .def_readwrite("local", &arb::gap_junction_connection::local,
             "One half of the gap junction connection.")
         .def_readwrite("peer", &arb::gap_junction_connection::peer,
@@ -204,10 +177,8 @@ void register_recipe(pybind11::module& m) {
             "gid"_a,
             "A list of the gap junctions connected to gid (default []).")
         // TODO: py_recipe::get_probe
-        .def("global_properties", &py_recipe::global_properties, pybind11::return_value_policy::copy,
-            "cell_kind"_a,
-            "Global property type specific to a given cell kind.")
-        .def("__str__", [](const py_recipe&){return "<arbor.recipe>";})
+        // TODO: py_recipe::global_properties
+        .def("__str__",  [](const py_recipe&){return "<arbor.recipe>";})
         .def("__repr__", [](const py_recipe&){return "<arbor.recipe>";});
 }
 } // namespace pyarb
diff --git a/python/recipe.hpp b/python/recipe.hpp
index 6057fb6d..f8ccaef3 100644
--- a/python/recipe.hpp
+++ b/python/recipe.hpp
@@ -26,30 +26,31 @@ public:
     virtual ~py_recipe() {}
 
     virtual arb::cell_size_type num_cells() const = 0;
-
     virtual pybind11::object cell_description(arb::cell_gid_type gid) const = 0;
     virtual arb::cell_kind cell_kind(arb::cell_gid_type gid) const = 0;
 
-    virtual arb::cell_size_type num_sources(arb::cell_gid_type) const { return 0; }
-    virtual arb::cell_size_type num_targets(arb::cell_gid_type) const { return 0; }
-
-    //TODO: virtual arb::cell_size_type num_probes(arb::cell_gid_type) const { return 0; }
-
+    virtual arb::cell_size_type num_sources(arb::cell_gid_type) const {
+        return 0;
+    }
+    virtual arb::cell_size_type num_targets(arb::cell_gid_type) const {
+        return 0;
+    }
     virtual arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const {
         return gap_junctions_on(gid).size();
     }
-
     virtual std::vector<pybind11::object> event_generators(arb::cell_gid_type gid) const {
-        auto guard = pybind11::gil_scoped_acquire();
+        return {};
+    }
+    virtual std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const {
+        return {};
+    }
+    virtual std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type) const {
         return {};
     }
 
-    virtual std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const { return {}; }
-    virtual std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type) const { return {}; }
-
+    //TODO: virtual arb::cell_size_type num_probes(arb::cell_gid_type) const { return 0; }
     //TODO: virtual pybind11::object get_probe (arb::cell_member_type id) const {...}
-
-    virtual pybind11::object global_properties(arb::cell_kind kind) const = 0;
+    //TODO: virtual pybind11::object global_properties(arb::cell_kind kind) const {return pybind11::none();};
 };
 
 class py_recipe_trampoline: public py_recipe {
@@ -74,8 +75,6 @@ public:
         PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_targets, gid);
     }
 
-    //TODO: arb::cell_size_type num_probes(arb::cell_gid_type)
-
     arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const override {
         PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_gap_junction_sites, gid);
     }
@@ -92,11 +91,8 @@ public:
         PYBIND11_OVERLOAD(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid);
     }
 
+    //TODO: arb::cell_size_type num_probes(arb::cell_gid_type)
     //TODO: pybind11::object get_probe(arb::cell_member_type id)
-
-    pybind11::object global_properties(arb::cell_kind kind) const override {
-        PYBIND11_OVERLOAD_PURE(pybind11::object, py_recipe, global_properties, kind);
-    }
 };
 
 // A recipe shim that holds a pyarb::recipe implementation.
@@ -134,10 +130,8 @@ public:
         return impl_->num_targets(gid);
     }
 
-/* //TODO: arb::cell_size_type num_probes(arb::cell_gid_type gid) const override {
-        return impl_->num_probes(gid);
-    }
-*/
+    //TODO: arb::cell_size_type num_probes(arb::cell_gid_type gid)
+
     arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const override {
         return impl_->num_gap_junction_sites(gid);
     }
@@ -152,12 +146,12 @@ public:
         return impl_->gap_junctions_on(gid);
     }
 
-    //TODO: arb::probe_info get_probe(arb::cell_member_type id) const override
-
-    // The pyarb::recipe::global_properties returns a pybind11::object, that is
-    // unwrapped and copied into a util::any.
-    arb::util::any get_global_properties(arb::cell_kind kind) const override;
+    //TODO: arb::probe_info get_probe(arb::cell_member_type id)
 
+    // TODO: wrap
+    arb::util::any get_global_properties(arb::cell_kind kind) const override {
+        return arb::util::any{};
+    }
 };
 
 } // namespace pyarb
diff --git a/python/simulation.cpp b/python/simulation.cpp
index 07af633b..6ed19f8c 100644
--- a/python/simulation.cpp
+++ b/python/simulation.cpp
@@ -15,14 +15,13 @@ void register_simulation(pybind11::module& m) {
         "The executable form of a model.\n"
         "A simulation is constructed from a recipe, and then used to update and monitor model state.");
     simulation
-        // A custom constructor that wraps a python recipe with
-        // arb::py_recipe_shim before forwarding it to the arb::recipe constructor.
+        // A custom constructor that wraps a python recipe with arb::py_recipe_shim
+        // before forwarding it to the arb::recipe constructor.
         .def(pybind11::init(
             [](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) {
                 return new arb::simulation(py_recipe_shim(rec), decomp, ctx.context);
             }),
-            // Release the python gil, so that callbacks into the python
-            // recipe don't deadlock.
+            // Release the python gil, so that callbacks into the python recipe don't deadlock.
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Initialize the model described by a recipe, with cells and network distributed\n"
             "according to the domain decomposition and computational resources described by a context.",
@@ -32,10 +31,10 @@ void register_simulation(pybind11::module& m) {
             "Reset the state of the simulation to its initial state.")
         .def("run", &arb::simulation::run,
             pybind11::call_guard<pybind11::gil_scoped_release>(),
-            "Run the simulation from current simulation time to tfinal, with maximum time step size dt.",
-            "tfinal"_a, "dt"_a)
+            "Run the simulation from current simulation time to tfinal (unit: ms), with maximum time step size dt (unit: ms).",
+            "tfinal"_a, "dt"_a=0.025)
         .def("set_binning_policy", &arb::simulation::set_binning_policy,
-            "Set event binning policy on all our groups.",
+            "Set the binning policy for event delivery, and the binning time interval if applicable (unit: ms).",
             "policy"_a, "bin_interval"_a)
         .def("__str__",  [](const arb::simulation&){ return "<arbor.simulation>"; })
         .def("__repr__", [](const arb::simulation&){ return "<arbor.simulation>"; });
diff --git a/python/spikes.cpp b/python/spikes.cpp
new file mode 100644
index 00000000..e362bb50
--- /dev/null
+++ b/python/spikes.cpp
@@ -0,0 +1,92 @@
+#include <memory>
+#include <vector>
+
+#include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
+
+#include <arbor/spike.hpp>
+#include <arbor/simulation.hpp>
+
+#include "strprintf.hpp"
+
+namespace pyarb {
+
+// A functor that models arb::spike_export_function.
+// Holds a shared pointer to the spike_vec used to store the spikes, so that if
+// the spike_vec in spike_recorder is garbage collected in Python, stores will
+// not seg fault.
+struct spike_callback {
+    using spike_vec = std::vector<arb::spike>;
+
+    std::shared_ptr<spike_vec> spike_store;
+
+    spike_callback(const std::shared_ptr<spike_vec>& state):
+        spike_store(state)
+    {}
+
+    void operator() (const spike_vec& spikes) {
+        spike_store->insert(spike_store->end(), spikes.begin(), spikes.end());
+    };
+};
+
+// Helper type for recording spikes from a simulation.
+// This type is wrapped in Python, to expose spike_recorder::spike_store.
+struct spike_recorder {
+    using spike_vec = std::vector<arb::spike>;
+    std::shared_ptr<spike_vec> spike_store;
+
+    spike_callback callback() {
+        // initialize the spike_store
+        spike_store = std::make_shared<spike_vec>();
+
+        // The callback holds a copy of spike_store, i.e. the shared
+        // pointer is held by both the spike_recorder and the callback, so if
+        // the spike_recorder is destructed in the calling Python code, attempts
+        // to write to spike_store inside the callback will not seg fault.
+        return spike_callback(spike_store);
+    }
+
+    const spike_vec& spikes() const {
+        return *spike_store;
+    }
+};
+
+std::shared_ptr<spike_recorder> attach_spike_recorder(arb::simulation& sim) {
+    auto r = std::make_shared<spike_recorder>();
+    sim.set_global_spike_callback(r->callback());
+    return r;
+}
+
+std::string spike_str(const arb::spike& s) {
+    return util::pprintf(
+            "<arbor.spike: source ({},{}), time {}>",
+            s.source.gid, s.source.index, s.time);
+}
+
+void register_spike_handling(pybind11::module& m) {
+    using namespace pybind11::literals;
+
+    pybind11::class_<arb::spike> spike(m, "spike");
+    spike
+        .def(pybind11::init<>())
+        .def_readwrite("source", &arb::spike::source)
+        .def_readwrite("time", &arb::spike::time)
+        .def("__str__",  &spike_str)
+        .def("__repr__", &spike_str);
+
+    // Use shared_ptr for spike_recorder, so that all copies of a recorder will
+    // see the spikes from the simulation with which the recorder's callback has been
+    // registered.
+    pybind11::class_<spike_recorder, std::shared_ptr<spike_recorder>> sprec(m, "spike_recorder");
+    sprec
+        .def(pybind11::init<>())
+        .def_property_readonly("spikes", &spike_recorder::spikes);
+
+    m.def("attach_spike_recorder", &attach_spike_recorder,
+          "sim"_a,
+          "Attach a spike recorder to an arbor simulation.\n"
+          "The recorder that is returned will record all spikes generated after it has been\n"
+          "attached (spikes generated before attaching are not recorded).");
+}
+
+} // namespace pyarb
-- 
GitLab