diff --git a/arbor/include/arbor/spike_source_cell.hpp b/arbor/include/arbor/spike_source_cell.hpp index 5688b9949dd76016cbac40a6f25716efe680add7..16b91a46a10316913afd7e7ed6f8ea5a04034457 100644 --- a/arbor/include/arbor/spike_source_cell.hpp +++ b/arbor/include/arbor/spike_source_cell.hpp @@ -11,10 +11,12 @@ namespace arb { struct ARB_SYMBOL_VISIBLE spike_source_cell { cell_tag_type source; // Label of source. - schedule seq; + std::vector<schedule> seqs; spike_source_cell() = delete; - spike_source_cell(cell_tag_type source, schedule seq): source(std::move(source)), seq(std::move(seq)) {}; + template<typename... Seqs> + spike_source_cell(cell_tag_type source, Seqs&&... seqs): source(std::move(source)), seqs{std::forward<Seqs>(seqs)...} {} + spike_source_cell(cell_tag_type source, std::vector<schedule> seqs): source(std::move(source)), seqs(std::move(seqs)) {} }; } // namespace arb diff --git a/arbor/spike_source_cell_group.cpp b/arbor/spike_source_cell_group.cpp index 82d1062ba537ec807ca0591eda67a2ff175d9566..2f4176f3d8c729827f154c833eda3968dc405854 100644 --- a/arbor/spike_source_cell_group.cpp +++ b/arbor/spike_source_cell_group.cpp @@ -26,13 +26,13 @@ spike_source_cell_group::spike_source_cell_group( } } - time_sequences_.reserve(gids_.size()); + time_sequences_.reserve(gids.size()); for (auto gid: gids_) { cg_sources.add_cell(); cg_targets.add_cell(); try { auto cell = util::any_cast<spike_source_cell>(rec.get_cell_description(gid)); - time_sequences_.push_back(std::move(cell.seq)); + time_sequences_.emplace_back(cell.seqs); cg_sources.add_label(cell.source, {0, 1}); } catch (std::bad_any_cast& e) { @@ -51,8 +51,10 @@ void spike_source_cell_group::advance(epoch ep, time_type dt, const event_lane_s for (auto i: util::count_along(gids_)) { const auto gid = gids_[i]; - for (auto t: util::make_range(time_sequences_[i].events(ep.t0, ep.t1))) { - spikes_.push_back({{gid, 0u}, t}); + for (auto& ts: time_sequences_[i]) { + for (auto &t: util::make_range(ts.events(ep.t0, ep.t1))) { + spikes_.push_back({{gid, 0u}, t}); + } } } @@ -60,8 +62,10 @@ void spike_source_cell_group::advance(epoch ep, time_type dt, const event_lane_s }; void spike_source_cell_group::reset() { - for (auto& s: time_sequences_) { - s.reset(); + for (auto& ss: time_sequences_) { + for(auto& s: ss) { + s.reset(); + } } clear_spikes(); } diff --git a/arbor/spike_source_cell_group.hpp b/arbor/spike_source_cell_group.hpp index aefe4a0d50fccfa13f8927c5c8fa04bf67b03b40..23a9c4b36a98129fc4b858ead418acaf421a271f 100644 --- a/arbor/spike_source_cell_group.hpp +++ b/arbor/spike_source_cell_group.hpp @@ -40,7 +40,7 @@ public: private: std::vector<spike> spikes_; std::vector<cell_gid_type> gids_; - std::vector<schedule> time_sequences_; + std::vector<std::vector<schedule>> time_sequences_; }; } // namespace arb diff --git a/test/unit/test_spike_source.cpp b/test/unit/test_spike_source.cpp index 82a5173f4d949dd824a775135204c30a8c1e5d6a..ebd2e2ea8eb8cede9778d3e7122b6a7beb201f1e 100644 --- a/test/unit/test_spike_source.cpp +++ b/test/unit/test_spike_source.cpp @@ -115,3 +115,43 @@ TEST(spike_source, exhaust) test_seq(regular_schedule(0, 1, 5)); test_seq(explicit_schedule({0.3, 2.3, 4.7})); } + +TEST(spike_source, multiple) +{ + // This test assumes that seq will exhaust itself before t=10 ms. + auto test_seq = [](auto&&... seqs) { + std::vector<schedule> schedules{seqs...}; + ss_recipe rec(1u, spike_source_cell("src", static_cast<decltype(seqs)&&>(seqs)...)); + cell_label_range srcs, tgts; + spike_source_cell_group group({0}, rec, srcs, tgts); + + // epoch ending at 10ms + epoch ep(0, 0., 10.); + group.advance(ep, 1, {}); + + auto expected = spike_times(group.spikes()); + std::sort(expected.begin(), expected.end()); + + auto actual = std::vector<time_type>{}; + for (auto& schedule: schedules) { + auto ts = as_vector(schedule.events(0, 10)); + actual.insert(actual.end(), + ts.begin(), ts.end()); + } + std::sort(actual.begin(), actual.end()); + EXPECT_EQ(expected, actual); + + // Check that the last spike was before the end of the epoch. + EXPECT_LT(group.spikes().back().time, time_type(10)); + }; + + auto seqs = std::vector<schedule>{regular_schedule(0, 1, 5), + explicit_schedule({0.3, 2.3, 4.7})}; + test_seq(seqs); + test_seq(std::vector<schedule>{regular_schedule(0, 1, 5), + explicit_schedule({0.3, 2.3, 4.7})}); + test_seq(regular_schedule(0, 1, 5), + explicit_schedule({0.3, 2.3, 4.7})); + auto reg_sched = regular_schedule(0, 1, 5); + test_seq(reg_sched, explicit_schedule({0.3, 2.3, 4.7})); +}