From fc85765e5550334fcc0661739b5ff2b72d00ef9a Mon Sep 17 00:00:00 2001 From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com> Date: Wed, 5 Oct 2022 10:11:30 +0200 Subject: [PATCH] Clean up plasticity (#1985) 1. Fix Python bindings for `recipe::update` - *drop* the GIL before handing off to C++ - tighten exception safety 2. Run plasticity examples with threads; both C++ and Python. - C++: Guard against I/O interleaving. - Py: Drop spikes from source, prettify reporting. - C++: use decor chaining. 3. Modernise PYBIND11_OVERLOAD -> *RIDE (advised since 2.6). 4. No longer do we initialise connectivity twice. - Simplify communicator construction. - Fix unit tests that needed to two-phase init communicator. --- arbor/communication/communicator.cpp | 11 +++--- arbor/communication/communicator.hpp | 2 -- arbor/simulation.cpp | 3 +- example/plasticity/plasticity.cpp | 40 ++++++++++++++------- python/example/plasticity.py | 22 ++++++------ python/recipe.hpp | 16 ++++----- python/simulation.cpp | 15 +++++--- test/unit-distributed/test_communicator.cpp | 12 ++++--- 8 files changed, 71 insertions(+), 50 deletions(-) diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp index 6f816bf7..70fa08ba 100644 --- a/arbor/communication/communicator.cpp +++ b/arbor/communication/communicator.cpp @@ -24,16 +24,12 @@ namespace arb { communicator::communicator(const recipe& rec, const domain_decomposition& dom_dec, - const label_resolution_map& source_resolution_map, - const label_resolution_map& target_resolution_map, execution_context& ctx): num_total_cells_{rec.num_cells()}, num_local_cells_{dom_dec.num_local_cells()}, num_local_groups_{dom_dec.num_groups()}, num_domains_{(cell_size_type) ctx.distributed->size()}, distributed_{ctx.distributed}, - thread_pool_{ctx.thread_pool} { - update_connections(rec, dom_dec, source_resolution_map, target_resolution_map); -} + thread_pool_{ctx.thread_pool} {} void communicator::update_connections(const connectivity& rec, const domain_decomposition& dom_dec, @@ -80,7 +76,6 @@ void communicator::update_connections(const connectivity& rec, auto gid = gids[i]; gid_infos[i] = gid_info(gid, i, rec.connections_on(gid)); }); - cell_local_size_type n_cons = util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); }); std::vector<unsigned> src_domains; @@ -129,7 +124,9 @@ void communicator::update_connections(const connectivity& rec, // This is num_domains_ independent sorts, so it can be parallelized trivially. const auto& cp = connection_part_; threading::parallel_for::apply(0, num_domains_, thread_pool_.get(), - [&](cell_size_type i) { util::sort(util::subrange_view(connections_, cp[i], cp[i+1])); }); + [&](cell_size_type i) { + util::sort(util::subrange_view(connections_, cp[i], cp[i+1])); + }); } std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_size_type i) { diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp index 7a4becd3..c9eab082 100644 --- a/arbor/communication/communicator.hpp +++ b/arbor/communication/communicator.hpp @@ -32,8 +32,6 @@ public: explicit communicator(const recipe& rec, const domain_decomposition& dom_dec, - const label_resolution_map& source_resolver, - const label_resolution_map& target_resolver, execution_context& ctx); /// The range of event queues that belong to cells in group i. diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp index 55350d68..887462d1 100644 --- a/arbor/simulation.cpp +++ b/arbor/simulation.cpp @@ -222,7 +222,7 @@ simulation_state::simulation_state( source_resolution_map_ = label_resolution_map(std::move(global_sources)); target_resolution_map_ = label_resolution_map(std::move(local_targets)); - communicator_ = communicator(rec, ddc_, source_resolution_map_, target_resolution_map_, *ctx_); + communicator_ = communicator(rec, ddc_, *ctx_); update(rec); epoch_.reset(); } @@ -268,7 +268,6 @@ void simulation_state::update(const connectivity& rec) { event_lanes_[1].resize(num_local_cells); } - void simulation_state::reset() { epoch_ = epoch(); diff --git a/example/plasticity/plasticity.cpp b/example/plasticity/plasticity.cpp index 4b97d458..37748caf 100644 --- a/example/plasticity/plasticity.cpp +++ b/example/plasticity/plasticity.cpp @@ -31,7 +31,7 @@ struct recipe: public arb::recipe { arb::region all = "(all)"_reg; // Whole cell arb::cell_size_type n_ = 0; // Cell count - mutable std::unordered_map<arb::cell_gid_type, std::vector<arb::cell_connection>> connected; // lookup table for connections + std::unordered_map<arb::cell_gid_type, std::vector<arb::cell_connection>> connected; // lookup table for connections // Required but uninteresting methods recipe(arb::cell_size_type n): n_{n} {} arb::cell_size_type num_cells() const override { return n_; } @@ -48,7 +48,12 @@ struct recipe: public arb::recipe { return {arb::cable_probe_membrane_voltage{center}}; } // Look up the (potential) connection to this cell - std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override { return connected[gid]; } + std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override { + if (auto it = connected.find(gid); it != connected.end()) { + return it->second; + } + return {}; + } // Connect cell `to` to the spike source void add_connection(arb::cell_gid_type to) { assert(to > 0); connected[to] = {arb::cell_connection({0, src}, {syn}, weight, delay)}; } // Return the cell at gid @@ -57,29 +62,39 @@ struct recipe: public arb::recipe { if (gid == 0) return arb::spike_source_cell{src, arb::regular_schedule(f_spike)}; // all others are receiving cable cells; single CV w/ HH arb::segment_tree tree; tree.append(arb::mnpos, {-r_soma, 0, 0, r_soma}, {r_soma, 0, 0, r_soma}, 1); - auto decor = arb::decor{}; - decor.paint(all, arb::density("hh", {{"gl", 5}})); - decor.place(center, arb::synapse("expsyn"), syn); - decor.place(center, arb::threshold_detector{-10.0}, det); - decor.set_default(arb::cv_policy_every_segment()); + auto decor = arb::decor{} + .paint(all, arb::density("hh", {{"gl", 5}})) + .place(center, arb::synapse("expsyn"), syn) + .place(center, arb::threshold_detector{-10.0}, det) + .set_default(arb::cv_policy_every_segment()); return arb::cable_cell({tree}, {}, decor); } }; +// For demonstration: Avoid interleaving std::cout in multi-threaded scenarios. +// NEVER do this in HPC!!! +std::mutex mtx; + void sampler(arb::probe_metadata pm, std::size_t n, const arb::sample_record* samples) { auto* loc = arb::util::any_cast<const arb::mlocation*>(pm.meta); - std::cout << std::fixed << std::setprecision(4); + for (std::size_t i = 0; i<n; ++i) { + std::lock_guard<std::mutex> lock{mtx}; auto* value = arb::util::any_cast<const double*>(samples[i].data); - std::cout << "| " << samples[i].time << " | " << loc->pos << " | " << *value << " |\n"; + std::cout << std::fixed << std::setprecision(4) + << "| " << samples[i].time << " | " << loc->pos << " | " << *value << " |\n"; } } void spike_cb(const std::vector<arb::spike>& spikes) { - for(const auto& spike: spikes) std::cout << " * " << spike.source << "@" << spike.time << '\n'; + for(const auto& spike: spikes) { + std::lock_guard<std::mutex> lock{mtx}; + std::cout << " * " << spike.source << "@" << spike.time << '\n'; + } } void print_header(double from, double to) { + std::lock_guard<std::mutex> lock{mtx}; std::cout << "\n" << "Running simulation from " << from << "ms to " << to << "ms\n" << "Spikes are marked: *\n" @@ -93,9 +108,10 @@ const double dt = 0.05; int main(int argc, char** argv) { auto rec = recipe(3); rec.add_connection(1); - auto sim = arb::simulation(rec); + auto ctx = arb::make_context(arb::proc_allocation{8, -1}); + auto sim = arb::simulation(rec, ctx); sim.add_sampler(arb::all_probes, arb::regular_schedule(dt), sampler, arb::sampling_policy::exact); - sim.set_local_spike_callback(spike_cb); + sim.set_global_spike_callback(spike_cb); print_header(0, 1); sim.run(1.0, dt); rec.add_connection(2); diff --git a/python/example/plasticity.py b/python/example/plasticity.py index ba4c2099..cd8b8a2b 100644 --- a/python/example/plasticity.py +++ b/python/example/plasticity.py @@ -75,9 +75,10 @@ class recipe(A.recipe): self.connected.add(to) +# Context for multi-threading +ctx = A.context(threads=2) # Make an unconnected network with 2 cable cells and one spike source, rec = recipe(3) - # but before setting up anything, connect cable cell gid=1 to spike source gid=0 # and make the simulation of the simple network # @@ -88,12 +89,10 @@ rec = recipe(3) # Note that the connection is just _recorded_ in the recipe, the actual connectivity # is set up in the simulation construction. rec.add_connection_to_spike_source(1) -sim = A.simulation(rec) +sim = A.simulation(rec, ctx) sim.record(A.spike_recording.all) - # then run the simulation for a bit sim.run(0.25, 0.025) - # update the simulation to # # spike_source <gid=0> ----> cable_cell <gid=1> @@ -101,11 +100,14 @@ sim.run(0.25, 0.025) # ----> cable_cell <gid=2> rec.add_connection_to_spike_source(2) sim.update(rec) - # and run the simulation for another bit. sim.run(0.5, 0.025) - -# When finished, print spike times and locations. -print("spikes:") -for sp in sim.spikes(): - print(" ", sp) +# when finished, print spike times and locations. +source_spikes = 0 +print("Spikes:") +for (gid, lid), t in sim.spikes(): + if gid == 0: + source_spikes += 1 + else: + print(f" * {t:>8.4f}ms: gid={gid} detector={lid}") +print(f"Source spiked {source_spikes:>5d} times.") diff --git a/python/recipe.hpp b/python/recipe.hpp index a8a658ab..1a3b4453 100644 --- a/python/recipe.hpp +++ b/python/recipe.hpp @@ -53,35 +53,35 @@ public: class py_recipe_trampoline: public py_recipe { public: arb::cell_size_type num_cells() const override { - PYBIND11_OVERLOAD_PURE(arb::cell_size_type, py_recipe, num_cells); + PYBIND11_OVERRIDE_PURE(arb::cell_size_type, py_recipe, num_cells); } pybind11::object cell_description(arb::cell_gid_type gid) const override { - PYBIND11_OVERLOAD_PURE(pybind11::object, py_recipe, cell_description, gid); + PYBIND11_OVERRIDE_PURE(pybind11::object, py_recipe, cell_description, gid); } arb::cell_kind cell_kind(arb::cell_gid_type gid) const override { - PYBIND11_OVERLOAD_PURE(arb::cell_kind, py_recipe, cell_kind, gid); + PYBIND11_OVERRIDE_PURE(arb::cell_kind, py_recipe, cell_kind, gid); } std::vector<pybind11::object> event_generators(arb::cell_gid_type gid) const override { - PYBIND11_OVERLOAD(std::vector<pybind11::object>, py_recipe, event_generators, gid); + PYBIND11_OVERRIDE(std::vector<pybind11::object>, py_recipe, event_generators, gid); } std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override { - PYBIND11_OVERLOAD(std::vector<arb::cell_connection>, py_recipe, connections_on, gid); + PYBIND11_OVERRIDE(std::vector<arb::cell_connection>, py_recipe, connections_on, gid); } std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type gid) const override { - PYBIND11_OVERLOAD(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid); + PYBIND11_OVERRIDE(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid); } std::vector<arb::probe_info> probes(arb::cell_gid_type gid) const override { - PYBIND11_OVERLOAD(std::vector<arb::probe_info>, py_recipe, probes, gid); + PYBIND11_OVERRIDE(std::vector<arb::probe_info>, py_recipe, probes, gid); } pybind11::object global_properties(arb::cell_kind kind) const override { - PYBIND11_OVERLOAD(pybind11::object, py_recipe, global_properties, kind); + PYBIND11_OVERRIDE(pybind11::object, py_recipe, global_properties, kind); } }; diff --git a/python/simulation.cpp b/python/simulation.cpp index ef3e1868..c5a15b2e 100644 --- a/python/simulation.cpp +++ b/python/simulation.cpp @@ -67,6 +67,16 @@ public: } } + void update(std::shared_ptr<py_recipe>& rec) { + try { + sim_->update(py_recipe_shim(rec)); + } + catch (...) { + py_reset_and_throw(); + throw; + } + } + void reset() { sim_->reset(); spike_record_.clear(); @@ -94,10 +104,6 @@ public: sim_->set_binning_policy(policy, bin_interval); } - void update(std::shared_ptr<py_recipe>& rec) { - sim_->update(py_recipe_shim(rec)); - } - void record(spike_recording policy) { auto spike_recorder = [this](const std::vector<arb::spike>& spikes) { auto old_size = spike_record_.size(); @@ -222,6 +228,7 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) { pybind11::arg_v("domains", pybind11::none(), "Domain decomposition"), pybind11::arg_v("seed", 0u, "Random number generator seed")) .def("update", &simulation_shim::update, + pybind11::call_guard<pybind11::gil_scoped_release>(), "Rebuild the connection table from recipe::connections_on and the event" "generators based on recipe::event_generators.", "recipe"_a) diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp index e3ac5250..89f98dc1 100644 --- a/test/unit-distributed/test_communicator.cpp +++ b/test/unit-distributed/test_communicator.cpp @@ -531,8 +531,8 @@ TEST(communicator, ring) auto global_sources = g_context->distributed->gather_cell_labels_and_gids(local_sources); // construct the communicator - auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets), *g_context); - + auto C = communicator(R, D, *g_context); + C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets)); // every cell fires EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;})); // last cell in each domain fires @@ -638,11 +638,12 @@ TEST(communicator, all2all) auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, mc_gids}); // construct the communicator - auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids}), *g_context); + auto C = communicator(R, D, *g_context); + C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids})); auto connections = C.connections(); for (auto i: util::make_span(0, n_global)) { - for (unsigned j = 0; j < n_local; ++j) { + for (auto j: util::make_span(0, n_local)) { auto c = connections[i*n_local+j]; EXPECT_EQ(i, c.source.gid); EXPECT_EQ(0u, c.source.index); @@ -684,7 +685,8 @@ TEST(communicator, mini_network) auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, gids}); // construct the communicator - auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids}), *g_context); + auto C = communicator(R, D, *g_context); + C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids})); // sort connections by source then target auto connections = C.connections(); -- GitLab