Skip to content
Snippets Groups Projects
Commit 507bca67 authored by Benjamin Cumming's avatar Benjamin Cumming
Browse files

update serial threading for asyncronous compute

* fixed const interface to the serial proxy for thread private storage
* added a serial proxy for tbb task group and moved the tbb task group
  behind the threading interface.
* added comments to the communicator to better document each method
parent a5872409
No related branches found
No related tags found
No related merge requests found
...@@ -27,20 +27,6 @@ namespace communication { ...@@ -27,20 +27,6 @@ namespace communication {
// Once all connections have been specified, the construct() method can be used // Once all connections have been specified, the construct() method can be used
// to build the data structures required for efficient spike communication and // to build the data structures required for efficient spike communication and
// event generation. // event generation.
//
// To overlap communication and computation, i.e. to perform spike
// exchange at the same time as cell state update, thread safe access to the
// spike and event lists must be provided. We use double buffering, whereby
// for each of the spikes and events one buffer is exposed publicly, while
// the other is used internally by the communicator
// - the spike lists are not directly exposed via the communicator
// interface. Instead they are updated when the add_spike() methods
// are called.
// - the event queues for each cell group are exposed via the queue()
// method.
// For each double buffer, the current buffer (accessed via buffer.get())
// is exposed to the user, and the other buffer is used inside exchange().
template <typename Time, typename CommunicationPolicy> template <typename Time, typename CommunicationPolicy>
class communicator { class communicator {
public: public:
...@@ -56,6 +42,7 @@ public: ...@@ -56,6 +42,7 @@ public:
communicator() = default; communicator() = default;
// TODO
// for now, still assuming one-to-one association cells <-> groups, // for now, still assuming one-to-one association cells <-> groups,
// so that 'group' gids as represented by their first cell gid are // so that 'group' gids as represented by their first cell gid are
// contiguous. // contiguous.
...@@ -73,17 +60,20 @@ public: ...@@ -73,17 +60,20 @@ public:
connections_.push_back(con); connections_.push_back(con);
} }
/// returns true if the cell with gid is on the domain of the caller
bool is_local_cell(id_type gid) const { bool is_local_cell(id_type gid) const {
return gid>=cell_gid_from_ && gid<cell_gid_to_; return gid>=cell_gid_from_ && gid<cell_gid_to_;
} }
// builds the optimized data structure /// builds the optimized data structure
/// must be called after all connections have been added
void construct() { void construct() {
if (!std::is_sorted(connections_.begin(), connections_.end())) { if (!std::is_sorted(connections_.begin(), connections_.end())) {
std::sort(connections_.begin(), connections_.end()); std::sort(connections_.begin(), connections_.end());
} }
} }
/// the minimum delay of all connections in the global network.
time_type min_delay() { time_type min_delay() {
auto local_min = std::numeric_limits<time_type>::max(); auto local_min = std::numeric_limits<time_type>::max();
for (auto& con : connections_) { for (auto& con : connections_) {
...@@ -93,6 +83,13 @@ public: ...@@ -93,6 +83,13 @@ public:
return communication_policy_.min(local_min); return communication_policy_.min(local_min);
} }
/// Perform exchange of spikes.
///
/// Takes as input the list of local_spikes that were generated on the calling domain.
///
/// Returns a vector of event queues, with one queue for each local cell group. The
/// events in each queue are all events that must be delivered to targets in that cell
/// group as a result of the global spike exchange.
std::vector<event_queue> exchange(const std::vector<spike_type>& local_spikes) { std::vector<event_queue> exchange(const std::vector<spike_type>& local_spikes) {
// global all-to-all to gather a local copy of the global spike list on each node. // global all-to-all to gather a local copy of the global spike list on each node.
auto global_spikes = communication_policy_.gather_spikes( local_spikes ); auto global_spikes = communication_policy_.gather_spikes( local_spikes );
...@@ -118,6 +115,8 @@ public: ...@@ -118,6 +115,8 @@ public:
return queues; return queues;
} }
/// Returns the total number of global spikes over the duration
/// of the simulation
uint64_t num_spikes() const { return num_spikes_; } uint64_t num_spikes() const { return num_spikes_; }
const std::vector<connection_type>& connections() const { const std::vector<connection_type>& connections() const {
......
...@@ -95,9 +95,7 @@ public: ...@@ -95,9 +95,7 @@ public:
// these buffers will store the new spikes generated in update_cells. // these buffers will store the new spikes generated in update_cells.
current_spikes().clear(); current_spikes().clear();
// TODO this needs a threading wrapper // task that updates cell state in parallel.
tbb::task_group g;
auto update_cells = [&] () { auto update_cells = [&] () {
threading::parallel_for::apply( threading::parallel_for::apply(
0u, cell_groups_.size(), 0u, cell_groups_.size(),
...@@ -117,6 +115,10 @@ public: ...@@ -117,6 +115,10 @@ public:
}); });
}; };
// task that performs spike exchange with the spikes generated in
// the previous integration period, generating the postsynaptic
// events that must be delivered at the start of the next
// integration period at the latest.
auto exchange = [&] () { auto exchange = [&] () {
PE("stepping", "exchange"); PE("stepping", "exchange");
auto local_spikes = previous_spikes().gather(); auto local_spikes = previous_spikes().gather();
...@@ -124,6 +126,9 @@ public: ...@@ -124,6 +126,9 @@ public:
PL(2); PL(2);
}; };
// run the tasks, overlapping if the threading model and number of
// available threads permits it.
threading::task_group g;
g.run(exchange); g.run(exchange);
g.run(update_cells); g.run(update_cells);
g.wait(); g.wait();
...@@ -133,12 +138,12 @@ public: ...@@ -133,12 +138,12 @@ public:
return t_; return t_;
} }
// TODO : these two calls are only thread safe if called outside the main time // only thread safe if called outside the run() method
// stepping loop.
void add_artificial_spike(cell_member_type source) { void add_artificial_spike(cell_member_type source) {
add_artificial_spike(source, t_); add_artificial_spike(source, t_);
} }
// only thread safe if called outside the run() method
void add_artificial_spike(cell_member_type source, time_type tspike) { void add_artificial_spike(cell_member_type source, time_type tspike) {
current_spikes().get().push_back({source, tspike}); current_spikes().get().push_back({source, tspike});
} }
...@@ -171,8 +176,25 @@ private: ...@@ -171,8 +176,25 @@ private:
using local_spike_store_type = thread_private_spike_store<time_type>; using local_spike_store_type = thread_private_spike_store<time_type>;
util::double_buffer< local_spike_store_type > local_spikes_; util::double_buffer< local_spike_store_type > local_spikes_;
// Convenience functions that map the spike buffers and event queues onto
// the appropriate integration interval.
//
// To overlap communication and computation, integration intervals of
// size Delta/2 are used, where Delta is the minimum delay in the global
// system.
// From the frame of reference of the current integration period we
// define three intervals: previous, current and future
// Then we define the following :
// current_spikes : spikes generated in the current interval
// previous_spikes: spikes generated in the preceding interval
// current_events : events to be delivered at the start of
// the current interval
// future_events : events to be delivered at the start of
// the next interval
local_spike_store_type& current_spikes() { return local_spikes_.get(); } local_spike_store_type& current_spikes() { return local_spikes_.get(); }
local_spike_store_type& previous_spikes() { return local_spikes_.other(); } local_spike_store_type& previous_spikes() { return local_spikes_.other(); }
std::vector<event_queue_type>& current_events() { return event_queues_.get(); } std::vector<event_queue_type>& current_events() { return event_queues_.get(); }
std::vector<event_queue_type>& future_events() { return event_queues_.other(); } std::vector<event_queue_type>& future_events() { return event_queues_.other(); }
}; };
......
...@@ -19,6 +19,8 @@ namespace threading { ...@@ -19,6 +19,8 @@ namespace threading {
template <typename T> template <typename T>
class enumerable_thread_specific { class enumerable_thread_specific {
std::array<T, 1> data; std::array<T, 1> data;
using iterator_type = typename std::array<T, 1>::iterator;
using const_iterator_type = typename std::array<T, 1>::const_iterator;
public : public :
...@@ -37,11 +39,14 @@ class enumerable_thread_specific { ...@@ -37,11 +39,14 @@ class enumerable_thread_specific {
auto size() -> decltype(data.size()) const { return data.size(); } auto size() -> decltype(data.size()) const { return data.size(); }
auto begin() -> decltype(data.begin()) { return data.begin(); } iterator_type begin() { return data.begin(); }
auto end() -> decltype(data.end()) { return data.end(); } iterator_type end() { return data.end(); }
auto cbegin() -> decltype(data.cbegin()) const { return data.cbegin(); } const_iterator_type begin() const { return data.begin(); }
auto cend() -> decltype(data.cend()) const { return data.cend(); } const_iterator_type end() const { return data.end(); }
const_iterator_type cbegin() const { return data.cbegin(); }
const_iterator_type cend() const { return data.cend(); }
}; };
...@@ -83,6 +88,34 @@ struct timer { ...@@ -83,6 +88,34 @@ struct timer {
constexpr bool multithreaded() { return false; } constexpr bool multithreaded() { return false; }
/// Proxy for tbb task group.
/// The tbb version launches tasks asynchronously, returning control to the
/// caller. The serial version implemented here simply runs the task, before
/// returning control, effectively serializing all asynchronous calls.
class task_group {
public:
task_group() = default;
template<typename Func>
void run(const Func& f) {
f();
}
template<typename Func>
void run_and_wait(const Func& f) {
f();
}
void wait()
{}
bool is_canceling() {
return false;
}
void cancel()
{}
};
} // threading } // threading
} // mc } // mc
......
...@@ -49,6 +49,8 @@ constexpr bool multithreaded() { return true; } ...@@ -49,6 +49,8 @@ constexpr bool multithreaded() { return true; }
template <typename T> template <typename T>
using parallel_vector = tbb::concurrent_vector<T>; using parallel_vector = tbb::concurrent_vector<T>;
using task_group = tbb::task_group;
} // threading } // threading
} // mc } // mc
} // nest } // nest
......
...@@ -61,14 +61,6 @@ TEST(spike_store, clear) ...@@ -61,14 +61,6 @@ TEST(spike_store, clear)
EXPECT_EQ(store.get().size(), 0u); EXPECT_EQ(store.get().size(), 0u);
} }
template <typename I, typename T>
bool test_spike_equality(
nest::mc::spike<I,T> lhs,
nest::mc::spike<I,T> rhs)
{
return lhs.time==rhs.time;
}
TEST(spike_store, gather) TEST(spike_store, gather)
{ {
using store_type = nest::mc::thread_private_spike_store<float>; using store_type = nest::mc::thread_private_spike_store<float>;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment