From 2b2f89c4466deda8a3cdb05a229284bf98a41774 Mon Sep 17 00:00:00 2001
From: Ben Cumming <louncharf@gmail.com>
Date: Wed, 10 May 2017 14:58:19 +0200
Subject: [PATCH] Feature/generic cell groups (#259)

Refactor model and recipe to build models that have different cell types.

    Refactor recipe::get_cell to return unique_any so that.
        All recipe definitions in tests and miniap had to be updated to use
        the new interface.
    Make a cell_group_factory that forwards arguments for building a
    cell group to the appropriate cell_group constructor.
    Refactor model to use generic cell types
        Constructor now delegates cell_group generation to the
        cell_group_factory.
        Add an implementation file model.cpp for model to reduce compilation
        times (by 2-7 seconds on my desktop).
    Refactor probe enumeration code in model and cell_group
        add interface to cell_group for querying enumeration of probes
        in a cell_group
        use this interface instead of directly computing enumeration in
        model constructor, which no longer has easy access to probe
        information.
---
 miniapp/miniapp.cpp                      |  16 +-
 miniapp/miniapp_recipes.cpp              |   6 +-
 src/CMakeLists.txt                       |   2 +
 src/cell.hpp                             |  23 +--
 src/cell_group.hpp                       |   5 +-
 src/cell_group_factory.cpp               |  39 ++++
 src/cell_group_factory.hpp               |  20 ++
 src/mc_cell_group.hpp                    |  34 ++++
 src/model.cpp                            | 227 +++++++++++++++++++++
 src/model.hpp                            | 243 ++---------------------
 src/probes.hpp                           |  22 ++
 src/recipe.hpp                           |   7 +-
 src/segment.hpp                          |  14 +-
 tests/unit/test_domain_decomposition.cpp |   4 +-
 14 files changed, 402 insertions(+), 260 deletions(-)
 create mode 100644 src/cell_group_factory.cpp
 create mode 100644 src/cell_group_factory.hpp
 create mode 100644 src/model.cpp
 create mode 100644 src/probes.hpp

diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp
index fe06203a..9b6cbc0f 100644
--- a/miniapp/miniapp.cpp
+++ b/miniapp/miniapp.cpp
@@ -33,7 +33,7 @@ using sample_trace_type = sample_trace<time_type, double>;
 using file_export_type = io::exporter_spike_file<global_policy>;
 void banner();
 std::unique_ptr<recipe> make_recipe(const io::cl_options&, const probe_distribution&);
-std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe);
+std::unique_ptr<sample_trace_type> make_trace(probe_record probe);
 using communicator_type = communication::communicator<communication::global_policy>;
 
 void write_trace_json(const sample_trace_type& trace, const std::string& prefix = "trace_");
@@ -117,7 +117,7 @@ int main(int argc, char** argv) {
                 continue;
             }
 
-            traces.push_back(make_trace(probe.id, probe.probe));
+            traces.push_back(make_trace(probe));
             m.attach_sampler(probe.id, make_trace_sampler(traces.back().get(), sample_dt));
         }
 
@@ -207,7 +207,7 @@ std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_d
     }
 }
 
-std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe) {
+std::unique_ptr<sample_trace_type> make_trace(probe_record probe) {
     std::string name = "";
     std::string units = "";
 
@@ -224,7 +224,7 @@ std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_s
     }
     name += probe.location.segment? "dend" : "soma";
 
-    return util::make_unique<sample_trace_type>(probe_id, name, units);
+    return util::make_unique<sample_trace_type>(probe.id, name, units);
 }
 
 void write_trace_json(const sample_trace_type& trace, const std::string& prefix) {
@@ -249,13 +249,17 @@ void write_trace_json(const sample_trace_type& trace, const std::string& prefix)
 }
 
 void report_compartment_stats(const recipe& rec) {
-std::size_t ncell = rec.num_cells();
+    std::size_t ncell = rec.num_cells();
     std::size_t ncomp_total = 0;
     std::size_t ncomp_min = std::numeric_limits<std::size_t>::max();
     std::size_t ncomp_max = 0;
 
     for (std::size_t i = 0; i<ncell; ++i) {
-        std::size_t ncomp = rec.get_cell(i).num_compartments();
+        std::size_t ncomp = 0;
+        auto c = rec.get_cell(i);
+        if (auto ptr = util::any_cast<cell>(&c)) {
+            ncomp = ptr->num_compartments();
+        }
         ncomp_total += ncomp;
         ncomp_min = std::min(ncomp_min, ncomp);
         ncomp_max = std::max(ncomp_max, ncomp);
diff --git a/miniapp/miniapp_recipes.cpp b/miniapp/miniapp_recipes.cpp
index 0aa559a5..32045e3a 100644
--- a/miniapp/miniapp_recipes.cpp
+++ b/miniapp/miniapp_recipes.cpp
@@ -6,6 +6,7 @@
 #include <cell.hpp>
 #include <morphology.hpp>
 #include <util/debug.hpp>
+#include <util/unique_any.hpp>
 
 #include "miniapp_recipes.hpp"
 #include "morphology_pool.hpp"
@@ -80,7 +81,7 @@ public:
 
     cell_size_type num_cells() const override { return ncell_; }
 
-    cell get_cell(cell_gid_type i) const override {
+    util::unique_any get_cell(cell_gid_type i) const override {
         auto gen = std::mt19937(i); // TODO: replace this with hashing generator...
 
         auto cc = get_cell_count_info(i);
@@ -108,7 +109,8 @@ public:
             }
         }
         EXPECTS(cell.probes().size()==cc.num_probes);
-        return cell;
+
+        return util::unique_any(std::move(cell));
     }
 
     cell_kind get_cell_kind(cell_gid_type) const override {
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index c42d2d13..469086ad 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -1,7 +1,9 @@
 set(BASE_SOURCES
     common_types_io.cpp
     cell.cpp
+    cell_group_factory.cpp
     event_binner.cpp
+    model.cpp
     morphology.cpp
     parameter_list.cpp
     profiling/memory_meter.cpp
diff --git a/src/cell.hpp b/src/cell.hpp
index 13fb3a72..fbcbe504 100644
--- a/src/cell.hpp
+++ b/src/cell.hpp
@@ -8,6 +8,7 @@
 #include <common_types.hpp>
 #include <cell_tree.hpp>
 #include <morphology.hpp>
+#include <probes.hpp>
 #include <segment.hpp>
 #include <stimulus.hpp>
 #include <util/debug.hpp>
@@ -25,29 +26,11 @@ struct compartment_model {
     std::vector<cell_tree::int_type> segment_index;
 };
 
-struct segment_location {
-    segment_location(cell_lid_type s, double l)
-    : segment(s), position(l)
-    {
-        EXPECTS(position>=0. && position<=1.);
-    }
-    friend bool operator==(segment_location l, segment_location r) {
-        return l.segment==r.segment && l.position==r.position;
-    }
-    cell_lid_type segment;
-    double position;
-};
-
 int find_compartment_index(
     segment_location const& location,
     compartment_model const& graph
 );
 
-enum class probeKind {
-    membrane_voltage,
-    membrane_current
-};
-
 struct probe_spec {
     segment_location location;
     probeKind kind;
@@ -207,7 +190,9 @@ public:
     }
 
     const std::vector<probe_spec>&
-    probes() const { return probes_; }
+    probes() const {
+        return probes_;
+    }
 
 private:
     // storage for connections
diff --git a/src/cell_group.hpp b/src/cell_group.hpp
index 35a67c29..1473411b 100644
--- a/src/cell_group.hpp
+++ b/src/cell_group.hpp
@@ -3,13 +3,13 @@
 #include <memory>
 #include <vector>
 
+#include <cell.hpp>
 #include <common_types.hpp>
 #include <event_binner.hpp>
 #include <event_queue.hpp>
+#include <probes.hpp>
 #include <sampler_function.hpp>
 #include <spike.hpp>
-#include <util/optional.hpp>
-#include <util/make_unique.hpp>
 
 namespace nest {
 namespace mc {
@@ -27,6 +27,7 @@ public:
     virtual const std::vector<spike>& spikes() const = 0;
     virtual void clear_spikes() = 0;
     virtual void add_sampler(cell_member_type probe_id, sampler_function s, time_type start_time = 0) = 0;
+    virtual std::vector<probe_record> probes() const = 0;
 };
 
 using cell_group_ptr = std::unique_ptr<cell_group>;
diff --git a/src/cell_group_factory.cpp b/src/cell_group_factory.cpp
new file mode 100644
index 00000000..43abfaef
--- /dev/null
+++ b/src/cell_group_factory.cpp
@@ -0,0 +1,39 @@
+#include <vector>
+
+#include <backends.hpp>
+#include <cell_group.hpp>
+#include <fvm_multicell.hpp>
+#include <mc_cell_group.hpp>
+#include <util/unique_any.hpp>
+
+namespace nest {
+namespace mc {
+
+using gpu_fvm_cell = mc_cell_group<fvm::fvm_multicell<gpu::backend>>;
+using mc_fvm_cell = mc_cell_group<fvm::fvm_multicell<multicore::backend>>;
+
+cell_group_ptr cell_group_factory(
+        cell_kind kind,
+        cell_gid_type first_gid,
+        const std::vector<util::unique_any>& cells,
+        backend_policy backend)
+{
+    if (backend==backend_policy::prefer_gpu) {
+        switch (kind) {
+        case cell_kind::cable1d_neuron:
+            return make_cell_group<gpu_fvm_cell>(first_gid, cells);
+        default:
+            throw std::runtime_error("unknown cell kind");
+        }
+    }
+
+    switch (kind) {
+    case cell_kind::cable1d_neuron:
+        return make_cell_group<mc_fvm_cell>(first_gid, cells);
+    default:
+        throw std::runtime_error("unknown cell kind");
+    }
+}
+
+} // namespace mc
+} // namespace nest
diff --git a/src/cell_group_factory.hpp b/src/cell_group_factory.hpp
new file mode 100644
index 00000000..e5c47185
--- /dev/null
+++ b/src/cell_group_factory.hpp
@@ -0,0 +1,20 @@
+#pragma once
+
+#include <vector>
+
+#include <backends.hpp>
+#include <cell_group.hpp>
+#include <util/unique_any.hpp>
+
+namespace nest {
+namespace mc {
+
+// Helper factory for building cell groups
+cell_group_ptr cell_group_factory(
+    cell_kind kind,
+    cell_gid_type first_gid,
+    const std::vector<util::unique_any>& cells,
+    backend_policy backend);
+
+} // namespace mc
+} // namespace nest
diff --git a/src/mc_cell_group.hpp b/src/mc_cell_group.hpp
index dec775f8..2ed838e8 100644
--- a/src/mc_cell_group.hpp
+++ b/src/mc_cell_group.hpp
@@ -17,6 +17,7 @@
 #include <util/debug.hpp>
 #include <util/partition.hpp>
 #include <util/range.hpp>
+#include <util/unique_any.hpp>
 
 #include <profiling/profiler.hpp>
 
@@ -59,8 +60,34 @@ public:
             ++source_gid;
         }
         EXPECTS(spike_sources_.size()==n_detectors);
+
+        // Create the enumeration of probes attached to cells in this cell group
+        probes_.reserve(n_probes);
+        for (auto i: util::make_span(0, cells.size())){
+            const cell_gid_type probe_gid = gid_base_ + i;
+            const auto probes_on_cell = cells[i].probes();
+            for (cell_lid_type lid: util::make_span(0, probes_on_cell.size())) {
+                // get the unique global identifier of this probe
+                cell_member_type id{probe_gid, lid};
+
+                // get the location and kind information of the probe
+                const auto p = probes_on_cell[lid];
+
+                // record the combined identifier and probe details
+                probes_.push_back(probe_record{id, p.location, p.kind});
+            }
+        }
     }
 
+    mc_cell_group(cell_gid_type first_gid, const std::vector<util::unique_any>& cells):
+        mc_cell_group(
+            first_gid,
+            util::transform_view(
+                cells,
+                [](const util::unique_any& c) -> const cell& {return util::any_cast<const cell&>(c);})
+        )
+    {}
+
     cell_kind get_cell_kind() const override {
         return cell_kind::cable1d_neuron;
     }
@@ -177,6 +204,10 @@ public:
         sample_events_.push({sampler_index, start_time});
     }
 
+    std::vector<probe_record> probes() const override {
+        return probes_;
+    }
+
 private:
     // gid of first cell in group.
     cell_gid_type gid_base_;
@@ -222,6 +253,9 @@ private:
     // Lookup table for target ids -> local target handle indices.
     std::vector<std::size_t> target_handle_divisions_;
 
+    // Enumeration of the probes that are attached to the cells in the cell group
+    std::vector<probe_record> probes_;
+
     // Build handle index lookup tables.
     template <typename Cells>
     void build_handle_partitions(const Cells& cells) {
diff --git a/src/model.cpp b/src/model.cpp
new file mode 100644
index 00000000..9d0fb262
--- /dev/null
+++ b/src/model.cpp
@@ -0,0 +1,227 @@
+#include <model.hpp>
+
+#include <vector>
+
+#include <backends.hpp>
+#include <cell_group.hpp>
+#include <cell_group_factory.hpp>
+#include <domain_decomposition.hpp>
+#include <recipe.hpp>
+#include <util/span.hpp>
+#include <util/unique_any.hpp>
+#include <profiling/profiler.hpp>
+
+namespace nest {
+namespace mc {
+
+model::model(const recipe& rec, const domain_decomposition& decomp):
+    domain_(decomp)
+{
+    // set up communicator based on partition
+    communicator_ = communicator_type(domain_.gid_group_partition());
+
+    // generate the cell groups in parallel, with one task per cell group
+    cell_groups_.resize(domain_.num_local_groups());
+
+    // thread safe vector for constructing the list of probes in parallel
+    threading::parallel_vector<probe_record> probe_tmp;
+
+    threading::parallel_for::apply(0, cell_groups_.size(),
+        [&](cell_gid_type i) {
+            PE("setup", "cells");
+
+            auto group = domain_.get_group(i);
+            std::vector<util::unique_any> cells(group.end-group.begin);
+
+            for (auto gid: util::make_span(group.begin, group.end)) {
+                auto i = gid-group.begin;
+                cells[i] = rec.get_cell(gid);
+            }
+
+            cell_groups_[i] = cell_group_factory(
+                    group.kind, group.begin, cells, domain_.backend());
+            PL(2);
+        });
+
+    // store probes
+    for (const auto& c: cell_groups_) {
+        util::append(probes_, c->probes());
+    }
+
+    // generate the network connections
+    for (cell_gid_type i: util::make_span(domain_.cell_begin(), domain_.cell_end())) {
+        for (const auto& cc: rec.connections_on(i)) {
+            connection conn{cc.source, cc.dest, cc.weight, cc.delay};
+            communicator_.add_connection(conn);
+        }
+    }
+    communicator_.construct();
+
+    // Allocate an empty queue buffer for each cell group
+    // These must be set initially to ensure that a queue is available for each
+    // cell group for the first time step.
+    current_events().resize(num_groups());
+    future_events().resize(num_groups());
+}
+
+void model::reset() {
+    t_ = 0.;
+    for (auto& group: cell_groups_) {
+        group->reset();
+    }
+
+    communicator_.reset();
+
+    for(auto& q : current_events()) {
+        q.clear();
+    }
+    for(auto& q : future_events()) {
+        q.clear();
+    }
+
+    current_spikes().clear();
+    previous_spikes().clear();
+
+    util::profilers_restart();
+}
+
+time_type model::run(time_type tfinal, time_type dt) {
+    // Calculate the size of the largest possible time integration interval
+    // before communication of spikes is required.
+    // If spike exchange and cell update are serialized, this is the
+    // minimum delay of the network, however we use half this period
+    // to overlap communication and computation.
+    time_type t_interval = communicator_.min_delay()/2;
+
+    time_type tuntil;
+
+    // task that updates cell state in parallel.
+    auto update_cells = [&] () {
+        threading::parallel_for::apply(
+            0u, cell_groups_.size(),
+             [&](unsigned i) {
+                auto &group = cell_groups_[i];
+
+                PE("stepping","events");
+                group->enqueue_events(current_events()[i]);
+                PL();
+
+                group->advance(tuntil, dt);
+
+                PE("events");
+                current_spikes().insert(group->spikes());
+                group->clear_spikes();
+                PL(2);
+            });
+    };
+
+    // task that performs spike exchange with the spikes generated in
+    // the previous integration period, generating the postsynaptic
+    // events that must be delivered at the start of the next
+    // integration period at the latest.
+    auto exchange = [&] () {
+        PE("stepping", "communication");
+
+        PE("exchange");
+        auto local_spikes = previous_spikes().gather();
+        auto global_spikes = communicator_.exchange(local_spikes);
+        PL();
+
+        PE("spike output");
+        local_export_callback_(local_spikes);
+        global_export_callback_(global_spikes.values());
+        PL();
+
+        PE("events");
+        future_events() = communicator_.make_event_queues(global_spikes);
+        PL();
+
+        PL(2);
+    };
+
+    while (t_<tfinal) {
+        tuntil = std::min(t_+t_interval, tfinal);
+
+        event_queues_.exchange();
+        local_spikes_.exchange();
+
+        // empty the spike buffers for the current integration period.
+        // these buffers will store the new spikes generated in update_cells.
+        current_spikes().clear();
+
+        // run the tasks, overlapping if the threading model and number of
+        // available threads permits it.
+        threading::task_group g;
+        g.run(exchange);
+        g.run(update_cells);
+        g.wait();
+
+        t_ = tuntil;
+    }
+
+    // Run the exchange one last time to ensure that all spikes are output
+    // to file.
+    event_queues_.exchange();
+    local_spikes_.exchange();
+    exchange();
+
+    return t_;
+}
+
+// only thread safe if called outside the run() method
+void model::add_artificial_spike(cell_member_type source) {
+    add_artificial_spike(source, t_);
+}
+
+// only thread safe if called outside the run() method
+void model::add_artificial_spike(cell_member_type source, time_type tspike) {
+    if (domain_.is_local_gid(source.gid)) {
+        current_spikes().get().push_back({source, tspike});
+    }
+}
+
+void model::attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom) {
+    const auto idx = domain_.local_group_from_gid(probe_id.gid);
+
+    // only attach samplers for local cells
+    if (idx) {
+        cell_groups_[*idx]->add_sampler(probe_id, f, tfrom);
+    }
+}
+
+const std::vector<probe_record>& model::probes() const {
+    return probes_;
+}
+
+std::size_t model::num_spikes() const {
+    return communicator_.num_spikes();
+}
+
+std::size_t model::num_groups() const {
+    return cell_groups_.size();
+}
+
+std::size_t model::num_cells() const {
+    return domain_.num_local_cells();
+}
+
+void model::set_binning_policy(binning_kind policy, time_type bin_interval) {
+    for (auto& group: cell_groups_) {
+        group->set_binning_policy(policy, bin_interval);
+    }
+}
+
+cell_group& model::group(int i) {
+    return *cell_groups_[i];
+}
+
+void model::set_global_spike_callback(spike_export_function export_callback) {
+    global_export_callback_ = export_callback;
+}
+
+void model::set_local_spike_callback(spike_export_function export_callback) {
+    local_export_callback_ = export_callback;
+}
+
+} // namespace mc
+} // namespace nest
diff --git a/src/model.hpp b/src/model.hpp
index cef1da4d..3d81d196 100644
--- a/src/model.hpp
+++ b/src/model.hpp
@@ -1,275 +1,68 @@
 #pragma once
 
-#include <memory>
 #include <vector>
 
-#include <cstdlib>
-
 #include <backends.hpp>
-#include <fvm_multicell.hpp>
-
-#include <common_types.hpp>
-#include <cell.hpp>
 #include <cell_group.hpp>
+#include <common_types.hpp>
+#include <domain_decomposition.hpp>
 #include <communication/communicator.hpp>
 #include <communication/global_policy.hpp>
-#include <domain_decomposition.hpp>
-#include <mc_cell_group.hpp>
-#include <profiling/profiler.hpp>
 #include <recipe.hpp>
 #include <sampler_function.hpp>
 #include <thread_private_spike_store.hpp>
-#include <threading/threading.hpp>
-#include <trace_sampler.hpp>
 #include <util/nop.hpp>
-#include <util/partition.hpp>
-#include <util/range.hpp>
+#include <util/unique_any.hpp>
 
 namespace nest {
 namespace mc {
 
-using gpu_lowered_cell =
-    mc_cell_group<fvm::fvm_multicell<gpu::backend>>;
-
-using multicore_lowered_cell =
-    mc_cell_group<fvm::fvm_multicell<multicore::backend>>;
-
 class model {
 public:
     using communicator_type = communication::communicator<communication::global_policy>;
     using spike_export_function = std::function<void(const std::vector<spike>&)>;
 
-    struct probe_record {
-        cell_member_type id;
-        probe_spec probe;
-    };
-
-    model(const recipe& rec, const domain_decomposition& decomp):
-        domain_(decomp)
-    {
-        // set up communicator based on partition
-        communicator_ = communicator_type(domain_.gid_group_partition());
-
-        // generate the cell groups in parallel, with one task per cell group
-        cell_groups_.resize(domain_.num_local_groups());
-
-        // thread safe vector for constructing the list of probes in parallel
-        threading::parallel_vector<probe_record> probe_tmp;
-
-        threading::parallel_for::apply(0, cell_groups_.size(),
-            [&](cell_gid_type i) {
-                PE("setup", "cells");
-
-                auto gids = domain_.get_group(i);
-                std::vector<cell> cells(gids.end-gids.begin);
-
-                for (auto gid: util::make_span(gids.begin, gids.end)) {
-                    auto i = gid-gids.begin;
-                    cells[i] = rec.get_cell(gid);
-
-                    cell_lid_type j = 0;
-                    for (const auto& probe: cells[i].probes()) {
-                        cell_member_type probe_id{gid, j++};
-                        probe_tmp.push_back({probe_id, probe});
-                    }
-                }
-
-                if (domain_.backend()==backend_policy::use_multicore) {
-                    cell_groups_[i] = make_cell_group<multicore_lowered_cell>(gids.begin, cells);
-                }
-                else {
-                    cell_groups_[i] = make_cell_group<gpu_lowered_cell>(gids.begin, cells);
-                }
-                PL(2);
-            });
-
-        // store probes
-        probes_.assign(probe_tmp.begin(), probe_tmp.end());
-
-        // generate the network connections
-        for (cell_gid_type i: util::make_span(domain_.cell_begin(), domain_.cell_end())) {
-            for (const auto& cc: rec.connections_on(i)) {
-                connection conn{cc.source, cc.dest, cc.weight, cc.delay};
-                communicator_.add_connection(conn);
-            }
-        }
-        communicator_.construct();
-
-        // Allocate an empty queue buffer for each cell group
-        // These must be set initially to ensure that a queue is available for each
-        // cell group for the first time step.
-        current_events().resize(num_groups());
-        future_events().resize(num_groups());
-    }
-
-    void reset() {
-        t_ = 0.;
-        for (auto& group: cell_groups_) {
-            group->reset();
-        }
-
-        communicator_.reset();
+    model(const recipe& rec, const domain_decomposition& decomp);
 
-        for(auto& q : current_events()) {
-            q.clear();
-        }
-        for(auto& q : future_events()) {
-            q.clear();
-        }
+    void reset();
 
-        current_spikes().clear();
-        previous_spikes().clear();
-
-        util::profilers_restart();
-    }
-
-    time_type run(time_type tfinal, time_type dt) {
-        // Calculate the size of the largest possible time integration interval
-        // before communication of spikes is required.
-        // If spike exchange and cell update are serialized, this is the
-        // minimum delay of the network, however we use half this period
-        // to overlap communication and computation.
-        time_type t_interval = communicator_.min_delay()/2;
-
-        time_type tuntil;
-
-        // task that updates cell state in parallel.
-        auto update_cells = [&] () {
-            threading::parallel_for::apply(
-                0u, cell_groups_.size(),
-                 [&](unsigned i) {
-                    auto &group = cell_groups_[i];
-
-                    PE("stepping","events");
-                    group->enqueue_events(current_events()[i]);
-                    PL();
-
-                    group->advance(tuntil, dt);
-
-                    PE("events");
-                    current_spikes().insert(group->spikes());
-                    group->clear_spikes();
-                    PL(2);
-                });
-        };
-
-        // task that performs spike exchange with the spikes generated in
-        // the previous integration period, generating the postsynaptic
-        // events that must be delivered at the start of the next
-        // integration period at the latest.
-        auto exchange = [&] () {
-            PE("stepping", "communication");
-
-            PE("exchange");
-            auto local_spikes = previous_spikes().gather();
-            auto global_spikes = communicator_.exchange(local_spikes);
-            PL();
-
-            PE("spike output");
-            local_export_callback_(local_spikes);
-            global_export_callback_(global_spikes.values());
-            PL();
-
-            PE("events");
-            future_events() = communicator_.make_event_queues(global_spikes);
-            PL();
-
-            PL(2);
-        };
-
-        while (t_<tfinal) {
-            tuntil = std::min(t_+t_interval, tfinal);
-
-            event_queues_.exchange();
-            local_spikes_.exchange();
-
-            // empty the spike buffers for the current integration period.
-            // these buffers will store the new spikes generated in update_cells.
-            current_spikes().clear();
-
-            // run the tasks, overlapping if the threading model and number of
-            // available threads permits it.
-            threading::task_group g;
-            g.run(exchange);
-            g.run(update_cells);
-            g.wait();
-
-            t_ = tuntil;
-        }
-
-        // Run the exchange one last time to ensure that all spikes are output
-        // to file.
-        event_queues_.exchange();
-        local_spikes_.exchange();
-        exchange();
-
-        return t_;
-    }
+    time_type run(time_type tfinal, time_type dt);
 
     // only thread safe if called outside the run() method
-    void add_artificial_spike(cell_member_type source) {
-        add_artificial_spike(source, t_);
-    }
+    void add_artificial_spike(cell_member_type source);
 
     // only thread safe if called outside the run() method
-    void add_artificial_spike(cell_member_type source, time_type tspike) {
-        if (domain_.is_local_gid(source.gid)) {
-            current_spikes().get().push_back({source, tspike});
-        }
-    }
-
-    void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0) {
-        const auto idx = domain_.local_group_from_gid(probe_id.gid);
+    void add_artificial_spike(cell_member_type source, time_type tspike);
 
-        // only attach samplers for local cells
-        if (idx) {
-            cell_groups_[*idx]->add_sampler(probe_id, f, tfrom);
-        }
-    }
+    void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0);
 
-    const std::vector<probe_record>& probes() const { return probes_; }
+    const std::vector<probe_record>& probes() const;
 
-    std::size_t num_spikes() const {
-        return communicator_.num_spikes();
-    }
+    std::size_t num_spikes() const;
 
-    std::size_t num_groups() const {
-        return cell_groups_.size();
-    }
+    std::size_t num_groups() const;
 
-    std::size_t num_cells() const {
-        return domain_.num_local_cells();
-    }
+    std::size_t num_cells() const;
 
     // Set event binning policy on all our groups.
-    void set_binning_policy(binning_kind policy, time_type bin_interval) {
-        for (auto& group: cell_groups_) {
-            group->set_binning_policy(policy, bin_interval);
-        }
-    }
+    void set_binning_policy(binning_kind policy, time_type bin_interval);
 
     // access cell_group directly
-    cell_group& group(int i) {
-        return *cell_groups_[i];
-    }
+    cell_group& group(int i);
 
     // register a callback that will perform a export of the global
     // spike vector
-    void set_global_spike_callback(spike_export_function export_callback) {
-        global_export_callback_ = export_callback;
-    }
+    void set_global_spike_callback(spike_export_function export_callback);
 
     // register a callback that will perform a export of the rank local
     // spike vector
-    void set_local_spike_callback(spike_export_function export_callback) {
-        local_export_callback_ = export_callback;
-    }
+    void set_local_spike_callback(spike_export_function export_callback);
 
 private:
     const domain_decomposition &domain_;
 
     time_type t_ = 0.;
-    std::vector<std::unique_ptr<cell_group>> cell_groups_;
+    std::vector<cell_group_ptr> cell_groups_;
     communicator_type communicator_;
     std::vector<probe_record> probes_;
 
diff --git a/src/probes.hpp b/src/probes.hpp
new file mode 100644
index 00000000..36b27c3b
--- /dev/null
+++ b/src/probes.hpp
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <cell.hpp>
+#include <morphology.hpp>
+#include <segment.hpp>
+
+namespace nest {
+namespace mc {
+
+enum class probeKind {
+    membrane_voltage,
+    membrane_current
+};
+
+struct probe_record {
+    cell_member_type id;
+    segment_location location;
+    probeKind kind;
+};
+
+} // namespace mc
+} // namespace nest
diff --git a/src/recipe.hpp b/src/recipe.hpp
index 683ac1d6..496543be 100644
--- a/src/recipe.hpp
+++ b/src/recipe.hpp
@@ -5,6 +5,7 @@
 #include <stdexcept>
 
 #include <cell.hpp>
+#include <util/unique_any.hpp>
 
 namespace nest {
 namespace mc {
@@ -47,7 +48,7 @@ class recipe {
 public:
     virtual cell_size_type num_cells() const =0;
 
-    virtual cell get_cell(cell_gid_type) const =0;
+    virtual util::unique_any get_cell(cell_gid_type) const =0;
     virtual cell_kind get_cell_kind(cell_gid_type) const = 0;
 
     virtual cell_count_info get_cell_count_info(cell_gid_type) const =0;
@@ -69,8 +70,8 @@ public:
         return 1;
     }
 
-    cell get_cell(cell_gid_type) const override {
-        return cell(clone_cell, cell_);
+    util::unique_any get_cell(cell_gid_type) const override {
+        return util::unique_any(cell(clone_cell, cell_));
     }
 
     cell_kind get_cell_kind(cell_gid_type) const override {
diff --git a/src/segment.hpp b/src/segment.hpp
index 53a486c2..e208ec69 100644
--- a/src/segment.hpp
+++ b/src/segment.hpp
@@ -475,6 +475,18 @@ DivCompClass div_compartments(const cable_segment* cable) {
     return DivCompClass(cable->num_compartments(), cable->radii(), cable->lengths());
 }
 
+struct segment_location {
+    segment_location(cell_lid_type s, double l):
+        segment(s), position(l)
+    {
+        EXPECTS(position>=0. && position<=1.);
+    }
+    friend bool operator==(segment_location l, segment_location r) {
+        return l.segment==r.segment && l.position==r.position;
+    }
+    cell_lid_type segment;
+    double position;
+};
+
 } // namespace mc
 } // namespace nest
-
diff --git a/tests/unit/test_domain_decomposition.cpp b/tests/unit/test_domain_decomposition.cpp
index ab5d3fe9..a57f83a5 100644
--- a/tests/unit/test_domain_decomposition.cpp
+++ b/tests/unit/test_domain_decomposition.cpp
@@ -17,8 +17,8 @@ public:
         return size_;
     }
 
-    cell get_cell(cell_gid_type) const override {
-        return cell();
+    util::unique_any get_cell(cell_gid_type) const override {
+        return {};
     }
     cell_kind get_cell_kind(cell_gid_type) const override {
         return cell_kind::cable1d_neuron;
-- 
GitLab