diff --git a/doc/python/simulation.rst b/doc/python/simulation.rst index dcfc769c81b2f288cfaef02740edd4de621dfa9c..452b6f527bf0c72cf5cc063f5841531f703fd3c9 100644 --- a/doc/python/simulation.rst +++ b/doc/python/simulation.rst @@ -105,6 +105,9 @@ over the local and distributed hardware resources (see :ref:`pydomdec`). Then, t Each spike is represented as a NumPy structured datatype with signature ``('source', [('gid', '<u4'), ('index', '<u4')]), ('time', '<f8')``. + The spikes are sorted in ascending order of spike time, and spikes with the same time are + sorted accourding to source gid then index. + **Sampling probes:** .. function:: sample(probe_id, schedule, policy) @@ -209,6 +212,11 @@ Spikes recorded during a simulation are returned as a NumPy structured datatype ``source`` and ``time``. The ``source`` field itself is a structured datatype with two fields, ``gid`` and ``index``, identifying the spike detector that generated the spike. +.. Note:: + + The spikes returned by :py:func:`simulation.record` are sorted in ascending order of spike time. + Spikes that have the same spike time are sorted in ascending order of gid and local index of the + spike source. .. container:: example-code @@ -219,7 +227,9 @@ Spikes recorded during a simulation are returned as a NumPy structured datatype # Instantiate the simulation. sim = arbor.simulation(recipe, decomp, context) - # Direct the simulation to record all spikes. + # Direct the simulation to record all spikes, which will record all spikes + # across multiple MPI ranks in distrubuted simulation. + # To only record spikes from the local MPI rank, use arbor.spike_recording.local sim.record(arbor.spike_recording.all) # Run the simulation for 2000 ms with time stepping of 0.025 ms diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 08c79ae119a5b9d6af2dd4b89d23dcc34d8f6ecc..dd2028f062bd34a5ddc57dc1d361e864e3be229c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -38,7 +38,6 @@ set(pyarb_source schedule.cpp simulation.cpp single_cell_model.cpp - spikes.cpp ) # compile the pyarb sources into an object library that will be diff --git a/python/pyarb.cpp b/python/pyarb.cpp index 4c4280885ad8c8029ec2686c8f1d62fb7e33c90e..9a6782b87fccf0ae6209c2b516bd155b0e6ef591 100644 --- a/python/pyarb.cpp +++ b/python/pyarb.cpp @@ -26,7 +26,6 @@ void register_recipe(pybind11::module& m); void register_schedules(pybind11::module& m); void register_simulation(pybind11::module& m, pyarb_global_ptr); void register_single_cell(pybind11::module& m); -void register_spike_handling(pybind11::module& m); #ifdef ARB_MPI_ENABLED void register_mpi(pybind11::module& m); @@ -59,7 +58,6 @@ PYBIND11_MODULE(_arbor, m) { pyarb::register_schedules(m); pyarb::register_simulation(m, global_ptr); pyarb::register_single_cell(m); - pyarb::register_spike_handling(m); #ifdef ARB_MPI_ENABLED pyarb::register_mpi(m); diff --git a/python/simulation.cpp b/python/simulation.cpp index ab29cd6c970e67c47357c261852728baf9f37f02..98a06c5a1a0ee8befc25489f56a54f9e619cae3b 100644 --- a/python/simulation.cpp +++ b/python/simulation.cpp @@ -87,7 +87,14 @@ public: void record(spike_recording policy) { auto spike_recorder = [this](const std::vector<arb::spike>& spikes) { + auto old_size = spike_record_.size(); + // Append the new spikes to the end of the spike record. spike_record_.insert(spike_record_.end(), spikes.begin(), spikes.end()); + // Sort the newly appended spikes. + std::sort(spike_record_.begin()+old_size, spike_record_.end(), + [](const auto& lhs, const auto& rhs) { + return std::tie(lhs.time, lhs.source.gid, lhs.source.index)<std::tie(rhs.time, rhs.source.gid, rhs.source.index); + }); }; switch (policy) { diff --git a/python/spikes.cpp b/python/spikes.cpp deleted file mode 100644 index c2b006f953685f81d7603cba432d0283833a18f1..0000000000000000000000000000000000000000 --- a/python/spikes.cpp +++ /dev/null @@ -1,92 +0,0 @@ -#include <memory> -#include <vector> - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> - -#include <arbor/spike.hpp> -#include <arbor/simulation.hpp> - -#include "strprintf.hpp" - -namespace pyarb { - -// A functor that models arb::spike_export_function. -// Holds a shared pointer to the spike_vec used to store the spikes, so that if -// the spike_vec in spike_recorder is garbage collected in Python, stores will -// not seg fault. -struct spike_callback { - using spike_vec = std::vector<arb::spike>; - - std::shared_ptr<spike_vec> spike_store; - - spike_callback(const std::shared_ptr<spike_vec>& state): - spike_store(state) - {} - - void operator() (const spike_vec& spikes) { - spike_store->insert(spike_store->end(), spikes.begin(), spikes.end()); - }; -}; - -// Helper type for recording spikes from a simulation. -// This type is wrapped in Python, to expose spike_recorder::spike_store. -struct spike_recorder { - using spike_vec = std::vector<arb::spike>; - std::shared_ptr<spike_vec> spike_store; - - spike_callback callback() { - // initialize the spike_store - spike_store = std::make_shared<spike_vec>(); - - // The callback holds a copy of spike_store, i.e. the shared - // pointer is held by both the spike_recorder and the callback, so if - // the spike_recorder is destructed in the calling Python code, attempts - // to write to spike_store inside the callback will not seg fault. - return spike_callback(spike_store); - } - - const spike_vec& spikes() const { - return *spike_store; - } -}; - -std::shared_ptr<spike_recorder> attach_spike_recorder(arb::simulation& sim) { - auto r = std::make_shared<spike_recorder>(); - sim.set_global_spike_callback(r->callback()); - return r; -} - -std::string spike_str(const arb::spike& s) { - return util::pprintf( - "<arbor.spike: source ({},{}), time {} ms>", - s.source.gid, s.source.index, s.time); -} - -void register_spike_handling(pybind11::module& m) { - using namespace pybind11::literals; - - pybind11::class_<arb::spike> spike(m, "spike"); - spike - .def(pybind11::init<>()) - .def_readwrite("source", &arb::spike::source, "The spike source (type: cell_member).") - .def_readwrite("time", &arb::spike::time, "The spike time [ms].") - .def("__str__", &spike_str) - .def("__repr__", &spike_str); - - // Use shared_ptr for spike_recorder, so that all copies of a recorder will - // see the spikes from the simulation with which the recorder's callback has been - // registered. - pybind11::class_<spike_recorder, std::shared_ptr<spike_recorder>> sprec(m, "spike_recorder"); - sprec - .def(pybind11::init<>()) - .def_property_readonly("spikes", &spike_recorder::spikes, "A list of the recorded spikes."); - - m.def("attach_spike_recorder", &attach_spike_recorder, - "sim"_a, - "Attach a spike recorder to an arbor simulation.\n" - "The recorder that is returned will record all spikes generated after it has been\n" - "attached (spikes generated before attaching are not recorded)."); -} - -} // namespace pyarb diff --git a/python/test/unit/runner.py b/python/test/unit/runner.py index cc9f16d2b4feea3abc2ea145c1496e6d366dfdd6..d27357cf42ec430eb6d12956c809b51ec05475e2 100644 --- a/python/test/unit/runner.py +++ b/python/test/unit/runner.py @@ -19,6 +19,7 @@ try: import test_cable_probes import test_morphology import test_catalogues + import test_spikes # add more if needed except ModuleNotFoundError: from test import options @@ -30,6 +31,7 @@ except ModuleNotFoundError: from test.unit import test_schedules from test.unit import test_cable_probes from test.unit import test_morphology + from test.unit import test_spikes # add more if needed test_modules = [\ @@ -40,7 +42,8 @@ test_modules = [\ test_identifiers,\ test_schedules,\ test_cable_probes,\ - test_morphology\ + test_morphology,\ + test_spikes,\ ] # add more if needed def suite(): diff --git a/python/test/unit/test_simulator.py b/python/test/unit/test_simulator.py deleted file mode 100644 index a13806416f109bce63d93c8c5aefd663b4a36fc5..0000000000000000000000000000000000000000 --- a/python/test/unit/test_simulator.py +++ /dev/null @@ -1,264 +0,0 @@ -# -*- coding: utf-8 -*- -# -# test_simulator.py - -import unittest -import numpy as np -import arbor as A - -# to be able to run .py file from child directory -import sys, os -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) - -try: - import options -except ModuleNotFoundError: - from test import options - -""" -all tests for the simulator wrapper -""" - -# Test recipe cc2 comprises two cable cells and some probes. - -class cc2_recipe(A.recipe): - def __init__(self): - A.recipe.__init__(self) - st = A.segment_tree() - i = st.append(A.mnpos, (0, 0, 0, 10), (1, 0, 0, 10), 1) - st.append(i, (1, 3, 0, 5), 1) - st.append(i, (1, -4, 0, 3), 1) - self.the_morphology = A.morphology(st) - self.the_cat = A.default_catalogue() - self.the_props = A.neuron_cable_properties() - self.the_props.register(self.the_cat) - - def num_cells(self): - return 2 - - def num_targets(self, gid): - return 0 - - def num_sources(self, gid): - return 0 - - def cell_kind(self, gid): - return A.cell_kind.cable - - def connections_on(self, gid): - return [] - - def event_generators(self, gid): - return [] - - def global_properties(self, kind): - return self.the_props - - def probes(self, gid): - # Cell 0 has three voltage probes: - # 0, 0: end of branch 1 - # 0, 1: end of branch 2 - # 0, 2: all terminal points - # Values sampled from (0, 0) and (0, 1) should correspond - # to the values sampled from (0, 2). - - # Cell 1 has whole cell probes: - # 0, 0: all membrane voltages - # 0, 1: all expsyn state variable 'g' - - if gid==0: - return [A.cable_probe_membrane_voltage('(location 1 1)'), - A.cable_probe_membrane_voltage('(location 2 1)'), - A.cable_probe_membrane_voltage('(terminal)')] - elif gid==1: - return [A.cable_probe_membrane_voltage_cell(), - A.cable_probe_point_state_cell('expsyn', 'g')] - else: - return [] - - def cell_description(self, gid): - c = A.cable_cell(self.the_morphology, A.label_dict()) - c.set_properties(Vm=0.0, cm=0.01, rL=30, tempK=300) - c.paint('(all)', "pas") - c.place('(location 0 0)', A.iclamp(current=10 if gid==0 else 20)) - c.place('(sum (on-branches 0.3) (location 0 0.6))', "expsyn") - return c - -# Test recipe lif2 comprises two independent LIF cells driven by a regular, rapid -# sequence of incoming spikes. The cells have differing refactory periods. - -class lif2_recipe(A.recipe): - def __init__(self): - A.recipe.__init__(self) - - def num_cells(self): - return 2 - - def num_targets(self, gid): - return 0 - - def num_sources(self, gid): - return 0 - - def cell_kind(self, gid): - return A.cell_kind.lif - - def connections_on(self, gid): - return [] - - def event_generators(self, gid): - sched_dt = 0.25 - weight = 400 - return [A.event_generator((gid,0), weight, A.regular_schedule(sched_dt)) for gid in range(0, self.num_cells())] - - def probes(self, gid): - return [] - - def cell_description(self, gid): - c = A.lif_cell() - if gid==0: - c.t_ref = 2 - if gid==1: - c.t_ref = 4 - return c - -class Simulator(unittest.TestCase): - def init_sim(self, recipe): - context = A.context() - dd = A.partition_load_balance(recipe, context) - return A.simulation(recipe, dd, context) - - def test_simple_run(self): - sim = self.init_sim(cc2_recipe()) - sim.run(1.0, 0.01) - - def test_probe_meta(self): - sim = self.init_sim(cc2_recipe()) - - self.assertEqual([A.location(1, 1)], sim.probe_metadata((0, 0))) - self.assertEqual([A.location(2, 1)], sim.probe_metadata((0, 1))) - self.assertEqual([A.location(1, 1), A.location(2, 1)], sorted(sim.probe_metadata((0, 2)), key=lambda x:(x.branch, x.pos))) - - # Default CV policy is one per branch, which also gives a tivial CV over the branch point. - # Expect metadata cables to be one for each full branch, plus three length-zero cables corresponding to the branch point. - self.assertEqual([A.cable(0, 0, 1), A.cable(0, 1, 1), A.cable(1, 0, 0), A.cable(1, 0, 1), A.cable(2, 0, 0), A.cable(2, 0, 1)], - sorted(sim.probe_metadata((1,0))[0], key=lambda x:(x.branch, x.prox, x.dist))) - - # Four expsyn synapses; the two on branch zero should be coalesced, giving a multiplicity of 2. - # Expect entries to be in target index order. - m11 = sim.probe_metadata((1,1))[0] - self.assertEqual(4, len(m11)) - self.assertEqual([0, 1, 2, 3], [x.target for x in m11]) - self.assertEqual([2, 2, 1, 1], [x.multiplicity for x in m11]) - self.assertEqual([A.location(0, 0.3), A.location(0, 0.6), A.location(1, 0.3), A.location(2, 0.3)], [x.location for x in m11]) - - def test_probe_scalar_recorders(self): - sim = self.init_sim(cc2_recipe()) - ts = [0, 0.1, 0.3, 0.7] - h = sim.sample((0, 0), A.explicit_schedule(ts)) - dt = 0.01 - sim.run(10., dt) - s, meta = sim.samples(h)[0] - self.assertEqual(A.location(1, 1), meta) - for i, t in enumerate(s[:,0]): - self.assertLess(abs(t-ts[i]), dt) - - sim.remove_sampler(h) - sim.reset() - h = sim.sample(A.cell_member(0, 0), A.explicit_schedule(ts), A.sampling_policy.exact) - sim.run(10., dt) - s, meta = sim.samples(h)[0] - for i, t in enumerate(s[:,0]): - self.assertEqual(t, ts[i]) - - - def test_probe_multi_scalar_recorders(self): - sim = self.init_sim(cc2_recipe()) - ts = [0, 0.1, 0.3, 0.7] - h0 = sim.sample((0, 0), A.explicit_schedule(ts)) - h1 = sim.sample((0, 1), A.explicit_schedule(ts)) - h2 = sim.sample((0, 2), A.explicit_schedule(ts)) - - dt = 0.01 - sim.run(10., dt) - - r0 = sim.samples(h0) - self.assertEqual(1, len(r0)) - s0, meta0 = r0[0] - - r1 = sim.samples(h1) - self.assertEqual(1, len(r1)) - s1, meta1 = r1[0] - - r2 = sim.samples(h2) - self.assertEqual(2, len(r2)) - s20, meta20 = r2[0] - s21, meta21 = r2[1] - - # Probe id (0, 2) has probes over the two locations that correspond to probes (0, 0) and (0, 1). - - # (order is not guaranteed to line up though) - if meta20==meta0: - self.assertEqual(meta1, meta21) - self.assertTrue((s0[:,1]==s20[:,1]).all()) - self.assertTrue((s1[:,1]==s21[:,1]).all()) - else: - self.assertEqual(meta1, meta20) - self.assertTrue((s1[:,1]==s20[:,1]).all()) - self.assertEqual(meta0, meta21) - self.assertTrue((s0[:,1]==s21[:,1]).all()) - - def test_probe_vector_recorders(self): - sim = self.init_sim(cc2_recipe()) - ts = [0, 0.1, 0.3, 0.7] - h0 = sim.sample((1, 0), A.explicit_schedule(ts), A.sampling_policy.exact) - h1 = sim.sample((1, 1), A.explicit_schedule(ts), A.sampling_policy.exact) - sim.run(10., 0.01) - - # probe (1, 0) is the whole cell voltage; expect time + 6 sample values per row in returned data (see test_probe_meta above). - - s0, meta0 = sim.samples(h0)[0] - self.assertEqual(6, len(meta0)) - self.assertEqual((len(ts), 7), s0.shape) - for i, t in enumerate(s0[:,0]): - self.assertEqual(t, ts[i]) - - # probe (1, 1) is the 'g' state for all expsyn synapses. - # With the default descretization, expect two synapses with multiplicity 2 and two with multiplicity 1. - - s1, meta1 = sim.samples(h1)[0] - self.assertEqual(4, len(meta1)) - self.assertEqual((len(ts), 5), s1.shape) - for i, t in enumerate(s1[:,0]): - self.assertEqual(t, ts[i]) - - meta1_mult = {(m.location.branch, m.location.pos): m.multiplicity for m in meta1} - self.assertEqual(2, meta1_mult[(0, 0.3)]) - self.assertEqual(2, meta1_mult[(0, 0.6)]) - self.assertEqual(1, meta1_mult[(1, 0.3)]) - self.assertEqual(1, meta1_mult[(2, 0.3)]) - - def test_spikes(self): - sim = self.init_sim(lif2_recipe()) - sim.record(A.spike_recording.all) - sim.run(21, 0.01) - - spikes = sim.spikes().tolist() - s0 = sorted([t for s, t in spikes if s==(0, 0)]) - s1 = sorted([t for s, t in spikes if s==(1, 0)]) - - self.assertEqual([0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20], s0) - self.assertEqual([0, 4, 8, 12, 16, 20], s1) - -def suite(): - # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts - suite = unittest.makeSuite(Simulator, ('test')) - return suite - -def run(): - v = options.parse_arguments().verbosity - runner = unittest.TextTestRunner(verbosity = v) - runner.run(suite()) - -if __name__ == "__main__": - run() diff --git a/python/test/unit/test_spikes.py b/python/test/unit/test_spikes.py new file mode 100644 index 0000000000000000000000000000000000000000..86e7464e756868d22f1c69ab81fdf6c6cb2cc5e2 --- /dev/null +++ b/python/test/unit/test_spikes.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# +# test_spikes.py + +import unittest +import arbor as A + +# to be able to run .py file from child directory +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +try: + import options +except ModuleNotFoundError: + from test import options + +""" +all tests for the simulator wrapper +""" + +# Test recipe art_spiker_recipe comprises three artificial spiking cells + +class art_spiker_recipe(A.recipe): + def __init__(self): + A.recipe.__init__(self) + self.props = A.neuron_cable_properties() + self.trains = [ + [0.8, 2, 2.1, 3], + [0.4, 2, 2.2, 3.1, 4.5], + [0.2, 2, 2.8, 3]] + + def num_cells(self): + return 3 + + def num_targets(self, gid): + return 0 + + def num_sources(self, gid): + return 1 + + def cell_kind(self, gid): + return A.cell_kind.spike_source + + def connections_on(self, gid): + return [] + + def event_generators(self, gid): + return [] + + def global_properties(self, kind): + return self.the_props + + def probes(self, gid): + return [] + + def cell_description(self, gid): + return A.spike_source_cell(A.explicit_schedule(self.trains[gid])) + + +class Spikes(unittest.TestCase): + # Helper for constructing a simulation from a recipe using default context and domain decomposition. + def init_sim(self, recipe): + context = A.context() + dd = A.partition_load_balance(recipe, context) + return A.simulation(recipe, dd, context) + + # test that all spikes are sorted by time then by gid + def test_spikes_sorted(self): + sim = self.init_sim(art_spiker_recipe()) + sim.record(A.spike_recording.all) + # run simulation in 5 steps, forcing 5 epochs + sim.run(1, 0.01) + sim.run(2, 0.01) + sim.run(3, 0.01) + sim.run(4, 0.01) + sim.run(5, 0.01) + + spikes = sim.spikes() + times = spikes["time"].tolist() + gids = spikes["source"]["gid"].tolist() + + self.assertEqual([2, 1, 0, 0, 1, 2, 0, 1, 2, 0, 2, 1, 1], gids) + self.assertEqual([0.2, 0.4, 0.8, 2., 2., 2., 2.1, 2.2, 2.8, 3., 3., 3.1, 4.5], times) + +def suite(): + # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts + suite = unittest.makeSuite(Spikes, ('test')) + return suite + +def run(): + v = options.parse_arguments().verbosity + runner = unittest.TextTestRunner(verbosity = v) + runner.run(suite()) + +if __name__ == "__main__": + run()