From d15afd3925bfe01eddeddc523724bc29332b4547 Mon Sep 17 00:00:00 2001
From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com>
Date: Tue, 21 Jun 2022 14:48:40 +0200
Subject: [PATCH] Add some convenience to simulation creation. (#1904)

Encapsulates 80% of cases:
```cxx
 rec = recipe()
 ctx = make_context()
 dec = partition_load_balance(rec, ctx)
 sim = simulation(rec, dec, ctx)
```
is now written as
```cxx
 rec = recipe()
 sim = simulation(rec)
```
In python we use keyword args to allow
both to be specified separatly.

Partially fixes #1862
---
 arbor/include/arbor/load_balance.hpp           |  2 +-
 arbor/include/arbor/simulation.hpp             |  9 ++++++++-
 arbor/partition_load_balance.cpp               |  2 +-
 arbor/simulation.cpp                           |  4 ++--
 doc/concepts/simulation.rst                    | 12 ++++++++----
 example/bench/bench.cpp                        |  2 +-
 example/brunel/brunel.cpp                      |  5 +++--
 example/drybench/drybench.cpp                  |  4 +---
 example/dryrun/dryrun.cpp                      |  4 +---
 example/gap_junctions/gap_junctions.cpp        |  2 +-
 example/generators/generators.cpp              |  5 +----
 example/lfp/lfp.cpp                            | 18 +++++++-----------
 example/probe-demo/probe-demo.cpp              |  3 +--
 example/ring/ring.cpp                          |  6 +++---
 example/single/single.cpp                      |  3 +--
 python/context.cpp                             |  2 +-
 python/example/brunel.py                       |  2 +-
 python/example/dynamic-catalogue.py            |  4 +---
 python/example/gap_junctions.py                |  6 ++----
 python/example/network_ring.py                 |  5 ++---
 python/example/network_ring_mpi.py             |  3 +--
 python/example/single_cell_cable.py            |  6 +-----
 python/example/single_cell_detailed_recipe.py  | 14 ++++----------
 .../single_cell_extracellular_potentials.py    |  4 +---
 python/example/single_cell_recipe.py           |  7 +------
 python/example/single_cell_stdp.py             |  4 +---
 python/example/two_cell_gap_junctions.py       |  6 +-----
 python/simulation.cpp                          | 16 +++++++++++-----
 python/single_cell_model.cpp                   |  2 +-
 python/test/fixtures.py                        |  2 +-
 python/test/unit/test_cable_probes.py          |  2 +-
 python/test/unit/test_catalogues.py            |  2 +-
 python/test/unit/test_multiple_connections.py  |  8 +++-----
 python/test/unit/test_profiling.py             |  2 +-
 python/test/unit_distributed/test_simulator.py |  2 +-
 test/unit/test_event_delivery.cpp              | 11 ++++++-----
 test/unit/test_fvm_lowered.cpp                 |  4 ++--
 test/unit/test_lif_cell_group.cpp              |  6 +++---
 test/unit/test_probe.cpp                       | 10 +++++-----
 test/unit/test_recipe.cpp                      | 16 ++++++++--------
 test/unit/test_simulation.cpp                  |  6 +++---
 test/unit/test_spikes.cpp                      |  2 +-
 42 files changed, 105 insertions(+), 130 deletions(-)

diff --git a/arbor/include/arbor/load_balance.hpp b/arbor/include/arbor/load_balance.hpp
index b9e8e71b..1c073980 100644
--- a/arbor/include/arbor/load_balance.hpp
+++ b/arbor/include/arbor/load_balance.hpp
@@ -23,5 +23,5 @@ using partition_hint_map = std::unordered_map<cell_kind, partition_hint>;
 ARB_ARBOR_API domain_decomposition partition_load_balance(
     const recipe& rec,
     const context& ctx,
-    partition_hint_map hint_map = {});
+    const partition_hint_map& hint_map = {});
 } // namespace arb
diff --git a/arbor/include/arbor/simulation.hpp b/arbor/include/arbor/simulation.hpp
index 7e116719..c2dd5311 100644
--- a/arbor/include/arbor/simulation.hpp
+++ b/arbor/include/arbor/simulation.hpp
@@ -4,11 +4,13 @@
 #include <memory>
 #include <unordered_map>
 #include <vector>
+#include <functional>
 
 #include <arbor/export.hpp>
 #include <arbor/common_types.hpp>
 #include <arbor/context.hpp>
 #include <arbor/domain_decomposition.hpp>
+#include <arbor/load_balance.hpp>
 #include <arbor/recipe.hpp>
 #include <arbor/sampling.hpp>
 #include <arbor/schedule.hpp>
@@ -25,7 +27,12 @@ class simulation_state;
 
 class ARB_ARBOR_API simulation {
 public:
-    simulation(const recipe& rec, const domain_decomposition& decomp, const context& ctx);
+
+    simulation(const recipe& rec, const context& ctx, const domain_decomposition& decomp);
+
+    simulation(const recipe& rec,
+               const context& ctx=make_context(),
+               std::function<domain_decomposition(const recipe&, const context&)> balancer=[](auto& r, auto& c) { return partition_load_balance(r, c); }): simulation(rec, ctx, balancer(rec, ctx)) {}
 
     void reset();
 
diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp
index 19f3ac09..71467aab 100644
--- a/arbor/partition_load_balance.cpp
+++ b/arbor/partition_load_balance.cpp
@@ -23,7 +23,7 @@ namespace arb {
 ARB_ARBOR_API domain_decomposition partition_load_balance(
     const recipe& rec,
     const context& ctx,
-    partition_hint_map hint_map)
+    const partition_hint_map& hint_map)
 {
     using util::make_span;
 
diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp
index 21a07d3c..f604e19b 100644
--- a/arbor/simulation.cpp
+++ b/arbor/simulation.cpp
@@ -510,8 +510,8 @@ void simulation_state::inject_events(const cse_vector& events) {
 
 simulation::simulation(
     const recipe& rec,
-    const domain_decomposition& decomp,
-    const context& ctx)
+    const context& ctx,
+    const domain_decomposition& decomp)
 {
     impl_.reset(new simulation_state(rec, decomp, *ctx));
 }
diff --git a/doc/concepts/simulation.rst b/doc/concepts/simulation.rst
index 2a63cc69..6dc7c055 100644
--- a/doc/concepts/simulation.rst
+++ b/doc/concepts/simulation.rst
@@ -9,12 +9,16 @@ group are scheduled.
 From recipe to simulation
 -------------------------
 
-To build a simulation the following are needed:
+To build a simulation the following is needed:
 
 * A :ref:`recipe <modelrecipe>` that describes the cells and connections in the model.
-* A :ref:`domain decomposition <modeldomdec>` that describes the distribution of the
-  model over the local and distributed :ref:`hardware resources <modelhardware>`.
-* An :ref:`execution context <modelcontext>` used to execute the simulation.
+* A :ref:`domain decomposition <modeldomdec>` that describes the distribution of
+  the model over the local and distributed :ref:`hardware resources
+  <modelhardware>`. If not given, a default algorithm will be used which assigns
+  cells to groups one to one; each group is assigned to a thread from the context.
+* An :ref:`execution context <modelcontext>` used to execute the simulation. If
+  not given, the default context will be used, which allocates one thread, one
+  process (MPI task), and no GPU.
 
 Simulation execution and interaction
 ------------------------------------
diff --git a/example/bench/bench.cpp b/example/bench/bench.cpp
index 0d6fb976..1057ff74 100644
--- a/example/bench/bench.cpp
+++ b/example/bench/bench.cpp
@@ -165,7 +165,7 @@ int main(int argc, char** argv) {
         meters.checkpoint("domain-decomp", context);
 
         // Construct the model.
-        arb::simulation sim(recipe, decomp, context);
+        arb::simulation sim(recipe, context, decomp);
         meters.checkpoint("model-build", context);
 
         // Run the simulation for 100 ms, with time steps of 0.01 ms.
diff --git a/example/brunel/brunel.cpp b/example/brunel/brunel.cpp
index 4343e593..351e656f 100644
--- a/example/brunel/brunel.cpp
+++ b/example/brunel/brunel.cpp
@@ -246,9 +246,10 @@ int main(int argc, char** argv) {
 
         partition_hint_map hints;
         hints[cell_kind::lif].cpu_group_size = group_size;
-        auto decomp = partition_load_balance(recipe, context, hints);
 
-        simulation sim(recipe, decomp, context);
+        simulation sim(recipe,
+                       context,
+                       [&hints](auto& r, auto& c) { return partition_load_balance(r, c, hints); });
 
         // Set up spike recording.
         std::vector<arb::spike> recorded_spikes;
diff --git a/example/drybench/drybench.cpp b/example/drybench/drybench.cpp
index ac891cdd..c5300d1a 100644
--- a/example/drybench/drybench.cpp
+++ b/example/drybench/drybench.cpp
@@ -156,10 +156,8 @@ int main(int argc, char** argv) {
         auto tile = std::make_unique<tile_desc>(params);
         arb::symmetric_recipe recipe(std::move(tile));
 
-        auto decomp = arb::partition_load_balance(recipe, ctx);
-
         // Construct the model.
-        arb::simulation sim(recipe, decomp, ctx);
+        arb::simulation sim(recipe, ctx);
 
         meters.checkpoint("model-init", ctx);
 
diff --git a/example/dryrun/dryrun.cpp b/example/dryrun/dryrun.cpp
index 8a057d09..6f38d968 100644
--- a/example/dryrun/dryrun.cpp
+++ b/example/dryrun/dryrun.cpp
@@ -174,10 +174,8 @@ int main(int argc, char** argv) {
                 params.num_ranks, params.cell, params.min_delay);
         arb::symmetric_recipe recipe(std::move(tile));
 
-        auto decomp = arb::partition_load_balance(recipe, ctx);
-
         // Construct the model.
-        arb::simulation sim(recipe, decomp, ctx);
+        arb::simulation sim(recipe, ctx);
 
         // The id of the only probe on the cell: the cell_member type points to (cell 0, probe 0)
         auto probe_id = cell_member_type{0, 0};
diff --git a/example/gap_junctions/gap_junctions.cpp b/example/gap_junctions/gap_junctions.cpp
index f61b835e..10b124d1 100644
--- a/example/gap_junctions/gap_junctions.cpp
+++ b/example/gap_junctions/gap_junctions.cpp
@@ -172,7 +172,7 @@ int main(int argc, char** argv) {
         auto decomp = arb::partition_load_balance(recipe, context);
 
         // Construct the model.
-        arb::simulation sim(recipe, decomp, context);
+        arb::simulation sim(recipe, context, decomp);
 
         // Set up the probe that will measure voltage in the cell.
 
diff --git a/example/generators/generators.cpp b/example/generators/generators.cpp
index 94bc75be..9e72aff2 100644
--- a/example/generators/generators.cpp
+++ b/example/generators/generators.cpp
@@ -132,11 +132,8 @@ int main() {
     // Create an instance of our recipe.
     generator_recipe recipe;
 
-    // Make the domain decomposition for the model
-    auto decomp = arb::partition_load_balance(recipe, context);
-
     // Construct the model.
-    arb::simulation sim(recipe, decomp, context);
+    arb::simulation sim(recipe, context);
 
     // Set up the probe that will measure voltage in the cell.
 
diff --git a/example/lfp/lfp.cpp b/example/lfp/lfp.cpp
index 9669c26e..cb4f7ec2 100644
--- a/example/lfp/lfp.cpp
+++ b/example/lfp/lfp.cpp
@@ -3,7 +3,6 @@
 #include <vector>
 #include <iostream>
 
-#include <arbor/load_balance.hpp>
 #include <arbor/cable_cell.hpp>
 #include <arbor/morph/morphology.hpp>
 #include <arbor/morph/place_pwlin.hpp>
@@ -79,8 +78,7 @@ private:
         // Apical dendrite, length 490 μm, radius 1 μm, with SWC tag 4.
         tree.append(soma_apex, {0, 0, 10, 1},  {0, 0, 500, 1}, 4);
 
-        decor dec;
-
+        auto dec = arb::decor();
         // Use NEURON defaults for reversal potentials, ion concentrations etc., but override ra, cm.
         dec.set_default(axial_resistivity{100});     // [Ω·cm]
         dec.set_default(membrane_capacitance{0.01}); // [F/m²]
@@ -183,24 +181,22 @@ struct {
 // Run simulation.
 
 int main(int argc, char** argv) {
-    auto context = arb::make_context();
-
-    // Weight 0.005 μS, onset at t = 0 ms, mean frequency 0.1 kHz.
-    auto events = arb::poisson_generator({"syn"}, .005, 0., 0.1, std::minstd_rand{});
-    lfp_demo_recipe R(events);
-
+    // Configuration
     const double t_stop = 100;    // [ms]
     const double sample_dt = 0.1; // [ms]
     const double dt = 0.1;        // [ms]
 
-    arb::simulation sim(R, arb::partition_load_balance(R, context), context);
+    // Weight 0.005 μS, onset at t = 0 ms, mean frequency 0.1 kHz.
+    auto events = arb::poisson_generator({"syn"}, .005, 0., 0.1, std::minstd_rand{});
+    lfp_demo_recipe recipe(events);
+    arb::simulation sim(recipe);
 
     std::vector<position> electrodes = {
         {30, 0, 0},
         {30, 0, 100}
     };
 
-    arb::morphology cell_morphology = any_cast<arb::cable_cell>(R.get_cell_description(0)).morphology();
+    arb::morphology cell_morphology = any_cast<arb::cable_cell>(recipe.get_cell_description(0)).morphology();
     arb::place_pwlin placed_cell(cell_morphology);
 
     auto probe0_metadata = sim.get_probe_metadata(cell_member_type{0, 0});
diff --git a/example/probe-demo/probe-demo.cpp b/example/probe-demo/probe-demo.cpp
index 9e9f7854..0d27c111 100644
--- a/example/probe-demo/probe-demo.cpp
+++ b/example/probe-demo/probe-demo.cpp
@@ -145,8 +145,7 @@ int main(int argc, char** argv) {
 
         cable_recipe R(opt.probe_addr, opt.n_cv);
 
-        auto context = arb::make_context();
-        arb::simulation sim(R, arb::partition_load_balance(R, context), context);
+        arb::simulation sim(R);
 
         sim.add_sampler(arb::all_probes,
                 arb::regular_schedule(opt.sample_dt),
diff --git a/example/ring/ring.cpp b/example/ring/ring.cpp
index 45a3b55b..fa47d3f2 100644
--- a/example/ring/ring.cpp
+++ b/example/ring/ring.cpp
@@ -13,6 +13,7 @@
 
 #include <arborio/label_parse.hpp>
 
+#include <arbor/load_balance.hpp>
 #include <arbor/assert_macro.hpp>
 #include <arbor/common_types.hpp>
 #include <arbor/cable_cell.hpp>
@@ -160,10 +161,9 @@ int main(int argc, char** argv) {
         // Create an instance of our recipe.
         ring_recipe recipe(params.num_cells, params.cell, params.min_delay);
 
-        auto decomp = arb::partition_load_balance(recipe, context);
-
         // Construct the model.
-        arb::simulation sim(recipe, decomp, context);
+        auto decomposition = arb::partition_load_balance(recipe, context);
+        arb::simulation sim(recipe, context, decomposition);
 
         // Set up the probe that will measure voltage in the cell.
 
diff --git a/example/single/single.cpp b/example/single/single.cpp
index 2721d2fd..4befabcd 100644
--- a/example/single/single.cpp
+++ b/example/single/single.cpp
@@ -89,8 +89,7 @@ int main(int argc, char** argv) {
         options opt = parse_options(argc, argv);
         single_recipe R(opt.swc_file.empty()? default_morphology(): read_swc(opt.swc_file), opt.policy);
 
-        auto context = arb::make_context();
-        arb::simulation sim(R, arb::partition_load_balance(R, context), context);
+        arb::simulation sim(R);
 
         // Attach a sampler to the probe described in the recipe, sampling every 0.1 ms.
 
diff --git a/python/context.cpp b/python/context.cpp
index 80fcbc61..1da55c86 100644
--- a/python/context.cpp
+++ b/python/context.cpp
@@ -123,7 +123,7 @@ void register_contexts(pybind11::module& m) {
         .def("__repr__", util::to_string<proc_allocation_shim>);
 
     // context
-    pybind11::class_<context_shim> context(m, "context", "An opaque handle for the hardware resources used in a simulation.");
+    pybind11::class_<context_shim, std::shared_ptr<context_shim>> context(m, "context", "An opaque handle for the hardware resources used in a simulation.");
     context
         .def(pybind11::init(
             [](unsigned threads, pybind11::object gpu, pybind11::object mpi){
diff --git a/python/example/brunel.py b/python/example/brunel.py
index 6a01f411..5455ca69 100755
--- a/python/example/brunel.py
+++ b/python/example/brunel.py
@@ -257,7 +257,7 @@ if __name__ == "__main__":
 
     meters.checkpoint("load-balance", context)
 
-    sim = arbor.simulation(recipe, decomp, context)
+    sim = arbor.simulation(recipe, context, decomp)
     sim.record(arbor.spike_recording.all)
 
     meters.checkpoint("simulation-init", context)
diff --git a/python/example/dynamic-catalogue.py b/python/example/dynamic-catalogue.py
index 6fc5e8fc..67fc7725 100644
--- a/python/example/dynamic-catalogue.py
+++ b/python/example/dynamic-catalogue.py
@@ -42,7 +42,5 @@ where <arbor> is the location of the arbor source tree."""
     exit(1)
 
 rcp = recipe()
-ctx = arb.context()
-dom = arb.partition_load_balance(rcp, ctx)
-sim = arb.simulation(rcp, dom, ctx)
+sim = arb.simulation(rcp)
 sim.run(tfinal=30)
diff --git a/python/example/gap_junctions.py b/python/example/gap_junctions.py
index 2b6b3221..8d5ce084 100644
--- a/python/example/gap_junctions.py
+++ b/python/example/gap_junctions.py
@@ -133,10 +133,8 @@ ncells = nchains * ncells_per_chain
 # Instantiate recipe
 recipe = chain_recipe(ncells_per_chain, nchains)
 
-# Create a default execution context, domain decomposition and simulation
-context = arbor.context()
-decomp = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, decomp, context)
+# Create a default simulation
+sim = arbor.simulation(recipe)
 
 # Set spike generators to record
 sim.record(arbor.spike_recording.all)
diff --git a/python/example/network_ring.py b/python/example/network_ring.py
index 9eec3cbd..9cb97a3a 100755
--- a/python/example/network_ring.py
+++ b/python/example/network_ring.py
@@ -123,9 +123,8 @@ ncells = 4
 recipe = ring_recipe(ncells)
 
 # (12) Create an execution context using all locally available threads, domain decomposition and simulation
-context = arbor.context("avail_threads")
-decomp = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, decomp, context)
+ctx = arbor.context("avail_threads")
+sim = arbor.simulation(recipe, context=ctx)
 
 # (13) Set spike generators to record
 sim.record(arbor.spike_recording.all)
diff --git a/python/example/network_ring_mpi.py b/python/example/network_ring_mpi.py
index 84aa7daa..7ccd2d53 100644
--- a/python/example/network_ring_mpi.py
+++ b/python/example/network_ring_mpi.py
@@ -132,8 +132,7 @@ context = arbor.context(mpi=comm)
 print(context)
 
 # (13) Create a default domain decomposition and simulation
-decomp = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, decomp, context)
+sim = arbor.simulation(recipe, context)
 
 # (14) Set spike generators to record
 sim.record(arbor.spike_recording.all)
diff --git a/python/example/single_cell_cable.py b/python/example/single_cell_cable.py
index c7116fc9..b375ec37 100755
--- a/python/example/single_cell_cable.py
+++ b/python/example/single_cell_cable.py
@@ -208,12 +208,8 @@ if __name__ == "__main__":
     ]
     recipe = Cable(probes, **vars(args))
 
-    # create a default execution context and a default domain decomposition
-    context = arbor.context()
-    domains = arbor.partition_load_balance(recipe, context)
-
     # configure the simulation and handles for the probes
-    sim = arbor.simulation(recipe, domains, context)
+    sim = arbor.simulation(recipe)
     dt = 0.001
     handles = [
         sim.sample((0, i), arbor.regular_schedule(dt)) for i in range(len(probes))
diff --git a/python/example/single_cell_detailed_recipe.py b/python/example/single_cell_detailed_recipe.py
index 25da4580..dd658f66 100644
--- a/python/example/single_cell_detailed_recipe.py
+++ b/python/example/single_cell_detailed_recipe.py
@@ -145,14 +145,8 @@ class single_recipe(arbor.recipe):
 # Pass the probe in a list because that it what single_recipe expects.
 recipe = single_recipe(cell, [probe])
 
-# (4) Create an execution context
-context = arbor.context()
-
-# (5) Create a domain decomposition
-domains = arbor.partition_load_balance(recipe, context)
-
-# (6) Create a simulation
-sim = arbor.simulation(recipe, domains, context)
+# (7) Create a simulation
+sim = arbor.simulation(recipe)
 
 # Instruct the simulation to record the spikes and sample the probe
 sim.record(arbor.spike_recording.all)
@@ -160,10 +154,10 @@ sim.record(arbor.spike_recording.all)
 probe_id = arbor.cell_member(0, 0)
 handle = sim.sample(probe_id, arbor.regular_schedule(0.02))
 
-# (7) Run the simulation
+# (8) Run the simulation
 sim.run(tfinal=100, dt=0.025)
 
-# (8) Print or display the results
+# (9) Print or display the results
 spikes = sim.spikes()
 print(len(spikes), "spikes recorded:")
 for s in spikes:
diff --git a/python/example/single_cell_extracellular_potentials.py b/python/example/single_cell_extracellular_potentials.py
index a5c13334..413ebb90 100644
--- a/python/example/single_cell_extracellular_potentials.py
+++ b/python/example/single_cell_extracellular_potentials.py
@@ -124,9 +124,7 @@ p, cell = make_cable_cell(morphology, clamp_location)
 recipe = Recipe(cell)
 
 # instantiate simulation
-context = arbor.context()
-domains = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, domains, context)
+sim = arbor.simulation(recipe)
 
 # set up sampling on probes with sampling every 1 ms
 schedule = arbor.regular_schedule(1.0)
diff --git a/python/example/single_cell_recipe.py b/python/example/single_cell_recipe.py
index eae1e436..facd5c8d 100644
--- a/python/example/single_cell_recipe.py
+++ b/python/example/single_cell_recipe.py
@@ -60,15 +60,10 @@ class single_recipe(arbor.recipe):
 
 recipe = single_recipe(cell, [arbor.cable_probe_membrane_voltage('"midpoint"')])
 
-# (6) Create a default execution context and a default domain decomposition.
-
-context = arbor.context()
-domains = arbor.partition_load_balance(recipe, context)
-
 # (7) Create and run simulation and set up 10 kHz (every 0.1 ms) sampling on the probe.
 # The probe is located on cell 0, and is the 0th probe on that cell, thus has probe_id (0, 0).
 
-sim = arbor.simulation(recipe, domains, context)
+sim = arbor.simulation(recipe)
 sim.record(arbor.spike_recording.all)
 handle = sim.sample((0, 0), arbor.regular_schedule(0.1))
 sim.run(tfinal=30)
diff --git a/python/example/single_cell_stdp.py b/python/example/single_cell_stdp.py
index 2f0d75bc..8c2d443c 100755
--- a/python/example/single_cell_stdp.py
+++ b/python/example/single_cell_stdp.py
@@ -78,9 +78,7 @@ class single_recipe(arbor.recipe):
 def run(dT, n_pairs=1, do_plots=False):
     recipe = single_recipe(dT, n_pairs)
 
-    context = arbor.context()
-    domains = arbor.partition_load_balance(recipe, context)
-    sim = arbor.simulation(recipe, domains, context)
+    sim = arbor.simulation(recipe)
 
     sim.record(arbor.spike_recording.all)
 
diff --git a/python/example/two_cell_gap_junctions.py b/python/example/two_cell_gap_junctions.py
index add860b9..c303802c 100755
--- a/python/example/two_cell_gap_junctions.py
+++ b/python/example/two_cell_gap_junctions.py
@@ -153,12 +153,8 @@ if __name__ == "__main__":
     probes = [arbor.cable_probe_membrane_voltage('"gj_site"')]
     recipe = TwoCellsWithGapJunction(probes, **vars(args))
 
-    # create a default execution context and a default domain decomposition
-    context = arbor.context()
-    domains = arbor.partition_load_balance(recipe, context)
-
     # configure the simulation and handles for the probes
-    sim = arbor.simulation(recipe, domains, context)
+    sim = arbor.simulation(recipe)
 
     dt = 0.01
     handles = []
diff --git a/python/simulation.cpp b/python/simulation.cpp
index b506ff17..90689f84 100644
--- a/python/simulation.cpp
+++ b/python/simulation.cpp
@@ -55,11 +55,11 @@ class simulation_shim {
     std::unordered_map<arb::sampler_association_handle, sampler_callback> sampler_map_;
 
 public:
-    simulation_shim(std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx, pyarb_global_ptr global_ptr):
+    simulation_shim(std::shared_ptr<py_recipe>& rec, const context_shim& ctx, const arb::domain_decomposition& decomp, pyarb_global_ptr global_ptr):
         global_ptr_(global_ptr)
     {
         try {
-            sim_.reset(new arb::simulation(py_recipe_shim(rec), decomp, ctx.context));
+            sim_.reset(new arb::simulation(py_recipe_shim(rec), ctx.context, decomp));
         }
         catch (...) {
             py_reset_and_throw();
@@ -195,9 +195,13 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
         // A custom constructor that wraps a python recipe with arb::py_recipe_shim
         // before forwarding it to the arb::recipe constructor.
         .def(pybind11::init(
-            [global_ptr](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) {
+                 [global_ptr](std::shared_ptr<py_recipe>& rec,
+                              const std::shared_ptr<context_shim>& ctx_,
+                              const std::optional<arb::domain_decomposition>& decomp) {
                 try {
-                    return new simulation_shim(rec, decomp, ctx, global_ptr);
+                    auto ctx = ctx_ ? ctx_ : std::make_shared<context_shim>(arb::make_context());
+                    auto dec = decomp.value_or(arb::partition_load_balance(py_recipe_shim(rec), ctx->context));
+                    return new simulation_shim(rec, *ctx, dec, global_ptr);
                 }
                 catch (...) {
                     py_reset_and_throw();
@@ -208,7 +212,9 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Initialize the model described by a recipe, with cells and network distributed\n"
             "according to the domain decomposition and computational resources described by a context.",
-            "recipe"_a, "domain_decomposition"_a, "context"_a)
+             "recipe"_a,
+             pybind11::arg_v("context", pybind11::none(), "Execution context"),
+             pybind11::arg_v("domains", pybind11::none(), "Domain decomposition"))
         .def("reset", &simulation_shim::reset,
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Reset the state of the simulation to its initial state.")
diff --git a/python/single_cell_model.cpp b/python/single_cell_model.cpp
index 5165aa02..cc8c0cb6 100644
--- a/python/single_cell_model.cpp
+++ b/python/single_cell_model.cpp
@@ -171,7 +171,7 @@ public:
 
         auto domdec = arb::partition_load_balance(rec, ctx_);
 
-        sim_ = std::make_unique<arb::simulation>(rec, domdec, ctx_);
+        sim_ = std::make_unique<arb::simulation>(rec, ctx_, domdec);
 
         // Create one trace for each probe.
         traces_.reserve(probes_.size());
diff --git a/python/test/fixtures.py b/python/test/fixtures.py
index 9f44c516..7885b12a 100644
--- a/python/test/fixtures.py
+++ b/python/test/fixtures.py
@@ -246,4 +246,4 @@ def sum_weight_hh_spike_2():
 @art_spiker_recipe
 def art_spiking_sim(context, art_spiker_recipe):
     dd = arbor.partition_load_balance(art_spiker_recipe, context)
-    return arbor.simulation(art_spiker_recipe, dd, context)
+    return arbor.simulation(art_spiker_recipe, context, dd)
diff --git a/python/test/unit/test_cable_probes.py b/python/test/unit/test_cable_probes.py
index 59e8e610..d2a66057 100644
--- a/python/test/unit/test_cable_probes.py
+++ b/python/test/unit/test_cable_probes.py
@@ -91,7 +91,7 @@ class TestCableProbes(unittest.TestCase):
         recipe = cc_recipe()
         context = A.context()
         dd = A.partition_load_balance(recipe, context)
-        sim = A.simulation(recipe, dd, context)
+        sim = A.simulation(recipe, context, dd)
 
         all_cv_cables = [A.cable(0, 0, 1)]
 
diff --git a/python/test/unit/test_catalogues.py b/python/test/unit/test_catalogues.py
index e4bd1e63..6ef96f91 100644
--- a/python/test/unit/test_catalogues.py
+++ b/python/test/unit/test_catalogues.py
@@ -60,7 +60,7 @@ class TestCatalogues(unittest.TestCase):
         rcp = recipe()
         ctx = arb.context()
         dom = arb.partition_load_balance(rcp, ctx)
-        sim = arb.simulation(rcp, dom, ctx)
+        sim = arb.simulation(rcp, ctx, dom)
         sim.run(tfinal=30)
 
     def test_empty(self):
diff --git a/python/test/unit/test_multiple_connections.py b/python/test/unit/test_multiple_connections.py
index 89cc32d7..37b0f16e 100644
--- a/python/test/unit/test_multiple_connections.py
+++ b/python/test/unit/test_multiple_connections.py
@@ -134,8 +134,7 @@ class TestMultipleConnections(unittest.TestCase):
         self.assertAlmostEqual(connections_from_recipe[3].delay, 1.4)
 
         # construct domain_decomposition and simulation object
-        dd = arb.partition_load_balance(art_spiker_recipe, context)
-        sim = arb.simulation(art_spiker_recipe, dd, context)
+        sim = arb.simulation(art_spiker_recipe, context)
         sim.record(arb.spike_recording.all)
 
         # create schedule and handle to record the membrane potential of neuron 3
@@ -365,9 +364,8 @@ class TestMultipleConnections(unittest.TestCase):
         self.assertAlmostEqual(connections_from_recipe[1].weight, weight2)
         self.assertAlmostEqual(connections_from_recipe[1].delay, 1.4)
 
-        # construct domain_decomposition and simulation object
-        dd = arb.partition_load_balance(art_spiker_recipe, context)
-        sim = arb.simulation(art_spiker_recipe, dd, context)
+        # construct simulation object
+        sim = arb.simulation(art_spiker_recipe, context)
         sim.record(arb.spike_recording.all)
 
         # create schedule and handle to record the membrane potential of neuron 3
diff --git a/python/test/unit/test_profiling.py b/python/test/unit/test_profiling.py
index 92a17e23..0c0d484d 100644
--- a/python/test/unit/test_profiling.py
+++ b/python/test/unit/test_profiling.py
@@ -93,7 +93,7 @@ class TestProfiling(unittest.TestCase):
         arb.profiler_initialize(context)
         recipe = a_recipe()
         dd = arb.partition_load_balance(recipe, context)
-        arb.simulation(recipe, dd, context).run(1)
+        arb.simulation(recipe, context, dd).run(1)
         summary = arb.profiler_summary()
         self.assertEqual(str, type(summary), "profiler summary must be str")
         self.assertTrue(summary, "empty summary")
diff --git a/python/test/unit_distributed/test_simulator.py b/python/test/unit_distributed/test_simulator.py
index 8ac2a983..ea657a95 100644
--- a/python/test/unit_distributed/test_simulator.py
+++ b/python/test/unit_distributed/test_simulator.py
@@ -69,7 +69,7 @@ class TestSimulator(unittest.TestCase):
         self.assertEqual(1, len(local_groups))
         self.assertEqual([self.rank], local_groups[0].gids)
 
-        return A.simulation(recipe, dd, context)
+        return A.simulation(recipe, context, dd)
 
     def test_local_spikes(self):
         sim = self.init_sim()
diff --git a/test/unit/test_event_delivery.cpp b/test/unit/test_event_delivery.cpp
index 1294240c..82a029f8 100644
--- a/test/unit/test_event_delivery.cpp
+++ b/test/unit/test_event_delivery.cpp
@@ -49,17 +49,18 @@ using gid_vector = std::vector<cell_gid_type>;
 using group_gids_type = std::vector<gid_vector>;
 
 std::vector<cell_gid_type> run_test_sim(const recipe& R, const group_gids_type& group_gids) {
-    arb::context ctx = make_context(proc_allocation{});
-    unsigned n = R.num_cells();
 
+    unsigned n = R.num_cells();
     std::vector<group_description> groups;
     for (const auto& gidvec: group_gids) {
         groups.emplace_back(cell_kind::cable, gidvec, backend_kind::multicore);
     }
-    auto D = domain_decomposition(R, ctx, groups);
-    std::vector<spike> spikes;
 
-    simulation sim(R, D, ctx);
+    auto C = make_context();
+    auto D = domain_decomposition(R, C, groups);
+    simulation sim(R, C, D);
+
+    std::vector<spike> spikes;
     sim.set_global_spike_callback(
             [&spikes](const std::vector<spike>& ss) {
                 spikes.insert(spikes.end(), ss.begin(), ss.end());
diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp
index d9d1aa76..55f10243 100644
--- a/test/unit/test_fvm_lowered.cpp
+++ b/test/unit/test_fvm_lowered.cpp
@@ -485,7 +485,7 @@ TEST(fvm_lowered, derived_mechs) {
         float times[] = {10.f, 20.f};
 
         auto decomp = partition_load_balance(rec, context);
-        simulation sim(rec, decomp, context);
+        simulation sim(rec, context, decomp);
         sim.add_sampler(all_probes, explicit_schedule(times), sampler);
         sim.run(30.0, 1.f/1024);
 
@@ -516,7 +516,7 @@ TEST(fvm_lowered, null_region) {
     rec.catalogue().derive("custom_kin1", "test_kin1", {{"tau", 20.0}});
 
     auto decomp = partition_load_balance(rec, context);
-    simulation sim(rec, decomp, context);
+    simulation sim(rec, context, decomp);
     EXPECT_NO_THROW(sim.run(30.0, 1.f/1024));
 }
 
diff --git a/test/unit/test_lif_cell_group.cpp b/test/unit/test_lif_cell_group.cpp
index 9ca765c4..0ea015a3 100644
--- a/test/unit/test_lif_cell_group.cpp
+++ b/test/unit/test_lif_cell_group.cpp
@@ -133,7 +133,7 @@ TEST(lif_cell_group, throw) {
     probe_recipe rec;
     auto context = make_context();
     auto decomp = partition_load_balance(rec, context);
-    EXPECT_THROW(simulation(rec, decomp, context), bad_cell_probe);
+    EXPECT_THROW(simulation(rec, context, decomp), bad_cell_probe);
 }
 
 TEST(lif_cell_group, recipe)
@@ -153,7 +153,7 @@ TEST(lif_cell_group, spikes) {
     auto context = make_context();
 
     auto decomp = partition_load_balance(recipe, context);
-    simulation sim(recipe, decomp, context);
+    simulation sim(recipe, context, decomp);
 
     cse_vector events;
 
@@ -193,7 +193,7 @@ TEST(lif_cell_group, ring)
     auto decomp = partition_load_balance(recipe, context);
 
     // Creates a simulation with a ring recipe of lif neurons
-    simulation sim(recipe, decomp, context);
+    simulation sim(recipe, context, decomp);
 
     std::vector<spike> spike_buffer;
 
diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp
index bc9f41c3..ca541823 100644
--- a/test/unit/test_probe.cpp
+++ b/test/unit/test_probe.cpp
@@ -826,7 +826,7 @@ void run_axial_and_ion_current_sampled_probe_test(const context& ctx) {
     partition_hint_map phints = {
        {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}}
     };
-    simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx);
+    simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints));
 
     // Take a sample at 20 tau, and run sim for just a bit longer.
 
@@ -913,7 +913,7 @@ auto run_simple_samplers(
     partition_hint_map phints = {
        {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}}
     };
-    simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx);
+    simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints));
 
     std::vector<trace_vector<SampleData, SampleMeta>> traces(n_probe);
     for (unsigned i = 0; i<n_probe; ++i) {
@@ -1278,13 +1278,13 @@ void run_exact_sampling_probe_test(const context& ctx) {
     };
     domain_decomposition one_cell_group = partition_load_balance(rec, ctx, phints);
 
-    simulation lax_sim(rec, one_cell_group, ctx);
+    simulation lax_sim(rec, ctx, one_cell_group);
     for (unsigned i = 0; i<n_cell; ++i) {
         lax_sim.add_sampler(one_probe({i, 0}), sample_sched, make_simple_sampler(lax_traces.at(i)), sampling_policy::lax);
     }
     lax_sim.run(t_end, max_dt);
 
-    simulation exact_sim(rec, one_cell_group, ctx);
+    simulation exact_sim(rec, ctx, one_cell_group);
     for (unsigned i = 0; i<n_cell; ++i) {
         exact_sim.add_sampler(one_probe({i, 0}), sample_sched, make_simple_sampler(exact_traces.at(i)), sampling_policy::exact);
     }
@@ -1366,7 +1366,7 @@ TEST(probe, get_probe_metadata) {
     partition_hint_map phints = {
        {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}}
     };
-    simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx);
+    simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints));
 
     std::vector<probe_metadata> mm = sim.get_probe_metadata({0, 0});
     ASSERT_EQ(3u, mm.size());
diff --git a/test/unit/test_recipe.cpp b/test/unit/test_recipe.cpp
index f630f851..f1ffe120 100644
--- a/test/unit/test_recipe.cpp
+++ b/test/unit/test_recipe.cpp
@@ -112,7 +112,7 @@ TEST(recipe, gap_junctions)
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {gjs_0, gjs_1}, {{}, {}});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context));
+        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0));
     }
     {
         std::vector<arb::gap_junction_connection> gjs_0 = {{{1, "gapjunction1", policy::assert_univalent}, {"gapjunction0", policy::assert_univalent}, 0.1},
@@ -126,7 +126,7 @@ TEST(recipe, gap_junctions)
         auto recipe_1 = custom_recipe({cell_0, cell_1}, {{}, {}}, {gjs_0, gjs_1}, {{}, {}});
         auto decomp_1 = partition_load_balance(recipe_1, context);
 
-        EXPECT_THROW(simulation(recipe_1, decomp_1, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_1, context, decomp_1), arb::bad_connection_label);
 
     }
 }
@@ -150,7 +150,7 @@ TEST(recipe, connections)
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context));
+        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0));
     }
     {
         conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1},
@@ -164,7 +164,7 @@ TEST(recipe, connections)
         auto recipe_1 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_1 = partition_load_balance(recipe_1, context);
 
-        EXPECT_THROW(simulation(recipe_1, decomp_1, context), arb::bad_connection_source_gid);
+        EXPECT_THROW(simulation(recipe_1, context, decomp_1), arb::bad_connection_source_gid);
     }
     {
         conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1},
@@ -178,7 +178,7 @@ TEST(recipe, connections)
         auto recipe_2 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_2 = partition_load_balance(recipe_2, context);
 
-        EXPECT_THROW(simulation(recipe_2, decomp_2, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_2, context, decomp_2), arb::bad_connection_label);
     }
     {
         conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1},
@@ -192,7 +192,7 @@ TEST(recipe, connections)
         auto recipe_4 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_4 = partition_load_balance(recipe_4, context);
 
-        EXPECT_THROW(simulation(recipe_4, decomp_4, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_4, context, decomp_4), arb::bad_connection_label);
     }
 }
 
@@ -210,7 +210,7 @@ TEST(recipe, event_generators) {
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}},  {gens_0, gens_1});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context));
+        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0));
     }
     {
         gens_0 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}, {{"synapse3"}, 2.0, 0.1}})};
@@ -219,6 +219,6 @@ TEST(recipe, event_generators) {
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}},  {gens_0, gens_1});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_THROW(simulation(recipe_0, decomp_0, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_0, context, decomp_0), arb::bad_connection_label);
     }
 }
diff --git a/test/unit/test_simulation.cpp b/test/unit/test_simulation.cpp
index 3110d438..b79ac19a 100644
--- a/test/unit/test_simulation.cpp
+++ b/test/unit/test_simulation.cpp
@@ -51,7 +51,7 @@ TEST(simulation, null) {
     auto r = null_recipe{};
     auto c = arb::make_context();
     auto d = arb::partition_load_balance(r, c);
-    auto s = arb::simulation(r, d, c);
+    auto s = arb::simulation(r, c, d);
     s.run(0.05, 0.01);
 }
 
@@ -74,7 +74,7 @@ TEST(simulation, spike_global_callback) {
     play_spikes rec(spike_times);
     auto ctx = n_thread_context(4);
     auto decomp = partition_load_balance(rec, ctx);
-    simulation sim(rec, decomp, ctx);
+    simulation sim(rec, ctx, decomp);
 
     std::vector<spike> collected;
     sim.set_global_spike_callback([&](const std::vector<spike>& spikes) {
@@ -155,7 +155,7 @@ TEST(simulation, restart) {
 
     auto ctx = n_thread_context(4);
     auto decomp = partition_load_balance(rec, ctx);
-    simulation sim(rec, decomp, ctx);
+    simulation sim(rec, ctx, decomp);
 
     std::vector<spike> collected;
     sim.set_global_spike_callback([&](const std::vector<spike>& spikes) {
diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp
index 12a53234..d3279169 100644
--- a/test/unit/test_spikes.cpp
+++ b/test/unit/test_spikes.cpp
@@ -234,7 +234,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher_interpolation) {
         cable1d_recipe rec({cell});
 
         auto decomp = arb::partition_load_balance(rec, context);
-        arb::simulation sim(rec, decomp, context);
+        arb::simulation sim(rec, context, decomp);
 
         sim.set_global_spike_callback(
                 [&spikes](const std::vector<arb::spike>& recorded_spikes) {
-- 
GitLab