From 1aa0aed20c1ad58cbcbb310c17774d2287360ac5 Mon Sep 17 00:00:00 2001
From: max9901 <43051058+max9901@users.noreply.github.com>
Date: Fri, 8 Oct 2021 09:31:58 +0200
Subject: [PATCH] Provide a method to clear sampler contents

- Add `simulation.clear_samplers` to remove spikes and samples from Python.
- Usecase: long-running simulations; where unlimited growth of sampling memory is an issue.
---
 doc/python/simulation.rst               |   4 +
 python/simulation.cpp                   |  12 ++
 python/test/unit/runner.py              |   5 +-
 python/test/unit/test_clear_samplers.py | 148 ++++++++++++++++++++++++
 4 files changed, 168 insertions(+), 1 deletion(-)
 create mode 100644 python/test/unit/test_clear_samplers.py

diff --git a/doc/python/simulation.rst b/doc/python/simulation.rst
index 452b6f52..0104843a 100644
--- a/doc/python/simulation.rst
+++ b/doc/python/simulation.rst
@@ -74,6 +74,10 @@ over the local and distributed hardware resources (see :ref:`pydomdec`). Then, t
         Reset the state of the simulation to its initial state.
         Clears recorded spikes and sample data.
 
+    .. function:: clear_samplers()
+
+        Clears recorded spikes and sample data.
+
     .. function:: run(tfinal, dt)
 
         Run the simulation from current simulation time to ``tfinal``,
diff --git a/python/simulation.cpp b/python/simulation.cpp
index 98a06c5a..fecae322 100644
--- a/python/simulation.cpp
+++ b/python/simulation.cpp
@@ -77,6 +77,15 @@ public:
         }
     }
 
+    void clear_samplers() {
+        spike_record_.clear();
+        for (auto&& [handle, cb]: sampler_map_) {
+            for (auto& rec: *cb.recorders) {
+                rec->reset();
+            }
+        }
+    }
+
     arb::time_type run(arb::time_type tfinal, arb::time_type dt) {
         return sim_->run(tfinal, dt);
     }
@@ -199,6 +208,9 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
         .def("reset", &simulation_shim::reset,
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Reset the state of the simulation to its initial state.")
+        .def("clear_samplers", &simulation_shim::clear_samplers,
+             pybind11::call_guard<pybind11::gil_scoped_release>(),
+             "Clearing spike and sample information. restoring memory")
         .def("run", &simulation_shim::run,
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Run the simulation from current simulation time to tfinal [ms], with maximum time step size dt [ms].",
diff --git a/python/test/unit/runner.py b/python/test/unit/runner.py
index 75550dee..224f1957 100644
--- a/python/test/unit/runner.py
+++ b/python/test/unit/runner.py
@@ -12,6 +12,7 @@ try:
     import options
     import test_cable_probes
     import test_catalogues
+    import test_clear_samplers
     import test_contexts
     import test_decor
     import test_domain_decomposition
@@ -27,6 +28,7 @@ except ModuleNotFoundError:
     from test import options
     from test.unit import test_cable_probes
     from test.unit import test_catalogues
+    from test.unit import test_clear_samplers
     from test.unit import test_contexts
     from test.unit import test_decor
     from test.unit import test_domain_decompositions
@@ -40,7 +42,8 @@ except ModuleNotFoundError:
 
 test_modules = [\
     test_cable_probes,\
-    test_catalogues,\
+    test_catalogues, \
+    test_clear_samplers, \
     test_contexts,\
     test_decor,\
     test_domain_decompositions,\
diff --git a/python/test/unit/test_clear_samplers.py b/python/test/unit/test_clear_samplers.py
new file mode 100644
index 00000000..c54db908
--- /dev/null
+++ b/python/test/unit/test_clear_samplers.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+#
+# test_spikes.py
+
+import unittest
+import arbor as A
+import numpy as np
+
+# to be able to run .py file from child directory
+import sys, os
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
+
+try:
+    import options
+except ModuleNotFoundError:
+    from test import options
+
+"""
+all tests for the simulator wrapper
+"""
+
+def make_cable_cell():
+    # (1) Create a morphology with a single (cylindrical) segment of length=diameter=6 μm
+    tree = A.segment_tree()
+    tree.append(A.mnpos, A.mpoint(-3, 0, 0, 3), A.mpoint(3, 0, 0, 3), tag=1)
+
+    # (2) Define the soma and its midpoint
+    labels = A.label_dict({'soma':   '(tag 1)',
+                               'midpoint': '(location 0 0.5)'})
+
+    # (3) Create cell and set properties
+    decor = A.decor()
+    decor.set_property(Vm=-40)
+    decor.paint('"soma"', 'hh')
+    decor.place('"midpoint"', A.iclamp( 10, 2, 0.8), "iclamp")
+    decor.place('"midpoint"', A.spike_detector(-10), "detector")
+    return A.cable_cell(tree, labels, decor)
+
+# Test recipe art_spiker_recipe comprises three artificial spiking cells
+class art_spiker_recipe(A.recipe):
+    def __init__(self):
+        A.recipe.__init__(self)
+        self.the_props = A.neuron_cable_properties()
+        self.trains = [
+                [0.8, 2, 2.1, 3],
+                [0.4, 2, 2.2, 3.1, 4.5],
+                [0.2, 2, 2.8, 3]]
+
+    def num_cells(self):
+        return 4
+
+    def cell_kind(self, gid):
+        if gid < 3:
+            return A.cell_kind.spike_source
+        else:
+            return A.cell_kind.cable
+
+    def connections_on(self, gid):
+        return []
+
+    def event_generators(self, gid):
+        return []
+
+    def global_properties(self, kind):
+        return self.the_props
+
+    def probes(self, gid):
+        if gid < 3:
+            return []
+        else:
+            return [A.cable_probe_membrane_voltage('"midpoint"')]
+
+    def cell_description(self, gid):
+        if gid < 3:
+            return A.spike_source_cell("src", A.explicit_schedule(self.trains[gid]))
+        else:
+            return make_cable_cell()
+
+
+
+class Clear_samplers(unittest.TestCase):
+    # Helper for constructing a simulation from a recipe using default context and domain decomposition.
+    def init_sim(self, recipe):
+        context = A.context()
+        dd = A.partition_load_balance(recipe, context)
+        return A.simulation(recipe, dd, context)
+
+    # test that all spikes are sorted by time then by gid
+    def test_spike_clearing(self):
+
+        sim = self.init_sim(art_spiker_recipe())
+        sim.record(A.spike_recording.all)
+        handle = sim.sample((3, 0), A.regular_schedule(0.1))
+
+        # baseline to test against Run in exactly the same stepping to make sure there are no rounding differences
+        sim.run(3, 0.01)
+        sim.run(5, 0.01)
+        spikes = sim.spikes()
+        times = spikes["time"].tolist()
+        gids = spikes["source"]["gid"].tolist()
+        data, meta = sim.samples(handle)[0]
+        # reset the simulator
+        sim.reset()
+
+        # simulated with clearing the memory inbetween the steppings
+        sim.run(3,0.01)
+        spikes = sim.spikes()
+        times_t  = spikes["time"].tolist()
+        gids_t   = spikes["source"]["gid"].tolist()
+        data_t, meta_t = sim.samples(handle)[0]
+
+        # clear the samplers memory
+        sim.clear_samplers()
+
+        # Check if the memory is cleared
+        spikes = sim.spikes()
+        self.assertEqual(0, len(spikes["time"].tolist()))
+        self.assertEqual(0, len(spikes["source"]["gid"].tolist()))
+        data_test, meta_test = sim.samples(handle)[0]
+        self.assertEqual(0,data_test.size)
+
+        # run the next part of the simulation
+        sim.run(5, 0.01)
+        spikes = sim.spikes()
+        times_t.extend(spikes["time"].tolist())
+        gids_t.extend(spikes["source"]["gid"].tolist())
+        data_temp, meta_temp = sim.samples(handle)[0]
+        data_t = np.concatenate((data_t, data_temp), 0)
+
+
+        # check if results are the same
+        self.assertEqual(gids, gids_t)
+        self.assertEqual(times_t, times)
+        self.assertEqual(list(data[:, 0]), list(data_t[:, 0]))
+        self.assertEqual(list(data[:, 1]), list(data_t[:, 1]))
+
+def suite():
+    # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts
+    suite = unittest.makeSuite(Clear_samplers, ('test'))
+    return suite
+
+def run():
+    v = options.parse_arguments().verbosity
+    runner = unittest.TextTestRunner(verbosity = v)
+    runner.run(suite())
+
+if __name__ == "__main__":
+    run()
-- 
GitLab