From c83adfcf0b7c2e533f14b54f0525780fa74d8685 Mon Sep 17 00:00:00 2001
From: Sam Yates <yates@cscs.ch>
Date: Wed, 22 Mar 2017 14:11:58 +0100
Subject: [PATCH] Tame `time_type` proliferation. (#203)

Differing classes had their own time_type and other classes were parameterized on it in ways that were compatibile, but only by chance.

With these changes, modifying the time type used in spike will propagate through to all dependent classes.

  * Rename generic spike<I, T> as basic_spike<I, T>.
  * Use spike = basic_spike<cell_member_type, float> as the common spike type.
  * Replace instances of spike_type aliases with just spike.
  * time_type aliases are defined in terms of spike::time_type.
  * Remove time_type parameterization in connection.
  * Remove time_type parameterization in communicator.
  * Remove time_type parameterization in exporter classes.
---
 miniapp/miniapp.cpp                           |  9 +++---
 src/cell_group.hpp                            | 10 +++---
 src/communication/communicator.hpp            | 20 ++++++------
 src/connection.hpp                            | 32 ++++++++-----------
 src/io/exporter.hpp                           |  8 ++---
 src/io/exporter_spike_file.hpp                |  8 ++---
 src/model.hpp                                 |  9 +++---
 src/spike.hpp                                 | 24 ++++++++------
 src/thread_private_spike_store.hpp            | 20 +++---------
 .../test_communicator.cpp                     |  3 +-
 .../test_exporter_spike_file.cpp              |  7 ++--
 tests/performance/io/disk_io.cpp              |  7 ++--
 tests/unit/test_spike_store.cpp               | 11 ++++---
 13 files changed, 73 insertions(+), 95 deletions(-)

diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp
index 495e6986..c4add558 100644
--- a/miniapp/miniapp.cpp
+++ b/miniapp/miniapp.cpp
@@ -36,13 +36,12 @@ using lowered_cell = fvm::fvm_multicell<multicore::backend>;
 #endif
 using model_type = model<lowered_cell>;
 using sample_trace_type = sample_trace<model_type::time_type, model_type::value_type>;
-using file_export_type = io::exporter_spike_file<model_type::time_type, global_policy>;
+using file_export_type = io::exporter_spike_file<global_policy>;
 void banner();
 std::unique_ptr<recipe> make_recipe(const io::cl_options&, const probe_distribution&);
 std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe);
 std::pair<cell_gid_type, cell_gid_type> distribute_cells(cell_size_type ncells);
-using communicator_type = communication::communicator<model_type::time_type, communication::global_policy>;
-using spike_type = typename communicator_type::spike_type;
+using communicator_type = communication::communicator<communication::global_policy>;
 
 void write_trace_json(const sample_trace_type& trace, const std::string& prefix = "trace_");
 void report_compartment_stats(const recipe&);
@@ -140,14 +139,14 @@ int main(int argc, char** argv) {
             if (options.single_file_per_rank) {
                 file_exporter = register_exporter(options);
                 m.set_local_spike_callback(
-                    [&](const std::vector<spike_type>& spikes) {
+                    [&](const std::vector<spike>& spikes) {
                         file_exporter->output(spikes);
                     });
             }
             else if(communication::global_policy::id()==0) {
                 file_exporter = register_exporter(options);
                 m.set_global_spike_callback(
-                    [&](const std::vector<spike_type>& spikes) {
+                    [&](const std::vector<spike>& spikes) {
                        file_exporter->output(spikes);
                     });
             }
diff --git a/src/cell_group.hpp b/src/cell_group.hpp
index ef4c71b2..990f055f 100644
--- a/src/cell_group.hpp
+++ b/src/cell_group.hpp
@@ -28,7 +28,7 @@ public:
     using size_type  = typename lowered_cell_type::value_type;
     using source_id_type = cell_member_type;
 
-    using time_type = float;
+    using time_type = spike::time_type;
     using sampler_function = std::function<util::optional<time_type>(time_type, double)>;
 
     cell_group() = default;
@@ -141,8 +141,7 @@ public:
         }
     }
 
-    const std::vector<spike<source_id_type, time_type>>&
-    spikes() const {
+    const std::vector<spike>& spikes() const {
         return spikes_;
     }
 
@@ -150,8 +149,7 @@ public:
         spikes_.clear();
     }
 
-    const std::vector<source_id_type>&
-    spike_sources() const {
+    const std::vector<source_id_type>& spike_sources() const {
         return spike_sources_;
     }
 
@@ -197,7 +195,7 @@ private:
     std::vector<source_id_type> spike_sources_;
 
     /// spikes that are generated
-    std::vector<spike<source_id_type, time_type>> spikes_;
+    std::vector<spike> spikes_;
 
     /// pending events to be delivered
     event_queue<postsynaptic_spike_event<time_type>> events_;
diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp
index 764020dc..ffaad887 100644
--- a/src/communication/communicator.hpp
+++ b/src/communication/communicator.hpp
@@ -29,14 +29,12 @@ namespace communication {
 // Once all connections have been specified, the construct() method can be used
 // to build the data structures required for efficient spike communication and
 // event generation.
-template <typename Time, typename CommunicationPolicy>
+
+template <typename CommunicationPolicy>
 class communicator {
 public:
     using communication_policy_type = CommunicationPolicy;
-    using id_type = cell_gid_type;
-    using time_type = Time;
-    using spike_type = spike<cell_member_type, time_type>;
-    using connection_type = connection<time_type>;
+    using time_type = spike::time_type;
 
     /// per-cell group lists of events to be delivered
     using event_queue =
@@ -56,13 +54,13 @@ public:
         return cell_gid_partition_.size();
     }
 
-    void add_connection(connection_type con) {
+    void add_connection(connection con) {
         EXPECTS(is_local_cell(con.destination().gid));
         connections_.push_back(con);
     }
 
     /// returns true if the cell with gid is on the domain of the caller
-    bool is_local_cell(id_type gid) const {
+    bool is_local_cell(cell_gid_type gid) const {
         return algorithms::in_interval(gid, cell_gid_partition_.bounds());
     }
 
@@ -88,7 +86,7 @@ public:
     ///
     /// Takes as input the list of local_spikes that were generated on the calling domain.
     /// Returns the full global set of vectors, along with meta data about their partition
-    gathered_vector<spike_type> exchange(const std::vector<spike_type>& local_spikes) {
+    gathered_vector<spike> exchange(const std::vector<spike>& local_spikes) {
         // global all-to-all to gather a local copy of the global spike list on each node.
         auto global_spikes = communication_policy_.gather_spikes( local_spikes );
         num_spikes_ += global_spikes.size();
@@ -102,7 +100,7 @@ public:
     /// Returns a vector of event queues, with one queue for each local cell group. The
     /// events in each queue are all events that must be delivered to targets in that cell
     /// group as a result of the global spike exchange.
-    std::vector<event_queue> make_event_queues(const gathered_vector<spike_type>& global_spikes) {
+    std::vector<event_queue> make_event_queues(const gathered_vector<spike>& global_spikes) {
         auto queues = std::vector<event_queue>(num_groups_local());
         for (auto spike : global_spikes.values()) {
             // search for targets
@@ -121,7 +119,7 @@ public:
     /// Returns the total number of global spikes over the duration of the simulation
     uint64_t num_spikes() const { return num_spikes_; }
 
-    const std::vector<connection_type>& connections() const {
+    const std::vector<connection>& connections() const {
         return connections_;
     }
 
@@ -139,7 +137,7 @@ private:
         return cell_gid_partition_.index(cell_gid);
     }
 
-    std::vector<connection_type> connections_;
+    std::vector<connection> connections_;
 
     communication_policy_type communication_policy_;
 
diff --git a/src/connection.hpp b/src/connection.hpp
index 361aee20..3897c6fb 100644
--- a/src/connection.hpp
+++ b/src/connection.hpp
@@ -9,15 +9,13 @@
 namespace nest {
 namespace mc {
 
-template <typename Time>
 class connection {
 public:
-    using id_type = cell_member_type;
-    using time_type = Time;
+    using time_type = spike::time_type;
 
-    connection()=default;
+    connection() = default;
 
-    connection(id_type src, id_type dest, float w, time_type d) :
+    connection(cell_member_type src, cell_member_type dest, float w, time_type d) :
         source_(src),
         destination_(dest),
         weight_(w),
@@ -25,18 +23,18 @@ public:
     {}
 
     float weight() const { return weight_; }
-    float delay() const { return delay_; }
+    time_type delay() const { return delay_; }
 
-    id_type source() const { return source_; }
-    id_type destination() const { return destination_; }
+    cell_member_type source() const { return source_; }
+    cell_member_type destination() const { return destination_; }
 
-    postsynaptic_spike_event<time_type> make_event(spike<id_type, time_type> s) {
+    postsynaptic_spike_event<time_type> make_event(const spike& s) {
         return {destination_, s.time + delay_, weight_};
     }
 
 private:
-    id_type source_;
-    id_type destination_;
+    cell_member_type source_;
+    cell_member_type destination_;
     float weight_;
     time_type delay_;
 };
@@ -44,26 +42,22 @@ private:
 // connections are sorted by source id
 // these operators make for easy interopability with STL algorithms
 
-template <typename T>
-static inline bool operator<(connection<T> lhs, connection<T> rhs) {
+static inline bool operator<(const connection& lhs, const connection& rhs) {
     return lhs.source() < rhs.source();
 }
 
-template <typename T>
-static inline bool operator<(connection<T> lhs, typename connection<T>::id_type rhs) {
+static inline bool operator<(const connection& lhs, cell_member_type rhs) {
     return lhs.source() < rhs;
 }
 
-template <typename T>
-static inline bool operator<(typename connection<T>::id_type lhs, connection<T> rhs) {
+static inline bool operator<(cell_member_type lhs, const connection& rhs) {
     return lhs < rhs.source();
 }
 
 } // namespace mc
 } // namespace nest
 
-template <typename T>
-static inline std::ostream& operator<<(std::ostream& o, nest::mc::connection<T> const& con) {
+static inline std::ostream& operator<<(std::ostream& o, nest::mc::connection const& con) {
     return o << "con [" << con.source() << " -> " << con.destination()
              << " : weight " << con.weight()
              << ", delay " << con.delay() << "]";
diff --git a/src/io/exporter.hpp b/src/io/exporter.hpp
index a5885b7b..e27ae636 100644
--- a/src/io/exporter.hpp
+++ b/src/io/exporter.hpp
@@ -14,15 +14,11 @@ namespace io {
 // Exposes one virtual functions:
 //    do_export(vector<type>) receiving a vector of parameters to export
 
-template <typename Time, typename CommunicationPolicy>
+template <typename CommunicationPolicy>
 class exporter {
-
 public:
-    using time_type = Time;
-    using spike_type = spike<cell_member_type, time_type>;
-
     // Performs the export of the data
-    virtual void output(const std::vector<spike_type>&) = 0;
+    virtual void output(const std::vector<spike>&) = 0;
 
     // Returns the status of the exporter
     virtual bool good() const = 0;
diff --git a/src/io/exporter_spike_file.hpp b/src/io/exporter_spike_file.hpp
index 94418712..090a4641 100644
--- a/src/io/exporter_spike_file.hpp
+++ b/src/io/exporter_spike_file.hpp
@@ -19,11 +19,9 @@ namespace nest {
 namespace mc {
 namespace io {
 
-template <typename Time, typename CommunicationPolicy>
-class exporter_spike_file : public exporter<Time, CommunicationPolicy> {
+template <typename CommunicationPolicy>
+class exporter_spike_file : public exporter<CommunicationPolicy> {
 public:
-    using time_type = Time;
-    using spike_type = spike<cell_member_type, time_type>;
     using communication_policy_type = CommunicationPolicy;
 
     // Constructor
@@ -53,7 +51,7 @@ public:
     // Performs export of the spikes to file.
     // One id and spike time with 4 decimals after the comma on a
     // line space separated.
-    void output(const std::vector<spike_type>& spikes) override {
+    void output(const std::vector<spike>& spikes) override {
         for (auto spike : spikes) {
             char linebuf[45];
             auto n =
diff --git a/src/model.hpp b/src/model.hpp
index 048656cb..c411b4a8 100644
--- a/src/model.hpp
+++ b/src/model.hpp
@@ -28,10 +28,9 @@ public:
     using cell_group_type = cell_group<Cell>;
     using time_type = typename cell_group_type::time_type;
     using value_type = typename cell_group_type::value_type;
-    using communicator_type = communication::communicator<time_type, communication::global_policy>;
+    using communicator_type = communication::communicator<communication::global_policy>;
     using sampler_function = typename cell_group_type::sampler_function;
-    using spike_type = typename communicator_type::spike_type;
-    using spike_export_function = std::function<void(const std::vector<spike_type>&)>;
+    using spike_export_function = std::function<void(const std::vector<spike>&)>;
 
     struct probe_record {
         cell_member_type id;
@@ -77,7 +76,7 @@ public:
         // generate the network connections
         for (cell_gid_type i: util::make_span(gid_partition().bounds())) {
             for (const auto& cc: rec.connections_on(i)) {
-                connection<time_type> conn{cc.source, cc.dest, cc.weight, cc.delay};
+                connection conn{cc.source, cc.dest, cc.weight, cc.delay};
                 communicator_.add_connection(conn);
             }
         }
@@ -264,7 +263,7 @@ private:
     using event_queue_type = typename communicator_type::event_queue;
     util::double_buffer<std::vector<event_queue_type>> event_queues_;
 
-    using local_spike_store_type = thread_private_spike_store<time_type>;
+    using local_spike_store_type = thread_private_spike_store;
     util::double_buffer<local_spike_store_type> local_spikes_;
 
     spike_export_function global_export_callback_ = util::nop_function;
diff --git a/src/spike.hpp b/src/spike.hpp
index d3ea0255..6675605d 100644
--- a/src/spike.hpp
+++ b/src/spike.hpp
@@ -3,37 +3,43 @@
 #include <ostream>
 #include <type_traits>
 
+#include <common_types.hpp>
+
 namespace nest {
 namespace mc {
 
 template <typename I, typename Time>
-struct spike {
+struct basic_spike {
     using id_type = I;
     using time_type = Time;
 
     id_type source = id_type{};
-    time_type time = -1.;
+    time_type time = -1;
 
-    spike() = default;
+    basic_spike() = default;
 
-    spike(id_type s, time_type t) :
+    basic_spike(id_type s, time_type t):
         source(s), time(t)
     {}
 };
 
+/// Standard specialization:
+using spike = basic_spike<cell_member_type, float>;
+
 } // namespace mc
 } // namespace nest
 
-/// custom stream operator for printing nest::mc::spike<> values
+// Custom stream operator for printing nest::mc::spike<> values.
 template <typename I, typename T>
-std::ostream& operator<<(std::ostream& o, nest::mc::spike<I, T> s) {
+std::ostream& operator<<(std::ostream& o, nest::mc::basic_spike<I, T> s) {
     return o << "spike[t " << s.time << ", src " << s.source << "]";
 }
 
-/// less than comparison operator for nest::mc::spike<> values
-/// spikes are ordered by spike time, for use in sorting and queueing
+/// Less than comparison operator for nest::mc::spike<> values:
+/// spikes are ordered by spike time, for use in sorting and queueing.
 template <typename I, typename T>
-bool operator<(nest::mc::spike<I, T> lhs, nest::mc::spike<I, T> rhs) {
+bool operator<(nest::mc::basic_spike<I, T> lhs, nest::mc::basic_spike<I, T> rhs) {
     return lhs.time < rhs.time;
 }
 
+
diff --git a/src/thread_private_spike_store.hpp b/src/thread_private_spike_store.hpp
index 3f320fdd..4b1973e3 100644
--- a/src/thread_private_spike_store.hpp
+++ b/src/thread_private_spike_store.hpp
@@ -15,17 +15,12 @@ namespace mc {
 /// The thread private buffer of the calling thread.
 /// The insert() and gather() methods add a vector of spikes to the buffer,
 /// and collate all of the buffers into a single vector respectively.
-template <typename Time>
 class thread_private_spike_store {
 public :
-    using id_type = cell_gid_type;
-    using time_type = Time;
-    using spike_type = spike<cell_member_type, time_type>;
-
     /// Collate all of the individual buffers into a single vector of spikes.
     /// Does not modify the buffer contents.
-    std::vector<spike_type> gather() const {
-        std::vector<spike_type> spikes;
+    std::vector<spike> gather() const {
+        std::vector<spike> spikes;
         unsigned num_spikes = 0u;
         for (auto& b : buffers_) {
             num_spikes += b.size();
@@ -40,12 +35,7 @@ public :
     }
 
     /// Return a reference to the thread private buffer of the calling thread
-    std::vector<spike_type>& get() {
-        return buffers_.local();
-    }
-
-    /// Return a reference to the thread private buffer of the calling thread
-    const std::vector<spike_type>& get() const {
+    std::vector<spike>& get() {
         return buffers_.local();
     }
 
@@ -58,7 +48,7 @@ public :
 
     /// Append the passed spikes to the end of the thread private buffer of the
     /// calling thread
-    void insert(const std::vector<spike_type>& spikes) {
+    void insert(const std::vector<spike>& spikes) {
         auto& buff = get();
         buff.insert(buff.end(), spikes.begin(), spikes.end());
     }
@@ -66,7 +56,7 @@ public :
 private :
     /// thread private storage for accumulating spikes
     using local_spike_store_type =
-        threading::enumerable_thread_specific<std::vector<spike_type>>;
+        threading::enumerable_thread_specific<std::vector<spike>>;
 
     local_spike_store_type buffers_;
 
diff --git a/tests/global_communication/test_communicator.cpp b/tests/global_communication/test_communicator.cpp
index b6bcc034..e3d427ae 100644
--- a/tests/global_communication/test_communicator.cpp
+++ b/tests/global_communication/test_communicator.cpp
@@ -11,8 +11,7 @@
 
 using namespace nest::mc;
 
-using time_type = float;
-using communicator_type = communication::communicator<time_type, communication::global_policy>;
+using communicator_type = communication::communicator<communication::global_policy>;
 
 TEST(communicator, setup) {
     /*
diff --git a/tests/global_communication/test_exporter_spike_file.cpp b/tests/global_communication/test_exporter_spike_file.cpp
index 027a315b..376230ed 100644
--- a/tests/global_communication/test_exporter_spike_file.cpp
+++ b/tests/global_communication/test_exporter_spike_file.cpp
@@ -9,15 +9,14 @@
 #include <communication/communicator.hpp>
 #include <communication/global_policy.hpp>
 #include <io/exporter_spike_file.hpp>
+#include <spike.hpp>
 
 class exporter_spike_file_fixture : public ::testing::Test {
 protected:
-    using time_type = float;
     using communicator_type = nest::mc::communication::global_policy;
 
     using exporter_type =
-        nest::mc::io::exporter_spike_file<time_type, communicator_type>;
-    using spike_type = exporter_type::spike_type;
+        nest::mc::io::exporter_spike_file<communicator_type>;
 
     std::string file_name_;
     std::string path_;
@@ -88,7 +87,7 @@ TEST_F(exporter_spike_file_fixture, do_export) {
         exporter_type exporter(file_name_, path_, extension_);
 
         // Create some spikes
-        std::vector<spike_type> spikes;
+        std::vector<nest::mc::spike> spikes;
         spikes.push_back({ { 0, 0 }, 0.0 });
         spikes.push_back({ { 0, 0 }, 0.1 });
         spikes.push_back({ { 1, 0 }, 1.0 });
diff --git a/tests/performance/io/disk_io.cpp b/tests/performance/io/disk_io.cpp
index 9d77318f..05b64666 100644
--- a/tests/performance/io/disk_io.cpp
+++ b/tests/performance/io/disk_io.cpp
@@ -12,14 +12,13 @@
 #include <fvm_multicell.hpp>
 #include <io/exporter_spike_file.hpp>
 #include <profiling/profiler.hpp>
+#include <spike.hpp>
 
 using namespace nest::mc;
 
 using global_policy = communication::global_policy;
 using lowered_cell = fvm::fvm_multicell<multicore::backend>;
 using cell_group_type = cell_group<lowered_cell>;
-using time_type = typename cell_group_type::time_type;
-using spike_type = io::exporter_spike_file<time_type, global_policy>::spike_type;
 using timer = util::timer_type;
 
 int main(int argc, char** argv) {
@@ -67,7 +66,7 @@ int main(int argc, char** argv) {
     }
 
     // Create the sut
-    io::exporter_spike_file<time_type, global_policy> exporter(
+    io::exporter_spike_file<global_policy> exporter(
          "spikes", "./", "gdf", true);
 
     // We need the nr of ranks to calculate the nr of spikes to produce per
@@ -78,7 +77,7 @@ int main(int argc, char** argv) {
     auto spikes_per_rank = nr_spikes / nr_ranks;
 
     // Create a set of spikes
-    std::vector<spike_type> spikes;
+    std::vector<spike> spikes;
 
     // *********************************************************************
     // To have a  somewhat realworld data set we calculate from the nr of spikes
diff --git a/tests/unit/test_spike_store.cpp b/tests/unit/test_spike_store.cpp
index 362412e6..a8039018 100644
--- a/tests/unit/test_spike_store.cpp
+++ b/tests/unit/test_spike_store.cpp
@@ -1,11 +1,14 @@
 #include "../gtest.h"
 
+#include <spike.hpp>
 #include <threading/threading.hpp>
 #include <thread_private_spike_store.hpp>
 
+using nest::mc::spike;
+
 TEST(spike_store, insert)
 {
-    using store_type = nest::mc::thread_private_spike_store<float>;
+    using store_type = nest::mc::thread_private_spike_store;
 
     store_type store;
 
@@ -49,7 +52,7 @@ TEST(spike_store, insert)
 
 TEST(spike_store, clear)
 {
-    using store_type = nest::mc::thread_private_spike_store<float>;
+    using store_type = nest::mc::thread_private_spike_store;
 
     store_type store;
 
@@ -64,11 +67,11 @@ TEST(spike_store, clear)
 
 TEST(spike_store, gather)
 {
-    using store_type = nest::mc::thread_private_spike_store<float>;
+    using store_type = nest::mc::thread_private_spike_store;
 
     store_type store;
 
-    auto spikes = std::vector<store_type::spike_type>
+    std::vector<spike> spikes =
         { {{0,0}, 0.0f}, {{1,2}, 0.5f}, {{2,4}, 1.0f} };
 
     store.insert(spikes);
-- 
GitLab