Skip to content
Snippets Groups Projects
Commit 726328c4 authored by Benjamin Cumming's avatar Benjamin Cumming Committed by akuesters
Browse files

Python spikes (#788)

Support for recording spikes generated by a simulation in the Python wrapper
* Implement a `spike_recorder` that holds a shared pointer to a `std::vector` of spikes, and a callback for the `arb::simulation` spike recording API.
* Add `python/example/ring.py` that creates a ring network, then records and prints spikes.
* Some fixes to get the full `recipe` -> `domain_decomposition` -> `simulation` -> `spikes` workflow to work
  * always use default `global_parameters`: user customization of global parameters for cable cells can wait until the ion species interface is finished.
  * change the Python recipe interface for `recipe::connections_on` to use `pybind11::objects` because of shim.
* Some small improvements to error and help messages.

Fixes #764 
parent 4b07e613
Branches
Tags
No related merge requests found
......@@ -66,6 +66,8 @@ flags = [
'/usr/include/python3.6m', # TODO: run a command to find this on "any" system
'-I',
'sup/include',
'-I',
'/usr/include/python3.7m',
]
# Set this to the absolute path to the folder (NOT the file!) containing the
......
......@@ -11,7 +11,7 @@ namespace arb {
using arb::util::pprintf;
bad_cell_description::bad_cell_description(cell_kind kind, cell_gid_type gid):
arbor_exception(pprintf("bad description for cell kind {} on gid {}", kind, gid)),
arbor_exception(pprintf("recipe::get_cell_kind(gid={}) -> {} does not match the cell type provided by recipe::get_cell_description(gid={})", gid, kind, gid)),
gid(gid),
kind(kind)
{}
......
......@@ -25,8 +25,9 @@ add_library(pyarb MODULE
mpi.cpp
pyarb.cpp
recipe.cpp
simulation.cpp
schedule.cpp
simulation.cpp
spikes.cpp
)
target_link_libraries(pyarb PRIVATE arbor pybind11::module)
......
......@@ -24,6 +24,7 @@ arb::util::unique_any convert_cell(pybind11::object o) {
using pybind11::isinstance;
using pybind11::cast;
pybind11::gil_scoped_acquire guard;
if (isinstance<arb::spike_source_cell>(o)) {
return arb::util::unique_any(cast<arb::spike_source_cell>(o));
}
......@@ -132,7 +133,7 @@ arb::cable_cell branch_cell(arb::cell_gid_type gid, const cell_parameters& param
// Add a synapse to the mid point of the first dendrite.
cell.add_synapse({1, 0.5}, "expsyn");
// Add additional synapses that will not be connected to anything.
// Add additional synapses.
for (unsigned i=1u; i<params.synapses; ++i) {
cell.add_synapse({1, 0.5}, "expsyn");
}
......
import sys
import arbor
class ring_recipe (arbor.recipe):
def __init__(self, n=4):
# The base C++ class constructor must be called first, to ensure that
# all memory in the C++ class is initialized correctly.
arbor.recipe.__init__(self)
self.ncells = n
self.params = arbor.cell_parameters()
# The num_cells method that returns the total number of cells in the model
# must be implemented.
def num_cells(self):
return self.ncells
# The cell_description method returns a cell
def cell_description(self, gid):
return arbor.branch_cell(gid, self.params)
def num_targets(self, gid):
return 1
def num_sources(self, gid):
return 1
# The kind method returns the type of cell with gid.
# Note: this must agree with the type returned by cell_description.
def cell_kind(self, gid):
return arbor.cell_kind.cable
# Make a ring network
def connections_on(self, gid):
src = (gid-1)%self.ncells
w = 0.01
d = 10
return [arbor.connection(arbor.cell_member(src,0), arbor.cell_member(gid,0), w, d)]
# Attach a generator to the first cell in the ring.
def event_generators(self, gid):
if gid==0:
sched = arbor.explicit_schedule([1])
return [arbor.event_generator(arbor.cell_member(0,0), 0.1, sched)]
return []
context = arbor.context(threads=4, gpu_id=None)
print(context)
recipe = ring_recipe(100)
print(recipe)
decomp = arbor.partition_load_balance(recipe, context)
print(decomp)
sim = arbor.simulation(recipe, decomp, context)
print(sim)
recorder = arbor.attach_spike_recorder(sim)
sim.run(1000)
for s in recorder.spikes:
print(s)
......@@ -48,7 +48,7 @@ void register_identifiers(pybind11::module& m) {
"Proxy cell that generates spikes from a spike sequence provided by the user.");
pybind11::enum_<arb::backend_kind>(m, "backend",
"Enumeration used to indicate which hardware backend to use for running a cell_group.")
"Enumeration used to indicate which hardware backend to execute a cell group on.")
.value("gpu", arb::backend_kind::gpu,
"Use GPU backend.")
.value("multicore", arb::backend_kind::multicore,
......
......@@ -15,6 +15,7 @@ void register_identifiers(pybind11::module& m);
void register_recipe(pybind11::module& m);
void register_schedules(pybind11::module& m);
void register_simulation(pybind11::module& m);
void register_spike_handling(pybind11::module& m);
#ifdef ARB_MPI_ENABLED
void register_mpi(pybind11::module& m);
......@@ -37,4 +38,5 @@ PYBIND11_MODULE(arbor, m) {
pyarb::register_recipe(m);
pyarb::register_schedules(m);
pyarb::register_simulation(m);
pyarb::register_spike_handling(m);
}
......@@ -22,43 +22,17 @@ namespace pyarb {
// The py::recipe::cell_decription returns a pybind11::object, that is
// unwrapped and copied into a arb::util::unique_any.
arb::util::unique_any py_recipe_shim::get_cell_description(arb::cell_gid_type gid) const {
// Aquire the GIL because it must be held when calling isinstance and cast.
auto guard = pybind11::gil_scoped_acquire();
// Get the python object pyarb::cell_description from the python front end
pybind11::gil_scoped_acquire guard;
return convert_cell(impl_->cell_description(gid));
}
// The py::recipe::global_properties returns a pybind11::object, that is
// unwrapped and copied into a arb::util::any.
arb::util::any py_recipe_shim::get_global_properties(arb::cell_kind kind) const {
using pybind11::cast;
// Aquire the GIL because it must be held when calling cast.
auto guard = pybind11::gil_scoped_acquire();
// Get the python object pyarb::global_properties from the python front end
pybind11::object o = impl_->global_properties(kind);
if (kind == arb::cell_kind::cable) {
return arb::util::any(cast<arb::cable_cell_global_properties>(o));
}
else return arb::util::any{};
throw pyarb_error( "recipe.global_properties returned \""
+ std::string(pybind11::str(o))
+ "\" which does not describe a known Arbor global property description");
}
std::vector<arb::event_generator> py_recipe_shim::event_generators(arb::cell_gid_type gid) const {
using namespace std::string_literals;
using pybind11::isinstance;
using pybind11::cast;
// Aquire the GIL because it must be held when calling isinstance and cast.
auto guard = pybind11::gil_scoped_acquire();
pybind11::gil_scoped_acquire guard;
// Get the python list of pyarb::event_generator_shim from the python front end.
auto pygens = impl_->event_generators(gid);
......@@ -69,10 +43,9 @@ std::vector<arb::event_generator> py_recipe_shim::event_generators(arb::cell_gid
for (auto& g: pygens) {
// check that a valid Python event_generator was passed.
if (!isinstance<pyarb::event_generator_shim>(g)) {
std::stringstream s;
s << "recipe supplied an invalid event generator for gid "
<< gid << ": " << pybind11::str(g);
throw pyarb_error(s.str());
throw pyarb_error(
util::pprintf(
"recipe supplied an invalid event generator for gid {}: {}", gid, pybind11::str(g)));
}
// get a reference to the python event_generator
auto& p = cast<const pyarb::event_generator_shim&>(g);
......@@ -128,24 +101,24 @@ void register_recipe(pybind11::module& m) {
using namespace pybind11::literals;
// Connections
pybind11::class_<cell_connection_shim> cell_connection(m, "connection",
pybind11::class_<arb::cell_connection> cell_connection(m, "connection",
"Describes a connection between two cells:\n"
" Defined by source and destination end points (that is pre-synaptic and post-synaptic respectively), a connection weight and a delay time.");
cell_connection
.def(pybind11::init<arb::cell_member_type, arb::cell_member_type, float, arb::time_type>(),
"source"_a = arb::cell_member_type{0,0}, "dest"_a = arb::cell_member_type{0,0}, "weight"_a = 0.f, "delay"_a,
"source"_a, "dest"_a, "weight"_a, "delay"_a,
"Construct a connection with arguments:\n"
" source: The source end point of the connection (default (0,0)).\n"
" dest: The destination end point of the connection (default (0,0)).\n"
" weight: The weight delivered to the target synapse (dimensionless with interpretation specific to synapse type of target, default 0.).\n"
" source: The source end point of the connection.\n"
" dest: The destination end point of the connection.\n"
" weight: The weight delivered to the target synapse (unit: defined by the type of synapse target).\n"
" delay: The delay of the connection (unit: ms).")
.def_readwrite("source", &cell_connection_shim::source,
.def_readwrite("source", &arb::cell_connection::source,
"The source of the connection.")
.def_readwrite("dest", &cell_connection_shim::destination,
.def_readwrite("dest", &arb::cell_connection::dest,
"The destination of the connection.")
.def_readwrite("weight", &cell_connection_shim::weight,
.def_readwrite("weight", &arb::cell_connection::weight,
"The weight of the connection.")
.def_property("delay", &cell_connection_shim::get_delay, &cell_connection_shim::set_delay,
.def_readwrite("delay", &arb::cell_connection::delay,
"The delay time of the connection (unit: ms).")
.def("__str__", &con_to_string)
.def("__repr__", &con_to_string);
......@@ -155,11 +128,11 @@ void register_recipe(pybind11::module& m) {
"Describes a gap junction between two gap junction sites.");
gap_junction_connection
.def(pybind11::init<arb::cell_member_type, arb::cell_member_type, double>(),
"local"_a = arb::cell_member_type{0,0}, "peer"_a = arb::cell_member_type{0,0}, "ggap"_a = 0.f,
"local"_a, "peer"_a, "ggap"_a,
"Construct a gap junction connection with arguments:\n"
" local: One half of the gap junction connection (default (0,0)).\n"
" peer: Other half of the gap junction connection (default (0,0)).\n"
" ggap: Gap junction conductance (unit: μS, default 0.).")
" local: One half of the gap junction connection.\n"
" peer: Other half of the gap junction connection.\n"
" ggap: Gap junction conductance (unit: μS).")
.def_readwrite("local", &arb::gap_junction_connection::local,
"One half of the gap junction connection.")
.def_readwrite("peer", &arb::gap_junction_connection::peer,
......@@ -204,9 +177,7 @@ void register_recipe(pybind11::module& m) {
"gid"_a,
"A list of the gap junctions connected to gid (default []).")
// TODO: py_recipe::get_probe
.def("global_properties", &py_recipe::global_properties, pybind11::return_value_policy::copy,
"cell_kind"_a,
"Global property type specific to a given cell kind.")
// TODO: py_recipe::global_properties
.def("__str__", [](const py_recipe&){return "<arbor.recipe>";})
.def("__repr__", [](const py_recipe&){return "<arbor.recipe>";});
}
......
......@@ -26,30 +26,31 @@ public:
virtual ~py_recipe() {}
virtual arb::cell_size_type num_cells() const = 0;
virtual pybind11::object cell_description(arb::cell_gid_type gid) const = 0;
virtual arb::cell_kind cell_kind(arb::cell_gid_type gid) const = 0;
virtual arb::cell_size_type num_sources(arb::cell_gid_type) const { return 0; }
virtual arb::cell_size_type num_targets(arb::cell_gid_type) const { return 0; }
//TODO: virtual arb::cell_size_type num_probes(arb::cell_gid_type) const { return 0; }
virtual arb::cell_size_type num_sources(arb::cell_gid_type) const {
return 0;
}
virtual arb::cell_size_type num_targets(arb::cell_gid_type) const {
return 0;
}
virtual arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const {
return gap_junctions_on(gid).size();
}
virtual std::vector<pybind11::object> event_generators(arb::cell_gid_type gid) const {
auto guard = pybind11::gil_scoped_acquire();
return {};
}
virtual std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const {
return {};
}
virtual std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type) const {
return {};
}
virtual std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const { return {}; }
virtual std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type) const { return {}; }
//TODO: virtual arb::cell_size_type num_probes(arb::cell_gid_type) const { return 0; }
//TODO: virtual pybind11::object get_probe (arb::cell_member_type id) const {...}
virtual pybind11::object global_properties(arb::cell_kind kind) const = 0;
//TODO: virtual pybind11::object global_properties(arb::cell_kind kind) const {return pybind11::none();};
};
class py_recipe_trampoline: public py_recipe {
......@@ -74,8 +75,6 @@ public:
PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_targets, gid);
}
//TODO: arb::cell_size_type num_probes(arb::cell_gid_type)
arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const override {
PYBIND11_OVERLOAD(arb::cell_size_type, py_recipe, num_gap_junction_sites, gid);
}
......@@ -92,11 +91,8 @@ public:
PYBIND11_OVERLOAD(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid);
}
//TODO: arb::cell_size_type num_probes(arb::cell_gid_type)
//TODO: pybind11::object get_probe(arb::cell_member_type id)
pybind11::object global_properties(arb::cell_kind kind) const override {
PYBIND11_OVERLOAD_PURE(pybind11::object, py_recipe, global_properties, kind);
}
};
// A recipe shim that holds a pyarb::recipe implementation.
......@@ -134,10 +130,8 @@ public:
return impl_->num_targets(gid);
}
/* //TODO: arb::cell_size_type num_probes(arb::cell_gid_type gid) const override {
return impl_->num_probes(gid);
}
*/
//TODO: arb::cell_size_type num_probes(arb::cell_gid_type gid)
arb::cell_size_type num_gap_junction_sites(arb::cell_gid_type gid) const override {
return impl_->num_gap_junction_sites(gid);
}
......@@ -152,12 +146,12 @@ public:
return impl_->gap_junctions_on(gid);
}
//TODO: arb::probe_info get_probe(arb::cell_member_type id) const override
// The pyarb::recipe::global_properties returns a pybind11::object, that is
// unwrapped and copied into a util::any.
arb::util::any get_global_properties(arb::cell_kind kind) const override;
//TODO: arb::probe_info get_probe(arb::cell_member_type id)
// TODO: wrap
arb::util::any get_global_properties(arb::cell_kind kind) const override {
return arb::util::any{};
}
};
} // namespace pyarb
......@@ -15,14 +15,13 @@ void register_simulation(pybind11::module& m) {
"The executable form of a model.\n"
"A simulation is constructed from a recipe, and then used to update and monitor model state.");
simulation
// A custom constructor that wraps a python recipe with
// arb::py_recipe_shim before forwarding it to the arb::recipe constructor.
// A custom constructor that wraps a python recipe with arb::py_recipe_shim
// before forwarding it to the arb::recipe constructor.
.def(pybind11::init(
[](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) {
return new arb::simulation(py_recipe_shim(rec), decomp, ctx.context);
}),
// Release the python gil, so that callbacks into the python
// recipe don't deadlock.
// Release the python gil, so that callbacks into the python recipe don't deadlock.
pybind11::call_guard<pybind11::gil_scoped_release>(),
"Initialize the model described by a recipe, with cells and network distributed\n"
"according to the domain decomposition and computational resources described by a context.",
......@@ -32,10 +31,10 @@ void register_simulation(pybind11::module& m) {
"Reset the state of the simulation to its initial state.")
.def("run", &arb::simulation::run,
pybind11::call_guard<pybind11::gil_scoped_release>(),
"Run the simulation from current simulation time to tfinal, with maximum time step size dt.",
"tfinal"_a, "dt"_a)
"Run the simulation from current simulation time to tfinal (unit: ms), with maximum time step size dt (unit: ms).",
"tfinal"_a, "dt"_a=0.025)
.def("set_binning_policy", &arb::simulation::set_binning_policy,
"Set event binning policy on all our groups.",
"Set the binning policy for event delivery, and the binning time interval if applicable (unit: ms).",
"policy"_a, "bin_interval"_a)
.def("__str__", [](const arb::simulation&){ return "<arbor.simulation>"; })
.def("__repr__", [](const arb::simulation&){ return "<arbor.simulation>"; });
......
#include <memory>
#include <vector>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <arbor/spike.hpp>
#include <arbor/simulation.hpp>
#include "strprintf.hpp"
namespace pyarb {
// A functor that models arb::spike_export_function.
// Holds a shared pointer to the spike_vec used to store the spikes, so that if
// the spike_vec in spike_recorder is garbage collected in Python, stores will
// not seg fault.
struct spike_callback {
using spike_vec = std::vector<arb::spike>;
std::shared_ptr<spike_vec> spike_store;
spike_callback(const std::shared_ptr<spike_vec>& state):
spike_store(state)
{}
void operator() (const spike_vec& spikes) {
spike_store->insert(spike_store->end(), spikes.begin(), spikes.end());
};
};
// Helper type for recording spikes from a simulation.
// This type is wrapped in Python, to expose spike_recorder::spike_store.
struct spike_recorder {
using spike_vec = std::vector<arb::spike>;
std::shared_ptr<spike_vec> spike_store;
spike_callback callback() {
// initialize the spike_store
spike_store = std::make_shared<spike_vec>();
// The callback holds a copy of spike_store, i.e. the shared
// pointer is held by both the spike_recorder and the callback, so if
// the spike_recorder is destructed in the calling Python code, attempts
// to write to spike_store inside the callback will not seg fault.
return spike_callback(spike_store);
}
const spike_vec& spikes() const {
return *spike_store;
}
};
std::shared_ptr<spike_recorder> attach_spike_recorder(arb::simulation& sim) {
auto r = std::make_shared<spike_recorder>();
sim.set_global_spike_callback(r->callback());
return r;
}
std::string spike_str(const arb::spike& s) {
return util::pprintf(
"<arbor.spike: source ({},{}), time {}>",
s.source.gid, s.source.index, s.time);
}
void register_spike_handling(pybind11::module& m) {
using namespace pybind11::literals;
pybind11::class_<arb::spike> spike(m, "spike");
spike
.def(pybind11::init<>())
.def_readwrite("source", &arb::spike::source)
.def_readwrite("time", &arb::spike::time)
.def("__str__", &spike_str)
.def("__repr__", &spike_str);
// Use shared_ptr for spike_recorder, so that all copies of a recorder will
// see the spikes from the simulation with which the recorder's callback has been
// registered.
pybind11::class_<spike_recorder, std::shared_ptr<spike_recorder>> sprec(m, "spike_recorder");
sprec
.def(pybind11::init<>())
.def_property_readonly("spikes", &spike_recorder::spikes);
m.def("attach_spike_recorder", &attach_spike_recorder,
"sim"_a,
"Attach a spike recorder to an arbor simulation.\n"
"The recorder that is returned will record all spikes generated after it has been\n"
"attached (spikes generated before attaching are not recorded).");
}
} // namespace pyarb
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment