From e32ba10acdc69f702da30985943a4e9eebee9fb9 Mon Sep 17 00:00:00 2001
From: Ben Cumming <bcumming@cscs.ch>
Date: Mon, 2 Dec 2019 15:05:55 +0100
Subject: [PATCH] Replace paritioned vector lookup with hash table. (#910)

Fix linear scaling of model initialization w.r.t. number of MPI ranks.

Fixes #909.
---
 arbor/partition_load_balance.cpp | 28 +++++++++++-----------------
 1 file changed, 11 insertions(+), 17 deletions(-)

diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp
index 3b847688..87e1b404 100644
--- a/arbor/partition_load_balance.cpp
+++ b/arbor/partition_load_balance.cpp
@@ -26,23 +26,20 @@ domain_decomposition partition_load_balance(
     const bool gpu_avail = ctx->gpu->has_gpu();
 
     struct partition_gid_domain {
-        partition_gid_domain(gathered_vector<cell_gid_type> divs, unsigned domains):
-            gids_by_rank(std::move(divs)), num_domains(domains)
-        {}
-
-        int operator()(cell_gid_type gid) const {
-            using namespace util;
-            auto rank_part = partition_view(gids_by_rank.partition());
-            for (auto i: count_along(rank_part)) {
-                if (binary_search_index(subrange_view(gids_by_rank.values(), rank_part[i]), gid)) {
-                    return i;
+        partition_gid_domain(const gathered_vector<cell_gid_type>& divs, unsigned domains) {
+            auto rank_part = util::partition_view(divs.partition());
+            for (auto rank: count_along(rank_part)) {
+                for (auto gid: util::subrange_view(divs.values(), rank_part[rank])) {
+                    gid_map[gid] = rank;
                 }
             }
-            return -1;
         }
 
-        const gathered_vector<cell_gid_type> gids_by_rank;
-        unsigned num_domains;
+        int operator()(cell_gid_type gid) const {
+            return gid_map.at(gid);
+        }
+
+        std::unordered_map<cell_gid_type, int> gid_map;
     };
 
     struct cell_identifier {
@@ -214,9 +211,6 @@ domain_decomposition partition_load_balance(
     cell_size_type num_local_cells = local_gids.size();
 
     // Exchange gid list with all other nodes
-
-    util::sort(local_gids);
-
     // global all-to-all to gather a local copy of the global gid list on each node.
     auto global_gids = ctx->distributed->gather_gids(local_gids);
 
@@ -226,7 +220,7 @@ domain_decomposition partition_load_balance(
     d.num_local_cells = num_local_cells;
     d.num_global_cells = num_global_cells;
     d.groups = std::move(groups);
-    d.gid_domain = partition_gid_domain(std::move(global_gids), num_domains);
+    d.gid_domain = partition_gid_domain(global_gids, num_domains);
 
     return d;
 }
-- 
GitLab