From 55aac4a9813fa00c2e03a5811d80a5e587ecf4e4 Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Mon, 5 Sep 2022 15:05:48 +0200 Subject: [PATCH] Remove explicit generator (#1962) - Remove the old, multi-target `event_generator` class in favour of `schedule_generator(tgt, weight, explicit_schedule)` - Much simplification ensues, `event_generator` is no longer a type-erasing container, but just what `schedule_generator` was before - Make the label resolution in generators a bit more eager, no longer at simulation time, but now during setup (bit give a wee bit of perf as well) Closes #1488 --- arbor/include/arbor/event_generator.hpp | 173 ++++-------------------- example/diffusion/diffusion.cpp | 2 +- example/dryrun/dryrun.cpp | 2 +- example/ring/ring.cpp | 2 +- python/recipe.cpp | 2 +- test/unit/test_diffusion.cpp | 6 +- test/unit/test_event_generators.cpp | 64 +++------ test/unit/test_merge_events.cpp | 46 +++---- test/unit/test_probe.cpp | 3 +- test/unit/test_recipe.cpp | 15 +- test/unit/test_simulation.cpp | 2 +- 11 files changed, 86 insertions(+), 231 deletions(-) diff --git a/arbor/include/arbor/event_generator.hpp b/arbor/include/arbor/event_generator.hpp index 85db87d8..5e14ecb4 100644 --- a/arbor/include/arbor/event_generator.hpp +++ b/arbor/include/arbor/event_generator.hpp @@ -5,6 +5,7 @@ #include <memory> #include <random> #include <type_traits> +#include <optional> #include <arbor/assert.hpp> #include <arbor/common_types.hpp> @@ -49,116 +50,22 @@ namespace arb { // and `events(t2, t3)` to the same event generator must satisfy // 0 ≤ t0 ≤ t1 ≤ t2 ≤ t3. // -// `event_generator` objects have value semantics, and use type erasure -// to wrap implementation details. An `event_generator` can be constructed -// from an object of an implementation class Impl that is copy-constructable -// and otherwise provides `reset` and `events` methods following the -// API described above. -// -// Some pre-defined event generators are included: -// - `empty_generator`: produces no events -// - `schedule_generator`: produces events according to a time schedule. -// A target is selected using a label resolution function for every generated -// event. -// - `explicit_generator`: is constructed from a vector of {label, time, weight} -// objects. Explicit targets are generated from the labels using a resolution -// function before the first call to the `events` method. +// `event_generator` objects have value semantics. using event_seq = std::pair<const spike_event*, const spike_event*>; using resolution_function = std::function<cell_lid_type(const cell_local_label_type&)>; -// The simplest possible generator that generates no events. -// Declared ahead of event_generator so that it can be used as the default -// generator. -struct empty_generator { - void reset() {} - event_seq events(time_type, time_type) { - return {nullptr, nullptr}; - } - void resolve_label(resolution_function) {} -}; - -class event_generator { -public: - event_generator(): event_generator(empty_generator()) {} - - template <typename Impl, std::enable_if_t<!std::is_same<std::decay_t<Impl>, event_generator>::value, int> = 0> - event_generator(Impl&& impl): - impl_(new wrap<Impl>(std::forward<Impl>(impl))) - {} - - event_generator(event_generator&& other) = default; - event_generator& operator=(event_generator&& other) = default; - - event_generator(const event_generator& other): - impl_(other.impl_->clone()) - {} - - event_generator& operator=(const event_generator& other) { - impl_ = other.impl_->clone(); - return *this; - } - - void reset() { - impl_->reset(); - } - - event_seq events(time_type t0, time_type t1) { - return impl_->events(t0, t1); - } - - void resolve_label(resolution_function label_resolver) { - impl_->resolve_label(std::move(label_resolver)); - } - -private: - struct interface { - virtual void reset() = 0; - virtual void resolve_label(resolution_function) = 0; - virtual event_seq events(time_type, time_type) = 0; - virtual std::unique_ptr<interface> clone() = 0; - virtual ~interface() {} - }; - - std::unique_ptr<interface> impl_; - - template <typename Impl> - struct wrap: interface { - explicit wrap(const Impl& impl): wrapped(impl) {} - explicit wrap(Impl&& impl): wrapped(std::move(impl)) {} - - event_seq events(time_type t0, time_type t1) override { - return wrapped.events(t0, t1); - } - - void reset() override { - wrapped.reset(); - } - - void resolve_label(resolution_function label_resolver) override { - wrapped.resolve_label(std::move(label_resolver)); - } - - std::unique_ptr<interface> clone() override { - return std::unique_ptr<interface>(new wrap<Impl>(wrapped)); - } - - Impl wrapped; - }; -}; - -// Convenience routines for making schedule_generator: // Generate events with a fixed target and weight according to // a provided time schedule. -struct schedule_generator { - schedule_generator(cell_local_label_type target, float weight, schedule sched): +struct event_generator { + event_generator(cell_local_label_type target, float weight, schedule sched): target_(std::move(target)), weight_(weight), sched_(std::move(sched)) {} void resolve_label(resolution_function label_resolver) { - label_resolver_ = std::move(label_resolver); + resolved_ = label_resolver(target_); } void reset() { @@ -166,13 +73,15 @@ struct schedule_generator { } event_seq events(time_type t0, time_type t1) { + if (!resolved_) throw ; + auto tgt = *resolved_; auto ts = sched_.events(t0, t1); events_.clear(); events_.reserve(ts.second-ts.first); for (auto i = ts.first; i!=ts.second; ++i) { - events_.push_back(spike_event{label_resolver_(target_), *i, weight_}); + events_.push_back(spike_event{tgt, *i, weight_}); } return {events_.data(), events_.data()+events_.size()}; @@ -182,10 +91,21 @@ private: pse_vector events_; cell_local_label_type target_; resolution_function label_resolver_; + std::optional<cell_lid_type> resolved_; float weight_; schedule sched_; }; +// Simplest generator: just do nothing +inline +event_generator empty_generator( + cell_local_label_type target, + float weight) +{ + return event_generator(std::move(target), weight, schedule()); +} + + // Generate events at integer multiples of dt that lie between tstart and tstop. inline event_generator regular_generator( @@ -195,7 +115,7 @@ inline event_generator regular_generator( time_type dt, time_type tstop=terminal_time) { - return schedule_generator(std::move(target), weight, regular_schedule(tstart, dt, tstop)); + return event_generator(std::move(target), weight, regular_schedule(tstart, dt, tstop)); } template <typename RNG> @@ -207,56 +127,19 @@ inline event_generator poisson_generator( const RNG& rng, time_type tstop=terminal_time) { - return schedule_generator(std::move(target), weight, poisson_schedule(tstart, rate_kHz, rng, tstop)); + return event_generator(std::move(target), weight, poisson_schedule(tstart, rate_kHz, rng, tstop)); } // Generate events from a predefined sorted event sequence. -struct explicit_generator { - struct labeled_synapse_event { - cell_local_label_type label; - time_type time; - float weight; - }; - - using lse_vector = std::vector<labeled_synapse_event>; - - explicit_generator() = default; - explicit_generator(const explicit_generator&) = default; - explicit_generator(explicit_generator&&) = default; - - explicit_generator(const lse_vector& events): - input_events_(events), start_index_(0) {} - - void resolve_label(resolution_function label_resolver) { - for (const auto& e: input_events_) { - events_.push_back({label_resolver(e.label), e.time, e.weight}); - } - std::sort(events_.begin(), events_.end()); - } - - void reset() { - start_index_ = 0; - } - - event_seq events(time_type t0, time_type t1) { - const spike_event* lb = events_.data()+start_index_; - const spike_event* ub = events_.data()+events_.size(); - - lb = std::lower_bound(lb, ub, t0, event_time_less{}); - ub = std::lower_bound(lb, ub, t1, event_time_less{}); - - start_index_ = ub-events_.data(); - return {lb, ub}; - } - -private: - lse_vector input_events_; - pse_vector events_; - std::size_t start_index_ = 0; -}; - +template<typename S> inline +event_generator explicit_generator(cell_local_label_type target, + float weight, + const S& s) +{ + return event_generator(std::move(target), weight, explicit_schedule(s)); +} } // namespace arb diff --git a/example/diffusion/diffusion.cpp b/example/diffusion/diffusion.cpp index e2ed2537..991a351d 100644 --- a/example/diffusion/diffusion.cpp +++ b/example/diffusion/diffusion.cpp @@ -31,7 +31,7 @@ struct linear: public recipe { cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable; } std::any get_global_properties(cell_kind) const override { return gprop; } std::vector<probe_info> get_probes(cell_gid_type) const override { return {cable_probe_ion_diff_concentration_cell{"na"}}; } - std::vector<event_generator> event_generators(cell_gid_type) const override { return {explicit_generator({{{"Zap"}, 0.0, 0.005}})}; } + std::vector<event_generator> event_generators(cell_gid_type) const override { return {explicit_generator({"Zap"}, 0.005, std::vector<float>{0.f})}; } util::unique_any get_cell_description(cell_gid_type) const override { // Stick morphology // -----|----- diff --git a/example/dryrun/dryrun.cpp b/example/dryrun/dryrun.cpp index 2e00e204..b2f244e8 100644 --- a/example/dryrun/dryrun.cpp +++ b/example/dryrun/dryrun.cpp @@ -112,7 +112,7 @@ public: std::vector<arb::event_generator> event_generators(cell_gid_type gid) const override { std::vector<arb::event_generator> gens; if (gid%20 == 0) { - gens.push_back(arb::explicit_generator({{{"synapse"}, 1.0, event_weight_}})); + gens.push_back(arb::explicit_generator({"synapse"}, event_weight_, std::vector<float>{1.0f})); } return gens; } diff --git a/example/ring/ring.cpp b/example/ring/ring.cpp index 72c4889e..cf896753 100644 --- a/example/ring/ring.cpp +++ b/example/ring/ring.cpp @@ -100,7 +100,7 @@ public: std::vector<arb::event_generator> event_generators(cell_gid_type gid) const override { std::vector<arb::event_generator> gens; if (!gid) { - gens.push_back(arb::explicit_generator({{{"primary_syn"}, 1.0, event_weight_}})); + gens.push_back(arb::explicit_generator({"primary_syn"}, event_weight_, std::vector<float>{1.0f})); } return gens; } diff --git a/python/recipe.cpp b/python/recipe.cpp index 83172209..54cef50e 100644 --- a/python/recipe.cpp +++ b/python/recipe.cpp @@ -107,7 +107,7 @@ static std::vector<arb::event_generator> convert_gen(std::vector<pybind11::objec auto& p = cast<const pyarb::event_generator_shim&>(g); // convert the event_generator to an arb::event_generator - gens.push_back(arb::schedule_generator(p.target, p.weight, std::move(p.time_sched))); + gens.push_back(arb::event_generator(p.target, p.weight, std::move(p.time_sched))); } return gens; diff --git a/test/unit/test_diffusion.cpp b/test/unit/test_diffusion.cpp index 4575ed7d..a3bcc6a4 100644 --- a/test/unit/test_diffusion.cpp +++ b/test/unit/test_diffusion.cpp @@ -54,7 +54,9 @@ struct linear: public recipe { std::vector<arb::event_generator> event_generators(arb::cell_gid_type gid) const override { std::vector<arb::event_generator> result; - for (const auto& [t, w]: inject_at) result.push_back(arb::explicit_generator({{{"Zap"}, t, w}})); + for (const auto& [t, w]: inject_at) { + result.push_back(arb::explicit_generator({"Zap"}, w, std::vector<float>{t})); + } return result; } @@ -62,7 +64,7 @@ struct linear: public recipe { double extent = 1.0, diameter = 1.0, cv_length = 1.0; - std::vector<std::tuple<double, float>> inject_at; + std::vector<std::tuple<float, float>> inject_at; morphology morph; arb::decor decor; diff --git a/test/unit/test_event_generators.cpp b/test/unit/test_event_generators.cpp index cdb97b82..066612c0 100644 --- a/test/unit/test_event_generators.cpp +++ b/test/unit/test_event_generators.cpp @@ -33,8 +33,7 @@ TEST(event_generators, assign_and_copy) { event_generator g1(gen); EXPECT_EQ(expected, first(g1.events(0., 1.))); - event_generator g2; - g2 = gen; + event_generator g2 = gen; EXPECT_EQ(expected, first(g2.events(0., 1.))); const auto& const_gen = gen; @@ -42,8 +41,7 @@ TEST(event_generators, assign_and_copy) { event_generator g3(const_gen); EXPECT_EQ(expected, first(g3.events(0., 1.))); - event_generator g4; - g4 = gen; + event_generator g4 = gen; EXPECT_EQ(expected, first(g4.events(0., 1.))); event_generator g5(std::move(gen)); @@ -82,54 +80,30 @@ TEST(event_generators, regular) { EXPECT_EQ(expected({12, 12.5}), as_vector(gen.events(12, 12.7))); } +using lse_vector = std::vector<std::tuple<cell_local_label_type, time_type, float>>; + TEST(event_generators, seq) { - explicit_generator::lse_vector in = { - {{"l0"}, 0.1, 1.0}, - {{"l0"}, 1.0, 2.0}, - {{"l2"}, 1.0, 3.0}, - {{"l1"}, 1.5, 4.0}, - {{"l2"}, 2.3, 5.0}, - {{"l0"}, 3.0, 6.0}, - {{"l0"}, 3.5, 7.0}, - }; - std::unordered_map<cell_tag_type, cell_lid_type> lid_map = {{"l0", 0},{"l1", 1}, {"l2", 2}}; + std::vector<arb::time_type> times = {1, 2, 3, 4, 5, 6, 7}; + lse_vector in; pse_vector expected; - std::transform(in.begin(), in.end(), std::back_inserter(expected), - [lid_map](const auto& item) {return spike_event{lid_map.at(item.label.tag), item.time, item.weight};}); - - event_generator gen = explicit_generator(in); - gen.resolve_label([lid_map](const cell_local_label_type& item) {return lid_map.at(item.tag);}); - - EXPECT_EQ(expected, as_vector(gen.events(0, 100.))); - gen.reset(); - EXPECT_EQ(expected, as_vector(gen.events(0, 100.))); - gen.reset(); + float weight = 0.42; + arb::cell_local_label_type l0 = {"l0"}; + for (auto time: times) { + in.push_back({l0, weight, time}); + expected.push_back({0, time, weight}); + } - // Check reported sub-intervals against a smaller set of events. - in = { - {{"l0"}, 1.5, 4.0}, - {{"l0"}, 2.3, 5.0}, - {{"l0"}, 3.0, 6.0}, - {{"l0"}, 3.5, 7.0}, - }; - expected.clear(); - std::transform(in.begin(), in.end(), std::back_inserter(expected), - [lid_map](const auto& item) {return spike_event{lid_map.at(item.label.tag), item.time, item.weight};}); + event_generator gen = explicit_generator(l0, weight, times); + gen.resolve_label([](const cell_local_label_type&) {return 0;}); - gen = explicit_generator(in); - gen.resolve_label([lid_map](const cell_local_label_type& item) {return lid_map.at(item.tag);}); + EXPECT_EQ(expected, as_vector(gen.events(0, 100.))); gen.reset(); + EXPECT_EQ(expected, as_vector(gen.events(0, 100.))); gen.reset(); - auto draw = [](event_generator& gen, time_type t0, time_type t1) { - gen.reset(); - return as_vector(gen.events(t0, t1)); - }; - - auto events = [&expected] (int b, int e) { - return pse_vector(expected.begin()+b, expected.begin()+e); - }; + auto draw = [](auto& gen, auto t0, auto t1) { gen.reset(); return as_vector(gen.events(t0, t1)); }; + auto events = [&expected] (int b, int e) { auto beg = expected.begin(); return pse_vector(beg+b, beg+e); }; // a range that includes all the events - EXPECT_EQ(expected, draw(gen, 0, 4)); + EXPECT_EQ(expected, draw(gen, 0, 8)); // a strict subset including the first event EXPECT_EQ(events(0, 2), draw(gen, 0, 3)); diff --git a/test/unit/test_merge_events.cpp b/test/unit/test_merge_events.cpp index f09cffc0..e264b014 100644 --- a/test/unit/test_merge_events.cpp +++ b/test/unit/test_merge_events.cpp @@ -179,40 +179,36 @@ TEST(merge_events, X) EXPECT_EQ(expected, lf); } -// Test the tournament tree for merging two small sequences -TEST(merge_events, tourney_seq) -{ - explicit_generator::lse_vector evs1 = { - {{"l0"}, 1, 1}, - {{"l0"}, 2, 2}, - {{"l0"}, 3, 3}, - {{"l0"}, 4, 4}, - {{"l0"}, 5, 5}, - }; + struct labeled_synapse_event { - explicit_generator::lse_vector evs2 = { - {{"l0"}, 1.5, 1}, - {{"l0"}, 2.5, 2}, - {{"l0"}, 3.5, 3}, - {{"l0"}, 4.5, 4}, - {{"l0"}, 5.5, 5}, }; - pse_vector expected; +using lse_vector = std::vector<std::tuple<cell_local_label_type, time_type, float>>; - auto gen_pse = [](const auto& item) {return spike_event{0, item.time, item.weight};}; - std::transform(evs1.begin(), evs1.end(), std::back_inserter(expected), gen_pse); - std::transform(evs2.begin(), evs2.end(), std::back_inserter(expected), gen_pse); +// Test the tournament tree for merging two small sequences +TEST(merge_events, tourney_seq) +{ + std::vector<arb::time_type> times {1, 2, 3, 4, 5}; + cell_local_label_type l0 = {"l0"}; + float w1 = 1.0f, w2 = 2.0f; + lse_vector evs1, evs2; + pse_vector expected; + for (const auto time: times) { + evs1.emplace_back(l0, w1, time); + evs2.emplace_back(l0, w2, time); + expected.push_back({0, time, w1}); + expected.push_back({0, time, w2}); + } util::sort(expected); - event_generator g1 = explicit_generator(evs1); - event_generator g2 = explicit_generator(evs2); + auto + g1 = explicit_generator(l0, w1, times), + g2 = explicit_generator(l0, w2, times); g1.resolve_label([](const cell_local_label_type&) {return 0;}); g2.resolve_label([](const cell_local_label_type&) {return 0;}); - std::vector<event_span> spans; - spans.emplace_back(g1.events(0, terminal_time)); - spans.emplace_back(g2.events(0, terminal_time)); + std::vector<event_span> spans = {g1.events(0, terminal_time), + g2.events(0, terminal_time)}; impl::tourney_tree tree(spans); pse_vector lf; diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp index 61f5fef2..5790c9e3 100644 --- a/test/unit/test_probe.cpp +++ b/test/unit/test_probe.cpp @@ -1249,8 +1249,7 @@ void run_exact_sampling_probe_test(context ctx) { std::vector<event_generator> event_generators(cell_gid_type gid) const override { // Send a single event to cell i at 0.1*i milliseconds. - explicit_generator::lse_vector spikes = {{{"syn"}, 0.1*gid, 1.f}}; - return {explicit_generator(spikes)}; + return {explicit_generator({"syn"}, 1.0f, std::vector<float>{0.1f*gid})}; } std::any get_global_properties(cell_kind k) const override { diff --git a/test/unit/test_recipe.cpp b/test/unit/test_recipe.cpp index f1ffe120..b5b45205 100644 --- a/test/unit/test_recipe.cpp +++ b/test/unit/test_recipe.cpp @@ -201,20 +201,21 @@ TEST(recipe, event_generators) { auto cell_0 = custom_cell(1, 2, 0); auto cell_1 = custom_cell(2, 1, 0); - std::vector<arb::event_generator> gens_0, gens_1; { - gens_0 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}, {{"synapse1"}, 2.0, 0.1}})}; - - gens_1 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}})}; + std::vector<arb::event_generator> + gens_0 = {arb::explicit_generator({"synapse0"}, 0.1, std::vector<arb::time_type>{1.0}), + arb::explicit_generator({"synapse1"}, 0.1, std::vector<arb::time_type>{2.0})}, + gens_1 = {arb::explicit_generator({"synapse0"}, 0.1, std::vector<arb::time_type>{1.0})}; auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}}, {gens_0, gens_1}); auto decomp_0 = partition_load_balance(recipe_0, context); - EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0)); + EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0).run(1, 0.1)); } { - gens_0 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}, {{"synapse3"}, 2.0, 0.1}})}; - gens_1.clear(); + std::vector<arb::event_generator> + gens_0 = {arb::regular_generator({"totally-not-a-synapse-42"}, 0.1, 0, 0.001)}, + gens_1 = {}; auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}}, {gens_0, gens_1}); auto decomp_0 = partition_load_balance(recipe_0, context); diff --git a/test/unit/test_simulation.cpp b/test/unit/test_simulation.cpp index b79ac19a..0e0caf2e 100644 --- a/test/unit/test_simulation.cpp +++ b/test/unit/test_simulation.cpp @@ -124,7 +124,7 @@ struct lif_chain: public recipe { return {}; } else { - return {schedule_generator({"tgt"}, weight_, triggers_)}; + return {event_generator({"tgt"}, weight_, triggers_)}; } } -- GitLab