diff --git a/src/cell_group.hpp b/src/cell_group.hpp index 990f055ff139e42e0a30eb9deaa9b57498362c67..55cf7dd9d4bb9f50b76e49fdc6f99bdd0d867681 100644 --- a/src/cell_group.hpp +++ b/src/cell_group.hpp @@ -19,10 +19,79 @@ namespace nest { namespace mc { +enum class binning_kind { + none, + regular, // => round time down to multiple of binning interval. + following, // => round times down to previous event if within binning interval. +}; + +class event_binner { +public: + using time_type = spike::time_type; + + void reset() { + last_event_times_.clear(); + } + + event_binner(): policy_(binning_kind::none), bin_interval_(0) {} + + event_binner(binning_kind policy, time_type bin_interval): + policy_(policy), bin_interval_(bin_interval) + {} + + // Determine binned time for an event based on policy. + // If `t_min` is specified, the binned time will be no lower than `t_min`. + // Otherwise the returned binned time will be less than or equal to the parameter `t`, + // and within `bin_interval_`. + + time_type bin(cell_gid_type id, time_type t, time_type t_min = std::numeric_limits<time_type>::lowest()) { + time_type t_binned = t; + + switch (policy_) { + case binning_kind::none: + break; + case binning_kind::regular: + if (bin_interval_>0) { + t_binned = std::floor(t/bin_interval_)*bin_interval_; + } + break; + case binning_kind::following: + if (auto last_t = last_event_time(id)) { + if (t-*last_t<bin_interval_) { + t_binned = *last_t; + } + } + update_last_event_time(id, t_binned); + break; + default: + throw std::logic_error("unrecognized binning policy"); + } + + return std::max(t_binned, t_min); + } + +private: + binning_kind policy_; + + // Interval in which event times can be aliased. + time_type bin_interval_; + + // (Consider replacing this with a vector-backed store.) + std::unordered_map<cell_gid_type, time_type> last_event_times_; + + util::optional<time_type> last_event_time(cell_gid_type id) { + auto it = last_event_times_.find(id); + return it==last_event_times_.end()? util::nothing: util::just(it->second); + } + + void update_last_event_time(cell_gid_type id, time_type t) { + last_event_times_[id] = t; + } +}; + template <typename LoweredCell> class cell_group { public: - using iarray = cell_gid_type; using lowered_cell_type = LoweredCell; using value_type = typename lowered_cell_type::value_type; using size_type = typename lowered_cell_type::value_type; @@ -48,7 +117,7 @@ public: target_handles_.resize(n_targets); probe_handles_.resize(n_probes); - cell_.initialize(cells, target_handles_, probe_handles_); + lowered_.initialize(cells, target_handles_, probe_handles_); // Create a list of the global identifiers for the spike sources auto source_gid = cell_gid_type{gid_base_}; @@ -65,23 +134,24 @@ public: spikes_.clear(); clear_events(); reset_samplers(); - cell_.reset(); + binner_.reset(); + lowered_.reset(); } - time_type min_step(time_type dt) { - return 0.1*dt; + void set_binning_policy(binning_kind policy, time_type bin_interval) { + binner_ = event_binner(policy, bin_interval); } void advance(time_type tfinal, time_type dt) { - while (cell_.time()<tfinal) { + while (lowered_.time()<tfinal) { // take any pending samples - time_type cell_time = cell_.time(); + time_type cell_time = lowered_.time(); PE("sampling"); while (auto m = sample_events_.pop_if_before(cell_time)) { auto& s = samplers_[m->sampler_index]; EXPECTS((bool)s.sampler); - auto next = s.sampler(cell_.time(), cell_.probe(s.handle)); + auto next = s.sampler(lowered_.time(), lowered_.probe(s.handle)); if (next) { m->time = std::max(*next, cell_time); @@ -91,31 +161,31 @@ public: PL(); // look for events in the next time step - time_type tstep = cell_.time()+dt; + time_type tstep = lowered_.time()+dt; tstep = std::min(tstep, tfinal); auto next = events_.pop_if_before(tstep); // apply events that are due within the smallest allowed time step. - while (next && (next->time-cell_.time()) < min_step(dt)) { + while (next && (next->time-lowered_.time()) < 0.1*(dt)) { auto handle = get_target_handle(next->target); - cell_.deliver_event(handle, next->weight); + lowered_.deliver_event(handle, next->weight); next = events_.pop_if_before(tstep); } // integrate cell state time_type tnext = next ? next->time: tstep; - cell_.advance(tnext - cell_.time()); + lowered_.advance(tnext - lowered_.time()); - if (util::is_debug_mode() && !cell_.is_physical_solution()) { + if (util::is_debug_mode() && !lowered_.is_physical_solution()) { std::cerr << "warning: solution out of bounds for cell " - << gid_base_ << " at t " << cell_.time() << " ms\n"; + << gid_base_ << " at t " << lowered_.time() << " ms\n"; } // apply events PE("events"); if (next) { auto handle = get_target_handle(next->target); - cell_.deliver_event(handle, next->weight); + lowered_.deliver_event(handle, next->weight); } PL(); } @@ -125,12 +195,12 @@ public: // record the local spike source index, which must be converted to a // global index for spike communication. PE("events"); - for (auto c: cell_.get_spikes()) { + for (auto c: lowered_.get_spikes()) { spikes_.push_back({spike_sources_[c.index], time_type(c.time)}); } // Now that the spikes have been generated, clear the old crossings // to get ready to record spikes from the next integration period. - cell_.clear_spikes(); + lowered_.clear_spikes(); PL(); } @@ -181,33 +251,33 @@ public: } value_type probe(cell_member_type probe_id) const { - return cell_.probe(get_probe_handle(probe_id)); + return lowered_.probe(get_probe_handle(probe_id)); } private: - /// gid of first cell in group + // gid of first cell in group. cell_gid_type gid_base_; - /// the lowered cell state (e.g. FVM) of the cell - lowered_cell_type cell_; + // The lowered cell state (e.g. FVM) of the cell. + lowered_cell_type lowered_; - /// spike detectors attached to the cell + // Spike detectors attached to the cell. std::vector<source_id_type> spike_sources_; - /// spikes that are generated + // Spikes that are generated. std::vector<spike> spikes_; - /// pending events to be delivered + // Event time binning manager. + event_binner binner_; + + // Pending events to be delivered. event_queue<postsynaptic_spike_event<time_type>> events_; - /// pending samples to be taken + // Pending samples to be taken. event_queue<sample_event<time_type>> sample_events_; std::vector<time_type> sampler_start_times_; - /// the global id of the first target (e.g. a synapse) in this group - iarray first_target_gid_; - - /// handles for accessing lowered cell + // Handles for accessing lowered cell. using target_handle = typename lowered_cell_type::target_handle; std::vector<target_handle> target_handles_; @@ -219,16 +289,16 @@ private: sampler_function sampler; }; - /// collection of samplers to be run against probes in this group + // Collection of samplers to be run against probes in this group. std::vector<sampler_entry> samplers_; - /// lookup table for probe ids -> local probe handle indices + // Lookup table for probe ids -> local probe handle indices. std::vector<std::size_t> probe_handle_divisions_; - /// lookup table for target ids -> local target handle indices + // Lookup table for target ids -> local target handle indices. std::vector<std::size_t> target_handle_divisions_; - /// build handle index lookup tables + // Build handle index lookup tables. template <typename Cells> void build_handle_partitions(const Cells& cells) { auto probe_counts = util::transform_view(cells, [](const cell& c) { return c.probes().size(); }); @@ -238,7 +308,7 @@ private: make_partition(target_handle_divisions_, target_counts); } - /// use handle partition to get index from id + // Use handle partition to get index from id. template <typename Divisions> std::size_t handle_partition_lookup(const Divisions& divisions, cell_member_type id) const { // NB: without any assertion checking, this would just be: @@ -256,12 +326,12 @@ private: return i; } - /// get probe handle from probe id + // Get probe handle from probe id. probe_handle get_probe_handle(cell_member_type probe_id) const { return probe_handles_[handle_partition_lookup(probe_handle_divisions_, probe_id)]; } - /// get target handle from target id + // Get target handle from target id. target_handle get_target_handle(cell_member_type target_id) const { return target_handles_[handle_partition_lookup(target_handle_divisions_, target_id)]; } diff --git a/tests/unit/common.hpp b/tests/unit/common.hpp index 687bfef82e4f7ac895e17fdce17b0a5e64f77d38..dc7b5a357fdbc8e0a1aa25ec993a36e3d9c6427f 100644 --- a/tests/unit/common.hpp +++ b/tests/unit/common.hpp @@ -5,6 +5,8 @@ * more than one unit test. */ +#include "../gtest.h" + namespace testing { // sentinel for use with range-related tests @@ -106,4 +108,37 @@ int nomove<V>::copy_ctor_count; template <typename V> int nomove<V>::copy_assign_count; +// Google Test assertion-returning predicates: + +// Assert two sequences of floating point values are almost equal. +// (Uses internal class `FloatingPoint` from gtest.) +template <typename FPType, typename Seq1, typename Seq2> +::testing::AssertionResult seq_almost_eq(Seq1&& seq1, Seq2&& seq2) { + using std::begin; + using std::end; + + auto i1 = begin(seq1); + auto i2 = begin(seq2); + + auto e1 = end(seq1); + auto e2 = end(seq2); + + for (std::size_t j = 0; i1!=e1 && i2!=e2; ++i1, ++i2, ++j) { + using FP = testing::internal::FloatingPoint<FPType>; + + auto v1 = *i1; + auto v2 = *i2; + + if (!FP{v1}.AlmostEquals(FP{v2})) { + return ::testing::AssertionFailure() << "floating point numbers " << v1 << " and " << v2 << " differ at index " << j; + } + + } + + if (i1!=e1 || i2!=e2) { + return ::testing::AssertionFailure() << "sequences differ in length"; + } + return ::testing::AssertionSuccess(); +} + } // namespace testing diff --git a/tests/unit/test_cell_group.cpp b/tests/unit/test_cell_group.cpp index f20945ad0f639708ef3fd7742acd6aead554d154..afd2d20354b22f49f461a6427d50a64805bd02f0 100644 --- a/tests/unit/test_cell_group.cpp +++ b/tests/unit/test_cell_group.cpp @@ -5,10 +5,11 @@ #include <fvm_multicell.hpp> #include <util/rangeutil.hpp> +#include "common.hpp" #include "../test_common_cells.hpp" -using fvm_cell = - nest::mc::fvm::fvm_multicell<nest::mc::multicore::backend>; +using namespace nest::mc; +using fvm_cell = fvm::fvm_multicell<nest::mc::multicore::backend>; nest::mc::cell make_cell() { using namespace nest::mc; @@ -21,10 +22,7 @@ nest::mc::cell make_cell() { return cell; } -TEST(cell_group, test) -{ - using namespace nest::mc; - +TEST(cell_group, test) { using cell_group_type = cell_group<fvm_cell>; auto group = cell_group_type{0, util::singleton_view(make_cell())}; @@ -35,10 +33,7 @@ TEST(cell_group, test) EXPECT_EQ(4u, group.spikes().size()); } -TEST(cell_group, sources) -{ - using namespace nest::mc; - +TEST(cell_group, sources) { using cell_group_type = cell_group<fvm_cell>; auto cell = make_cell(); @@ -66,3 +61,101 @@ TEST(cell_group, sources) } } } + +TEST(cell_group, event_binner) { + using testing::seq_almost_eq; + + std::pair<cell_gid_type, float> binning_test_data[] = { + { 11, 0.50 }, + { 12, 0.70 }, + { 14, 0.73 }, + { 11, 1.80 }, + { 12, 1.83 }, + { 11, 1.90 }, + { 11, 2.00 }, + { 14, 2.00 }, + { 11, 2.10 }, + { 14, 2.30 } + }; + + std::unordered_map<cell_gid_type, std::vector<float>> ev_times; + std::vector<float> expected; + + auto run_binner = [&](event_binner&& binner) { + ev_times.clear(); + for (auto p: binning_test_data) { + ev_times[p.first].push_back(binner.bin(p.first, p.second)); + } + }; + + run_binner(event_binner{binning_kind::none, 0}); + + EXPECT_TRUE(seq_almost_eq<float>(ev_times[11], (float []){0.50, 1.80, 1.90, 2.00, 2.10})); + EXPECT_TRUE(seq_almost_eq<float>(ev_times[12], (float []){0.70, 1.83})); + EXPECT_TRUE(ev_times[13].empty()); + EXPECT_TRUE(seq_almost_eq<float>(ev_times[14], (float []){0.73, 2.00, 2.30})); + + run_binner(event_binner{binning_kind::regular, 0.25}); + + EXPECT_TRUE(seq_almost_eq<float>(ev_times[11], (float []){0.50, 1.75, 1.75, 2.00, 2.00})); + EXPECT_TRUE(seq_almost_eq<float>(ev_times[12], (float []){0.50, 1.75})); + EXPECT_TRUE(ev_times[13].empty()); + EXPECT_TRUE(seq_almost_eq<float>(ev_times[14], (float []){0.50, 2.00, 2.25})); + + run_binner(event_binner{binning_kind::following, 0.25}); + + EXPECT_TRUE(seq_almost_eq<float>(ev_times[11], (float []){0.50, 1.80, 1.80, 1.80, 2.10})); + EXPECT_TRUE(seq_almost_eq<float>(ev_times[12], (float []){0.70, 1.83})); + EXPECT_TRUE(ev_times[13].empty()); + EXPECT_TRUE(seq_almost_eq<float>(ev_times[14], (float []){0.73, 2.00, 2.30})); +} + +TEST(cell_group, event_binner_with_min) { + using testing::seq_almost_eq; + + struct test_time { + float time; + float t_min; + }; + test_time test_data[] = { + {0.8f, 1.0f}, + {1.6f, 1.0f}, + {1.9f, 1.8f}, + {2.0f, 1.8f}, + {2.2f, 1.8f} + }; + + std::vector<float> times; + auto run_binner = [&](event_binner&& binner, bool use_min) { + times.clear(); + for (auto p: test_data) { + if (use_min) { + times.push_back(binner.bin(0, p.time, p.t_min)); + } + else { + times.push_back(binner.bin(0, p.time)); + } + } + }; + + // 'none' binning + run_binner(event_binner{binning_kind::none, 0.5}, false); + EXPECT_TRUE(seq_almost_eq<float>(times, (float []){0.8, 1.6, 1.9, 2.0, 2.2})); + + run_binner(event_binner{binning_kind::none, 0.5}, true); + EXPECT_TRUE(seq_almost_eq<float>(times, (float []){1.0, 1.6, 1.9, 2.0, 2.2})); + + // 'regular' binning + run_binner(event_binner{binning_kind::regular, 0.5}, false); + EXPECT_TRUE(seq_almost_eq<float>(times, (float []){0.5, 1.5, 1.5, 2.0, 2.0})); + + run_binner(event_binner{binning_kind::regular, 0.5}, true); + EXPECT_TRUE(seq_almost_eq<float>(times, (float []){1.0, 1.5, 1.8, 2.0, 2.0})); + + // 'following' binning + run_binner(event_binner{binning_kind::following, 0.5}, false); + EXPECT_TRUE(seq_almost_eq<float>(times, (float []){0.8, 1.6, 1.6, 1.6, 2.2})); + + run_binner(event_binner{binning_kind::following, 0.5}, true); + EXPECT_TRUE(seq_almost_eq<float>(times, (float []){1.0, 1.6, 1.8, 1.8, 2.2})); +}