diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index 495e6986ad91df087e2a95a0ca49a11d7216170e..c4add5583f21a69707998d6cf1a6591a78b58d39 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -36,13 +36,12 @@ using lowered_cell = fvm::fvm_multicell<multicore::backend>; #endif using model_type = model<lowered_cell>; using sample_trace_type = sample_trace<model_type::time_type, model_type::value_type>; -using file_export_type = io::exporter_spike_file<model_type::time_type, global_policy>; +using file_export_type = io::exporter_spike_file<global_policy>; void banner(); std::unique_ptr<recipe> make_recipe(const io::cl_options&, const probe_distribution&); std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe); std::pair<cell_gid_type, cell_gid_type> distribute_cells(cell_size_type ncells); -using communicator_type = communication::communicator<model_type::time_type, communication::global_policy>; -using spike_type = typename communicator_type::spike_type; +using communicator_type = communication::communicator<communication::global_policy>; void write_trace_json(const sample_trace_type& trace, const std::string& prefix = "trace_"); void report_compartment_stats(const recipe&); @@ -140,14 +139,14 @@ int main(int argc, char** argv) { if (options.single_file_per_rank) { file_exporter = register_exporter(options); m.set_local_spike_callback( - [&](const std::vector<spike_type>& spikes) { + [&](const std::vector<spike>& spikes) { file_exporter->output(spikes); }); } else if(communication::global_policy::id()==0) { file_exporter = register_exporter(options); m.set_global_spike_callback( - [&](const std::vector<spike_type>& spikes) { + [&](const std::vector<spike>& spikes) { file_exporter->output(spikes); }); } diff --git a/src/cell_group.hpp b/src/cell_group.hpp index ef4c71b2344390b4a705df29969d7dabb600354b..990f055ff139e42e0a30eb9deaa9b57498362c67 100644 --- a/src/cell_group.hpp +++ b/src/cell_group.hpp @@ -28,7 +28,7 @@ public: using size_type = typename lowered_cell_type::value_type; using source_id_type = cell_member_type; - using time_type = float; + using time_type = spike::time_type; using sampler_function = std::function<util::optional<time_type>(time_type, double)>; cell_group() = default; @@ -141,8 +141,7 @@ public: } } - const std::vector<spike<source_id_type, time_type>>& - spikes() const { + const std::vector<spike>& spikes() const { return spikes_; } @@ -150,8 +149,7 @@ public: spikes_.clear(); } - const std::vector<source_id_type>& - spike_sources() const { + const std::vector<source_id_type>& spike_sources() const { return spike_sources_; } @@ -197,7 +195,7 @@ private: std::vector<source_id_type> spike_sources_; /// spikes that are generated - std::vector<spike<source_id_type, time_type>> spikes_; + std::vector<spike> spikes_; /// pending events to be delivered event_queue<postsynaptic_spike_event<time_type>> events_; diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index 764020dcfc4accda25c7074469fa515f7f755ab4..ffaad887ca2952a4ae19fd55f8c111c28fc1d173 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -29,14 +29,12 @@ namespace communication { // Once all connections have been specified, the construct() method can be used // to build the data structures required for efficient spike communication and // event generation. -template <typename Time, typename CommunicationPolicy> + +template <typename CommunicationPolicy> class communicator { public: using communication_policy_type = CommunicationPolicy; - using id_type = cell_gid_type; - using time_type = Time; - using spike_type = spike<cell_member_type, time_type>; - using connection_type = connection<time_type>; + using time_type = spike::time_type; /// per-cell group lists of events to be delivered using event_queue = @@ -56,13 +54,13 @@ public: return cell_gid_partition_.size(); } - void add_connection(connection_type con) { + void add_connection(connection con) { EXPECTS(is_local_cell(con.destination().gid)); connections_.push_back(con); } /// returns true if the cell with gid is on the domain of the caller - bool is_local_cell(id_type gid) const { + bool is_local_cell(cell_gid_type gid) const { return algorithms::in_interval(gid, cell_gid_partition_.bounds()); } @@ -88,7 +86,7 @@ public: /// /// Takes as input the list of local_spikes that were generated on the calling domain. /// Returns the full global set of vectors, along with meta data about their partition - gathered_vector<spike_type> exchange(const std::vector<spike_type>& local_spikes) { + gathered_vector<spike> exchange(const std::vector<spike>& local_spikes) { // global all-to-all to gather a local copy of the global spike list on each node. auto global_spikes = communication_policy_.gather_spikes( local_spikes ); num_spikes_ += global_spikes.size(); @@ -102,7 +100,7 @@ public: /// Returns a vector of event queues, with one queue for each local cell group. The /// events in each queue are all events that must be delivered to targets in that cell /// group as a result of the global spike exchange. - std::vector<event_queue> make_event_queues(const gathered_vector<spike_type>& global_spikes) { + std::vector<event_queue> make_event_queues(const gathered_vector<spike>& global_spikes) { auto queues = std::vector<event_queue>(num_groups_local()); for (auto spike : global_spikes.values()) { // search for targets @@ -121,7 +119,7 @@ public: /// Returns the total number of global spikes over the duration of the simulation uint64_t num_spikes() const { return num_spikes_; } - const std::vector<connection_type>& connections() const { + const std::vector<connection>& connections() const { return connections_; } @@ -139,7 +137,7 @@ private: return cell_gid_partition_.index(cell_gid); } - std::vector<connection_type> connections_; + std::vector<connection> connections_; communication_policy_type communication_policy_; diff --git a/src/connection.hpp b/src/connection.hpp index 361aee20646113fac75bcde3139fad1026e427a9..3897c6fba8c3440bcbc828e7488eca400ea10793 100644 --- a/src/connection.hpp +++ b/src/connection.hpp @@ -9,15 +9,13 @@ namespace nest { namespace mc { -template <typename Time> class connection { public: - using id_type = cell_member_type; - using time_type = Time; + using time_type = spike::time_type; - connection()=default; + connection() = default; - connection(id_type src, id_type dest, float w, time_type d) : + connection(cell_member_type src, cell_member_type dest, float w, time_type d) : source_(src), destination_(dest), weight_(w), @@ -25,18 +23,18 @@ public: {} float weight() const { return weight_; } - float delay() const { return delay_; } + time_type delay() const { return delay_; } - id_type source() const { return source_; } - id_type destination() const { return destination_; } + cell_member_type source() const { return source_; } + cell_member_type destination() const { return destination_; } - postsynaptic_spike_event<time_type> make_event(spike<id_type, time_type> s) { + postsynaptic_spike_event<time_type> make_event(const spike& s) { return {destination_, s.time + delay_, weight_}; } private: - id_type source_; - id_type destination_; + cell_member_type source_; + cell_member_type destination_; float weight_; time_type delay_; }; @@ -44,26 +42,22 @@ private: // connections are sorted by source id // these operators make for easy interopability with STL algorithms -template <typename T> -static inline bool operator<(connection<T> lhs, connection<T> rhs) { +static inline bool operator<(const connection& lhs, const connection& rhs) { return lhs.source() < rhs.source(); } -template <typename T> -static inline bool operator<(connection<T> lhs, typename connection<T>::id_type rhs) { +static inline bool operator<(const connection& lhs, cell_member_type rhs) { return lhs.source() < rhs; } -template <typename T> -static inline bool operator<(typename connection<T>::id_type lhs, connection<T> rhs) { +static inline bool operator<(cell_member_type lhs, const connection& rhs) { return lhs < rhs.source(); } } // namespace mc } // namespace nest -template <typename T> -static inline std::ostream& operator<<(std::ostream& o, nest::mc::connection<T> const& con) { +static inline std::ostream& operator<<(std::ostream& o, nest::mc::connection const& con) { return o << "con [" << con.source() << " -> " << con.destination() << " : weight " << con.weight() << ", delay " << con.delay() << "]"; diff --git a/src/io/exporter.hpp b/src/io/exporter.hpp index a5885b7b10841593059402a491ccf21148b78df3..e27ae636dada96e006936382873f114f731d927a 100644 --- a/src/io/exporter.hpp +++ b/src/io/exporter.hpp @@ -14,15 +14,11 @@ namespace io { // Exposes one virtual functions: // do_export(vector<type>) receiving a vector of parameters to export -template <typename Time, typename CommunicationPolicy> +template <typename CommunicationPolicy> class exporter { - public: - using time_type = Time; - using spike_type = spike<cell_member_type, time_type>; - // Performs the export of the data - virtual void output(const std::vector<spike_type>&) = 0; + virtual void output(const std::vector<spike>&) = 0; // Returns the status of the exporter virtual bool good() const = 0; diff --git a/src/io/exporter_spike_file.hpp b/src/io/exporter_spike_file.hpp index 94418712ef7578b7ceef9ae4415a30ca9fc10c06..090a4641b99851c2610b3932169eeeb6dacfb91f 100644 --- a/src/io/exporter_spike_file.hpp +++ b/src/io/exporter_spike_file.hpp @@ -19,11 +19,9 @@ namespace nest { namespace mc { namespace io { -template <typename Time, typename CommunicationPolicy> -class exporter_spike_file : public exporter<Time, CommunicationPolicy> { +template <typename CommunicationPolicy> +class exporter_spike_file : public exporter<CommunicationPolicy> { public: - using time_type = Time; - using spike_type = spike<cell_member_type, time_type>; using communication_policy_type = CommunicationPolicy; // Constructor @@ -53,7 +51,7 @@ public: // Performs export of the spikes to file. // One id and spike time with 4 decimals after the comma on a // line space separated. - void output(const std::vector<spike_type>& spikes) override { + void output(const std::vector<spike>& spikes) override { for (auto spike : spikes) { char linebuf[45]; auto n = diff --git a/src/model.hpp b/src/model.hpp index 048656cbd39ea36696a8cff5a5004ef28b4f6a94..c411b4a8092551aa974b052d3235a22094f33fab 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -28,10 +28,9 @@ public: using cell_group_type = cell_group<Cell>; using time_type = typename cell_group_type::time_type; using value_type = typename cell_group_type::value_type; - using communicator_type = communication::communicator<time_type, communication::global_policy>; + using communicator_type = communication::communicator<communication::global_policy>; using sampler_function = typename cell_group_type::sampler_function; - using spike_type = typename communicator_type::spike_type; - using spike_export_function = std::function<void(const std::vector<spike_type>&)>; + using spike_export_function = std::function<void(const std::vector<spike>&)>; struct probe_record { cell_member_type id; @@ -77,7 +76,7 @@ public: // generate the network connections for (cell_gid_type i: util::make_span(gid_partition().bounds())) { for (const auto& cc: rec.connections_on(i)) { - connection<time_type> conn{cc.source, cc.dest, cc.weight, cc.delay}; + connection conn{cc.source, cc.dest, cc.weight, cc.delay}; communicator_.add_connection(conn); } } @@ -264,7 +263,7 @@ private: using event_queue_type = typename communicator_type::event_queue; util::double_buffer<std::vector<event_queue_type>> event_queues_; - using local_spike_store_type = thread_private_spike_store<time_type>; + using local_spike_store_type = thread_private_spike_store; util::double_buffer<local_spike_store_type> local_spikes_; spike_export_function global_export_callback_ = util::nop_function; diff --git a/src/spike.hpp b/src/spike.hpp index d3ea02551456d2408712886dcff8125b314630e9..6675605dbfc18769ace33acac78e7bcaa681e61a 100644 --- a/src/spike.hpp +++ b/src/spike.hpp @@ -3,37 +3,43 @@ #include <ostream> #include <type_traits> +#include <common_types.hpp> + namespace nest { namespace mc { template <typename I, typename Time> -struct spike { +struct basic_spike { using id_type = I; using time_type = Time; id_type source = id_type{}; - time_type time = -1.; + time_type time = -1; - spike() = default; + basic_spike() = default; - spike(id_type s, time_type t) : + basic_spike(id_type s, time_type t): source(s), time(t) {} }; +/// Standard specialization: +using spike = basic_spike<cell_member_type, float>; + } // namespace mc } // namespace nest -/// custom stream operator for printing nest::mc::spike<> values +// Custom stream operator for printing nest::mc::spike<> values. template <typename I, typename T> -std::ostream& operator<<(std::ostream& o, nest::mc::spike<I, T> s) { +std::ostream& operator<<(std::ostream& o, nest::mc::basic_spike<I, T> s) { return o << "spike[t " << s.time << ", src " << s.source << "]"; } -/// less than comparison operator for nest::mc::spike<> values -/// spikes are ordered by spike time, for use in sorting and queueing +/// Less than comparison operator for nest::mc::spike<> values: +/// spikes are ordered by spike time, for use in sorting and queueing. template <typename I, typename T> -bool operator<(nest::mc::spike<I, T> lhs, nest::mc::spike<I, T> rhs) { +bool operator<(nest::mc::basic_spike<I, T> lhs, nest::mc::basic_spike<I, T> rhs) { return lhs.time < rhs.time; } + diff --git a/src/thread_private_spike_store.hpp b/src/thread_private_spike_store.hpp index 3f320fdd93709c622696e889f80c16588eb5ad66..4b1973e3aceae22cf9820b4161192904d68b2f17 100644 --- a/src/thread_private_spike_store.hpp +++ b/src/thread_private_spike_store.hpp @@ -15,17 +15,12 @@ namespace mc { /// The thread private buffer of the calling thread. /// The insert() and gather() methods add a vector of spikes to the buffer, /// and collate all of the buffers into a single vector respectively. -template <typename Time> class thread_private_spike_store { public : - using id_type = cell_gid_type; - using time_type = Time; - using spike_type = spike<cell_member_type, time_type>; - /// Collate all of the individual buffers into a single vector of spikes. /// Does not modify the buffer contents. - std::vector<spike_type> gather() const { - std::vector<spike_type> spikes; + std::vector<spike> gather() const { + std::vector<spike> spikes; unsigned num_spikes = 0u; for (auto& b : buffers_) { num_spikes += b.size(); @@ -40,12 +35,7 @@ public : } /// Return a reference to the thread private buffer of the calling thread - std::vector<spike_type>& get() { - return buffers_.local(); - } - - /// Return a reference to the thread private buffer of the calling thread - const std::vector<spike_type>& get() const { + std::vector<spike>& get() { return buffers_.local(); } @@ -58,7 +48,7 @@ public : /// Append the passed spikes to the end of the thread private buffer of the /// calling thread - void insert(const std::vector<spike_type>& spikes) { + void insert(const std::vector<spike>& spikes) { auto& buff = get(); buff.insert(buff.end(), spikes.begin(), spikes.end()); } @@ -66,7 +56,7 @@ public : private : /// thread private storage for accumulating spikes using local_spike_store_type = - threading::enumerable_thread_specific<std::vector<spike_type>>; + threading::enumerable_thread_specific<std::vector<spike>>; local_spike_store_type buffers_; diff --git a/tests/global_communication/test_communicator.cpp b/tests/global_communication/test_communicator.cpp index b6bcc03404322b6a7b32eda1afaa83b459056982..e3d427ae36e4df5a2e0b139a05bc3d8b14121b83 100644 --- a/tests/global_communication/test_communicator.cpp +++ b/tests/global_communication/test_communicator.cpp @@ -11,8 +11,7 @@ using namespace nest::mc; -using time_type = float; -using communicator_type = communication::communicator<time_type, communication::global_policy>; +using communicator_type = communication::communicator<communication::global_policy>; TEST(communicator, setup) { /* diff --git a/tests/global_communication/test_exporter_spike_file.cpp b/tests/global_communication/test_exporter_spike_file.cpp index 027a315bf05aeac72055d1fe64deb64bf9fbc911..376230ed607da7eca55b8671c1b593edd9d9b4e4 100644 --- a/tests/global_communication/test_exporter_spike_file.cpp +++ b/tests/global_communication/test_exporter_spike_file.cpp @@ -9,15 +9,14 @@ #include <communication/communicator.hpp> #include <communication/global_policy.hpp> #include <io/exporter_spike_file.hpp> +#include <spike.hpp> class exporter_spike_file_fixture : public ::testing::Test { protected: - using time_type = float; using communicator_type = nest::mc::communication::global_policy; using exporter_type = - nest::mc::io::exporter_spike_file<time_type, communicator_type>; - using spike_type = exporter_type::spike_type; + nest::mc::io::exporter_spike_file<communicator_type>; std::string file_name_; std::string path_; @@ -88,7 +87,7 @@ TEST_F(exporter_spike_file_fixture, do_export) { exporter_type exporter(file_name_, path_, extension_); // Create some spikes - std::vector<spike_type> spikes; + std::vector<nest::mc::spike> spikes; spikes.push_back({ { 0, 0 }, 0.0 }); spikes.push_back({ { 0, 0 }, 0.1 }); spikes.push_back({ { 1, 0 }, 1.0 }); diff --git a/tests/performance/io/disk_io.cpp b/tests/performance/io/disk_io.cpp index 9d77318f36bde0bc1b3366d3f9d79dff0513c6e0..05b64666c4502264d7d81cb1901e9d1152d2ae47 100644 --- a/tests/performance/io/disk_io.cpp +++ b/tests/performance/io/disk_io.cpp @@ -12,14 +12,13 @@ #include <fvm_multicell.hpp> #include <io/exporter_spike_file.hpp> #include <profiling/profiler.hpp> +#include <spike.hpp> using namespace nest::mc; using global_policy = communication::global_policy; using lowered_cell = fvm::fvm_multicell<multicore::backend>; using cell_group_type = cell_group<lowered_cell>; -using time_type = typename cell_group_type::time_type; -using spike_type = io::exporter_spike_file<time_type, global_policy>::spike_type; using timer = util::timer_type; int main(int argc, char** argv) { @@ -67,7 +66,7 @@ int main(int argc, char** argv) { } // Create the sut - io::exporter_spike_file<time_type, global_policy> exporter( + io::exporter_spike_file<global_policy> exporter( "spikes", "./", "gdf", true); // We need the nr of ranks to calculate the nr of spikes to produce per @@ -78,7 +77,7 @@ int main(int argc, char** argv) { auto spikes_per_rank = nr_spikes / nr_ranks; // Create a set of spikes - std::vector<spike_type> spikes; + std::vector<spike> spikes; // ********************************************************************* // To have a somewhat realworld data set we calculate from the nr of spikes diff --git a/tests/unit/test_spike_store.cpp b/tests/unit/test_spike_store.cpp index 362412e6def4060a395e109b3be97ef04f1e8d23..a8039018572cdaecc2515e33fced63b30d330a0b 100644 --- a/tests/unit/test_spike_store.cpp +++ b/tests/unit/test_spike_store.cpp @@ -1,11 +1,14 @@ #include "../gtest.h" +#include <spike.hpp> #include <threading/threading.hpp> #include <thread_private_spike_store.hpp> +using nest::mc::spike; + TEST(spike_store, insert) { - using store_type = nest::mc::thread_private_spike_store<float>; + using store_type = nest::mc::thread_private_spike_store; store_type store; @@ -49,7 +52,7 @@ TEST(spike_store, insert) TEST(spike_store, clear) { - using store_type = nest::mc::thread_private_spike_store<float>; + using store_type = nest::mc::thread_private_spike_store; store_type store; @@ -64,11 +67,11 @@ TEST(spike_store, clear) TEST(spike_store, gather) { - using store_type = nest::mc::thread_private_spike_store<float>; + using store_type = nest::mc::thread_private_spike_store; store_type store; - auto spikes = std::vector<store_type::spike_type> + std::vector<spike> spikes = { {{0,0}, 0.0f}, {{1,2}, 0.5f}, {{2,4}, 1.0f} }; store.insert(spikes);