From 953d50072d915be9f81144d2e9e1c46b53c91f4e Mon Sep 17 00:00:00 2001
From: akuesters <42005107+akuesters@users.noreply.github.com>
Date: Fri, 16 Aug 2019 10:12:03 +0200
Subject: [PATCH] Python wrapper: add hint_map to domain decomposition (#827)

- wraps `partition_hint` struct for python
- `partition_hint`struct in `load_balance.hpp` is adjusted in a way that in case of zero unsigned int `cpu_group_size`/`gpu_group_size` it is set to the default value (via setter/getter)
- adds documentation for `partition_hint`
- adds testing for `domain_decomposition` including `partition_hint` (in unit and unit_distributed)
- adds `partition_hint` in example `ring.py`
- corrects `config.cpp` (and doc) to test for `ARB_GPU_ENABLED` instead of `ARB_WITH_GPU`

Fixes #776
Addresses #799, #769
---
 arbor/partition_load_balance.cpp              |   7 +
 doc/py_domdec.rst                             |  62 ++-
 doc/py_hardware.rst                           |   2 +-
 python/config.cpp                             |   2 +-
 python/domain_decomposition.cpp               |  42 +-
 python/example/ring.py                        |  11 +-
 python/test/unit/runner.py                    |   3 +
 .../test/unit/test_domain_decompositions.py   | 258 ++++++++++
 python/test/unit_distributed/runner.py        |   5 +-
 .../unit_distributed/test_contexts_mpi4py.py  |   2 +-
 .../test_domain_decompositions.py             | 464 ++++++++++++++++++
 11 files changed, 846 insertions(+), 12 deletions(-)
 create mode 100644 python/test/unit/test_domain_decompositions.py
 create mode 100644 python/test/unit_distributed/test_domain_decompositions.py

diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp
index 5497e6bc..3b847688 100644
--- a/arbor/partition_load_balance.cpp
+++ b/arbor/partition_load_balance.cpp
@@ -14,6 +14,7 @@
 #include "util/maputil.hpp"
 #include "util/partition.hpp"
 #include "util/span.hpp"
+#include "util/strprintf.hpp"
 
 namespace arb {
 
@@ -170,6 +171,12 @@ domain_decomposition partition_load_balance(
         partition_hint hint;
         if (auto opt_hint = util::value_by_key(hint_map, k)) {
             hint = opt_hint.value();
+            if(!hint.cpu_group_size) {
+                throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested cpu_cell_group size of {}", k, hint.cpu_group_size));
+            }
+            if(hint.prefer_gpu && !hint.gpu_group_size) {
+                throw arbor_exception(arb::util::pprintf("unable to perform load balancing because {} has invalid suggested gpu_cell_group size of {}", k, hint.gpu_group_size));
+            }
         }
 
         backend_kind backend = backend_kind::multicore;
diff --git a/doc/py_domdec.rst b/doc/py_domdec.rst
index 3b8f75ea..87cb0c61 100644
--- a/doc/py_domdec.rst
+++ b/doc/py_domdec.rst
@@ -18,7 +18,7 @@ If the model is distributed with MPI, the partitioning algorithm for cells is
 distributed with MPI communication. The returned :class:`domain_decomposition`
 describes the cell groups on the local MPI rank.
 
-.. function:: partition_load_balance(recipe, context)
+.. function:: partition_load_balance(recipe, context, hints)
 
     Construct a :class:`domain_decomposition` that distributes the cells
     in the model described by an :class:`arbor.recipe` over the distributed and local hardware
@@ -31,12 +31,70 @@ describes the cell groups on the local MPI rank.
     grained parallelism in the cell group.
     Otherwise, cells are grouped into small groups that fit in cache, and can be
     distributed over the available cores.
+    Optionally, provide a dictionary of :class:`partition_hint` s for certain cell kinds, by default this dictionary is empty.
 
     .. Note::
         The partitioning assumes that all cells of the same kind have equal
         computational cost, hence it may not produce a balanced partition for
         models with cells that have a large variance in computational costs.
 
+.. class:: partition_hint
+
+    Provide a hint on how the cell groups should be partitioned.
+
+    .. function:: partition_hint(cpu_group_size, gpu_group_size, prefer_gpu)
+
+        Construct a partition hint with arguments :attr:`cpu_group_size` and :attr:`gpu_group_size`, and whether to :attr:`prefer_gpu`.
+
+        By default returns a partition hint with :attr:`cpu_group_size` = ``1``, i.e., each cell is put in its own group, :attr:`gpu_group_size` = ``max``, i.e., all cells are put in one group, and :attr:`prefer_gpu` = ``True``, i.e., GPU usage is preferred.
+
+    .. attribute:: cpu_group_size
+
+        The size of the cell group assigned to CPU.
+        Must be positive, else set to default value.
+
+    .. attribute:: gpu_group_size
+
+        The size of the cell group assigned to GPU.
+        Must be positive, else set to default value.
+
+    .. attribute:: prefer_gpu
+
+        Whether GPU usage is preferred.
+
+    .. attribute:: max_size
+
+        Get the maximum size of cell groups.
+
+An example of a partition load balance with hints reads as follows:
+
+.. container:: example-code
+
+    .. code-block:: python
+
+        import arbor
+
+        # Get a communication context (with 4 threads, no GPU)
+        context = arbor.context(threads=4, gpu_id=None)
+
+        # Initialise a recipe of user defined type my_recipe with 100 cells.
+        n_cells = 100
+        recipe = my_recipe(n_cells)
+
+        # The hints perfer the multicore backend, so the decomposition is expected
+        # to never have cell groups on the GPU, regardless of whether a GPU is
+        # available or not.
+        cable_hint                  = arb.partition_hint()
+        cable_hint.prefer_gpu       = False
+        cable_hint.cpu_group_size   = 3
+        spike_hint                  = arb.partition_hint()
+        spike_hint.prefer_gpu       = False
+        spike_hint.cpu_group_size   = 4
+        hints = dict([(arb.cell_kind.cable, cable_hint), (arb.cell_kind.spike_source, spike_hint)])
+
+        decomp = arb.partition_load_balance(recipe, context, hints)
+
+
 Decomposition
 -------------
 As defined in :ref:`modeldomdec` a domain decomposition is a description of the distribution of the model over the available computational resources.
@@ -119,4 +177,4 @@ Therefore, the following data structures are used to describe domain decompositi
 
         .. attribute:: backend
 
-            The hardware backend on which the cell group will run.
+            The hardware :class:`backend` on which the cell group will run.
diff --git a/doc/py_hardware.rst b/doc/py_hardware.rst
index 4785588a..0a5e5cd1 100644
--- a/doc/py_hardware.rst
+++ b/doc/py_hardware.rst
@@ -21,7 +21,7 @@ Helper functions for checking cmake or environment variables, as well as configu
 
       * ``ARB_MPI_ENABLED``
       * ``ARB_WITH_MPI4PY``
-      * ``ARB_WITH_GPU``
+      * ``ARB_GPU_ENABLED``
       * ``ARB_VERSION``
 
     .. container:: example-code
diff --git a/python/config.cpp b/python/config.cpp
index d011d26a..78d0a7ba 100644
--- a/python/config.cpp
+++ b/python/config.cpp
@@ -24,7 +24,7 @@ pybind11::dict config() {
 #else
     dict[pybind11::str("mpi4py")]  = pybind11::bool_(false);
 #endif
-#ifdef ARB_WITH_GPU
+#ifdef ARB_GPU_ENABLED
     dict[pybind11::str("gpu")]     = pybind11::bool_(true);
 #else
     dict[pybind11::str("gpu")]     = pybind11::bool_(false);
diff --git a/python/domain_decomposition.cpp b/python/domain_decomposition.cpp
index 0191aec4..3b43d215 100644
--- a/python/domain_decomposition.cpp
+++ b/python/domain_decomposition.cpp
@@ -1,7 +1,9 @@
+#include <limits>
 #include <string>
 #include <sstream>
 
 #include <pybind11/pybind11.h>
+#include <pybind11/stl.h>
 
 #include <arbor/context.hpp>
 #include <arbor/domain_decomposition.hpp>
@@ -15,7 +17,7 @@ namespace pyarb {
 
 std::string gd_string(const arb::group_description& g) {
     return util::pprintf(
-        "<arbor.group_description: num_cells {}, gids [{}], {}, {}",
+        "<arbor.group_description: num_cells {}, gids [{}], {}, {}>",
         g.gids.size(), util::csv(g.gids, 5), g.kind, g.backend);
 }
 
@@ -25,6 +27,12 @@ std::string dd_string(const arb::domain_decomposition& d) {
         d.domain_id, d.num_domains, d.num_local_cells, d.num_global_cells, d.groups.size());
 }
 
+std::string ph_string(const arb::partition_hint& h) {
+    return util::pprintf(
+        "<arbor.partition_hint: cpu_group_size {}, gpu_group_size {}, prefer_gpu {}>",
+        h.cpu_group_size, h.gpu_group_size, (h.prefer_gpu == 1) ? "True" : "False");
+}
+
 void register_domain_decomposition(pybind11::module& m) {
     using namespace pybind11::literals;
 
@@ -44,6 +52,29 @@ void register_domain_decomposition(pybind11::module& m) {
         .def("__str__",  &gd_string)
         .def("__repr__", &gd_string);
 
+    // Partition hint
+    pybind11::class_<arb::partition_hint> partition_hint(m, "partition_hint",
+        "Provide a hint on how the cell groups should be partitioned.");
+    partition_hint
+        .def(pybind11::init<std::size_t, std::size_t, bool>(),
+            "cpu_group_size"_a = 1, "gpu_group_size"_a = std::numeric_limits<std::size_t>::max(), "prefer_gpu"_a = true,
+            "Construct a partition hint with arguments:\n"
+            "  cpu_group_size: The size of cell group assigned to CPU, each cell in its own group by default.\n"
+            "                  Must be positive, else set to default value.\n"
+            "  gpu_group_size: The size of cell group assigned to GPU, all cells in one group by default.\n"
+            "                  Must be positive, else set to default value.\n"
+            "  prefer_gpu:     Whether GPU is preferred, True by default.")
+        .def_readwrite("cpu_group_size", &arb::partition_hint::cpu_group_size,
+                                        "The size of cell group assigned to CPU.")
+        .def_readwrite("gpu_group_size", &arb::partition_hint::gpu_group_size,
+                                        "The size of cell group assigned to GPU.")
+        .def_readwrite("prefer_gpu", &arb::partition_hint::prefer_gpu,
+                                        "Whether GPU usage is preferred.")
+        .def_property_readonly_static("max_size",  [](pybind11::object) { return arb::partition_hint::max_size; },
+                                        "Get the maximum size of cell groups.")
+        .def("__str__",  &ph_string)
+        .def("__repr__", &ph_string);
+
     // Domain decomposition
     pybind11::class_<arb::domain_decomposition> domain_decomposition(m, "domain_decomposition",
         "The domain decomposition is responsible for describing the distribution of cells across cell groups and domains.");
@@ -72,12 +103,13 @@ void register_domain_decomposition(pybind11::module& m) {
     // Partition load balancer
     // The Python recipe has to be shimmed for passing to the function that takes a C++ recipe.
     m.def("partition_load_balance",
-        [](std::shared_ptr<py_recipe>& recipe, const context_shim& ctx) {
-            return arb::partition_load_balance(py_recipe_shim(recipe), ctx.context);
+        [](std::shared_ptr<py_recipe>& recipe, const context_shim& ctx, arb::partition_hint_map hint_map) {
+            return arb::partition_load_balance(py_recipe_shim(recipe), ctx.context, std::move(hint_map));
         },
         "Construct a domain_decomposition that distributes the cells in the model described by recipe\n"
-        "over the distributed and local hardware resources described by context.",
-        "recipe"_a, "context"_a);
+        "over the distributed and local hardware resources described by context.\n"
+        "Optionally, provide a dictionary of partition hints for certain cell kinds, by default empty.",
+        "recipe"_a, "context"_a, "hints"_a=arb::partition_hint_map{});
 }
 
 } // namespace pyarb
diff --git a/python/example/ring.py b/python/example/ring.py
index f5ba456e..d4710218 100644
--- a/python/example/ring.py
+++ b/python/example/ring.py
@@ -59,10 +59,19 @@ meters.checkpoint('recipe-create', context)
 decomp = arbor.partition_load_balance(recipe, context)
 print(f'{decomp}')
 
+hint = arbor.partition_hint()
+hint.prefer_gpu = True
+hint.gpu_group_size = 1000
+print(f'{hint}')
+
+hints = dict([(arbor.cell_kind.cable, hint)])
+decomp = arbor.partition_load_balance(recipe, context, hints)
+print(f'{decomp}')
+
 meters.checkpoint('load-balance', context)
 
 sim = arbor.simulation(recipe, decomp, context)
-print(f'{sim}')
+print(f'{sim} finished')
 
 meters.checkpoint('simulation-init', context)
 
diff --git a/python/test/unit/runner.py b/python/test/unit/runner.py
index 2833e381..06ebb671 100644
--- a/python/test/unit/runner.py
+++ b/python/test/unit/runner.py
@@ -11,6 +11,7 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../.
 try:
     import options
     import test_contexts
+    import test_domain_decomposition
     import test_event_generators
     import test_identifiers
     import test_tests
@@ -19,6 +20,7 @@ try:
 except ModuleNotFoundError:
     from test import options
     from test.unit import test_contexts
+    from test.unit import test_domain_decompositions
     from test.unit import test_event_generators
     from test.unit import test_identifiers
     from test.unit import test_schedules
@@ -26,6 +28,7 @@ except ModuleNotFoundError:
 
 test_modules = [\
     test_contexts,\
+    test_domain_decompositions,\
     test_event_generators,\
     test_identifiers,\
     test_schedules\
diff --git a/python/test/unit/test_domain_decompositions.py b/python/test/unit/test_domain_decompositions.py
new file mode 100644
index 00000000..fc239f9c
--- /dev/null
+++ b/python/test/unit/test_domain_decompositions.py
@@ -0,0 +1,258 @@
+# -*- coding: utf-8 -*-
+#
+# test_domain_decomposition.py
+
+import unittest
+
+import arbor as arb
+
+# to be able to run .py file from child directory
+import sys, os
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
+
+try:
+    import options
+except ModuleNotFoundError:
+    from test import options
+
+# check Arbor's configuration of mpi and gpu
+config = arb.config()
+gpu_enabled = config["gpu"]
+
+"""
+all tests for non-distributed arb.domain_decomposition
+"""
+
+# Dummy recipe
+class homo_recipe (arb.recipe):
+    def __init__(self, n=4):
+        arb.recipe.__init__(self)
+        self.ncells = n
+
+    def num_cells(self):
+        return self.ncells
+
+    def cell_description(self, gid):
+        return []
+
+    def cell_kind(self, gid):
+            return arb.cell_kind.cable
+
+# Heterogenous cell population of cable and rss cells.
+# Interleaved so that cells with even gid are cable cells, and even gid are spike source cells.
+class hetero_recipe (arb.recipe):
+    def __init__(self, n=4):
+        arb.recipe.__init__(self)
+        self.ncells = n
+
+    def num_cells(self):
+        return self.ncells
+
+    def cell_description(self, gid):
+        return []
+
+    def cell_kind(self, gid):
+        if (gid%2):
+            return arb.cell_kind.spike_source
+        else:
+            return arb.cell_kind.cable
+
+class Domain_Decompositions(unittest.TestCase):
+    # 1 cpu core, no gpus; assumes all cells will be put into cell groups of size 1
+    def test_domain_decomposition_homogenous_CPU(self):
+        n_cells = 10
+        recipe = homo_recipe(n_cells)
+        context = arb.context()
+        decomp = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(decomp.num_local_cells, n_cells)
+        self.assertEqual(decomp.num_global_cells, n_cells)
+        self.assertEqual(len(decomp.groups), n_cells)
+
+        gids = list(range(n_cells))
+        for gid in gids:
+            self.assertEqual(0, decomp.gid_domain(gid))
+
+        # Each cell group contains 1 cell of kind cable
+        # Each group should also be tagged for cpu execution
+        for i in gids:
+            grp = decomp.groups[i]
+            self.assertEqual(len(grp.gids), 1)
+            self.assertEqual(grp.gids[0], i)
+            self.assertEqual(grp.backend, arb.backend.multicore)
+            self.assertEqual(grp.kind, arb.cell_kind.cable)
+
+    # 1 cpu core, 1 gpu; assumes all cells will be placed on gpu in a single cell group
+    @unittest.skipIf(gpu_enabled == False, "GPU not enabled")
+    def test_domain_decomposition_homogenous_GPU(self):
+        n_cells = 10
+        recipe = homo_recipe(n_cells)
+        context = arb.context(threads=1, gpu_id=0)
+        decomp = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(decomp.num_local_cells, n_cells)
+        self.assertEqual(decomp.num_global_cells, n_cells)
+        self.assertEqual(len(decomp.groups), 1)
+
+        gids = range(n_cells)
+        for gid in gids:
+            self.assertEqual(0, decomp.gid_domain(gid))
+
+        # Each cell group contains 1 cell of kind cable
+        # Each group should also be tagged for gpu execution
+
+        grp = decomp.groups[0]
+
+        self.assertEqual(len(grp.gids), n_cells)
+        self.assertEqual(grp.gids[0], 0)
+        self.assertEqual(grp.gids[-1], n_cells-1)
+        self.assertEqual(grp.backend, arb.backend.gpu)
+        self.assertEqual(grp.kind, arb.cell_kind.cable)
+
+    # 1 cpu core, no gpus; assumes all cells will be put into cell groups of size 1
+    def test_domain_decomposition_heterogenous_CPU(self):
+        n_cells = 10
+        recipe = hetero_recipe(n_cells)
+        context = arb.context()
+        decomp = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(decomp.num_local_cells, n_cells)
+        self.assertEqual(decomp.num_global_cells, n_cells)
+        self.assertEqual(len(decomp.groups), n_cells)
+
+        gids = list(range(n_cells))
+        for gid in gids:
+            self.assertEqual(0, decomp.gid_domain(gid))
+
+        # Each cell group contains 1 cell of kind cable
+        # Each group should also be tagged for cpu execution
+        grps = list(range(n_cells))
+        kind_lists = dict()
+        for i in grps:
+            grp = decomp.groups[i]
+            self.assertEqual(len(grp.gids), 1)
+            k = grp.kind
+            if k not in kind_lists:
+                kind_lists[k] = []
+            kind_lists[k].append(grp.gids[0])
+
+            self.assertEqual(grp.backend, arb.backend.multicore)
+
+        kinds = [arb.cell_kind.cable, arb.cell_kind.spike_source]
+        for k in kinds:
+            gids = kind_lists[k]
+            self.assertEqual(len(gids), int(n_cells/2))
+            for gid in gids:
+                self.assertEqual(k, recipe.cell_kind(gid))
+
+    # 1 cpu core, 1 gpu; assumes cable cells will be placed on gpu in a single cell group; spike cells are on cpu in cell groups of size 1
+    @unittest.skipIf(gpu_enabled == False, "GPU not enabled")
+    def test_domain_decomposition_heterogenous_GPU(self):
+        n_cells = 10
+        recipe = hetero_recipe(n_cells)
+        context = arb.context(threads=1, gpu_id=0)
+        decomp = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(decomp.num_local_cells, n_cells)
+        self.assertEqual(decomp.num_global_cells, n_cells)
+
+        # one cell group with n_cells/2 on gpu, and n_cells/2 groups on cpu
+        expected_groups = int(n_cells/2) + 1
+        self.assertEqual(len(decomp.groups), expected_groups)
+
+        grps = range(expected_groups)
+        n = 0
+        # iterate over each group and test its properties
+        for i in grps:
+            grp = decomp.groups[i]
+            k = grp.kind
+            if (k == arb.cell_kind.cable):
+                self.assertEqual(grp.backend, arb.backend.gpu)
+                self.assertEqual(len(grp.gids), int(n_cells/2))
+                for gid in grp.gids:
+                    self.assertTrue(gid%2==0)
+                    n += 1
+            elif (k == arb.cell_kind.spike_source):
+                self.assertEqual(grp.backend, arb.backend.multicore)
+                self.assertEqual(len(grp.gids), 1)
+                self.assertTrue(grp.gids[0]%2)
+                n += 1
+        self.assertEqual(n_cells, n)
+
+    def test_domain_decomposition_hints(self):
+        n_cells = 20
+        recipe = hetero_recipe(n_cells)
+        context = arb.context()
+        # The hints perfer the multicore backend, so the decomposition is expected
+        # to never have cell groups on the GPU, regardless of whether a GPU is
+        # available or not.
+        cable_hint = arb.partition_hint()
+        cable_hint.prefer_gpu = False
+        cable_hint.cpu_group_size = 3
+        spike_hint = arb.partition_hint()
+        spike_hint.prefer_gpu = False
+        spike_hint.cpu_group_size = 4
+        hints = dict([(arb.cell_kind.cable, cable_hint), (arb.cell_kind.spike_source, spike_hint)])
+
+        decomp = arb.partition_load_balance(recipe, context, hints)
+
+        exp_cable_groups = [[0, 2, 4], [6, 8, 10], [12, 14, 16], [18]]
+        exp_spike_groups = [[1, 3, 5, 7], [9, 11, 13, 15], [17, 19]]
+
+        cable_groups = []
+        spike_groups = []
+
+        for g in decomp.groups:
+            self.assertTrue(g.kind == arb.cell_kind.cable or g.kind == arb.cell_kind.spike_source)
+
+            if (g.kind == arb.cell_kind.cable):
+                cable_groups.append(g.gids)
+            elif (g.kind == arb.cell_kind.spike_source):
+                spike_groups.append(g.gids)
+
+        self.assertEqual(exp_cable_groups, cable_groups)
+        self.assertEqual(exp_spike_groups, spike_groups)
+
+    def test_domain_decomposition_exceptions(self):
+        n_cells = 20
+        recipe = hetero_recipe(n_cells)
+        context = arb.context()
+        # The hints perfer the multicore backend, so the decomposition is expected
+        # to never have cell groups on the GPU, regardless of whether a GPU is
+        # available or not.
+        cable_hint = arb.partition_hint()
+        cable_hint.prefer_gpu = False
+        cable_hint.cpu_group_size = 0
+        spike_hint = arb.partition_hint()
+        spike_hint.prefer_gpu = False
+        spike_hint.gpu_group_size = 1
+        hints = dict([(arb.cell_kind.cable, cable_hint), (arb.cell_kind.spike_source, spike_hint)])
+
+        with self.assertRaisesRegex(RuntimeError,
+            "unable to perform load balancing because cell_kind::cable has invalid suggested cpu_cell_group size of 0"):
+            decomp = arb.partition_load_balance(recipe, context, hints)
+
+        cable_hint = arb.partition_hint()
+        cable_hint.prefer_gpu = False
+        cable_hint.cpu_group_size = 1
+        spike_hint = arb.partition_hint()
+        spike_hint.prefer_gpu = True
+        spike_hint.gpu_group_size = 0
+        hints = dict([(arb.cell_kind.cable, cable_hint), (arb.cell_kind.spike_source, spike_hint)])
+
+        with self.assertRaisesRegex(RuntimeError,
+            "unable to perform load balancing because cell_kind::spike_source has invalid suggested gpu_cell_group size of 0"):
+            decomp = arb.partition_load_balance(recipe, context, hints)
+
+def suite():
+    # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts
+    suite = unittest.makeSuite(Domain_Decompositions, ('test'))
+    return suite
+
+def run():
+    v = options.parse_arguments().verbosity
+    runner = unittest.TextTestRunner(verbosity = v)
+    runner.run(suite())
+
+if __name__ == "__main__":
+    run()
diff --git a/python/test/unit_distributed/runner.py b/python/test/unit_distributed/runner.py
index ddb27608..8a913781 100644
--- a/python/test/unit_distributed/runner.py
+++ b/python/test/unit_distributed/runner.py
@@ -21,16 +21,19 @@ try:
     import options
     import test_contexts_arbmpi
     import test_contexts_mpi4py
+    import test_domain_decompositions
     # add more if needed
 except ModuleNotFoundError:
     from test import options
     from test.unit_distributed import test_contexts_arbmpi
     from test.unit_distributed import test_contexts_mpi4py
+    from test.unit_distributed import test_domain_decompositions
     # add more if needed
 
 test_modules = [\
     test_contexts_arbmpi,\
-    test_contexts_mpi4py\
+    test_contexts_mpi4py,\
+    test_domain_decompositions\
 ] # add more if needed
 
 def suite():
diff --git a/python/test/unit_distributed/test_contexts_mpi4py.py b/python/test/unit_distributed/test_contexts_mpi4py.py
index 7a05bb5a..4dd225b6 100644
--- a/python/test/unit_distributed/test_contexts_mpi4py.py
+++ b/python/test/unit_distributed/test_contexts_mpi4py.py
@@ -78,7 +78,7 @@ def suite():
 def run():
     v = options.parse_arguments().verbosity
 
-    comm = arb.mpi_comm_from_mpi4py(mpi.COMM_WORLD)
+    comm = arb.mpi_comm(mpi.COMM_WORLD)
     alloc = arb.proc_allocation()
     ctx = arb.context(alloc, comm)
     rank = ctx.rank
diff --git a/python/test/unit_distributed/test_domain_decompositions.py b/python/test/unit_distributed/test_domain_decompositions.py
new file mode 100644
index 00000000..17bd5252
--- /dev/null
+++ b/python/test/unit_distributed/test_domain_decompositions.py
@@ -0,0 +1,464 @@
+# -*- coding: utf-8 -*-
+#
+# test_domain_decompositions.py
+
+import unittest
+
+import arbor as arb
+
+# to be able to run .py file from child directory
+import sys, os
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
+
+try:
+    import options
+except ModuleNotFoundError:
+    from test import options
+
+# check Arbor's configuration of mpi and gpu
+config = arb.config()
+gpu_enabled = config["gpu"]
+mpi_enabled = config["mpi"]
+
+"""
+all tests for distributed arb.domain_decomposition
+"""
+
+# Dummy recipe
+class homo_recipe (arb.recipe):
+    def __init__(self, n=4):
+        arb.recipe.__init__(self)
+        self.ncells = n
+
+    def num_cells(self):
+        return self.ncells
+
+    def cell_description(self, gid):
+        return []
+
+    def cell_kind(self, gid):
+            return arb.cell_kind.cable
+
+# Heterogenous cell population of cable and rss cells.
+# Interleaved so that cells with even gid are cable cells, and even gid are spike source cells.
+class hetero_recipe (arb.recipe):
+    def __init__(self, n=4):
+        arb.recipe.__init__(self)
+        self.ncells = n
+
+    def num_cells(self):
+        return self.ncells
+
+    def cell_description(self, gid):
+        return []
+
+    def cell_kind(self, gid):
+        if (gid%2):
+            return arb.cell_kind.spike_source
+        else:
+            return arb.cell_kind.cable
+
+    def num_sources(self, gid):
+        return 0
+
+    def num_targets(self, gid):
+        return 0
+
+    def connections_on(self, gid):
+        return []
+
+    def event_generators(self, gid):
+        return []
+
+class gj_switch:
+    def __init__(self, gid, shift):
+        self.gid_ = gid
+        self.shift_ = shift
+
+    def switch(self, arg):
+        default = []
+        return getattr(self, 'case_' + str(arg), lambda: default)()
+
+    def case_1(self):
+        return [arb.gap_junction_connection(arb.cell_member(7 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1)]
+
+    def case_2(self):
+        return [arb.gap_junction_connection(arb.cell_member(6 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1),
+                arb.gap_junction_connection(arb.cell_member(9 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1)]
+
+    def case_6(self):
+        return [arb.gap_junction_connection(arb.cell_member(2 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1),
+                arb.gap_junction_connection(arb.cell_member(7 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1)]
+
+    def case_7(self):
+        return [arb.gap_junction_connection(arb.cell_member(6 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1),
+                arb.gap_junction_connection(arb.cell_member(1 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1)]
+
+    def case_9(self):
+        return [arb.gap_junction_connection(arb.cell_member(2 + self.shift_, 0), arb.cell_member(self.gid_, 0), 0.1)]
+
+class gj_symmetric (arb.recipe):
+    def __init__(self, num_ranks):
+        arb.recipe.__init__(self)
+        self.ncopies = num_ranks
+        self.size    = 10
+
+    def num_cells(self):
+        return self.size*self.ncopies
+
+    def cell_description(self, gid):
+        return []
+
+    def cell_kind(self, gid):
+        return arb.cell_kind.cable
+
+    def gap_junctions_on(self, gid):
+        shift = int((gid/self.size))
+        shift *= self.size
+        s = gj_switch(gid, shift)
+        return s.switch(gid%self.size)
+
+class gj_non_symmetric (arb.recipe):
+    def __init__(self, num_ranks):
+        arb.recipe.__init__(self)
+        self.groups = num_ranks
+        self.size   = num_ranks
+
+    def num_cells(self):
+        return self.size*self.groups
+
+    def cell_description(self, gid):
+        return []
+
+    def cell_kind(self, gid):
+        return arb.cell_kind.cable
+
+    def gap_junctions_on(self, gid):
+        group = int(gid/self.groups)
+        id = gid%self.size
+
+        if (id == group and group != (self.groups - 1)):
+            return [arb.gap_junction_connection(arb.cell_member(gid + self.size, 0), arb.cell_member(gid, 0), 0.1)]
+        elif (id == group - 1):
+            return [arb.gap_junction_connection(arb.cell_member(gid - self.size, 0), arb.cell_member(gid, 0), 0.1)]
+        else:
+            return []
+
+@unittest.skipIf(mpi_enabled == False, "MPI not enabled")
+class Domain_Decompositions_Distributed(unittest.TestCase):
+    # Initialize mpi only once in this class (when adding classes move initialization to setUpModule()
+    @classmethod
+    def setUpClass(self):
+        self.local_mpi = False
+        if not arb.mpi_is_initialized():
+            arb.mpi_init()
+            self.local_mpi = True
+    # Finalize mpi only once in this class (when adding classes move finalization to setUpModule()
+    @classmethod
+    def tearDownClass(self):
+        if self.local_mpi:
+            arb.mpi_finalize()
+
+    # 1 node with 1 cpu core, no gpus; assumes all cells will be put into cell groups of size 1
+    def test_domain_decomposition_homogenous_MC(self):
+        if (mpi_enabled):
+            comm = arb.mpi_comm()
+            context = arb.context(threads=1, gpu_id=None, mpi=comm)
+        else:
+            context = arb.context(threads=1, gpu_id=None)
+
+        N = context.ranks
+        I = context.rank
+
+        # 10 cells per domain
+        n_local = 10
+        n_global = n_local * N
+
+        recipe = homo_recipe(n_global)
+        decomp = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(decomp.num_local_cells, n_local)
+        self.assertEqual(decomp.num_global_cells, n_global)
+        self.assertEqual(len(decomp.groups), n_local)
+
+        b = I * n_local
+        e = (I + 1) * n_local
+        gids = list(range(b,e))
+
+        for gid in gids:
+            self.assertEqual(I, decomp.gid_domain(gid))
+
+        # Each cell group contains 1 cell of kind cable
+        # Each group should also be tagged for cpu execution
+
+        for i in gids:
+            local_group = i - b
+            grp = decomp.groups[local_group]
+
+            self.assertEqual(len(grp.gids), 1)
+            self.assertEqual(grp.gids[0], i)
+            self.assertEqual(grp.backend, arb.backend.multicore)
+            self.assertEqual(grp.kind, arb.cell_kind.cable)
+
+    # 1 node with 1 cpu core, 1 gpu; assumes all cells will be placed on gpu in a single cell group
+    @unittest.skipIf(gpu_enabled == False, "GPU not enabled")
+    def test_domain_decomposition_homogenous_GPU(self):
+
+        if (mpi_enabled):
+            comm = arb.mpi_comm()
+            context = arb.context(threads=1, gpu_id=0, mpi=comm)
+        else:
+            context = arb.context(threads=1, gpu_id=0)
+
+        N = context.ranks
+        I = context.rank
+
+        # 10 cells per domain
+        n_local = 10
+        n_global = n_local * N
+
+        recipe = homo_recipe(n_global)
+        decomp = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(decomp.num_local_cells, n_local)
+        self.assertEqual(decomp.num_global_cells, n_global)
+        self.assertEqual(len(decomp.groups), 1)
+
+        b = I * n_local
+        e = (I + 1) * n_local
+        gids = list(range(b,e))
+
+        for gid in gids:
+            self.assertEqual(I, decomp.gid_domain(gid))
+
+        # Each cell group contains 1 cell of kind cable
+        # Each group should also be tagged for gpu execution
+
+        grp = decomp.groups[0]
+
+        self.assertEqual(len(grp.gids), n_local)
+        self.assertEqual(grp.gids[0], b)
+        self.assertEqual(grp.gids[-1], e-1)
+        self.assertEqual(grp.backend, arb.backend.gpu)
+        self.assertEqual(grp.kind, arb.cell_kind.cable)
+
+    # 1 node with 1 cpu core, no gpus; assumes all cells will be put into cell groups of size 1
+    def test_domain_decomposition_heterogenous_MC(self):
+        if (mpi_enabled):
+            comm = arb.mpi_comm()
+            context = arb.context(threads=1, gpu_id=None, mpi=comm)
+        else:
+            context = arb.context(threads=1, gpu_id=None)
+
+        N = context.ranks
+        I = context.rank
+
+        # 10 cells per domain
+        n_local = 10
+        n_global = n_local * N
+        n_local_groups = n_local # 1 cell per group
+
+        recipe = hetero_recipe(n_global)
+        decomp = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(decomp.num_local_cells, n_local)
+        self.assertEqual(decomp.num_global_cells, n_global)
+        self.assertEqual(len(decomp.groups), n_local)
+
+        b = I * n_local
+        e = (I + 1) * n_local
+        gids = list(range(b,e))
+
+        for gid in gids:
+            self.assertEqual(I, decomp.gid_domain(gid))
+
+        # Each cell group contains 1 cell of kind cable
+        # Each group should also be tagged for cpu execution
+        grps = list(range(n_local_groups))
+        kind_lists = dict()
+        for i in grps:
+            grp = decomp.groups[i]
+            self.assertEqual(len(grp.gids), 1)
+            k = grp.kind
+            if k not in kind_lists:
+                kind_lists[k] = []
+            kind_lists[k].append(grp.gids[0])
+
+            self.assertEqual(grp.backend, arb.backend.multicore)
+
+        kinds = [arb.cell_kind.cable, arb.cell_kind.spike_source]
+        for k in kinds:
+            gids = kind_lists[k]
+            self.assertEqual(len(gids), int(n_local/2))
+            for gid in gids:
+                self.assertEqual(k, recipe.cell_kind(gid))
+
+    def test_domain_decomposition_symmetric(self):
+        nranks = 1
+        rank = 0
+        if (mpi_enabled):
+            comm = arb.mpi_comm()
+            context = arb.context(threads=1, gpu_id=None, mpi=comm)
+            nranks = context.ranks
+            rank = context.rank
+        else:
+            context = arb.context(threads=1, gpu_id=None)
+
+        recipe = gj_symmetric(nranks)
+        decomp0 = arb.partition_load_balance(recipe, context)
+
+        self.assertEqual(6, len(decomp0.groups))
+
+        shift = int((rank * recipe.num_cells())/nranks)
+
+        exp_groups0 = [ [0 + shift],
+                        [3 + shift],
+                        [4 + shift],
+                        [5 + shift],
+                        [8 + shift],
+                        [1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift]]
+
+        for i in range(6):
+            self.assertEqual(exp_groups0[i], decomp0.groups[i].gids)
+
+        cells_per_rank = int(recipe.num_cells()/nranks)
+
+        for i in range(recipe.num_cells()):
+            self.assertEqual(int(i/cells_per_rank), decomp0.gid_domain(i))
+
+        # Test different group_hints
+        hint1 = arb.partition_hint()
+        hint1.prefer_gpu = False
+        hint1.cpu_group_size = recipe.num_cells()
+        hints1 = dict([(arb.cell_kind.cable, hint1)])
+
+        decomp1 = arb.partition_load_balance(recipe, context, hints1)
+        self.assertEqual(1, len(decomp1.groups))
+
+        exp_groups1 = [0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift,
+                        1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift]
+
+        self.assertEqual(exp_groups1, decomp1.groups[0].gids)
+
+        for i in range(recipe.num_cells()):
+            self.assertEqual(int(i/cells_per_rank), decomp1.gid_domain(i))
+
+        hint2 = arb.partition_hint()
+        hint2.prefer_gpu = False
+        hint2.cpu_group_size = int(cells_per_rank/2)
+        hints2 = dict([(arb.cell_kind.cable, hint2)])
+
+        decomp2 = arb.partition_load_balance(recipe, context, hints2)
+        self.assertEqual(2, len(decomp2.groups))
+
+        exp_groups2 = [ [0 + shift, 3 + shift, 4 + shift, 5 + shift, 8 + shift],
+                        [1 + shift, 2 + shift, 6 + shift, 7 + shift, 9 + shift] ]
+
+        for i in range(2):
+            self.assertEqual(exp_groups2[i], decomp2.groups[i].gids)
+
+        for i in range(recipe.num_cells()):
+            self.assertEqual(int(i/cells_per_rank), decomp2.gid_domain(i))
+
+    def test_domain_decomposition_nonsymmetric(self):
+        nranks = 1
+        rank = 0
+        if (mpi_enabled):
+            comm = arb.mpi_comm()
+            context = arb.context(threads=1, gpu_id=None, mpi=comm)
+            nranks = context.ranks
+            rank = context.rank
+        else:
+            context = arb.context(threads=1, gpu_id=None)
+
+        recipe = gj_non_symmetric(nranks)
+        decomp = arb.partition_load_balance(recipe, context)
+
+        cells_per_rank = nranks
+
+        # check groups
+        i = 0
+        for gid in range(rank*cells_per_rank, (rank + 1)*cells_per_rank):
+            if (gid%nranks == rank - 1):
+                continue
+            elif (gid%nranks == rank and rank != nranks - 1):
+                cg = [gid, gid + cells_per_rank]
+                self.assertEqual(cg, decomp.groups[len(decomp.groups)-1].gids)
+            else:
+                cg = [gid]
+                self.assertEqual(cg, decomp.groups[i].gids)
+                i += 1
+
+        # check gid_domains
+        for gid in range(recipe.num_cells()):
+            group = int(gid/cells_per_rank)
+            idx = gid%cells_per_rank
+            ngroups = nranks
+            if (idx == group - 1):
+                self.assertEqual(group - 1, decomp.gid_domain(gid))
+            elif (idx == group and group != ngroups - 1):
+                self.assertEqual(group, decomp.gid_domain(gid))
+            else:
+                self.assertEqual(group, decomp.gid_domain(gid))
+
+    def test_domain_decomposition_exceptions(self):
+        nranks = 1
+        rank = 0
+        if (mpi_enabled):
+            comm = arb.mpi_comm()
+            context = arb.context(threads=1, gpu_id=None, mpi=comm)
+            nranks = context.ranks
+            rank = context.rank
+        else:
+            context = arb.context(threads=1, gpu_id=None)
+
+        recipe = gj_symmetric(nranks)
+
+        hint1 = arb.partition_hint()
+        hint1.prefer_gpu = False
+        hint1.cpu_group_size = 0
+        hints1 = dict([(arb.cell_kind.cable, hint1)])
+
+        with self.assertRaisesRegex(RuntimeError,
+            "unable to perform load balancing because cell_kind::cable has invalid suggested cpu_cell_group size of 0"):
+            decomp1 = arb.partition_load_balance(recipe, context, hints1)
+
+        hint2 = arb.partition_hint()
+        hint2.prefer_gpu = True
+        hint2.gpu_group_size = 0
+        hints2 = dict([(arb.cell_kind.cable, hint2)])
+
+        with self.assertRaisesRegex(RuntimeError,
+            "unable to perform load balancing because cell_kind::cable has invalid suggested gpu_cell_group size of 0"):
+            decomp2 = arb.partition_load_balance(recipe, context, hints2)
+
+def suite():
+    # specify class and test functions in tuple (here: all tests starting with 'test' from class Contexts
+    suite = unittest.makeSuite(Domain_Decompositions_Distributed, ('test'))
+    return suite
+
+def run():
+    v = options.parse_arguments().verbosity
+
+    if not arb.mpi_is_initialized():
+        arb.mpi_init()
+
+    comm = arb.mpi_comm()
+        
+    alloc = arb.proc_allocation()
+    ctx = arb.context(alloc, comm)
+    rank = ctx.rank
+
+    if rank == 0:
+        runner = unittest.TextTestRunner(verbosity = v)
+    else:
+        sys.stdout = open(os.devnull, 'w')
+        runner = unittest.TextTestRunner(stream=sys.stdout)
+
+    runner.run(suite())
+
+    if not arb.mpi_is_finalized():
+        arb.mpi_finalize()
+
+if __name__ == "__main__":
+    run()
-- 
GitLab