diff --git a/arbor/include/arbor/load_balance.hpp b/arbor/include/arbor/load_balance.hpp
index b9e8e71b25a9bd75e6aa7487e897c5cc099f3191..1c07398095294f2ff0924ff2a5eed203d4a8283f 100644
--- a/arbor/include/arbor/load_balance.hpp
+++ b/arbor/include/arbor/load_balance.hpp
@@ -23,5 +23,5 @@ using partition_hint_map = std::unordered_map<cell_kind, partition_hint>;
 ARB_ARBOR_API domain_decomposition partition_load_balance(
     const recipe& rec,
     const context& ctx,
-    partition_hint_map hint_map = {});
+    const partition_hint_map& hint_map = {});
 } // namespace arb
diff --git a/arbor/include/arbor/simulation.hpp b/arbor/include/arbor/simulation.hpp
index 7e11671935f2957603452f12f828b66215a22d38..c2dd5311ec2ef6d71a0cb0bbb0bd1ed8232217bd 100644
--- a/arbor/include/arbor/simulation.hpp
+++ b/arbor/include/arbor/simulation.hpp
@@ -4,11 +4,13 @@
 #include <memory>
 #include <unordered_map>
 #include <vector>
+#include <functional>
 
 #include <arbor/export.hpp>
 #include <arbor/common_types.hpp>
 #include <arbor/context.hpp>
 #include <arbor/domain_decomposition.hpp>
+#include <arbor/load_balance.hpp>
 #include <arbor/recipe.hpp>
 #include <arbor/sampling.hpp>
 #include <arbor/schedule.hpp>
@@ -25,7 +27,12 @@ class simulation_state;
 
 class ARB_ARBOR_API simulation {
 public:
-    simulation(const recipe& rec, const domain_decomposition& decomp, const context& ctx);
+
+    simulation(const recipe& rec, const context& ctx, const domain_decomposition& decomp);
+
+    simulation(const recipe& rec,
+               const context& ctx=make_context(),
+               std::function<domain_decomposition(const recipe&, const context&)> balancer=[](auto& r, auto& c) { return partition_load_balance(r, c); }): simulation(rec, ctx, balancer(rec, ctx)) {}
 
     void reset();
 
diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp
index 19f3ac09931380660f81f7ee24d83ecd6d9aeaba..71467aabff1fc7e7cbb7a471a522f2b7046a1b74 100644
--- a/arbor/partition_load_balance.cpp
+++ b/arbor/partition_load_balance.cpp
@@ -23,7 +23,7 @@ namespace arb {
 ARB_ARBOR_API domain_decomposition partition_load_balance(
     const recipe& rec,
     const context& ctx,
-    partition_hint_map hint_map)
+    const partition_hint_map& hint_map)
 {
     using util::make_span;
 
diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp
index 21a07d3c983f854fdcc9704e91a8777c7d38f3c2..f604e19bc560e03f9e078e330f54bc29c8fd5577 100644
--- a/arbor/simulation.cpp
+++ b/arbor/simulation.cpp
@@ -510,8 +510,8 @@ void simulation_state::inject_events(const cse_vector& events) {
 
 simulation::simulation(
     const recipe& rec,
-    const domain_decomposition& decomp,
-    const context& ctx)
+    const context& ctx,
+    const domain_decomposition& decomp)
 {
     impl_.reset(new simulation_state(rec, decomp, *ctx));
 }
diff --git a/doc/concepts/simulation.rst b/doc/concepts/simulation.rst
index 2a63cc69c536bf70ea67093c78007ad9b8d3b1a0..6dc7c0551ca2c8d99e68754cd2f1a73488fc11fa 100644
--- a/doc/concepts/simulation.rst
+++ b/doc/concepts/simulation.rst
@@ -9,12 +9,16 @@ group are scheduled.
 From recipe to simulation
 -------------------------
 
-To build a simulation the following are needed:
+To build a simulation the following is needed:
 
 * A :ref:`recipe <modelrecipe>` that describes the cells and connections in the model.
-* A :ref:`domain decomposition <modeldomdec>` that describes the distribution of the
-  model over the local and distributed :ref:`hardware resources <modelhardware>`.
-* An :ref:`execution context <modelcontext>` used to execute the simulation.
+* A :ref:`domain decomposition <modeldomdec>` that describes the distribution of
+  the model over the local and distributed :ref:`hardware resources
+  <modelhardware>`. If not given, a default algorithm will be used which assigns
+  cells to groups one to one; each group is assigned to a thread from the context.
+* An :ref:`execution context <modelcontext>` used to execute the simulation. If
+  not given, the default context will be used, which allocates one thread, one
+  process (MPI task), and no GPU.
 
 Simulation execution and interaction
 ------------------------------------
diff --git a/example/bench/bench.cpp b/example/bench/bench.cpp
index 0d6fb97679be57675881846ba1db1f6d143ba621..1057ff74f90053e22901c9e50044a74b014db927 100644
--- a/example/bench/bench.cpp
+++ b/example/bench/bench.cpp
@@ -165,7 +165,7 @@ int main(int argc, char** argv) {
         meters.checkpoint("domain-decomp", context);
 
         // Construct the model.
-        arb::simulation sim(recipe, decomp, context);
+        arb::simulation sim(recipe, context, decomp);
         meters.checkpoint("model-build", context);
 
         // Run the simulation for 100 ms, with time steps of 0.01 ms.
diff --git a/example/brunel/brunel.cpp b/example/brunel/brunel.cpp
index 4343e5933e2f72352eacd1bf06cae4afbd1e06a0..351e656fe03f4979ea3bc1403fb29d4c6fd0bcff 100644
--- a/example/brunel/brunel.cpp
+++ b/example/brunel/brunel.cpp
@@ -246,9 +246,10 @@ int main(int argc, char** argv) {
 
         partition_hint_map hints;
         hints[cell_kind::lif].cpu_group_size = group_size;
-        auto decomp = partition_load_balance(recipe, context, hints);
 
-        simulation sim(recipe, decomp, context);
+        simulation sim(recipe,
+                       context,
+                       [&hints](auto& r, auto& c) { return partition_load_balance(r, c, hints); });
 
         // Set up spike recording.
         std::vector<arb::spike> recorded_spikes;
diff --git a/example/drybench/drybench.cpp b/example/drybench/drybench.cpp
index ac891cdd551de85540c202b0f41d2d0fdf248c2a..c5300d1aaab5131f4b4aab3d668c9ca0e6369e73 100644
--- a/example/drybench/drybench.cpp
+++ b/example/drybench/drybench.cpp
@@ -156,10 +156,8 @@ int main(int argc, char** argv) {
         auto tile = std::make_unique<tile_desc>(params);
         arb::symmetric_recipe recipe(std::move(tile));
 
-        auto decomp = arb::partition_load_balance(recipe, ctx);
-
         // Construct the model.
-        arb::simulation sim(recipe, decomp, ctx);
+        arb::simulation sim(recipe, ctx);
 
         meters.checkpoint("model-init", ctx);
 
diff --git a/example/dryrun/dryrun.cpp b/example/dryrun/dryrun.cpp
index 8a057d0972f9a00644212593f9d3fca69c85c74d..6f38d968238449effb77280ab0c5b10297727918 100644
--- a/example/dryrun/dryrun.cpp
+++ b/example/dryrun/dryrun.cpp
@@ -174,10 +174,8 @@ int main(int argc, char** argv) {
                 params.num_ranks, params.cell, params.min_delay);
         arb::symmetric_recipe recipe(std::move(tile));
 
-        auto decomp = arb::partition_load_balance(recipe, ctx);
-
         // Construct the model.
-        arb::simulation sim(recipe, decomp, ctx);
+        arb::simulation sim(recipe, ctx);
 
         // The id of the only probe on the cell: the cell_member type points to (cell 0, probe 0)
         auto probe_id = cell_member_type{0, 0};
diff --git a/example/gap_junctions/gap_junctions.cpp b/example/gap_junctions/gap_junctions.cpp
index f61b835ea1d15337d0851a2e55da6e56a1f5572e..10b124d1e37251552dce0b4b51a128216fad46bb 100644
--- a/example/gap_junctions/gap_junctions.cpp
+++ b/example/gap_junctions/gap_junctions.cpp
@@ -172,7 +172,7 @@ int main(int argc, char** argv) {
         auto decomp = arb::partition_load_balance(recipe, context);
 
         // Construct the model.
-        arb::simulation sim(recipe, decomp, context);
+        arb::simulation sim(recipe, context, decomp);
 
         // Set up the probe that will measure voltage in the cell.
 
diff --git a/example/generators/generators.cpp b/example/generators/generators.cpp
index 94bc75be1fc18444fedf1ec2339278b6d4a3e17a..9e72aff20af3da1ce9b854458f59d25efd57adc4 100644
--- a/example/generators/generators.cpp
+++ b/example/generators/generators.cpp
@@ -132,11 +132,8 @@ int main() {
     // Create an instance of our recipe.
     generator_recipe recipe;
 
-    // Make the domain decomposition for the model
-    auto decomp = arb::partition_load_balance(recipe, context);
-
     // Construct the model.
-    arb::simulation sim(recipe, decomp, context);
+    arb::simulation sim(recipe, context);
 
     // Set up the probe that will measure voltage in the cell.
 
diff --git a/example/lfp/lfp.cpp b/example/lfp/lfp.cpp
index 9669c26ee37d74f8fbfda1c37fb7e196f406f759..cb4f7ec26e223466b2f7267d00843276e5278714 100644
--- a/example/lfp/lfp.cpp
+++ b/example/lfp/lfp.cpp
@@ -3,7 +3,6 @@
 #include <vector>
 #include <iostream>
 
-#include <arbor/load_balance.hpp>
 #include <arbor/cable_cell.hpp>
 #include <arbor/morph/morphology.hpp>
 #include <arbor/morph/place_pwlin.hpp>
@@ -79,8 +78,7 @@ private:
         // Apical dendrite, length 490 μm, radius 1 μm, with SWC tag 4.
         tree.append(soma_apex, {0, 0, 10, 1},  {0, 0, 500, 1}, 4);
 
-        decor dec;
-
+        auto dec = arb::decor();
         // Use NEURON defaults for reversal potentials, ion concentrations etc., but override ra, cm.
         dec.set_default(axial_resistivity{100});     // [Ω·cm]
         dec.set_default(membrane_capacitance{0.01}); // [F/m²]
@@ -183,24 +181,22 @@ struct {
 // Run simulation.
 
 int main(int argc, char** argv) {
-    auto context = arb::make_context();
-
-    // Weight 0.005 μS, onset at t = 0 ms, mean frequency 0.1 kHz.
-    auto events = arb::poisson_generator({"syn"}, .005, 0., 0.1, std::minstd_rand{});
-    lfp_demo_recipe R(events);
-
+    // Configuration
     const double t_stop = 100;    // [ms]
     const double sample_dt = 0.1; // [ms]
     const double dt = 0.1;        // [ms]
 
-    arb::simulation sim(R, arb::partition_load_balance(R, context), context);
+    // Weight 0.005 μS, onset at t = 0 ms, mean frequency 0.1 kHz.
+    auto events = arb::poisson_generator({"syn"}, .005, 0., 0.1, std::minstd_rand{});
+    lfp_demo_recipe recipe(events);
+    arb::simulation sim(recipe);
 
     std::vector<position> electrodes = {
         {30, 0, 0},
         {30, 0, 100}
     };
 
-    arb::morphology cell_morphology = any_cast<arb::cable_cell>(R.get_cell_description(0)).morphology();
+    arb::morphology cell_morphology = any_cast<arb::cable_cell>(recipe.get_cell_description(0)).morphology();
     arb::place_pwlin placed_cell(cell_morphology);
 
     auto probe0_metadata = sim.get_probe_metadata(cell_member_type{0, 0});
diff --git a/example/probe-demo/probe-demo.cpp b/example/probe-demo/probe-demo.cpp
index 9e9f785430e0da57daa17b0f2801c835d652456f..0d27c111bd8b3a8b209717dc3dcbeff8904c42dc 100644
--- a/example/probe-demo/probe-demo.cpp
+++ b/example/probe-demo/probe-demo.cpp
@@ -145,8 +145,7 @@ int main(int argc, char** argv) {
 
         cable_recipe R(opt.probe_addr, opt.n_cv);
 
-        auto context = arb::make_context();
-        arb::simulation sim(R, arb::partition_load_balance(R, context), context);
+        arb::simulation sim(R);
 
         sim.add_sampler(arb::all_probes,
                 arb::regular_schedule(opt.sample_dt),
diff --git a/example/ring/ring.cpp b/example/ring/ring.cpp
index 45a3b55be009b6e2dc7e5a66b24ab23451ac916e..fa47d3f2159a818a69c1b0380792eac8f26dc439 100644
--- a/example/ring/ring.cpp
+++ b/example/ring/ring.cpp
@@ -13,6 +13,7 @@
 
 #include <arborio/label_parse.hpp>
 
+#include <arbor/load_balance.hpp>
 #include <arbor/assert_macro.hpp>
 #include <arbor/common_types.hpp>
 #include <arbor/cable_cell.hpp>
@@ -160,10 +161,9 @@ int main(int argc, char** argv) {
         // Create an instance of our recipe.
         ring_recipe recipe(params.num_cells, params.cell, params.min_delay);
 
-        auto decomp = arb::partition_load_balance(recipe, context);
-
         // Construct the model.
-        arb::simulation sim(recipe, decomp, context);
+        auto decomposition = arb::partition_load_balance(recipe, context);
+        arb::simulation sim(recipe, context, decomposition);
 
         // Set up the probe that will measure voltage in the cell.
 
diff --git a/example/single/single.cpp b/example/single/single.cpp
index 2721d2fd4d0256036d27e7d67d670be6699cf58d..4befabcdc35a3ead4698cf3e3d1a71df164841ce 100644
--- a/example/single/single.cpp
+++ b/example/single/single.cpp
@@ -89,8 +89,7 @@ int main(int argc, char** argv) {
         options opt = parse_options(argc, argv);
         single_recipe R(opt.swc_file.empty()? default_morphology(): read_swc(opt.swc_file), opt.policy);
 
-        auto context = arb::make_context();
-        arb::simulation sim(R, arb::partition_load_balance(R, context), context);
+        arb::simulation sim(R);
 
         // Attach a sampler to the probe described in the recipe, sampling every 0.1 ms.
 
diff --git a/python/context.cpp b/python/context.cpp
index 80fcbc61c93fe01e90b92a8c052054a43708e8f5..1da55c86f0b6546d8883070761b2afd33751d24b 100644
--- a/python/context.cpp
+++ b/python/context.cpp
@@ -123,7 +123,7 @@ void register_contexts(pybind11::module& m) {
         .def("__repr__", util::to_string<proc_allocation_shim>);
 
     // context
-    pybind11::class_<context_shim> context(m, "context", "An opaque handle for the hardware resources used in a simulation.");
+    pybind11::class_<context_shim, std::shared_ptr<context_shim>> context(m, "context", "An opaque handle for the hardware resources used in a simulation.");
     context
         .def(pybind11::init(
             [](unsigned threads, pybind11::object gpu, pybind11::object mpi){
diff --git a/python/example/brunel.py b/python/example/brunel.py
index 6a01f4116dffc1cee023431363dfecd0ca6cb3ca..5455ca69b25d5879487ff25384a451674de1f36d 100755
--- a/python/example/brunel.py
+++ b/python/example/brunel.py
@@ -257,7 +257,7 @@ if __name__ == "__main__":
 
     meters.checkpoint("load-balance", context)
 
-    sim = arbor.simulation(recipe, decomp, context)
+    sim = arbor.simulation(recipe, context, decomp)
     sim.record(arbor.spike_recording.all)
 
     meters.checkpoint("simulation-init", context)
diff --git a/python/example/dynamic-catalogue.py b/python/example/dynamic-catalogue.py
index 6fc5e8fcd965b69aa85d5275244fa6fa6ca1ed23..67fc7725edcf8cfb342f6a8d4c979810b75666a4 100644
--- a/python/example/dynamic-catalogue.py
+++ b/python/example/dynamic-catalogue.py
@@ -42,7 +42,5 @@ where <arbor> is the location of the arbor source tree."""
     exit(1)
 
 rcp = recipe()
-ctx = arb.context()
-dom = arb.partition_load_balance(rcp, ctx)
-sim = arb.simulation(rcp, dom, ctx)
+sim = arb.simulation(rcp)
 sim.run(tfinal=30)
diff --git a/python/example/gap_junctions.py b/python/example/gap_junctions.py
index 2b6b3221fe596c25c661872b37a2f09706881b22..8d5ce0841de596619ac027bf51368490cb9f1ec7 100644
--- a/python/example/gap_junctions.py
+++ b/python/example/gap_junctions.py
@@ -133,10 +133,8 @@ ncells = nchains * ncells_per_chain
 # Instantiate recipe
 recipe = chain_recipe(ncells_per_chain, nchains)
 
-# Create a default execution context, domain decomposition and simulation
-context = arbor.context()
-decomp = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, decomp, context)
+# Create a default simulation
+sim = arbor.simulation(recipe)
 
 # Set spike generators to record
 sim.record(arbor.spike_recording.all)
diff --git a/python/example/network_ring.py b/python/example/network_ring.py
index 9eec3cbdc04caf55d9ad7a762505ed709f4f0395..9cb97a3a72385886b2c0839eabef7b58cd390bc7 100755
--- a/python/example/network_ring.py
+++ b/python/example/network_ring.py
@@ -123,9 +123,8 @@ ncells = 4
 recipe = ring_recipe(ncells)
 
 # (12) Create an execution context using all locally available threads, domain decomposition and simulation
-context = arbor.context("avail_threads")
-decomp = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, decomp, context)
+ctx = arbor.context("avail_threads")
+sim = arbor.simulation(recipe, context=ctx)
 
 # (13) Set spike generators to record
 sim.record(arbor.spike_recording.all)
diff --git a/python/example/network_ring_mpi.py b/python/example/network_ring_mpi.py
index 84aa7daa4809d711ba5f2af0ec07fba06964a3de..7ccd2d53bd2dbbb7b4c25f4c85340ac643ebea3d 100644
--- a/python/example/network_ring_mpi.py
+++ b/python/example/network_ring_mpi.py
@@ -132,8 +132,7 @@ context = arbor.context(mpi=comm)
 print(context)
 
 # (13) Create a default domain decomposition and simulation
-decomp = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, decomp, context)
+sim = arbor.simulation(recipe, context)
 
 # (14) Set spike generators to record
 sim.record(arbor.spike_recording.all)
diff --git a/python/example/single_cell_cable.py b/python/example/single_cell_cable.py
index c7116fc945851a95b48950946e70c9a993e41d42..b375ec377bb9f176abf46c2c3fb0d7f955151fa4 100755
--- a/python/example/single_cell_cable.py
+++ b/python/example/single_cell_cable.py
@@ -208,12 +208,8 @@ if __name__ == "__main__":
     ]
     recipe = Cable(probes, **vars(args))
 
-    # create a default execution context and a default domain decomposition
-    context = arbor.context()
-    domains = arbor.partition_load_balance(recipe, context)
-
     # configure the simulation and handles for the probes
-    sim = arbor.simulation(recipe, domains, context)
+    sim = arbor.simulation(recipe)
     dt = 0.001
     handles = [
         sim.sample((0, i), arbor.regular_schedule(dt)) for i in range(len(probes))
diff --git a/python/example/single_cell_detailed_recipe.py b/python/example/single_cell_detailed_recipe.py
index 25da458083ef394182bd604003c014fcfd68b3ff..dd658f66be94fa062b77f70ceacef5013faf98cc 100644
--- a/python/example/single_cell_detailed_recipe.py
+++ b/python/example/single_cell_detailed_recipe.py
@@ -145,14 +145,8 @@ class single_recipe(arbor.recipe):
 # Pass the probe in a list because that it what single_recipe expects.
 recipe = single_recipe(cell, [probe])
 
-# (4) Create an execution context
-context = arbor.context()
-
-# (5) Create a domain decomposition
-domains = arbor.partition_load_balance(recipe, context)
-
-# (6) Create a simulation
-sim = arbor.simulation(recipe, domains, context)
+# (7) Create a simulation
+sim = arbor.simulation(recipe)
 
 # Instruct the simulation to record the spikes and sample the probe
 sim.record(arbor.spike_recording.all)
@@ -160,10 +154,10 @@ sim.record(arbor.spike_recording.all)
 probe_id = arbor.cell_member(0, 0)
 handle = sim.sample(probe_id, arbor.regular_schedule(0.02))
 
-# (7) Run the simulation
+# (8) Run the simulation
 sim.run(tfinal=100, dt=0.025)
 
-# (8) Print or display the results
+# (9) Print or display the results
 spikes = sim.spikes()
 print(len(spikes), "spikes recorded:")
 for s in spikes:
diff --git a/python/example/single_cell_extracellular_potentials.py b/python/example/single_cell_extracellular_potentials.py
index a5c1333482ccb5bd5342d2f52f3c133ecf2fe766..413ebb90195caf6a20a2895c8203e1573e6d521c 100644
--- a/python/example/single_cell_extracellular_potentials.py
+++ b/python/example/single_cell_extracellular_potentials.py
@@ -124,9 +124,7 @@ p, cell = make_cable_cell(morphology, clamp_location)
 recipe = Recipe(cell)
 
 # instantiate simulation
-context = arbor.context()
-domains = arbor.partition_load_balance(recipe, context)
-sim = arbor.simulation(recipe, domains, context)
+sim = arbor.simulation(recipe)
 
 # set up sampling on probes with sampling every 1 ms
 schedule = arbor.regular_schedule(1.0)
diff --git a/python/example/single_cell_recipe.py b/python/example/single_cell_recipe.py
index eae1e436ee48870f2bec25f5fba44571f11fc077..facd5c8d171797bfe681c04ae13498d106eb0174 100644
--- a/python/example/single_cell_recipe.py
+++ b/python/example/single_cell_recipe.py
@@ -60,15 +60,10 @@ class single_recipe(arbor.recipe):
 
 recipe = single_recipe(cell, [arbor.cable_probe_membrane_voltage('"midpoint"')])
 
-# (6) Create a default execution context and a default domain decomposition.
-
-context = arbor.context()
-domains = arbor.partition_load_balance(recipe, context)
-
 # (7) Create and run simulation and set up 10 kHz (every 0.1 ms) sampling on the probe.
 # The probe is located on cell 0, and is the 0th probe on that cell, thus has probe_id (0, 0).
 
-sim = arbor.simulation(recipe, domains, context)
+sim = arbor.simulation(recipe)
 sim.record(arbor.spike_recording.all)
 handle = sim.sample((0, 0), arbor.regular_schedule(0.1))
 sim.run(tfinal=30)
diff --git a/python/example/single_cell_stdp.py b/python/example/single_cell_stdp.py
index 2f0d75bca3893633b89c8f3d935482d6d1f44d62..8c2d443c86730def84b92c52473ec1f693cb7aa4 100755
--- a/python/example/single_cell_stdp.py
+++ b/python/example/single_cell_stdp.py
@@ -78,9 +78,7 @@ class single_recipe(arbor.recipe):
 def run(dT, n_pairs=1, do_plots=False):
     recipe = single_recipe(dT, n_pairs)
 
-    context = arbor.context()
-    domains = arbor.partition_load_balance(recipe, context)
-    sim = arbor.simulation(recipe, domains, context)
+    sim = arbor.simulation(recipe)
 
     sim.record(arbor.spike_recording.all)
 
diff --git a/python/example/two_cell_gap_junctions.py b/python/example/two_cell_gap_junctions.py
index add860b967139ea6d94f46162dd10d09664f5a63..c303802cb5a3bee410341f96d102f191f2618f6d 100755
--- a/python/example/two_cell_gap_junctions.py
+++ b/python/example/two_cell_gap_junctions.py
@@ -153,12 +153,8 @@ if __name__ == "__main__":
     probes = [arbor.cable_probe_membrane_voltage('"gj_site"')]
     recipe = TwoCellsWithGapJunction(probes, **vars(args))
 
-    # create a default execution context and a default domain decomposition
-    context = arbor.context()
-    domains = arbor.partition_load_balance(recipe, context)
-
     # configure the simulation and handles for the probes
-    sim = arbor.simulation(recipe, domains, context)
+    sim = arbor.simulation(recipe)
 
     dt = 0.01
     handles = []
diff --git a/python/simulation.cpp b/python/simulation.cpp
index b506ff17b96f511c70d2ca77afdc3f557ef05706..90689f843d4288c879d9e7124470fd0c8196eb3e 100644
--- a/python/simulation.cpp
+++ b/python/simulation.cpp
@@ -55,11 +55,11 @@ class simulation_shim {
     std::unordered_map<arb::sampler_association_handle, sampler_callback> sampler_map_;
 
 public:
-    simulation_shim(std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx, pyarb_global_ptr global_ptr):
+    simulation_shim(std::shared_ptr<py_recipe>& rec, const context_shim& ctx, const arb::domain_decomposition& decomp, pyarb_global_ptr global_ptr):
         global_ptr_(global_ptr)
     {
         try {
-            sim_.reset(new arb::simulation(py_recipe_shim(rec), decomp, ctx.context));
+            sim_.reset(new arb::simulation(py_recipe_shim(rec), ctx.context, decomp));
         }
         catch (...) {
             py_reset_and_throw();
@@ -195,9 +195,13 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
         // A custom constructor that wraps a python recipe with arb::py_recipe_shim
         // before forwarding it to the arb::recipe constructor.
         .def(pybind11::init(
-            [global_ptr](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) {
+                 [global_ptr](std::shared_ptr<py_recipe>& rec,
+                              const std::shared_ptr<context_shim>& ctx_,
+                              const std::optional<arb::domain_decomposition>& decomp) {
                 try {
-                    return new simulation_shim(rec, decomp, ctx, global_ptr);
+                    auto ctx = ctx_ ? ctx_ : std::make_shared<context_shim>(arb::make_context());
+                    auto dec = decomp.value_or(arb::partition_load_balance(py_recipe_shim(rec), ctx->context));
+                    return new simulation_shim(rec, *ctx, dec, global_ptr);
                 }
                 catch (...) {
                     py_reset_and_throw();
@@ -208,7 +212,9 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Initialize the model described by a recipe, with cells and network distributed\n"
             "according to the domain decomposition and computational resources described by a context.",
-            "recipe"_a, "domain_decomposition"_a, "context"_a)
+             "recipe"_a,
+             pybind11::arg_v("context", pybind11::none(), "Execution context"),
+             pybind11::arg_v("domains", pybind11::none(), "Domain decomposition"))
         .def("reset", &simulation_shim::reset,
             pybind11::call_guard<pybind11::gil_scoped_release>(),
             "Reset the state of the simulation to its initial state.")
diff --git a/python/single_cell_model.cpp b/python/single_cell_model.cpp
index 5165aa02b1c51a794e570a499f3f36c7a6270c6b..cc8c0cb67d4240bdbee8a0c614efd89bc0c34251 100644
--- a/python/single_cell_model.cpp
+++ b/python/single_cell_model.cpp
@@ -171,7 +171,7 @@ public:
 
         auto domdec = arb::partition_load_balance(rec, ctx_);
 
-        sim_ = std::make_unique<arb::simulation>(rec, domdec, ctx_);
+        sim_ = std::make_unique<arb::simulation>(rec, ctx_, domdec);
 
         // Create one trace for each probe.
         traces_.reserve(probes_.size());
diff --git a/python/test/fixtures.py b/python/test/fixtures.py
index 9f44c516cdb6e579f640d9dc93561c0c86318236..7885b12af3a76260548f849785bdf17a74e2428c 100644
--- a/python/test/fixtures.py
+++ b/python/test/fixtures.py
@@ -246,4 +246,4 @@ def sum_weight_hh_spike_2():
 @art_spiker_recipe
 def art_spiking_sim(context, art_spiker_recipe):
     dd = arbor.partition_load_balance(art_spiker_recipe, context)
-    return arbor.simulation(art_spiker_recipe, dd, context)
+    return arbor.simulation(art_spiker_recipe, context, dd)
diff --git a/python/test/unit/test_cable_probes.py b/python/test/unit/test_cable_probes.py
index 59e8e6109a3dd0e5aae8293b57fe455be73c98a5..d2a66057e16889120fbedf7e1d976a26d7dc4ccd 100644
--- a/python/test/unit/test_cable_probes.py
+++ b/python/test/unit/test_cable_probes.py
@@ -91,7 +91,7 @@ class TestCableProbes(unittest.TestCase):
         recipe = cc_recipe()
         context = A.context()
         dd = A.partition_load_balance(recipe, context)
-        sim = A.simulation(recipe, dd, context)
+        sim = A.simulation(recipe, context, dd)
 
         all_cv_cables = [A.cable(0, 0, 1)]
 
diff --git a/python/test/unit/test_catalogues.py b/python/test/unit/test_catalogues.py
index e4bd1e634b7d1ece3410f25689190199ccafc259..6ef96f91f2022aa2948c521a9313b244d1c64fa4 100644
--- a/python/test/unit/test_catalogues.py
+++ b/python/test/unit/test_catalogues.py
@@ -60,7 +60,7 @@ class TestCatalogues(unittest.TestCase):
         rcp = recipe()
         ctx = arb.context()
         dom = arb.partition_load_balance(rcp, ctx)
-        sim = arb.simulation(rcp, dom, ctx)
+        sim = arb.simulation(rcp, ctx, dom)
         sim.run(tfinal=30)
 
     def test_empty(self):
diff --git a/python/test/unit/test_multiple_connections.py b/python/test/unit/test_multiple_connections.py
index 89cc32d74ef2f992fd80cdf74eb91ba8ab32926a..37b0f16e20790602b0154e82fa7fc446699e752b 100644
--- a/python/test/unit/test_multiple_connections.py
+++ b/python/test/unit/test_multiple_connections.py
@@ -134,8 +134,7 @@ class TestMultipleConnections(unittest.TestCase):
         self.assertAlmostEqual(connections_from_recipe[3].delay, 1.4)
 
         # construct domain_decomposition and simulation object
-        dd = arb.partition_load_balance(art_spiker_recipe, context)
-        sim = arb.simulation(art_spiker_recipe, dd, context)
+        sim = arb.simulation(art_spiker_recipe, context)
         sim.record(arb.spike_recording.all)
 
         # create schedule and handle to record the membrane potential of neuron 3
@@ -365,9 +364,8 @@ class TestMultipleConnections(unittest.TestCase):
         self.assertAlmostEqual(connections_from_recipe[1].weight, weight2)
         self.assertAlmostEqual(connections_from_recipe[1].delay, 1.4)
 
-        # construct domain_decomposition and simulation object
-        dd = arb.partition_load_balance(art_spiker_recipe, context)
-        sim = arb.simulation(art_spiker_recipe, dd, context)
+        # construct simulation object
+        sim = arb.simulation(art_spiker_recipe, context)
         sim.record(arb.spike_recording.all)
 
         # create schedule and handle to record the membrane potential of neuron 3
diff --git a/python/test/unit/test_profiling.py b/python/test/unit/test_profiling.py
index 92a17e237b009249b648d78373e7f488b3c00a0c..0c0d484d597a0d715f40a25779c352b931be3bc5 100644
--- a/python/test/unit/test_profiling.py
+++ b/python/test/unit/test_profiling.py
@@ -93,7 +93,7 @@ class TestProfiling(unittest.TestCase):
         arb.profiler_initialize(context)
         recipe = a_recipe()
         dd = arb.partition_load_balance(recipe, context)
-        arb.simulation(recipe, dd, context).run(1)
+        arb.simulation(recipe, context, dd).run(1)
         summary = arb.profiler_summary()
         self.assertEqual(str, type(summary), "profiler summary must be str")
         self.assertTrue(summary, "empty summary")
diff --git a/python/test/unit_distributed/test_simulator.py b/python/test/unit_distributed/test_simulator.py
index 8ac2a98330c4cc7589cf369abae4d7ae211e18f6..ea657a9526a4360b1680036de2d766b4e48a4659 100644
--- a/python/test/unit_distributed/test_simulator.py
+++ b/python/test/unit_distributed/test_simulator.py
@@ -69,7 +69,7 @@ class TestSimulator(unittest.TestCase):
         self.assertEqual(1, len(local_groups))
         self.assertEqual([self.rank], local_groups[0].gids)
 
-        return A.simulation(recipe, dd, context)
+        return A.simulation(recipe, context, dd)
 
     def test_local_spikes(self):
         sim = self.init_sim()
diff --git a/test/unit/test_event_delivery.cpp b/test/unit/test_event_delivery.cpp
index 1294240ce8f8d285684d90e814ca824115d88b5c..82a029f839b532ee6c880cd413ab7666b1f2501f 100644
--- a/test/unit/test_event_delivery.cpp
+++ b/test/unit/test_event_delivery.cpp
@@ -49,17 +49,18 @@ using gid_vector = std::vector<cell_gid_type>;
 using group_gids_type = std::vector<gid_vector>;
 
 std::vector<cell_gid_type> run_test_sim(const recipe& R, const group_gids_type& group_gids) {
-    arb::context ctx = make_context(proc_allocation{});
-    unsigned n = R.num_cells();
 
+    unsigned n = R.num_cells();
     std::vector<group_description> groups;
     for (const auto& gidvec: group_gids) {
         groups.emplace_back(cell_kind::cable, gidvec, backend_kind::multicore);
     }
-    auto D = domain_decomposition(R, ctx, groups);
-    std::vector<spike> spikes;
 
-    simulation sim(R, D, ctx);
+    auto C = make_context();
+    auto D = domain_decomposition(R, C, groups);
+    simulation sim(R, C, D);
+
+    std::vector<spike> spikes;
     sim.set_global_spike_callback(
             [&spikes](const std::vector<spike>& ss) {
                 spikes.insert(spikes.end(), ss.begin(), ss.end());
diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp
index d9d1aa76c7ceebc0eb73f2560b4d51cbda281bae..55f10243cb64ca0f6125eb06f914587921efbd7b 100644
--- a/test/unit/test_fvm_lowered.cpp
+++ b/test/unit/test_fvm_lowered.cpp
@@ -485,7 +485,7 @@ TEST(fvm_lowered, derived_mechs) {
         float times[] = {10.f, 20.f};
 
         auto decomp = partition_load_balance(rec, context);
-        simulation sim(rec, decomp, context);
+        simulation sim(rec, context, decomp);
         sim.add_sampler(all_probes, explicit_schedule(times), sampler);
         sim.run(30.0, 1.f/1024);
 
@@ -516,7 +516,7 @@ TEST(fvm_lowered, null_region) {
     rec.catalogue().derive("custom_kin1", "test_kin1", {{"tau", 20.0}});
 
     auto decomp = partition_load_balance(rec, context);
-    simulation sim(rec, decomp, context);
+    simulation sim(rec, context, decomp);
     EXPECT_NO_THROW(sim.run(30.0, 1.f/1024));
 }
 
diff --git a/test/unit/test_lif_cell_group.cpp b/test/unit/test_lif_cell_group.cpp
index 9ca765c40608093c049b57e86c4cf6609620702f..0ea015a39ddbb154e7d171a8e28fbfccb93bb330 100644
--- a/test/unit/test_lif_cell_group.cpp
+++ b/test/unit/test_lif_cell_group.cpp
@@ -133,7 +133,7 @@ TEST(lif_cell_group, throw) {
     probe_recipe rec;
     auto context = make_context();
     auto decomp = partition_load_balance(rec, context);
-    EXPECT_THROW(simulation(rec, decomp, context), bad_cell_probe);
+    EXPECT_THROW(simulation(rec, context, decomp), bad_cell_probe);
 }
 
 TEST(lif_cell_group, recipe)
@@ -153,7 +153,7 @@ TEST(lif_cell_group, spikes) {
     auto context = make_context();
 
     auto decomp = partition_load_balance(recipe, context);
-    simulation sim(recipe, decomp, context);
+    simulation sim(recipe, context, decomp);
 
     cse_vector events;
 
@@ -193,7 +193,7 @@ TEST(lif_cell_group, ring)
     auto decomp = partition_load_balance(recipe, context);
 
     // Creates a simulation with a ring recipe of lif neurons
-    simulation sim(recipe, decomp, context);
+    simulation sim(recipe, context, decomp);
 
     std::vector<spike> spike_buffer;
 
diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp
index bc9f41c374678bd4f2251c3072d58402fdfd38ed..ca541823faa3cca46e41489e8a04857c534cefca 100644
--- a/test/unit/test_probe.cpp
+++ b/test/unit/test_probe.cpp
@@ -826,7 +826,7 @@ void run_axial_and_ion_current_sampled_probe_test(const context& ctx) {
     partition_hint_map phints = {
        {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}}
     };
-    simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx);
+    simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints));
 
     // Take a sample at 20 tau, and run sim for just a bit longer.
 
@@ -913,7 +913,7 @@ auto run_simple_samplers(
     partition_hint_map phints = {
        {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}}
     };
-    simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx);
+    simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints));
 
     std::vector<trace_vector<SampleData, SampleMeta>> traces(n_probe);
     for (unsigned i = 0; i<n_probe; ++i) {
@@ -1278,13 +1278,13 @@ void run_exact_sampling_probe_test(const context& ctx) {
     };
     domain_decomposition one_cell_group = partition_load_balance(rec, ctx, phints);
 
-    simulation lax_sim(rec, one_cell_group, ctx);
+    simulation lax_sim(rec, ctx, one_cell_group);
     for (unsigned i = 0; i<n_cell; ++i) {
         lax_sim.add_sampler(one_probe({i, 0}), sample_sched, make_simple_sampler(lax_traces.at(i)), sampling_policy::lax);
     }
     lax_sim.run(t_end, max_dt);
 
-    simulation exact_sim(rec, one_cell_group, ctx);
+    simulation exact_sim(rec, ctx, one_cell_group);
     for (unsigned i = 0; i<n_cell; ++i) {
         exact_sim.add_sampler(one_probe({i, 0}), sample_sched, make_simple_sampler(exact_traces.at(i)), sampling_policy::exact);
     }
@@ -1366,7 +1366,7 @@ TEST(probe, get_probe_metadata) {
     partition_hint_map phints = {
        {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}}
     };
-    simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx);
+    simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints));
 
     std::vector<probe_metadata> mm = sim.get_probe_metadata({0, 0});
     ASSERT_EQ(3u, mm.size());
diff --git a/test/unit/test_recipe.cpp b/test/unit/test_recipe.cpp
index f630f851f097fc16faa12ede8015176f012e1070..f1ffe120bc7f332b2c256402531e83cbe3c8597e 100644
--- a/test/unit/test_recipe.cpp
+++ b/test/unit/test_recipe.cpp
@@ -112,7 +112,7 @@ TEST(recipe, gap_junctions)
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {gjs_0, gjs_1}, {{}, {}});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context));
+        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0));
     }
     {
         std::vector<arb::gap_junction_connection> gjs_0 = {{{1, "gapjunction1", policy::assert_univalent}, {"gapjunction0", policy::assert_univalent}, 0.1},
@@ -126,7 +126,7 @@ TEST(recipe, gap_junctions)
         auto recipe_1 = custom_recipe({cell_0, cell_1}, {{}, {}}, {gjs_0, gjs_1}, {{}, {}});
         auto decomp_1 = partition_load_balance(recipe_1, context);
 
-        EXPECT_THROW(simulation(recipe_1, decomp_1, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_1, context, decomp_1), arb::bad_connection_label);
 
     }
 }
@@ -150,7 +150,7 @@ TEST(recipe, connections)
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context));
+        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0));
     }
     {
         conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1},
@@ -164,7 +164,7 @@ TEST(recipe, connections)
         auto recipe_1 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_1 = partition_load_balance(recipe_1, context);
 
-        EXPECT_THROW(simulation(recipe_1, decomp_1, context), arb::bad_connection_source_gid);
+        EXPECT_THROW(simulation(recipe_1, context, decomp_1), arb::bad_connection_source_gid);
     }
     {
         conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1},
@@ -178,7 +178,7 @@ TEST(recipe, connections)
         auto recipe_2 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_2 = partition_load_balance(recipe_2, context);
 
-        EXPECT_THROW(simulation(recipe_2, decomp_2, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_2, context, decomp_2), arb::bad_connection_label);
     }
     {
         conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1},
@@ -192,7 +192,7 @@ TEST(recipe, connections)
         auto recipe_4 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}},  {{}, {}});
         auto decomp_4 = partition_load_balance(recipe_4, context);
 
-        EXPECT_THROW(simulation(recipe_4, decomp_4, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_4, context, decomp_4), arb::bad_connection_label);
     }
 }
 
@@ -210,7 +210,7 @@ TEST(recipe, event_generators) {
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}},  {gens_0, gens_1});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context));
+        EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0));
     }
     {
         gens_0 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}, {{"synapse3"}, 2.0, 0.1}})};
@@ -219,6 +219,6 @@ TEST(recipe, event_generators) {
         auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}},  {gens_0, gens_1});
         auto decomp_0 = partition_load_balance(recipe_0, context);
 
-        EXPECT_THROW(simulation(recipe_0, decomp_0, context), arb::bad_connection_label);
+        EXPECT_THROW(simulation(recipe_0, context, decomp_0), arb::bad_connection_label);
     }
 }
diff --git a/test/unit/test_simulation.cpp b/test/unit/test_simulation.cpp
index 3110d43822af939061ea6422bf77531a6c377589..b79ac19a6fa21d152ea1933b3a72e417a41a4795 100644
--- a/test/unit/test_simulation.cpp
+++ b/test/unit/test_simulation.cpp
@@ -51,7 +51,7 @@ TEST(simulation, null) {
     auto r = null_recipe{};
     auto c = arb::make_context();
     auto d = arb::partition_load_balance(r, c);
-    auto s = arb::simulation(r, d, c);
+    auto s = arb::simulation(r, c, d);
     s.run(0.05, 0.01);
 }
 
@@ -74,7 +74,7 @@ TEST(simulation, spike_global_callback) {
     play_spikes rec(spike_times);
     auto ctx = n_thread_context(4);
     auto decomp = partition_load_balance(rec, ctx);
-    simulation sim(rec, decomp, ctx);
+    simulation sim(rec, ctx, decomp);
 
     std::vector<spike> collected;
     sim.set_global_spike_callback([&](const std::vector<spike>& spikes) {
@@ -155,7 +155,7 @@ TEST(simulation, restart) {
 
     auto ctx = n_thread_context(4);
     auto decomp = partition_load_balance(rec, ctx);
-    simulation sim(rec, decomp, ctx);
+    simulation sim(rec, ctx, decomp);
 
     std::vector<spike> collected;
     sim.set_global_spike_callback([&](const std::vector<spike>& spikes) {
diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp
index 12a53234f044b766d99142e38f9f321d8405421f..d3279169a2f47d13460736288b03c2b87ddeec2f 100644
--- a/test/unit/test_spikes.cpp
+++ b/test/unit/test_spikes.cpp
@@ -234,7 +234,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher_interpolation) {
         cable1d_recipe rec({cell});
 
         auto decomp = arb::partition_load_balance(rec, context);
-        arb::simulation sim(rec, decomp, context);
+        arb::simulation sim(rec, context, decomp);
 
         sim.set_global_spike_callback(
                 [&spikes](const std::vector<arb::spike>& recorded_spikes) {