diff --git a/arbor/communication/dry_run_context.cpp b/arbor/communication/dry_run_context.cpp
index 9936eaabba65cae48e358e88143eacd667da643a..4bc17f4d1343d7a01d4a190a3868eb1f5040d780 100644
--- a/arbor/communication/dry_run_context.cpp
+++ b/arbor/communication/dry_run_context.cpp
@@ -70,26 +70,6 @@ struct dry_run_context_impl {
         return gathered_vector<cell_gid_type>(std::move(gathered_gids), std::move(partition));
     }
 
-    std::vector<std::vector<cell_gid_type>>
-    gather_gj_connections(const std::vector<std::vector<cell_gid_type>> & local_connections) const {
-        auto local_size = local_connections.size();
-        std::vector<std::vector<cell_gid_type>> global_connections;
-        global_connections.reserve(local_size*num_ranks_);
-
-        for (unsigned i = 0; i < num_ranks_; i++) {
-            util::append(global_connections, local_connections);
-        }
-
-        for (unsigned i = 0; i < num_ranks_; i++) {
-            for (unsigned j = i*local_size; j < (i+1)*local_size; j++){
-                for (auto& conn_gid: global_connections[j]) {
-                    conn_gid += num_cells_per_tile_*i;
-                }
-            }
-        }
-        return global_connections;
-    }
-
     cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
         cell_label_range global_ranges;
         for (unsigned i = 0; i < num_ranks_; i++) {
diff --git a/arbor/communication/mpi_context.cpp b/arbor/communication/mpi_context.cpp
index c2fcf64af76ffa0f226ad23e7e4ef27a87916b0a..8109a5a88931575da10d6a839684fb8f480e960a 100644
--- a/arbor/communication/mpi_context.cpp
+++ b/arbor/communication/mpi_context.cpp
@@ -53,11 +53,6 @@ struct mpi_context_impl {
         return mpi::gather_all_with_partition(local_gids, comm_);
     }
 
-    std::vector<std::vector<cell_gid_type>>
-    gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
-        return mpi::gather_all(local_connections, comm_);
-    }
-
     cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
         cell_label_range res;
         res.sizes  = mpi::gather_all(local_ranges.sizes, comm_);
@@ -139,11 +134,6 @@ struct remote_context_impl {
     gathered_vector<cell_gid_type>
     gather_gids(const std::vector<cell_gid_type>& local_gids) const { return mpi_.gather_gids(local_gids); }
 
-    std::vector<std::vector<cell_gid_type>>
-    gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
-        return mpi_.gather_gj_connections(local_connections);
-    }
-
     cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
         return mpi_.gather_cell_label_range(local_ranges);
     }
diff --git a/arbor/distributed_context.hpp b/arbor/distributed_context.hpp
index c988177d33224d88c33cb7e5bf7ce80e92dba1b2..76e8f4c0a9db1b7bc5a890aa26c408eb765b78d3 100644
--- a/arbor/distributed_context.hpp
+++ b/arbor/distributed_context.hpp
@@ -72,10 +72,6 @@ public:
         return impl_->gather_gids(local_gids);
     }
 
-    gj_connection_vector gather_gj_connections(const gj_connection_vector& local_connections) const {
-        return impl_->gather_gj_connections(local_connections);
-    }
-
     cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const {
         return impl_->gather_cell_label_range(local_ranges);
     }
@@ -117,8 +113,6 @@ private:
         remote_gather_spikes(const spike_vector& local_spikes) const = 0;
         virtual gathered_vector<cell_gid_type>
         gather_gids(const gid_vector& local_gids) const = 0;
-        virtual gj_connection_vector
-        gather_gj_connections(const gj_connection_vector& local_connections) const = 0;
         virtual cell_label_range
         gather_cell_label_range(const cell_label_range& local_ranges) const = 0;
         virtual cell_labels_and_gids
@@ -154,10 +148,6 @@ private:
         gather_gids(const gid_vector& local_gids) const override {
             return wrapped.gather_gids(local_gids);
         }
-        std::vector<std::vector<cell_gid_type>>
-        gather_gj_connections(const gj_connection_vector& local_connections) const override {
-            return wrapped.gather_gj_connections(local_connections);
-        }
         cell_label_range
         gather_cell_label_range(const cell_label_range& local_ranges) const override {
             return wrapped.gather_cell_label_range(local_ranges);
@@ -217,10 +207,6 @@ struct local_context {
     }
     void remote_ctrl_send_continue(const epoch&) const {}
     void remote_ctrl_send_done() const {}
-    std::vector<std::vector<cell_gid_type>>
-    gather_gj_connections(const std::vector<std::vector<cell_gid_type>>& local_connections) const {
-        return local_connections;
-    }
     cell_label_range
     gather_cell_label_range(const cell_label_range& local_ranges) const {
         return local_ranges;
diff --git a/arbor/domain_decomposition.cpp b/arbor/domain_decomposition.cpp
index aa22082120a42c53a1359d8e1ea11a7c94fef2a7..f9c14a03d9f85c9d382233d3c8bd1db3f0a5e26b 100644
--- a/arbor/domain_decomposition.cpp
+++ b/arbor/domain_decomposition.cpp
@@ -13,11 +13,9 @@
 #include "util/span.hpp"
 
 namespace arb {
-domain_decomposition::domain_decomposition(
-    const recipe& rec,
-    context ctx,
-    const std::vector<group_description>& groups)
-{
+domain_decomposition::domain_decomposition(const recipe& rec,
+                                           context ctx,
+                                           const std::vector<group_description>& groups) {
     struct partition_gid_domain {
         partition_gid_domain(const gathered_vector<cell_gid_type>& divs, unsigned domains) {
             auto rank_part = util::partition_view(divs.partition());
@@ -27,9 +25,7 @@ domain_decomposition::domain_decomposition(
                 }
             }
         }
-        int operator()(cell_gid_type gid) const {
-            return gid_map.at(gid);
-        }
+        int operator()(cell_gid_type gid) const { return gid_map.at(gid); }
         std::unordered_map<cell_gid_type, int> gid_map;
     };
 
@@ -41,22 +37,14 @@ domain_decomposition::domain_decomposition(
 
     std::vector<cell_gid_type> local_gids;
     for (const auto& g: groups) {
-        if (g.backend == backend_kind::gpu && !has_gpu) {
-            throw invalid_backend(domain_id);
-        }
-        if (g.backend == backend_kind::gpu && g.kind != cell_kind::cable) {
-            throw incompatible_backend(domain_id, g.kind);
-        }
+        if (g.backend == backend_kind::gpu && !has_gpu) throw invalid_backend(domain_id);
+        if (g.backend == backend_kind::gpu && g.kind != cell_kind::cable) throw incompatible_backend(domain_id, g.kind);
 
         std::unordered_set<cell_gid_type> gid_set(g.gids.begin(), g.gids.end());
         for (const auto& gid: g.gids) {
-            if (gid >= num_global_cells) {
-                throw out_of_bounds(gid, num_global_cells);
-            }
+            if (gid >= num_global_cells) throw out_of_bounds(gid, num_global_cells);
             for (const auto& gj: rec.gap_junctions_on(gid)) {
-                if (!gid_set.count(gj.peer.gid)) {
-                    throw invalid_gj_cell_group(gid, gj.peer.gid);
-                }
+                if (!gid_set.count(gj.peer.gid)) throw invalid_gj_cell_group(gid, gj.peer.gid);
             }
         }
         local_gids.insert(local_gids.end(), g.gids.begin(), g.gids.end());
@@ -64,16 +52,12 @@ domain_decomposition::domain_decomposition(
     cell_size_type num_local_cells = local_gids.size();
 
     auto global_gids = dist->gather_gids(local_gids);
-    if (global_gids.size() != num_global_cells) {
-        throw invalid_sum_local_cells(global_gids.size(), num_global_cells);
-    }
+    if (global_gids.size() != num_global_cells) throw invalid_sum_local_cells(global_gids.size(), num_global_cells);
 
     auto global_gid_vals = global_gids.values();
     util::sort(global_gid_vals);
     for (unsigned i = 1; i < global_gid_vals.size(); ++i) {
-        if (global_gid_vals[i] == global_gid_vals[i-1]) {
-            throw duplicate_gid(global_gid_vals[i]);
-        }
+        if (global_gid_vals[i] == global_gid_vals[i-1]) throw duplicate_gid(global_gid_vals[i]);
     }
 
     num_domains_ = num_domains;
diff --git a/arbor/execution_context.hpp b/arbor/execution_context.hpp
index fb75f60cbd3ad87e63a840a660651eb5d5748212..b14fe01b344e06948287cecbbc02a221b28a2d7f 100644
--- a/arbor/execution_context.hpp
+++ b/arbor/execution_context.hpp
@@ -34,7 +34,6 @@ struct ARB_ARBOR_API execution_context {
 
     template <typename Comm>
     execution_context(const proc_allocation& resources, Comm comm, Comm remote);
-
 };
 
 } // namespace arb
diff --git a/arbor/include/arbor/common_types.hpp b/arbor/include/arbor/common_types.hpp
index 5e57f723c26e5d6641b1c8379166175cb63198e9..e2d1eccdc7952aa5ff4b01cca3cee632d850f173 100644
--- a/arbor/include/arbor/common_types.hpp
+++ b/arbor/include/arbor/common_types.hpp
@@ -142,10 +142,12 @@ using sample_index_type = std::int32_t;
 using sample_size_type  = std::uint32_t;
 
 // Enumeration for execution back-end targets, as specified in domain decompositions.
-
+// NOTE(important): Given in order of priority, ie we will attempt schedule gpu before
+//                  MC groups, for reasons of effiency. Ugly, but as we do not have more
+//                  backends, this is OK for now.
 enum class backend_kind {
+    gpu,         //  Use gpu back-end when supported by cell_group implementation.
     multicore,   //  Use multicore back-end for all computation.
-    gpu          //  Use gpu back-end when supported by cell_group implementation.
 };
 
 // Enumeration used to indentify the cell type/kind, used by the model to
diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp
index 8718c743d111eb83506ab8876321e4a804314bbc..550c2144b9b50a3b96aa0eb2d58fad7a1d5d1ab7 100644
--- a/arbor/partition_load_balance.cpp
+++ b/arbor/partition_load_balance.cpp
@@ -1,18 +1,16 @@
-#include <queue>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
+#include <algorithm>
 
 #include <arbor/domdecexcept.hpp>
 #include <arbor/domain_decomposition.hpp>
 #include <arbor/load_balance.hpp>
 #include <arbor/recipe.hpp>
-#include <arbor/symmetric_recipe.hpp>
 #include <arbor/context.hpp>
 
 #include "cell_group_factory.hpp"
 #include "execution_context.hpp"
-#include "gpu_context.hpp"
 #include "util/maputil.hpp"
 #include "util/partition.hpp"
 #include "util/span.hpp"
@@ -20,215 +18,200 @@
 
 namespace arb {
 
-ARB_ARBOR_API domain_decomposition partition_load_balance(
-    const recipe& rec,
-    context ctx,
-    const partition_hint_map& hint_map)
-{
-    using util::make_span;
+namespace {
+using gj_connection_set   = std::unordered_set<cell_gid_type>;
+using gj_connection_table = std::unordered_map<cell_gid_type, gj_connection_set>;
+using gid_range           = std::pair<cell_gid_type, cell_gid_type>;
+using super_cell          = std::vector<cell_gid_type>;
+
+// Build global GJ connectivity table such that
+// * table[gid] is the set of all gids connected to gid via a GJ
+// * iff A in table[B], then B in table[A]
+auto build_global_gj_connection_table(const recipe& rec) {
+    gj_connection_table res;
+    // Collect all explicit GJ connections and make them bi-directional
+    for (cell_gid_type gid = 0; gid < rec.num_cells(); ++gid) {
+        for (const auto& gj: rec.gap_junctions_on(gid)) {
+            auto peer = gj.peer.gid;
+            res[gid].insert(peer);
+            res[peer].insert(gid);
+        }
+    }
+    return res;
+}
 
+// compute range of gids for the local domain, such that the first (= num_cells
+// % num_dom) domains get an extra element.
+auto make_local_gid_range(context ctx, cell_gid_type num_global_cells) {
     const auto& dist = ctx->distributed;
     unsigned num_domains = dist->size();
     unsigned domain_id = dist->id();
-    const bool gpu_avail = ctx->gpu->has_gpu();
-    auto num_global_cells = rec.num_cells();
-
-    auto dom_size = [&](unsigned dom) -> cell_gid_type {
-        const cell_gid_type B = num_global_cells/num_domains;
-        const cell_gid_type R = num_global_cells - num_domains*B;
-        return B + (dom<R);
-    };
-
-    // Global load balance
-
-    std::vector<cell_gid_type> gid_divisions;
-    auto gid_part = make_partition(
-        gid_divisions, transform_view(make_span(num_domains), dom_size));
-
-    // Global gj_connection table
-
-    // Generate a local gj_connection table.
-    // The table is indexed by the index of the target gid in the gid_part of that domain.
-    // If gid_part[domain_id] = [a, b); local_gj_connection of gid `x` is at index `x-a`.
-    const auto dom_range = gid_part[domain_id];
-    std::vector<std::vector<cell_gid_type>> local_gj_connection_table(dom_range.second-dom_range.first);
-    for (auto gid: make_span(gid_part[domain_id])) {
-        for (const auto& c: rec.gap_junctions_on(gid)) {
-            local_gj_connection_table[gid-dom_range.first].push_back(c.peer.gid);
-        }
+    // normal block size
+    auto block = num_global_cells/num_domains;
+    // domains that need an extra element
+    auto extra = num_global_cells - num_domains*block;
+    // now compute the range
+    if (domain_id < extra) {
+        // all previous domains, incl ours, have an extra element
+        auto beg = domain_id*(block + 1);
+        auto end = beg + block + 1;
+        return std::make_pair(beg, end);
     }
-    // Sort the gj connections of each local cell.
-    for (auto& gid_conns: local_gj_connection_table) {
-        util::sort(gid_conns);
+    else {
+        // in this case the first `extra` domains added an extra element and the
+        // rest has size `block`
+        auto beg = extra + domain_id*block;
+        auto end = beg + block;
+        return std::make_pair(beg, end);
     }
+}
 
-    // Gather the global gj_connection table.
-    // The global gj_connection table after gathering is indexed by gid.
-    auto global_gj_connection_table = dist->gather_gj_connections(local_gj_connection_table);
-
-    // Make all gj_connections bidirectional.
-    std::vector<std::unordered_set<cell_gid_type>> missing_peers(global_gj_connection_table.size());
-    for (auto gid: make_span(global_gj_connection_table.size())) {
-        const auto& local_conns = global_gj_connection_table[gid];
-        for (auto peer: local_conns) {
-            auto& peer_conns = global_gj_connection_table[peer];
-            // If gid is not in the peer connection table insert it into the
-            // missing_peers set
-            if (!std::binary_search(peer_conns.begin(), peer_conns.end(), gid)) {
-                missing_peers[peer].insert(gid);
-            }
-        }
-    }
-    // Append the missing peers into the global_gj_connections table
-    for (unsigned i = 0; i < global_gj_connection_table.size(); ++i) {
-        std::move(missing_peers[i].begin(), missing_peers[i].end(), std::back_inserter(global_gj_connection_table[i]));
-    }
-    // Local load balance
-
-    std::vector<std::vector<cell_gid_type>> super_cells; //cells connected by gj
-    std::vector<cell_gid_type> reg_cells; //independent cells
-
-    // Map to track visited cells (cells that already belong to a group)
-    std::unordered_set<cell_gid_type> visited;
-
-    // Connected components algorithm using BFS
-    std::queue<cell_gid_type> q;
-    for (auto gid: make_span(gid_part[domain_id])) {
-        if (!global_gj_connection_table[gid].empty()) {
-            // If cell hasn't been visited yet, must belong to new super_cell
-            // Perform BFS starting from that cell
-            if (!visited.count(gid)) {
-                visited.insert(gid);
-                std::vector<cell_gid_type> cg;
-                q.push(gid);
+// build the list of components for the local domain, where a component is a list of
+// cell gids such that
+// * the smallest gid in the list is in the local_gid_range
+// * all gids that are connected to the smallest gid are also in the list
+// * all gids w/o GJ connections come first (for historical reasons!?)
+auto build_components(const gj_connection_table& global_gj_connection_table,
+                      gid_range local_gid_range) {
+    // cells connected by gj
+    std::vector<super_cell> super_cells;
+    // singular cells
+    std::vector<super_cell> res;
+    // track visited cells (cells that already belong to a group)
+    gj_connection_set visited;
+    // Connected components via BFS
+    std::vector<cell_gid_type> q;
+    for (auto gid: util::make_span(local_gid_range)) {
+        if (global_gj_connection_table.count(gid)) {
+            // If cell hasn't been visited yet, must belong to new component
+            if (visited.insert(gid).second) {
+                // pivot gid: the smallest found in this group; must be at
+                // smaller or equal to `gid`.
+                auto min_gid = gid;
+                q.push_back(gid);
+                super_cell sc;
                 while (!q.empty()) {
-                    auto element = q.front();
-                    q.pop();
-                    cg.push_back(element);
-                    // Adjacency list
-                    for (const auto& peer: global_gj_connection_table[element]) {
-                        if (visited.insert(peer).second) {
-                            q.push(peer);
-                        }
+                    auto element = q.back();
+                    q.pop_back();
+                    sc.push_back(element);
+                    min_gid = std::min(element, min_gid);
+                    // queue up conjoined cells
+                    for (const auto& peer: global_gj_connection_table.at(element)) {
+                        if (visited.insert(peer).second) q.push_back(peer);
                     }
                 }
-                super_cells.push_back(cg);
+                // if the pivot gid belongs to our domain, this group will be part
+                // of our domain, keep it and sort.
+                if (min_gid >= local_gid_range.first) {
+                    std::sort(sc.begin(), sc.end());
+                    super_cells.emplace_back(std::move(sc));
+                }
             }
         }
         else {
-            // If cell has no gap_junctions, put in separate group of independent cells
-            reg_cells.push_back(gid);
+            res.push_back({gid});
         }
     }
+    // append super cells to result
+    res.reserve(res.size() + super_cells.size());
+    std::move(super_cells.begin(), super_cells.end(), std::back_inserter(res));
+    return res;
+}
 
-    // Sort super_cell groups and only keep those where the first element in the group belongs to domain
-    super_cells.erase(std::remove_if(super_cells.begin(), super_cells.end(),
-            [gid_part, domain_id](std::vector<cell_gid_type>& cg)
-            {
-                std::sort(cg.begin(), cg.end());
-                return cg.front() < gid_part[domain_id].first;
-            }), super_cells.end());
-
-    // Collect local gids that belong to this rank, and sort gids into kind lists
-    // kind_lists maps a cell_kind to a vector of either:
-    // 1. gids of regular cells (in reg_cells)
-    // 2. indices of supercells (in super_cells)
-
-    struct cell_identifier {
-        cell_gid_type id;
-        bool is_super_cell;
-    };
-    std::vector<cell_gid_type> local_gids;
-    std::unordered_map<cell_kind, std::vector<cell_identifier>> kind_lists;
-    for (auto gid: reg_cells) {
-        local_gids.push_back(gid);
-        kind_lists[rec.get_cell_kind(gid)].push_back({gid, false});
+// Figure what backend and group size to use
+auto get_backend(context ctx, cell_kind kind, const partition_hint_map& hint_map) {
+    auto has_gpu = ctx->gpu->has_gpu() && cell_kind_supported(kind, backend_kind::gpu, *ctx);
+    const auto& hint = util::value_by_key_or(hint_map, kind, {});
+    if (!hint.cpu_group_size) {
+        throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested cpu_cell_group size of {}",
+                                                 kind, hint.cpu_group_size));
     }
-
-    for (unsigned i = 0; i < super_cells.size(); i++) {
-        auto kind = rec.get_cell_kind(super_cells[i].front());
-        for (auto gid: super_cells[i]) {
-            if (rec.get_cell_kind(gid) != kind) {
-                throw gj_kind_mismatch(gid, super_cells[i].front());
-            }
-            local_gids.push_back(gid);
-        }
-        kind_lists[kind].push_back({i, true});
+    if (hint.prefer_gpu && !hint.gpu_group_size) {
+        throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested gpu_cell_group size of {}",
+                                                 kind, hint.gpu_group_size));
     }
+    if (hint.prefer_gpu && has_gpu) return std::make_pair(backend_kind::gpu, hint.gpu_group_size);
+    return std::make_pair(backend_kind::multicore, hint.cpu_group_size);
+}
 
-
-    // Create a flat vector of the cell kinds present on this node,
-    // partitioned such that kinds for which GPU implementation are
-    // listed before the others. This is a very primitive attempt at
-    // scheduling; the cell_groups that run on the GPU will be executed
-    // before other cell_groups, which is likely to be more efficient.
-    //
-    // TODO: This creates an dependency between the load balancer and
-    // the threading internals. We need support for setting the priority
-    // of cell group updates according to rules such as the back end on
-    // which the cell group is running.
-
-    auto has_gpu_backend = [&ctx](cell_kind c) {
-        return cell_kind_supported(c, backend_kind::gpu, *ctx);
-    };
-
-    std::vector<cell_kind> kinds;
-    for (auto l: kind_lists) {
-        kinds.push_back(cell_kind(l.first));
+struct group_parameters {
+    cell_kind kind;
+    backend_kind backend;
+    size_t size;
+};
+
+// Create a flat vector of the cell kinds present on this node, sorted such that
+// kinds for which GPU implementation are listed before the others. This is a
+// very primitive attempt at scheduling; the cell_groups that run on the GPU
+// will be executed before other cell_groups, which is likely to be more
+// efficient.
+//
+// TODO: This creates an dependency between the load balancer and the threading
+// internals. We need support for setting the priority of cell group updates
+// according to rules such as the back end on which the cell group is running.
+auto build_group_parameters(context ctx,
+                            const partition_hint_map& hint_map,
+                            const std::unordered_map<cell_kind, std::vector<cell_gid_type>>& kind_lists) {
+    std::vector<group_parameters> res;
+    for (const auto& [kind, _gids]: kind_lists) {
+        const auto& [backend, group_size] = get_backend(ctx, kind, hint_map);
+        res.push_back({kind, backend, group_size});
     }
-    std::partition(kinds.begin(), kinds.end(), has_gpu_backend);
+    util::sort_by(res, [](const auto& p) { return p.kind; });
+    return res;
+}
 
-    std::vector<group_description> groups;
-    for (auto k: kinds) {
-        partition_hint hint;
-        if (auto opt_hint = util::value_by_key(hint_map, k)) {
-            hint = opt_hint.value();
-            if(!hint.cpu_group_size) {
-                throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested cpu_cell_group size of {}", k, hint.cpu_group_size));
-            }
-            if(hint.prefer_gpu && !hint.gpu_group_size) {
-                throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested gpu_cell_group size of {}", k, hint.gpu_group_size));
-            }
-        }
+// Build the list of GJ-connected cells local to this domain.
+// NOTE We put this into its own function to avoid increasing RSS.
+auto build_local_components(const recipe& rec, context ctx) {
+    const auto global_gj_connection_table = build_global_gj_connection_table(rec);
+    const auto local_gid_range = make_local_gid_range(ctx, rec.num_cells());
+    return build_components(global_gj_connection_table, local_gid_range);
+}
 
-        backend_kind backend = backend_kind::multicore;
-        std::size_t group_size = hint.cpu_group_size;
+} // namespace
 
-        if (hint.prefer_gpu && gpu_avail && has_gpu_backend(k)) {
-            backend = backend_kind::gpu;
-            group_size = hint.gpu_group_size;
+ARB_ARBOR_API domain_decomposition partition_load_balance(const recipe& rec,
+                                                          context ctx,
+                                                          const partition_hint_map& hint_map) {
+    const auto components = build_local_components(rec, ctx);
+
+    std::vector<cell_gid_type> local_gids;
+    std::unordered_map<cell_kind, std::vector<cell_gid_type>> kind_lists;
+
+    for (auto idx: util::make_span(components.size())) {
+        const auto& component = components[idx];
+        const auto& first_gid  = component.front();
+        auto kind = rec.get_cell_kind(first_gid);
+        for (auto gid: component) {
+            if (rec.get_cell_kind(gid) != kind) throw gj_kind_mismatch(gid, first_gid);
+            local_gids.push_back(gid);
         }
+        kind_lists[kind].push_back((cell_gid_type) idx);
+    }
 
+    auto kinds = build_group_parameters(ctx, hint_map, kind_lists);
+
+    std::vector<group_description> groups;
+    for (const auto& params: kinds) {
         std::vector<cell_gid_type> group_elements;
-        // group_elements are sorted such that the gids of all members of a super_cell are consecutive.
-        for (auto cell: kind_lists[k]) {
-            if (!cell.is_super_cell) {
-                group_elements.push_back(cell.id);
-            } else {
-                if (group_elements.size() + super_cells[cell.id].size() > group_size && !group_elements.empty()) {
-                    groups.emplace_back(k, std::move(group_elements), backend);
-                    group_elements.clear();
-                }
-                for (auto gid: super_cells[cell.id]) {
-                    group_elements.push_back(gid);
-                }
-            }
-            if (group_elements.size()>=group_size) {
-                groups.emplace_back(k, std::move(group_elements), backend);
+        // group_elements are sorted such that the gids of all members of a component are consecutive.
+        for (auto cell: kind_lists[params.kind]) {
+            const auto& component = components[cell];
+            // adding the current group would go beyond alloted size, so add to the list
+            // of groups and start a new one.
+            if (group_elements.size() + component.size() > params.size && !group_elements.empty()) {
+                groups.emplace_back(params.kind, std::move(group_elements), params.backend);
                 group_elements.clear();
             }
+            // we are clear to add the current component. NOTE this may exceed
+            // the alloted size, but only by the minimal amount manageable
+            group_elements.insert(group_elements.end(), component.begin(), component.end());
         }
-        if (!group_elements.empty()) {
-            groups.emplace_back(k, std::move(group_elements), backend);
-        }
+        // we may have a trailing, incomplete group, so add it.
+        if (!group_elements.empty()) groups.emplace_back(params.kind, std::move(group_elements), params.backend);
     }
 
-    // Exchange gid list with all other nodes
-    // global all-to-all to gather a local copy of the global gid list on each node.
-    auto global_gids = dist->gather_gids(local_gids);
-
-    return domain_decomposition(rec, ctx, groups);
+    return {rec, ctx, groups};
 }
-
 } // namespace arb
-
diff --git a/test/unit/test_domain_decomposition.cpp b/test/unit/test_domain_decomposition.cpp
index 10e0b754c6b0d4b9ec199d2008826b7d56383c10..915a346ae922e32b52997298681a519980c4381f 100644
--- a/test/unit/test_domain_decomposition.cpp
+++ b/test/unit/test_domain_decomposition.cpp
@@ -5,11 +5,11 @@
 #include <arbor/domain_decomposition.hpp>
 #include <arbor/load_balance.hpp>
 #include <arbor/version.hpp>
-
 #include <arborenv/default_env.hpp>
 
 #include "util/span.hpp"
-
+#include "../execution_context.hpp"
+#include "../distributed_context.hpp"
 #include "../common_cells.hpp"
 #include "../simple_recipes.hpp"
 
@@ -28,122 +28,152 @@ using arb::util::make_span;
 // partition_load_balance into components that can be tested in isolation.
 
 namespace {
-    // Dummy recipes types for testing.
+// Dummy recipes types for testing.
 
-    struct dummy_cell {};
-    using homo_recipe = homogeneous_recipe<cell_kind::cable, dummy_cell>;
+struct dummy_cell {};
+using homo_recipe = homogeneous_recipe<cell_kind::cable, dummy_cell>;
 
-    // Heterogenous cell population of cable and spike source cells.
-    // Interleaved so that cells with even gid are cable cells, and odd gid are
-    // spike source cells.
-    class hetero_recipe: public recipe {
-    public:
-        hetero_recipe(cell_size_type s): size_(s) {}
+// Heterogenous cell population of cable and spike source cells.
+// Interleaved so that cells with even gid are cable cells, and odd gid are
+// spike source cells.
+class hetero_recipe: public recipe {
+public:
+    hetero_recipe(cell_size_type s): size_(s) {}
 
-        cell_size_type num_cells() const override {
-            return size_;
-        }
+    cell_size_type num_cells() const override {
+        return size_;
+    }
 
-        util::unique_any get_cell_description(cell_gid_type) const override {
-            return {};
-        }
+    util::unique_any get_cell_description(cell_gid_type) const override {
+        return {};
+    }
 
-        cell_kind get_cell_kind(cell_gid_type gid) const override {
-            return gid%2?
-                cell_kind::spike_source:
-                cell_kind::cable;
-        }
+    cell_kind get_cell_kind(cell_gid_type gid) const override {
+        return gid%2?
+            cell_kind::spike_source:
+            cell_kind::cable;
+    }
 
-    private:
-        cell_size_type size_;
-    };
+private:
+    cell_size_type size_;
+};
 
-    class gap_recipe: public recipe {
-    public:
-        gap_recipe(bool full_connected): fully_connected_(full_connected) {}
+class gap_recipe: public recipe {
+public:
+    gap_recipe(bool full_connected): fully_connected_(full_connected) {}
 
-        cell_size_type num_cells() const override {
-            return size_;
-        }
+    cell_size_type num_cells() const override {
+        return size_;
+    }
 
-        arb::util::unique_any get_cell_description(cell_gid_type) const override {
-            auto c = arb::make_cell_soma_only(false);
-            c.decorations.place(mlocation{0,1}, junction("gj"), "gj");
-            return {arb::cable_cell(c)};
-        }
+    arb::util::unique_any get_cell_description(cell_gid_type) const override {
+        auto c = arb::make_cell_soma_only(false);
+        c.decorations.place(mlocation{0,1}, junction("gj"), "gj");
+        return {arb::cable_cell(c)};
+    }
 
-        cell_kind get_cell_kind(cell_gid_type gid) const override {
-            return cell_kind::cable;
-        }
-        std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override {
-            switch (gid) {
-                case 0:  return {gap_junction_connection({13, "gj"}, {"gj"}, 0.1)};
-                case 2:  return {gap_junction_connection({7,  "gj"}, {"gj"}, 0.1)};
-                case 3:  return {gap_junction_connection({8, "gj"}, {"gj"}, 0.1)};
-                case 4: {
-                    if (!fully_connected_) return {gap_junction_connection({9, "gj"}, {"gj"}, 0.1)};
-                    return {
-                        gap_junction_connection({8, "gj"}, {"gj"}, 0.1),
-                        gap_junction_connection({9, "gj"}, {"gj"}, 0.1)
-                    };
-                }
-                case 7: {
-                    if (!fully_connected_) return {};
-                    return {
-                        gap_junction_connection({2, "gj"}, {"gj"}, 0.1),
-                        gap_junction_connection({11, "gj"}, {"gj"}, 0.1)
-                    };
-                }
-                case 8: {
-                    if (!fully_connected_) return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)};
-                    return {
-                        gap_junction_connection({3, "gj"}, {"gj"}, 0.1),
-                        gap_junction_connection({4, "gj"}, {"gj"}, 0.1)
-                    };
-                }
-                case 9: {
-                    if (!fully_connected_) return {};
-                    return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)};
-                }
-                case 11: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)};
-                case 13: {
-                    if (!fully_connected_) return {};
-                    return { gap_junction_connection({0, "gj"}, {"gj"}, 0.1)};
-                }
-                default: return {};
+    cell_kind get_cell_kind(cell_gid_type gid) const override {
+        return cell_kind::cable;
+    }
+    std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override {
+        switch (gid) {
+            case 0:  return {gap_junction_connection({13, "gj"}, {"gj"}, 0.1)};
+            case 2:  return {gap_junction_connection({7,  "gj"}, {"gj"}, 0.1)};
+            case 3:  return {gap_junction_connection({8, "gj"}, {"gj"}, 0.1)};
+            case 4: {
+                if (!fully_connected_) return {gap_junction_connection({9, "gj"}, {"gj"}, 0.1)};
+                return {
+                    gap_junction_connection({8, "gj"}, {"gj"}, 0.1),
+                    gap_junction_connection({9, "gj"}, {"gj"}, 0.1)
+                };
             }
+            case 7: {
+                if (!fully_connected_) return {};
+                return {
+                    gap_junction_connection({2, "gj"}, {"gj"}, 0.1),
+                    gap_junction_connection({11, "gj"}, {"gj"}, 0.1)
+                };
+            }
+            case 8: {
+                if (!fully_connected_) return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)};
+                return {
+                    gap_junction_connection({3, "gj"}, {"gj"}, 0.1),
+                    gap_junction_connection({4, "gj"}, {"gj"}, 0.1)
+                };
+            }
+            case 9: {
+                if (!fully_connected_) return {};
+                return {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)};
+            }
+            case 11: return {gap_junction_connection({7, "gj"}, {"gj"}, 0.1)};
+            case 13: {
+                if (!fully_connected_) return {};
+                return { gap_junction_connection({0, "gj"}, {"gj"}, 0.1)};
+            }
+            default: return {};
         }
+    }
 
-    private:
-        bool fully_connected_ = true;
-        cell_size_type size_ = 15;
-    };
+private:
+    bool fully_connected_ = true;
+    cell_size_type size_ = 15;
+};
 
-    class custom_gap_recipe: public recipe {
-    public:
-        custom_gap_recipe(cell_size_type ncells, std::vector<std::vector<gap_junction_connection>> gj_conns):
-        size_(ncells), gj_conns_(std::move(gj_conns)){}
+class custom_gap_recipe: public recipe {
+public:
+    custom_gap_recipe(cell_size_type ncells, std::vector<std::vector<gap_junction_connection>> gj_conns):
+    size_(ncells), gj_conns_(std::move(gj_conns)){}
 
-        cell_size_type num_cells() const override {
-            return size_;
-        }
+    cell_size_type num_cells() const override {
+        return size_;
+    }
 
-        arb::util::unique_any get_cell_description(cell_gid_type) const override {
-            auto c = arb::make_cell_soma_only(false);
-            c.decorations.place(mlocation{0,1}, junction("gj"), "gj");
-            return {arb::cable_cell(c)};
-        }
+    arb::util::unique_any get_cell_description(cell_gid_type) const override {
+        auto c = arb::make_cell_soma_only(false);
+        c.decorations.place(mlocation{0,1}, junction("gj"), "gj");
+        return {arb::cable_cell(c)};
+    }
+
+    cell_kind get_cell_kind(cell_gid_type gid) const override {
+        return cell_kind::cable;
+    }
+    std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override {
+        return gj_conns_[gid];
+    }
+private:
+    cell_size_type size_ = 7;
+    std::vector<std::vector<gap_junction_connection>> gj_conns_;
+};
+
+struct unimplemented: std::runtime_error {
+    unimplemented(const std::string& f): std::runtime_error{f} {}
+};
+
+struct dummy_context {
+    dummy_context(int i, int s): size_{s}, id_{i} {}
+
+    int size_ = 1;
+    int id_ = 0;
+
+    gathered_vector<spike> gather_spikes(const std::vector<spike>&) const { throw unimplemented{__FUNCTION__}; }
+    std::vector<spike> remote_gather_spikes(const std::vector<spike>&) const { throw unimplemented{__FUNCTION__}; }
+    gathered_vector<cell_gid_type> gather_gids(const std::vector<cell_gid_type>& local_gids) const { throw unimplemented{__FUNCTION__}; }
+    void remote_ctrl_send_continue(const epoch&) const {}
+    void remote_ctrl_send_done() const {}
+    cell_label_range gather_cell_label_range(const cell_label_range& local_ranges) const { throw unimplemented{__FUNCTION__}; }
+    cell_labels_and_gids gather_cell_labels_and_gids(const cell_labels_and_gids& local_labels_and_gids) const { throw unimplemented{__FUNCTION__}; }
+    template <typename T> std::vector<T> gather(T value, int) const { throw unimplemented{__FUNCTION__}; }
+
+    int id() const { return id_; }
+    int size() const { return size_; }
+
+    template <typename T> T min(T value) const { return value; }
+    template <typename T> T max(T value) const { return value; }
+    template <typename T> T sum(T value) const { return value; }
+    void barrier() const {}
+    std::string name() const { return "dummy"; }
+};
 
-        cell_kind get_cell_kind(cell_gid_type gid) const override {
-            return cell_kind::cable;
-        }
-        std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override {
-            return gj_conns_[gid];
-        }
-    private:
-        cell_size_type size_ = 7;
-        std::vector<std::vector<gap_junction_connection>> gj_conns_;
-    };
 }
 
 // test assumes one domain
@@ -442,7 +472,7 @@ TEST(domain_decomposition, unidirectional_gj_recipe) {
                 {},
                 {},
                 {},
-                {gap_junction_connection({4, "gj"}, {"gj"}, 0.1)},
+                {gap_junction_connection({4, "gj"}, {"gj"}, 0.1), gap_junction_connection({4, "gj"}, {"gj"}, 0.1)},
                 {},
                 {},
                 {gap_junction_connection({5, "gj"}, {"gj"}, 0.1), gap_junction_connection({7, "gj"}, {"gj"}, 0.1)},
@@ -573,3 +603,65 @@ TEST(domain_decomposition, invalid) {
         EXPECT_THROW(domain_decomposition(rec, ctx, groups), invalid_gj_cell_group);
     }
 }
+
+struct gj_symmetric: public recipe {
+    gj_symmetric(unsigned num_ranks, bool fully_connected):
+        ncopies_(num_ranks),
+        fully_connected_(fully_connected) {}
+
+    cell_size_type num_cells_per_rank() const { return size_; }
+    cell_size_type num_cells() const override { return size_*ncopies_; }
+    arb::util::unique_any get_cell_description(cell_gid_type) const override { return {}; }
+    cell_kind get_cell_kind(cell_gid_type gid) const override { return cell_kind::cable; }
+
+    std::vector<gap_junction_connection> gap_junctions_on(cell_gid_type gid) const override {
+        unsigned shift = (gid/size_)*size_;
+        switch (gid % size_) {
+            case 1 :  {
+                if (!fully_connected_) return {};
+                return {gap_junction_connection({7 + shift, "gj"}, {"gj"}, 0.1)};
+            }
+            case 2 :  {
+                if (!fully_connected_) return {};
+                return {
+                    gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1),
+                    gap_junction_connection({9 + shift, "gj"}, {"gj"}, 0.1)
+                };
+            }
+            case 6 :  return {
+                gap_junction_connection({2 + shift, "gj"}, {"gj"}, 0.1),
+                gap_junction_connection({7 + shift, "gj"}, {"gj"}, 0.1)
+            };
+            case 7 :  {
+                if (!fully_connected_)  {
+                    return {gap_junction_connection({1 + shift, "gj"}, {"gj"}, 0.1)};
+                }
+                return {
+                    gap_junction_connection({6 + shift, "gj"}, {"gj"}, 0.1),
+                    gap_junction_connection({1 + shift, "gj"}, {"gj"}, 0.1)
+                };
+            }
+            case 9 :  return { gap_junction_connection({2 + shift, "gj"}, {"gj"}, 0.1)};
+            default : return {};
+        }
+    }
+
+    cell_size_type size_ = 10;
+    unsigned ncopies_;
+    bool fully_connected_;
+};
+
+TEST(domain_decomposition, symmetric_groups) {
+    auto ctx = make_context();
+    for (int nranks = 1; nranks < 20; ++nranks) {
+        for (int rank = 0; rank < nranks; ++rank) {
+            ctx->distributed = std::make_shared<distributed_context>(dummy_context{rank, nranks});
+            for (const auto& R: {gj_symmetric(nranks, true), gj_symmetric(nranks, false)}) {
+                // NOTE: This is a bit silly, but allows us to test _most_ of
+                // the invariants without proper MPI support. If we could get `gather_gids` to
+                // work and return the expected values we could even test all of them.
+                EXPECT_THROW(partition_load_balance(R, {ctx}), unimplemented);
+            }
+        }
+    }
+}