From 4b07e613e5df4c771ab80f6118a240627810ec7b Mon Sep 17 00:00:00 2001 From: Ben Cumming <bcumming@cscs.ch> Date: Thu, 20 Jun 2019 13:04:33 +0200 Subject: [PATCH] Move implementation of communicator to separate TU (#790) --- arbor/CMakeLists.txt | 1 + arbor/communication/communicator.cpp | 223 +++++++++++++++++++++++++++ arbor/communication/communicator.hpp | 204 ++---------------------- 3 files changed, 233 insertions(+), 195 deletions(-) create mode 100644 arbor/communication/communicator.cpp diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt index 936aa2d9..923a4c97 100644 --- a/arbor/CMakeLists.txt +++ b/arbor/CMakeLists.txt @@ -6,6 +6,7 @@ set(arbor_sources backends/multicore/mechanism.cpp backends/multicore/shared_state.cpp backends/multicore/stimulus.cpp + communication/communicator.cpp communication/dry_run_context.cpp benchmark_cell_group.cpp builtin_mechanisms.cpp diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp new file mode 100644 index 00000000..50acd362 --- /dev/null +++ b/arbor/communication/communicator.cpp @@ -0,0 +1,223 @@ +#include <utility> +#include <vector> + +#include <arbor/assert.hpp> +#include <arbor/common_types.hpp> +#include <arbor/domain_decomposition.hpp> +#include <arbor/recipe.hpp> +#include <arbor/spike.hpp> + +#include "algorithms.hpp" +#include "communication/gathered_vector.hpp" +#include "connection.hpp" +#include "distributed_context.hpp" +#include "execution_context.hpp" +#include "profile/profiler_macro.hpp" +#include "threading/threading.hpp" +#include "util/partition.hpp" +#include "util/rangeutil.hpp" +#include "util/span.hpp" + +#include "communication/communicator.hpp" + +namespace arb { + +communicator::communicator(const recipe& rec, + const domain_decomposition& dom_dec, + execution_context& ctx) +{ + distributed_ = ctx.distributed; + thread_pool_ = ctx.thread_pool; + + num_domains_ = distributed_->size(); + num_local_groups_ = dom_dec.groups.size(); + num_local_cells_ = dom_dec.num_local_cells; + + // For caching information about each cell + struct gid_info { + using connection_list = decltype(std::declval<recipe>().connections_on(0)); + cell_gid_type gid; // global identifier of cell + cell_size_type index_on_domain; // index of cell in this domain + connection_list conns; // list of connections terminating at this cell + gid_info() = default; // so we can in a std::vector + gid_info(cell_gid_type g, cell_size_type di, connection_list c): + gid(g), index_on_domain(di), conns(std::move(c)) {} + }; + + // Make a list of local gid with their group index and connections + // -> gid_infos + // Count the number of local connections (i.e. connections terminating on this domain) + // -> n_cons: scalar + // Calculate and store domain id of the presynaptic cell on each local connection + // -> src_domains: array with one entry for every local connection + // Also the count of presynaptic sources from each domain + // -> src_counts: array with one entry for each domain + + // Record all the gid in a flat vector. + // These are used to map from local index to gid in the parallel loop + // that populates gid_infos. + std::vector<cell_gid_type> gids; + gids.reserve(num_local_cells_); + for (auto g: dom_dec.groups) { + util::append(gids, g.gids); + } + // Build the connection information for local cells in parallel. + std::vector<gid_info> gid_infos; + gid_infos.resize(num_local_cells_); + threading::parallel_for::apply(0, gids.size(), thread_pool_.get(), + [&](cell_size_type i) { + auto gid = gids[i]; + gid_infos[i] = gid_info(gid, i, rec.connections_on(gid)); + }); + + cell_local_size_type n_cons = + util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); }); + std::vector<unsigned> src_domains; + src_domains.reserve(n_cons); + std::vector<cell_size_type> src_counts(num_domains_); + for (const auto& g: gid_infos) { + for (auto con: g.conns) { + const auto src = dom_dec.gid_domain(con.source.gid); + src_domains.push_back(src); + src_counts[src]++; + } + } + + // Construct the connections. + // The loop above gave the information required to construct in place + // the connections as partitioned by the domain of their source gid. + connections_.resize(n_cons); + connection_part_ = algorithms::make_index(src_counts); + auto offsets = connection_part_; + std::size_t pos = 0; + for (const auto& cell: gid_infos) { + for (auto c: cell.conns) { + const auto i = offsets[src_domains[pos]]++; + connections_[i] = {c.source, c.dest, c.weight, c.delay, cell.index_on_domain}; + ++pos; + } + } + + // Build cell partition by group for passing events to cell groups + index_part_ = util::make_partition(index_divisions_, + util::transform_view( + dom_dec.groups, + [](const group_description& g){return g.gids.size();})); + + // Sort the connections for each domain. + // This is num_domains_ independent sorts, so it can be parallelized trivially. + const auto& cp = connection_part_; + threading::parallel_for::apply(0, num_domains_, thread_pool_.get(), + [&](cell_size_type i) { + util::sort(util::subrange_view(connections_, cp[i], cp[i+1])); + }); +} + +std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_size_type i) { + arb_assert(i<num_local_groups_); + return index_part_[i]; +} + +time_type communicator::min_delay() { + auto local_min = std::numeric_limits<time_type>::max(); + for (auto& con : connections_) { + local_min = std::min(local_min, con.delay()); + } + + return distributed_->min(local_min); +} + +gathered_vector<spike> communicator::exchange(std::vector<spike> local_spikes) { + PE(communication_exchange_sort); + // sort the spikes in ascending order of source gid + util::sort_by(local_spikes, [](spike s){return s.source;}); + PL(); + + PE(communication_exchange_gather); + // global all-to-all to gather a local copy of the global spike list on each node. + auto global_spikes = distributed_->gather_spikes(local_spikes); + num_spikes_ += global_spikes.size(); + PL(); + + return global_spikes; +} + +void communicator::make_event_queues( + const gathered_vector<spike>& global_spikes, + std::vector<pse_vector>& queues) +{ + arb_assert(queues.size()==num_local_cells_); + + using util::subrange_view; + using util::make_span; + using util::make_range; + + const auto& sp = global_spikes.partition(); + const auto& cp = connection_part_; + for (auto dom: make_span(num_domains_)) { + auto cons = subrange_view(connections_, cp[dom], cp[dom+1]); + auto spks = subrange_view(global_spikes.values(), sp[dom], sp[dom+1]); + + struct spike_pred { + bool operator()(const spike& spk, const cell_member_type& src) + {return spk.source<src;} + bool operator()(const cell_member_type& src, const spike& spk) + {return src<spk.source;} + }; + + // We have a choice of whether to walk spikes or connections: + // i.e., we can iterate over the spikes, and for each spike search + // the for connections that have the same source; or alternatively + // for each connection, we can search the list of spikes for spikes + // with the same source. + // + // We iterate over whichever set is the smallest, which has + // complexity of order max(S log(C), C log(S)), where S is the + // number of spikes, and C is the number of connections. + if (cons.size()<spks.size()) { + auto sp = spks.begin(); + auto cn = cons.begin(); + while (cn!=cons.end() && sp!=spks.end()) { + auto sources = std::equal_range(sp, spks.end(), cn->source(), spike_pred()); + for (auto s: make_range(sources)) { + queues[cn->index_on_domain()].push_back(cn->make_event(s)); + } + + sp = sources.first; + ++cn; + } + } + else { + auto cn = cons.begin(); + auto sp = spks.begin(); + while (cn!=cons.end() && sp!=spks.end()) { + auto targets = std::equal_range(cn, cons.end(), sp->source); + for (auto c: make_range(targets)) { + queues[c.index_on_domain()].push_back(c.make_event(*sp)); + } + + cn = targets.first; + ++sp; + } + } + } +} + +std::uint64_t communicator::num_spikes() const { + return num_spikes_; +} + +cell_size_type communicator::num_local_cells() const { + return num_local_cells_; +} + +const std::vector<connection>& communicator::connections() const { + return connections_; +} + +void communicator::reset() { + num_spikes_ = 0; +} + +} // namespace arb + diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp index 38613f4e..2cc98a78 100644 --- a/arbor/communication/communicator.hpp +++ b/arbor/communication/communicator.hpp @@ -1,29 +1,16 @@ #pragma once -#include <algorithm> -#include <functional> -#include <iostream> -#include <random> -#include <utility> #include <vector> -#include <arbor/assert.hpp> #include <arbor/common_types.hpp> #include <arbor/domain_decomposition.hpp> #include <arbor/recipe.hpp> #include <arbor/spike.hpp> -#include "algorithms.hpp" #include "communication/gathered_vector.hpp" #include "connection.hpp" -#include "distributed_context.hpp" #include "execution_context.hpp" -#include "profile/profiler_macro.hpp" -#include "threading/threading.hpp" -#include "util/double_buffer.hpp" #include "util/partition.hpp" -#include "util/rangeutil.hpp" -#include "util/span.hpp" namespace arb { @@ -44,129 +31,19 @@ public: explicit communicator(const recipe& rec, const domain_decomposition& dom_dec, - execution_context& ctx) - { - distributed_ = ctx.distributed; - thread_pool_ = ctx.thread_pool; - - num_domains_ = distributed_->size(); - num_local_groups_ = dom_dec.groups.size(); - num_local_cells_ = dom_dec.num_local_cells; - - // For caching information about each cell - struct gid_info { - using connection_list = decltype(std::declval<recipe>().connections_on(0)); - cell_gid_type gid; // global identifier of cell - cell_size_type index_on_domain; // index of cell in this domain - connection_list conns; // list of connections terminating at this cell - gid_info() = default; // so we can in a std::vector - gid_info(cell_gid_type g, cell_size_type di, connection_list c): - gid(g), index_on_domain(di), conns(std::move(c)) {} - }; - - // Make a list of local gid with their group index and connections - // -> gid_infos - // Count the number of local connections (i.e. connections terminating on this domain) - // -> n_cons: scalar - // Calculate and store domain id of the presynaptic cell on each local connection - // -> src_domains: array with one entry for every local connection - // Also the count of presynaptic sources from each domain - // -> src_counts: array with one entry for each domain - - // Record all the gid in a flat vector. - // These are used to map from local index to gid in the parallel loop - // that populates gid_infos. - std::vector<cell_gid_type> gids; - gids.reserve(num_local_cells_); - for (auto g: dom_dec.groups) { - util::append(gids, g.gids); - } - // Build the connection information for local cells in parallel. - std::vector<gid_info> gid_infos; - gid_infos.resize(num_local_cells_); - threading::parallel_for::apply(0, gids.size(), thread_pool_.get(), - [&](cell_size_type i) { - auto gid = gids[i]; - gid_infos[i] = gid_info(gid, i, rec.connections_on(gid)); - }); - - cell_local_size_type n_cons = - util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); }); - std::vector<unsigned> src_domains; - src_domains.reserve(n_cons); - std::vector<cell_size_type> src_counts(num_domains_); - for (const auto& g: gid_infos) { - for (auto con: g.conns) { - const auto src = dom_dec.gid_domain(con.source.gid); - src_domains.push_back(src); - src_counts[src]++; - } - } - - // Construct the connections. - // The loop above gave the information required to construct in place - // the connections as partitioned by the domain of their source gid. - connections_.resize(n_cons); - connection_part_ = algorithms::make_index(src_counts); - auto offsets = connection_part_; - std::size_t pos = 0; - for (const auto& cell: gid_infos) { - for (auto c: cell.conns) { - const auto i = offsets[src_domains[pos]]++; - connections_[i] = {c.source, c.dest, c.weight, c.delay, cell.index_on_domain}; - ++pos; - } - } - - // Build cell partition by group for passing events to cell groups - index_part_ = util::make_partition(index_divisions_, - util::transform_view( - dom_dec.groups, - [](const group_description& g){return g.gids.size();})); - - // Sort the connections for each domain. - // This is num_domains_ independent sorts, so it can be parallelized trivially. - const auto& cp = connection_part_; - threading::parallel_for::apply(0, num_domains_, thread_pool_.get(), - [&](cell_size_type i) { - util::sort(util::subrange_view(connections_, cp[i], cp[i+1])); - }); - } + execution_context& ctx); /// The range of event queues that belong to cells in group i. - std::pair<cell_size_type, cell_size_type> group_queue_range(cell_size_type i) { - arb_assert(i<num_local_groups_); - return index_part_[i]; - } + std::pair<cell_size_type, cell_size_type> group_queue_range(cell_size_type i); /// The minimum delay of all connections in the global network. - time_type min_delay() { - auto local_min = std::numeric_limits<time_type>::max(); - for (auto& con : connections_) { - local_min = std::min(local_min, con.delay()); - } - - return distributed_->min(local_min); - } + time_type min_delay(); /// Perform exchange of spikes. /// /// Takes as input the list of local_spikes that were generated on the calling domain. /// Returns the full global set of vectors, along with meta data about their partition - gathered_vector<spike> exchange(std::vector<spike> local_spikes) { - PE(communication_exchange_sort); - // sort the spikes in ascending order of source gid - util::sort_by(local_spikes, [](spike s){return s.source;}); - PL(); - - PE(communication_exchange_gather); - // global all-to-all to gather a local copy of the global spike list on each node. - auto global_spikes = distributed_->gather_spikes(local_spikes); - num_spikes_ += global_spikes.size(); - PL(); - - return global_spikes; - } + gathered_vector<spike> exchange(std::vector<spike> local_spikes); /// Check each global spike in turn to see it generates local events. /// If so, make the events and insert them into the appropriate event list. @@ -178,79 +55,16 @@ public: /// in the list. void make_event_queues( const gathered_vector<spike>& global_spikes, - std::vector<pse_vector>& queues) - { - arb_assert(queues.size()==num_local_cells_); - - using util::subrange_view; - using util::make_span; - using util::make_range; - - const auto& sp = global_spikes.partition(); - const auto& cp = connection_part_; - for (auto dom: make_span(num_domains_)) { - auto cons = subrange_view(connections_, cp[dom], cp[dom+1]); - auto spks = subrange_view(global_spikes.values(), sp[dom], sp[dom+1]); - - struct spike_pred { - bool operator()(const spike& spk, const cell_member_type& src) - {return spk.source<src;} - bool operator()(const cell_member_type& src, const spike& spk) - {return src<spk.source;} - }; - - // We have a choice of whether to walk spikes or connections: - // i.e., we can iterate over the spikes, and for each spike search - // the for connections that have the same source; or alternatively - // for each connection, we can search the list of spikes for spikes - // with the same source. - // - // We iterate over whichever set is the smallest, which has - // complexity of order max(S log(C), C log(S)), where S is the - // number of spikes, and C is the number of connections. - if (cons.size()<spks.size()) { - auto sp = spks.begin(); - auto cn = cons.begin(); - while (cn!=cons.end() && sp!=spks.end()) { - auto sources = std::equal_range(sp, spks.end(), cn->source(), spike_pred()); - for (auto s: make_range(sources)) { - queues[cn->index_on_domain()].push_back(cn->make_event(s)); - } - - sp = sources.first; - ++cn; - } - } - else { - auto cn = cons.begin(); - auto sp = spks.begin(); - while (cn!=cons.end() && sp!=spks.end()) { - auto targets = std::equal_range(cn, cons.end(), sp->source); - for (auto c: make_range(targets)) { - queues[c.index_on_domain()].push_back(c.make_event(*sp)); - } - - cn = targets.first; - ++sp; - } - } - } - } + std::vector<pse_vector>& queues); /// Returns the total number of global spikes over the duration of the simulation - std::uint64_t num_spikes() const { return num_spikes_; } + std::uint64_t num_spikes() const; - cell_size_type num_local_cells() const { - return num_local_cells_; - } + cell_size_type num_local_cells() const; - const std::vector<connection>& connections() const { - return connections_; - } + const std::vector<connection>& connections() const; - void reset() { - num_spikes_ = 0; - } + void reset(); private: cell_size_type num_local_cells_; -- GitLab