From 1aa0aed20c1ad58cbcbb310c17774d2287360ac5 Mon Sep 17 00:00:00 2001 From: max9901 <43051058+max9901@users.noreply.github.com> Date: Fri, 8 Oct 2021 09:31:58 +0200 Subject: [PATCH] Provide a method to clear sampler contents - Add `simulation.clear_samplers` to remove spikes and samples from Python. - Usecase: long-running simulations; where unlimited growth of sampling memory is an issue. --- doc/python/simulation.rst | 4 + python/simulation.cpp | 12 ++ python/test/unit/runner.py | 5 +- python/test/unit/test_clear_samplers.py | 148 ++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 python/test/unit/test_clear_samplers.py diff --git a/doc/python/simulation.rst b/doc/python/simulation.rst index 452b6f52..0104843a 100644 --- a/doc/python/simulation.rst +++ b/doc/python/simulation.rst @@ -74,6 +74,10 @@ over the local and distributed hardware resources (see :ref:`pydomdec`). Then, t Reset the state of the simulation to its initial state. Clears recorded spikes and sample data. + .. function:: clear_samplers() + + Clears recorded spikes and sample data. + .. function:: run(tfinal, dt) Run the simulation from current simulation time to ``tfinal``, diff --git a/python/simulation.cpp b/python/simulation.cpp index 98a06c5a..fecae322 100644 --- a/python/simulation.cpp +++ b/python/simulation.cpp @@ -77,6 +77,15 @@ public: } } + void clear_samplers() { + spike_record_.clear(); + for (auto&& [handle, cb]: sampler_map_) { + for (auto& rec: *cb.recorders) { + rec->reset(); + } + } + } + arb::time_type run(arb::time_type tfinal, arb::time_type dt) { return sim_->run(tfinal, dt); } @@ -199,6 +208,9 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) { .def("reset", &simulation_shim::reset, pybind11::call_guard<pybind11::gil_scoped_release>(), "Reset the state of the simulation to its initial state.") + .def("clear_samplers", &simulation_shim::clear_samplers, + pybind11::call_guard<pybind11::gil_scoped_release>(), + "Clearing spike and sample information. restoring memory") .def("run", &simulation_shim::run, pybind11::call_guard<pybind11::gil_scoped_release>(), "Run the simulation from current simulation time to tfinal [ms], with maximum time step size dt [ms].", diff --git a/python/test/unit/runner.py b/python/test/unit/runner.py index 75550dee..224f1957 100644 --- a/python/test/unit/runner.py +++ b/python/test/unit/runner.py @@ -12,6 +12,7 @@ try: import options import test_cable_probes import test_catalogues + import test_clear_samplers import test_contexts import test_decor import test_domain_decomposition @@ -27,6 +28,7 @@ except ModuleNotFoundError: from test import options from test.unit import test_cable_probes from test.unit import test_catalogues + from test.unit import test_clear_samplers from test.unit import test_contexts from test.unit import test_decor from test.unit import test_domain_decompositions @@ -40,7 +42,8 @@ except ModuleNotFoundError: test_modules = [\ test_cable_probes,\ - test_catalogues,\ + test_catalogues, \ + test_clear_samplers, \ test_contexts,\ test_decor,\ test_domain_decompositions,\ diff --git a/python/test/unit/test_clear_samplers.py b/python/test/unit/test_clear_samplers.py new file mode 100644 index 00000000..c54db908 --- /dev/null +++ b/python/test/unit/test_clear_samplers.py @@ -0,0 +1,148 @@ +# -*- coding: utf-8 -*- +# +# test_spikes.py + +import unittest +import arbor as A +import numpy as np + +# to be able to run .py file from child directory +import sys, os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) + +try: + import options +except ModuleNotFoundError: + from test import options + +""" +all tests for the simulator wrapper +""" + +def make_cable_cell(): + # (1) Create a morphology with a single (cylindrical) segment of length=diameter=6 μm + tree = A.segment_tree() + tree.append(A.mnpos, A.mpoint(-3, 0, 0, 3), A.mpoint(3, 0, 0, 3), tag=1) + + # (2) Define the soma and its midpoint + labels = A.label_dict({'soma': '(tag 1)', + 'midpoint': '(location 0 0.5)'}) + + # (3) Create cell and set properties + decor = A.decor() + decor.set_property(Vm=-40) + decor.paint('"soma"', 'hh') + decor.place('"midpoint"', A.iclamp( 10, 2, 0.8), "iclamp") + decor.place('"midpoint"', A.spike_detector(-10), "detector") + return A.cable_cell(tree, labels, decor) + +# Test recipe art_spiker_recipe comprises three artificial spiking cells +class art_spiker_recipe(A.recipe): + def __init__(self): + A.recipe.__init__(self) + self.the_props = A.neuron_cable_properties() + self.trains = [ + [0.8, 2, 2.1, 3], + [0.4, 2, 2.2, 3.1, 4.5], + [0.2, 2, 2.8, 3]] + + def num_cells(self): + return 4 + + def cell_kind(self, gid): + if gid < 3: + return A.cell_kind.spike_source + else: + return A.cell_kind.cable + + def connections_on(self, gid): + return [] + + def event_generators(self, gid): + return [] + + def global_properties(self, kind): + return self.the_props + + def probes(self, gid): + if gid < 3: + return [] + else: + return [A.cable_probe_membrane_voltage('"midpoint"')] + + def cell_description(self, gid): + if gid < 3: + return A.spike_source_cell("src", A.explicit_schedule(self.trains[gid])) + else: + return make_cable_cell() + + + +class Clear_samplers(unittest.TestCase): + # Helper for constructing a simulation from a recipe using default context and domain decomposition. + def init_sim(self, recipe): + context = A.context() + dd = A.partition_load_balance(recipe, context) + return A.simulation(recipe, dd, context) + + # test that all spikes are sorted by time then by gid + def test_spike_clearing(self): + + sim = self.init_sim(art_spiker_recipe()) + sim.record(A.spike_recording.all) + handle = sim.sample((3, 0), A.regular_schedule(0.1)) + + # baseline to test against Run in exactly the same stepping to make sure there are no rounding differences + sim.run(3, 0.01) + sim.run(5, 0.01) + spikes = sim.spikes() + times = spikes["time"].tolist() + gids = spikes["source"]["gid"].tolist() + data, meta = sim.samples(handle)[0] + # reset the simulator + sim.reset() + + # simulated with clearing the memory inbetween the steppings + sim.run(3,0.01) + spikes = sim.spikes() + times_t = spikes["time"].tolist() + gids_t = spikes["source"]["gid"].tolist() + data_t, meta_t = sim.samples(handle)[0] + + # clear the samplers memory + sim.clear_samplers() + + # Check if the memory is cleared + spikes = sim.spikes() + self.assertEqual(0, len(spikes["time"].tolist())) + self.assertEqual(0, len(spikes["source"]["gid"].tolist())) + data_test, meta_test = sim.samples(handle)[0] + self.assertEqual(0,data_test.size) + + # run the next part of the simulation + sim.run(5, 0.01) + spikes = sim.spikes() + times_t.extend(spikes["time"].tolist()) + gids_t.extend(spikes["source"]["gid"].tolist()) + data_temp, meta_temp = sim.samples(handle)[0] + data_t = np.concatenate((data_t, data_temp), 0) + + + # check if results are the same + self.assertEqual(gids, gids_t) + self.assertEqual(times_t, times) + self.assertEqual(list(data[:, 0]), list(data_t[:, 0])) + self.assertEqual(list(data[:, 1]), list(data_t[:, 1])) + +def suite(): + # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts + suite = unittest.makeSuite(Clear_samplers, ('test')) + return suite + +def run(): + v = options.parse_arguments().verbosity + runner = unittest.TextTestRunner(verbosity = v) + runner.run(suite()) + +if __name__ == "__main__": + run() -- GitLab