From 55aac4a9813fa00c2e03a5811d80a5e587ecf4e4 Mon Sep 17 00:00:00 2001
From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com>
Date: Mon, 5 Sep 2022 15:05:48 +0200
Subject: [PATCH] Remove explicit generator (#1962)

- Remove the old, multi-target `event_generator` class in favour of `schedule_generator(tgt, weight, explicit_schedule)`
- Much simplification ensues, `event_generator` is no longer a type-erasing container, but just what
   `schedule_generator` was before
- Make the label resolution in generators a bit more eager, no longer at simulation time,
   but now during setup (bit give a wee bit of perf as well)

Closes #1488
---
 arbor/include/arbor/event_generator.hpp | 173 ++++--------------------
 example/diffusion/diffusion.cpp         |   2 +-
 example/dryrun/dryrun.cpp               |   2 +-
 example/ring/ring.cpp                   |   2 +-
 python/recipe.cpp                       |   2 +-
 test/unit/test_diffusion.cpp            |   6 +-
 test/unit/test_event_generators.cpp     |  64 +++------
 test/unit/test_merge_events.cpp         |  46 +++----
 test/unit/test_probe.cpp                |   3 +-
 test/unit/test_recipe.cpp               |  15 +-
 test/unit/test_simulation.cpp           |   2 +-
 11 files changed, 86 insertions(+), 231 deletions(-)

diff --git a/arbor/include/arbor/event_generator.hpp b/arbor/include/arbor/event_generator.hpp
index 85db87d8..5e14ecb4 100644
--- a/arbor/include/arbor/event_generator.hpp
+++ b/arbor/include/arbor/event_generator.hpp
@@ -5,6 +5,7 @@
 #include <memory>
 #include <random>
 #include <type_traits>
+#include <optional>
 
 #include <arbor/assert.hpp>
 #include <arbor/common_types.hpp>
@@ -49,116 +50,22 @@ namespace arb {
 // and `events(t2, t3)` to the same event generator must satisfy
 // 0 ≤ t0 ≤ t1 ≤ t2 ≤ t3.
 //
-// `event_generator` objects have value semantics, and use type erasure
-// to wrap implementation details. An `event_generator` can be constructed
-// from an object of an implementation class Impl that is copy-constructable
-// and otherwise provides `reset` and `events` methods following the
-// API described above.
-//
-// Some pre-defined event generators are included:
-//  - `empty_generator`: produces no events
-//  - `schedule_generator`: produces events according to a time schedule.
-//    A target is selected using a label resolution function for every generated
-//    event.
-//  - `explicit_generator`: is constructed from a vector of {label, time, weight}
-//    objects. Explicit targets are generated from the labels using a resolution
-//    function before the first call to the `events` method.
+// `event_generator` objects have value semantics.
 
 using event_seq = std::pair<const spike_event*, const spike_event*>;
 using resolution_function = std::function<cell_lid_type(const cell_local_label_type&)>;
 
-// The simplest possible generator that generates no events.
-// Declared ahead of event_generator so that it can be used as the default
-// generator.
-struct empty_generator {
-    void reset() {}
-    event_seq events(time_type, time_type) {
-        return {nullptr, nullptr};
-    }
-    void resolve_label(resolution_function) {}
-};
-
-class event_generator {
-public:
-    event_generator(): event_generator(empty_generator()) {}
-
-    template <typename Impl, std::enable_if_t<!std::is_same<std::decay_t<Impl>, event_generator>::value, int> = 0>
-    event_generator(Impl&& impl):
-        impl_(new wrap<Impl>(std::forward<Impl>(impl)))
-    {}
-
-    event_generator(event_generator&& other) = default;
-    event_generator& operator=(event_generator&& other) = default;
-
-    event_generator(const event_generator& other):
-        impl_(other.impl_->clone())
-    {}
-
-    event_generator& operator=(const event_generator& other) {
-        impl_ = other.impl_->clone();
-        return *this;
-    }
-
-    void reset() {
-        impl_->reset();
-    }
-
-    event_seq events(time_type t0, time_type t1) {
-        return impl_->events(t0, t1);
-    }
-
-    void resolve_label(resolution_function label_resolver) {
-        impl_->resolve_label(std::move(label_resolver));
-    }
-
-private:
-    struct interface {
-        virtual void reset() = 0;
-        virtual void resolve_label(resolution_function) = 0;
-        virtual event_seq events(time_type, time_type) = 0;
-        virtual std::unique_ptr<interface> clone() = 0;
-        virtual ~interface() {}
-    };
-
-    std::unique_ptr<interface> impl_;
-
-    template <typename Impl>
-    struct wrap: interface {
-        explicit wrap(const Impl& impl): wrapped(impl) {}
-        explicit wrap(Impl&& impl): wrapped(std::move(impl)) {}
-
-        event_seq events(time_type t0, time_type t1) override {
-            return wrapped.events(t0, t1);
-        }
-
-        void reset() override {
-            wrapped.reset();
-        }
-
-        void resolve_label(resolution_function label_resolver) override {
-            wrapped.resolve_label(std::move(label_resolver));
-        }
-
-        std::unique_ptr<interface> clone() override {
-            return std::unique_ptr<interface>(new wrap<Impl>(wrapped));
-        }
-
-        Impl wrapped;
-    };
-};
-
-// Convenience routines for making schedule_generator:
 
 // Generate events with a fixed target and weight according to
 // a provided time schedule.
 
-struct schedule_generator {
-    schedule_generator(cell_local_label_type target, float weight, schedule sched):
+struct event_generator {
+    event_generator(cell_local_label_type target, float weight, schedule sched):
         target_(std::move(target)), weight_(weight), sched_(std::move(sched))
     {}
 
     void resolve_label(resolution_function label_resolver) {
-        label_resolver_ = std::move(label_resolver);
+        resolved_ = label_resolver(target_);
     }
 
     void reset() {
@@ -166,13 +73,15 @@ struct schedule_generator {
     }
 
     event_seq events(time_type t0, time_type t1) {
+        if (!resolved_) throw ;
+        auto tgt = *resolved_;
         auto ts = sched_.events(t0, t1);
 
         events_.clear();
         events_.reserve(ts.second-ts.first);
 
         for (auto i = ts.first; i!=ts.second; ++i) {
-            events_.push_back(spike_event{label_resolver_(target_), *i, weight_});
+            events_.push_back(spike_event{tgt, *i, weight_});
         }
 
         return {events_.data(), events_.data()+events_.size()};
@@ -182,10 +91,21 @@ private:
     pse_vector events_;
     cell_local_label_type target_;
     resolution_function label_resolver_;
+    std::optional<cell_lid_type> resolved_;
     float weight_;
     schedule sched_;
 };
 
+// Simplest generator: just do nothing
+inline
+event_generator empty_generator(
+    cell_local_label_type target,
+    float weight)
+{
+    return event_generator(std::move(target), weight, schedule());
+}
+
+
 // Generate events at integer multiples of dt that lie between tstart and tstop.
 
 inline event_generator regular_generator(
@@ -195,7 +115,7 @@ inline event_generator regular_generator(
     time_type dt,
     time_type tstop=terminal_time)
 {
-    return schedule_generator(std::move(target), weight, regular_schedule(tstart, dt, tstop));
+    return event_generator(std::move(target), weight, regular_schedule(tstart, dt, tstop));
 }
 
 template <typename RNG>
@@ -207,56 +127,19 @@ inline event_generator poisson_generator(
     const RNG& rng,
     time_type tstop=terminal_time)
 {
-    return schedule_generator(std::move(target), weight, poisson_schedule(tstart, rate_kHz, rng, tstop));
+    return event_generator(std::move(target), weight, poisson_schedule(tstart, rate_kHz, rng, tstop));
 }
 
 
 // Generate events from a predefined sorted event sequence.
 
-struct explicit_generator {
-    struct labeled_synapse_event {
-        cell_local_label_type label;
-        time_type time;
-        float weight;
-    };
-
-    using lse_vector = std::vector<labeled_synapse_event>;
-
-    explicit_generator() = default;
-    explicit_generator(const explicit_generator&) = default;
-    explicit_generator(explicit_generator&&) = default;
-
-    explicit_generator(const lse_vector& events):
-        input_events_(events), start_index_(0) {}
-
-    void resolve_label(resolution_function label_resolver) {
-        for (const auto& e: input_events_) {
-            events_.push_back({label_resolver(e.label), e.time, e.weight});
-        }
-        std::sort(events_.begin(), events_.end());
-    }
-
-    void reset() {
-        start_index_ = 0;
-    }
-
-    event_seq events(time_type t0, time_type t1) {
-        const spike_event* lb = events_.data()+start_index_;
-        const spike_event* ub = events_.data()+events_.size();
-
-        lb = std::lower_bound(lb, ub, t0, event_time_less{});
-        ub = std::lower_bound(lb, ub, t1, event_time_less{});
-
-        start_index_ = ub-events_.data();
-        return {lb, ub};
-    }
-
-private:
-    lse_vector input_events_;
-    pse_vector events_;
-    std::size_t start_index_ = 0;
-};
-
+template<typename S> inline
+event_generator explicit_generator(cell_local_label_type target,
+                                   float weight,
+                                   const S& s)
+{
+    return event_generator(std::move(target), weight, explicit_schedule(s));
+}
 
 } // namespace arb
 
diff --git a/example/diffusion/diffusion.cpp b/example/diffusion/diffusion.cpp
index e2ed2537..991a351d 100644
--- a/example/diffusion/diffusion.cpp
+++ b/example/diffusion/diffusion.cpp
@@ -31,7 +31,7 @@ struct linear: public recipe {
     cell_kind get_cell_kind(cell_gid_type)                       const override { return cell_kind::cable; }
     std::any get_global_properties(cell_kind)                    const override { return gprop; }
     std::vector<probe_info> get_probes(cell_gid_type)            const override { return {cable_probe_ion_diff_concentration_cell{"na"}}; }
-    std::vector<event_generator> event_generators(cell_gid_type) const override { return {explicit_generator({{{"Zap"}, 0.0, 0.005}})}; }
+    std::vector<event_generator> event_generators(cell_gid_type) const override { return {explicit_generator({"Zap"}, 0.005, std::vector<float>{0.f})}; }
     util::unique_any get_cell_description(cell_gid_type)         const override {
         // Stick morphology
         // -----|-----
diff --git a/example/dryrun/dryrun.cpp b/example/dryrun/dryrun.cpp
index 2e00e204..b2f244e8 100644
--- a/example/dryrun/dryrun.cpp
+++ b/example/dryrun/dryrun.cpp
@@ -112,7 +112,7 @@ public:
     std::vector<arb::event_generator> event_generators(cell_gid_type gid) const override {
         std::vector<arb::event_generator> gens;
         if (gid%20 == 0) {
-            gens.push_back(arb::explicit_generator({{{"synapse"}, 1.0, event_weight_}}));
+            gens.push_back(arb::explicit_generator({"synapse"}, event_weight_, std::vector<float>{1.0f}));
         }
         return gens;
     }
diff --git a/example/ring/ring.cpp b/example/ring/ring.cpp
index 72c4889e..cf896753 100644
--- a/example/ring/ring.cpp
+++ b/example/ring/ring.cpp
@@ -100,7 +100,7 @@ public:
     std::vector<arb::event_generator> event_generators(cell_gid_type gid) const override {
         std::vector<arb::event_generator> gens;
         if (!gid) {
-            gens.push_back(arb::explicit_generator({{{"primary_syn"}, 1.0, event_weight_}}));
+            gens.push_back(arb::explicit_generator({"primary_syn"}, event_weight_, std::vector<float>{1.0f}));
         }
         return gens;
     }
diff --git a/python/recipe.cpp b/python/recipe.cpp
index 83172209..54cef50e 100644
--- a/python/recipe.cpp
+++ b/python/recipe.cpp
@@ -107,7 +107,7 @@ static std::vector<arb::event_generator> convert_gen(std::vector<pybind11::objec
         auto& p = cast<const pyarb::event_generator_shim&>(g);
 
         // convert the event_generator to an arb::event_generator
-        gens.push_back(arb::schedule_generator(p.target, p.weight, std::move(p.time_sched)));
+        gens.push_back(arb::event_generator(p.target, p.weight, std::move(p.time_sched)));
     }
 
     return gens;
diff --git a/test/unit/test_diffusion.cpp b/test/unit/test_diffusion.cpp
index 4575ed7d..a3bcc6a4 100644
--- a/test/unit/test_diffusion.cpp
+++ b/test/unit/test_diffusion.cpp
@@ -54,7 +54,9 @@ struct linear: public recipe {
 
     std::vector<arb::event_generator> event_generators(arb::cell_gid_type gid) const override {
         std::vector<arb::event_generator> result;
-        for (const auto& [t, w]: inject_at) result.push_back(arb::explicit_generator({{{"Zap"}, t, w}}));
+        for (const auto& [t, w]: inject_at) {
+            result.push_back(arb::explicit_generator({"Zap"}, w, std::vector<float>{t}));
+        }
         return result;
     }
 
@@ -62,7 +64,7 @@ struct linear: public recipe {
     double extent = 1.0,
            diameter = 1.0,
            cv_length = 1.0;
-    std::vector<std::tuple<double, float>> inject_at;
+    std::vector<std::tuple<float, float>> inject_at;
     morphology morph;
     arb::decor decor;
 
diff --git a/test/unit/test_event_generators.cpp b/test/unit/test_event_generators.cpp
index cdb97b82..066612c0 100644
--- a/test/unit/test_event_generators.cpp
+++ b/test/unit/test_event_generators.cpp
@@ -33,8 +33,7 @@ TEST(event_generators, assign_and_copy) {
     event_generator g1(gen);
     EXPECT_EQ(expected, first(g1.events(0., 1.)));
 
-    event_generator g2;
-    g2 = gen;
+    event_generator g2 = gen;
     EXPECT_EQ(expected, first(g2.events(0., 1.)));
 
     const auto& const_gen = gen;
@@ -42,8 +41,7 @@ TEST(event_generators, assign_and_copy) {
     event_generator g3(const_gen);
     EXPECT_EQ(expected, first(g3.events(0., 1.)));
 
-    event_generator g4;
-    g4 = gen;
+    event_generator g4 = gen;
     EXPECT_EQ(expected, first(g4.events(0., 1.)));
 
     event_generator g5(std::move(gen));
@@ -82,54 +80,30 @@ TEST(event_generators, regular) {
     EXPECT_EQ(expected({12, 12.5}), as_vector(gen.events(12, 12.7)));
 }
 
+using lse_vector = std::vector<std::tuple<cell_local_label_type, time_type, float>>;
+
 TEST(event_generators, seq) {
-    explicit_generator::lse_vector in = {
-        {{"l0"}, 0.1, 1.0},
-        {{"l0"}, 1.0, 2.0},
-        {{"l2"}, 1.0, 3.0},
-        {{"l1"}, 1.5, 4.0},
-        {{"l2"}, 2.3, 5.0},
-        {{"l0"}, 3.0, 6.0},
-        {{"l0"}, 3.5, 7.0},
-    };
-    std::unordered_map<cell_tag_type, cell_lid_type> lid_map = {{"l0", 0},{"l1", 1}, {"l2", 2}};
+    std::vector<arb::time_type> times = {1, 2, 3, 4, 5, 6, 7};
+    lse_vector in;
     pse_vector expected;
-    std::transform(in.begin(), in.end(), std::back_inserter(expected),
-        [lid_map](const auto& item) {return spike_event{lid_map.at(item.label.tag), item.time, item.weight};});
-
-    event_generator gen = explicit_generator(in);
-    gen.resolve_label([lid_map](const cell_local_label_type& item) {return lid_map.at(item.tag);});
-
-    EXPECT_EQ(expected, as_vector(gen.events(0, 100.)));
-    gen.reset();
-    EXPECT_EQ(expected, as_vector(gen.events(0, 100.)));
-    gen.reset();
+    float weight = 0.42;
+    arb::cell_local_label_type l0 = {"l0"};
+    for (auto time: times) {
+        in.push_back({l0, weight, time});
+        expected.push_back({0, time, weight});
+    }
 
-    // Check reported sub-intervals against a smaller set of events.
-    in = {
-        {{"l0"}, 1.5, 4.0},
-        {{"l0"}, 2.3, 5.0},
-        {{"l0"}, 3.0, 6.0},
-        {{"l0"}, 3.5, 7.0},
-    };
-    expected.clear();
-    std::transform(in.begin(), in.end(), std::back_inserter(expected),
-        [lid_map](const auto& item) {return spike_event{lid_map.at(item.label.tag), item.time, item.weight};});
+    event_generator gen = explicit_generator(l0, weight, times);
+    gen.resolve_label([](const cell_local_label_type&) {return 0;});
 
-    gen = explicit_generator(in);
-    gen.resolve_label([lid_map](const cell_local_label_type& item) {return lid_map.at(item.tag);});
+    EXPECT_EQ(expected, as_vector(gen.events(0, 100.))); gen.reset();
+    EXPECT_EQ(expected, as_vector(gen.events(0, 100.))); gen.reset();
 
-    auto draw = [](event_generator& gen, time_type t0, time_type t1) {
-        gen.reset();
-        return as_vector(gen.events(t0, t1));
-    };
-
-    auto events = [&expected] (int b, int e) {
-      return pse_vector(expected.begin()+b, expected.begin()+e);
-    };
+    auto draw = [](auto& gen, auto t0, auto t1) { gen.reset(); return as_vector(gen.events(t0, t1)); };
+    auto events = [&expected] (int b, int e) { auto beg = expected.begin(); return pse_vector(beg+b, beg+e); };
 
     // a range that includes all the events
-    EXPECT_EQ(expected, draw(gen, 0, 4));
+    EXPECT_EQ(expected, draw(gen, 0, 8));
 
     // a strict subset including the first event
     EXPECT_EQ(events(0, 2), draw(gen, 0, 3));
diff --git a/test/unit/test_merge_events.cpp b/test/unit/test_merge_events.cpp
index f09cffc0..e264b014 100644
--- a/test/unit/test_merge_events.cpp
+++ b/test/unit/test_merge_events.cpp
@@ -179,40 +179,36 @@ TEST(merge_events, X)
     EXPECT_EQ(expected, lf);
 }
 
-// Test the tournament tree for merging two small sequences 
-TEST(merge_events, tourney_seq)
-{
-    explicit_generator::lse_vector evs1 = {
-        {{"l0"}, 1, 1},
-        {{"l0"}, 2, 2},
-        {{"l0"}, 3, 3},
-        {{"l0"}, 4, 4},
-        {{"l0"}, 5, 5},
-    };
+    struct labeled_synapse_event {
 
-    explicit_generator::lse_vector evs2 = {
-        {{"l0"}, 1.5, 1},
-        {{"l0"}, 2.5, 2},
-        {{"l0"}, 3.5, 3},
-        {{"l0"}, 4.5, 4},
-        {{"l0"}, 5.5, 5},
     };
 
-    pse_vector expected;
+using lse_vector = std::vector<std::tuple<cell_local_label_type, time_type, float>>;
 
-    auto gen_pse = [](const auto& item) {return spike_event{0, item.time, item.weight};};
-    std::transform(evs1.begin(), evs1.end(), std::back_inserter(expected), gen_pse);
-    std::transform(evs2.begin(), evs2.end(), std::back_inserter(expected), gen_pse);
+// Test the tournament tree for merging two small sequences 
+TEST(merge_events, tourney_seq)
+{
+    std::vector<arb::time_type> times {1, 2, 3, 4, 5};
+    cell_local_label_type l0 = {"l0"};
+    float w1 = 1.0f, w2 = 2.0f;
+    lse_vector evs1, evs2;
+    pse_vector expected;
+    for (const auto time: times) {
+        evs1.emplace_back(l0, w1, time);
+        evs2.emplace_back(l0, w2, time);
+        expected.push_back({0, time, w1});
+        expected.push_back({0, time, w2});
+    }
     util::sort(expected);
 
-    event_generator g1 = explicit_generator(evs1);
-    event_generator g2 = explicit_generator(evs2);
+    auto
+        g1 = explicit_generator(l0, w1, times),
+        g2 = explicit_generator(l0, w2, times);
     g1.resolve_label([](const cell_local_label_type&) {return 0;});
     g2.resolve_label([](const cell_local_label_type&) {return 0;});
 
-    std::vector<event_span> spans;
-    spans.emplace_back(g1.events(0, terminal_time));
-    spans.emplace_back(g2.events(0, terminal_time));
+    std::vector<event_span> spans = {g1.events(0, terminal_time),
+                                     g2.events(0, terminal_time)};
     impl::tourney_tree tree(spans);
 
     pse_vector lf;
diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp
index 61f5fef2..5790c9e3 100644
--- a/test/unit/test_probe.cpp
+++ b/test/unit/test_probe.cpp
@@ -1249,8 +1249,7 @@ void run_exact_sampling_probe_test(context ctx) {
 
         std::vector<event_generator> event_generators(cell_gid_type gid) const override {
             // Send a single event to cell i at 0.1*i milliseconds.
-            explicit_generator::lse_vector spikes = {{{"syn"}, 0.1*gid, 1.f}};
-            return {explicit_generator(spikes)};
+            return {explicit_generator({"syn"}, 1.0f, std::vector<float>{0.1f*gid})};
         }
 
         std::any get_global_properties(cell_kind k) const override {
diff --git a/test/unit/test_recipe.cpp b/test/unit/test_recipe.cpp
index f1ffe120..b5b45205 100644
--- a/test/unit/test_recipe.cpp
+++ b/test/unit/test_recipe.cpp
@@ -201,20 +201,21 @@ TEST(recipe, event_generators) {
 
     auto cell_0 = custom_cell(1, 2, 0);
     auto cell_1 = custom_cell(2, 1, 0);
-    std::vector<arb::event_generator> gens_0, gens_1;
     {
-        gens_0 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}, {{"synapse1"}, 2.0, 0.1}})};
-
-        gens_1 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}})};
+        std::vector<arb::event_generator>
+            gens_0 = {arb::explicit_generator({"synapse0"}, 0.1, std::vector<arb::time_type>{1.0}),
+                      arb::explicit_generator({"synapse1"}, 0.1, std::vector<arb::time_type>{2.0})},
+            gens_1 = {arb::explicit_generator({"synapse0"}, 0.1, std::vector<arb::time_type>{1.0})};
 
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}},  {gens_0, gens_1});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0));
+        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0).run(1, 0.1));
     }
     {
-        gens_0 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}, {{"synapse3"}, 2.0, 0.1}})};
-        gens_1.clear();
+        std::vector<arb::event_generator>
+            gens_0 = {arb::regular_generator({"totally-not-a-synapse-42"}, 0.1, 0, 0.001)},
+            gens_1 = {};
 
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}},  {gens_0, gens_1});
         auto decomp_0 = partition_load_balance(recipe_0, context);
diff --git a/test/unit/test_simulation.cpp b/test/unit/test_simulation.cpp
index b79ac19a..0e0caf2e 100644
--- a/test/unit/test_simulation.cpp
+++ b/test/unit/test_simulation.cpp
@@ -124,7 +124,7 @@ struct lif_chain: public recipe {
             return {};
         }
         else {
-            return {schedule_generator({"tgt"}, weight_, triggers_)};
+            return {event_generator({"tgt"}, weight_, triggers_)};
         }
     }
 
-- 
GitLab