diff --git a/miniapp/io.cpp b/miniapp/io.cpp index d3acd53097efaf5e16d8f7f60f400cc60aaaea48..89a976b9a0da367c22508cceb07a235809bf63d5 100644 --- a/miniapp/io.cpp +++ b/miniapp/io.cpp @@ -1,7 +1,11 @@ #include <algorithm> #include <exception> #include <fstream> +#include <iostream> #include <istream> +#include <memory> +#include <sstream> +#include <string> #include <type_traits> #include <tclap/CmdLine.h> @@ -9,6 +13,7 @@ #include <util/meta.hpp> #include <util/optional.hpp> +#include <util/strprintf.hpp> #include "io.hpp" @@ -191,6 +196,11 @@ cl_options read_options(int argc, char** argv, bool allow_write) { "z", "profile-only-zero", "Only output profile information for rank 0", cmd, false); TCLAP::SwitchArg verbose_arg( "v", "verbose", "Present more verbose information to stdout", cmd, false); + TCLAP::ValueArg<std::string> ispike_arg( + "I", "ispike_file", + "Input spikes from file", + false, "", "file name", cmd); + cmd.reorder_arguments(); cmd.parse(argc, argv); @@ -240,6 +250,11 @@ cl_options read_options(int argc, char** argv, bool allow_write) { update_option(options.file_extension, fopts, "file_extension"); } + update_option(options.spike_file_input, fopts, "spike_file_input"); + if (options.spike_file_input) { + update_option(options.input_spike_path, fopts, "input_spike_path"); + } + update_option(options.dry_run_ranks, fopts, "dry_run_ranks"); update_option(options.profile_only_zero, fopts, "profile_only_zero"); @@ -279,6 +294,12 @@ cl_options read_options(int argc, char** argv, bool allow_write) { update_option(options.profile_only_zero, profile_only_zero_arg); update_option(options.dry_run_ranks, dry_run_ranks_arg); + std::string is_file_name = ispike_arg.getValue(); + if (is_file_name != "") { + options.spike_file_input = true; + update_option(options.input_spike_path, ispike_arg); + } + if (options.trace_format!="csv" && options.trace_format!="json") { throw usage_error("trace format must be one of: csv, json"); } @@ -389,6 +410,50 @@ std::ostream& operator<<(std::ostream& o, const cl_options& options) { return o; } + +/// Parse spike times from a stream +/// A single spike per line, trailing whitespace is ignore +/// Throws a usage error when parsing fails +/// +/// Returns a vector of time_type + +std::vector<time_type> parse_spike_times_from_stream(std::ifstream & fid) { + std::vector<time_type> times; + std::string line; + while (std::getline(fid, line)) { + std::stringstream s(line); + + time_type t; + s >> t >> std::ws; + + if (!s || s.peek() != EOF) { + throw std::runtime_error( util::strprintf( + "Unable to parse spike file on line %d: \"%s\"\n", + times.size(), line)); + } + + times.push_back(t); + } + + return times; +} + +/// Parse spike times from a file supplied in path +/// A single spike per line, trailing white space is ignored +/// Throws a usage error when opening file or parsing fails +/// +/// Returns a vector of time_type + +std::vector<time_type> get_parsed_spike_times_from_path(nest::mc::util::path path) { + std::ifstream fid(path); + if (!fid) { + throw std::runtime_error(util::strprintf( + "Unable to parse spike file: \"%s\"\n", path.c_str())); + } + + return parse_spike_times_from_stream(fid); +} + } // namespace io } // namespace mc } // namespace nest diff --git a/miniapp/io.hpp b/miniapp/io.hpp index bc16a6b1099b7a911bf32bb456e9017bead41470..a2887d2cb344aa677c24673de9e162491e7e998a 100644 --- a/miniapp/io.hpp +++ b/miniapp/io.hpp @@ -1,12 +1,15 @@ #pragma once -#include <string> #include <cstdint> #include <iosfwd> #include <stdexcept> +#include <string> #include <utility> +#include <vector> +#include <common_types.hpp> #include <util/optional.hpp> +#include <util/path.hpp> namespace nest { namespace mc { @@ -51,6 +54,10 @@ struct cl_options { std::string file_name = "spikes"; std::string file_extension = "gdf"; + // Parameters for spike input. + bool spike_file_input = false; + std::string input_spike_path; // Path to file with spikes + // Dry run parameters (pertinent only when built with 'dryrun' distrib model). int dry_run_ranks = 1; @@ -80,6 +87,11 @@ std::ostream& operator<<(std::ostream& o, const cl_options& opt); cl_options read_options(int argc, char** argv, bool allow_write = true); +/// Helper function for loading a vector of spike times from file +/// Spike times are expected to be in milli seconds floating points +/// On spike-time per line + +std::vector<time_type> get_parsed_spike_times_from_path(nest::mc::util::path path); } // namespace io } // namespace mc diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index 57a6641246310c6f8c7e3a05978f1836c61d9ddf..b2f0d9ccbbd8760cbb8c56355d5526659c74cc5c 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -202,6 +202,11 @@ std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_d p.num_synapses = options.all_to_all? options.cells-1: options.synapses_per_cell; p.synapse_type = options.syn_type; + // Parameters for spike input from file + if (options.spike_file_input) { + p.input_spike_path = options.input_spike_path; + } + if (options.all_to_all) { return make_basic_kgraph_recipe(options.cells, p, pdist); } diff --git a/miniapp/miniapp_recipes.cpp b/miniapp/miniapp_recipes.cpp index e8c49422ef0e610e4478b7aaeefe9706d30d5fd6..45cb86f4ef3ca012969e5ca978bdaf589d058ebe 100644 --- a/miniapp/miniapp_recipes.cpp +++ b/miniapp/miniapp_recipes.cpp @@ -4,13 +4,16 @@ #include <utility> #include <cell.hpp> +#include <dss_cell_description.hpp> #include <rss_cell.hpp> #include <morphology.hpp> #include <util/debug.hpp> +#include "io.hpp" #include "miniapp_recipes.hpp" #include "morphology_pool.hpp" + namespace nest { namespace mc { @@ -84,10 +87,15 @@ public: } util::unique_any get_cell_description(cell_gid_type i) const override { - // The last 'cell' is a rss_cell with one spike at t=0 + // The last 'cell' is a spike source cell. Either a regular spiking + // or a spikes from file. if (i == ncell_) { - return util::unique_any(std::move( - rss_cell::rss_cell_description(0.0, 0.1, 0.1) )); + if (param_.input_spike_path) { + auto spike_times = io::get_parsed_spike_times_from_path(param_.input_spike_path.get()); + return util::unique_any(dss_cell_description(spike_times)); + } + + return util::unique_any(rss_cell::rss_cell_description(0.0, 0.1, 0.1)); } auto gen = std::mt19937(i); // TODO: replace this with hashing generator... @@ -124,6 +132,10 @@ public: cell_kind get_cell_kind(cell_gid_type i ) const override { // The last 'cell' is a rss_cell with one spike at t=0 if (i == ncell_) { + if (param_.input_spike_path) { + return cell_kind::data_spike_source; + } + return cell_kind::regular_spike_source; } return cell_kind::cable1d_neuron; diff --git a/miniapp/miniapp_recipes.hpp b/miniapp/miniapp_recipes.hpp index afc1a0103fa00707c2edb9a0ebb74866376d6698..bb3c5240b4238806d84b0a89e61e4357db4fc06c 100644 --- a/miniapp/miniapp_recipes.hpp +++ b/miniapp/miniapp_recipes.hpp @@ -5,6 +5,7 @@ #include <stdexcept> #include <recipe.hpp> +#include <util/optional.hpp> #include "morphology_pool.hpp" @@ -39,6 +40,10 @@ struct basic_recipe_param { // If true, iterate through morphologies rather than select randomly. bool morphology_round_robin = false; + + // If set we are importing the spikes injected in the network from file + // instead of a single spike at t==0 + util::optional<std::string> input_spike_path; // Path to file with spikes }; std::unique_ptr<recipe> make_basic_ring_recipe( diff --git a/src/cell_group_factory.cpp b/src/cell_group_factory.cpp index f47419d0b457e2a00c190e567c7aa716f6a619d5..77a604db1743b79777a7e4d3ad8d486d9e2fd4b3 100644 --- a/src/cell_group_factory.cpp +++ b/src/cell_group_factory.cpp @@ -2,9 +2,10 @@ #include <backends.hpp> #include <cell_group.hpp> -#include <rss_cell_group.hpp> +#include <dss_cell_group.hpp> #include <fvm_multicell.hpp> #include <mc_cell_group.hpp> +#include <rss_cell_group.hpp> #include <util/unique_any.hpp> namespace nest { @@ -31,6 +32,9 @@ cell_group_ptr cell_group_factory( case cell_kind::regular_spike_source: return make_cell_group<rss_cell_group>(first_gid, cell_descriptions); + case cell_kind::data_spike_source: + return make_cell_group<dss_cell_group>(first_gid, cell_descriptions); + default: throw std::runtime_error("unknown cell kind"); } diff --git a/src/common_types.hpp b/src/common_types.hpp index d5c1a04bd983348a77354a47c7df67e639193f31..74bb33f6c6bf55c13de90e813c3cb7fe9bccc7c0 100644 --- a/src/common_types.hpp +++ b/src/common_types.hpp @@ -59,6 +59,7 @@ using time_type = float; enum cell_kind { cable1d_neuron, // Our own special mc neuron regular_spike_source, // Regular spiking source + data_spike_source, // Spike source from values inserted via description }; } // namespace mc diff --git a/src/dss_cell_description.hpp b/src/dss_cell_description.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1fddebb75867071e916ed9b8141ea07c94091f41 --- /dev/null +++ b/src/dss_cell_description.hpp @@ -0,0 +1,25 @@ +#pragma once +#pragma once + +#include <vector> + +#include <common_types.hpp> +#include <util/debug.hpp> + +namespace nest { +namespace mc { + +/// Description for a data spike source: A cell that generates spikes provided as a vector of +/// floating point valued spiketimes at the start of a run +struct dss_cell_description { + std::vector<time_type> spike_times; + + /// The description needs a vector of doubles for the description + dss_cell_description(std::vector<time_type> & spike_times) : + spike_times(spike_times) + {} +}; + + +} // namespace mc +} // namespace nest diff --git a/src/dss_cell_group.hpp b/src/dss_cell_group.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c30d3f44adbadd8715b9d8f2b9f13e0fe776a461 --- /dev/null +++ b/src/dss_cell_group.hpp @@ -0,0 +1,120 @@ +#pragma once + +#include <cell_group.hpp> +#include <dss_cell_description.hpp> +#include <util/span.hpp> +#include <util/unique_any.hpp> + +namespace nest { +namespace mc { + +/// Cell_group to collect spike sources +class dss_cell_group : public cell_group { +public: + using source_id_type = cell_member_type; + + dss_cell_group(cell_gid_type first_gid, const std::vector<util::unique_any>& cell_descriptions): + gid_base_(first_gid) + { + using util::make_span; + for (cell_gid_type i: make_span(0, cell_descriptions.size())) { + // store spike times from description + const auto times = util::any_cast<dss_cell_description>(cell_descriptions[i]).spike_times; + spike_times_.push_back(std::vector<time_type>(times)); + + // Assure the spike times are sorted + std::sort(spike_times_[i].begin(), spike_times_[i].end()); + + // Now we can grab the first spike time + not_emit_it_.push_back(spike_times_[i].begin()); + + // create a lid to gid map + spike_sources_.push_back({gid_base_+i, 0}); + } + } + + virtual ~dss_cell_group() = default; + + cell_kind get_cell_kind() const override { + return cell_kind::data_spike_source; + } + + void reset() override { + // Declare both iterators outside of the for loop for consistency + auto it = not_emit_it_.begin(); + auto times = spike_times_.begin(); + + for (;it != not_emit_it_.end(); ++it, times++) { + // Point to the first not emitted spike. + *it = times->begin(); + } + + clear_spikes(); + } + + void set_binning_policy(binning_kind policy, time_type bin_interval) override + {} + + void advance(time_type tfinal, time_type dt) override { + for (auto cell_idx: util::make_span(0, not_emit_it_.size())) { + // The first potential spike_time to emit for this cell + auto spike_time_it = not_emit_it_[cell_idx]; + + // Find the first to not emit and store as iterator + not_emit_it_[cell_idx] = std::find_if( + spike_time_it, spike_times_[cell_idx].end(), + [tfinal](time_type t) {return t >= tfinal; } + ); + + // loop over the range we now have (might be empty), and create spikes + for (; spike_time_it != not_emit_it_[cell_idx]; ++spike_time_it) { + spikes_.push_back({ spike_sources_[cell_idx], *spike_time_it }); + } + } + }; + + void enqueue_events(const std::vector<postsynaptic_spike_event>& events) override { + std::logic_error("The dss_cells do not support incoming events!"); + } + + const std::vector<spike>& spikes() const override { + return spikes_; + } + + void clear_spikes() override { + spikes_.clear(); + } + + std::vector<probe_record> probes() const override { + return probes_; + } + + void add_sampler(cell_member_type probe_id, sampler_function s, time_type start_time = 0) override { + std::logic_error("The dss_cells do not support sampling of internal state!"); + } + +private: + // gid of first cell in group. + cell_gid_type gid_base_; + + // Spikes that are generated. + std::vector<spike> spikes_; + + // Spike generators attached to the cell + std::vector<source_id_type> spike_sources_; + + std::vector<probe_record> probes_; + + // The dss_cell is simple: Put all logic in the cellgroup cause accelerator support + // is not expected. We need storage for the cell state + + // The times the cells should spike, one for each cell, sorted in time + std::vector<std::vector<time_type> > spike_times_; + + // Each cell needs its own itterator to the first spike not yet emitted + std::vector<std::vector<time_type>::iterator > not_emit_it_; +}; + +} // namespace mc +} // namespace nest + diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 564f5964c57c1a948b754a03bd24f9a4124634cf..40427c1ab001a231e1001cafa096b24be472e3f8 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -41,6 +41,7 @@ set(TEST_SOURCES test_counter.cpp test_cycle.cpp test_domain_decomposition.cpp + test_dss_cell_group.cpp test_either.cpp test_event_queue.cpp test_event_binner.cpp @@ -62,6 +63,7 @@ set(TEST_SOURCES test_probe.cpp test_segment.cpp test_range.cpp + test_rss_cell.cpp test_span.cpp test_spikes.cpp test_spike_store.cpp diff --git a/tests/unit/test_dss_cell_group.cpp b/tests/unit/test_dss_cell_group.cpp new file mode 100644 index 0000000000000000000000000000000000000000..675487893551820ad0ae123003a94a0dccc9db24 --- /dev/null +++ b/tests/unit/test_dss_cell_group.cpp @@ -0,0 +1,75 @@ +#include "../gtest.h" + +#include "dss_cell_description.hpp" +#include "dss_cell_group.hpp" +#include <util/unique_any.hpp> + + +using namespace nest::mc; + +TEST(dss_cell, constructor) +{ + std::vector<time_type> spikes; + + std::vector<util::unique_any> cell_descriptions(1); + cell_descriptions[0] = util::unique_any(dss_cell_description(spikes)); + + dss_cell_group sut(0, cell_descriptions); +} + +TEST(dss_cell, basic_usage) +{ + std::vector<time_type> spikes_to_emit; + + time_type spike_time = 0.1; + spikes_to_emit.push_back(spike_time); + + std::vector<util::unique_any> cell_descriptions(1); + cell_descriptions[0] = util::unique_any(dss_cell_description(spikes_to_emit)); + + dss_cell_group sut(0, cell_descriptions); + + // no spikes in this time frame + sut.advance(0.09, 0.01); // The dt (0,01) is not used + + auto spikes = sut.spikes(); + EXPECT_EQ(size_t(0), spikes.size()); + + // only one in this time frame + sut.advance(0.11, 0.01); + spikes = sut.spikes(); + EXPECT_EQ(size_t(1), spikes.size()); + ASSERT_FLOAT_EQ(spike_time, spikes[0].time); + + // Clear the spikes after 'processing' them + sut.clear_spikes(); + spikes = sut.spikes(); + EXPECT_EQ(size_t(0), spikes.size()); + + // No spike to be emitted + sut.advance(0.12, 0.01); + spikes = sut.spikes(); + EXPECT_EQ(size_t(0), spikes.size()); + + // Reset the internal state to null + sut.reset(); + + // Expect 10 excluding the 0.2 + sut.advance(0.2, 0.01); + spikes = sut.spikes(); + EXPECT_EQ(size_t(1), spikes.size()); + ASSERT_FLOAT_EQ(spike_time, spikes[0].time); +} + + +TEST(dss_cell, cell_kind_correct) +{ + std::vector<time_type> spikes_to_emit; + + std::vector<util::unique_any> cell_descriptions(1); + cell_descriptions[0] = util::unique_any(dss_cell_description(spikes_to_emit)); + + dss_cell_group sut(0, cell_descriptions); + + EXPECT_EQ(cell_kind::data_spike_source, sut.get_cell_kind()); +}