From ecd794a71b0083aef55886754ee207e944e983ee Mon Sep 17 00:00:00 2001
From: Sebastian Schmitt <sebastian.schmitt@kip.uni-heidelberg.de>
Date: Thu, 30 Sep 2021 18:01:13 +0200
Subject: [PATCH] Limit Poisson schedule with a stop time (#1684)

Limit poisson schedule with a stop time using the same semantics as the regular schedule.
closes #1617
---
 arbor/include/arbor/event_generator.hpp |  5 +++--
 arbor/include/arbor/schedule.hpp        | 22 ++++++++++++++------
 python/schedule.cpp                     | 27 +++++++++++++++++++------
 python/schedule.hpp                     |  6 +++++-
 python/test/unit/test_schedules.py      |  9 +++++++++
 test/unit/test_schedule.cpp             | 14 +++++++++++++
 6 files changed, 68 insertions(+), 15 deletions(-)

diff --git a/arbor/include/arbor/event_generator.hpp b/arbor/include/arbor/event_generator.hpp
index 42e5ac40..85db87d8 100644
--- a/arbor/include/arbor/event_generator.hpp
+++ b/arbor/include/arbor/event_generator.hpp
@@ -204,9 +204,10 @@ inline event_generator poisson_generator(
     float weight,
     time_type tstart,
     time_type rate_kHz,
-    const RNG& rng)
+    const RNG& rng,
+    time_type tstop=terminal_time)
 {
-    return schedule_generator(std::move(target), weight, poisson_schedule(tstart, rate_kHz, rng));
+    return schedule_generator(std::move(target), weight, poisson_schedule(tstart, rate_kHz, rng, tstop));
 }
 
 
diff --git a/arbor/include/arbor/schedule.hpp b/arbor/include/arbor/schedule.hpp
index 6e851335..e75903af 100644
--- a/arbor/include/arbor/schedule.hpp
+++ b/arbor/include/arbor/schedule.hpp
@@ -176,10 +176,11 @@ inline schedule explicit_schedule(const std::initializer_list<time_type>& seq) {
 template <typename RandomNumberEngine>
 class poisson_schedule_impl {
 public:
-    poisson_schedule_impl(time_type tstart, time_type rate_kHz, const RandomNumberEngine& rng):
-        tstart_(tstart), exp_(rate_kHz), rng_(rng), reset_state_(rng), next_(tstart)
+    poisson_schedule_impl(time_type tstart, time_type rate_kHz, const RandomNumberEngine& rng, time_type tstop):
+        tstart_(tstart), exp_(rate_kHz), rng_(rng), reset_state_(rng), next_(tstart), tstop_(tstop)
     {
         arb_assert(tstart_>=0);
+        arb_assert(tstart_ <= tstop_);
         step();
     }
 
@@ -190,6 +191,14 @@ public:
     }
 
     time_event_span events(time_type t0, time_type t1) {
+        // if we start after the maximal allowed time, we have nothing to do
+        if (t0 >= tstop_) {
+            return {};
+        }
+
+        // restrict by maximal allowed time
+        t1 = std::min(t1, tstop_);
+
         times_.clear();
 
         while (next_<t0) {
@@ -215,16 +224,17 @@ private:
     RandomNumberEngine reset_state_;
     time_type next_;
     std::vector<time_type> times_;
+    time_type tstop_;
 };
 
 template <typename RandomNumberEngine>
-inline schedule poisson_schedule(time_type rate_kHz, const RandomNumberEngine& rng) {
-    return schedule(poisson_schedule_impl<RandomNumberEngine>(0., rate_kHz, rng));
+inline schedule poisson_schedule(time_type rate_kHz, const RandomNumberEngine& rng, time_type tstop=terminal_time) {
+    return schedule(poisson_schedule_impl<RandomNumberEngine>(0., rate_kHz, rng, tstop));
 }
 
 template <typename RandomNumberEngine>
-inline schedule poisson_schedule(time_type tstart, time_type rate_kHz, const RandomNumberEngine& rng) {
-    return schedule(poisson_schedule_impl<RandomNumberEngine>(tstart, rate_kHz, rng));
+inline schedule poisson_schedule(time_type tstart, time_type rate_kHz, const RandomNumberEngine& rng, time_type tstop=terminal_time) {
+    return schedule(poisson_schedule_impl<RandomNumberEngine>(tstart, rate_kHz, rng, tstop));
 }
 
 } // namespace arb
diff --git a/python/schedule.cpp b/python/schedule.cpp
index 36edc380..70950a4a 100644
--- a/python/schedule.cpp
+++ b/python/schedule.cpp
@@ -26,6 +26,7 @@ std::ostream& operator<<(std::ostream& o, const explicit_schedule_shim& e) {
 
 std::ostream& operator<<(std::ostream& o, const poisson_schedule_shim& p) {
     return o << "<arbor.poisson_schedule: tstart " << p.tstart << " ms"
+             << ", tstop " << util::to_string(p.tstop) << " ms"
              << ", freq " << p.freq << " kHz"
              << ", seed " << p.seed << ">";
 };
@@ -141,11 +142,13 @@ std::vector<arb::time_type> explicit_schedule_shim::events(arb::time_type t0, ar
 poisson_schedule_shim::poisson_schedule_shim(
         arb::time_type ts,
         arb::time_type f,
-        rng_type::result_type s)
+        rng_type::result_type s,
+        py::object tstop)
 {
     set_tstart(ts);
     set_freq(f);
     seed = s;
+    set_tstop(tstop);
 }
 
 poisson_schedule_shim::poisson_schedule_shim(arb::time_type f) {
@@ -164,6 +167,11 @@ void poisson_schedule_shim::set_freq(arb::time_type f) {
     freq = f;
 };
 
+void poisson_schedule_shim::set_tstop(py::object t) {
+    tstop = py2optional<arb::time_type>(
+            t, "tstop must be a non-negative number, or None", is_nonneg());
+};
+
 arb::time_type poisson_schedule_shim::get_tstart() const {
     return tstart;
 }
@@ -172,8 +180,12 @@ arb::time_type poisson_schedule_shim::get_freq() const {
     return freq;
 }
 
+poisson_schedule_shim::opt_time_type poisson_schedule_shim::get_tstop() const {
+    return tstop;
+}
+
 arb::schedule poisson_schedule_shim::schedule() const {
-    return arb::poisson_schedule(tstart, freq, rng_type(seed));
+    return arb::poisson_schedule(tstart, freq, rng_type(seed), tstop.value_or(arb::terminal_time));
 }
 
 std::vector<arb::time_type> poisson_schedule_shim::events(arb::time_type t0, arb::time_type t1) {
@@ -237,15 +249,16 @@ void register_schedules(py::module& m) {
 
     // Poisson schedule
     py::class_<poisson_schedule_shim, schedule_shim_base> poisson_schedule(m, "poisson_schedule",
-        "Describes a schedule according to a Poisson process.");
+        "Describes a schedule according to a Poisson process within the interval [tstart, tstop).");
 
     poisson_schedule
-        .def(py::init<time_type, time_type, std::mt19937_64::result_type>(),
-            "tstart"_a = 0., "freq"_a, "seed"_a = 0,
+        .def(py::init<time_type, time_type, std::mt19937_64::result_type, py::object>(),
+             "tstart"_a = 0., "freq"_a, "seed"_a = 0, "tstop"_a = py::none(),
             "Construct a Poisson schedule with arguments:\n"
             "  tstart: The delivery time of the first event in the sequence [ms], 0 by default.\n"
             "  freq:   The expected frequency [kHz].\n"
-            "  seed:   The seed for the random number generator, 0 by default.")
+            "  seed:   The seed for the random number generator, 0 by default.\n"
+            "  tstop:  No events delivered after this time [ms], None by default.")
         .def(py::init<time_type>(),
             "freq"_a,
             "Construct a Poisson schedule, starting from t = 0, default seed, with:\n"
@@ -256,6 +269,8 @@ void register_schedules(py::module& m) {
             "The expected frequency [kHz].")
         .def_readwrite("seed", &poisson_schedule_shim::seed,
             "The seed for the random number generator.")
+        .def_property("tstop", &poisson_schedule_shim::get_tstop, &poisson_schedule_shim::set_tstop,
+            "No events delivered after this time [ms].")
         .def("events", &poisson_schedule_shim::events,
             "A view of monotonically increasing time values in the half-open interval [t0, t1).")
         .def("__str__",  util::to_string<poisson_schedule_shim>)
diff --git a/python/schedule.hpp b/python/schedule.hpp
index 9ee81fda..d7369674 100644
--- a/python/schedule.hpp
+++ b/python/schedule.hpp
@@ -77,19 +77,23 @@ struct explicit_schedule_shim: schedule_shim_base {
 // arb::poisson_schedule when a C++ recipe is created from a Python recipe.
 struct poisson_schedule_shim: schedule_shim_base {
     using rng_type = std::mt19937_64;
+    using opt_time_type = std::optional<arb::time_type>;
 
     arb::time_type tstart; // ms
     arb::time_type freq; // kHz
+    opt_time_type  tstop = {}; // ms
     rng_type::result_type seed;
 
-    poisson_schedule_shim(arb::time_type ts, arb::time_type f, rng_type::result_type s);
+    poisson_schedule_shim(arb::time_type ts, arb::time_type f, rng_type::result_type s, pybind11::object tstop);
     poisson_schedule_shim(arb::time_type f);
 
     void set_tstart(arb::time_type t);
     void set_freq(arb::time_type f);
+    void set_tstop(pybind11::object t);
 
     arb::time_type get_tstart() const;
     arb::time_type get_freq() const;
+    opt_time_type get_tstop() const;
 
     arb::schedule schedule() const override;
 
diff --git a/python/test/unit/test_schedules.py b/python/test/unit/test_schedules.py
index e0c6d9da..44eec19d 100644
--- a/python/test/unit/test_schedules.py
+++ b/python/test/unit/test_schedules.py
@@ -173,6 +173,15 @@ class PoissonSchedule(unittest.TestCase):
             "t1 must be a non-negative number"):
             ps = arb.poisson_schedule(0,0.01)
             ps.events(1., -1.)
+        with self.assertRaisesRegex(RuntimeError,
+            "tstop must be a non-negative number, or None"):
+            arb.poisson_schedule(0, 0.1, tstop='tstop')
+            ps.events(1., -1.)
+
+    def test_tstop_poisson_schedule(self):
+        tstop = 50
+        events = arb.poisson_schedule(0., 1, 0, tstop).events(0, 100)
+        self.assertTrue(max(events) < tstop)
 
 def suite():
     # specify class and test functions in tuple (here: all tests starting with 'test' from classes RegularSchedule, ExplicitSchedule and PoissonSchedule
diff --git a/test/unit/test_schedule.cpp b/test/unit/test_schedule.cpp
index e03707fe..84523851 100644
--- a/test/unit/test_schedule.cpp
+++ b/test/unit/test_schedule.cpp
@@ -348,3 +348,17 @@ TEST(schedule, poisson_offset_reset) {
     run_reset_check(poisson_schedule(3.3, 9.1, G), 1, 10, 7);
 }
 
+TEST(schedule, poisson_tstop) {
+    SCOPED_TRACE("poisson_tstop");
+    std::mt19937_64 G;
+    G.discard(500);
+
+    const double tstop = 50;
+
+    auto const times = as_vector(poisson_schedule(0, .234, G, tstop).events(0., 100.));
+    auto const max = std::max_element(begin(times), end(times));
+
+    EXPECT_TRUE(max != end(times));
+    EXPECT_TRUE(*max <= tstop);
+}
+
-- 
GitLab