diff --git a/arbor/include/arbor/load_balance.hpp b/arbor/include/arbor/load_balance.hpp index b9e8e71b25a9bd75e6aa7487e897c5cc099f3191..1c07398095294f2ff0924ff2a5eed203d4a8283f 100644 --- a/arbor/include/arbor/load_balance.hpp +++ b/arbor/include/arbor/load_balance.hpp @@ -23,5 +23,5 @@ using partition_hint_map = std::unordered_map<cell_kind, partition_hint>; ARB_ARBOR_API domain_decomposition partition_load_balance( const recipe& rec, const context& ctx, - partition_hint_map hint_map = {}); + const partition_hint_map& hint_map = {}); } // namespace arb diff --git a/arbor/include/arbor/simulation.hpp b/arbor/include/arbor/simulation.hpp index 7e11671935f2957603452f12f828b66215a22d38..c2dd5311ec2ef6d71a0cb0bbb0bd1ed8232217bd 100644 --- a/arbor/include/arbor/simulation.hpp +++ b/arbor/include/arbor/simulation.hpp @@ -4,11 +4,13 @@ #include <memory> #include <unordered_map> #include <vector> +#include <functional> #include <arbor/export.hpp> #include <arbor/common_types.hpp> #include <arbor/context.hpp> #include <arbor/domain_decomposition.hpp> +#include <arbor/load_balance.hpp> #include <arbor/recipe.hpp> #include <arbor/sampling.hpp> #include <arbor/schedule.hpp> @@ -25,7 +27,12 @@ class simulation_state; class ARB_ARBOR_API simulation { public: - simulation(const recipe& rec, const domain_decomposition& decomp, const context& ctx); + + simulation(const recipe& rec, const context& ctx, const domain_decomposition& decomp); + + simulation(const recipe& rec, + const context& ctx=make_context(), + std::function<domain_decomposition(const recipe&, const context&)> balancer=[](auto& r, auto& c) { return partition_load_balance(r, c); }): simulation(rec, ctx, balancer(rec, ctx)) {} void reset(); diff --git a/arbor/partition_load_balance.cpp b/arbor/partition_load_balance.cpp index 19f3ac09931380660f81f7ee24d83ecd6d9aeaba..71467aabff1fc7e7cbb7a471a522f2b7046a1b74 100644 --- a/arbor/partition_load_balance.cpp +++ b/arbor/partition_load_balance.cpp @@ -23,7 +23,7 @@ namespace arb { ARB_ARBOR_API domain_decomposition partition_load_balance( const recipe& rec, const context& ctx, - partition_hint_map hint_map) + const partition_hint_map& hint_map) { using util::make_span; diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp index 21a07d3c983f854fdcc9704e91a8777c7d38f3c2..f604e19bc560e03f9e078e330f54bc29c8fd5577 100644 --- a/arbor/simulation.cpp +++ b/arbor/simulation.cpp @@ -510,8 +510,8 @@ void simulation_state::inject_events(const cse_vector& events) { simulation::simulation( const recipe& rec, - const domain_decomposition& decomp, - const context& ctx) + const context& ctx, + const domain_decomposition& decomp) { impl_.reset(new simulation_state(rec, decomp, *ctx)); } diff --git a/doc/concepts/simulation.rst b/doc/concepts/simulation.rst index 2a63cc69c536bf70ea67093c78007ad9b8d3b1a0..6dc7c0551ca2c8d99e68754cd2f1a73488fc11fa 100644 --- a/doc/concepts/simulation.rst +++ b/doc/concepts/simulation.rst @@ -9,12 +9,16 @@ group are scheduled. From recipe to simulation ------------------------- -To build a simulation the following are needed: +To build a simulation the following is needed: * A :ref:`recipe <modelrecipe>` that describes the cells and connections in the model. -* A :ref:`domain decomposition <modeldomdec>` that describes the distribution of the - model over the local and distributed :ref:`hardware resources <modelhardware>`. -* An :ref:`execution context <modelcontext>` used to execute the simulation. +* A :ref:`domain decomposition <modeldomdec>` that describes the distribution of + the model over the local and distributed :ref:`hardware resources + <modelhardware>`. If not given, a default algorithm will be used which assigns + cells to groups one to one; each group is assigned to a thread from the context. +* An :ref:`execution context <modelcontext>` used to execute the simulation. If + not given, the default context will be used, which allocates one thread, one + process (MPI task), and no GPU. Simulation execution and interaction ------------------------------------ diff --git a/example/bench/bench.cpp b/example/bench/bench.cpp index 0d6fb97679be57675881846ba1db1f6d143ba621..1057ff74f90053e22901c9e50044a74b014db927 100644 --- a/example/bench/bench.cpp +++ b/example/bench/bench.cpp @@ -165,7 +165,7 @@ int main(int argc, char** argv) { meters.checkpoint("domain-decomp", context); // Construct the model. - arb::simulation sim(recipe, decomp, context); + arb::simulation sim(recipe, context, decomp); meters.checkpoint("model-build", context); // Run the simulation for 100 ms, with time steps of 0.01 ms. diff --git a/example/brunel/brunel.cpp b/example/brunel/brunel.cpp index 4343e5933e2f72352eacd1bf06cae4afbd1e06a0..351e656fe03f4979ea3bc1403fb29d4c6fd0bcff 100644 --- a/example/brunel/brunel.cpp +++ b/example/brunel/brunel.cpp @@ -246,9 +246,10 @@ int main(int argc, char** argv) { partition_hint_map hints; hints[cell_kind::lif].cpu_group_size = group_size; - auto decomp = partition_load_balance(recipe, context, hints); - simulation sim(recipe, decomp, context); + simulation sim(recipe, + context, + [&hints](auto& r, auto& c) { return partition_load_balance(r, c, hints); }); // Set up spike recording. std::vector<arb::spike> recorded_spikes; diff --git a/example/drybench/drybench.cpp b/example/drybench/drybench.cpp index ac891cdd551de85540c202b0f41d2d0fdf248c2a..c5300d1aaab5131f4b4aab3d668c9ca0e6369e73 100644 --- a/example/drybench/drybench.cpp +++ b/example/drybench/drybench.cpp @@ -156,10 +156,8 @@ int main(int argc, char** argv) { auto tile = std::make_unique<tile_desc>(params); arb::symmetric_recipe recipe(std::move(tile)); - auto decomp = arb::partition_load_balance(recipe, ctx); - // Construct the model. - arb::simulation sim(recipe, decomp, ctx); + arb::simulation sim(recipe, ctx); meters.checkpoint("model-init", ctx); diff --git a/example/dryrun/dryrun.cpp b/example/dryrun/dryrun.cpp index 8a057d0972f9a00644212593f9d3fca69c85c74d..6f38d968238449effb77280ab0c5b10297727918 100644 --- a/example/dryrun/dryrun.cpp +++ b/example/dryrun/dryrun.cpp @@ -174,10 +174,8 @@ int main(int argc, char** argv) { params.num_ranks, params.cell, params.min_delay); arb::symmetric_recipe recipe(std::move(tile)); - auto decomp = arb::partition_load_balance(recipe, ctx); - // Construct the model. - arb::simulation sim(recipe, decomp, ctx); + arb::simulation sim(recipe, ctx); // The id of the only probe on the cell: the cell_member type points to (cell 0, probe 0) auto probe_id = cell_member_type{0, 0}; diff --git a/example/gap_junctions/gap_junctions.cpp b/example/gap_junctions/gap_junctions.cpp index f61b835ea1d15337d0851a2e55da6e56a1f5572e..10b124d1e37251552dce0b4b51a128216fad46bb 100644 --- a/example/gap_junctions/gap_junctions.cpp +++ b/example/gap_junctions/gap_junctions.cpp @@ -172,7 +172,7 @@ int main(int argc, char** argv) { auto decomp = arb::partition_load_balance(recipe, context); // Construct the model. - arb::simulation sim(recipe, decomp, context); + arb::simulation sim(recipe, context, decomp); // Set up the probe that will measure voltage in the cell. diff --git a/example/generators/generators.cpp b/example/generators/generators.cpp index 94bc75be1fc18444fedf1ec2339278b6d4a3e17a..9e72aff20af3da1ce9b854458f59d25efd57adc4 100644 --- a/example/generators/generators.cpp +++ b/example/generators/generators.cpp @@ -132,11 +132,8 @@ int main() { // Create an instance of our recipe. generator_recipe recipe; - // Make the domain decomposition for the model - auto decomp = arb::partition_load_balance(recipe, context); - // Construct the model. - arb::simulation sim(recipe, decomp, context); + arb::simulation sim(recipe, context); // Set up the probe that will measure voltage in the cell. diff --git a/example/lfp/lfp.cpp b/example/lfp/lfp.cpp index 9669c26ee37d74f8fbfda1c37fb7e196f406f759..cb4f7ec26e223466b2f7267d00843276e5278714 100644 --- a/example/lfp/lfp.cpp +++ b/example/lfp/lfp.cpp @@ -3,7 +3,6 @@ #include <vector> #include <iostream> -#include <arbor/load_balance.hpp> #include <arbor/cable_cell.hpp> #include <arbor/morph/morphology.hpp> #include <arbor/morph/place_pwlin.hpp> @@ -79,8 +78,7 @@ private: // Apical dendrite, length 490 μm, radius 1 μm, with SWC tag 4. tree.append(soma_apex, {0, 0, 10, 1}, {0, 0, 500, 1}, 4); - decor dec; - + auto dec = arb::decor(); // Use NEURON defaults for reversal potentials, ion concentrations etc., but override ra, cm. dec.set_default(axial_resistivity{100}); // [Ω·cm] dec.set_default(membrane_capacitance{0.01}); // [F/m²] @@ -183,24 +181,22 @@ struct { // Run simulation. int main(int argc, char** argv) { - auto context = arb::make_context(); - - // Weight 0.005 μS, onset at t = 0 ms, mean frequency 0.1 kHz. - auto events = arb::poisson_generator({"syn"}, .005, 0., 0.1, std::minstd_rand{}); - lfp_demo_recipe R(events); - + // Configuration const double t_stop = 100; // [ms] const double sample_dt = 0.1; // [ms] const double dt = 0.1; // [ms] - arb::simulation sim(R, arb::partition_load_balance(R, context), context); + // Weight 0.005 μS, onset at t = 0 ms, mean frequency 0.1 kHz. + auto events = arb::poisson_generator({"syn"}, .005, 0., 0.1, std::minstd_rand{}); + lfp_demo_recipe recipe(events); + arb::simulation sim(recipe); std::vector<position> electrodes = { {30, 0, 0}, {30, 0, 100} }; - arb::morphology cell_morphology = any_cast<arb::cable_cell>(R.get_cell_description(0)).morphology(); + arb::morphology cell_morphology = any_cast<arb::cable_cell>(recipe.get_cell_description(0)).morphology(); arb::place_pwlin placed_cell(cell_morphology); auto probe0_metadata = sim.get_probe_metadata(cell_member_type{0, 0}); diff --git a/example/probe-demo/probe-demo.cpp b/example/probe-demo/probe-demo.cpp index 9e9f785430e0da57daa17b0f2801c835d652456f..0d27c111bd8b3a8b209717dc3dcbeff8904c42dc 100644 --- a/example/probe-demo/probe-demo.cpp +++ b/example/probe-demo/probe-demo.cpp @@ -145,8 +145,7 @@ int main(int argc, char** argv) { cable_recipe R(opt.probe_addr, opt.n_cv); - auto context = arb::make_context(); - arb::simulation sim(R, arb::partition_load_balance(R, context), context); + arb::simulation sim(R); sim.add_sampler(arb::all_probes, arb::regular_schedule(opt.sample_dt), diff --git a/example/ring/ring.cpp b/example/ring/ring.cpp index 45a3b55be009b6e2dc7e5a66b24ab23451ac916e..fa47d3f2159a818a69c1b0380792eac8f26dc439 100644 --- a/example/ring/ring.cpp +++ b/example/ring/ring.cpp @@ -13,6 +13,7 @@ #include <arborio/label_parse.hpp> +#include <arbor/load_balance.hpp> #include <arbor/assert_macro.hpp> #include <arbor/common_types.hpp> #include <arbor/cable_cell.hpp> @@ -160,10 +161,9 @@ int main(int argc, char** argv) { // Create an instance of our recipe. ring_recipe recipe(params.num_cells, params.cell, params.min_delay); - auto decomp = arb::partition_load_balance(recipe, context); - // Construct the model. - arb::simulation sim(recipe, decomp, context); + auto decomposition = arb::partition_load_balance(recipe, context); + arb::simulation sim(recipe, context, decomposition); // Set up the probe that will measure voltage in the cell. diff --git a/example/single/single.cpp b/example/single/single.cpp index 2721d2fd4d0256036d27e7d67d670be6699cf58d..4befabcdc35a3ead4698cf3e3d1a71df164841ce 100644 --- a/example/single/single.cpp +++ b/example/single/single.cpp @@ -89,8 +89,7 @@ int main(int argc, char** argv) { options opt = parse_options(argc, argv); single_recipe R(opt.swc_file.empty()? default_morphology(): read_swc(opt.swc_file), opt.policy); - auto context = arb::make_context(); - arb::simulation sim(R, arb::partition_load_balance(R, context), context); + arb::simulation sim(R); // Attach a sampler to the probe described in the recipe, sampling every 0.1 ms. diff --git a/python/context.cpp b/python/context.cpp index 80fcbc61c93fe01e90b92a8c052054a43708e8f5..1da55c86f0b6546d8883070761b2afd33751d24b 100644 --- a/python/context.cpp +++ b/python/context.cpp @@ -123,7 +123,7 @@ void register_contexts(pybind11::module& m) { .def("__repr__", util::to_string<proc_allocation_shim>); // context - pybind11::class_<context_shim> context(m, "context", "An opaque handle for the hardware resources used in a simulation."); + pybind11::class_<context_shim, std::shared_ptr<context_shim>> context(m, "context", "An opaque handle for the hardware resources used in a simulation."); context .def(pybind11::init( [](unsigned threads, pybind11::object gpu, pybind11::object mpi){ diff --git a/python/example/brunel.py b/python/example/brunel.py index 6a01f4116dffc1cee023431363dfecd0ca6cb3ca..5455ca69b25d5879487ff25384a451674de1f36d 100755 --- a/python/example/brunel.py +++ b/python/example/brunel.py @@ -257,7 +257,7 @@ if __name__ == "__main__": meters.checkpoint("load-balance", context) - sim = arbor.simulation(recipe, decomp, context) + sim = arbor.simulation(recipe, context, decomp) sim.record(arbor.spike_recording.all) meters.checkpoint("simulation-init", context) diff --git a/python/example/dynamic-catalogue.py b/python/example/dynamic-catalogue.py index 6fc5e8fcd965b69aa85d5275244fa6fa6ca1ed23..67fc7725edcf8cfb342f6a8d4c979810b75666a4 100644 --- a/python/example/dynamic-catalogue.py +++ b/python/example/dynamic-catalogue.py @@ -42,7 +42,5 @@ where <arbor> is the location of the arbor source tree.""" exit(1) rcp = recipe() -ctx = arb.context() -dom = arb.partition_load_balance(rcp, ctx) -sim = arb.simulation(rcp, dom, ctx) +sim = arb.simulation(rcp) sim.run(tfinal=30) diff --git a/python/example/gap_junctions.py b/python/example/gap_junctions.py index 2b6b3221fe596c25c661872b37a2f09706881b22..8d5ce0841de596619ac027bf51368490cb9f1ec7 100644 --- a/python/example/gap_junctions.py +++ b/python/example/gap_junctions.py @@ -133,10 +133,8 @@ ncells = nchains * ncells_per_chain # Instantiate recipe recipe = chain_recipe(ncells_per_chain, nchains) -# Create a default execution context, domain decomposition and simulation -context = arbor.context() -decomp = arbor.partition_load_balance(recipe, context) -sim = arbor.simulation(recipe, decomp, context) +# Create a default simulation +sim = arbor.simulation(recipe) # Set spike generators to record sim.record(arbor.spike_recording.all) diff --git a/python/example/network_ring.py b/python/example/network_ring.py index 9eec3cbdc04caf55d9ad7a762505ed709f4f0395..9cb97a3a72385886b2c0839eabef7b58cd390bc7 100755 --- a/python/example/network_ring.py +++ b/python/example/network_ring.py @@ -123,9 +123,8 @@ ncells = 4 recipe = ring_recipe(ncells) # (12) Create an execution context using all locally available threads, domain decomposition and simulation -context = arbor.context("avail_threads") -decomp = arbor.partition_load_balance(recipe, context) -sim = arbor.simulation(recipe, decomp, context) +ctx = arbor.context("avail_threads") +sim = arbor.simulation(recipe, context=ctx) # (13) Set spike generators to record sim.record(arbor.spike_recording.all) diff --git a/python/example/network_ring_mpi.py b/python/example/network_ring_mpi.py index 84aa7daa4809d711ba5f2af0ec07fba06964a3de..7ccd2d53bd2dbbb7b4c25f4c85340ac643ebea3d 100644 --- a/python/example/network_ring_mpi.py +++ b/python/example/network_ring_mpi.py @@ -132,8 +132,7 @@ context = arbor.context(mpi=comm) print(context) # (13) Create a default domain decomposition and simulation -decomp = arbor.partition_load_balance(recipe, context) -sim = arbor.simulation(recipe, decomp, context) +sim = arbor.simulation(recipe, context) # (14) Set spike generators to record sim.record(arbor.spike_recording.all) diff --git a/python/example/single_cell_cable.py b/python/example/single_cell_cable.py index c7116fc945851a95b48950946e70c9a993e41d42..b375ec377bb9f176abf46c2c3fb0d7f955151fa4 100755 --- a/python/example/single_cell_cable.py +++ b/python/example/single_cell_cable.py @@ -208,12 +208,8 @@ if __name__ == "__main__": ] recipe = Cable(probes, **vars(args)) - # create a default execution context and a default domain decomposition - context = arbor.context() - domains = arbor.partition_load_balance(recipe, context) - # configure the simulation and handles for the probes - sim = arbor.simulation(recipe, domains, context) + sim = arbor.simulation(recipe) dt = 0.001 handles = [ sim.sample((0, i), arbor.regular_schedule(dt)) for i in range(len(probes)) diff --git a/python/example/single_cell_detailed_recipe.py b/python/example/single_cell_detailed_recipe.py index 25da458083ef394182bd604003c014fcfd68b3ff..dd658f66be94fa062b77f70ceacef5013faf98cc 100644 --- a/python/example/single_cell_detailed_recipe.py +++ b/python/example/single_cell_detailed_recipe.py @@ -145,14 +145,8 @@ class single_recipe(arbor.recipe): # Pass the probe in a list because that it what single_recipe expects. recipe = single_recipe(cell, [probe]) -# (4) Create an execution context -context = arbor.context() - -# (5) Create a domain decomposition -domains = arbor.partition_load_balance(recipe, context) - -# (6) Create a simulation -sim = arbor.simulation(recipe, domains, context) +# (7) Create a simulation +sim = arbor.simulation(recipe) # Instruct the simulation to record the spikes and sample the probe sim.record(arbor.spike_recording.all) @@ -160,10 +154,10 @@ sim.record(arbor.spike_recording.all) probe_id = arbor.cell_member(0, 0) handle = sim.sample(probe_id, arbor.regular_schedule(0.02)) -# (7) Run the simulation +# (8) Run the simulation sim.run(tfinal=100, dt=0.025) -# (8) Print or display the results +# (9) Print or display the results spikes = sim.spikes() print(len(spikes), "spikes recorded:") for s in spikes: diff --git a/python/example/single_cell_extracellular_potentials.py b/python/example/single_cell_extracellular_potentials.py index a5c1333482ccb5bd5342d2f52f3c133ecf2fe766..413ebb90195caf6a20a2895c8203e1573e6d521c 100644 --- a/python/example/single_cell_extracellular_potentials.py +++ b/python/example/single_cell_extracellular_potentials.py @@ -124,9 +124,7 @@ p, cell = make_cable_cell(morphology, clamp_location) recipe = Recipe(cell) # instantiate simulation -context = arbor.context() -domains = arbor.partition_load_balance(recipe, context) -sim = arbor.simulation(recipe, domains, context) +sim = arbor.simulation(recipe) # set up sampling on probes with sampling every 1 ms schedule = arbor.regular_schedule(1.0) diff --git a/python/example/single_cell_recipe.py b/python/example/single_cell_recipe.py index eae1e436ee48870f2bec25f5fba44571f11fc077..facd5c8d171797bfe681c04ae13498d106eb0174 100644 --- a/python/example/single_cell_recipe.py +++ b/python/example/single_cell_recipe.py @@ -60,15 +60,10 @@ class single_recipe(arbor.recipe): recipe = single_recipe(cell, [arbor.cable_probe_membrane_voltage('"midpoint"')]) -# (6) Create a default execution context and a default domain decomposition. - -context = arbor.context() -domains = arbor.partition_load_balance(recipe, context) - # (7) Create and run simulation and set up 10 kHz (every 0.1 ms) sampling on the probe. # The probe is located on cell 0, and is the 0th probe on that cell, thus has probe_id (0, 0). -sim = arbor.simulation(recipe, domains, context) +sim = arbor.simulation(recipe) sim.record(arbor.spike_recording.all) handle = sim.sample((0, 0), arbor.regular_schedule(0.1)) sim.run(tfinal=30) diff --git a/python/example/single_cell_stdp.py b/python/example/single_cell_stdp.py index 2f0d75bca3893633b89c8f3d935482d6d1f44d62..8c2d443c86730def84b92c52473ec1f693cb7aa4 100755 --- a/python/example/single_cell_stdp.py +++ b/python/example/single_cell_stdp.py @@ -78,9 +78,7 @@ class single_recipe(arbor.recipe): def run(dT, n_pairs=1, do_plots=False): recipe = single_recipe(dT, n_pairs) - context = arbor.context() - domains = arbor.partition_load_balance(recipe, context) - sim = arbor.simulation(recipe, domains, context) + sim = arbor.simulation(recipe) sim.record(arbor.spike_recording.all) diff --git a/python/example/two_cell_gap_junctions.py b/python/example/two_cell_gap_junctions.py index add860b967139ea6d94f46162dd10d09664f5a63..c303802cb5a3bee410341f96d102f191f2618f6d 100755 --- a/python/example/two_cell_gap_junctions.py +++ b/python/example/two_cell_gap_junctions.py @@ -153,12 +153,8 @@ if __name__ == "__main__": probes = [arbor.cable_probe_membrane_voltage('"gj_site"')] recipe = TwoCellsWithGapJunction(probes, **vars(args)) - # create a default execution context and a default domain decomposition - context = arbor.context() - domains = arbor.partition_load_balance(recipe, context) - # configure the simulation and handles for the probes - sim = arbor.simulation(recipe, domains, context) + sim = arbor.simulation(recipe) dt = 0.01 handles = [] diff --git a/python/simulation.cpp b/python/simulation.cpp index b506ff17b96f511c70d2ca77afdc3f557ef05706..90689f843d4288c879d9e7124470fd0c8196eb3e 100644 --- a/python/simulation.cpp +++ b/python/simulation.cpp @@ -55,11 +55,11 @@ class simulation_shim { std::unordered_map<arb::sampler_association_handle, sampler_callback> sampler_map_; public: - simulation_shim(std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx, pyarb_global_ptr global_ptr): + simulation_shim(std::shared_ptr<py_recipe>& rec, const context_shim& ctx, const arb::domain_decomposition& decomp, pyarb_global_ptr global_ptr): global_ptr_(global_ptr) { try { - sim_.reset(new arb::simulation(py_recipe_shim(rec), decomp, ctx.context)); + sim_.reset(new arb::simulation(py_recipe_shim(rec), ctx.context, decomp)); } catch (...) { py_reset_and_throw(); @@ -195,9 +195,13 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) { // A custom constructor that wraps a python recipe with arb::py_recipe_shim // before forwarding it to the arb::recipe constructor. .def(pybind11::init( - [global_ptr](std::shared_ptr<py_recipe>& rec, const arb::domain_decomposition& decomp, const context_shim& ctx) { + [global_ptr](std::shared_ptr<py_recipe>& rec, + const std::shared_ptr<context_shim>& ctx_, + const std::optional<arb::domain_decomposition>& decomp) { try { - return new simulation_shim(rec, decomp, ctx, global_ptr); + auto ctx = ctx_ ? ctx_ : std::make_shared<context_shim>(arb::make_context()); + auto dec = decomp.value_or(arb::partition_load_balance(py_recipe_shim(rec), ctx->context)); + return new simulation_shim(rec, *ctx, dec, global_ptr); } catch (...) { py_reset_and_throw(); @@ -208,7 +212,9 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) { pybind11::call_guard<pybind11::gil_scoped_release>(), "Initialize the model described by a recipe, with cells and network distributed\n" "according to the domain decomposition and computational resources described by a context.", - "recipe"_a, "domain_decomposition"_a, "context"_a) + "recipe"_a, + pybind11::arg_v("context", pybind11::none(), "Execution context"), + pybind11::arg_v("domains", pybind11::none(), "Domain decomposition")) .def("reset", &simulation_shim::reset, pybind11::call_guard<pybind11::gil_scoped_release>(), "Reset the state of the simulation to its initial state.") diff --git a/python/single_cell_model.cpp b/python/single_cell_model.cpp index 5165aa02b1c51a794e570a499f3f36c7a6270c6b..cc8c0cb67d4240bdbee8a0c614efd89bc0c34251 100644 --- a/python/single_cell_model.cpp +++ b/python/single_cell_model.cpp @@ -171,7 +171,7 @@ public: auto domdec = arb::partition_load_balance(rec, ctx_); - sim_ = std::make_unique<arb::simulation>(rec, domdec, ctx_); + sim_ = std::make_unique<arb::simulation>(rec, ctx_, domdec); // Create one trace for each probe. traces_.reserve(probes_.size()); diff --git a/python/test/fixtures.py b/python/test/fixtures.py index 9f44c516cdb6e579f640d9dc93561c0c86318236..7885b12af3a76260548f849785bdf17a74e2428c 100644 --- a/python/test/fixtures.py +++ b/python/test/fixtures.py @@ -246,4 +246,4 @@ def sum_weight_hh_spike_2(): @art_spiker_recipe def art_spiking_sim(context, art_spiker_recipe): dd = arbor.partition_load_balance(art_spiker_recipe, context) - return arbor.simulation(art_spiker_recipe, dd, context) + return arbor.simulation(art_spiker_recipe, context, dd) diff --git a/python/test/unit/test_cable_probes.py b/python/test/unit/test_cable_probes.py index 59e8e6109a3dd0e5aae8293b57fe455be73c98a5..d2a66057e16889120fbedf7e1d976a26d7dc4ccd 100644 --- a/python/test/unit/test_cable_probes.py +++ b/python/test/unit/test_cable_probes.py @@ -91,7 +91,7 @@ class TestCableProbes(unittest.TestCase): recipe = cc_recipe() context = A.context() dd = A.partition_load_balance(recipe, context) - sim = A.simulation(recipe, dd, context) + sim = A.simulation(recipe, context, dd) all_cv_cables = [A.cable(0, 0, 1)] diff --git a/python/test/unit/test_catalogues.py b/python/test/unit/test_catalogues.py index e4bd1e634b7d1ece3410f25689190199ccafc259..6ef96f91f2022aa2948c521a9313b244d1c64fa4 100644 --- a/python/test/unit/test_catalogues.py +++ b/python/test/unit/test_catalogues.py @@ -60,7 +60,7 @@ class TestCatalogues(unittest.TestCase): rcp = recipe() ctx = arb.context() dom = arb.partition_load_balance(rcp, ctx) - sim = arb.simulation(rcp, dom, ctx) + sim = arb.simulation(rcp, ctx, dom) sim.run(tfinal=30) def test_empty(self): diff --git a/python/test/unit/test_multiple_connections.py b/python/test/unit/test_multiple_connections.py index 89cc32d74ef2f992fd80cdf74eb91ba8ab32926a..37b0f16e20790602b0154e82fa7fc446699e752b 100644 --- a/python/test/unit/test_multiple_connections.py +++ b/python/test/unit/test_multiple_connections.py @@ -134,8 +134,7 @@ class TestMultipleConnections(unittest.TestCase): self.assertAlmostEqual(connections_from_recipe[3].delay, 1.4) # construct domain_decomposition and simulation object - dd = arb.partition_load_balance(art_spiker_recipe, context) - sim = arb.simulation(art_spiker_recipe, dd, context) + sim = arb.simulation(art_spiker_recipe, context) sim.record(arb.spike_recording.all) # create schedule and handle to record the membrane potential of neuron 3 @@ -365,9 +364,8 @@ class TestMultipleConnections(unittest.TestCase): self.assertAlmostEqual(connections_from_recipe[1].weight, weight2) self.assertAlmostEqual(connections_from_recipe[1].delay, 1.4) - # construct domain_decomposition and simulation object - dd = arb.partition_load_balance(art_spiker_recipe, context) - sim = arb.simulation(art_spiker_recipe, dd, context) + # construct simulation object + sim = arb.simulation(art_spiker_recipe, context) sim.record(arb.spike_recording.all) # create schedule and handle to record the membrane potential of neuron 3 diff --git a/python/test/unit/test_profiling.py b/python/test/unit/test_profiling.py index 92a17e237b009249b648d78373e7f488b3c00a0c..0c0d484d597a0d715f40a25779c352b931be3bc5 100644 --- a/python/test/unit/test_profiling.py +++ b/python/test/unit/test_profiling.py @@ -93,7 +93,7 @@ class TestProfiling(unittest.TestCase): arb.profiler_initialize(context) recipe = a_recipe() dd = arb.partition_load_balance(recipe, context) - arb.simulation(recipe, dd, context).run(1) + arb.simulation(recipe, context, dd).run(1) summary = arb.profiler_summary() self.assertEqual(str, type(summary), "profiler summary must be str") self.assertTrue(summary, "empty summary") diff --git a/python/test/unit_distributed/test_simulator.py b/python/test/unit_distributed/test_simulator.py index 8ac2a98330c4cc7589cf369abae4d7ae211e18f6..ea657a9526a4360b1680036de2d766b4e48a4659 100644 --- a/python/test/unit_distributed/test_simulator.py +++ b/python/test/unit_distributed/test_simulator.py @@ -69,7 +69,7 @@ class TestSimulator(unittest.TestCase): self.assertEqual(1, len(local_groups)) self.assertEqual([self.rank], local_groups[0].gids) - return A.simulation(recipe, dd, context) + return A.simulation(recipe, context, dd) def test_local_spikes(self): sim = self.init_sim() diff --git a/test/unit/test_event_delivery.cpp b/test/unit/test_event_delivery.cpp index 1294240ce8f8d285684d90e814ca824115d88b5c..82a029f839b532ee6c880cd413ab7666b1f2501f 100644 --- a/test/unit/test_event_delivery.cpp +++ b/test/unit/test_event_delivery.cpp @@ -49,17 +49,18 @@ using gid_vector = std::vector<cell_gid_type>; using group_gids_type = std::vector<gid_vector>; std::vector<cell_gid_type> run_test_sim(const recipe& R, const group_gids_type& group_gids) { - arb::context ctx = make_context(proc_allocation{}); - unsigned n = R.num_cells(); + unsigned n = R.num_cells(); std::vector<group_description> groups; for (const auto& gidvec: group_gids) { groups.emplace_back(cell_kind::cable, gidvec, backend_kind::multicore); } - auto D = domain_decomposition(R, ctx, groups); - std::vector<spike> spikes; - simulation sim(R, D, ctx); + auto C = make_context(); + auto D = domain_decomposition(R, C, groups); + simulation sim(R, C, D); + + std::vector<spike> spikes; sim.set_global_spike_callback( [&spikes](const std::vector<spike>& ss) { spikes.insert(spikes.end(), ss.begin(), ss.end()); diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index d9d1aa76c7ceebc0eb73f2560b4d51cbda281bae..55f10243cb64ca0f6125eb06f914587921efbd7b 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -485,7 +485,7 @@ TEST(fvm_lowered, derived_mechs) { float times[] = {10.f, 20.f}; auto decomp = partition_load_balance(rec, context); - simulation sim(rec, decomp, context); + simulation sim(rec, context, decomp); sim.add_sampler(all_probes, explicit_schedule(times), sampler); sim.run(30.0, 1.f/1024); @@ -516,7 +516,7 @@ TEST(fvm_lowered, null_region) { rec.catalogue().derive("custom_kin1", "test_kin1", {{"tau", 20.0}}); auto decomp = partition_load_balance(rec, context); - simulation sim(rec, decomp, context); + simulation sim(rec, context, decomp); EXPECT_NO_THROW(sim.run(30.0, 1.f/1024)); } diff --git a/test/unit/test_lif_cell_group.cpp b/test/unit/test_lif_cell_group.cpp index 9ca765c40608093c049b57e86c4cf6609620702f..0ea015a39ddbb154e7d171a8e28fbfccb93bb330 100644 --- a/test/unit/test_lif_cell_group.cpp +++ b/test/unit/test_lif_cell_group.cpp @@ -133,7 +133,7 @@ TEST(lif_cell_group, throw) { probe_recipe rec; auto context = make_context(); auto decomp = partition_load_balance(rec, context); - EXPECT_THROW(simulation(rec, decomp, context), bad_cell_probe); + EXPECT_THROW(simulation(rec, context, decomp), bad_cell_probe); } TEST(lif_cell_group, recipe) @@ -153,7 +153,7 @@ TEST(lif_cell_group, spikes) { auto context = make_context(); auto decomp = partition_load_balance(recipe, context); - simulation sim(recipe, decomp, context); + simulation sim(recipe, context, decomp); cse_vector events; @@ -193,7 +193,7 @@ TEST(lif_cell_group, ring) auto decomp = partition_load_balance(recipe, context); // Creates a simulation with a ring recipe of lif neurons - simulation sim(recipe, decomp, context); + simulation sim(recipe, context, decomp); std::vector<spike> spike_buffer; diff --git a/test/unit/test_probe.cpp b/test/unit/test_probe.cpp index bc9f41c374678bd4f2251c3072d58402fdfd38ed..ca541823faa3cca46e41489e8a04857c534cefca 100644 --- a/test/unit/test_probe.cpp +++ b/test/unit/test_probe.cpp @@ -826,7 +826,7 @@ void run_axial_and_ion_current_sampled_probe_test(const context& ctx) { partition_hint_map phints = { {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}} }; - simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx); + simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints)); // Take a sample at 20 tau, and run sim for just a bit longer. @@ -913,7 +913,7 @@ auto run_simple_samplers( partition_hint_map phints = { {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}} }; - simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx); + simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints)); std::vector<trace_vector<SampleData, SampleMeta>> traces(n_probe); for (unsigned i = 0; i<n_probe; ++i) { @@ -1278,13 +1278,13 @@ void run_exact_sampling_probe_test(const context& ctx) { }; domain_decomposition one_cell_group = partition_load_balance(rec, ctx, phints); - simulation lax_sim(rec, one_cell_group, ctx); + simulation lax_sim(rec, ctx, one_cell_group); for (unsigned i = 0; i<n_cell; ++i) { lax_sim.add_sampler(one_probe({i, 0}), sample_sched, make_simple_sampler(lax_traces.at(i)), sampling_policy::lax); } lax_sim.run(t_end, max_dt); - simulation exact_sim(rec, one_cell_group, ctx); + simulation exact_sim(rec, ctx, one_cell_group); for (unsigned i = 0; i<n_cell; ++i) { exact_sim.add_sampler(one_probe({i, 0}), sample_sched, make_simple_sampler(exact_traces.at(i)), sampling_policy::exact); } @@ -1366,7 +1366,7 @@ TEST(probe, get_probe_metadata) { partition_hint_map phints = { {cell_kind::cable, {partition_hint::max_size, partition_hint::max_size, true}} }; - simulation sim(rec, partition_load_balance(rec, ctx, phints), ctx); + simulation sim(rec, ctx, partition_load_balance(rec, ctx, phints)); std::vector<probe_metadata> mm = sim.get_probe_metadata({0, 0}); ASSERT_EQ(3u, mm.size()); diff --git a/test/unit/test_recipe.cpp b/test/unit/test_recipe.cpp index f630f851f097fc16faa12ede8015176f012e1070..f1ffe120bc7f332b2c256402531e83cbe3c8597e 100644 --- a/test/unit/test_recipe.cpp +++ b/test/unit/test_recipe.cpp @@ -112,7 +112,7 @@ TEST(recipe, gap_junctions) auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {gjs_0, gjs_1}, {{}, {}}); auto decomp_0 = partition_load_balance(recipe_0, context); - EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context)); + EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0)); } { std::vector<arb::gap_junction_connection> gjs_0 = {{{1, "gapjunction1", policy::assert_univalent}, {"gapjunction0", policy::assert_univalent}, 0.1}, @@ -126,7 +126,7 @@ TEST(recipe, gap_junctions) auto recipe_1 = custom_recipe({cell_0, cell_1}, {{}, {}}, {gjs_0, gjs_1}, {{}, {}}); auto decomp_1 = partition_load_balance(recipe_1, context); - EXPECT_THROW(simulation(recipe_1, decomp_1, context), arb::bad_connection_label); + EXPECT_THROW(simulation(recipe_1, context, decomp_1), arb::bad_connection_label); } } @@ -150,7 +150,7 @@ TEST(recipe, connections) auto recipe_0 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}}, {{}, {}}); auto decomp_0 = partition_load_balance(recipe_0, context); - EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context)); + EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0)); } { conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1}, @@ -164,7 +164,7 @@ TEST(recipe, connections) auto recipe_1 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}}, {{}, {}}); auto decomp_1 = partition_load_balance(recipe_1, context); - EXPECT_THROW(simulation(recipe_1, decomp_1, context), arb::bad_connection_source_gid); + EXPECT_THROW(simulation(recipe_1, context, decomp_1), arb::bad_connection_source_gid); } { conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1}, @@ -178,7 +178,7 @@ TEST(recipe, connections) auto recipe_2 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}}, {{}, {}}); auto decomp_2 = partition_load_balance(recipe_2, context); - EXPECT_THROW(simulation(recipe_2, decomp_2, context), arb::bad_connection_label); + EXPECT_THROW(simulation(recipe_2, context, decomp_2), arb::bad_connection_label); } { conns_0 = {{{1, "detector0"}, {"synapse0"}, 0.1, 0.1}, @@ -192,7 +192,7 @@ TEST(recipe, connections) auto recipe_4 = custom_recipe({cell_0, cell_1}, {conns_0, conns_1}, {{}, {}}, {{}, {}}); auto decomp_4 = partition_load_balance(recipe_4, context); - EXPECT_THROW(simulation(recipe_4, decomp_4, context), arb::bad_connection_label); + EXPECT_THROW(simulation(recipe_4, context, decomp_4), arb::bad_connection_label); } } @@ -210,7 +210,7 @@ TEST(recipe, event_generators) { auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}}, {gens_0, gens_1}); auto decomp_0 = partition_load_balance(recipe_0, context); - EXPECT_NO_THROW(simulation(recipe_0, decomp_0, context)); + EXPECT_NO_THROW(simulation(recipe_0, context, decomp_0)); } { gens_0 = {arb::explicit_generator({{{"synapse0"}, 1.0, 0.1}, {{"synapse3"}, 2.0, 0.1}})}; @@ -219,6 +219,6 @@ TEST(recipe, event_generators) { auto recipe_0 = custom_recipe({cell_0, cell_1}, {{}, {}}, {{}, {}}, {gens_0, gens_1}); auto decomp_0 = partition_load_balance(recipe_0, context); - EXPECT_THROW(simulation(recipe_0, decomp_0, context), arb::bad_connection_label); + EXPECT_THROW(simulation(recipe_0, context, decomp_0), arb::bad_connection_label); } } diff --git a/test/unit/test_simulation.cpp b/test/unit/test_simulation.cpp index 3110d43822af939061ea6422bf77531a6c377589..b79ac19a6fa21d152ea1933b3a72e417a41a4795 100644 --- a/test/unit/test_simulation.cpp +++ b/test/unit/test_simulation.cpp @@ -51,7 +51,7 @@ TEST(simulation, null) { auto r = null_recipe{}; auto c = arb::make_context(); auto d = arb::partition_load_balance(r, c); - auto s = arb::simulation(r, d, c); + auto s = arb::simulation(r, c, d); s.run(0.05, 0.01); } @@ -74,7 +74,7 @@ TEST(simulation, spike_global_callback) { play_spikes rec(spike_times); auto ctx = n_thread_context(4); auto decomp = partition_load_balance(rec, ctx); - simulation sim(rec, decomp, ctx); + simulation sim(rec, ctx, decomp); std::vector<spike> collected; sim.set_global_spike_callback([&](const std::vector<spike>& spikes) { @@ -155,7 +155,7 @@ TEST(simulation, restart) { auto ctx = n_thread_context(4); auto decomp = partition_load_balance(rec, ctx); - simulation sim(rec, decomp, ctx); + simulation sim(rec, ctx, decomp); std::vector<spike> collected; sim.set_global_spike_callback([&](const std::vector<spike>& spikes) { diff --git a/test/unit/test_spikes.cpp b/test/unit/test_spikes.cpp index 12a53234f044b766d99142e38f9f321d8405421f..d3279169a2f47d13460736288b03c2b87ddeec2f 100644 --- a/test/unit/test_spikes.cpp +++ b/test/unit/test_spikes.cpp @@ -234,7 +234,7 @@ TEST(SPIKES_TEST_CLASS, threshold_watcher_interpolation) { cable1d_recipe rec({cell}); auto decomp = arb::partition_load_balance(rec, context); - arb::simulation sim(rec, decomp, context); + arb::simulation sim(rec, context, decomp); sim.set_global_spike_callback( [&spikes](const std::vector<arb::spike>& recorded_spikes) {