diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp
index b4af4cb4cee64330ad02a018e01d0666afee3422..05b1682b431208f8472911da0f255eb772f86e45 100644
--- a/miniapp/miniapp.cpp
+++ b/miniapp/miniapp.cpp
@@ -15,6 +15,7 @@
 #include <hardware/gpu.hpp>
 #include <hardware/node_info.hpp>
 #include <io/exporter_spike_file.hpp>
+#include <load_balance.hpp>
 #include <model.hpp>
 #include <profiling/profiler.hpp>
 #include <profiling/meter_manager.hpp>
@@ -92,7 +93,7 @@ int main(int argc, char** argv) {
                     options.file_extension, options.over_write);
         };
 
-        auto decomp = domain_decomposition(*recipe, nd);
+        auto decomp = partition_load_balance(*recipe, nd);
 
         model m(*recipe, decomp);
 
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c12c49829df10687bcd690a5e7682c7e01988dee..d189142b6d485c33290f65b478348eed127668a4 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -3,19 +3,20 @@ set(BASE_SOURCES
     common_types_io.cpp
     cell.cpp
     event_binner.cpp
+    hardware/affinity.cpp
+    hardware/gpu.cpp
+    hardware/memory.cpp
+    hardware/node_info.cpp
+    hardware/power.cpp
     model.cpp
     morphology.cpp
     parameter_list.cpp
+    partition_load_balance.cpp
     profiling/memory_meter.cpp
     profiling/meter_manager.cpp
     profiling/power_meter.cpp
     profiling/profiler.cpp
     swcio.cpp
-    hardware/affinity.cpp
-    hardware/gpu.cpp
-    hardware/memory.cpp
-    hardware/node_info.cpp
-    hardware/power.cpp
     threading/threading.cpp
     util/debug.cpp
     util/hostname.cpp
diff --git a/src/communication/communicator.hpp b/src/communication/communicator.hpp
index 553e1bd7c851b29ba5f44d76104c131d503460de..e6a5073969e075a785f5cfcadbddf7c292ab7918 100644
--- a/src/communication/communicator.hpp
+++ b/src/communication/communicator.hpp
@@ -51,7 +51,7 @@ public:
     explicit communicator(const recipe& rec, const domain_decomposition& dom_dec) {
         using util::make_span;
         num_domains_ = comms_.size();
-        num_local_groups_ = dom_dec.num_local_groups();
+        num_local_groups_ = dom_dec.groups.size();
 
         // For caching information about each cell
         struct gid_info {
@@ -72,13 +72,13 @@ public:
         // Also the count of presynaptic sources from each domain
         //   -> src_counts: array with one entry for each domain
         std::vector<gid_info> gid_infos;
-        gid_infos.reserve(dom_dec.num_local_cells());
+        gid_infos.reserve(dom_dec.num_local_cells);
 
         cell_local_size_type n_cons = 0;
         std::vector<unsigned> src_domains;
         std::vector<cell_size_type> src_counts(num_domains_);
         for (auto i: make_span(0, num_local_groups_)) {
-            const auto& group = dom_dec.get_group(i);
+            const auto& group = dom_dec.groups[i];
             for (auto gid: group.gids) {
                 gid_info info(gid, i, rec.connections_on(gid));
                 n_cons += info.conns.size();
diff --git a/src/domain_decomposition.hpp b/src/domain_decomposition.hpp
index 796d5dc40e1d6fd8c917f00240bd2eb081d2ea36..9655e4ea3b7adb6ebfd2f82fdd85c6204e22f32a 100644
--- a/src/domain_decomposition.hpp
+++ b/src/domain_decomposition.hpp
@@ -1,8 +1,9 @@
 #pragma once
 
+#include <functional>
 #include <type_traits>
-#include <vector>
 #include <unordered_map>
+#include <vector>
 
 #include <backends.hpp>
 #include <common_types.hpp>
@@ -23,122 +24,54 @@ inline bool has_gpu_backend(cell_kind k) {
     return false;
 }
 
-/// Utility type for meta data for a local cell group.
+/// Meta data for a local cell group.
 struct group_description {
+    /// The kind of cell in the group. All cells in a cell_group have the same type.
     const cell_kind kind;
+
+    /// The gids of the cells in the cell_group, sorted in ascending order.
     const std::vector<cell_gid_type> gids;
+
+    /// The back end on which the cell_group is to run.
     const backend_kind backend;
 
     group_description(cell_kind k, std::vector<cell_gid_type> g, backend_kind b):
         kind(k), gids(std::move(g)), backend(b)
-    {}
-};
-
-class domain_decomposition {
-public:
-    domain_decomposition(const recipe& rec, hw::node_info nd):
-        node_(nd)
     {
-        using kind_type = std::underlying_type<cell_kind>::type;
-        using util::make_span;
-
-        num_domains_ = communication::global_policy::size();
-        domain_id_ = communication::global_policy::id();
-        num_global_cells_ = rec.num_cells();
-
-        auto dom_size = [this](unsigned dom) -> cell_gid_type {
-            const cell_gid_type B = num_global_cells_/num_domains_;
-            const cell_gid_type R = num_global_cells_ - num_domains_*B;
-            return B + (dom<R);
-        };
-
-        // TODO: load balancing logic will be refactored into its own class,
-        // and the domain decomposition will become a much simpler representation
-        // of the result distribution of cells over domains.
-
-        // Global load balance
-
-        gid_part_ = make_partition(
-            gid_divisions_, transform_view(make_span(0, num_domains_), dom_size));
-
-        // Local load balance
-
-        std::unordered_map<kind_type, std::vector<cell_gid_type>> kind_lists;
-        for (auto gid: make_span(gid_part_[domain_id_])) {
-            kind_lists[rec.get_cell_kind(gid)].push_back(gid);
-        }
-
-        // Create a flat vector of the cell kinds present on this node,
-        // partitioned such that kinds for which GPU implementation are
-        // listed before the others. This is a very primitive attempt at
-        // scheduling; the cell_groups that run on the GPU will be executed
-        // before other cell_groups, which is likely to be more efficient.
-        //
-        // TODO: This creates an dependency between the load balancer and
-        // the threading internals. We need support for setting the priority
-        // of cell group updates according to rules such as the back end on
-        // which the cell group is running.
-        std::vector<cell_kind> kinds;
-        for (auto l: kind_lists) {
-            kinds.push_back(cell_kind(l.first));
-        }
-        std::partition(kinds.begin(), kinds.end(), has_gpu_backend);
-
-        for (auto k: kinds) {
-            // put all cells into a single cell group on the gpu if possible
-            if (node_.num_gpus && has_gpu_backend(k)) {
-                groups_.push_back({k, std::move(kind_lists[k]), backend_kind::gpu});
-            }
-            // otherwise place into cell groups of size 1 on the cpu cores
-            else {
-                for (auto gid: kind_lists[k]) {
-                    groups_.push_back({k, {gid}, backend_kind::multicore});
-                }
-            }
-        }
+        EXPECTS(std::is_sorted(gids.begin(), gids.end()));
     }
+};
 
-    int gid_domain(cell_gid_type gid) const {
-        EXPECTS(gid<num_global_cells_);
-        return gid_part_.index(gid);
+/// Meta data that describes a domain decomposition.
+/// A domain_decomposition type is responsible solely for describing the
+/// distribution of cells across cell_groups and domains.
+/// A load balancing algorithm is responsible for generating the
+/// domain_decomposition, e.g. nest::mc::partitioned_load_balancer().
+struct domain_decomposition {
+    /// Tests whether a gid is on the local domain.
+    bool is_local_gid(cell_gid_type gid) const {
+        return gid_domain(gid)==domain_id;
     }
 
-    /// Returns the total number of cells in the global model.
-    cell_size_type num_global_cells() const {
-        return num_global_cells_;
-    }
+    /// Return the domain id of cell with gid.
+    /// Supplied by the load balancing algorithm that generates the domain
+    /// decomposition.
+    std::function<int(cell_gid_type)> gid_domain;
 
-    /// Returns the number of cells on the local domain.
-    cell_size_type num_local_cells() const {
-        auto rng = gid_part_[domain_id_];
-        return rng.second - rng.first;
-    }
+    /// Number of distrubuted domains
+    int num_domains;
 
-    /// Returns the number of cell groups on the local domain.
-    cell_size_type num_local_groups() const {
-        return groups_.size();
-    }
+    /// The index of the local domain
+    int domain_id;
 
-    /// Returns meta data for a local cell group.
-    const group_description& get_group(cell_size_type i) const {
-        EXPECTS(i<num_local_groups());
-        return groups_[i];
-    }
+    /// Total number of cells in the local domain
+    cell_size_type num_local_cells;
 
-    /// Tests whether a gid is on the local domain.
-    bool is_local_gid(cell_gid_type gid) const {
-        return algorithms::in_interval(gid, gid_part_[domain_id_]);
-    }
+    /// Total number of cells in the global model (sum over all domains)
+    cell_size_type num_global_cells;
 
-private:
-    int num_domains_;
-    int domain_id_;
-    hw::node_info node_;
-    cell_size_type num_global_cells_;
-    std::vector<cell_gid_type> gid_divisions_;
-    decltype(util::make_partition(gid_divisions_, gid_divisions_)) gid_part_;
-    std::vector<cell_kind> group_kinds_;
-    std::vector<group_description> groups_;
+    /// Descriptions of the cell groups on the local domain
+    std::vector<group_description> groups;
 };
 
 } // namespace mc
diff --git a/src/load_balance.hpp b/src/load_balance.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..6e77ca255d03537425c24a4daca3c6274217dae3
--- /dev/null
+++ b/src/load_balance.hpp
@@ -0,0 +1,12 @@
+#include <communication/global_policy.hpp>
+#include <domain_decomposition.hpp>
+#include <hardware/node_info.hpp>
+#include <recipe.hpp>
+
+namespace nest {
+namespace mc {
+
+domain_decomposition partition_load_balance(const recipe& rec, hw::node_info nd);
+
+} // namespace mc
+} // namespace nest
diff --git a/src/model.cpp b/src/model.cpp
index 43e56ba56afc9d57170abc4a48f5d3b0eee3f138..0e84a1b8a8137f012fdf6995b5b39cca99daafe3 100644
--- a/src/model.cpp
+++ b/src/model.cpp
@@ -17,18 +17,18 @@ namespace mc {
 model::model(const recipe& rec, const domain_decomposition& decomp):
     communicator_(rec, decomp)
 {
-    for (auto i: util::make_span(0, decomp.num_local_groups())) {
-        for (auto gid: decomp.get_group(i).gids) {
+    for (auto i: util::make_span(0, decomp.groups.size())) {
+        for (auto gid: decomp.groups[i].gids) {
             gid_groups_[gid] = i;
         }
     }
 
     // Generate the cell groups in parallel, with one task per cell group.
-    cell_groups_.resize(decomp.num_local_groups());
+    cell_groups_.resize(decomp.groups.size());
     threading::parallel_for::apply(0, cell_groups_.size(),
         [&](cell_gid_type i) {
             PE("setup", "cells");
-            cell_groups_[i] = cell_group_factory(rec, decomp.get_group(i));
+            cell_groups_[i] = cell_group_factory(rec, decomp.groups[i]);
             PL(2);
         });
 
diff --git a/src/partition_load_balance.cpp b/src/partition_load_balance.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..2147c63febc98da2885f790ac251bc88045ea678
--- /dev/null
+++ b/src/partition_load_balance.cpp
@@ -0,0 +1,98 @@
+#include <communication/global_policy.hpp>
+#include <domain_decomposition.hpp>
+#include <hardware/node_info.hpp>
+#include <recipe.hpp>
+
+namespace nest {
+namespace mc {
+
+domain_decomposition partition_load_balance(const recipe& rec, hw::node_info nd) {
+    struct partition_gid_domain {
+        partition_gid_domain(std::vector<cell_gid_type> divs):
+            gid_divisions(std::move(divs))
+        {}
+
+        int operator()(cell_gid_type gid) const {
+            auto gid_part = util::partition_view(gid_divisions);
+            return gid_part.index(gid);
+        }
+
+        const std::vector<cell_gid_type> gid_divisions;
+    };
+
+    using kind_type = std::underlying_type<cell_kind>::type;
+    using util::make_span;
+
+    unsigned num_domains = communication::global_policy::size();
+    unsigned domain_id = communication::global_policy::id();
+    auto num_global_cells = rec.num_cells();
+
+    auto dom_size = [&](unsigned dom) -> cell_gid_type {
+        const cell_gid_type B = num_global_cells/num_domains;
+        const cell_gid_type R = num_global_cells - num_domains*B;
+        return B + (dom<R);
+    };
+
+    // Global load balance
+
+    std::vector<cell_gid_type> gid_divisions;
+    auto gid_part = make_partition(
+        gid_divisions, transform_view(make_span(0, num_domains), dom_size));
+
+    // Local load balance
+
+    std::unordered_map<kind_type, std::vector<cell_gid_type>> kind_lists;
+    for (auto gid: make_span(gid_part[domain_id])) {
+        kind_lists[rec.get_cell_kind(gid)].push_back(gid);
+    }
+
+    // Create a flat vector of the cell kinds present on this node,
+    // partitioned such that kinds for which GPU implementation are
+    // listed before the others. This is a very primitive attempt at
+    // scheduling; the cell_groups that run on the GPU will be executed
+    // before other cell_groups, which is likely to be more efficient.
+    //
+    // TODO: This creates an dependency between the load balancer and
+    // the threading internals. We need support for setting the priority
+    // of cell group updates according to rules such as the back end on
+    // which the cell group is running.
+    std::vector<cell_kind> kinds;
+    for (auto l: kind_lists) {
+        kinds.push_back(cell_kind(l.first));
+    }
+    std::partition(kinds.begin(), kinds.end(), has_gpu_backend);
+
+    std::vector<group_description> groups;
+    for (auto k: kinds) {
+        // put all cells into a single cell group on the gpu if possible
+        if (nd.num_gpus && has_gpu_backend(k)) {
+            groups.push_back({k, std::move(kind_lists[k]), backend_kind::gpu});
+        }
+        // otherwise place into cell groups of size 1 on the cpu cores
+        else {
+            for (auto gid: kind_lists[k]) {
+                groups.push_back({k, {gid}, backend_kind::multicore});
+            }
+        }
+    }
+
+    // calculate the number of local cells
+    auto rng = gid_part[domain_id];
+    cell_size_type num_local_cells = rng.second - rng.first;
+
+    domain_decomposition d;
+    d.num_domains = num_domains;
+    d.domain_id = domain_id;
+    d.num_local_cells = num_local_cells;
+    d.num_global_cells = num_global_cells;
+    d.groups = std::move(groups);
+    d.gid_domain = partition_gid_domain(std::move(gid_divisions));
+
+    return d;
+
+    //return domain_decomposition(num_domains, domain_id, num_local_cells, num_global_cells, std::move(groups));
+}
+
+} // namespace mc
+} // namespace nest
+
diff --git a/tests/global_communication/test_communicator.cpp b/tests/global_communication/test_communicator.cpp
index 5a423b932be3523e3154d13f8d41359a55bc7b2a..154c458f3d42e61a94075793081b7bc5d8fa8c2b 100644
--- a/tests/global_communication/test_communicator.cpp
+++ b/tests/global_communication/test_communicator.cpp
@@ -9,6 +9,7 @@
 #include <communication/communicator.hpp>
 #include <communication/global_policy.hpp>
 #include <hardware/node_info.hpp>
+#include <load_balance.hpp>
 #include <util/filter.hpp>
 #include <util/rangeutil.hpp>
 #include <util/span.hpp>
@@ -296,8 +297,8 @@ namespace {
     // make a list of the gids on the local domain
     std::vector<cell_gid_type> get_gids(const domain_decomposition& D) {
         std::vector<cell_gid_type> gids;
-        for (auto i: util::make_span(0, D.num_local_groups())) {
-            util::append(gids, D.get_group(i).gids);
+        for (auto i: util::make_span(0, D.groups.size())) {
+            util::append(gids, D.groups[i].gids);
         }
         return gids;
     }
@@ -306,8 +307,8 @@ namespace {
     std::unordered_map<cell_gid_type, cell_gid_type>
     get_group_map(const domain_decomposition& D) {
         std::unordered_map<cell_gid_type, cell_gid_type> map;
-        for (auto i: util::make_span(0, D.num_local_groups())) {
-            for (auto gid: D.get_group(i).gids) {
+        for (auto i: util::make_span(0, D.groups.size())) {
+            for (auto gid: D.groups[i].gids) {
                 map[gid] = i;
             }
         }
@@ -344,7 +345,7 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) {
 
     // generate the events
     auto queues = C.make_event_queues(global_spikes);
-    if (queues.size() != D.num_local_groups()) { // one queue for each cell group
+    if (queues.size() != D.groups.size()) { // one queue for each cell group
         return ::testing::AssertionFailure()
             << "expect one event queue for each cell group";
     }
@@ -354,9 +355,9 @@ test_ring(const domain_decomposition& D, comm_type& C, F&& f) {
     // that gid. If so, look up the event queue of the cell_group of gid, and
     // search for the expected event.
     for (auto gid: gids) {
-        auto src = source_of(gid, D.num_global_cells());
+        auto src = source_of(gid, D.num_global_cells);
         if (f(src)) {
-            auto expected = expected_event_ring(gid, D.num_global_cells());
+            auto expected = expected_event_ring(gid, D.num_global_cells);
             auto grp = group_map[gid];
             auto& q = queues[grp];
             if (std::find(q.begin(), q.end(), expected)==q.end()) {
@@ -391,7 +392,7 @@ TEST(communicator, ring)
     auto R = ring_recipe(n_global);
     // use a node decomposition that reflects the resources available
     // on the node that the test is running on, including gpus.
-    auto D = domain_decomposition(R, hw::node_info());
+    const auto D = partition_load_balance(R, hw::node_info());
     auto C = communication::communicator<policy>(R, D);
 
     // every cell fires
@@ -421,7 +422,7 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) {
     std::reverse(local_spikes.begin(), local_spikes.end());
 
     std::vector<cell_gid_type> spike_gids = assign_from(
-        filter(make_span(0, D.num_global_cells()), f));
+        filter(make_span(0, D.groups.size()), f));
 
     // gather the global set of spikes
     auto global_spikes = C.exchange(local_spikes);
@@ -433,7 +434,7 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) {
 
     // generate the events
     auto queues = C.make_event_queues(global_spikes);
-    if (queues.size() != D.num_local_groups()) { // one queue for each cell group
+    if (queues.size() != D.groups.size()) { // one queue for each cell group
         return ::testing::AssertionFailure()
             << "expect one event queue for each cell group";
     }
@@ -461,7 +462,7 @@ test_all2all(const domain_decomposition& D, comm_type& C, F&& f) {
     int num_events = std::accumulate(queues.begin(), queues.end(), 0,
             [](int l, decltype(queues.front())& r){return l + r.size();});
 
-    int expected_events = D.num_global_cells()*spike_gids.size();
+    int expected_events = D.num_global_cells*spike_gids.size();
 
     EXPECT_EQ(expected_events, policy::sum(num_events));
     return ::testing::AssertionSuccess();
@@ -480,7 +481,7 @@ TEST(communicator, all2all)
     auto R = all2all_recipe(n_global);
     // use a node decomposition that reflects the resources available
     // on the node that the test is running on, including gpus.
-    auto D = domain_decomposition(R, hw::node_info());
+    const auto D = partition_load_balance(R, hw::node_info());
     auto C = communication::communicator<policy>(R, D);
 
     // every cell fires
diff --git a/tests/global_communication/test_domain_decomposition.cpp b/tests/global_communication/test_domain_decomposition.cpp
index 0e64b4f3a2384361e4d0276ba891276bb78f390c..c90114396fb3ce4cd3fac0806d70b8fbdd917e2a 100644
--- a/tests/global_communication/test_domain_decomposition.cpp
+++ b/tests/global_communication/test_domain_decomposition.cpp
@@ -9,6 +9,7 @@
 #include <communication/communicator.hpp>
 #include <communication/global_policy.hpp>
 #include <hardware/node_info.hpp>
+#include <load_balance.hpp>
 
 using namespace nest::mc;
 
@@ -89,11 +90,11 @@ TEST(domain_decomp, homogeneous) {
         // 10 cells per domain
         unsigned n_local = 10;
         unsigned n_global = n_local*N;
-        domain_decomposition D(homo_recipe(n_global), nd);
+        const auto D = partition_load_balance(homo_recipe(n_global), nd);
 
-        EXPECT_EQ(D.num_global_cells(), n_global);
-        EXPECT_EQ(D.num_local_cells(), n_local);
-        EXPECT_EQ(D.num_local_groups(), n_local);
+        EXPECT_EQ(D.num_global_cells, n_global);
+        EXPECT_EQ(D.num_local_cells, n_local);
+        EXPECT_EQ(D.groups.size(), n_local);
 
         auto b = I*n_local;
         auto e = (I+1)*n_local;
@@ -109,7 +110,7 @@ TEST(domain_decomp, homogeneous) {
         // Each group should also be tagged for cpu execution
         for (auto i: gids) {
             auto local_group = i-b;
-            auto& grp = D.get_group(local_group);
+            auto& grp = D.groups[local_group];
             EXPECT_EQ(grp.gids.size(), 1u);
             EXPECT_EQ(grp.gids.front(), unsigned(i));
             EXPECT_EQ(grp.backend, backend_kind::multicore);
@@ -123,11 +124,11 @@ TEST(domain_decomp, homogeneous) {
         // 10 cells per domain
         unsigned n_local = 10;
         unsigned n_global = n_local*N;
-        domain_decomposition D(homo_recipe(n_global), nd);
+        const auto D = partition_load_balance(homo_recipe(n_global), nd);
 
-        EXPECT_EQ(D.num_global_cells(), n_global);
-        EXPECT_EQ(D.num_local_cells(), n_local);
-        EXPECT_EQ(D.num_local_groups(), 1u);
+        EXPECT_EQ(D.num_global_cells, n_global);
+        EXPECT_EQ(D.num_local_cells, n_local);
+        EXPECT_EQ(D.groups.size(), 1u);
 
         auto b = I*n_local;
         auto e = (I+1)*n_local;
@@ -141,7 +142,7 @@ TEST(domain_decomp, homogeneous) {
 
         // Each cell group contains 1 cell of kind cable1d_neuron
         // Each group should also be tagged for cpu execution
-        auto grp = D.get_group(0u);
+        auto grp = D.groups[0u];
 
         EXPECT_EQ(grp.gids.size(), n_local);
         EXPECT_EQ(grp.gids.front(), b);
@@ -166,11 +167,11 @@ TEST(domain_decomp, heterogeneous) {
         const unsigned n_global = n_local*N;
         const unsigned n_local_grps = n_local; // 1 cell per group
         auto R = hetero_recipe(n_global);
-        domain_decomposition D(R, nd);
+        const auto D = partition_load_balance(R, nd);
 
-        EXPECT_EQ(D.num_global_cells(), n_global);
-        EXPECT_EQ(D.num_local_cells(), n_local);
-        EXPECT_EQ(D.num_local_groups(), n_local);
+        EXPECT_EQ(D.num_global_cells, n_global);
+        EXPECT_EQ(D.num_local_cells, n_local);
+        EXPECT_EQ(D.groups.size(), n_local);
 
         auto b = I*n_local;
         auto e = (I+1)*n_local;
@@ -187,7 +188,7 @@ TEST(domain_decomp, heterogeneous) {
         auto grps = util::make_span(0, n_local_grps);
         std::map<cell_kind, std::set<cell_gid_type>> kind_lists;
         for (auto i: grps) {
-            auto& grp = D.get_group(i);
+            auto& grp = D.groups[i];
             EXPECT_EQ(grp.gids.size(), 1u);
             kind_lists[grp.kind].insert(grp.gids.front());
             EXPECT_EQ(grp.backend, backend_kind::multicore);
diff --git a/tests/unit/test_domain_decomposition.cpp b/tests/unit/test_domain_decomposition.cpp
index b7e313d8bbe3cd34f21f7b6fb8d907a6894d76cb..90c00f5e6e4757369117f38f4f6bbf03fd63f165 100644
--- a/tests/unit/test_domain_decomposition.cpp
+++ b/tests/unit/test_domain_decomposition.cpp
@@ -3,6 +3,7 @@
 #include <backends.hpp>
 #include <domain_decomposition.hpp>
 #include <hardware/node_info.hpp>
+#include <load_balance.hpp>
 
 using namespace nest::mc;
 
@@ -72,14 +73,6 @@ namespace {
     };
 }
 
-//  domain_decomposition interface:
-//      int gid_domain(cell_gid_type gid)
-//      cell_size_type num_global_cells()
-//      cell_size_type num_local_cells()
-//      cell_size_type num_local_groups()
-//      const group_description& get_group(cell_size_type i)
-//      bool is_local_gid(cell_gid_type i)
-
 TEST(domain_decomposition, homogenous_population)
 {
     {   // Test on a node with 1 cpu core and no gpus.
@@ -89,11 +82,11 @@ TEST(domain_decomposition, homogenous_population)
         hw::node_info nd(1, 0);
 
         unsigned num_cells = 10;
-        domain_decomposition D(homo_recipe(num_cells), nd);
+        const auto D = partition_load_balance(homo_recipe(num_cells), nd);
 
-        EXPECT_EQ(D.num_global_cells(), num_cells);
-        EXPECT_EQ(D.num_local_cells(), num_cells);
-        EXPECT_EQ(D.num_local_groups(), num_cells);
+        EXPECT_EQ(D.num_global_cells, num_cells);
+        EXPECT_EQ(D.num_local_cells, num_cells);
+        EXPECT_EQ(D.groups.size(), num_cells);
 
         auto gids = util::make_span(0, num_cells);
         for (auto gid: gids) {
@@ -105,7 +98,7 @@ TEST(domain_decomposition, homogenous_population)
         // Each cell group contains 1 cell of kind cable1d_neuron
         // Each group should also be tagged for cpu execution
         for (auto i: gids) {
-            auto& grp = D.get_group(i);
+            auto& grp = D.groups[i];
             EXPECT_EQ(grp.gids.size(), 1u);
             EXPECT_EQ(grp.gids.front(), unsigned(i));
             EXPECT_EQ(grp.backend, backend_kind::multicore);
@@ -117,11 +110,11 @@ TEST(domain_decomposition, homogenous_population)
         hw::node_info nd(1, 1);
 
         unsigned num_cells = 10;
-        domain_decomposition D(homo_recipe(num_cells), nd);
+        const auto D = partition_load_balance(homo_recipe(num_cells), nd);
 
-        EXPECT_EQ(D.num_global_cells(), num_cells);
-        EXPECT_EQ(D.num_local_cells(), num_cells);
-        EXPECT_EQ(D.num_local_groups(), 1u);
+        EXPECT_EQ(D.num_global_cells, num_cells);
+        EXPECT_EQ(D.num_local_cells, num_cells);
+        EXPECT_EQ(D.groups.size(), 1u);
 
         auto gids = util::make_span(0, num_cells);
         for (auto gid: gids) {
@@ -132,7 +125,7 @@ TEST(domain_decomposition, homogenous_population)
 
         // Each cell group contains 1 cell of kind cable1d_neuron
         // Each group should also be tagged for cpu execution
-        auto grp = D.get_group(0u);
+        auto grp = D.groups[0u];
 
         EXPECT_EQ(grp.gids.size(), num_cells);
         EXPECT_EQ(grp.gids.front(), 0u);
@@ -152,11 +145,11 @@ TEST(domain_decomposition, heterogenous_population)
 
         unsigned num_cells = 10;
         auto R = hetero_recipe(num_cells);
-        domain_decomposition D(R, nd);
+        const auto D = partition_load_balance(R, nd);
 
-        EXPECT_EQ(D.num_global_cells(), num_cells);
-        EXPECT_EQ(D.num_local_cells(), num_cells);
-        EXPECT_EQ(D.num_local_groups(), num_cells);
+        EXPECT_EQ(D.num_global_cells, num_cells);
+        EXPECT_EQ(D.num_local_cells, num_cells);
+        EXPECT_EQ(D.groups.size(), num_cells);
 
         auto gids = util::make_span(0, num_cells);
         for (auto gid: gids) {
@@ -170,7 +163,7 @@ TEST(domain_decomposition, heterogenous_population)
         auto grps = util::make_span(0, num_cells);
         std::map<cell_kind, std::set<cell_gid_type>> kind_lists;
         for (auto i: grps) {
-            auto& grp = D.get_group(i);
+            auto& grp = D.groups[i];
             EXPECT_EQ(grp.gids.size(), 1u);
             auto k = grp.kind;
             kind_lists[k].insert(grp.gids.front());
@@ -192,19 +185,19 @@ TEST(domain_decomposition, heterogenous_population)
 
         unsigned num_cells = 10;
         auto R = hetero_recipe(num_cells);
-        domain_decomposition D(R, nd);
+        const auto D = partition_load_balance(R, nd);
 
-        EXPECT_EQ(D.num_global_cells(), num_cells);
-        EXPECT_EQ(D.num_local_cells(), num_cells);
+        EXPECT_EQ(D.num_global_cells, num_cells);
+        EXPECT_EQ(D.num_local_cells, num_cells);
         // one cell group with num_cells/2 on gpu, and num_cells/2 groups on cpu
         auto expected_groups = num_cells/2+1;
-        EXPECT_EQ(D.num_local_groups(), expected_groups);
+        EXPECT_EQ(D.groups.size(), expected_groups);
 
         auto grps = util::make_span(0, expected_groups);
         unsigned ncells = 0;
         // iterate over each group and test its properties
         for (auto i: grps) {
-            auto& grp = D.get_group(i);
+            auto& grp = D.groups[i];
             auto k = grp.kind;
             if (k==cell_kind::cable1d_neuron) {
                 EXPECT_EQ(grp.backend, backend_kind::gpu);
diff --git a/tests/validation/validate_ball_and_stick.hpp b/tests/validation/validate_ball_and_stick.hpp
index 9c288fce885c29b1a4cb1ddacaaad24b3bdd12d4..21dd6469fd21b4bc167226696c9e4e654efda3a8 100644
--- a/tests/validation/validate_ball_and_stick.hpp
+++ b/tests/validation/validate_ball_and_stick.hpp
@@ -3,6 +3,7 @@
 #include <cell.hpp>
 #include <common_types.hpp>
 #include <fvm_multicell.hpp>
+#include <load_balance.hpp>
 #include <hardware/node_info.hpp>
 #include <model.hpp>
 #include <recipe.hpp>
@@ -51,7 +52,7 @@ void run_ncomp_convergence_test(
             }
         }
         hw::node_info nd(1, backend==backend_kind::gpu? 1: 0);
-        domain_decomposition decomp(singleton_recipe{c}, nd);
+        auto decomp = partition_load_balance(singleton_recipe{c}, nd);
         model m(singleton_recipe{c}, decomp);
 
         runner.run(m, ncomp, t_end, dt, exclude);
diff --git a/tests/validation/validate_kinetic.hpp b/tests/validation/validate_kinetic.hpp
index f4bab4bb6215952517703bd4901c754483b59f8d..75a4d5efe4d235e1ae90f1d9a24d71bf5d3f0bfb 100644
--- a/tests/validation/validate_kinetic.hpp
+++ b/tests/validation/validate_kinetic.hpp
@@ -4,6 +4,7 @@
 #include <cell.hpp>
 #include <fvm_multicell.hpp>
 #include <hardware/node_info.hpp>
+#include <load_balance.hpp>
 #include <model.hpp>
 #include <recipe.hpp>
 #include <simple_sampler.hpp>
@@ -34,7 +35,7 @@ void run_kinetic_dt(
     runner.load_reference_data(ref_file);
 
     hw::node_info nd(1, backend==backend_kind::gpu? 1: 0);
-    domain_decomposition decomp(singleton_recipe{c}, nd);
+    auto decomp = partition_load_balance(singleton_recipe{c}, nd);
     model model(singleton_recipe{c}, decomp);
 
     auto exclude = stimulus_ends(c);
diff --git a/tests/validation/validate_soma.hpp b/tests/validation/validate_soma.hpp
index d062af98b879f68a8a8da20eefad8f97db8c2297..b33fa7b80959555a728c57d393fe42915d1f74bc 100644
--- a/tests/validation/validate_soma.hpp
+++ b/tests/validation/validate_soma.hpp
@@ -4,6 +4,7 @@
 #include <cell.hpp>
 #include <fvm_multicell.hpp>
 #include <hardware/node_info.hpp>
+#include <load_balance.hpp>
 #include <model.hpp>
 #include <recipe.hpp>
 #include <simple_sampler.hpp>
@@ -21,7 +22,7 @@ void validate_soma(nest::mc::backend_kind backend) {
     add_common_voltage_probes(c);
 
     hw::node_info nd(1, backend==backend_kind::gpu? 1: 0);
-    domain_decomposition decomp(singleton_recipe{c}, nd);
+    auto decomp = partition_load_balance(singleton_recipe{c}, nd);
     model m(singleton_recipe{c}, decomp);
 
     float sample_dt = .025f;
diff --git a/tests/validation/validate_synapses.hpp b/tests/validation/validate_synapses.hpp
index 7607e46885dd10c859a0e9ed9f1ee5227a48740d..3e814add278ccb91cad9c8722c1c23261a6d8c6b 100644
--- a/tests/validation/validate_synapses.hpp
+++ b/tests/validation/validate_synapses.hpp
@@ -4,6 +4,7 @@
 #include <cell_group.hpp>
 #include <fvm_multicell.hpp>
 #include <hardware/node_info.hpp>
+#include <load_balance.hpp>
 #include <model.hpp>
 #include <recipe.hpp>
 #include <simple_sampler.hpp>
@@ -62,7 +63,7 @@ void run_synapse_test(
     hw::node_info nd(1, backend==backend_kind::gpu? 1: 0);
     for (int ncomp = 10; ncomp<max_ncomp; ncomp*=2) {
         c.cable(1)->set_compartments(ncomp);
-        domain_decomposition decomp(singleton_recipe{c}, nd);
+        auto decomp = partition_load_balance(singleton_recipe{c}, nd);
         model m(singleton_recipe{c}, decomp);
         m.group(0).enqueue_events(synthetic_events);