From fc85765e5550334fcc0661739b5ff2b72d00ef9a Mon Sep 17 00:00:00 2001
From: Thorsten Hater <24411438+thorstenhater@users.noreply.github.com>
Date: Wed, 5 Oct 2022 10:11:30 +0200
Subject: [PATCH] Clean up plasticity (#1985)

1. Fix Python bindings for `recipe::update`
   - *drop* the GIL before handing off to C++
   - tighten exception safety
2. Run plasticity examples with threads; both C++ and Python.
   - C++: Guard against I/O interleaving.
   - Py: Drop spikes from source, prettify reporting.
   - C++: use decor chaining.
3. Modernise PYBIND11_OVERLOAD -> *RIDE (advised since 2.6).
4. No longer do we initialise connectivity twice.
   - Simplify communicator construction.
   - Fix unit tests that needed to two-phase init communicator.
---
 arbor/communication/communicator.cpp        | 11 +++---
 arbor/communication/communicator.hpp        |  2 --
 arbor/simulation.cpp                        |  3 +-
 example/plasticity/plasticity.cpp           | 40 ++++++++++++++-------
 python/example/plasticity.py                | 22 ++++++------
 python/recipe.hpp                           | 16 ++++-----
 python/simulation.cpp                       | 15 +++++---
 test/unit-distributed/test_communicator.cpp | 12 ++++---
 8 files changed, 71 insertions(+), 50 deletions(-)

diff --git a/arbor/communication/communicator.cpp b/arbor/communication/communicator.cpp
index 6f816bf7..70fa08ba 100644
--- a/arbor/communication/communicator.cpp
+++ b/arbor/communication/communicator.cpp
@@ -24,16 +24,12 @@ namespace arb {
 
 communicator::communicator(const recipe& rec,
                            const domain_decomposition& dom_dec,
-                           const label_resolution_map& source_resolution_map,
-                           const label_resolution_map& target_resolution_map,
                            execution_context& ctx):  num_total_cells_{rec.num_cells()},
                                                      num_local_cells_{dom_dec.num_local_cells()},
                                                      num_local_groups_{dom_dec.num_groups()},
                                                      num_domains_{(cell_size_type) ctx.distributed->size()},
                                                      distributed_{ctx.distributed},
-                                                     thread_pool_{ctx.thread_pool} {
-    update_connections(rec, dom_dec, source_resolution_map, target_resolution_map);
-}
+                                                     thread_pool_{ctx.thread_pool} {}
 
 void communicator::update_connections(const connectivity& rec,
                                       const domain_decomposition& dom_dec,
@@ -80,7 +76,6 @@ void communicator::update_connections(const connectivity& rec,
             auto gid = gids[i];
             gid_infos[i] = gid_info(gid, i, rec.connections_on(gid));
         });
-
     cell_local_size_type n_cons =
         util::sum_by(gid_infos, [](const gid_info& g){ return g.conns.size(); });
     std::vector<unsigned> src_domains;
@@ -129,7 +124,9 @@ void communicator::update_connections(const connectivity& rec,
     // This is num_domains_ independent sorts, so it can be parallelized trivially.
     const auto& cp = connection_part_;
     threading::parallel_for::apply(0, num_domains_, thread_pool_.get(),
-                                   [&](cell_size_type i) { util::sort(util::subrange_view(connections_, cp[i], cp[i+1])); });
+                                   [&](cell_size_type i) {
+                                       util::sort(util::subrange_view(connections_, cp[i], cp[i+1]));
+                                   });
 }
 
 std::pair<cell_size_type, cell_size_type> communicator::group_queue_range(cell_size_type i) {
diff --git a/arbor/communication/communicator.hpp b/arbor/communication/communicator.hpp
index 7a4becd3..c9eab082 100644
--- a/arbor/communication/communicator.hpp
+++ b/arbor/communication/communicator.hpp
@@ -32,8 +32,6 @@ public:
 
     explicit communicator(const recipe& rec,
                           const domain_decomposition& dom_dec,
-                          const label_resolution_map& source_resolver,
-                          const label_resolution_map& target_resolver,
                           execution_context& ctx);
 
     /// The range of event queues that belong to cells in group i.
diff --git a/arbor/simulation.cpp b/arbor/simulation.cpp
index 55350d68..887462d1 100644
--- a/arbor/simulation.cpp
+++ b/arbor/simulation.cpp
@@ -222,7 +222,7 @@ simulation_state::simulation_state(
 
     source_resolution_map_ = label_resolution_map(std::move(global_sources));
     target_resolution_map_ = label_resolution_map(std::move(local_targets));
-    communicator_ = communicator(rec, ddc_, source_resolution_map_, target_resolution_map_, *ctx_);
+    communicator_ = communicator(rec, ddc_, *ctx_);
     update(rec);
     epoch_.reset();
 }
@@ -268,7 +268,6 @@ void simulation_state::update(const connectivity& rec) {
     event_lanes_[1].resize(num_local_cells);
 }
 
-
 void simulation_state::reset() {
     epoch_ = epoch();
 
diff --git a/example/plasticity/plasticity.cpp b/example/plasticity/plasticity.cpp
index 4b97d458..37748caf 100644
--- a/example/plasticity/plasticity.cpp
+++ b/example/plasticity/plasticity.cpp
@@ -31,7 +31,7 @@ struct recipe: public arb::recipe {
     arb::region all    = "(all)"_reg;           // Whole cell
     arb::cell_size_type n_ = 0;                 // Cell count
 
-    mutable std::unordered_map<arb::cell_gid_type, std::vector<arb::cell_connection>> connected; // lookup table for connections
+    std::unordered_map<arb::cell_gid_type, std::vector<arb::cell_connection>> connected; // lookup table for connections
     // Required but uninteresting methods
     recipe(arb::cell_size_type n): n_{n} {}
     arb::cell_size_type num_cells() const override { return n_; }
@@ -48,7 +48,12 @@ struct recipe: public arb::recipe {
         return {arb::cable_probe_membrane_voltage{center}};
     }
     // Look up the (potential) connection to this cell
-    std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override { return connected[gid]; }
+    std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override {
+        if (auto it = connected.find(gid); it != connected.end()) {
+            return it->second;
+        }
+        return {};
+    }
     // Connect cell `to` to the spike source
     void add_connection(arb::cell_gid_type to) { assert(to > 0); connected[to] = {arb::cell_connection({0, src}, {syn}, weight, delay)}; }
     // Return the cell at gid
@@ -57,29 +62,39 @@ struct recipe: public arb::recipe {
         if (gid == 0) return arb::spike_source_cell{src, arb::regular_schedule(f_spike)};
         // all others are receiving cable cells; single CV w/ HH
         arb::segment_tree tree; tree.append(arb::mnpos, {-r_soma, 0, 0, r_soma}, {r_soma, 0, 0, r_soma}, 1);
-        auto decor = arb::decor{};
-        decor.paint(all, arb::density("hh", {{"gl", 5}}));
-        decor.place(center, arb::synapse("expsyn"), syn);
-        decor.place(center, arb::threshold_detector{-10.0}, det);
-        decor.set_default(arb::cv_policy_every_segment());
+        auto decor = arb::decor{}
+            .paint(all, arb::density("hh", {{"gl", 5}}))
+            .place(center, arb::synapse("expsyn"), syn)
+            .place(center, arb::threshold_detector{-10.0}, det)
+            .set_default(arb::cv_policy_every_segment());
         return arb::cable_cell({tree}, {}, decor);
     }
 };
 
+// For demonstration: Avoid interleaving std::cout in multi-threaded scenarios.
+// NEVER do this in HPC!!!
+std::mutex mtx;
+
 void sampler(arb::probe_metadata pm, std::size_t n, const arb::sample_record* samples) {
     auto* loc = arb::util::any_cast<const arb::mlocation*>(pm.meta);
-    std::cout << std::fixed << std::setprecision(4);
+
     for (std::size_t i = 0; i<n; ++i) {
+        std::lock_guard<std::mutex> lock{mtx};
         auto* value = arb::util::any_cast<const double*>(samples[i].data);
-        std::cout << "|  " << samples[i].time << " |      " << loc->pos << " | " << *value << " |\n";
+        std::cout << std::fixed << std::setprecision(4)
+                  << "|  " << samples[i].time << " |      " << loc->pos << " | " << *value << " |\n";
     }
 }
 
 void spike_cb(const std::vector<arb::spike>& spikes) {
-    for(const auto& spike: spikes) std::cout << " * " << spike.source << "@" << spike.time << '\n';
+    for(const auto& spike: spikes) {
+        std::lock_guard<std::mutex> lock{mtx};
+        std::cout << " * " << spike.source << "@" << spike.time << '\n';
+    }
 }
 
 void print_header(double from, double to) {
+    std::lock_guard<std::mutex> lock{mtx};
     std::cout << "\n"
               << "Running simulation from " << from << "ms to " << to << "ms\n"
               << "Spikes are marked: *\n"
@@ -93,9 +108,10 @@ const double dt = 0.05;
 int main(int argc, char** argv) {
     auto rec = recipe(3);
     rec.add_connection(1);
-    auto sim = arb::simulation(rec);
+    auto ctx = arb::make_context(arb::proc_allocation{8, -1});
+    auto sim = arb::simulation(rec, ctx);
     sim.add_sampler(arb::all_probes, arb::regular_schedule(dt), sampler, arb::sampling_policy::exact);
-    sim.set_local_spike_callback(spike_cb);
+    sim.set_global_spike_callback(spike_cb);
     print_header(0, 1);
     sim.run(1.0, dt);
     rec.add_connection(2);
diff --git a/python/example/plasticity.py b/python/example/plasticity.py
index ba4c2099..cd8b8a2b 100644
--- a/python/example/plasticity.py
+++ b/python/example/plasticity.py
@@ -75,9 +75,10 @@ class recipe(A.recipe):
         self.connected.add(to)
 
 
+# Context for multi-threading
+ctx = A.context(threads=2)
 # Make an unconnected network with 2 cable cells and one spike source,
 rec = recipe(3)
-
 # but before setting up anything, connect cable cell gid=1 to spike source gid=0
 # and make the simulation of the simple network
 #
@@ -88,12 +89,10 @@ rec = recipe(3)
 # Note that the connection is just _recorded_ in the recipe, the actual connectivity
 # is set up in the simulation construction.
 rec.add_connection_to_spike_source(1)
-sim = A.simulation(rec)
+sim = A.simulation(rec, ctx)
 sim.record(A.spike_recording.all)
-
 # then run the simulation for a bit
 sim.run(0.25, 0.025)
-
 # update the simulation to
 #
 #    spike_source <gid=0> ----> cable_cell <gid=1>
@@ -101,11 +100,14 @@ sim.run(0.25, 0.025)
 #                         ----> cable_cell <gid=2>
 rec.add_connection_to_spike_source(2)
 sim.update(rec)
-
 # and run the simulation for another bit.
 sim.run(0.5, 0.025)
-
-# When finished, print spike times and locations.
-print("spikes:")
-for sp in sim.spikes():
-    print(" ", sp)
+# when finished, print spike times and locations.
+source_spikes = 0
+print("Spikes:")
+for (gid, lid), t in sim.spikes():
+    if gid == 0:
+        source_spikes += 1
+    else:
+        print(f"  * {t:>8.4f}ms: gid={gid} detector={lid}")
+print(f"Source spiked {source_spikes:>5d} times.")
diff --git a/python/recipe.hpp b/python/recipe.hpp
index a8a658ab..1a3b4453 100644
--- a/python/recipe.hpp
+++ b/python/recipe.hpp
@@ -53,35 +53,35 @@ public:
 class py_recipe_trampoline: public py_recipe {
 public:
     arb::cell_size_type num_cells() const override {
-        PYBIND11_OVERLOAD_PURE(arb::cell_size_type, py_recipe, num_cells);
+        PYBIND11_OVERRIDE_PURE(arb::cell_size_type, py_recipe, num_cells);
     }
 
     pybind11::object cell_description(arb::cell_gid_type gid) const override {
-        PYBIND11_OVERLOAD_PURE(pybind11::object, py_recipe, cell_description, gid);
+        PYBIND11_OVERRIDE_PURE(pybind11::object, py_recipe, cell_description, gid);
     }
 
     arb::cell_kind cell_kind(arb::cell_gid_type gid) const override {
-        PYBIND11_OVERLOAD_PURE(arb::cell_kind, py_recipe, cell_kind, gid);
+        PYBIND11_OVERRIDE_PURE(arb::cell_kind, py_recipe, cell_kind, gid);
     }
 
     std::vector<pybind11::object> event_generators(arb::cell_gid_type gid) const override {
-        PYBIND11_OVERLOAD(std::vector<pybind11::object>, py_recipe, event_generators, gid);
+        PYBIND11_OVERRIDE(std::vector<pybind11::object>, py_recipe, event_generators, gid);
     }
 
     std::vector<arb::cell_connection> connections_on(arb::cell_gid_type gid) const override {
-        PYBIND11_OVERLOAD(std::vector<arb::cell_connection>, py_recipe, connections_on, gid);
+        PYBIND11_OVERRIDE(std::vector<arb::cell_connection>, py_recipe, connections_on, gid);
     }
 
     std::vector<arb::gap_junction_connection> gap_junctions_on(arb::cell_gid_type gid) const override {
-        PYBIND11_OVERLOAD(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid);
+        PYBIND11_OVERRIDE(std::vector<arb::gap_junction_connection>, py_recipe, gap_junctions_on, gid);
     }
 
     std::vector<arb::probe_info> probes(arb::cell_gid_type gid) const override {
-        PYBIND11_OVERLOAD(std::vector<arb::probe_info>, py_recipe, probes, gid);
+        PYBIND11_OVERRIDE(std::vector<arb::probe_info>, py_recipe, probes, gid);
     }
 
     pybind11::object global_properties(arb::cell_kind kind) const override {
-        PYBIND11_OVERLOAD(pybind11::object, py_recipe, global_properties, kind);
+        PYBIND11_OVERRIDE(pybind11::object, py_recipe, global_properties, kind);
     }
 };
 
diff --git a/python/simulation.cpp b/python/simulation.cpp
index ef3e1868..c5a15b2e 100644
--- a/python/simulation.cpp
+++ b/python/simulation.cpp
@@ -67,6 +67,16 @@ public:
         }
     }
 
+    void update(std::shared_ptr<py_recipe>& rec) {
+        try {
+            sim_->update(py_recipe_shim(rec));
+        }
+        catch (...) {
+            py_reset_and_throw();
+            throw;
+        }
+    }
+
     void reset() {
         sim_->reset();
         spike_record_.clear();
@@ -94,10 +104,6 @@ public:
         sim_->set_binning_policy(policy, bin_interval);
     }
 
-    void update(std::shared_ptr<py_recipe>& rec) {
-        sim_->update(py_recipe_shim(rec));
-    }
-
     void record(spike_recording policy) {
         auto spike_recorder = [this](const std::vector<arb::spike>& spikes) {
             auto old_size = spike_record_.size();
@@ -222,6 +228,7 @@ void register_simulation(pybind11::module& m, pyarb_global_ptr global_ptr) {
              pybind11::arg_v("domains", pybind11::none(), "Domain decomposition"),
              pybind11::arg_v("seed", 0u, "Random number generator seed"))
         .def("update", &simulation_shim::update,
+             pybind11::call_guard<pybind11::gil_scoped_release>(),
              "Rebuild the connection table from recipe::connections_on and the event"
              "generators based on recipe::event_generators.",
              "recipe"_a)
diff --git a/test/unit-distributed/test_communicator.cpp b/test/unit-distributed/test_communicator.cpp
index e3ac5250..89f98dc1 100644
--- a/test/unit-distributed/test_communicator.cpp
+++ b/test/unit-distributed/test_communicator.cpp
@@ -531,8 +531,8 @@ TEST(communicator, ring)
     auto global_sources = g_context->distributed->gather_cell_labels_and_gids(local_sources);
 
     // construct the communicator
-    auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets), *g_context);
-
+    auto C = communicator(R, D, *g_context);
+    C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map(local_targets));
     // every cell fires
     EXPECT_TRUE(test_ring(D, C, [](cell_gid_type g){return true;}));
     // last cell in each domain fires
@@ -638,11 +638,12 @@ TEST(communicator, all2all)
     auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, mc_gids});
 
     // construct the communicator
-    auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids}), *g_context);
+    auto C = communicator(R, D, *g_context);
+    C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, mc_gids}));
     auto connections = C.connections();
 
     for (auto i: util::make_span(0, n_global)) {
-        for (unsigned j = 0; j < n_local; ++j) {
+        for (auto j: util::make_span(0, n_local)) {
             auto c = connections[i*n_local+j];
             EXPECT_EQ(i, c.source.gid);
             EXPECT_EQ(0u, c.source.index);
@@ -684,7 +685,8 @@ TEST(communicator, mini_network)
     auto global_sources = g_context->distributed->gather_cell_labels_and_gids({local_sources, gids});
 
     // construct the communicator
-    auto C = communicator(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids}), *g_context);
+    auto C = communicator(R, D, *g_context);
+    C.update_connections(R, D, label_resolution_map(global_sources), label_resolution_map({local_targets, gids}));
 
     // sort connections by source then target
     auto connections = C.connections();
-- 
GitLab