diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp index 50ce644c2bc62904625fd74c68fe7a98de059395..5da7f1898028bdc479f4cc430cae89a64ae7fccb 100644 --- a/src/communication/communicator.hpp +++ b/src/communication/communicator.hpp @@ -6,13 +6,13 @@ #include <random> #include <functional> -#include <spike.hpp> -#include <util/double_buffer.hpp> #include <algorithms.hpp> #include <connection.hpp> +#include <communication/gathered_vector.hpp> #include <event_queue.hpp> #include <spike.hpp> #include <util/debug.hpp> +#include <util/double_buffer.hpp> namespace nest { namespace mc { @@ -87,29 +87,26 @@ public: /// Perform exchange of spikes. /// /// Takes as input the list of local_spikes that were generated on the calling domain. - /// - /// Returns a vector of event queues, with one queue for each local cell group. The - /// events in each queue are all events that must be delivered to targets in that cell - /// group as a result of the global spike exchange. - std::vector<event_queue> exchange(const std::vector<spike_type>& local_spikes, - std::function<void(const std::vector<spike_type>&)> global_export_callback) - { + /// Returns the full global set of vectors, along with meta data about their partition + gathered_vector<spike_type> exchange(const std::vector<spike_type>& local_spikes) { // global all-to-all to gather a local copy of the global spike list on each node. auto global_spikes = communication_policy_.gather_spikes( local_spikes ); num_spikes_ += global_spikes.size(); + return global_spikes; + } - global_export_callback(global_spikes); - - // check each global spike in turn to see it generates local events. - // if so, make the events and insert them into the appropriate event list. + /// Check each global spike in turn to see it generates local events. + /// If so, make the events and insert them into the appropriate event list. + /// Return a vector that contains the event queues for each local cell group. + /// + /// Returns a vector of event queues, with one queue for each local cell group. The + /// events in each queue are all events that must be delivered to targets in that cell + /// group as a result of the global spike exchange. + std::vector<event_queue> make_event_queues(const gathered_vector<spike_type>& global_spikes) { auto queues = std::vector<event_queue>(num_groups_local()); - - for (auto spike : global_spikes) { + for (auto spike : global_spikes.values()) { // search for targets - auto targets = - std::equal_range( - connections_.begin(), connections_.end(), spike.source - ); + auto targets = std::equal_range(connections_.begin(), connections_.end(), spike.source); // generate an event for each target for (auto it=targets.first; it!=targets.second; ++it) { @@ -121,8 +118,7 @@ public: return queues; } - /// Returns the total number of global spikes over the duration - /// of the simulation + /// Returns the total number of global spikes over the duration of the simulation uint64_t num_spikes() const { return num_spikes_; } const std::vector<connection_type>& connections() const { diff --git a/src/communication/gathered_vector.hpp b/src/communication/gathered_vector.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c83f74430a5564096604f3daf2b7aeda90c4b86 --- /dev/null +++ b/src/communication/gathered_vector.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include <cstdint> +#include <numeric> +#include <vector> + +#include <algorithms.hpp> + +namespace nest { +namespace mc { + +template <typename T> +class gathered_vector { +public: + using value_type = T; + using count_type = unsigned; + + gathered_vector(std::vector<value_type>&& v, std::vector<count_type>&& p) : + values_(std::move(v)), + partition_(std::move(p)) + { + EXPECTS(std::is_sorted(partition_.begin(), partition_.end())); + EXPECTS(std::size_t(partition_.back()) == v.size()); + } + + /// the partition of distribution + const std::vector<count_type>& partition() const { + return partition_; + } + + /// the number of entries in the gathered vector in partiion i + count_type count(std::size_t i) const { + return partition_[i+1] - partition[i]; + } + + /// the values in the gathered vector + const std::vector<value_type>& values() const { + return values_; + } + + /// the size of the gathered vector + std::size_t size() const { + return values_.size(); + } + +private: + std::vector<value_type> values_; + std::vector<count_type> partition_; +}; + +} // namespace mc +} // namespace nest diff --git a/src/communication/mpi.hpp b/src/communication/mpi.hpp index 5be71b6381bf8311bcab85d5396d64ba2c3eb26b..5b43f1a86df763426184d9d998f574cbb2a091da 100644 --- a/src/communication/mpi.hpp +++ b/src/communication/mpi.hpp @@ -10,6 +10,7 @@ #include <mpi.h> #include <algorithms.hpp> +#include <communication/gathered_vector.hpp> namespace nest { namespace mc { @@ -121,6 +122,38 @@ namespace mpi { return buffer; } + /// Gather all of a distributed vector + /// Retains the meta data (i.e. vector partition) + template <typename T> + gathered_vector<T> gather_all_meta(const std::vector<T>& values) { + using gathered_type = gathered_vector<T>; + using count_type = typename gathered_vector<T>::count_type; + using traits = mpi_traits<T>; + + // We have to use int for the count and displs vectors instead + // of count_type because these are used as arguments to MPI_Allgatherv + // which expects int arguments. + auto counts = gather_all(int(values.size())); + for (auto& c : counts) { + c *= traits::count(); + } + auto displs = algorithms::make_index(counts); + + std::vector<T> buffer(displs.back()/traits::count()); + + MPI_Allgatherv( + // send buffer + values.data(), counts[rank()], traits::mpi_type(), + // receive buffer + buffer.data(), counts.data(), displs.data(), traits::mpi_type(), + MPI_COMM_WORLD + ); + + std::vector<count_type> part(displs.size()); + std::copy(displs.begin(), displs.end(), part.begin()); + return gathered_type(std::move(buffer), std::move(part)); + } + template <typename T> T reduce(T value, MPI_Op op, int root) { using traits = mpi_traits<T>; diff --git a/src/communication/mpi_global_policy.hpp b/src/communication/mpi_global_policy.hpp index 0e9ec33370695990e2e2fa0d8d03d4a1e8b26c26..8bb597dd0bba87874a6098983a958a91e7f71a05 100644 --- a/src/communication/mpi_global_policy.hpp +++ b/src/communication/mpi_global_policy.hpp @@ -10,6 +10,7 @@ #include <algorithms.hpp> #include <common_types.hpp> +#include <communication/gathered_vector.hpp> #include <communication/mpi.hpp> #include <spike.hpp> @@ -18,10 +19,10 @@ namespace mc { namespace communication { struct mpi_global_policy { - template <typename I, typename T> - static std::vector<spike<I, T>> - gather_spikes(const std::vector<spike<I,T>>& local_spikes) { - return mpi::gather_all(local_spikes); + template <typename Spike> + static gathered_vector<Spike> + gather_spikes(const std::vector<Spike>& local_spikes) { + return mpi::gather_all_meta(local_spikes); } static int id() { return mpi::rank(); } diff --git a/src/communication/serial_global_policy.hpp b/src/communication/serial_global_policy.hpp index 9af5eae3fd0dd84be71cb6346018c607fa9ab3df..b048b6cfbf249596064e06ee6300d2e967c71b1e 100644 --- a/src/communication/serial_global_policy.hpp +++ b/src/communication/serial_global_policy.hpp @@ -4,6 +4,7 @@ #include <type_traits> #include <vector> +#include <communication/gathered_vector.hpp> #include <spike.hpp> namespace nest { @@ -11,10 +12,10 @@ namespace mc { namespace communication { struct serial_global_policy { - template <typename I, typename T> - static const std::vector<spike<I, T>>& - gather_spikes(const std::vector<spike<I, T>>& local_spikes) { - return local_spikes; + template <typename Spike> + static gathered_vector<Spike> + gather_spikes(const std::vector<Spike>& local_spikes) { + return gathered_vector<Spike>(std::vector<Spike>(local_spikes), {0u, 1u}); } static int id() { diff --git a/src/model.hpp b/src/model.hpp index 4fa778e06cf366826774fe4c980f1270b52c3cb1..ae7b4a7a199de3e4dee4f305ccabe498b160b96b 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -130,18 +130,23 @@ public: // the previous integration period, generating the postsynaptic // events that must be delivered at the start of the next // integration period at the latest. - - //TODO: - //An improvement might be : - //the exchange method simply exchanges spikes, and does not generate the event queues.It returns a struct that has both 1) the global spike list 2) an integer vector that describes the distribution of spikes across the ranks - //another method called something like build_queues that takes this spike info and returns the local spikes - // and the callbacks can then be called on the spike information directly in the model. auto exchange = [&] () { - PE("stepping", "exchange"); + PE("stepping", "communciation"); + + PE("exchange"); auto local_spikes = previous_spikes().gather(); + auto global_spikes = communicator_.exchange(local_spikes); + PL(); + + PE("spike output"); local_export_callback_(local_spikes); - future_events() = - communicator_.exchange(local_spikes, global_export_callback_); + global_export_callback_(global_spikes.values()); + PL(); + + PE("events"); + future_events() = communicator_.make_event_queues(global_spikes); + PL(); + PL(2); };