diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index ba6cd612e3a96a39648d3ac6f8093b5859212db1..7a9ee094b59cb53708cfa291bae6f3f2b9c76ece 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -43,7 +43,7 @@ namespace communication { template <typename Time, typename CommunicationPolicy> class communicator { -private: +public: using communication_policy_type = CommunicationPolicy; using id_type = cell_gid_type; using time_type = Time; @@ -58,60 +58,6 @@ private: using event_queue = std::vector<postsynaptic_spike_event<time_type>>; - /// double buffered storage of the thread private spike lists - util::double_buffer<local_spike_store_type> thread_spikes_; - - /// double buffered storage of the cell group event lists - util::double_buffer<std::vector<event_queue>> events_; - - /// access to the spikes buffered from the previous communication - /// interval. Used internally by the communicator for exchange - local_spike_store_type& buffered_spikes() { - return thread_spikes_.other(); - } - - /// access to thread-private list of spikes, used for storing - /// spikes added via the add_spike() interface - std::vector<spike_type>& thread_spikes() { - return thread_spikes_.get().local(); - } - - template <typename Container> - void clear(Container& c) { - for(auto& item : c) { - item.clear(); - } - } - - void clear_spikes() { clear(thread_spikes_.get()); } - void clear_buffered_spikes() { clear(buffered_spikes()); } - - void clear_events() { clear(events_.get()); } - void clear_buffered_events() { clear(events_.other()); } - - std::vector<spike_type> gather_local_spikes() { - std::vector<spike_type> spikes; - for (auto& v : buffered_spikes()) { - spikes.insert(spikes.end(), v.begin(), v.end()); - } - return spikes; - } - - std::size_t cell_group_index(cell_gid_type cell_gid) const { - // this will be more elaborate when there is more than one cell per cell group - EXPECTS(cell_gid>=cell_gid_from_ && cell_gid<cell_gid_to_); - return cell_gid-cell_gid_from_; - } - - std::vector<connection_type> connections_; - - communication_policy_type communication_policy_; - - uint64_t num_spikes_ = 0u; - id_type cell_gid_from_; - id_type cell_gid_to_; - -public: communicator() = default; // for now, still assuming one-to-one association cells <-> groups, @@ -119,12 +65,11 @@ public: // contiguous. communicator(id_type cell_from, id_type cell_to): cell_gid_from_(cell_from), cell_gid_to_(cell_to) - { - auto num_groups_local_ = cell_gid_to_-cell_gid_from_; + {} - // create an event queue for each target group - events_.get().resize(num_groups_local_); - events_.other().resize(num_groups_local_); + cell_local_size_type num_groups_local() const + { + return cell_gid_to_-cell_gid_from_; } void add_connection(connection_type con) { @@ -152,29 +97,18 @@ public: return communication_policy_.min(local_min); } - void add_spike(spike_type s) { - thread_spikes().push_back(s); - } - - void add_spikes(const std::vector<spike_type>& s) { - auto& v = thread_spikes(); - v.insert(v.end(), s.begin(), s.end()); - } - - void exchange() { + std::vector<event_queue> exchange(const local_spike_store_type& spikes) { // global all-to-all to gather a local copy of the global spike list // on each node - auto global_spikes = communication_policy_.gather_spikes(gather_local_spikes()); + auto global_spikes = + communication_policy_.gather_spikes( + merge_spike_store(spikes) + ); num_spikes_ += global_spikes.size(); - clear_buffered_spikes(); - - // clear the event queue buffers, which will store the events generated from the - // global_spikes below - clear_buffered_events(); // Check each global spike in turn to see it generates local events. // If so, make the events and insert them into the appropriate event list - auto& queues = events_.other(); + auto queues = std::vector<event_queue>(num_groups_local()); for (auto spike : global_spikes) { // search for targets auto targets = @@ -188,14 +122,12 @@ public: queues[gidx].push_back(it->make_event(spike)); } } + + return queues; } uint64_t num_spikes() const { return num_spikes_; } - const std::vector<postsynaptic_spike_event<time_type>>& queue(int i) const { - return events_.get()[i]; - } - const std::vector<connection_type>& connections() const { return connections_; } @@ -204,17 +136,28 @@ public: return communication_policy_; } - void swap_buffers() { - thread_spikes_.exchange(); - events_.exchange(); +private: + std::vector<spike_type> merge_spike_store(const local_spike_store_type& buffers) { + std::vector<spike_type> spikes; + for (auto& v : buffers) { + spikes.insert(spikes.end(), v.begin(), v.end()); + } + return spikes; } - void reset() { - clear_buffered_events(); - clear_buffered_spikes(); - clear_events(); - clear_spikes(); + std::size_t cell_group_index(cell_gid_type cell_gid) const { + // this will be more elaborate when there is more than one cell per cell group + EXPECTS(cell_gid>=cell_gid_from_ && cell_gid<cell_gid_to_); + return cell_gid-cell_gid_from_; } + + std::vector<connection_type> connections_; + + communication_policy_type communication_policy_; + + uint64_t num_spikes_ = 0u; + id_type cell_gid_from_; + id_type cell_gid_to_; }; } // namespace communication diff --git a/src/model.hpp b/src/model.hpp index 44d0b1149791c32dac15fa9e241df4e80912701f..019633387bc259ba1079f46fd20f3819abc52df4 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -76,8 +76,8 @@ public: while (t_<tfinal) { auto tuntil = std::min(t_+min_delay, tfinal); - // this is crude: the flow of spikes and events should be modelled explicitly - communicator_.swap_buffers(); + event_queues_.exchange(); + local_spikes_.exchange(); tbb::task_group g; @@ -90,13 +90,13 @@ public: auto &group = cell_groups_[i]; PE("stepping","events"); - group.enqueue_events(communicator_.queue(i)); + group.enqueue_events(current_events()[i]); PL(); group.advance(tuntil, dt); PE("events"); - communicator_.add_spikes(group.spikes()); + buffer_spikes(group.spikes()); group.clear_spikes(); PL(2); }); @@ -104,7 +104,7 @@ public: auto exchange = [&] () { PE("stepping", "exchange"); - communicator_.exchange(); + future_events() = communicator_.exchange(current_spikes()); PL(2); }; @@ -117,12 +117,14 @@ public: return t_; } + // TODO : these two calls are only thread safe if called outside the main time + // stepping loop. void add_artificial_spike(cell_member_type source) { add_artificial_spike(source, t_); } void add_artificial_spike(cell_member_type source, time_type tspike) { - communicator_.add_spike({source, tspike}); + previous_spikes().local().push_back({source, tspike}); } void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0) { @@ -144,6 +146,23 @@ private: std::vector<cell_group_type> cell_groups_; communicator_type communicator_; std::vector<probe_record> probes_; + using spike_type = typename communicator_type::spike_type; + + using event_queue_type = typename communicator_type::event_queue; + util::double_buffer< std::vector<event_queue_type> > event_queues_; + + using local_spike_store_type = typename communicator_type::local_spike_store_type; + util::double_buffer< local_spike_store_type > local_spikes_; + + local_spike_store_type& current_spikes() { return local_spikes_.get(); } + local_spike_store_type& previous_spikes() { return local_spikes_.other(); } + std::vector<event_queue_type>& current_events() { return event_queues_.get(); } + std::vector<event_queue_type>& future_events() { return event_queues_.other(); } + + void buffer_spikes(const std::vector<spike_type>& s) { + auto& buff = current_spikes().local(); + buff.insert(buff.end(), s.begin(), s.end()); + } }; } // namespace mc