diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp
index 3b84768890250c04cb627bdb08b3abec62783f94..87e1b4043e139947465a09500f714978f89f3118 100644
--- a/arbor/partition_load_balance.cpp
+++ b/arbor/partition_load_balance.cpp
@@ -26,23 +26,20 @@ domain_decomposition partition_load_balance(
     const bool gpu_avail = ctx->gpu->has_gpu();
 
     struct partition_gid_domain {
-        partition_gid_domain(gathered_vector<cell_gid_type> divs, unsigned domains):
-            gids_by_rank(std::move(divs)), num_domains(domains)
-        {}
-
-        int operator()(cell_gid_type gid) const {
-            using namespace util;
-            auto rank_part = partition_view(gids_by_rank.partition());
-            for (auto i: count_along(rank_part)) {
-                if (binary_search_index(subrange_view(gids_by_rank.values(), rank_part[i]), gid)) {
-                    return i;
+        partition_gid_domain(const gathered_vector<cell_gid_type>& divs, unsigned domains) {
+            auto rank_part = util::partition_view(divs.partition());
+            for (auto rank: count_along(rank_part)) {
+                for (auto gid: util::subrange_view(divs.values(), rank_part[rank])) {
+                    gid_map[gid] = rank;
                 }
             }
-            return -1;
         }
 
-        const gathered_vector<cell_gid_type> gids_by_rank;
-        unsigned num_domains;
+        int operator()(cell_gid_type gid) const {
+            return gid_map.at(gid);
+        }
+
+        std::unordered_map<cell_gid_type, int> gid_map;
     };
 
     struct cell_identifier {
@@ -214,9 +211,6 @@ domain_decomposition partition_load_balance(
     cell_size_type num_local_cells = local_gids.size();
 
     // Exchange gid list with all other nodes
-
-    util::sort(local_gids);
-
     // global all-to-all to gather a local copy of the global gid list on each node.
     auto global_gids = ctx->distributed->gather_gids(local_gids);
 
@@ -226,7 +220,7 @@ domain_decomposition partition_load_balance(
     d.num_local_cells = num_local_cells;
     d.num_global_cells = num_global_cells;
     d.groups = std::move(groups);
-    d.gid_domain = partition_gid_domain(std::move(global_gids), num_domains);
+    d.gid_domain = partition_gid_domain(global_gids, num_domains);
 
     return d;
 }