diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp index 3b84768890250c04cb627bdb08b3abec62783f94..87e1b4043e139947465a09500f714978f89f3118 100644 --- a/arbor/partition_load_balance.cpp +++ b/arbor/partition_load_balance.cpp @@ -26,23 +26,20 @@ domain_decomposition partition_load_balance( const bool gpu_avail = ctx->gpu->has_gpu(); struct partition_gid_domain { - partition_gid_domain(gathered_vector<cell_gid_type> divs, unsigned domains): - gids_by_rank(std::move(divs)), num_domains(domains) - {} - - int operator()(cell_gid_type gid) const { - using namespace util; - auto rank_part = partition_view(gids_by_rank.partition()); - for (auto i: count_along(rank_part)) { - if (binary_search_index(subrange_view(gids_by_rank.values(), rank_part[i]), gid)) { - return i; + partition_gid_domain(const gathered_vector<cell_gid_type>& divs, unsigned domains) { + auto rank_part = util::partition_view(divs.partition()); + for (auto rank: count_along(rank_part)) { + for (auto gid: util::subrange_view(divs.values(), rank_part[rank])) { + gid_map[gid] = rank; } } - return -1; } - const gathered_vector<cell_gid_type> gids_by_rank; - unsigned num_domains; + int operator()(cell_gid_type gid) const { + return gid_map.at(gid); + } + + std::unordered_map<cell_gid_type, int> gid_map; }; struct cell_identifier { @@ -214,9 +211,6 @@ domain_decomposition partition_load_balance( cell_size_type num_local_cells = local_gids.size(); // Exchange gid list with all other nodes - - util::sort(local_gids); - // global all-to-all to gather a local copy of the global gid list on each node. auto global_gids = ctx->distributed->gather_gids(local_gids); @@ -226,7 +220,7 @@ domain_decomposition partition_load_balance( d.num_local_cells = num_local_cells; d.num_global_cells = num_global_cells; d.groups = std::move(groups); - d.gid_domain = partition_gid_domain(std::move(global_gids), num_domains); + d.gid_domain = partition_gid_domain(global_gids, num_domains); return d; }