diff --git a/example/miniapp/io.cpp b/example/miniapp/io.cpp index 5f9a1aa5d9ef66576ae290c6f173d34e15f96bb3..95ec7690af9c584cb6b0aeba0b376f3b931fbe35 100644 --- a/example/miniapp/io.cpp +++ b/example/miniapp/io.cpp @@ -190,10 +190,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { false, defopts.dry_run_ranks, "positive integer", cmd); TCLAP::SwitchArg verbose_arg( "v", "verbose", "Present more verbose information to stdout", cmd, false); - TCLAP::ValueArg<std::string> ispike_arg( - "I", "ispike_file", - "Input spikes from file", - false, "", "file name", cmd); cmd.reorder_arguments(); @@ -243,11 +239,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { update_option(options.file_extension, fopts, "file_extension"); } - update_option(options.spike_file_input, fopts, "spike_file_input"); - if (options.spike_file_input) { - update_option(options.input_spike_path, fopts, "input_spike_path"); - } - update_option(options.dry_run_ranks, fopts, "dry_run_ranks"); } catch (std::exception& e) { @@ -282,12 +273,6 @@ cl_options read_options(int argc, char** argv, bool allow_write) { update_option(options.spike_file_output, spike_output_arg); update_option(options.dry_run_ranks, dry_run_ranks_arg); - std::string is_file_name = ispike_arg.getValue(); - if (is_file_name != "") { - options.spike_file_input = true; - update_option(options.input_spike_path, ispike_arg); - } - if (options.trace_format!="csv" && options.trace_format!="json") { throw usage_error("trace format must be one of: csv, json"); } @@ -392,49 +377,5 @@ std::ostream& operator<<(std::ostream& o, const cl_options& options) { return o; } - -/// Parse spike times from a stream -/// A single spike per line, trailing whitespace is ignore -/// Throws a usage error when parsing fails -/// -/// Returns a vector of time_type - -std::vector<time_type> parse_spike_times_from_stream(std::ifstream & fid) { - std::vector<time_type> times; - std::string line; - while (std::getline(fid, line)) { - std::stringstream s(line); - - time_type t; - s >> t >> std::ws; - - if (!s || s.peek() != EOF) { - throw std::runtime_error( util::strprintf( - "Unable to parse spike file on line %d: \"%s\"\n", - times.size(), line)); - } - - times.push_back(t); - } - - return times; -} - -/// Parse spike times from a file supplied in path -/// A single spike per line, trailing white space is ignored -/// Throws a usage error when opening file or parsing fails -/// -/// Returns a vector of time_type - -std::vector<time_type> get_parsed_spike_times_from_path(arb::util::path path) { - std::ifstream fid(path); - if (!fid) { - throw std::runtime_error(util::strprintf( - "Unable to parse spike file: \"%s\"\n", path.c_str())); - } - - return parse_spike_times_from_stream(fid); -} - } // namespace io } // namespace arb diff --git a/example/miniapp/io.hpp b/example/miniapp/io.hpp index a782346639eabb17ce37ca3fd5e0d82ac749a92f..6283ac4d34b860fd4772ffeee7c4dd31ad8178b8 100644 --- a/example/miniapp/io.hpp +++ b/example/miniapp/io.hpp @@ -52,10 +52,6 @@ struct cl_options { std::string file_name = "spikes"; std::string file_extension = "gdf"; - // Parameters for spike input. - bool spike_file_input = false; - std::string input_spike_path; // Path to file with spikes - // Dry run parameters (pertinent only when built with 'dryrun' distrib model). int dry_run_ranks = 1; @@ -85,11 +81,5 @@ std::ostream& operator<<(std::ostream& o, const cl_options& opt); cl_options read_options(int argc, char** argv, bool allow_write = true); -/// Helper function for loading a vector of spike times from file -/// Spike times are expected to be in milli seconds floating points -/// On spike-time per line - -std::vector<time_type> get_parsed_spike_times_from_path(arb::util::path path); - } // namespace io } // namespace arb diff --git a/example/miniapp/miniapp.cpp b/example/miniapp/miniapp.cpp index fc03fa2f0d951bde8a94bd78a41131dd3634b852..2aecfeed17f3f862c989e700b795b6cb4bd62a91 100644 --- a/example/miniapp/miniapp.cpp +++ b/example/miniapp/miniapp.cpp @@ -207,11 +207,6 @@ std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_d p.num_synapses = options.all_to_all? options.cells-1: options.synapses_per_cell; p.synapse_type = options.syn_type; - // Parameters for spike input from file - if (options.spike_file_input) { - p.input_spike_path = options.input_spike_path; - } - if (options.all_to_all) { return make_basic_kgraph_recipe(options.cells, p, pdist); } diff --git a/example/miniapp/miniapp_recipes.cpp b/example/miniapp/miniapp_recipes.cpp index 443df8e4dff13d40a7051c70065c105fe469afd7..16c8893980e431bf5ba12013e07b2571af443d35 100644 --- a/example/miniapp/miniapp_recipes.cpp +++ b/example/miniapp/miniapp_recipes.cpp @@ -4,17 +4,16 @@ #include <utility> #include <cell.hpp> -#include <dss_cell_description.hpp> #include <event_generator.hpp> -#include <rss_cell.hpp> #include <morphology.hpp> +#include <spike_source_cell_group.hpp> +#include <time_sequence.hpp> #include <util/debug.hpp> #include "io.hpp" #include "miniapp_recipes.hpp" #include "morphology_pool.hpp" - namespace arb { // TODO: split cell description into separate morphology, stimulus, mechanisms etc. @@ -87,15 +86,9 @@ public: } util::unique_any get_cell_description(cell_gid_type i) const override { - // The last 'cell' is a spike source cell. Either a regular spiking - // or a spikes from file. + // The last 'cell' is a spike source cell. if (i == ncell_) { - if (param_.input_spike_path) { - auto spike_times = io::get_parsed_spike_times_from_path(param_.input_spike_path.value()); - return util::unique_any(dss_cell_description(spike_times)); - } - - return util::unique_any(rss_cell{0.0, 0.1, 0.1}); + return util::unique_any(spike_source_cell{regular_time_seq(0.0, 0.1, 0.1)}); } auto gen = std::mt19937(i); // TODO: replace this with hashing generator... @@ -145,13 +138,9 @@ public: } cell_kind get_cell_kind(cell_gid_type i) const override { - // The last 'cell' is a rss_cell with one spike at t=0 + // The last 'cell' is a regular spike source with one spike at t=0 if (i == ncell_) { - if (param_.input_spike_path) { - return cell_kind::data_spike_source; - } - - return cell_kind::regular_spike_source; + return cell_kind::spike_source; } return cell_kind::cable1d_neuron; } @@ -269,7 +258,7 @@ public: std::vector<cell_connection> connections_on(cell_gid_type i) const override { std::vector<cell_connection> conns; - // The rss_cell does not have inputs + // The regular spike cell does not have inputs if (i == ncell_) { return conns; } @@ -287,7 +276,7 @@ public: cc.dest = {i, t}; conns.push_back(cc); - // The rss_cell spikes at t=0, with these connections it looks like + // The regular spike source spikes at t=0, with these connections it looks like // (source % 20) == 0 spikes at that moment. if (source % 20 == 0) { cc.source = {ncell_, 0}; @@ -322,7 +311,7 @@ public: std::vector<cell_connection> connections_on(cell_gid_type i) const override { std::vector<cell_connection> conns; - // The rss_cell does not have inputs + // The spike source does not have inputs if (i == ncell_) { return conns; } @@ -337,7 +326,7 @@ public: cc.dest = {i, t}; conns.push_back(cc); - // The rss_cell spikes at t=0, with these connections it looks like + // The spike source spikes at t=0, with these connections it looks like // (source % 20) == 0 spikes at that moment. if (source % 20 == 0) { cc.source = {ncell_, 0}; diff --git a/example/miniapp/miniapp_recipes.hpp b/example/miniapp/miniapp_recipes.hpp index 729be3a13cbbb6f6eb78dd5d51d70f9d120d0a0b..0ed62f6b2942b8b3de0ae200e1608dcf1a830221 100644 --- a/example/miniapp/miniapp_recipes.hpp +++ b/example/miniapp/miniapp_recipes.hpp @@ -39,10 +39,6 @@ struct basic_recipe_param { // If true, iterate through morphologies rather than select randomly. bool morphology_round_robin = false; - - // If set we are importing the spikes injected in the network from file - // instead of a single spike at t==0 - util::optional<std::string> input_spike_path; // Path to file with spikes }; std::unique_ptr<recipe> make_basic_ring_recipe( diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7e0f31855624a8f7e4e8c8bba131b79c39e5cf42..11fe3e273442156b2b1e93ea257117e7564b7695 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -26,6 +26,7 @@ set(arbor_cxx_sources profiling/power_meter.cpp profiling/profiler.cpp schedule.cpp + spike_source_cell_group.cpp swcio.cpp threading/threading.cpp util/debug.cpp diff --git a/src/cell_group_factory.cpp b/src/cell_group_factory.cpp index fbdd4dd9df5f8a6d9f7f5b3f1623e05d856ca37a..862dfdb01bffe2b8afecd012bd1ea438d6443142 100644 --- a/src/cell_group_factory.cpp +++ b/src/cell_group_factory.cpp @@ -3,12 +3,11 @@ #include <backends.hpp> #include <cell_group.hpp> #include <domain_decomposition.hpp> -#include <dss_cell_group.hpp> #include <fvm_lowered_cell.hpp> #include <lif_cell_group.hpp> #include <mc_cell_group.hpp> #include <recipe.hpp> -#include <rss_cell_group.hpp> +#include <spike_source_cell_group.hpp> #include <util/unique_any.hpp> namespace arb { @@ -18,15 +17,12 @@ cell_group_ptr cell_group_factory(const recipe& rec, const group_description& gr case cell_kind::cable1d_neuron: return make_cell_group<mc_cell_group>(group.gids, rec, make_fvm_lowered_cell(group.backend)); - case cell_kind::regular_spike_source: - return make_cell_group<rss_cell_group>(group.gids, rec); + case cell_kind::spike_source: + return make_cell_group<spike_source_cell_group>(group.gids, rec); case cell_kind::lif_neuron: return make_cell_group<lif_cell_group>(group.gids, rec); - case cell_kind::data_spike_source: - return make_cell_group<dss_cell_group>(group.gids, rec); - default: throw std::runtime_error("unknown cell kind"); } diff --git a/src/common_types.hpp b/src/common_types.hpp index 79f77232fdd67b986e723e122462c9a0fa3acdc2..d74bf31ef9fb5266102ccc8e19a4780690bd6fba 100644 --- a/src/common_types.hpp +++ b/src/common_types.hpp @@ -54,7 +54,7 @@ DEFINE_LEXICOGRAPHIC_ORDERING(cell_member_type,(a.gid,a.index),(b.gid,b.index)) // For storing time values [ms] using time_type = float; -constexpr time_type max_time = std::numeric_limits<time_type>::max(); +constexpr time_type terminal_time = std::numeric_limits<time_type>::max(); // Extra contextual information associated with a probe. @@ -68,10 +68,9 @@ using sample_size_type = std::int32_t; // group equal kinds in the same cell group. enum class cell_kind { - cable1d_neuron, // Our own special mc neuron - lif_neuron, // Leaky-integrate and fire neuron - regular_spike_source, // Regular spiking source - data_spike_source, // Spike source from values inserted via description + cable1d_neuron, // Our own special mc neuron + lif_neuron, // Leaky-integrate and fire neuron + spike_source, // Cell that generates spikes at a user-supplied sequence of time points }; } // namespace arb diff --git a/src/common_types_io.cpp b/src/common_types_io.cpp index 3a3fa34e7605d3cb2c7da8c984341d31c2907d41..91360c0fa46bc4b46cbee07a357a8c6dbff7de4b 100644 --- a/src/common_types_io.cpp +++ b/src/common_types_io.cpp @@ -9,12 +9,10 @@ std::ostream& operator<<(std::ostream& O, arb::cell_member_type m) { std::ostream& operator<<(std::ostream& o, arb::cell_kind k) { o << "cell_kind::"; switch (k) { - case arb::cell_kind::regular_spike_source: - return o << "regular_spike_source"; + case arb::cell_kind::spike_source: + return o << "spike_source"; case arb::cell_kind::cable1d_neuron: return o << "cable1d_neuron"; - case arb::cell_kind::data_spike_source: - return o << "data_spike_source"; case arb::cell_kind::lif_neuron: return o << "lif_neuron"; } diff --git a/src/dss_cell_description.hpp b/src/dss_cell_description.hpp deleted file mode 100644 index d0c937a8f3b6bb84bd9025b431fd42e634f85061..0000000000000000000000000000000000000000 --- a/src/dss_cell_description.hpp +++ /dev/null @@ -1,21 +0,0 @@ -#pragma once - -#include <vector> - -#include <common_types.hpp> - -namespace arb { - -/// Description for a data spike source: a cell that generates spikes provided as a vector of -/// spike times at the start of a run. - -struct dss_cell_description { - std::vector<time_type> spike_times; - - /// The description needs a vector of doubles for the description - dss_cell_description(std::vector<time_type> spike_times): - spike_times(std::move(spike_times)) - {} -}; - -} // namespace arb diff --git a/src/dss_cell_group.hpp b/src/dss_cell_group.hpp deleted file mode 100644 index 8cd21c5c37fddd7ac4d1618109bdc8e410a4f5a0..0000000000000000000000000000000000000000 --- a/src/dss_cell_group.hpp +++ /dev/null @@ -1,102 +0,0 @@ -#pragma once - -#include <cell_group.hpp> -#include <dss_cell_description.hpp> -#include <profiling/profiler.hpp> -#include <recipe.hpp> -#include <util/span.hpp> -#include <util/unique_any.hpp> - -namespace arb { - -/// Cell_group to collect spike sources -class dss_cell_group: public cell_group { -public: - dss_cell_group(std::vector<cell_gid_type> gids, const recipe& rec): - gids_(std::move(gids)) - { - for (auto gid: gids_) { - auto desc = util::any_cast<dss_cell_description>(rec.get_cell_description(gid)); - // store spike times from description - auto times = desc.spike_times; - util::sort(times); - spike_times_.push_back(std::move(times)); - - // Take a reference to the first spike time - not_emit_it_.push_back(spike_times_.back().begin()); - } - } - - cell_kind get_cell_kind() const override { - return cell_kind::data_spike_source; - } - - void reset() override { - // Reset the pointers to the next undelivered spike to the start - // of the input range. - auto it = not_emit_it_.begin(); - auto times = spike_times_.begin(); - for (;it != not_emit_it_.end(); ++it, times++) { - *it = times->begin(); - } - - clear_spikes(); - } - - void set_binning_policy(binning_kind policy, time_type bin_interval) override {} - - void advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) override { - PE(advance_dss); - for (auto i: util::make_span(0, not_emit_it_.size())) { - // The first potential spike_time to emit for this cell - auto spike_time_it = not_emit_it_[i]; - - // Find the first spike past tfinal - not_emit_it_[i] = std::find_if( - spike_time_it, spike_times_[i].end(), - [ep](time_type t) {return t >= ep.tfinal; } - ); - - // Loop over the range and create spikes - for (; spike_time_it != not_emit_it_[i]; ++spike_time_it) { - spikes_.push_back({ {gids_[i], 0u}, *spike_time_it }); - } - } - PL(); - }; - - const std::vector<spike>& spikes() const override { - return spikes_; - } - - void clear_spikes() override { - spikes_.clear(); - } - - void add_sampler(sampler_association_handle h, cell_member_predicate probe_ids, schedule sched, sampler_function fn, sampling_policy policy) override { - std::logic_error("The dss_cells do not support sampling of internal state!"); - } - - void remove_sampler(sampler_association_handle h) override {} - - void remove_all_samplers() override {} - -private: - // Spikes that are generated. - std::vector<spike> spikes_; - - // Map of local index to gid - std::vector<cell_gid_type> gids_; - - // The dss_cell is simple: Put all logic in the cellgroup cause accelerator support - // is not expected. We need storage for the cell state - - // The times the cells should spike, one for each cell, sorted in time - std::vector<std::vector<time_type> > spike_times_; - - // Each cell needs its own itterator to the first spike not yet emitted - std::vector<std::vector<time_type>::iterator > not_emit_it_; -}; - -} // namespace arb - diff --git a/src/event_generator.hpp b/src/event_generator.hpp index 393e391a8b4901034a9a981ce89030fd66786cd4..d72263b68f63cff60f0e4ff9ce16c723f1742f03 100644 --- a/src/event_generator.hpp +++ b/src/event_generator.hpp @@ -6,16 +6,38 @@ #include <common_types.hpp> #include <event_queue.hpp> +#include <time_sequence.hpp> #include <util/range.hpp> #include <util/rangeutil.hpp> namespace arb { +// Generate a postsynaptic spike event that has delivery time set to +// terminal_time. Such events are used as sentinels, to indicate the +// end of a sequence. +inline constexpr +postsynaptic_spike_event make_terminal_pse() { + return postsynaptic_spike_event{cell_member_type{0,0}, terminal_time, 0}; +} + inline -postsynaptic_spike_event terminal_pse() { - return postsynaptic_spike_event{cell_member_type{0,0}, max_time, 0}; +bool is_terminal_pse(const postsynaptic_spike_event& e) { + return e.time==terminal_time; } + +// The simplest possible generator that generates no events. +// Declared ahead of event_generator so that it can be used as the default +// generator. +struct empty_generator { + postsynaptic_spike_event front() { + return postsynaptic_spike_event{cell_member_type{0,0}, terminal_time, 0}; + } + void pop() {} + void reset() {} + void advance(time_type t) {}; +}; + // An event_generator generates a sequence of events to be delivered to a cell. // The sequence of events is always in ascending order, i.e. each event will be // greater than the event that proceded it, where events are ordered by: @@ -28,7 +50,7 @@ public: // copy, move and constructor interface // - event_generator(): event_generator(dummy_generator()) {} + event_generator(): event_generator(empty_generator()) {} template <typename Impl> event_generator(Impl&& impl): @@ -53,10 +75,10 @@ public: // Get the current event in the stream. // Does not modify the state of the stream, i.e. multiple calls to - // next() will return the same event in the absence of calls to pop(), + // front() will return the same event in the absence of calls to pop(), // advance() or reset(). - postsynaptic_spike_event next() { - return impl_->next(); + postsynaptic_spike_event front() { + return impl_->front(); } // Move the generator to the next event in the stream. @@ -69,7 +91,7 @@ public: impl_->reset(); } - // Update state of the generator such that the event returned by next() is + // Update state of the generator such that the event returned by front() is // the first event with delivery time >= t. void advance(time_type t) { return impl_->advance(t); @@ -77,7 +99,7 @@ public: private: struct interface { - virtual postsynaptic_spike_event next() = 0; + virtual postsynaptic_spike_event front() = 0; virtual void pop() = 0; virtual void advance(time_type t) = 0; virtual void reset() = 0; @@ -92,8 +114,8 @@ private: explicit wrap(const Impl& impl): wrapped(impl) {} explicit wrap(Impl&& impl): wrapped(std::move(impl)) {} - postsynaptic_spike_event next() override { - return wrapped.next(); + postsynaptic_spike_event front() override { + return wrapped.front(); } void pop() override { @@ -114,50 +136,38 @@ private: Impl wrapped; }; - - struct dummy_generator { - postsynaptic_spike_event next() { return terminal_pse(); } - void pop() {} - void reset() {} - void advance(time_type t) {}; - }; - }; // Generator that feeds events that are specified with a vector. // Makes a copy of the input sequence of events. struct vector_backed_generator { using pse = postsynaptic_spike_event; - vector_backed_generator(pse_vector events): - events_(std::move(events)), - it_(events_.begin()) - { - if (!util::is_sorted(events_)) { - util::sort(events_); - } - } + vector_backed_generator(cell_member_type target, float weight, std::vector<time_type> samples): + target_(target), + weight_(weight), + tseq_(std::move(samples)) + {} - postsynaptic_spike_event next() { - return it_==events_.end()? terminal_pse(): *it_; + postsynaptic_spike_event front() { + return postsynaptic_spike_event{target_, tseq_.front(), weight_}; } void pop() { - if (it_!=events_.end()) { - ++it_; - } + tseq_.pop(); } void reset() { - it_ = events_.begin(); + tseq_.reset(); } void advance(time_type t) { - it_ = std::lower_bound(events_.begin(), events_.end(), t, event_time_less()); + tseq_.advance(t); } private: - std::vector<postsynaptic_spike_event> events_; - std::vector<postsynaptic_spike_event>::const_iterator it_; + cell_member_type target_; + float weight_; + vector_time_seq tseq_; }; // Generator for events in a generic sequence. @@ -174,8 +184,8 @@ struct seq_generator { EXPECTS(util::is_sorted(events_)); } - postsynaptic_spike_event next() { - return it_==events_.end()? terminal_pse(): *it_; + postsynaptic_spike_event front() { + return it_==events_.end()? make_terminal_pse(): *it_; } void pop() { @@ -193,7 +203,6 @@ struct seq_generator { } private: - const Seq& events_; typename Seq::const_iterator it_; }; @@ -208,56 +217,32 @@ struct regular_generator { float weight, time_type tstart, time_type dt, - time_type tstop=max_time): + time_type tstop=terminal_time): target_(target), weight_(weight), - step_(0), - t_start_(tstart), - dt_(dt), - t_stop_(tstop) + tseq_(tstart, dt, tstop) {} - postsynaptic_spike_event next() { - const auto t = time(); - return t<t_stop_? - postsynaptic_spike_event{target_, t, weight_}: - terminal_pse(); + postsynaptic_spike_event front() { + return postsynaptic_spike_event{target_, tseq_.front(), weight_}; } void pop() { - ++step_; + tseq_.pop(); } void advance(time_type t0) { - t0 = std::max(t0, t_start_); - step_ = (t0-t_start_)/dt_; - - // Finding the smallest value for step_ that satisfies the condition - // that time() >= t0 is unfortunately a horror show because floating - // point precission. - while (step_ && time()>=t0) { - --step_; - } - while (time()<t0) { - ++step_; - } + tseq_.advance(t0); } void reset() { - step_ = 0; + tseq_.reset(); } private: - time_type time() const { - return t_start_ + step_*dt_; - } - cell_member_type target_; float weight_; - std::size_t step_; - time_type t_start_; - time_type dt_; - time_type t_stop_; + regular_time_seq tseq_; }; // Generates a stream of events at times described by a Poisson point process @@ -271,48 +256,34 @@ struct poisson_generator { RandomNumberEngine rng, time_type tstart, time_type rate_per_ms, - time_type tstop=max_time): - exp_(rate_per_ms), - reset_state_(std::move(rng)), + time_type tstop=terminal_time): target_(target), weight_(weight), - t_start_(tstart), - t_stop_(tstop) + tseq_(std::move(rng), tstart, rate_per_ms, tstop) { reset(); } - postsynaptic_spike_event next() { - return next_<t_stop_? - postsynaptic_spike_event{target_, next_, weight_}: - terminal_pse(); + postsynaptic_spike_event front() { + return postsynaptic_spike_event{target_, tseq_.front(), weight_}; } void pop() { - next_ += exp_(rng_); + tseq_.pop(); } void advance(time_type t0) { - while (next_<t0) { - pop(); - } + tseq_.advance(t0); } void reset() { - rng_ = reset_state_; - next_ = t_start_; - pop(); + tseq_.reset(); } private: - std::exponential_distribution<time_type> exp_; - RandomNumberEngine rng_; - const RandomNumberEngine reset_state_; const cell_member_type target_; const float weight_; - const time_type t_start_; - const time_type t_stop_; - time_type next_; + poisson_time_seq<RandomNumberEngine> tseq_; }; } // namespace arb diff --git a/src/merge_events.cpp b/src/merge_events.cpp index 1fd116151559a81a585ffcbcf7eea5e15daefdba..937a61a74561ccedd8049c3215a066d6b907e4ef 100644 --- a/src/merge_events.cpp +++ b/src/merge_events.cpp @@ -46,8 +46,8 @@ tourney_tree::tourney_tree(std::vector<event_generator>& input): // Set the leaf nodes for (auto i=0u; i<leaves_; ++i) { heap_[leaf(i)] = i<n_lanes_? - key_val(i, input[i].next()): - key_val(i, terminal_pse()); // null leaf node + key_val(i, input[i].front()): + key_val(i, make_terminal_pse()); // null leaf node } // Walk the tree to initialize the non-leaf nodes setup(0); @@ -62,7 +62,7 @@ void tourney_tree::print() const { } bool tourney_tree::empty() const { - return event(0).time == max_time; + return event(0).time == terminal_time; } bool tourney_tree::empty(time_type t) const { @@ -81,7 +81,7 @@ void tourney_tree::pop() { // draw the next event from the input lane input_[lane].pop(); // place event the leaf node for this lane - event(i) = input_[lane].next(); + event(i) = input_[lane].front(); // re-heapify the tree with a single walk from leaf to root while ((i=parent(i))) { diff --git a/src/rss_cell.hpp b/src/rss_cell.hpp deleted file mode 100644 index 17f2c4ae26d4c5569b3f18fd2cbe95a9dd792ec7..0000000000000000000000000000000000000000 --- a/src/rss_cell.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include <common_types.hpp> - -namespace arb { - -/// Description class for a regular spike source: a cell that generates -/// spikes with a fixed period over a given time interval. - -struct rss_cell { - time_type start_time; - time_type period; - time_type stop_time; -}; - -} // namespace arb diff --git a/src/rss_cell_group.hpp b/src/rss_cell_group.hpp deleted file mode 100644 index 046ead4ef44b8a13f383038d1b9133857ee1067a..0000000000000000000000000000000000000000 --- a/src/rss_cell_group.hpp +++ /dev/null @@ -1,97 +0,0 @@ -#pragma once - -#include <utility> - -#include <cell_group.hpp> -#include <profiling/profiler.hpp> -#include <recipe.hpp> -#include <rss_cell.hpp> -#include <util/unique_any.hpp> - -namespace arb { - -/// Cell group implementing RSS cells. - -class rss_cell_group: public cell_group { -public: - rss_cell_group(std::vector<cell_gid_type> gids, const recipe& rec) { - cells_.reserve(gids.size()); - for (auto gid: gids) { - cells_.emplace_back( - util::any_cast<rss_cell>(rec.get_cell_description(gid)), - gid); - } - reset(); - } - - cell_kind get_cell_kind() const override { - return cell_kind::regular_spike_source; - } - - void reset() override { - clear_spikes(); - for (auto& cell : cells_) { - cell.step = 0; - } - } - - void set_binning_policy(binning_kind policy, time_type bin_interval) override {} - - void advance(epoch ep, time_type dt, const event_lane_subrange& events) override { - PE(advance_dss); - for (auto& cell: cells_) { - - auto t_end = std::min(cell.stop_time, ep.tfinal); - auto t = cell.start_time + cell.step*cell.period; - - while (t < t_end) { - spikes_.push_back({{cell.gid, 0}, t}); - - // Increasing time with period time step has float issues - ++cell.step; - t = cell.start_time + cell.step*cell.period; - } - } - PL(); - } - - const std::vector<spike>& spikes() const override { - return spikes_; - } - - void clear_spikes() override { - spikes_.clear(); - } - - void add_sampler(sampler_association_handle, cell_member_predicate, schedule, sampler_function, sampling_policy) override { - std::logic_error("rss_cell does not support sampling"); - } - - void remove_sampler(sampler_association_handle) override {} - - void remove_all_samplers() override {} - -private: - // RSS description plus gid for each RSS cell. - struct rss_info: public rss_cell { - rss_info(rss_cell desc, cell_gid_type gid): - rss_cell(std::move(desc)), gid(gid) - {} - - cell_gid_type gid; - - // We do not store the time but a count of the number of step since start - // of cell. This prevents float problems at high number. - std::size_t step; - }; - - // RSS cell descriptions. - std::vector<rss_info> cells_; - - - // Spikes that are generated. - std::vector<spike> spikes_; -}; - -} // namespace arb - diff --git a/src/spike.hpp b/src/spike.hpp index 08a807568bffb9d48fe4faed348db2ae1c166121..422e5aad62c3ae86cc82094fbcf88ae8d5268494 100644 --- a/src/spike.hpp +++ b/src/spike.hpp @@ -21,7 +21,7 @@ struct basic_spike { {} friend bool operator==(const basic_spike& l, const basic_spike& r) { - return l.source==r.source && l.time==r.time; + return l.time==r.time && l.source==r.source; } }; diff --git a/src/spike_source_cell_group.cpp b/src/spike_source_cell_group.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cbdaade5e1b9e230fff25b1d8eb3e80173d91bf --- /dev/null +++ b/src/spike_source_cell_group.cpp @@ -0,0 +1,67 @@ +#include <exception> + +#include <cell_group.hpp> +#include <profiling/profiler.hpp> +#include <recipe.hpp> +#include <spike_source_cell_group.hpp> +#include <time_sequence.hpp> + +namespace arb { + +spike_source_cell_group::spike_source_cell_group(std::vector<cell_gid_type> gids, const recipe& rec): + gids_(std::move(gids)) +{ + time_sequences_.reserve(gids_.size()); + for (auto gid: gids_) { + try { + auto cell = util::any_cast<spike_source_cell>(rec.get_cell_description(gid)); + time_sequences_.push_back(std::move(cell.seq)); + } + catch (util::bad_any_cast& e) { + throw std::runtime_error("model cell type mismatch: gid "+std::to_string(gid)+" is not a spike_source_cell"); + } + } +} + +cell_kind spike_source_cell_group::get_cell_kind() const { + return cell_kind::spike_source; +} + +void spike_source_cell_group::advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) { + PE(advance_sscell); + + for (auto i: util::make_span(0, gids_.size())) { + auto& tseq = time_sequences_[i]; + const auto gid = gids_[i]; + + while (tseq.front()<ep.tfinal) { + spikes_.push_back({{gid, 0u}, tseq.front()}); + tseq.pop(); + } + } + PL(); +}; + +void spike_source_cell_group::reset() { + for (auto& s: time_sequences_) { + s.reset(); + } + + clear_spikes(); +} + +const std::vector<spike>& spike_source_cell_group::spikes() const { + return spikes_; +} + +void spike_source_cell_group::clear_spikes() { + spikes_.clear(); +} + +void spike_source_cell_group::add_sampler(sampler_association_handle, cell_member_predicate, schedule, sampler_function, sampling_policy) { + std::logic_error("A spike_source_cell group doen't support sampling of internal state!"); +} + +} // namespace arb + + diff --git a/src/spike_source_cell_group.hpp b/src/spike_source_cell_group.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f81367f54e7fcbb69ac20aa670f47183426d70b --- /dev/null +++ b/src/spike_source_cell_group.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include <cell_group.hpp> +#include <recipe.hpp> +#include <time_sequence.hpp> + +namespace arb { + +struct spike_source_cell { + time_seq seq; +}; + +class spike_source_cell_group: public cell_group { +public: + spike_source_cell_group(std::vector<cell_gid_type> gids, const recipe& rec); + + cell_kind get_cell_kind() const override; + + void advance(epoch ep, time_type dt, const event_lane_subrange& event_lanes) override; + + void reset() override; + + void set_binning_policy(binning_kind policy, time_type bin_interval) override {} + + const std::vector<spike>& spikes() const override; + + void clear_spikes() override; + + void add_sampler(sampler_association_handle h, cell_member_predicate probe_ids, schedule sched, sampler_function fn, sampling_policy policy) override; + + void remove_sampler(sampler_association_handle h) override {} + + void remove_all_samplers() override {} + +private: + std::vector<spike> spikes_; + std::vector<cell_gid_type> gids_; + std::vector<time_seq> time_sequences_; +}; + +} // namespace arb + diff --git a/src/time_sequence.hpp b/src/time_sequence.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9b01d0238e2d91b6286b4ef893d881caa9e7aa22 --- /dev/null +++ b/src/time_sequence.hpp @@ -0,0 +1,251 @@ +#pragma once + +#include <algorithm> +#include <memory> +#include <random> + +#include <common_types.hpp> +#include <event_queue.hpp> +#include <util/meta.hpp> +#include <util/rangeutil.hpp> + +namespace arb { + + +struct empty_time_seq { + time_type front() { return terminal_time; } + void pop() {} + void reset() {} + void advance(time_type t) {}; +}; + +// An time_sequence generates a sequence of time points. +// The sequence of times is always monotonically increasing, i.e. each time is +// not earlier than the time that proceded it. +class time_seq { +public: + // + // copy, move and constructor interface + // + + time_seq(): time_seq(empty_time_seq()) {} + + template < + typename Impl, + typename = typename util::enable_if_t< + !std::is_same<typename std::decay<Impl>::type, + time_seq>::value>> + time_seq(Impl&& impl): + impl_(new wrap<Impl>(std::forward<Impl>(impl))) + {} + + time_seq(time_seq&& other) = default; + time_seq& operator=(time_seq&& other) = default; + + time_seq(const time_seq& other): + impl_(other.impl_->clone()) + {} + + time_seq& operator=(const time_seq& other) { + impl_ = other.impl_->clone(); + return *this; + } + + // + // time sequence interface + // + + // Get the current time in the stream. + // Does not modify the state of the stream, i.e. multiple calls to + // front() will return the same time in the absence of calls to pop(), + // advance() or reset(). + time_type front() { + return impl_->front(); + } + + // Move the generator to the front time in the stream. + void pop() { + impl_->pop(); + } + + // Reset the generator to the same state that it had on construction. + void reset() { + impl_->reset(); + } + + // Update state of the generator such that the time returned by front() is + // the first time with delivery time >= t. + void advance(time_type t) { + return impl_->advance(t); + } + +private: + struct interface { + virtual time_type front() = 0; + virtual void pop() = 0; + virtual void advance(time_type t) = 0; + virtual void reset() = 0; + virtual std::unique_ptr<interface> clone() = 0; + virtual ~interface() {} + }; + + std::unique_ptr<interface> impl_; + + template <typename Impl> + struct wrap: interface { + explicit wrap(const Impl& impl): wrapped(impl) {} + explicit wrap(Impl&& impl): wrapped(std::move(impl)) {} + + time_type front() override { + return wrapped.front(); + } + void pop() override { + return wrapped.pop(); + } + void advance(time_type t) override { + return wrapped.advance(t); + } + void reset() override { + wrapped.reset(); + } + std::unique_ptr<interface> clone() override { + return std::unique_ptr<interface>(new wrap<Impl>(wrapped)); + } + + Impl wrapped; + }; +}; + +// Sequence of time points prescribed by a vector +struct vector_time_seq { + vector_time_seq(std::vector<time_type> seq): + seq_(std::move(seq)) + { + // Ensure that the time values are sorted. + if(!std::is_sorted(seq_.begin(), seq_.end())) { + util::sort(seq_); + } + reset(); + } + + time_type front() { + return it_==seq_.end()? terminal_time: *it_; + } + + void pop() { + if (it_!=seq_.end()) { + ++it_; + } + } + + void advance(time_type t0) { + it_ = std::lower_bound(seq_.begin(), seq_.end(), t0); + } + + void reset() { + it_ = seq_.begin(); + } + +private: + std::vector<time_type> seq_; + std::vector<time_type>::const_iterator it_; +}; + +// Generates a set of regularly spaced time samples. +// t=t_start+n*dt, ∀ t ∈ [t_start, t_stop) +struct regular_time_seq { + regular_time_seq(time_type tstart, + time_type dt, + time_type tstop=terminal_time): + step_(0), + t_start_(tstart), + dt_(dt), + t_stop_(tstop) + {} + + time_type front() { + const auto t = time(); + return t<t_stop_? t: terminal_time; + } + + void pop() { + ++step_; + } + + void advance(time_type t0) { + t0 = std::max(t0, t_start_); + step_ = (t0-t_start_)/dt_; + + // Finding the smallest value for step_ that satisfies the condition + // that time() >= t0 is unfortunately a horror show because floating + // point precission. + while (step_ && time()>=t0) { + --step_; + } + while (time()<t0) { + ++step_; + } + } + + void reset() { + step_ = 0; + } + +private: + time_type time() const { + return t_start_ + step_*dt_; + } + + std::size_t step_; + time_type t_start_; + time_type dt_; + time_type t_stop_; +}; + +// Generates a stream of time points described by a Poisson point process +// with rate_per_ms samples per ms. +template <typename RandomNumberEngine> +struct poisson_time_seq { + poisson_time_seq(RandomNumberEngine rng, + time_type tstart, + time_type rate_per_ms, + time_type tstop=terminal_time): + exp_(rate_per_ms), + reset_state_(std::move(rng)), + t_start_(tstart), + t_stop_(tstop) + { + reset(); + } + + time_type front() { + return next_<t_stop_? next_: terminal_time; + } + + void pop() { + next_ += exp_(rng_); + } + + void advance(time_type t0) { + while (next_<t0) { + pop(); + } + } + + void reset() { + rng_ = reset_state_; + next_ = t_start_; + pop(); + } + +private: + std::exponential_distribution<time_type> exp_; + RandomNumberEngine rng_; + const RandomNumberEngine reset_state_; + const time_type t_start_; + const time_type t_stop_; + time_type next_; +}; + +} // namespace arb + diff --git a/tests/global_communication/test_communicator.cpp b/tests/global_communication/test_communicator.cpp index a41c06f1c80ee6413f8ba664c7597c0d7079d5cf..55f48bae094092c4be7db329d2ba343882628dc2 100644 --- a/tests/global_communication/test_communicator.cpp +++ b/tests/global_communication/test_communicator.cpp @@ -177,7 +177,7 @@ namespace { } cell_kind get_cell_kind(cell_gid_type gid) const override { - return gid%2? cell_kind::cable1d_neuron: cell_kind::regular_spike_source; + return gid%2? cell_kind::cable1d_neuron: cell_kind::spike_source; } cell_size_type num_sources(cell_gid_type) const override { return 1; } @@ -240,7 +240,7 @@ namespace { return {}; } cell_kind get_cell_kind(cell_gid_type gid) const override { - return gid%2? cell_kind::cable1d_neuron: cell_kind::regular_spike_source; + return gid%2? cell_kind::cable1d_neuron: cell_kind::spike_source; } cell_size_type num_sources(cell_gid_type) const override { return 1; } diff --git a/tests/global_communication/test_domain_decomposition.cpp b/tests/global_communication/test_domain_decomposition.cpp index a586e0dacc305b92f1cb1ac1ae92800ef480a987..02304dc2958f4b36e2f2048126d083488bfe8d6e 100644 --- a/tests/global_communication/test_domain_decomposition.cpp +++ b/tests/global_communication/test_domain_decomposition.cpp @@ -41,7 +41,7 @@ namespace { cell_kind get_cell_kind(cell_gid_type gid) const override { return gid%2? - cell_kind::regular_spike_source: + cell_kind::spike_source: cell_kind::cable1d_neuron; } @@ -171,7 +171,7 @@ TEST(domain_decomposition, heterogeneous_population) { EXPECT_EQ(grp.backend, backend_kind::multicore); } - for (auto k: {cell_kind::cable1d_neuron, cell_kind::regular_spike_source}) { + for (auto k: {cell_kind::cable1d_neuron, cell_kind::spike_source}) { const auto& gids = kind_lists[k]; EXPECT_EQ(gids.size(), n_local/2); for (auto gid: gids) { diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index dca52faecdd1bf934dc032d09367f23acfce4425..942b22839db00949c59dabc1750c5dfb2dddd06c 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -45,7 +45,6 @@ set(test_sources test_counter.cpp test_cycle.cpp test_domain_decomposition.cpp - test_dss_cell_group.cpp test_either.cpp test_event_binner.cpp test_event_generators.cpp @@ -75,7 +74,7 @@ set(test_sources test_range.cpp test_segment.cpp test_schedule.cpp - test_rss_cell.cpp + test_spike_source.cpp test_local_context.cpp test_simd.cpp test_span.cpp @@ -85,6 +84,7 @@ set(test_sources test_strprintf.cpp test_swcio.cpp test_synapses.cpp + test_time_seq.cpp test_tree.cpp test_transform.cpp test_uninitialized.cpp diff --git a/tests/unit/test_domain_decomposition.cpp b/tests/unit/test_domain_decomposition.cpp index 1c58a7fb5443cf92d6dc1b8f1442515058e82d02..9c3ec59a9bc4f34db45d72ca529ea98c444727f6 100644 --- a/tests/unit/test_domain_decomposition.cpp +++ b/tests/unit/test_domain_decomposition.cpp @@ -18,9 +18,9 @@ namespace { struct dummy_cell {}; using homo_recipe = homogeneous_recipe<cell_kind::cable1d_neuron, dummy_cell>; - // Heterogenous cell population of cable and rss cells. + // Heterogenous cell population of cable and spike source cells. // Interleaved so that cells with even gid are cable cells, and odd gid are - // rss cells. + // spike source cells. class hetero_recipe: public recipe { public: hetero_recipe(cell_size_type s): size_(s) {} @@ -35,7 +35,7 @@ namespace { cell_kind get_cell_kind(cell_gid_type gid) const override { return gid%2? - cell_kind::regular_spike_source: + cell_kind::spike_source: cell_kind::cable1d_neuron; } @@ -140,7 +140,7 @@ TEST(domain_decomposition, heterogenous_population) EXPECT_EQ(grp.backend, backend_kind::multicore); } - for (auto k: {cell_kind::cable1d_neuron, cell_kind::regular_spike_source}) { + for (auto k: {cell_kind::cable1d_neuron, cell_kind::spike_source}) { const auto& gids = kind_lists[k]; EXPECT_EQ(gids.size(), num_cells/2); for (auto gid: gids) { @@ -177,7 +177,7 @@ TEST(domain_decomposition, heterogenous_population) ++ncells; } } - else if (k==cell_kind::regular_spike_source){ + else if (k==cell_kind::spike_source){ EXPECT_EQ(grp.backend, backend_kind::multicore); EXPECT_EQ(grp.gids.size(), 1u); EXPECT_TRUE(grp.gids.front()%2); diff --git a/tests/unit/test_dss_cell_group.cpp b/tests/unit/test_dss_cell_group.cpp deleted file mode 100644 index ed9fc3c4e4db87af95a9b7570a240399836bb67d..0000000000000000000000000000000000000000 --- a/tests/unit/test_dss_cell_group.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "../gtest.h" - -#include <dss_cell_description.hpp> -#include <dss_cell_group.hpp> -#include <util/unique_any.hpp> - -#include "../simple_recipes.hpp" - -using namespace arb; - -using dss_recipe = homogeneous_recipe<cell_kind::data_spike_source, dss_cell_description>; - -TEST(dss_cell, basic_usage) -{ - const time_type spike_time = 0.1; - dss_recipe rec(1u, dss_cell_description({spike_time})); - dss_cell_group sut({0}, rec); - - // No spikes in this time frame. - time_type dt = 0.01; // (note that dt is ignored in dss_cell_group). - epoch ep(0, 0.09); - sut.advance(ep, dt, {}); - - auto spikes = sut.spikes(); - EXPECT_EQ(0u, spikes.size()); - - // Only one in this time frame. - ep.advance(0.11); - sut.advance(ep, dt, {}); - spikes = sut.spikes(); - EXPECT_EQ(1u, spikes.size()); - ASSERT_FLOAT_EQ(spike_time, spikes[0].time); - - // Clear the spikes after 'processing' them. - sut.clear_spikes(); - spikes = sut.spikes(); - EXPECT_EQ(0u, spikes.size()); - - // No spike to be emitted. - ep.advance(0.12); - sut.advance(ep, dt, {}); - spikes = sut.spikes(); - EXPECT_EQ(0u, spikes.size()); - - // Reset the internal state. - sut.reset(); - - // Expect to have the one spike again after reset. - ep.advance(0.2); - sut.advance(ep, dt, {}); - spikes = sut.spikes(); - EXPECT_EQ(1u, spikes.size()); - ASSERT_FLOAT_EQ(spike_time, spikes[0].time); -} - - -TEST(dss_cell, cell_kind_correct) -{ - const time_type spike_time = 0.1; - dss_recipe rec(1u, dss_cell_description({spike_time})); - dss_cell_group sut({0}, rec); - - EXPECT_EQ(cell_kind::data_spike_source, sut.get_cell_kind()); -} diff --git a/tests/unit/test_event_generators.cpp b/tests/unit/test_event_generators.cpp index 34afb0588216a0bbce0d7965fbdfa10fbdc8c52f..8d8b538ac6f66fd8090768c0eb9d25071cef5e80 100644 --- a/tests/unit/test_event_generators.cpp +++ b/tests/unit/test_event_generators.cpp @@ -12,43 +12,14 @@ namespace{ gen.reset(); gen.advance(t0); pse_vector v; - while (gen.next().time<t1) { - v.push_back(gen.next()); + while (gen.front().time<t1) { + v.push_back(gen.front()); gen.pop(); } return v; } } -TEST(event_generators, vector_backed) { - std::vector<pse> in = { - {{0, 0}, 0.1, 1.0}, - {{0, 0}, 1.0, 2.0}, - {{0, 0}, 1.0, 3.0}, - {{0, 0}, 1.5, 4.0}, - {{0, 0}, 2.3, 5.0}, - {{0, 0}, 3.0, 6.0}, - {{0, 0}, 3.5, 7.0}, - }; - - vector_backed_generator gen(in); - - // Test pop, next and reset. - for (auto e: in) { - EXPECT_EQ(e, gen.next()); - gen.pop(); - } - gen.reset(); - for (auto e: in) { - EXPECT_EQ(e, gen.next()); - gen.pop(); - } - - // The loop above should have drained all events from gen, so we expect - // that the next() event will be the special terminal_pse event. - EXPECT_EQ(gen.next(), terminal_pse()); -} - TEST(event_generators, regular) { // make a regular generator that generates its first event at t=2ms and subsequent // events regularly spaced 0.5 ms apart. @@ -71,65 +42,21 @@ TEST(event_generators, regular) { // Test pop, next and reset. for (auto e: expected({2.0, 2.5, 3.0, 3.5, 4.0, 4.5})) { - EXPECT_EQ(e, gen.next()); + EXPECT_EQ(e, gen.front()); gen.pop(); } gen.reset(); for (auto e: expected({2.0, 2.5, 3.0, 3.5, 4.0, 4.5})) { - EXPECT_EQ(e, gen.next()); + EXPECT_EQ(e, gen.front()); gen.pop(); } gen.reset(); // Test advance gen.advance(10.1); - EXPECT_EQ(gen.next().time, time_type(10.5)); + EXPECT_EQ(gen.front().time, time_type(10.5)); gen.advance(12); - EXPECT_EQ(gen.next().time, time_type(12)); -} - -// Test for rounding problems with large time values and the regular generator -TEST(event_generator, regular_rounding) { - // make a regular generator that generates its first event at t=2ms and subsequent - // events regularly spaced 0.5 ms apart. - time_type t0 = 2.0; - time_type dt = 0.5; - cell_member_type target = {42, 3}; - float weight = 3.14; - - // Test for rounding problems with large time values. - // To better understand why this is an issue, uncomment the following: - // float T = 1802667.0f, DT = 0.024999f; - // std::size_t N = std::floor(T/DT); - // std::cout << "T " << T << " DT " << DT << " N " << N - // << " T-N*DT " << T - (N*DT) << " P " << (T - (N*DT))/DT << "\n"; - t0 = 1802667.0f; - dt = 0.024999f; - time_type int_len = 5*dt; - time_type t1 = t0 + int_len; - time_type t2 = t1 + int_len; - event_generator gen = regular_generator(target, weight, t0, dt); - - // Take the interval I_a: t ∈ [t0, t2) - // And the two sub-interavls - // I_l: t ∈ [t0, t1) - // I_r: t ∈ [t1, t2) - // Such that I_a = I_l ∪ I_r. - // If we draw events from each interval then merge them, we expect same set - // of events as when we draw from that large interval. - pse_vector int_l = draw(gen, t0, t1); - pse_vector int_r = draw(gen, t1, t2); - pse_vector int_a = draw(gen, t0, t2); - - pse_vector int_merged = int_l; - util::append(int_merged, int_r); - - EXPECT_TRUE(int_l.front().time >= t0); - EXPECT_TRUE(int_l.back().time < t1); - EXPECT_TRUE(int_r.front().time >= t1); - EXPECT_TRUE(int_r.back().time < t2); - EXPECT_EQ(int_a, int_merged); - EXPECT_TRUE(std::is_sorted(int_a.begin(), int_a.end())); + EXPECT_EQ(gen.front().time, time_type(12)); } TEST(event_generators, seq) { @@ -151,17 +78,17 @@ TEST(event_generators, seq) { // Test pop, next and reset. for (auto e: in) { - EXPECT_EQ(e, gen.next()); + EXPECT_EQ(e, gen.front()); gen.pop(); } gen.reset(); for (auto e: in) { - EXPECT_EQ(e, gen.next()); + EXPECT_EQ(e, gen.front()); gen.pop(); } // The loop above should have drained all events from gen, so we expect - // that the next() event will be the special terminal_pse event. - EXPECT_EQ(gen.next(), terminal_pse()); + // that the front() event will be the special terminal_pse event. + EXPECT_TRUE(is_terminal_pse(gen.front())); gen.reset(); @@ -215,8 +142,8 @@ TEST(event_generators, poisson) { pgen gen(target, weight, G, t0, lambda); pse_vector int1; - while (gen.next().time<t1) { - int1.push_back(gen.next()); + while (gen.front().time<t1) { + int1.push_back(gen.front()); gen.pop(); } // Test that the output is sorted @@ -225,8 +152,8 @@ TEST(event_generators, poisson) { // Reset and generate the same sequence of events gen.reset(); pse_vector int2; - while (gen.next().time<t1) { - int2.push_back(gen.next()); + while (gen.front().time<t1) { + int2.push_back(gen.front()); gen.pop(); } @@ -234,31 +161,3 @@ TEST(event_generators, poisson) { EXPECT_EQ(int1, int2); } -// Test a poisson generator that has a tstop past which no events -// should be generated. -TEST(event_generators, poisson_terminates) { - std::mt19937_64 G; - using pgen = poisson_generator<std::mt19937_64>; - - time_type t0 = 0; - time_type t1 = 10; - time_type t2 = 1e7; // pick a time far past the end of the interval [t0, t1) - time_type lambda = 10; // expect 10 events per ms - cell_member_type target{4, 2}; - float weight = 42; - // construct generator with explicit end time t1 - pgen gen(target, weight, G, t0, lambda, t1); - - pse_vector events; - // pull events off the generator well past the end of the end time t1 - while (gen.next().time<t2) { - events.push_back(gen.next()); - gen.pop(); - } - - // the generator should be exhausted - EXPECT_EQ(gen.next(), terminal_pse()); - - // the last event should be less than the end time - EXPECT_TRUE(events.back().time<t1); -} diff --git a/tests/unit/test_lif_cell_group.cpp b/tests/unit/test_lif_cell_group.cpp index b931d7034043a5cfb7addb38ffc77e52fdb7a9c4..43785e120b7a862ce19d3d339eff414723e3a489 100644 --- a/tests/unit/test_lif_cell_group.cpp +++ b/tests/unit/test_lif_cell_group.cpp @@ -6,9 +6,8 @@ #include <lif_cell_group.hpp> #include <load_balance.hpp> #include <simulation.hpp> +#include <spike_source_cell_group.hpp> #include <recipe.hpp> -#include <rss_cell.hpp> -#include <rss_cell_group.hpp> using namespace arb; // Simple ring network of LIF neurons. @@ -26,7 +25,7 @@ public: // LIF neurons have gid in range [1..n_lif_cells_] whereas fake cell is numbered with 0. cell_kind get_cell_kind(cell_gid_type gid) const override { if (gid == 0) { - return cell_kind::regular_spike_source; + return cell_kind::spike_source; } return cell_kind::lif_neuron; } @@ -60,11 +59,7 @@ public: // regularly spiking cell. if (gid == 0) { // Produces just a single spike at time 0ms. - auto rs = arb::rss_cell(); - rs.start_time = 0; - rs.period = 1; - rs.stop_time = 0.5; - return rs; + return spike_source_cell{vector_time_seq({0.f})}; } // LIF cell. return lif_cell_description(); diff --git a/tests/unit/test_merge_events.cpp b/tests/unit/test_merge_events.cpp index 2afc38dfce5e9bfab900506a30b9ad0a7277605c..03c58a6a57adba4fb20c6b513a843467df2bfce0 100644 --- a/tests/unit/test_merge_events.cpp +++ b/tests/unit/test_merge_events.cpp @@ -14,7 +14,7 @@ TEST(merge_events, empty) pse_vector lc; pse_vector lf; - merge_events(0, max_time, lc, events, empty_gens, lf); + merge_events(0, terminal_time, lc, events, empty_gens, lf); EXPECT_EQ(lf.size(), 0u); } @@ -46,7 +46,7 @@ TEST(merge_events, no_overlap) {{0, 0}, 11, 1}, }; - merge_events(10, max_time, lc, events, empty_gens, lf); + merge_events(10, terminal_time, lc, events, empty_gens, lf); pse_vector expected = { {{8, 0}, 10, 4}, @@ -84,7 +84,7 @@ TEST(merge_events, overlap) {{7, 0}, 10, 8}, }; - merge_events(10, max_time, lc, events, empty_gens, lf); + merge_events(10, terminal_time, lc, events, empty_gens, lf); pse_vector expected = { {{7, 0}, 10, 8}, // from events @@ -214,8 +214,8 @@ TEST(merge_events, tourney_poisson) pse_vector expected; for (auto& gen: generators) { // Push all events before tfinal in gen to the expected values. - while (gen.next().time<tfinal) { - expected.push_back(gen.next()); + while (gen.front().time<tfinal) { + expected.push_back(gen.front()); gen.pop(); } // Reset the generator so that it is ready to generate the same diff --git a/tests/unit/test_rss_cell.cpp b/tests/unit/test_rss_cell.cpp deleted file mode 100644 index f8c545a1cdd7ad7506ca6fb59b9e7c6270a72908..0000000000000000000000000000000000000000 --- a/tests/unit/test_rss_cell.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include "../gtest.h" - -#include <rss_cell.hpp> -#include <rss_cell_group.hpp> - -#include "../simple_recipes.hpp" - -using namespace arb; - -using rss_recipe = homogeneous_recipe<cell_kind::regular_spike_source, rss_cell>; - -TEST(rss_cell, basic_usage) -{ - constexpr time_type dt = 0.01; // dt is ignored by rss_cell_group::advance(). - - // Use floating point times with an exact representation in order to avoid - // rounding issues. - rss_cell desc{0.125, 0.03125, 0.5}; - rss_cell_group sut({0}, rss_recipe(1u, desc)); - - // No spikes in this time frame. - epoch ep(0, 0.1); - sut.advance(ep, dt, {}); - EXPECT_EQ(0u, sut.spikes().size()); - - // Only on in this time frame - sut.clear_spikes(); - ep.advance(0.127); - sut.advance(ep, dt, {}); - EXPECT_EQ(1u, sut.spikes().size()); - - // Reset cell group state. - sut.reset(); - - // Expect 12 spikes excluding the 0.5 end point. - ep.advance(0.5); - sut.advance(ep, dt, {}); - EXPECT_EQ(12u, sut.spikes().size()); -} - -TEST(rss_cell, poll_time_after_end_time) -{ - constexpr time_type dt = 0.01; // dt is ignored by rss_cell_group::advance(). - - rss_cell desc{0.125, 0.03125, 0.5}; - rss_cell_group sut({0}, rss_recipe(1u, desc)); - - // Expect 12 spikes in this time frame. - sut.advance(epoch(0, 0.7), dt, {}); - EXPECT_EQ(12u, sut.spikes().size()); - - // Now ask for spikes for a time slot already passed: - // It should result in zero spikes because of the internal state! - sut.clear_spikes(); - sut.advance(epoch(0, 0.2), dt, {}); - EXPECT_EQ(0u, sut.spikes().size()); - - sut.reset(); - - // Expect 12 excluding the 0.5 - sut.advance(epoch(0, 0.5), dt, {}); - EXPECT_EQ(12u, sut.spikes().size()); -} - -TEST(rss_cell, rate_bigger_then_epoch) -{ - constexpr time_type dt = 0.01; // dt is ignored by rss_cell_group::advance(). - - rss_cell desc{ 0.0, 100.0, 1000.0 }; - rss_cell_group sut({ 0 }, rss_recipe(1u, desc)); - - // take time steps of 10 ms - for (time_type start = 0.0; start < 1000.0; start += 10) { - sut.advance(epoch(start, start + 10.0), dt, {}); - } - // We spike once every 100 ms so in 1000.0 ms we should have 10 - EXPECT_EQ(10u, sut.spikes().size()); -} - - -TEST(rss_cell, cell_kind_correct) -{ - rss_cell desc{0.1, 0.01, 0.2}; - rss_cell_group sut({0}, rss_recipe(1u, desc)); - - EXPECT_EQ(cell_kind::regular_spike_source, sut.get_cell_kind()); -} diff --git a/tests/unit/test_spike_source.cpp b/tests/unit/test_spike_source.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7e9939b46fa84c834e6006ae365f7837f6ea8dfb --- /dev/null +++ b/tests/unit/test_spike_source.cpp @@ -0,0 +1,113 @@ +#include "../gtest.h" + +#include <spike_source_cell_group.hpp> +#include <time_sequence.hpp> +#include <util/unique_any.hpp> + +#include "../simple_recipes.hpp" + +using namespace arb; +using ss_recipe = homogeneous_recipe<cell_kind::spike_source, spike_source_cell>; +using pseq = arb::poisson_time_seq<std::mt19937_64>; + +// Test that a spike_source_cell_group identifies itself with the correct +// cell_kind enum value. +TEST(spike_source, cell_kind) +{ + ss_recipe rec(1u, spike_source_cell{vector_time_seq({})}); + spike_source_cell_group group({0}, rec); + + EXPECT_EQ(cell_kind::spike_source, group.get_cell_kind()); +} + +// Test that a spike_source_cell_group produces a sequence spikes with spike +// times corresponding to the underlying time_seq. +TEST(spike_source, matches_time_seq) +{ + auto test_seq = [](arb::time_seq seq) { + ss_recipe rec(1u, spike_source_cell{seq}); + spike_source_cell_group group({0}, rec); + + // epoch ending at 10ms + epoch ep(0, 10); + group.advance(ep, 1, {}); + for (auto s: group.spikes()) { + EXPECT_EQ(s.time, seq.front()); + seq.pop(); + } + EXPECT_TRUE(seq.front()>=ep.tfinal); + group.clear_spikes(); + + // advance to 20 ms and repeat + ep.advance(20); + group.advance(ep, 1, {}); + for (auto s: group.spikes()) { + EXPECT_EQ(s.time, seq.front()); + seq.pop(); + } + EXPECT_TRUE(seq.front()>=ep.tfinal); + }; + + std::mt19937_64 G; + test_seq(arb::regular_time_seq(0,1)); + test_seq(pseq(G, 0., 10)); // produce many spikes in each interval + test_seq(pseq(G, 0., 1e-6)); // very unlikely to produce any spikes in either interval +} + +// Test that a spike_source_cell_group will produce the same sequence of spikes +// after being reset. +TEST(spike_source, reset) +{ + auto test_seq = [](arb::time_seq seq) { + ss_recipe rec(1u, spike_source_cell{seq}); + spike_source_cell_group group({0}, rec); + + // Advance for 10 ms and store generated spikes in spikes1. + epoch ep(0, 10); + group.advance(ep, 1, {}); + auto spikes1 = group.spikes(); + + // Reset the model, then advance again to 10 ms, and store the + // generated spikes in spikes2. + group.reset(); + group.advance(ep, 1, {}); + auto spikes2 = group.spikes(); + + // Check that the same spikes were generated in each case. + EXPECT_EQ(spikes1, spikes2); + }; + + std::mt19937_64 G; + test_seq(arb::regular_time_seq(0,1)); + test_seq(pseq(G, 0., 10)); // produce many spikes in each interval + test_seq(pseq(G, 0., 1e-6)); // very unlikely to produce any spikes in either interval +} + +// Test that a spike_source_cell_group will produce the expected +// output when the underlying time_seq is finite. +TEST(spike_source, exhaust) +{ + // This test assumes that seq will exhaust itself before t=10 ms. + auto test_seq = [](arb::time_seq seq) { + ss_recipe rec(1u, spike_source_cell{seq}); + spike_source_cell_group group({0}, rec); + + // epoch ending at 10ms + epoch ep(0, 10); + group.advance(ep, 1, {}); + auto spikes = group.spikes(); + for (auto s: group.spikes()) { + EXPECT_EQ(s.time, seq.front()); + seq.pop(); + } + // The sequence shoule be exhausted, in which case the next value in the + // sequence should be marked as time_max. + EXPECT_EQ(seq.front(), arb::terminal_time); + // Check that the last spike was before the end of the epoch. + EXPECT_LT(spikes.back().time, time_type(10)); + }; + + std::mt19937_64 G; + test_seq(arb::regular_time_seq(0,1,5)); + test_seq(pseq(G, 0., 10, 5)); +} diff --git a/tests/unit/test_time_seq.cpp b/tests/unit/test_time_seq.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f5031513319072192e03cc3ea1a27b03f7649592 --- /dev/null +++ b/tests/unit/test_time_seq.cpp @@ -0,0 +1,174 @@ +#include "../gtest.h" +#include "common.hpp" + +#include <vector> + +#include <time_sequence.hpp> +#include <util/rangeutil.hpp> + +using namespace arb; +using pse = postsynaptic_spike_event; + +namespace{ + // Helper function that draws all samples in the half open interval + // t ∈ [t0, t1) from a time_seq. + std::vector<time_type> draw(time_seq& gen, time_type t0, time_type t1) { + gen.reset(); + gen.advance(t0); + std::vector<time_type> v; + while (gen.front()<t1) { + v.push_back(gen.front()); + gen.pop(); + } + return v; + } +} + +TEST(time_seq, vector) { + std::vector<time_type> times = {0.1, 1.0, 1.0, 1.5, 2.3, 3.0, 3.5, }; + + vector_time_seq seq(times); + + // Test pop, next and reset. + for (auto t: times) { + EXPECT_EQ(t, seq.front()); + seq.pop(); + } + seq.reset(); + for (auto t: times) { + EXPECT_EQ(t, seq.front()); + seq.pop(); + } + + // The loop above should have drained all samples from seq, so we expect + // that the front() time sample will be terminal_time. + EXPECT_EQ(seq.front(), terminal_time); +} + +TEST(time_seq, regular) { + // make a regular generator that generates its first event at t=2ms and subsequent + // events regularly spaced 0.5 ms apart. + regular_time_seq seq(2, 0.5); + + // Test pop, next and reset. + for (auto e: {2.0, 2.5, 3.0, 3.5, 4.0, 4.5}) { + EXPECT_EQ(e, seq.front()); + seq.pop(); + } + seq.reset(); + for (auto e: {2.0, 2.5, 3.0, 3.5, 4.0, 4.5}) { + EXPECT_EQ(e, seq.front()); + seq.pop(); + } + seq.reset(); + + // Test advance() + + seq.advance(10.1); + // next event greater ≥ 10.1 should be 10.5 + EXPECT_EQ(seq.front(), time_type(10.5)); + + seq.advance(12); + // next event greater ≥ 12 should be 12 + EXPECT_EQ(seq.front(), time_type(12)); +} + +// Test for rounding problems with large time values and the regular sequence +TEST(time_seq, regular_rounding) { + // make a regular generator that generates its first time point at t=2ms + // and subsequent times regularly spaced 0.5 ms apart. + time_type t0 = 2.0; + time_type dt = 0.5; + + // Test for rounding problems with large time values. + // To better understand why this is an issue, uncomment the following: + // float T = 1802667.0f, DT = 0.024999f; + // std::size_t N = std::floor(T/DT); + // std::cout << "T " << T << " DT " << DT << " N " << N + // << " T-N*DT " << T - (N*DT) << " P " << (T - (N*DT))/DT << "\n"; + t0 = 1802667.0f; + dt = 0.024999f; + time_type int_len = 5*dt; + time_type t1 = t0 + int_len; + time_type t2 = t1 + int_len; + time_seq seq = regular_time_seq(t0, dt); + + // Take the interval I_a: t ∈ [t0, t2) + // And the two sub-interavls + // I_l: t ∈ [t0, t1) + // I_r: t ∈ [t1, t2) + // Such that I_a = I_l ∪ I_r. + // If we draw points from each interval then merge them, we expect same set + // of points as when we draw from that large interval. + std::vector<time_type> int_l = draw(seq, t0, t1); + std::vector<time_type> int_r = draw(seq, t1, t2); + std::vector<time_type> int_a = draw(seq, t0, t2); + + std::vector<time_type> int_merged = int_l; + util::append(int_merged, int_r); + + EXPECT_TRUE(int_l.front() >= t0); + EXPECT_TRUE(int_l.back() < t1); + EXPECT_TRUE(int_r.front() >= t1); + EXPECT_TRUE(int_r.back() < t2); + EXPECT_EQ(int_a, int_merged); + EXPECT_TRUE(std::is_sorted(int_a.begin(), int_a.end())); +} + +TEST(time_seq, poisson) { + std::mt19937_64 G; + using pseq = poisson_time_seq<std::mt19937_64>; + + time_type t0 = 0; + time_type t1 = 10; + time_type lambda = 10; // expect 10 samples per ms + + pseq seq(G, t0, lambda); + + std::vector<time_type> int1; + while (seq.front()<t1) { + int1.push_back(seq.front()); + seq.pop(); + } + // Test that the output is sorted + EXPECT_TRUE(std::is_sorted(int1.begin(), int1.end())); + + // Reset and generate the same sequence of time points + seq.reset(); + std::vector<time_type> int2; + while (seq.front()<t1) { + int2.push_back(seq.front()); + seq.pop(); + } + + // Assert that the same sequence was generated + EXPECT_EQ(int1, int2); +} + +// Test a poisson generator that has a tstop past which no samples +// should be generated. +TEST(time_seq, poisson_terminates) { + std::mt19937_64 G; + using pseq = poisson_time_seq<std::mt19937_64>; + + time_type t0 = 0; + time_type t1 = 10; + time_type t2 = 1e7; // pick a time far past the end of the interval [t0, t1) + time_type lambda = 10; // expect 10 samples per ms + + // construct sequence with explicit end time t1 + pseq seq(G, t0, lambda, t1); + + std::vector<time_type> sequence; + // pull samples off the sequence well past the end of the end time t1 + while (seq.front()<t2) { + sequence.push_back(seq.front()); + seq.pop(); + } + + // the sequence should be exhausted + EXPECT_EQ(seq.front(), terminal_time); + + // the last sample should be less than the end time + EXPECT_TRUE(sequence.back()<t1); +}