diff --git a/CMakeLists.txt b/CMakeLists.txt index 37f1cfe80e468fa7d8c50a7f57ab7de7b01a2f51..44e40b639c2328b5512ad4c557a877d4d9cbec8f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -62,6 +62,12 @@ if(SYSTEM_CRAY) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -dynamic") endif() +# Cray systems +set(TARGET_KNL OFF CACHE BOOL "target Intel KNL architecture") +if(TARGET_KNL) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXXOPT_KNL}") +endif() + # targets for extermal dependencies include(ExternalProject) externalproject_add(modparser diff --git a/cmake/CompilerOptions.cmake b/cmake/CompilerOptions.cmake index 167ae04818f6c2eb791704de08d7bd99adb449fb..6dc09b961cc6a71ce69b1bf9cba5d6821e6949e8 100644 --- a/cmake/CompilerOptions.cmake +++ b/cmake/CompilerOptions.cmake @@ -13,10 +13,21 @@ if(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") set(CXXOPT_WALL "${CXXOPT_WALL} -Wno-missing-braces") endif() +if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") + + # compiler flags for generating KNL-specific AVX512 instructions + # supported in gcc 4.9.x and later + #set(CXXOPT_KNL "-mavx512f") + set(CXXOPT_KNL "-march=knl") +endif() + if(${CMAKE_CXX_COMPILER_ID} MATCHES "Intel") # Disable warning for unused template parameter # this is raised by a templated function in the json library set(CXXOPT_WALL "${CXXOPT_WALL} -wd488") + + # compiler flags for generating KNL-specific AVX512 instructions + set(CXXOPT_KNL "-xMIC-AVX512") endif() diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index 871bf65266da7c7ffcc962a41f9913d93fa89fc2..8c72430f5c4646ce7747c76cac088bee65a58df8 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -5,11 +5,12 @@ #include <vector> #include <random> +#include <spike.hpp> +#include <util/double_buffer.hpp> #include <algorithms.hpp> #include <connection.hpp> #include <event_queue.hpp> #include <spike.hpp> -#include <threading/threading.hpp> #include <util/debug.hpp> namespace nest { @@ -29,43 +30,50 @@ namespace communication { template <typename Time, typename CommunicationPolicy> class communicator { public: + using communication_policy_type = CommunicationPolicy; using id_type = cell_gid_type; using time_type = Time; - using communication_policy_type = CommunicationPolicy; - using spike_type = spike<cell_member_type, time_type>; + using connection_type = connection<time_type>; + + /// per-cell group lists of events to be delivered + using event_queue = + std::vector<postsynaptic_spike_event<time_type>>; communicator() = default; + // TODO // for now, still assuming one-to-one association cells <-> groups, // so that 'group' gids as represented by their first cell gid are // contiguous. communicator(id_type cell_from, id_type cell_to): cell_gid_from_(cell_from), cell_gid_to_(cell_to) - { - auto num_groups_local_ = cell_gid_to_-cell_gid_from_; + {} - // create an event queue for each target group - events_.resize(num_groups_local_); + cell_local_size_type num_groups_local() const + { + return cell_gid_to_-cell_gid_from_; } - - void add_connection(connection<time_type> con) { + void add_connection(connection_type con) { EXPECTS(is_local_cell(con.destination().gid)); connections_.push_back(con); } + /// returns true if the cell with gid is on the domain of the caller bool is_local_cell(id_type gid) const { return gid>=cell_gid_from_ && gid<cell_gid_to_; } - // builds the optimized data structure + /// builds the optimized data structure + /// must be called after all connections have been added void construct() { if (!std::is_sorted(connections_.begin(), connections_.end())) { std::sort(connections_.begin(), connections_.end()); } } + /// the minimum delay of all connections in the global network. time_type min_delay() { auto local_min = std::numeric_limits<time_type>::max(); for (auto& con : connections_) { @@ -75,31 +83,22 @@ public: return communication_policy_.min(local_min); } - void add_spike(spike_type s) { - thread_spikes().push_back(s); - } - - void add_spikes(const std::vector<spike_type>& s) { - auto& v = thread_spikes(); - v.insert(v.end(), s.begin(), s.end()); - } - - std::vector<spike_type>& thread_spikes() { - return thread_spikes_.local(); - } - - void exchange() { - // global all-to-all to gather a local copy of the global spike list - // on each node - auto global_spikes = communication_policy_.gather_spikes(local_spikes()); + /// Perform exchange of spikes. + /// + /// Takes as input the list of local_spikes that were generated on the calling domain. + /// + /// Returns a vector of event queues, with one queue for each local cell group. The + /// events in each queue are all events that must be delivered to targets in that cell + /// group as a result of the global spike exchange. + std::vector<event_queue> exchange(const std::vector<spike_type>& local_spikes) { + // global all-to-all to gather a local copy of the global spike list on each node. + auto global_spikes = communication_policy_.gather_spikes( local_spikes ); num_spikes_ += global_spikes.size(); - clear_thread_spike_buffers(); - for (auto& q : events_) { - q.clear(); - } + // check each global spike in turn to see it generates local events. + // if so, make the events and insert them into the appropriate event list. + auto queues = std::vector<event_queue>(num_groups_local()); - // check all global spikes to see if they will generate local events for (auto spike : global_spikes) { // search for targets auto targets = @@ -110,18 +109,18 @@ public: // generate an event for each target for (auto it=targets.first; it!=targets.second; ++it) { auto gidx = cell_group_index(it->destination().gid); - events_[gidx].push_back(it->make_event(spike)); + queues[gidx].push_back(it->make_event(spike)); } } + + return queues; } + /// Returns the total number of global spikes over the duration + /// of the simulation uint64_t num_spikes() const { return num_spikes_; } - const std::vector<postsynaptic_spike_event<time_type>>& queue(int i) const { - return events_[i]; - } - - const std::vector<connection<time_type>>& connections() const { + const std::vector<connection_type>& connections() const { return connections_; } @@ -129,28 +128,6 @@ public: return communication_policy_; } - std::vector<spike_type> local_spikes() { - std::vector<spike_type> spikes; - for (auto& v : thread_spikes_) { - spikes.insert(spikes.end(), v.begin(), v.end()); - } - return spikes; - } - - void clear_thread_spike_buffers() { - for (auto& v : thread_spikes_) { - v.clear(); - } - } - - void reset() { - // remove all in-flight spikes/events - clear_thread_spike_buffers(); - for (auto& evbuf: events_) { - evbuf.clear(); - } - } - private: std::size_t cell_group_index(cell_gid_type cell_gid) const { // this will be more elaborate when there is more than one cell per cell group @@ -158,24 +135,7 @@ private: return cell_gid-cell_gid_from_; } - // - // both of these can be fixed with double buffering - // - // FIXME : race condition on the thread_spikes_ buffers when exchange() modifies/access them - // ... other threads will be pushing to them simultaneously - // FIXME : race condition on the group-specific event queues when exchange pushes to them - // ... other threads will be accessing them to update their event queues - - // thread private storage for accumulating spikes - using local_spike_store_type = - nest::mc::threading::enumerable_thread_specific<std::vector<spike_type>>; - local_spike_store_type thread_spikes_; - - std::vector<connection<time_type>> connections_; - std::vector<std::vector<postsynaptic_spike_event<time_type>>> events_; - - // for keeping track of how time is spent where - //util::Profiler profiler_; + std::vector<connection_type> connections_; communication_policy_type communication_policy_; diff --git a/src/model.hpp b/src/model.hpp index 3812a44bfffaa21397d12953c93fcd13bb3843d9..84dd84298a78418ce882b68a29f2211a6e50401b 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -6,10 +6,11 @@ #include <common_types.hpp> #include <cell.hpp> #include <cell_group.hpp> -#include <communication/communicator.hpp> -#include <communication/global_policy.hpp> #include <fvm_cell.hpp> #include <recipe.hpp> +#include <thread_private_spike_store.hpp> +#include <communication/communicator.hpp> +#include <communication/global_policy.hpp> #include <profiling/profiler.hpp> #include "trace_sampler.hpp" @@ -36,9 +37,10 @@ public: cell_to_(cell_to), communicator_(cell_from, cell_to) { + // generate the cell groups in parallel, with one task per cell group cell_groups_ = std::vector<cell_group_type>{cell_to_-cell_from_}; - threading::parallel_vector<probe_record> probes; + threading::parallel_for::apply(cell_from_, cell_to_, [&](cell_gid_type i) { PE("setup", "cells"); @@ -54,8 +56,10 @@ public: PL(2); }); + // insert probes probes_.assign(probes.begin(), probes.end()); + // generate the network connections for (cell_gid_type i=cell_from_; i<cell_to_; ++i) { for (const auto& cc: rec.connections_on(i)) { connection<time_type> conn{cc.source, cc.dest, cc.weight, cc.delay}; @@ -63,6 +67,12 @@ public: } } communicator_.construct(); + + // Allocate an empty queue buffer for each cell group + // These must be set initially to ensure that a queue is available for each + // cell group for the first time step. + current_events().resize(num_groups()); + future_events().resize(num_groups()); } void reset() { @@ -74,41 +84,74 @@ public: } time_type run(time_type tfinal, time_type dt) { - time_type min_delay = communicator_.min_delay(); - while (t_<tfinal) { - auto tuntil = std::min(t_+min_delay, tfinal); - threading::parallel_for::apply( - 0u, cell_groups_.size(), - [&](unsigned i) { - auto& group = cell_groups_[i]; - - PE("stepping","events"); - group.enqueue_events(communicator_.queue(i)); - PL(); + // Calculate the size of the largest possible time integration interval + // before communication of spikes is required. + // If spike exchange and cell update are serialized, this is the + // minimum delay of the network, however we use half this period + // to overlap communication and computation. + time_type t_interval = communicator_.min_delay()/2; - group.advance(tuntil, dt); - - PE("events"); - communicator_.add_spikes(group.spikes()); - group.clear_spikes(); - PL(2); - }); + while (t_<tfinal) { + auto tuntil = std::min(t_+t_interval, tfinal); + + event_queues_.exchange(); + local_spikes_.exchange(); + + // empty the spike buffers for the current integration period. + // these buffers will store the new spikes generated in update_cells. + current_spikes().clear(); + + // task that updates cell state in parallel. + auto update_cells = [&] () { + threading::parallel_for::apply( + 0u, cell_groups_.size(), + [&](unsigned i) { + auto &group = cell_groups_[i]; + + PE("stepping","events"); + group.enqueue_events(current_events()[i]); + PL(); + + group.advance(tuntil, dt); + + PE("events"); + current_spikes().insert(group.spikes()); + group.clear_spikes(); + PL(2); + }); + }; + + // task that performs spike exchange with the spikes generated in + // the previous integration period, generating the postsynaptic + // events that must be delivered at the start of the next + // integration period at the latest. + auto exchange = [&] () { + PE("stepping", "exchange"); + auto local_spikes = previous_spikes().gather(); + future_events() = communicator_.exchange(local_spikes); + PL(2); + }; - PE("stepping", "exchange"); - communicator_.exchange(); - PL(2); + // run the tasks, overlapping if the threading model and number of + // available threads permits it. + threading::task_group g; + g.run(exchange); + g.run(update_cells); + g.wait(); t_ = tuntil; } return t_; } + // only thread safe if called outside the run() method void add_artificial_spike(cell_member_type source) { add_artificial_spike(source, t_); } + // only thread safe if called outside the run() method void add_artificial_spike(cell_member_type source, time_type tspike) { - communicator_.add_spike({source, tspike}); + current_spikes().get().push_back({source, tspike}); } void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0) { @@ -122,6 +165,7 @@ public: const std::vector<probe_record>& probes() const { return probes_; } std::size_t num_spikes() const { return communicator_.num_spikes(); } + std::size_t num_groups() const { return cell_groups_.size(); } private: cell_gid_type cell_from_; @@ -130,6 +174,35 @@ private: std::vector<cell_group_type> cell_groups_; communicator_type communicator_; std::vector<probe_record> probes_; + using spike_type = typename communicator_type::spike_type; + + using event_queue_type = typename communicator_type::event_queue; + util::double_buffer< std::vector<event_queue_type> > event_queues_; + + using local_spike_store_type = thread_private_spike_store<time_type>; + util::double_buffer< local_spike_store_type > local_spikes_; + + // Convenience functions that map the spike buffers and event queues onto + // the appropriate integration interval. + // + // To overlap communication and computation, integration intervals of + // size Delta/2 are used, where Delta is the minimum delay in the global + // system. + // From the frame of reference of the current integration period we + // define three intervals: previous, current and future + // Then we define the following : + // current_spikes : spikes generated in the current interval + // previous_spikes: spikes generated in the preceding interval + // current_events : events to be delivered at the start of + // the current interval + // future_events : events to be delivered at the start of + // the next interval + + local_spike_store_type& current_spikes() { return local_spikes_.get(); } + local_spike_store_type& previous_spikes() { return local_spikes_.other(); } + + std::vector<event_queue_type>& current_events() { return event_queues_.get(); } + std::vector<event_queue_type>& future_events() { return event_queues_.other(); } }; } // namespace mc diff --git a/src/thread_private_spike_store.hpp b/src/thread_private_spike_store.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1e14ab104619a42b8aa9f014172bb5716576017c --- /dev/null +++ b/src/thread_private_spike_store.hpp @@ -0,0 +1,75 @@ +#pragma once + +#include <vector> + +#include <common_types.hpp> +#include <spike.hpp> +#include <threading/threading.hpp> + +namespace nest { +namespace mc { + +/// Handles the complexity of managing thread private buffers of spikes. +/// Internally stores one thread private buffer of spikes for each hardware thread. +/// This can be accessed directly using the get() method, which returns a reference to +/// The thread private buffer of the calling thread. +/// The insert() and gather() methods add a vector of spikes to the buffer, +/// and collate all of the buffers into a single vector respectively. +template <typename Time> +class thread_private_spike_store { +public : + using id_type = cell_gid_type; + using time_type = Time; + using spike_type = spike<cell_member_type, time_type>; + + /// Collate all of the individual buffers into a single vector of spikes. + /// Does not modify the buffer contents. + std::vector<spike_type> gather() const { + std::vector<spike_type> spikes; + unsigned num_spikes = 0u; + for (auto& b : buffers_) { + num_spikes += b.size(); + } + spikes.reserve(num_spikes); + + for (auto& b : buffers_) { + spikes.insert(spikes.begin(), b.begin(), b.end()); + } + + return spikes; + } + + /// Return a reference to the thread private buffer of the calling thread + std::vector<spike_type>& get() { + return buffers_.local(); + } + + /// Return a reference to the thread private buffer of the calling thread + const std::vector<spike_type>& get() const { + return buffers_.local(); + } + + /// Clear all of the thread private buffers + void clear() { + for (auto& b : buffers_) { + b.clear(); + } + } + + /// Append the passed spikes to the end of the thread private buffer of the + /// calling thread + void insert(const std::vector<spike_type>& spikes) { + auto& buff = get(); + buff.insert(buff.end(), spikes.begin(), spikes.end()); + } + +private : + /// thread private storage for accumulating spikes + using local_spike_store_type = + threading::enumerable_thread_specific<std::vector<spike_type>>; + + local_spike_store_type buffers_; +}; + +} // namespace mc +} // namespace nest diff --git a/src/threading/serial.hpp b/src/threading/serial.hpp index 8d9c4d64fd84805ed0909ad704ffc192bd0da748..c18d4800e3c457e140f0fe6d2ca671aa6ccac241 100644 --- a/src/threading/serial.hpp +++ b/src/threading/serial.hpp @@ -19,6 +19,8 @@ namespace threading { template <typename T> class enumerable_thread_specific { std::array<T, 1> data; + using iterator_type = typename std::array<T, 1>::iterator; + using const_iterator_type = typename std::array<T, 1>::const_iterator; public : @@ -37,11 +39,14 @@ class enumerable_thread_specific { auto size() -> decltype(data.size()) const { return data.size(); } - auto begin() -> decltype(data.begin()) { return data.begin(); } - auto end() -> decltype(data.end()) { return data.end(); } + iterator_type begin() { return data.begin(); } + iterator_type end() { return data.end(); } - auto cbegin() -> decltype(data.cbegin()) const { return data.cbegin(); } - auto cend() -> decltype(data.cend()) const { return data.cend(); } + const_iterator_type begin() const { return data.begin(); } + const_iterator_type end() const { return data.end(); } + + const_iterator_type cbegin() const { return data.cbegin(); } + const_iterator_type cend() const { return data.cend(); } }; @@ -83,6 +88,34 @@ struct timer { constexpr bool multithreaded() { return false; } +/// Proxy for tbb task group. +/// The tbb version launches tasks asynchronously, returning control to the +/// caller. The serial version implemented here simply runs the task, before +/// returning control, effectively serializing all asynchronous calls. +class task_group { +public: + task_group() = default; + + template<typename Func> + void run(const Func& f) { + f(); + } + + template<typename Func> + void run_and_wait(const Func& f) { + f(); + } + + void wait() + {} + + bool is_canceling() { + return false; + } + + void cancel() + {} +}; } // threading } // mc diff --git a/src/threading/tbb.hpp b/src/threading/tbb.hpp index 8e0086741cac1d51e02e9fa148e9eaa3f3f837a3..853ceefde731cb1cd4db323541132c935011c73a 100644 --- a/src/threading/tbb.hpp +++ b/src/threading/tbb.hpp @@ -49,6 +49,8 @@ constexpr bool multithreaded() { return true; } template <typename T> using parallel_vector = tbb::concurrent_vector<T>; +using task_group = tbb::task_group; + } // threading } // mc } // nest diff --git a/src/util/double_buffer.hpp b/src/util/double_buffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31cffa1babef2ced0b843b4d3e375678612528b2 --- /dev/null +++ b/src/util/double_buffer.hpp @@ -0,0 +1,66 @@ +#pragma once + +#include <array> +#include <atomic> + +#include <util/debug.hpp> + +namespace nest { +namespace mc { +namespace util { + +/// double buffer with thread safe exchange/flip operation. +template <typename T> +class double_buffer { +private: + std::atomic<int> index_; + std::array<T, 2> buffers_; + + int other_index() { + return index_ ? 0 : 1; + } + +public: + using value_type = T; + + double_buffer() : + index_(0) + {} + + /// remove the copy and move constructors which won't work with std::atomic + double_buffer(double_buffer&&) = delete; + double_buffer(const double_buffer&) = delete; + double_buffer& operator=(const double_buffer&) = delete; + double_buffer& operator=(double_buffer&&) = delete; + + /// flip the buffers in a thread safe manner + /// n calls to exchange will always result in n flips + void exchange() { + // use operator^= which is overloaded by std::atomic<> + index_ ^= 1; + } + + /// get the current/front buffer + value_type& get() { + return buffers_[index_]; + } + + /// get the current/front buffer + const value_type& get() const { + return buffers_[index_]; + } + + /// get the back buffer + value_type& other() { + return buffers_[other_index()]; + } + + /// get the back buffer + const value_type& other() const { + return buffers_[other_index()]; + } +}; + +} // namespace util +} // namespace mc +} // namespace nest diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 4ba5bb8b0be6764e84fb24ea1c16ee803f6bc516..fb1b9369d4e38498c8953fc0ce2f16c1dbc715dc 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -11,6 +11,7 @@ set(HEADERS set(TEST_SOURCES # unit tests test_algorithms.cpp + test_double_buffer.cpp test_cell.cpp test_compartments.cpp test_event_queue.cpp @@ -26,6 +27,7 @@ set(TEST_SOURCES test_probe.cpp test_segment.cpp test_spikes.cpp + test_spike_store.cpp test_stimulus.cpp test_swcio.cpp test_synapses.cpp @@ -45,12 +47,12 @@ foreach(target ${TARGETS}) target_link_libraries(${target} LINK_PUBLIC cellalgo gtest) if(WITH_TBB) - target_link_libraries(${target} LINK_PUBLIC ${TBB_LIBRARIES}) + target_link_libraries(${target} LINK_PUBLIC ${TBB_LIBRARIES}) endif() if(WITH_MPI) - target_link_libraries(${target} LINK_PUBLIC ${MPI_C_LIBRARIES}) - set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "${MPI_C_LINK_FLAGS}") + target_link_libraries(${target} LINK_PUBLIC ${MPI_C_LIBRARIES}) + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS "${MPI_C_LINK_FLAGS}") endif() set_target_properties(${target} diff --git a/tests/unit/test_double_buffer.cpp b/tests/unit/test_double_buffer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1689cc2faac17331694c78830c22cd1ea57777fe --- /dev/null +++ b/tests/unit/test_double_buffer.cpp @@ -0,0 +1,58 @@ +#include "gtest.h" + +#include <util/double_buffer.hpp> + +// not much to test here: just test that values passed into the constructor +// are correctly stored in members +TEST(double_buffer, exchange_and_get) +{ + using namespace nest::mc::util; + + double_buffer<int> buf; + + buf.get() = 2134; + buf.exchange(); + buf.get() = 8990; + buf.exchange(); + + EXPECT_EQ(buf.get(), 2134); + EXPECT_EQ(buf.other(), 8990); + buf.exchange(); + EXPECT_EQ(buf.get(), 8990); + EXPECT_EQ(buf.other(), 2134); + buf.exchange(); + EXPECT_EQ(buf.get(), 2134); + EXPECT_EQ(buf.other(), 8990); +} + +TEST(double_buffer, assign_get_other) +{ + using namespace nest::mc::util; + + double_buffer<std::string> buf; + + buf.get() = "1"; + buf.other() = "2"; + + EXPECT_EQ(buf.get(), "1"); + EXPECT_EQ(buf.other(), "2"); +} + +TEST(double_buffer, non_pod) +{ + using namespace nest::mc::util; + + double_buffer<std::string> buf; + + buf.get() = "1"; + buf.other() = "2"; + + EXPECT_EQ(buf.get(), "1"); + EXPECT_EQ(buf.other(), "2"); + buf.exchange(); + EXPECT_EQ(buf.get(), "2"); + EXPECT_EQ(buf.other(), "1"); + buf.exchange(); + EXPECT_EQ(buf.get(), "1"); + EXPECT_EQ(buf.other(), "2"); +} diff --git a/tests/unit/test_spike_store.cpp b/tests/unit/test_spike_store.cpp new file mode 100644 index 0000000000000000000000000000000000000000..dd075f878bfccc7a7a07887e6736fcb2768e0237 --- /dev/null +++ b/tests/unit/test_spike_store.cpp @@ -0,0 +1,84 @@ +#include "gtest.h" + +#include <thread_private_spike_store.hpp> + +TEST(spike_store, insert) +{ + using store_type = nest::mc::thread_private_spike_store<float>; + + store_type store; + + // insert 3 spike events and check that they were inserted correctly + store.insert({ + {{0,0}, 0.0f}, + {{1,2}, 0.5f}, + {{2,4}, 1.0f} + }); + + { + EXPECT_EQ(store.get().size(), 3u); + auto i = 0u; + for (auto& spike : store.get()) { + EXPECT_EQ(spike.source.gid, i); + EXPECT_EQ(spike.source.index, 2*i); + EXPECT_EQ(spike.time, float(i)/2.f); + ++i; + } + } + + // insert another 3 events, then check that they were appended to the + // original three events correctly + store.insert({ + {{3,6}, 1.5f}, + {{4,8}, 2.0f}, + {{5,10}, 2.5f} + }); + + { + EXPECT_EQ(store.get().size(), 6u); + auto i = 0u; + for (auto& spike : store.get()) { + EXPECT_EQ(spike.source.gid, i); + EXPECT_EQ(spike.source.index, 2*i); + EXPECT_EQ(spike.time, float(i)/2.f); + ++i; + } + } +} + +TEST(spike_store, clear) +{ + using store_type = nest::mc::thread_private_spike_store<float>; + + store_type store; + + // insert 3 spike events + store.insert({ + {{0,0}, 0.0f}, {{1,2}, 0.5f}, {{2,4}, 1.0f} + }); + EXPECT_EQ(store.get().size(), 3u); + store.clear(); + EXPECT_EQ(store.get().size(), 0u); +} + +TEST(spike_store, gather) +{ + using store_type = nest::mc::thread_private_spike_store<float>; + + store_type store; + + auto spikes = std::vector<store_type::spike_type> + { {{0,0}, 0.0f}, {{1,2}, 0.5f}, {{2,4}, 1.0f} }; + + store.insert(spikes); + auto gathered_spikes = store.gather(); + + EXPECT_EQ(gathered_spikes.size(), spikes.size()); + + for(auto i=0u; i<spikes.size(); ++i) { + EXPECT_EQ(spikes[i].source.gid, gathered_spikes[i].source.gid); + EXPECT_EQ(spikes[i].source.index, gathered_spikes[i].source.index); + EXPECT_EQ(spikes[i].time, gathered_spikes[i].time); + } +} +