diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp index 9936eaabba65cae48e358e88143eacd667da643a..4bc17f4d1343d7a01d4a190a3868eb1f5040d780 100644 --- a/arbor/communication/dry_run_context.cpp +++ b/arbor/communication/dry_run_context.cpp @@ -70,26 +70,6 @@ struct dry_run_context_impl { return gathered_vector<cell_gid_type>(std::move(gathered_gids), std::move(partition)); } - std::vector<std::vector<cell_gid_type>> - gather_gj_connections(const std::vector<std::vector<cell_gid_type>> & local_connections) const { - auto local_size = local_connections.size(); - std::vector<std::vector<cell_gid_type>> global_connections; - global_connections.reserve(local_size*num_ranks_); - - for (unsigned i = 0; i < num_ranks_; i++) { - util::append(global_connections, local_connections); - } - - for (unsigned i = 0; i < num_ranks_; i++) { - for (unsigned j = i*local_size; j < (i+1)*local_size; j++){ - for (auto& conn_gid: global_connections[j]) { - conn_gid += num_cells_per_tile_*i; - } - } - } - return global_connections; - } - cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { cell_label_range global_ranges; for (unsigned i = 0; i < num_ranks_; i++) { diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp index c2fcf64af76ffa0f226ad23e7e4ef27a87916b0a..8109a5a88931575da10d6a839684fb8f480e960a 100644 --- a/arbor/communication/mpi_context.cpp +++ b/arbor/communication/mpi_context.cpp @@ -53,11 +53,6 @@ struct mpi_context_impl { return mpi::gather_all_with_partition(local_gids, comm_); } - std::vector<std::vector<cell_gid_type>> - gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const { - return mpi::gather_all(local_connections, comm_); - } - cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { cell_label_range res; res.sizes = mpi::gather_all(local_ranges.sizes, comm_); @@ -139,11 +134,6 @@ struct remote_context_impl { gathered_vector<cell_gid_type> gather_gids(const std::vector<cell_gid_type>& local_gids) const { return mpi_.gather_gids(local_gids); } - std::vector<std::vector<cell_gid_type>> - gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const { - return mpi_.gather_gj_connections(local_connections); - } - cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { return mpi_.gather_cell_label_range(local_ranges); } diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp index c988177d33224d88c33cb7e5bf7ce80e92dba1b2..76e8f4c0a9db1b7bc5a890aa26c408eb765b78d3 100644 --- a/arbor/distributed_context.hpp +++ b/arbor/distributed_context.hpp @@ -72,10 +72,6 @@ public: return impl_->gather_gids(local_gids); } - gj_connection_vector gather_gj_connections(const gj_connection_vector& local_connections) const { - return impl_->gather_gj_connections(local_connections); - } - cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { return impl_->gather_cell_label_range(local_ranges); } @@ -117,8 +113,6 @@ private: remote_gather_spikes(const spike_vector& local_spikes) const = 0; virtual gathered_vector<cell_gid_type> gather_gids(const gid_vector& local_gids) const = 0; - virtual gj_connection_vector - gather_gj_connections(const gj_connection_vector& local_connections) const = 0; virtual cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const = 0; virtual cell_labels_and_gids @@ -154,10 +148,6 @@ private: gather_gids(const gid_vector& local_gids) const override { return wrapped.gather_gids(local_gids); } - std::vector<std::vector<cell_gid_type>> - gather_gj_connections(const gj_connection_vector& local_connections) const override { - return wrapped.gather_gj_connections(local_connections); - } cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const override { return wrapped.gather_cell_label_range(local_ranges); @@ -217,10 +207,6 @@ struct local_context { } void remote_ctrl_send_continue(const epoch&) const {} void remote_ctrl_send_done() const {} - std::vector<std::vector<cell_gid_type>> - gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const { - return local_connections; - } cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { return local_ranges; diff --git a/arbor/domain_decomposition.cpp b/arbor/domain_decomposition.cpp index aa22082120a42c53a1359d8e1ea11a7c94fef2a7..f9c14a03d9f85c9d382233d3c8bd1db3f0a5e26b 100644 --- a/arbor/domain_decomposition.cpp +++ b/arbor/domain_decomposition.cpp @@ -13,11 +13,9 @@ #include "util/span.hpp" namespace arb { -domain_decomposition::domain_decomposition( - const recipe& rec, - context ctx, - const std::vector<group_description>& groups) -{ +domain_decomposition::domain_decomposition(const recipe& rec, + context ctx, + const std::vector<group_description>& groups) { struct partition_gid_domain { partition_gid_domain(const gathered_vector<cell_gid_type>& divs, unsigned domains) { auto rank_part = util::partition_view(divs.partition()); @@ -27,9 +25,7 @@ domain_decomposition::domain_decomposition( } } } - int operator()(cell_gid_type gid) const { - return gid_map.at(gid); - } + int operator()(cell_gid_type gid) const { return gid_map.at(gid); } std::unordered_map<cell_gid_type, int> gid_map; }; @@ -41,22 +37,14 @@ domain_decomposition::domain_decomposition( std::vector<cell_gid_type> local_gids; for (const auto& g: groups) { - if (g.backend == backend_kind::gpu && !has_gpu) { - throw invalid_backend(domain_id); - } - if (g.backend == backend_kind::gpu && g.kind != cell_kind::cable) { - throw incompatible_backend(domain_id, g.kind); - } + if (g.backend == backend_kind::gpu && !has_gpu) throw invalid_backend(domain_id); + if (g.backend == backend_kind::gpu && g.kind != cell_kind::cable) throw incompatible_backend(domain_id, g.kind); std::unordered_set<cell_gid_type> gid_set(g.gids.begin(), g.gids.end()); for (const auto& gid: g.gids) { - if (gid >= num_global_cells) { - throw out_of_bounds(gid, num_global_cells); - } + if (gid >= num_global_cells) throw out_of_bounds(gid, num_global_cells); for (const auto& gj: rec.gap_junctions_on(gid)) { - if (!gid_set.count(gj.peer.gid)) { - throw invalid_gj_cell_group(gid, gj.peer.gid); - } + if (!gid_set.count(gj.peer.gid)) throw invalid_gj_cell_group(gid, gj.peer.gid); } } local_gids.insert(local_gids.end(), g.gids.begin(), g.gids.end()); @@ -64,16 +52,12 @@ domain_decomposition::domain_decomposition( cell_size_type num_local_cells = local_gids.size(); auto global_gids = dist->gather_gids(local_gids); - if (global_gids.size() != num_global_cells) { - throw invalid_sum_local_cells(global_gids.size(), num_global_cells); - } + if (global_gids.size() != num_global_cells) throw invalid_sum_local_cells(global_gids.size(), num_global_cells); auto global_gid_vals = global_gids.values(); util::sort(global_gid_vals); for (unsigned i = 1; i < global_gid_vals.size(); ++i) { - if (global_gid_vals[i] == global_gid_vals[i-1]) { - throw duplicate_gid(global_gid_vals[i]); - } + if (global_gid_vals[i] == global_gid_vals[i-1]) throw duplicate_gid(global_gid_vals[i]); } num_domains_ = num_domains; diff --git a/arbor/execution_context.hpp b/arbor/execution_context.hpp index fb75f60cbd3ad87e63a840a660651eb5d5748212..b14fe01b344e06948287cecbbc02a221b28a2d7f 100644 --- a/arbor/execution_context.hpp +++ b/arbor/execution_context.hpp @@ -34,7 +34,6 @@ struct ARB_ARBOR_API execution_context { template <typename Comm> execution_context(const proc_allocation& resources, Comm comm, Comm remote); - }; } // namespace arb diff --git a/arbor/include/arbor/common_types.hpp b/arbor/include/arbor/common_types.hpp index 5e57f723c26e5d6641b1c8379166175cb63198e9..e2d1eccdc7952aa5ff4b01cca3cee632d850f173 100644 --- a/arbor/include/arbor/common_types.hpp +++ b/arbor/include/arbor/common_types.hpp @@ -142,10 +142,12 @@ using sample_index_type = std::int32_t; using sample_size_type = std::uint32_t; // Enumeration for execution back-end targets, as specified in domain decompositions. - +// NOTE(important): Given in order of priority, ie we will attempt schedule gpu before +// MC groups, for reasons of effiency. Ugly, but as we do not have more +// backends, this is OK for now. enum class backend_kind { + gpu, // Use gpu back-end when supported by cell_group implementation. multicore, // Use multicore back-end for all computation. - gpu // Use gpu back-end when supported by cell_group implementation. }; // Enumeration used to indentify the cell type/kind, used by the model to diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp index 8718c743d111eb83506ab8876321e4a804314bbc..550c2144b9b50a3b96aa0eb2d58fad7a1d5d1ab7 100644 --- a/arbor/partition_load_balance.cpp +++ b/arbor/partition_load_balance.cpp @@ -1,18 +1,16 @@ -#include <queue> #include <unordered_map> #include <unordered_set> #include <vector> +#include <algorithm> #include <arbor/domdecexcept.hpp> #include <arbor/domain_decomposition.hpp> #include <arbor/load_balance.hpp> #include <arbor/recipe.hpp> -#include <arbor/symmetric_recipe.hpp> #include <arbor/context.hpp> #include "cell_group_factory.hpp" #include "execution_context.hpp" -#include "gpu_context.hpp" #include "util/maputil.hpp" #include "util/partition.hpp" #include "util/span.hpp" @@ -20,215 +18,200 @@ namespace arb { -ARB_ARBOR_API domain_decomposition partition_load_balance( - const recipe& rec, - context ctx, - const partition_hint_map& hint_map) -{ - using util::make_span; +namespace { +using gj_connection_set = std::unordered_set<cell_gid_type>; +using gj_connection_table = std::unordered_map<cell_gid_type, gj_connection_set>; +using gid_range = std::pair<cell_gid_type, cell_gid_type>; +using super_cell = std::vector<cell_gid_type>; + +// Build global GJ connectivity table such that +// * table[gid] is the set of all gids connected to gid via a GJ +// * iff A in table[B], then B in table[A] +auto build_global_gj_connection_table(const recipe& rec) { + gj_connection_table res; + // Collect all explicit GJ connections and make them bi-directional + for (cell_gid_type gid = 0; gid < rec.num_cells(); ++gid) { + for (const auto& gj: rec.gap_junctions_on(gid)) { + auto peer = gj.peer.gid; + res[gid].insert(peer); + res[peer].insert(gid); + } + } + return res; +} +// compute range of gids for the local domain, such that the first (= num_cells +// % num_dom) domains get an extra element. +auto make_local_gid_range(context ctx, cell_gid_type num_global_cells) { const auto& dist = ctx->distributed; unsigned num_domains = dist->size(); unsigned domain_id = dist->id(); - const bool gpu_avail = ctx->gpu->has_gpu(); - auto num_global_cells = rec.num_cells(); - - auto dom_size = [&](unsigned dom) -> cell_gid_type { - const cell_gid_type B = num_global_cells/num_domains; - const cell_gid_type R = num_global_cells - num_domains*B; - return B + (dom<R); - }; - - // Global load balance - - std::vector<cell_gid_type> gid_divisions; - auto gid_part = make_partition( - gid_divisions, transform_view(make_span(num_domains), dom_size)); - - // Global gj_connection table - - // Generate a local gj_connection table. - // The table is indexed by the index of the target gid in the gid_part of that domain. - // If gid_part[domain_id] = [a, b); local_gj_connection of gid `x` is at index `x-a`. - const auto dom_range = gid_part[domain_id]; - std::vector<std::vector<cell_gid_type>> local_gj_connection_table(dom_range.second-dom_range.first); - for (auto gid: make_span(gid_part[domain_id])) { - for (const auto& c: rec.gap_junctions_on(gid)) { - local_gj_connection_table[gid-dom_range.first].push_back(c.peer.gid); - } + // normal block size + auto block = num_global_cells/num_domains; + // domains that need an extra element + auto extra = num_global_cells - num_domains*block; + // now compute the range + if (domain_id < extra) { + // all previous domains, incl ours, have an extra element + auto beg = domain_id*(block + 1); + auto end = beg + block + 1; + return std::make_pair(beg, end); } - // Sort the gj connections of each local cell. - for (auto& gid_conns: local_gj_connection_table) { - util::sort(gid_conns); + else { + // in this case the first `extra` domains added an extra element and the + // rest has size `block` + auto beg = extra + domain_id*block; + auto end = beg + block; + return std::make_pair(beg, end); } +} - // Gather the global gj_connection table. - // The global gj_connection table after gathering is indexed by gid. - auto global_gj_connection_table = dist->gather_gj_connections(local_gj_connection_table); - - // Make all gj_connections bidirectional. - std::vector<std::unordered_set<cell_gid_type>> missing_peers(global_gj_connection_table.size()); - for (auto gid: make_span(global_gj_connection_table.size())) { - const auto& local_conns = global_gj_connection_table[gid]; - for (auto peer: local_conns) { - auto& peer_conns = global_gj_connection_table[peer]; - // If gid is not in the peer connection table insert it into the - // missing_peers set - if (!std::binary_search(peer_conns.begin(), peer_conns.end(), gid)) { - missing_peers[peer].insert(gid); - } - } - } - // Append the missing peers into the global_gj_connections table - for (unsigned i = 0; i < global_gj_connection_table.size(); ++i) { - std::move(missing_peers[i].begin(), missing_peers[i].end(), std::back_inserter(global_gj_connection_table[i])); - } - // Local load balance - - std::vector<std::vector<cell_gid_type>> super_cells; //cells connected by gj - std::vector<cell_gid_type> reg_cells; //independent cells - - // Map to track visited cells (cells that already belong to a group) - std::unordered_set<cell_gid_type> visited; - - // Connected components algorithm using BFS - std::queue<cell_gid_type> q; - for (auto gid: make_span(gid_part[domain_id])) { - if (!global_gj_connection_table[gid].empty()) { - // If cell hasn't been visited yet, must belong to new super_cell - // Perform BFS starting from that cell - if (!visited.count(gid)) { - visited.insert(gid); - std::vector<cell_gid_type> cg; - q.push(gid); +// build the list of components for the local domain, where a component is a list of +// cell gids such that +// * the smallest gid in the list is in the local_gid_range +// * all gids that are connected to the smallest gid are also in the list +// * all gids w/o GJ connections come first (for historical reasons!?) +auto build_components(const gj_connection_table& global_gj_connection_table, + gid_range local_gid_range) { + // cells connected by gj + std::vector<super_cell> super_cells; + // singular cells + std::vector<super_cell> res; + // track visited cells (cells that already belong to a group) + gj_connection_set visited; + // Connected components via BFS + std::vector<cell_gid_type> q; + for (auto gid: util::make_span(local_gid_range)) { + if (global_gj_connection_table.count(gid)) { + // If cell hasn't been visited yet, must belong to new component + if (visited.insert(gid).second) { + // pivot gid: the smallest found in this group; must be at + // smaller or equal to `gid`. + auto min_gid = gid; + q.push_back(gid); + super_cell sc; while (!q.empty()) { - auto element = q.front(); - q.pop(); - cg.push_back(element); - // Adjacency list - for (const auto& peer: global_gj_connection_table[element]) { - if (visited.insert(peer).second) { - q.push(peer); - } + auto element = q.back(); + q.pop_back(); + sc.push_back(element); + min_gid = std::min(element, min_gid); + // queue up conjoined cells + for (const auto& peer: global_gj_connection_table.at(element)) { + if (visited.insert(peer).second) q.push_back(peer); } } - super_cells.push_back(cg); + // if the pivot gid belongs to our domain, this group will be part + // of our domain, keep it and sort. + if (min_gid >= local_gid_range.first) { + std::sort(sc.begin(), sc.end()); + super_cells.emplace_back(std::move(sc)); + } } } else { - // If cell has no gap_junctions, put in separate group of independent cells - reg_cells.push_back(gid); + res.push_back({gid}); } } + // append super cells to result + res.reserve(res.size() + super_cells.size()); + std::move(super_cells.begin(), super_cells.end(), std::back_inserter(res)); + return res; +} - // Sort super_cell groups and only keep those where the first element in the group belongs to domain - super_cells.erase(std::remove_if(super_cells.begin(), super_cells.end(), - [gid_part, domain_id](std::vector<cell_gid_type>& cg) - { - std::sort(cg.begin(), cg.end()); - return cg.front() < gid_part[domain_id].first; - }), super_cells.end()); - - // Collect local gids that belong to this rank, and sort gids into kind lists - // kind_lists maps a cell_kind to a vector of either: - // 1. gids of regular cells (in reg_cells) - // 2. indices of supercells (in super_cells) - - struct cell_identifier { - cell_gid_type id; - bool is_super_cell; - }; - std::vector<cell_gid_type> local_gids; - std::unordered_map<cell_kind, std::vector<cell_identifier>> kind_lists; - for (auto gid: reg_cells) { - local_gids.push_back(gid); - kind_lists[rec.get_cell_kind(gid)].push_back({gid, false}); +// Figure what backend and group size to use +auto get_backend(context ctx, cell_kind kind, const partition_hint_map& hint_map) { + auto has_gpu = ctx->gpu->has_gpu() && cell_kind_supported(kind, backend_kind::gpu, *ctx); + const auto& hint = util::value_by_key_or(hint_map, kind, {}); + if (!hint.cpu_group_size) { + throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested cpu_cell_group size of {}", + kind, hint.cpu_group_size)); } - - for (unsigned i = 0; i < super_cells.size(); i++) { - auto kind = rec.get_cell_kind(super_cells[i].front()); - for (auto gid: super_cells[i]) { - if (rec.get_cell_kind(gid) != kind) { - throw gj_kind_mismatch(gid, super_cells[i].front()); - } - local_gids.push_back(gid); - } - kind_lists[kind].push_back({i, true}); + if (hint.prefer_gpu && !hint.gpu_group_size) { + throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested gpu_cell_group size of {}", + kind, hint.gpu_group_size)); } + if (hint.prefer_gpu && has_gpu) return std::make_pair(backend_kind::gpu, hint.gpu_group_size); + return std::make_pair(backend_kind::multicore, hint.cpu_group_size); +} - - // Create a flat vector of the cell kinds present on this node, - // partitioned such that kinds for which GPU implementation are - // listed before the others. This is a very primitive attempt at - // scheduling; the cell_groups that run on the GPU will be executed - // before other cell_groups, which is likely to be more efficient. - // - // TODO: This creates an dependency between the load balancer and - // the threading internals. We need support for setting the priority - // of cell group updates according to rules such as the back end on - // which the cell group is running. - - auto has_gpu_backend = [&ctx](cell_kind c) { - return cell_kind_supported(c, backend_kind::gpu, *ctx); - }; - - std::vector<cell_kind> kinds; - for (auto l: kind_lists) { - kinds.push_back(cell_kind(l.first)); +struct group_parameters { + cell_kind kind; + backend_kind backend; + size_t size; +}; + +// Create a flat vector of the cell kinds present on this node, sorted such that +// kinds for which GPU implementation are listed before the others. This is a +// very primitive attempt at scheduling; the cell_groups that run on the GPU +// will be executed before other cell_groups, which is likely to be more +// efficient. +// +// TODO: This creates an dependency between the load balancer and the threading +// internals. We need support for setting the priority of cell group updates +// according to rules such as the back end on which the cell group is running. +auto build_group_parameters(context ctx, + const partition_hint_map& hint_map, + const std::unordered_map<cell_kind, std::vector<cell_gid_type>>& kind_lists) { + std::vector<group_parameters> res; + for (const auto& [kind, _gids]: kind_lists) { + const auto& [backend, group_size] = get_backend(ctx, kind, hint_map); + res.push_back({kind, backend, group_size}); } - std::partition(kinds.begin(), kinds.end(), has_gpu_backend); + util::sort_by(res, [](const auto& p) { return p.kind; }); + return res; +} - std::vector<group_description> groups; - for (auto k: kinds) { - partition_hint hint; - if (auto opt_hint = util::value_by_key(hint_map, k)) { - hint = opt_hint.value(); - if(!hint.cpu_group_size) { - throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested cpu_cell_group size of {}", k, hint.cpu_group_size)); - } - if(hint.prefer_gpu && !hint.gpu_group_size) { - throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested gpu_cell_group size of {}", k, hint.gpu_group_size)); - } - } +// Build the list of GJ-connected cells local to this domain. +// NOTE We put this into its own function to avoid increasing RSS. +auto build_local_components(const recipe& rec, context ctx) { + const auto global_gj_connection_table = build_global_gj_connection_table(rec); + const auto local_gid_range = make_local_gid_range(ctx, rec.num_cells()); + return build_components(global_gj_connection_table, local_gid_range); +} - backend_kind backend = backend_kind::multicore; - std::size_t group_size = hint.cpu_group_size; +} // namespace - if (hint.prefer_gpu && gpu_avail && has_gpu_backend(k)) { - backend = backend_kind::gpu; - group_size = hint.gpu_group_size; +ARB_ARBOR_API domain_decomposition partition_load_balance(const recipe& rec, + context ctx, + const partition_hint_map& hint_map) { + const auto components = build_local_components(rec, ctx); + + std::vector<cell_gid_type> local_gids; + std::unordered_map<cell_kind, std::vector<cell_gid_type>> kind_lists; + + for (auto idx: util::make_span(components.size())) { + const auto& component = components[idx]; + const auto& first_gid = component.front(); + auto kind = rec.get_cell_kind(first_gid); + for (auto gid: component) { + if (rec.get_cell_kind(gid) != kind) throw gj_kind_mismatch(gid, first_gid); + local_gids.push_back(gid); } + kind_lists[kind].push_back((cell_gid_type) idx); + } + auto kinds = build_group_parameters(ctx, hint_map, kind_lists); + + std::vector<group_description> groups; + for (const auto& params: kinds) { std::vector<cell_gid_type> group_elements; - // group_elements are sorted such that the gids of all members of a super_cell are consecutive. - for (auto cell: kind_lists[k]) { - if (!cell.is_super_cell) { - group_elements.push_back(cell.id); - } else { - if (group_elements.size() + super_cells[cell.id].size() > group_size && !group_elements.empty()) { - groups.emplace_back(k, std::move(group_elements), backend); - group_elements.clear(); - } - for (auto gid: super_cells[cell.id]) { - group_elements.push_back(gid); - } - } - if (group_elements.size()>=group_size) { - groups.emplace_back(k, std::move(group_elements), backend); + // group_elements are sorted such that the gids of all members of a component are consecutive. + for (auto cell: kind_lists[params.kind]) { + const auto& component = components[cell]; + // adding the current group would go beyond alloted size, so add to the list + // of groups and start a new one. + if (group_elements.size() + component.size() > params.size && !group_elements.empty()) { + groups.emplace_back(params.kind, std::move(group_elements), params.backend); group_elements.clear(); } + // we are clear to add the current component. NOTE this may exceed + // the alloted size, but only by the minimal amount manageable + group_elements.insert(group_elements.end(), component.begin(), component.end()); } - if (!group_elements.empty()) { - groups.emplace_back(k, std::move(group_elements), backend); - } + // we may have a trailing, incomplete group, so add it. + if (!group_elements.empty()) groups.emplace_back(params.kind, std::move(group_elements), params.backend); } - // Exchange gid list with all other nodes - // global all-to-all to gather a local copy of the global gid list on each node. - auto global_gids = dist->gather_gids(local_gids); - - return domain_decomposition(rec, ctx, groups); + return {rec, ctx, groups}; } - } // namespace arb - diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp index 10e0b754c6b0d4b9ec199d2008826b7d56383c10..915a346ae922e32b52997298681a519980c4381f 100644 --- a/test/unit/test_domain_decomposition.cpp +++ b/test/unit/test_domain_decomposition.cpp @@ -5,11 +5,11 @@ #include <arbor/domain_decomposition.hpp> #include <arbor/load_balance.hpp> #include <arbor/version.hpp> - #include <arborenv/default_env.hpp> #include "util/span.hpp" - +#include "../execution_context.hpp" +#include "../distributed_context.hpp" #include "../common_cells.hpp" #include "../simple_recipes.hpp" @@ -28,122 +28,152 @@ using arb::util::make_span; // partition_load_balance into components that can be tested in isolation. namespace { - // Dummy recipes types for testing. +// Dummy recipes types for testing. - struct dummy_cell {}; - using homo_recipe = homogeneous_recipe<cell_kind::cable, dummy_cell>; +struct dummy_cell {}; +using homo_recipe = homogeneous_recipe<cell_kind::cable, dummy_cell>; - // Heterogenous cell population of cable and spike source cells. - // Interleaved so that cells with even gid are cable cells, and odd gid are - // spike source cells. - class hetero_recipe: public recipe { - public: - hetero_recipe(cell_size_type s): size_(s) {} +// Heterogenous cell population of cable and spike source cells. +// Interleaved so that cells with even gid are cable cells, and odd gid are +// spike source cells. +class hetero_recipe: public recipe { +public: + hetero_recipe(cell_size_type s): size_(s) {} - cell_size_type num_cells() const override { - return size_; - } + cell_size_type num_cells() const override { + return size_; + } - util::unique_any get_cell_description(cell_gid_type) const override { - return {}; - } + util::unique_any get_cell_description(cell_gid_type) const override { + return {}; + } - cell_kind get_cell_kind(cell_gid_type gid) const override { - return gid%2? - cell_kind::spike_source: - cell_kind::cable; - } + cell_kind get_cell_kind(cell_gid_type gid) const override { + return gid%2? + cell_kind::spike_source: + cell_kind::cable; + } - private: - cell_size_type size_; - }; +private: + cell_size_type size_; +}; - class gap_recipe: public recipe { - public: - gap_recipe(bool full_connected): fully_connected_(full_connected) {} +class gap_recipe: public recipe { +public: + gap_recipe(bool full_connected): fully_connected_(full_connected) {} - cell_size_type num_cells() const override { - return size_; - } + cell_size_type num_cells() const override { + return size_; + } - arb::util::unique_any get_cell_description(cell_gid_type) const override { - auto c = arb::make_cell_soma_only(false); - c.decorations.place(mlocation{0,1}, junction("gj"), "gj"); - return {arb::cable_cell(c)}; - } + arb::util::unique_any get_cell_description(cell_gid_type) const override { + auto c = arb::make_cell_soma_only(false); + c.decorations.place(mlocation{0,1}, junction("gj"), "gj"); + return {arb::cable_cell(c)}; + } - cell_kind get_cell_kind(cell_gid_type gid) const override { - return cell_kind::cable; - } - std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { - switch (gid) { - case 0: return {gap_junction_connection({13, "gj"}, {"gj"}, 0.1)}; - case 2: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}; - case 3: return {gap_junction_connection({8, "gj"}, {"gj"}, 0.1)}; - case 4: { - if (!fully_connected_) return {gap_junction_connection({9, "gj"}, {"gj"}, 0.1)}; - return { - gap_junction_connection({8, "gj"}, {"gj"}, 0.1), - gap_junction_connection({9, "gj"}, {"gj"}, 0.1) - }; - } - case 7: { - if (!fully_connected_) return {}; - return { - gap_junction_connection({2, "gj"}, {"gj"}, 0.1), - gap_junction_connection({11, "gj"}, {"gj"}, 0.1) - }; - } - case 8: { - if (!fully_connected_) return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}; - return { - gap_junction_connection({3, "gj"}, {"gj"}, 0.1), - gap_junction_connection({4, "gj"}, {"gj"}, 0.1) - }; - } - case 9: { - if (!fully_connected_) return {}; - return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}; - } - case 11: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}; - case 13: { - if (!fully_connected_) return {}; - return { gap_junction_connection({0, "gj"}, {"gj"}, 0.1)}; - } - default: return {}; + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + switch (gid) { + case 0: return {gap_junction_connection({13, "gj"}, {"gj"}, 0.1)}; + case 2: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}; + case 3: return {gap_junction_connection({8, "gj"}, {"gj"}, 0.1)}; + case 4: { + if (!fully_connected_) return {gap_junction_connection({9, "gj"}, {"gj"}, 0.1)}; + return { + gap_junction_connection({8, "gj"}, {"gj"}, 0.1), + gap_junction_connection({9, "gj"}, {"gj"}, 0.1) + }; } + case 7: { + if (!fully_connected_) return {}; + return { + gap_junction_connection({2, "gj"}, {"gj"}, 0.1), + gap_junction_connection({11, "gj"}, {"gj"}, 0.1) + }; + } + case 8: { + if (!fully_connected_) return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}; + return { + gap_junction_connection({3, "gj"}, {"gj"}, 0.1), + gap_junction_connection({4, "gj"}, {"gj"}, 0.1) + }; + } + case 9: { + if (!fully_connected_) return {}; + return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}; + } + case 11: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}; + case 13: { + if (!fully_connected_) return {}; + return { gap_junction_connection({0, "gj"}, {"gj"}, 0.1)}; + } + default: return {}; } + } - private: - bool fully_connected_ = true; - cell_size_type size_ = 15; - }; +private: + bool fully_connected_ = true; + cell_size_type size_ = 15; +}; - class custom_gap_recipe: public recipe { - public: - custom_gap_recipe(cell_size_type ncells, std::vector<std::vector<gap_junction_connection>> gj_conns): - size_(ncells), gj_conns_(std::move(gj_conns)){} +class custom_gap_recipe: public recipe { +public: + custom_gap_recipe(cell_size_type ncells, std::vector<std::vector<gap_junction_connection>> gj_conns): + size_(ncells), gj_conns_(std::move(gj_conns)){} - cell_size_type num_cells() const override { - return size_; - } + cell_size_type num_cells() const override { + return size_; + } - arb::util::unique_any get_cell_description(cell_gid_type) const override { - auto c = arb::make_cell_soma_only(false); - c.decorations.place(mlocation{0,1}, junction("gj"), "gj"); - return {arb::cable_cell(c)}; - } + arb::util::unique_any get_cell_description(cell_gid_type) const override { + auto c = arb::make_cell_soma_only(false); + c.decorations.place(mlocation{0,1}, junction("gj"), "gj"); + return {arb::cable_cell(c)}; + } + + cell_kind get_cell_kind(cell_gid_type gid) const override { + return cell_kind::cable; + } + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + return gj_conns_[gid]; + } +private: + cell_size_type size_ = 7; + std::vector<std::vector<gap_junction_connection>> gj_conns_; +}; + +struct unimplemented: std::runtime_error { + unimplemented(const std::string& f): std::runtime_error{f} {} +}; + +struct dummy_context { + dummy_context(int i, int s): size_{s}, id_{i} {} + + int size_ = 1; + int id_ = 0; + + gathered_vector<spike> gather_spikes(const std::vector<spike>&) const { throw unimplemented{__FUNCTION__}; } + std::vector<spike> remote_gather_spikes(const std::vector<spike>&) const { throw unimplemented{__FUNCTION__}; } + gathered_vector<cell_gid_type> gather_gids(const std::vector<cell_gid_type>& local_gids) const { throw unimplemented{__FUNCTION__}; } + void remote_ctrl_send_continue(const epoch&) const {} + void remote_ctrl_send_done() const {} + cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { throw unimplemented{__FUNCTION__}; } + cell_labels_and_gids gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const { throw unimplemented{__FUNCTION__}; } + template <typename T> std::vector<T> gather(T value, int) const { throw unimplemented{__FUNCTION__}; } + + int id() const { return id_; } + int size() const { return size_; } + + template <typename T> T min(T value) const { return value; } + template <typename T> T max(T value) const { return value; } + template <typename T> T sum(T value) const { return value; } + void barrier() const {} + std::string name() const { return "dummy"; } +}; - cell_kind get_cell_kind(cell_gid_type gid) const override { - return cell_kind::cable; - } - std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { - return gj_conns_[gid]; - } - private: - cell_size_type size_ = 7; - std::vector<std::vector<gap_junction_connection>> gj_conns_; - }; } // test assumes one domain @@ -442,7 +472,7 @@ TEST(domain_decomposition, unidirectional_gj_recipe) { {}, {}, {}, - {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}, + {gap_junction_connection({4, "gj"}, {"gj"}, 0.1), gap_junction_connection({4, "gj"}, {"gj"}, 0.1)}, {}, {}, {gap_junction_connection({5, "gj"}, {"gj"}, 0.1), gap_junction_connection({7, "gj"}, {"gj"}, 0.1)}, @@ -573,3 +603,65 @@ TEST(domain_decomposition, invalid) { EXPECT_THROW(domain_decomposition(rec, ctx, groups), invalid_gj_cell_group); } } + +struct gj_symmetric: public recipe { + gj_symmetric(unsigned num_ranks, bool fully_connected): + ncopies_(num_ranks), + fully_connected_(fully_connected) {} + + cell_size_type num_cells_per_rank() const { return size_; } + cell_size_type num_cells() const override { return size_*ncopies_; } + arb::util::unique_any get_cell_description(cell_gid_type) const override { return {}; } + cell_kind get_cell_kind(cell_gid_type gid) const override { return cell_kind::cable; } + + std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override { + unsigned shift = (gid/size_)*size_; + switch (gid % size_) { + case 1 : { + if (!fully_connected_) return {}; + return {gap_junction_connection({7 + shift, "gj"}, {"gj"}, 0.1)}; + } + case 2 : { + if (!fully_connected_) return {}; + return { + gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1), + gap_junction_connection({9 + shift, "gj"}, {"gj"}, 0.1) + }; + } + case 6 : return { + gap_junction_connection({2 + shift, "gj"}, {"gj"}, 0.1), + gap_junction_connection({7 + shift, "gj"}, {"gj"}, 0.1) + }; + case 7 : { + if (!fully_connected_) { + return {gap_junction_connection({1 + shift, "gj"}, {"gj"}, 0.1)}; + } + return { + gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1), + gap_junction_connection({1 + shift, "gj"}, {"gj"}, 0.1) + }; + } + case 9 : return { gap_junction_connection({2 + shift, "gj"}, {"gj"}, 0.1)}; + default : return {}; + } + } + + cell_size_type size_ = 10; + unsigned ncopies_; + bool fully_connected_; +}; + +TEST(domain_decomposition, symmetric_groups) { + auto ctx = make_context(); + for (int nranks = 1; nranks < 20; ++nranks) { + for (int rank = 0; rank < nranks; ++rank) { + ctx->distributed = std::make_shared<distributed_context>(dummy_context{rank, nranks}); + for (const auto& R: {gj_symmetric(nranks, true), gj_symmetric(nranks, false)}) { + // NOTE: This is a bit silly, but allows us to test _most_ of + // the invariants without proper MPI support. If we could get `gather_gids` to + // work and return the expected values we could even test all of them. + EXPECT_THROW(partition_load_balance(R, {ctx}), unimplemented); + } + } + } +}