diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index 2eab80bd0556722d9fd3489b29c5b60d68180565..5adef5669280f94b68ae90d1cf99a2c38497344f 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -27,20 +27,6 @@ namespace communication { // Once all connections have been specified, the construct() method can be used // to build the data structures required for efficient spike communication and // event generation. -// -// To overlap communication and computation, i.e. to perform spike -// exchange at the same time as cell state update, thread safe access to the -// spike and event lists must be provided. We use double buffering, whereby -// for each of the spikes and events one buffer is exposed publicly, while -// the other is used internally by the communicator -// - the spike lists are not directly exposed via the communicator -// interface. Instead they are updated when the add_spike() methods -// are called. -// - the event queues for each cell group are exposed via the queue() -// method. -// For each double buffer, the current buffer (accessed via buffer.get()) -// is exposed to the user, and the other buffer is used inside exchange(). - template <typename Time, typename CommunicationPolicy> class communicator { public: @@ -56,6 +42,7 @@ public: communicator() = default; + // TODO // for now, still assuming one-to-one association cells <-> groups, // so that 'group' gids as represented by their first cell gid are // contiguous. @@ -73,17 +60,20 @@ public: connections_.push_back(con); } + /// returns true if the cell with gid is on the domain of the caller bool is_local_cell(id_type gid) const { return gid>=cell_gid_from_ && gid<cell_gid_to_; } - // builds the optimized data structure + /// builds the optimized data structure + /// must be called after all connections have been added void construct() { if (!std::is_sorted(connections_.begin(), connections_.end())) { std::sort(connections_.begin(), connections_.end()); } } + /// the minimum delay of all connections in the global network. time_type min_delay() { auto local_min = std::numeric_limits<time_type>::max(); for (auto& con : connections_) { @@ -93,6 +83,13 @@ public: return communication_policy_.min(local_min); } + /// Perform exchange of spikes. + /// + /// Takes as input the list of local_spikes that were generated on the calling domain. + /// + /// Returns a vector of event queues, with one queue for each local cell group. The + /// events in each queue are all events that must be delivered to targets in that cell + /// group as a result of the global spike exchange. std::vector<event_queue> exchange(const std::vector<spike_type>& local_spikes) { // global all-to-all to gather a local copy of the global spike list on each node. auto global_spikes = communication_policy_.gather_spikes( local_spikes ); @@ -118,6 +115,8 @@ public: return queues; } + /// Returns the total number of global spikes over the duration + /// of the simulation uint64_t num_spikes() const { return num_spikes_; } const std::vector<connection_type>& connections() const { diff --git a/src/model.hpp b/src/model.hpp index 29d59db4d110e798ca403eb8dbd07dc2653d926e..650231f3697198a3f28bd07f6412f3b37722de8f 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -95,9 +95,7 @@ public: // these buffers will store the new spikes generated in update_cells. current_spikes().clear(); - // TODO this needs a threading wrapper - tbb::task_group g; - + // task that updates cell state in parallel. auto update_cells = [&] () { threading::parallel_for::apply( 0u, cell_groups_.size(), @@ -117,6 +115,10 @@ public: }); }; + // task that performs spike exchange with the spikes generated in + // the previous integration period, generating the postsynaptic + // events that must be delivered at the start of the next + // integration period at the latest. auto exchange = [&] () { PE("stepping", "exchange"); auto local_spikes = previous_spikes().gather(); @@ -124,6 +126,9 @@ public: PL(2); }; + // run the tasks, overlapping if the threading model and number of + // available threads permits it. + threading::task_group g; g.run(exchange); g.run(update_cells); g.wait(); @@ -133,12 +138,12 @@ public: return t_; } - // TODO : these two calls are only thread safe if called outside the main time - // stepping loop. + // only thread safe if called outside the run() method void add_artificial_spike(cell_member_type source) { add_artificial_spike(source, t_); } + // only thread safe if called outside the run() method void add_artificial_spike(cell_member_type source, time_type tspike) { current_spikes().get().push_back({source, tspike}); } @@ -171,8 +176,25 @@ private: using local_spike_store_type = thread_private_spike_store<time_type>; util::double_buffer< local_spike_store_type > local_spikes_; + // Convenience functions that map the spike buffers and event queues onto + // the appropriate integration interval. + // + // To overlap communication and computation, integration intervals of + // size Delta/2 are used, where Delta is the minimum delay in the global + // system. + // From the frame of reference of the current integration period we + // define three intervals: previous, current and future + // Then we define the following : + // current_spikes : spikes generated in the current interval + // previous_spikes: spikes generated in the preceding interval + // current_events : events to be delivered at the start of + // the current interval + // future_events : events to be delivered at the start of + // the next interval + local_spike_store_type& current_spikes() { return local_spikes_.get(); } local_spike_store_type& previous_spikes() { return local_spikes_.other(); } + std::vector<event_queue_type>& current_events() { return event_queues_.get(); } std::vector<event_queue_type>& future_events() { return event_queues_.other(); } }; diff --git a/src/threading/serial.hpp b/src/threading/serial.hpp index 8d9c4d64fd84805ed0909ad704ffc192bd0da748..c18d4800e3c457e140f0fe6d2ca671aa6ccac241 100644 --- a/src/threading/serial.hpp +++ b/src/threading/serial.hpp @@ -19,6 +19,8 @@ namespace threading { template <typename T> class enumerable_thread_specific { std::array<T, 1> data; + using iterator_type = typename std::array<T, 1>::iterator; + using const_iterator_type = typename std::array<T, 1>::const_iterator; public : @@ -37,11 +39,14 @@ class enumerable_thread_specific { auto size() -> decltype(data.size()) const { return data.size(); } - auto begin() -> decltype(data.begin()) { return data.begin(); } - auto end() -> decltype(data.end()) { return data.end(); } + iterator_type begin() { return data.begin(); } + iterator_type end() { return data.end(); } - auto cbegin() -> decltype(data.cbegin()) const { return data.cbegin(); } - auto cend() -> decltype(data.cend()) const { return data.cend(); } + const_iterator_type begin() const { return data.begin(); } + const_iterator_type end() const { return data.end(); } + + const_iterator_type cbegin() const { return data.cbegin(); } + const_iterator_type cend() const { return data.cend(); } }; @@ -83,6 +88,34 @@ struct timer { constexpr bool multithreaded() { return false; } +/// Proxy for tbb task group. +/// The tbb version launches tasks asynchronously, returning control to the +/// caller. The serial version implemented here simply runs the task, before +/// returning control, effectively serializing all asynchronous calls. +class task_group { +public: + task_group() = default; + + template<typename Func> + void run(const Func& f) { + f(); + } + + template<typename Func> + void run_and_wait(const Func& f) { + f(); + } + + void wait() + {} + + bool is_canceling() { + return false; + } + + void cancel() + {} +}; } // threading } // mc diff --git a/src/threading/tbb.hpp b/src/threading/tbb.hpp index 8e0086741cac1d51e02e9fa148e9eaa3f3f837a3..853ceefde731cb1cd4db323541132c935011c73a 100644 --- a/src/threading/tbb.hpp +++ b/src/threading/tbb.hpp @@ -49,6 +49,8 @@ constexpr bool multithreaded() { return true; } template <typename T> using parallel_vector = tbb::concurrent_vector<T>; +using task_group = tbb::task_group; + } // threading } // mc } // nest diff --git a/tests/unit/test_spike_store.cpp b/tests/unit/test_spike_store.cpp index 0cddfda3e71f81417716a114221fecd7305ee94c..dd075f878bfccc7a7a07887e6736fcb2768e0237 100644 --- a/tests/unit/test_spike_store.cpp +++ b/tests/unit/test_spike_store.cpp @@ -61,14 +61,6 @@ TEST(spike_store, clear) EXPECT_EQ(store.get().size(), 0u); } -template <typename I, typename T> -bool test_spike_equality( - nest::mc::spike<I,T> lhs, - nest::mc::spike<I,T> rhs) -{ - return lhs.time==rhs.time; -} - TEST(spike_store, gather) { using store_type = nest::mc::thread_private_spike_store<float>;