diff --git a/src/cell_group.hpp b/src/cell_group.hpp index 7e26da0f4af332993c8a70043b5353b9d2b7cc5c..091fd0b1f5c89ec9c76bd2e432dacb032bdbc504 100644 --- a/src/cell_group.hpp +++ b/src/cell_group.hpp @@ -12,7 +12,6 @@ #include <sampling.hpp> #include <schedule.hpp> #include <spike.hpp> -#include <util/rangeutil.hpp> namespace arb { @@ -24,24 +23,8 @@ public: virtual void reset() = 0; virtual void set_binning_policy(binning_kind policy, time_type bin_interval) = 0; - virtual void advance(epoch epoch, time_type dt) = 0; - - // Pass events to be delivered to targets in the cell group in a future epoch. - // events: - // An unsorted vector of post-synaptic events is maintained for each gid - // on the local domain. These event lists are stored in a vector, with one - // entry for each gid. Event lists for a cell group are contiguous in the - // vector, in same order that input gid were provided to the cell_group - // constructor. - // tfinal: - // The final time for the current integration epoch. This may be used - // by the cell_group implementation to omptimise event queue wrangling. - // epoch: - // The current integration epoch. Events in events are due for delivery - // in epoch+1 and later. - virtual void enqueue_events( - epoch epoch, - util::subrange_view_type<std::vector<std::vector<postsynaptic_spike_event>>> events) = 0; + virtual void advance(epoch epoch, time_type dt, const event_lane_subrange& events) = 0; + virtual const std::vector<spike>& spikes() const = 0; virtual void clear_spikes() = 0; diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index d1813e134baefa54a84d70652bfe2ae530668620..a5b13639865c6e44467d838c39b177e1a8c80176 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -39,9 +39,6 @@ class communicator { public: using communication_policy_type = CommunicationPolicy; - /// per-cell group lists of events to be delivered - using event_queue = std::vector<postsynaptic_spike_event>; - communicator() {} explicit communicator(const recipe& rec, const domain_decomposition& dom_dec) { @@ -158,12 +155,12 @@ public: /// Returns a vector of event queues, with one queue for each local cell group. The /// events in each queue are all events that must be delivered to targets in that cell /// group as a result of the global spike exchange. - std::vector<event_queue> make_event_queues(const gathered_vector<spike>& global_spikes) { + std::vector<pse_vector> make_event_queues(const gathered_vector<spike>& global_spikes) { using util::subrange_view; using util::make_span; using util::make_range; - auto queues = std::vector<event_queue>(num_local_cells_); + auto queues = std::vector<pse_vector>(num_local_cells_); const auto& sp = global_spikes.partition(); const auto& cp = connection_part_; for (auto dom: make_span(0, num_domains_)) { @@ -220,6 +217,10 @@ public: /// Returns the total number of global spikes over the duration of the simulation std::uint64_t num_spikes() const { return num_spikes_; } + cell_size_type num_local_cells() const { + return num_local_cells_; + } + const std::vector<connection>& connections() const { return connections_; } diff --git a/src/dss_cell_group.hpp b/src/dss_cell_group.hpp index c3466572b26cc2b26eb493f78fca32c229d8c1d7..e4747fec5887a373d294e1063d2804dff0d1b1e3 100644 --- a/src/dss_cell_group.hpp +++ b/src/dss_cell_group.hpp @@ -44,7 +44,7 @@ public: void set_binning_policy(binning_kind policy, time_type bin_interval) override {} - void advance(epoch ep, time_type dt) override { + void advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) override { for (auto i: util::make_span(0, not_emit_it_.size())) { // The first potential spike_time to emit for this cell auto spike_time_it = not_emit_it_[i]; @@ -62,10 +62,6 @@ public: } }; - void enqueue_events(epoch, util::subrange_view_type<std::vector<std::vector<postsynaptic_spike_event>>>) override { - std::runtime_error("The dss_cells do not support incoming events!"); - } - const std::vector<spike>& spikes() const override { return spikes_; } diff --git a/src/event_queue.hpp b/src/event_queue.hpp index 02dfd2bbd5b49743a1ac4643e74811ec353932eb..5675e1d1719e7e8f60bcea21e2bbbb4ea3607362 100644 --- a/src/event_queue.hpp +++ b/src/event_queue.hpp @@ -7,12 +7,13 @@ #include <type_traits> #include <utility> -#include "common_types.hpp" -#include "generic_event.hpp" -#include "util/meta.hpp" -#include "util/optional.hpp" -#include "util/range.hpp" -#include "util/strprintf.hpp" +#include <common_types.hpp> +#include <generic_event.hpp> +#include <util/meta.hpp> +#include <util/optional.hpp> +#include <util/range.hpp> +#include <util/rangeutil.hpp> +#include <util/strprintf.hpp> namespace arb { @@ -28,16 +29,23 @@ struct postsynaptic_spike_event { time_type time; float weight; - friend bool operator==(postsynaptic_spike_event l, postsynaptic_spike_event r) { + friend bool operator==(const postsynaptic_spike_event& l, const postsynaptic_spike_event& r) { return l.target==r.target && l.time==r.time && l.weight==r.weight; } + friend bool operator<(const postsynaptic_spike_event& l, const postsynaptic_spike_event& r) { + return std::tie(l.time, l.target, l.weight) < std::tie(r.time, r.target, r.weight); + } + friend std::ostream& operator<<(std::ostream& o, const arb::postsynaptic_spike_event& e) { return o << "E[tgt " << e.target << ", t " << e.time << ", w " << e.weight << "]"; } }; +using pse_vector = std::vector<postsynaptic_spike_event>; +using event_lane_subrange = util::subrange_view_type<std::vector<pse_vector>>; + template <typename Event> class event_queue { public : diff --git a/src/mc_cell_group.hpp b/src/mc_cell_group.hpp index 845e9b3da18e198afe63a5da93196cb745e284a9..df3badc6eb9b4340eedbfb9383f2e47652b95788 100644 --- a/src/mc_cell_group.hpp +++ b/src/mc_cell_group.hpp @@ -67,13 +67,6 @@ public: } } spike_sources_.shrink_to_fit(); - - // Create event lane buffers. - // There is one set for each epoch: current and next. - event_lanes_.resize(2); - // For each epoch there is one lane for each cell in the cell group. - event_lanes_[0].resize(gids_.size()); - event_lanes_[1].resize(gids_.size()); } cell_kind get_cell_kind() const override { @@ -87,11 +80,6 @@ public: b.reset(); } lowered_.reset(); - for (auto& lanes: event_lanes_) { - for (auto& lane: lanes) { - lane.clear(); - } - } } void set_binning_policy(binning_kind policy, time_type bin_interval) override { @@ -99,23 +87,24 @@ public: binners_.resize(gids_.size(), event_binner(policy, bin_interval)); } - - void advance(epoch ep, time_type dt) override { + void advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) override { PE("advance"); EXPECTS(lowered_.state_synchronized()); time_type tstart = lowered_.min_time(); PE("event-setup"); - const auto& lc = event_lanes(ep.id); staged_events_.clear(); - for (auto lid: util::make_span(0, gids_.size())) { - auto& lane = lc[lid]; - for (auto e: lane) { - if (e.time>=ep.tfinal) break; - e.time = binners_[lid].bin(e.time, tstart); - auto h = target_handles_[target_handle_divisions_[lid]+e.target.index]; - auto ev = deliverable_event(e.time, h, e.weight); - staged_events_.push_back(ev); + // skip event binning if empty lanes are passed + if (event_lanes.size()) { + for (auto lid: util::make_span(0, gids_.size())) { + auto& lane = event_lanes[lid]; + for (auto e: lane) { + if (e.time>=ep.tfinal) break; + e.time = binners_[lid].bin(e.time, tstart); + auto h = target_handles_[target_handle_divisions_[lid]+e.target.index]; + auto ev = deliverable_event(e.time, h, e.weight); + staged_events_.push_back(ev); + } } } PL(); @@ -230,54 +219,6 @@ public: PL(); } - void enqueue_events( - epoch ep, - util::subrange_view_type<std::vector<std::vector<postsynaptic_spike_event>>> events) override - { - using pse = postsynaptic_spike_event; - - // Make convenience variables for event lanes: - // lf: event lanes for the next epoch - // lc: event lanes for the current epoch - auto& lf = event_lanes(ep.id+1); - const auto& lc = event_lanes(ep.id); - - // For a cell, merge the incoming events with events in lc that are - // not to be delivered in thie epoch, and store the result in lf. - auto merge_future_events = [&](unsigned l) { - PE("sort"); - // STEP 1: sort events in place in events[l] - util::sort_by(events[l], [](const pse& e){return event_time(e);}); - PL(); - - PE("merge"); - // STEP 2: clear lf to store merged list - lf[l].clear(); - - // STEP 3: merge new events and future events from lc into lf - auto pos = std::lower_bound(lc[l].begin(), lc[l].end(), ep.tfinal, event_time_less()); - lf[l].resize(events[l].size()+std::distance(pos, lc[l].end())); - std::merge( - events[l].begin(), events[l].end(), pos, lc[l].end(), lf[l].begin(), - [](const pse& l, const pse& r) {return l.time<r.time;}); - PL(); - }; - - // This operation is independent for each cell, so it can be performed - // in parallel if there are sufficient cells in the cell group. - // TODO: use parallel loop based on argument to this function? This would - // allow the caller, which has more context, to decide whether a - // parallel loop is beneficial. - if (gids_.size()>1) { - threading::parallel_for::apply(0, gids_.size(), merge_future_events); - } - else { - for (unsigned l=0; l<gids_.size(); ++l) { - merge_future_events(l); - } - } - } - const std::vector<spike>& spikes() const override { return spikes_; } @@ -310,10 +251,6 @@ public: } private: - std::vector<std::vector<postsynaptic_spike_event>>& event_lanes(std::size_t epoch_id) { - return event_lanes_[epoch_id%2]; - } - // List of the gids of the cells in the group. std::vector<cell_gid_type> gids_; @@ -332,8 +269,7 @@ private: // Event time binning manager. std::vector<event_binner> binners_; - // Pending events to be delivered. - std::vector<std::vector<std::vector<postsynaptic_spike_event>>> event_lanes_; + // List of events to deliver std::vector<deliverable_event> staged_events_; // Pending samples to be taken. diff --git a/src/model.cpp b/src/model.cpp index 42833540a8e66a330b41b8962323fc3646656656..de395019003b88ede71a39141acbf08436323f1d 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -1,3 +1,4 @@ +#include <set> #include <vector> #include <backends.hpp> @@ -6,18 +7,22 @@ #include <domain_decomposition.hpp> #include <model.hpp> #include <recipe.hpp> +#include <util/filter.hpp> #include <util/span.hpp> #include <util/unique_any.hpp> #include <profiling/profiler.hpp> namespace arb { +void merge_events(time_type tfinal, const pse_vector& lc, pse_vector& events, pse_vector& lf); + model::model(const recipe& rec, const domain_decomposition& decomp): communicator_(rec, decomp) { + cell_local_size_type lidx = 0; for (auto i: util::make_span(0, decomp.groups.size())) { for (auto gid: decomp.groups[i].gids) { - gid_groups_[gid] = i; + gid_to_local_[gid] = lidx++; } } @@ -29,14 +34,28 @@ model::model(const recipe& rec, const domain_decomposition& decomp): cell_groups_[i] = cell_group_factory(rec, decomp.groups[i]); PL(2); }); + + + // Create event lane buffers. + // There is one set for each epoch: current (0) and next (1). + // For each epoch there is one lane for each cell in the cell group. + event_lanes_[0].resize(communicator_.num_local_cells()); + event_lanes_[1].resize(communicator_.num_local_cells()); } void model::reset() { t_ = 0.; + for (auto& group: cell_groups_) { group->reset(); } + for (auto& lanes: event_lanes_) { + for (auto& lane: lanes) { + lane.clear(); + } + } + communicator_.reset(); current_spikes().clear(); @@ -63,8 +82,10 @@ time_type model::run(time_type tfinal, time_type dt) { PE("stepping"); auto &group = cell_groups_[i]; - group->advance(epoch_, dt); - + auto queues = util::subrange_view( + event_lanes(epoch_.id), + communicator_.group_queue_range(i)); + group->advance(epoch_, dt, queues); PE("events"); current_spikes().insert(group->spikes()); group->clear_spikes(); @@ -94,11 +115,15 @@ time_type model::run(time_type tfinal, time_type dt) { PL(); PE("enqueue"); - for (auto i: util::make_span(0, cell_groups_.size())) { - cell_groups_[i]->enqueue_events( - epoch_, - util::subrange_view(events, communicator_.group_queue_range(i))); - } + threading::parallel_for::apply(0, num_groups(), + [&](cell_size_type i) { + const auto epid = epoch_.id; + merge_events( + epoch_.tfinal, + event_lanes(epid)[i], + events[i], + event_lanes(epid+1)[i]); + }); PL(2); PL(2); @@ -171,16 +196,16 @@ std::size_t model::num_groups() const { return cell_groups_.size(); } +std::vector<pse_vector>& model::event_lanes(std::size_t epoch_id) { + return event_lanes_[epoch_id%2]; +} + void model::set_binning_policy(binning_kind policy, time_type bin_interval) { for (auto& group: cell_groups_) { group->set_binning_policy(policy, bin_interval); } } -cell_group& model::group(int i) { - return *cell_groups_[i]; -} - void model::set_global_spike_callback(spike_export_function export_callback) { global_export_callback_ = export_callback; } @@ -189,4 +214,67 @@ void model::set_local_spike_callback(spike_export_function export_callback) { local_export_callback_ = export_callback; } +util::optional<cell_size_type> model::local_cell_index(cell_gid_type gid) { + auto it = gid_to_local_.find(gid); + return it==gid_to_local_.end()? + util::nothing: + util::optional<cell_size_type>(it->second); +} + +void model::inject_events(const pse_vector& events) { + auto& lanes = event_lanes(epoch_.id); + + // Append all events that are to be delivered to local cells to the + // appropriate lane. At the same time, keep track of which lanes have been + // modified, because the lanes will have to be sorted once all events have + // been added. + pse_vector local_events; + std::set<cell_size_type> modified_lanes; + for (auto& e: events) { + if (e.time<t_) { + throw std::runtime_error("model::inject_events(): attempt to inject an event at time " + std::to_string(e.time) + ", when model state is at time " + std::to_string(t_)); + } + if (auto lidx = local_cell_index(e.target.gid)) { + lanes[*lidx].push_back(e); + modified_lanes.insert(*lidx); + } + } + + // Sort events in the event lanes that were modified + for (auto l: modified_lanes) { + util::sort(lanes[l]); + } +} + +// Merge events that are to be delivered from two lists into a sorted list. +// Events are sorted by delivery time, then target, then weight. +// +// tfinal: The time at which the current epoch finishes. The output list, `lc`, +// will contain all events with delivery times equal to or greater than +// tfinal. +// lc: Sorted set of events to be delivered before and after `tfinal`. +// events: Unsorted list of events with delivery time greater than or equal to +// tfinal. May be modified by the call. +// lf: Will hold a list of all postsynaptic events in `events` and `lc` that +// have delivery times greater than or equal to `tfinal`. +void merge_events(time_type tfinal, const pse_vector& lc, pse_vector& events, pse_vector& lf) { + // Merge the incoming events with events in lc that are not to be delivered + // in this epoch, and store the result in lf. + + // STEP 1: sort events in place in events[l] + PE("sort"); + util::sort(events); + PL(); + + // STEP 2: clear lf to store merged list + lf.clear(); + + // STEP 3: merge new events and future events from lc into lf + PE("merge"); + auto pos = std::lower_bound(lc.begin(), lc.end(), tfinal, event_time_less()); + lf.resize(events.size()+std::distance(pos, lc.end())); + std::merge(events.begin(), events.end(), pos, lc.end(), lf.begin()); + PL(); +} + } // namespace arb diff --git a/src/model.hpp b/src/model.hpp index 6ad99c7c98ef4d6d97394f7791c7c7a62a6acd36..b4abcdabe9576c17f316c797269ad5ba2143c8c4 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -1,6 +1,7 @@ #pragma once -#include <mutex> +#include <array> +#include <unordered_map> #include <vector> #include <backends.hpp> @@ -45,20 +46,22 @@ public: // Set event binning policy on all our groups. void set_binning_policy(binning_kind policy, time_type bin_interval); - // access cell_group directly - // TODO: deprecate. Currently used in some validation tests to inject - // events directly into a cell group. - cell_group& group(int i); - - // register a callback that will perform a export of the global - // spike vector + // Register a callback that will perform a export of the global + // spike vector. void set_global_spike_callback(spike_export_function export_callback); - // register a callback that will perform a export of the rank local - // spike vector + // Register a callback that will perform a export of the rank local + // spike vector. void set_local_spike_callback(spike_export_function export_callback); + // Add events directly to targets. + // Must be called before calling model::run, and must contain events that + // are to be delivered at or after the current model time. + void inject_events(const pse_vector& events); + private: + std::vector<pse_vector>& event_lanes(std::size_t epoch_id); + std::size_t num_groups() const; // keep track of information about the current integration interval @@ -73,9 +76,10 @@ private: spike_export_function global_export_callback_ = util::nop_function; spike_export_function local_export_callback_ = util::nop_function; - // Hash table for looking up the group index of the cell_group that - // contains gid - std::unordered_map<cell_gid_type, cell_gid_type> gid_groups_; + // Hash table for looking up the the local index of a cell with a given gid + std::unordered_map<cell_gid_type, cell_size_type> gid_to_local_; + + util::optional<cell_size_type> local_cell_index(cell_gid_type); communicator_type communicator_; @@ -94,6 +98,9 @@ private: local_spike_store_type& current_spikes() { return local_spikes_.get(); } local_spike_store_type& previous_spikes() { return local_spikes_.other(); } + // Pending events to be delivered. + std::array<std::vector<pse_vector>, 2> event_lanes_; + // Sampler associations handles are managed by a helper class. util::handle_set<sampler_association_handle> sassoc_handles_; }; diff --git a/src/rss_cell_group.hpp b/src/rss_cell_group.hpp index 6d46a544267bd083ab12b3606bc2bd8cfef56030..f9e292962714b118525ef217143e60f4a8d8a8cb 100644 --- a/src/rss_cell_group.hpp +++ b/src/rss_cell_group.hpp @@ -34,7 +34,7 @@ public: void set_binning_policy(binning_kind policy, time_type bin_interval) override {} - void advance(epoch ep, time_type dt) override { + void advance(epoch ep, time_type dt, const event_lane_subrange& events) override { for (const auto& cell: cells_) { auto t = std::max(cell.start_time, time_); auto t_end = std::min(cell.stop_time, ep.tfinal); @@ -47,10 +47,6 @@ public: time_ = ep.tfinal; } - void enqueue_events(epoch, util::subrange_view_type<std::vector<std::vector<postsynaptic_spike_event>>>) override { - std::logic_error("rss_cell cannot deliver events"); - } - const std::vector<spike>& spikes() const override { return spikes_; } @@ -59,7 +55,7 @@ public: spikes_.clear(); } - virtual void add_sampler(sampler_association_handle, cell_member_predicate, schedule, sampler_function, sampling_policy) { + void add_sampler(sampler_association_handle, cell_member_predicate, schedule, sampler_function, sampling_policy) override { std::logic_error("rss_cell does not support sampling"); } diff --git a/src/util/rangeutil.hpp b/src/util/rangeutil.hpp index 1556877a0966967d2f64fc47f5d55475c6d0ac91..e9f70699140abc2be75dd85cf55627ccfe00c7ad 100644 --- a/src/util/rangeutil.hpp +++ b/src/util/rangeutil.hpp @@ -146,6 +146,13 @@ sort(Seq& seq) { std::sort(std::begin(canon), std::end(canon)); } +template <typename Seq, typename Less> +enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> +sort(Seq& seq, const Less& less) { + auto canon = canonical_view(seq); + std::sort(std::begin(canon), std::end(canon), less); +} + template <typename Seq> enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> sort(const Seq& seq) { @@ -153,6 +160,13 @@ sort(const Seq& seq) { std::sort(std::begin(canon), std::end(canon)); } +template <typename Seq, typename Less> +enable_if_t<!std::is_const<typename sequence_traits<Seq>::reference>::value> +sort(const Seq& seq, const Less& less) { + auto canon = canonical_view(seq); + std::sort(std::begin(canon), std::end(canon), less); +} + // Sort in-place by projection `proj` template <typename Seq, typename Proj> diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 3c64b0a497dad0e282b77c28f57fbfe3b106992b..c0f0d420b136d08b2b33ecd7c535c0d996e53023 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -54,6 +54,7 @@ set(TEST_SOURCES test_math.cpp test_matrix.cpp test_mechanisms.cpp + test_merge_events.cpp test_multi_event_stream.cpp test_nop.cpp test_optional.cpp diff --git a/tests/unit/test_dss_cell_group.cpp b/tests/unit/test_dss_cell_group.cpp index 4baa8d0c48f801aeb3e2ce31e5007fb3c56f6c01..ed9fc3c4e4db87af95a9b7570a240399836bb67d 100644 --- a/tests/unit/test_dss_cell_group.cpp +++ b/tests/unit/test_dss_cell_group.cpp @@ -19,14 +19,14 @@ TEST(dss_cell, basic_usage) // No spikes in this time frame. time_type dt = 0.01; // (note that dt is ignored in dss_cell_group). epoch ep(0, 0.09); - sut.advance(ep, dt); + sut.advance(ep, dt, {}); auto spikes = sut.spikes(); EXPECT_EQ(0u, spikes.size()); // Only one in this time frame. ep.advance(0.11); - sut.advance(ep, dt); + sut.advance(ep, dt, {}); spikes = sut.spikes(); EXPECT_EQ(1u, spikes.size()); ASSERT_FLOAT_EQ(spike_time, spikes[0].time); @@ -38,7 +38,7 @@ TEST(dss_cell, basic_usage) // No spike to be emitted. ep.advance(0.12); - sut.advance(ep, dt); + sut.advance(ep, dt, {}); spikes = sut.spikes(); EXPECT_EQ(0u, spikes.size()); @@ -47,7 +47,7 @@ TEST(dss_cell, basic_usage) // Expect to have the one spike again after reset. ep.advance(0.2); - sut.advance(ep, dt); + sut.advance(ep, dt, {}); spikes = sut.spikes(); EXPECT_EQ(1u, spikes.size()); ASSERT_FLOAT_EQ(spike_time, spikes[0].time); diff --git a/tests/unit/test_mc_cell_group.cpp b/tests/unit/test_mc_cell_group.cpp index e926488d145946d28531a1a97db9f4fe7a239557..92d4f415e68478830fd6b532ce6b9e6a7745af93 100644 --- a/tests/unit/test_mc_cell_group.cpp +++ b/tests/unit/test_mc_cell_group.cpp @@ -33,7 +33,7 @@ TEST(mc_cell_group, get_kind) { TEST(mc_cell_group, test) { mc_cell_group<fvm_cell> group{{0}, cable1d_recipe(make_cell()) }; - group.advance(epoch(0, 50), 0.01); + group.advance(epoch(0, 50), 0.01, {}); // the model is expected to generate 4 spikes as a result of the // fixed stimulus over 50 ms diff --git a/tests/unit/test_merge_events.cpp b/tests/unit/test_merge_events.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f60fca0944128a20fe44bf5a48de623bc8a78fc0 --- /dev/null +++ b/tests/unit/test_merge_events.cpp @@ -0,0 +1,115 @@ +#include "../gtest.h" + +#include <event_queue.hpp> +#include <model.hpp> + +namespace arb { + // Declare prototype of the merge_events function, because it is only + // defined in the TU of model.cpp + void merge_events(time_type tfinal, const pse_vector& lc, pse_vector& events, pse_vector& lf); +} // namespace arb + +using namespace arb; + +std::ostream& operator<<(std::ostream& o, const pse_vector& events) { + o << "{{"; for (auto e: events) o << " " << e; + o << "}}"; + return o; +} + +using pse = postsynaptic_spike_event; +auto ev_bind = [] (const pse& e){ return std::tie(e.time, e.target, e.weight); }; +auto ev_less = [] (const pse& l, const pse& r){ return ev_bind(l)<ev_bind(r); }; + +// Test the trivial case of merging empty sets +TEST(merge_events, empty) +{ + pse_vector events; + pse_vector lc; + pse_vector lf; + + merge_events(0, lc, events, lf); + + EXPECT_EQ(lf.size(), 0u); +} + +// Test the case where there are no events in lc that are to be delivered +// after tfinal. +TEST(merge_events, no_overlap) +{ + pse_vector lc = { + {{0, 0}, 1, 1}, + {{0, 0}, 2, 1}, + {{0, 0}, 3, 3}, + }; + // Check that the inputs satisfy the precondition that lc is sorted. + EXPECT_TRUE(std::is_sorted(lc.begin(), lc.end(), ev_less)); + + // These events should be removed from lf by merge_events, and replaced + // with events to be delivered after t=10 + pse_vector lf = { + {{0, 0}, 1, 1}, + {{0, 0}, 2, 1}, + {{0, 0}, 3, 3}, + }; + + pse_vector events = { + {{0, 0}, 12, 1}, + {{0, 0}, 11, 2}, + {{8, 0}, 10, 4}, + {{0, 0}, 11, 1}, + }; + + merge_events(10, lc, events, lf); + + pse_vector expected = { + {{8, 0}, 10, 4}, + {{0, 0}, 11, 1}, + {{0, 0}, 11, 2}, + {{0, 0}, 12, 1}, + }; + + EXPECT_TRUE(std::is_sorted(lf.begin(), lf.end(), ev_less)); + EXPECT_EQ(expected, lf); +} + +// Test case where current events (lc) contains events that must be deilvered +// in a future epoch, i.e. events with delivery time greater than the tfinal +// argument passed to merge_events. +TEST(merge_events, overlap) +{ + pse_vector lc = { + {{0, 0}, 1, 1}, + {{0, 0}, 2, 1}, + // The current epoch ends at t=10, so all events from here down are expected in lf. + {{8, 0}, 10, 2}, + {{0, 0}, 11, 3}, + }; + EXPECT_TRUE(std::is_sorted(lc.begin(), lc.end(), ev_less)); + + pse_vector lf; + + pse_vector events = { + // events are in reverse order: they should be sorted in the output of merge_events. + {{0, 0}, 12, 1}, + {{0, 0}, 11, 2}, + {{0, 0}, 11, 1}, + {{8, 0}, 10, 3}, + {{7, 0}, 10, 8}, + }; + + merge_events(10, lc, events, lf); + + pse_vector expected = { + {{7, 0}, 10, 8}, // from events + {{8, 0}, 10, 2}, // from lc + {{8, 0}, 10, 3}, // from events + {{0, 0}, 11, 1}, // from events + {{0, 0}, 11, 2}, // from events + {{0, 0}, 11, 3}, // from lc + {{0, 0}, 12, 1}, // from events + }; + + EXPECT_TRUE(std::is_sorted(lf.begin(), lf.end(), ev_less)); + EXPECT_EQ(expected, lf); +} diff --git a/tests/unit/test_range.cpp b/tests/unit/test_range.cpp index d666e45ffac614779e3a2187b4d8a942a81352ec..fb9012d3cbacd4805aae69a223f56f4c6e241a9b 100644 --- a/tests/unit/test_range.cpp +++ b/tests/unit/test_range.cpp @@ -402,6 +402,12 @@ TEST(range, assign_from) { } } +struct foo { + int x; + int y; + friend bool operator==(const foo& l, const foo& r) {return l.x==r.x && l.y==r.y;}; +}; + TEST(range, sort) { char cstr[] = "howdy"; @@ -428,6 +434,16 @@ TEST(range, sort) { util::stable_sort_by(util::strict_view(mixed_range), rank); EXPECT_EQ(std::string("HELLOthere54321"), mixed); + + + // sort with user-provided less comparison function + + std::vector<foo> X = {{0, 5}, {1, 4}, {2, 3}, {3, 2}, {4, 1}, {5, 0}}; + + util::sort(X, [](const foo& l, const foo& r) {return l.y<r.y;}); + EXPECT_EQ(X, (std::vector<foo>{{5, 0}, {4, 1}, {3, 2}, {2, 3}, {1, 4}, {0, 5}})); + util::sort(X, [](const foo& l, const foo& r) {return l.x<r.x;}); + EXPECT_EQ(X, (std::vector<foo>{{0, 5}, {1, 4}, {2, 3}, {3, 2}, {4, 1}, {5, 0}})); } TEST(range, sum_by) { diff --git a/tests/unit/test_rss_cell.cpp b/tests/unit/test_rss_cell.cpp index 690fe258361bc5dae7b2eee9568b9f70a9a91810..4b0b648c0f54c025b639068b5e8afa05b0ce1c23 100644 --- a/tests/unit/test_rss_cell.cpp +++ b/tests/unit/test_rss_cell.cpp @@ -20,13 +20,13 @@ TEST(rss_cell, basic_usage) // No spikes in this time frame. epoch ep(0, 0.1); - sut.advance(ep, dt); + sut.advance(ep, dt, {}); EXPECT_EQ(0u, sut.spikes().size()); // Only on in this time frame sut.clear_spikes(); ep.advance(0.127); - sut.advance(ep, dt); + sut.advance(ep, dt, {}); EXPECT_EQ(1u, sut.spikes().size()); // Reset cell group state. @@ -34,7 +34,7 @@ TEST(rss_cell, basic_usage) // Expect 12 spikes excluding the 0.5 end point. ep.advance(0.5); - sut.advance(ep, dt); + sut.advance(ep, dt, {}); EXPECT_EQ(12u, sut.spikes().size()); } @@ -46,19 +46,19 @@ TEST(rss_cell, poll_time_after_end_time) rss_cell_group sut({0}, rss_recipe(1u, desc)); // Expect 12 spikes in this time frame. - sut.advance(epoch(0, 0.7), dt); + sut.advance(epoch(0, 0.7), dt, {}); EXPECT_EQ(12u, sut.spikes().size()); // Now ask for spikes for a time slot already passed: // It should result in zero spikes because of the internal state! sut.clear_spikes(); - sut.advance(epoch(0, 0.2), dt); + sut.advance(epoch(0, 0.2), dt, {}); EXPECT_EQ(0u, sut.spikes().size()); sut.reset(); // Expect 12 excluding the 0.5 - sut.advance(epoch(0, 0.5), dt); + sut.advance(epoch(0, 0.5), dt, {}); EXPECT_EQ(12u, sut.spikes().size()); } diff --git a/tests/validation/validate_synapses.cpp b/tests/validation/validate_synapses.cpp index 633bd418676b87fb2c2a41a351e50f502c7bdf2a..7b611ca90b6d1346e89a7f1bc38a9d6252c24df7 100644 --- a/tests/validation/validate_synapses.cpp +++ b/tests/validation/validate_synapses.cpp @@ -42,11 +42,11 @@ void run_synapse_test( c.add_synapse({1, 0.5}, syn_default); // injected spike events - std::vector<std::vector<postsynaptic_spike_event>> synthetic_events = {{ + std::vector<postsynaptic_spike_event> synthetic_events = { {{0u, 0u}, 10.0, 0.04}, {{0u, 0u}, 20.0, 0.04}, {{0u, 0u}, 40.0, 0.04} - }}; + }; // exclude points of discontinuity from linf analysis std::vector<float> exclude = {10.f, 20.f, 40.f}; @@ -76,9 +76,7 @@ void run_synapse_test( auto decomp = partition_load_balance(rec, nd); model m(rec, decomp); - // Add events an odd epoch (1), so that they are available during integration of epoch 0. - // This is a bit of a hack. - m.group(0).enqueue_events(epoch(1, 0.), util::subrange_view(synthetic_events, 0, 1)); + m.inject_events(synthetic_events); runner.run(m, ncomp, sample_dt, t_end, dt, exclude); }