From 2b2f89c4466deda8a3cdb05a229284bf98a41774 Mon Sep 17 00:00:00 2001 From: Ben Cumming <louncharf@gmail.com> Date: Wed, 10 May 2017 14:58:19 +0200 Subject: [PATCH] Feature/generic cell groups (#259) Refactor model and recipe to build models that have different cell types. Refactor recipe::get_cell to return unique_any so that. All recipe definitions in tests and miniap had to be updated to use the new interface. Make a cell_group_factory that forwards arguments for building a cell group to the appropriate cell_group constructor. Refactor model to use generic cell types Constructor now delegates cell_group generation to the cell_group_factory. Add an implementation file model.cpp for model to reduce compilation times (by 2-7 seconds on my desktop). Refactor probe enumeration code in model and cell_group add interface to cell_group for querying enumeration of probes in a cell_group use this interface instead of directly computing enumeration in model constructor, which no longer has easy access to probe information. --- miniapp/miniapp.cpp | 16 +- miniapp/miniapp_recipes.cpp | 6 +- src/CMakeLists.txt | 2 + src/cell.hpp | 23 +-- src/cell_group.hpp | 5 +- src/cell_group_factory.cpp | 39 ++++ src/cell_group_factory.hpp | 20 ++ src/mc_cell_group.hpp | 34 ++++ src/model.cpp | 227 +++++++++++++++++++++ src/model.hpp | 243 ++--------------------- src/probes.hpp | 22 ++ src/recipe.hpp | 7 +- src/segment.hpp | 14 +- tests/unit/test_domain_decomposition.cpp | 4 +- 14 files changed, 402 insertions(+), 260 deletions(-) create mode 100644 src/cell_group_factory.cpp create mode 100644 src/cell_group_factory.hpp create mode 100644 src/model.cpp create mode 100644 src/probes.hpp diff --git a/miniapp/miniapp.cpp b/miniapp/miniapp.cpp index fe06203a..9b6cbc0f 100644 --- a/miniapp/miniapp.cpp +++ b/miniapp/miniapp.cpp @@ -33,7 +33,7 @@ using sample_trace_type = sample_trace<time_type, double>; using file_export_type = io::exporter_spike_file<global_policy>; void banner(); std::unique_ptr<recipe> make_recipe(const io::cl_options&, const probe_distribution&); -std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe); +std::unique_ptr<sample_trace_type> make_trace(probe_record probe); using communicator_type = communication::communicator<communication::global_policy>; void write_trace_json(const sample_trace_type& trace, const std::string& prefix = "trace_"); @@ -117,7 +117,7 @@ int main(int argc, char** argv) { continue; } - traces.push_back(make_trace(probe.id, probe.probe)); + traces.push_back(make_trace(probe)); m.attach_sampler(probe.id, make_trace_sampler(traces.back().get(), sample_dt)); } @@ -207,7 +207,7 @@ std::unique_ptr<recipe> make_recipe(const io::cl_options& options, const probe_d } } -std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_spec probe) { +std::unique_ptr<sample_trace_type> make_trace(probe_record probe) { std::string name = ""; std::string units = ""; @@ -224,7 +224,7 @@ std::unique_ptr<sample_trace_type> make_trace(cell_member_type probe_id, probe_s } name += probe.location.segment? "dend" : "soma"; - return util::make_unique<sample_trace_type>(probe_id, name, units); + return util::make_unique<sample_trace_type>(probe.id, name, units); } void write_trace_json(const sample_trace_type& trace, const std::string& prefix) { @@ -249,13 +249,17 @@ void write_trace_json(const sample_trace_type& trace, const std::string& prefix) } void report_compartment_stats(const recipe& rec) { -std::size_t ncell = rec.num_cells(); + std::size_t ncell = rec.num_cells(); std::size_t ncomp_total = 0; std::size_t ncomp_min = std::numeric_limits<std::size_t>::max(); std::size_t ncomp_max = 0; for (std::size_t i = 0; i<ncell; ++i) { - std::size_t ncomp = rec.get_cell(i).num_compartments(); + std::size_t ncomp = 0; + auto c = rec.get_cell(i); + if (auto ptr = util::any_cast<cell>(&c)) { + ncomp = ptr->num_compartments(); + } ncomp_total += ncomp; ncomp_min = std::min(ncomp_min, ncomp); ncomp_max = std::max(ncomp_max, ncomp); diff --git a/miniapp/miniapp_recipes.cpp b/miniapp/miniapp_recipes.cpp index 0aa559a5..32045e3a 100644 --- a/miniapp/miniapp_recipes.cpp +++ b/miniapp/miniapp_recipes.cpp @@ -6,6 +6,7 @@ #include <cell.hpp> #include <morphology.hpp> #include <util/debug.hpp> +#include <util/unique_any.hpp> #include "miniapp_recipes.hpp" #include "morphology_pool.hpp" @@ -80,7 +81,7 @@ public: cell_size_type num_cells() const override { return ncell_; } - cell get_cell(cell_gid_type i) const override { + util::unique_any get_cell(cell_gid_type i) const override { auto gen = std::mt19937(i); // TODO: replace this with hashing generator... auto cc = get_cell_count_info(i); @@ -108,7 +109,8 @@ public: } } EXPECTS(cell.probes().size()==cc.num_probes); - return cell; + + return util::unique_any(std::move(cell)); } cell_kind get_cell_kind(cell_gid_type) const override { diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c42d2d13..469086ad 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,7 +1,9 @@ set(BASE_SOURCES common_types_io.cpp cell.cpp + cell_group_factory.cpp event_binner.cpp + model.cpp morphology.cpp parameter_list.cpp profiling/memory_meter.cpp diff --git a/src/cell.hpp b/src/cell.hpp index 13fb3a72..fbcbe504 100644 --- a/src/cell.hpp +++ b/src/cell.hpp @@ -8,6 +8,7 @@ #include <common_types.hpp> #include <cell_tree.hpp> #include <morphology.hpp> +#include <probes.hpp> #include <segment.hpp> #include <stimulus.hpp> #include <util/debug.hpp> @@ -25,29 +26,11 @@ struct compartment_model { std::vector<cell_tree::int_type> segment_index; }; -struct segment_location { - segment_location(cell_lid_type s, double l) - : segment(s), position(l) - { - EXPECTS(position>=0. && position<=1.); - } - friend bool operator==(segment_location l, segment_location r) { - return l.segment==r.segment && l.position==r.position; - } - cell_lid_type segment; - double position; -}; - int find_compartment_index( segment_location const& location, compartment_model const& graph ); -enum class probeKind { - membrane_voltage, - membrane_current -}; - struct probe_spec { segment_location location; probeKind kind; @@ -207,7 +190,9 @@ public: } const std::vector<probe_spec>& - probes() const { return probes_; } + probes() const { + return probes_; + } private: // storage for connections diff --git a/src/cell_group.hpp b/src/cell_group.hpp index 35a67c29..1473411b 100644 --- a/src/cell_group.hpp +++ b/src/cell_group.hpp @@ -3,13 +3,13 @@ #include <memory> #include <vector> +#include <cell.hpp> #include <common_types.hpp> #include <event_binner.hpp> #include <event_queue.hpp> +#include <probes.hpp> #include <sampler_function.hpp> #include <spike.hpp> -#include <util/optional.hpp> -#include <util/make_unique.hpp> namespace nest { namespace mc { @@ -27,6 +27,7 @@ public: virtual const std::vector<spike>& spikes() const = 0; virtual void clear_spikes() = 0; virtual void add_sampler(cell_member_type probe_id, sampler_function s, time_type start_time = 0) = 0; + virtual std::vector<probe_record> probes() const = 0; }; using cell_group_ptr = std::unique_ptr<cell_group>; diff --git a/src/cell_group_factory.cpp b/src/cell_group_factory.cpp new file mode 100644 index 00000000..43abfaef --- /dev/null +++ b/src/cell_group_factory.cpp @@ -0,0 +1,39 @@ +#include <vector> + +#include <backends.hpp> +#include <cell_group.hpp> +#include <fvm_multicell.hpp> +#include <mc_cell_group.hpp> +#include <util/unique_any.hpp> + +namespace nest { +namespace mc { + +using gpu_fvm_cell = mc_cell_group<fvm::fvm_multicell<gpu::backend>>; +using mc_fvm_cell = mc_cell_group<fvm::fvm_multicell<multicore::backend>>; + +cell_group_ptr cell_group_factory( + cell_kind kind, + cell_gid_type first_gid, + const std::vector<util::unique_any>& cells, + backend_policy backend) +{ + if (backend==backend_policy::prefer_gpu) { + switch (kind) { + case cell_kind::cable1d_neuron: + return make_cell_group<gpu_fvm_cell>(first_gid, cells); + default: + throw std::runtime_error("unknown cell kind"); + } + } + + switch (kind) { + case cell_kind::cable1d_neuron: + return make_cell_group<mc_fvm_cell>(first_gid, cells); + default: + throw std::runtime_error("unknown cell kind"); + } +} + +} // namespace mc +} // namespace nest diff --git a/src/cell_group_factory.hpp b/src/cell_group_factory.hpp new file mode 100644 index 00000000..e5c47185 --- /dev/null +++ b/src/cell_group_factory.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include <vector> + +#include <backends.hpp> +#include <cell_group.hpp> +#include <util/unique_any.hpp> + +namespace nest { +namespace mc { + +// Helper factory for building cell groups +cell_group_ptr cell_group_factory( + cell_kind kind, + cell_gid_type first_gid, + const std::vector<util::unique_any>& cells, + backend_policy backend); + +} // namespace mc +} // namespace nest diff --git a/src/mc_cell_group.hpp b/src/mc_cell_group.hpp index dec775f8..2ed838e8 100644 --- a/src/mc_cell_group.hpp +++ b/src/mc_cell_group.hpp @@ -17,6 +17,7 @@ #include <util/debug.hpp> #include <util/partition.hpp> #include <util/range.hpp> +#include <util/unique_any.hpp> #include <profiling/profiler.hpp> @@ -59,8 +60,34 @@ public: ++source_gid; } EXPECTS(spike_sources_.size()==n_detectors); + + // Create the enumeration of probes attached to cells in this cell group + probes_.reserve(n_probes); + for (auto i: util::make_span(0, cells.size())){ + const cell_gid_type probe_gid = gid_base_ + i; + const auto probes_on_cell = cells[i].probes(); + for (cell_lid_type lid: util::make_span(0, probes_on_cell.size())) { + // get the unique global identifier of this probe + cell_member_type id{probe_gid, lid}; + + // get the location and kind information of the probe + const auto p = probes_on_cell[lid]; + + // record the combined identifier and probe details + probes_.push_back(probe_record{id, p.location, p.kind}); + } + } } + mc_cell_group(cell_gid_type first_gid, const std::vector<util::unique_any>& cells): + mc_cell_group( + first_gid, + util::transform_view( + cells, + [](const util::unique_any& c) -> const cell& {return util::any_cast<const cell&>(c);}) + ) + {} + cell_kind get_cell_kind() const override { return cell_kind::cable1d_neuron; } @@ -177,6 +204,10 @@ public: sample_events_.push({sampler_index, start_time}); } + std::vector<probe_record> probes() const override { + return probes_; + } + private: // gid of first cell in group. cell_gid_type gid_base_; @@ -222,6 +253,9 @@ private: // Lookup table for target ids -> local target handle indices. std::vector<std::size_t> target_handle_divisions_; + // Enumeration of the probes that are attached to the cells in the cell group + std::vector<probe_record> probes_; + // Build handle index lookup tables. template <typename Cells> void build_handle_partitions(const Cells& cells) { diff --git a/src/model.cpp b/src/model.cpp new file mode 100644 index 00000000..9d0fb262 --- /dev/null +++ b/src/model.cpp @@ -0,0 +1,227 @@ +#include <model.hpp> + +#include <vector> + +#include <backends.hpp> +#include <cell_group.hpp> +#include <cell_group_factory.hpp> +#include <domain_decomposition.hpp> +#include <recipe.hpp> +#include <util/span.hpp> +#include <util/unique_any.hpp> +#include <profiling/profiler.hpp> + +namespace nest { +namespace mc { + +model::model(const recipe& rec, const domain_decomposition& decomp): + domain_(decomp) +{ + // set up communicator based on partition + communicator_ = communicator_type(domain_.gid_group_partition()); + + // generate the cell groups in parallel, with one task per cell group + cell_groups_.resize(domain_.num_local_groups()); + + // thread safe vector for constructing the list of probes in parallel + threading::parallel_vector<probe_record> probe_tmp; + + threading::parallel_for::apply(0, cell_groups_.size(), + [&](cell_gid_type i) { + PE("setup", "cells"); + + auto group = domain_.get_group(i); + std::vector<util::unique_any> cells(group.end-group.begin); + + for (auto gid: util::make_span(group.begin, group.end)) { + auto i = gid-group.begin; + cells[i] = rec.get_cell(gid); + } + + cell_groups_[i] = cell_group_factory( + group.kind, group.begin, cells, domain_.backend()); + PL(2); + }); + + // store probes + for (const auto& c: cell_groups_) { + util::append(probes_, c->probes()); + } + + // generate the network connections + for (cell_gid_type i: util::make_span(domain_.cell_begin(), domain_.cell_end())) { + for (const auto& cc: rec.connections_on(i)) { + connection conn{cc.source, cc.dest, cc.weight, cc.delay}; + communicator_.add_connection(conn); + } + } + communicator_.construct(); + + // Allocate an empty queue buffer for each cell group + // These must be set initially to ensure that a queue is available for each + // cell group for the first time step. + current_events().resize(num_groups()); + future_events().resize(num_groups()); +} + +void model::reset() { + t_ = 0.; + for (auto& group: cell_groups_) { + group->reset(); + } + + communicator_.reset(); + + for(auto& q : current_events()) { + q.clear(); + } + for(auto& q : future_events()) { + q.clear(); + } + + current_spikes().clear(); + previous_spikes().clear(); + + util::profilers_restart(); +} + +time_type model::run(time_type tfinal, time_type dt) { + // Calculate the size of the largest possible time integration interval + // before communication of spikes is required. + // If spike exchange and cell update are serialized, this is the + // minimum delay of the network, however we use half this period + // to overlap communication and computation. + time_type t_interval = communicator_.min_delay()/2; + + time_type tuntil; + + // task that updates cell state in parallel. + auto update_cells = [&] () { + threading::parallel_for::apply( + 0u, cell_groups_.size(), + [&](unsigned i) { + auto &group = cell_groups_[i]; + + PE("stepping","events"); + group->enqueue_events(current_events()[i]); + PL(); + + group->advance(tuntil, dt); + + PE("events"); + current_spikes().insert(group->spikes()); + group->clear_spikes(); + PL(2); + }); + }; + + // task that performs spike exchange with the spikes generated in + // the previous integration period, generating the postsynaptic + // events that must be delivered at the start of the next + // integration period at the latest. + auto exchange = [&] () { + PE("stepping", "communication"); + + PE("exchange"); + auto local_spikes = previous_spikes().gather(); + auto global_spikes = communicator_.exchange(local_spikes); + PL(); + + PE("spike output"); + local_export_callback_(local_spikes); + global_export_callback_(global_spikes.values()); + PL(); + + PE("events"); + future_events() = communicator_.make_event_queues(global_spikes); + PL(); + + PL(2); + }; + + while (t_<tfinal) { + tuntil = std::min(t_+t_interval, tfinal); + + event_queues_.exchange(); + local_spikes_.exchange(); + + // empty the spike buffers for the current integration period. + // these buffers will store the new spikes generated in update_cells. + current_spikes().clear(); + + // run the tasks, overlapping if the threading model and number of + // available threads permits it. + threading::task_group g; + g.run(exchange); + g.run(update_cells); + g.wait(); + + t_ = tuntil; + } + + // Run the exchange one last time to ensure that all spikes are output + // to file. + event_queues_.exchange(); + local_spikes_.exchange(); + exchange(); + + return t_; +} + +// only thread safe if called outside the run() method +void model::add_artificial_spike(cell_member_type source) { + add_artificial_spike(source, t_); +} + +// only thread safe if called outside the run() method +void model::add_artificial_spike(cell_member_type source, time_type tspike) { + if (domain_.is_local_gid(source.gid)) { + current_spikes().get().push_back({source, tspike}); + } +} + +void model::attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom) { + const auto idx = domain_.local_group_from_gid(probe_id.gid); + + // only attach samplers for local cells + if (idx) { + cell_groups_[*idx]->add_sampler(probe_id, f, tfrom); + } +} + +const std::vector<probe_record>& model::probes() const { + return probes_; +} + +std::size_t model::num_spikes() const { + return communicator_.num_spikes(); +} + +std::size_t model::num_groups() const { + return cell_groups_.size(); +} + +std::size_t model::num_cells() const { + return domain_.num_local_cells(); +} + +void model::set_binning_policy(binning_kind policy, time_type bin_interval) { + for (auto& group: cell_groups_) { + group->set_binning_policy(policy, bin_interval); + } +} + +cell_group& model::group(int i) { + return *cell_groups_[i]; +} + +void model::set_global_spike_callback(spike_export_function export_callback) { + global_export_callback_ = export_callback; +} + +void model::set_local_spike_callback(spike_export_function export_callback) { + local_export_callback_ = export_callback; +} + +} // namespace mc +} // namespace nest diff --git a/src/model.hpp b/src/model.hpp index cef1da4d..3d81d196 100644 --- a/src/model.hpp +++ b/src/model.hpp @@ -1,275 +1,68 @@ #pragma once -#include <memory> #include <vector> -#include <cstdlib> - #include <backends.hpp> -#include <fvm_multicell.hpp> - -#include <common_types.hpp> -#include <cell.hpp> #include <cell_group.hpp> +#include <common_types.hpp> +#include <domain_decomposition.hpp> #include <communication/communicator.hpp> #include <communication/global_policy.hpp> -#include <domain_decomposition.hpp> -#include <mc_cell_group.hpp> -#include <profiling/profiler.hpp> #include <recipe.hpp> #include <sampler_function.hpp> #include <thread_private_spike_store.hpp> -#include <threading/threading.hpp> -#include <trace_sampler.hpp> #include <util/nop.hpp> -#include <util/partition.hpp> -#include <util/range.hpp> +#include <util/unique_any.hpp> namespace nest { namespace mc { -using gpu_lowered_cell = - mc_cell_group<fvm::fvm_multicell<gpu::backend>>; - -using multicore_lowered_cell = - mc_cell_group<fvm::fvm_multicell<multicore::backend>>; - class model { public: using communicator_type = communication::communicator<communication::global_policy>; using spike_export_function = std::function<void(const std::vector<spike>&)>; - struct probe_record { - cell_member_type id; - probe_spec probe; - }; - - model(const recipe& rec, const domain_decomposition& decomp): - domain_(decomp) - { - // set up communicator based on partition - communicator_ = communicator_type(domain_.gid_group_partition()); - - // generate the cell groups in parallel, with one task per cell group - cell_groups_.resize(domain_.num_local_groups()); - - // thread safe vector for constructing the list of probes in parallel - threading::parallel_vector<probe_record> probe_tmp; - - threading::parallel_for::apply(0, cell_groups_.size(), - [&](cell_gid_type i) { - PE("setup", "cells"); - - auto gids = domain_.get_group(i); - std::vector<cell> cells(gids.end-gids.begin); - - for (auto gid: util::make_span(gids.begin, gids.end)) { - auto i = gid-gids.begin; - cells[i] = rec.get_cell(gid); - - cell_lid_type j = 0; - for (const auto& probe: cells[i].probes()) { - cell_member_type probe_id{gid, j++}; - probe_tmp.push_back({probe_id, probe}); - } - } - - if (domain_.backend()==backend_policy::use_multicore) { - cell_groups_[i] = make_cell_group<multicore_lowered_cell>(gids.begin, cells); - } - else { - cell_groups_[i] = make_cell_group<gpu_lowered_cell>(gids.begin, cells); - } - PL(2); - }); - - // store probes - probes_.assign(probe_tmp.begin(), probe_tmp.end()); - - // generate the network connections - for (cell_gid_type i: util::make_span(domain_.cell_begin(), domain_.cell_end())) { - for (const auto& cc: rec.connections_on(i)) { - connection conn{cc.source, cc.dest, cc.weight, cc.delay}; - communicator_.add_connection(conn); - } - } - communicator_.construct(); - - // Allocate an empty queue buffer for each cell group - // These must be set initially to ensure that a queue is available for each - // cell group for the first time step. - current_events().resize(num_groups()); - future_events().resize(num_groups()); - } - - void reset() { - t_ = 0.; - for (auto& group: cell_groups_) { - group->reset(); - } - - communicator_.reset(); + model(const recipe& rec, const domain_decomposition& decomp); - for(auto& q : current_events()) { - q.clear(); - } - for(auto& q : future_events()) { - q.clear(); - } + void reset(); - current_spikes().clear(); - previous_spikes().clear(); - - util::profilers_restart(); - } - - time_type run(time_type tfinal, time_type dt) { - // Calculate the size of the largest possible time integration interval - // before communication of spikes is required. - // If spike exchange and cell update are serialized, this is the - // minimum delay of the network, however we use half this period - // to overlap communication and computation. - time_type t_interval = communicator_.min_delay()/2; - - time_type tuntil; - - // task that updates cell state in parallel. - auto update_cells = [&] () { - threading::parallel_for::apply( - 0u, cell_groups_.size(), - [&](unsigned i) { - auto &group = cell_groups_[i]; - - PE("stepping","events"); - group->enqueue_events(current_events()[i]); - PL(); - - group->advance(tuntil, dt); - - PE("events"); - current_spikes().insert(group->spikes()); - group->clear_spikes(); - PL(2); - }); - }; - - // task that performs spike exchange with the spikes generated in - // the previous integration period, generating the postsynaptic - // events that must be delivered at the start of the next - // integration period at the latest. - auto exchange = [&] () { - PE("stepping", "communication"); - - PE("exchange"); - auto local_spikes = previous_spikes().gather(); - auto global_spikes = communicator_.exchange(local_spikes); - PL(); - - PE("spike output"); - local_export_callback_(local_spikes); - global_export_callback_(global_spikes.values()); - PL(); - - PE("events"); - future_events() = communicator_.make_event_queues(global_spikes); - PL(); - - PL(2); - }; - - while (t_<tfinal) { - tuntil = std::min(t_+t_interval, tfinal); - - event_queues_.exchange(); - local_spikes_.exchange(); - - // empty the spike buffers for the current integration period. - // these buffers will store the new spikes generated in update_cells. - current_spikes().clear(); - - // run the tasks, overlapping if the threading model and number of - // available threads permits it. - threading::task_group g; - g.run(exchange); - g.run(update_cells); - g.wait(); - - t_ = tuntil; - } - - // Run the exchange one last time to ensure that all spikes are output - // to file. - event_queues_.exchange(); - local_spikes_.exchange(); - exchange(); - - return t_; - } + time_type run(time_type tfinal, time_type dt); // only thread safe if called outside the run() method - void add_artificial_spike(cell_member_type source) { - add_artificial_spike(source, t_); - } + void add_artificial_spike(cell_member_type source); // only thread safe if called outside the run() method - void add_artificial_spike(cell_member_type source, time_type tspike) { - if (domain_.is_local_gid(source.gid)) { - current_spikes().get().push_back({source, tspike}); - } - } - - void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0) { - const auto idx = domain_.local_group_from_gid(probe_id.gid); + void add_artificial_spike(cell_member_type source, time_type tspike); - // only attach samplers for local cells - if (idx) { - cell_groups_[*idx]->add_sampler(probe_id, f, tfrom); - } - } + void attach_sampler(cell_member_type probe_id, sampler_function f, time_type tfrom = 0); - const std::vector<probe_record>& probes() const { return probes_; } + const std::vector<probe_record>& probes() const; - std::size_t num_spikes() const { - return communicator_.num_spikes(); - } + std::size_t num_spikes() const; - std::size_t num_groups() const { - return cell_groups_.size(); - } + std::size_t num_groups() const; - std::size_t num_cells() const { - return domain_.num_local_cells(); - } + std::size_t num_cells() const; // Set event binning policy on all our groups. - void set_binning_policy(binning_kind policy, time_type bin_interval) { - for (auto& group: cell_groups_) { - group->set_binning_policy(policy, bin_interval); - } - } + void set_binning_policy(binning_kind policy, time_type bin_interval); // access cell_group directly - cell_group& group(int i) { - return *cell_groups_[i]; - } + cell_group& group(int i); // register a callback that will perform a export of the global // spike vector - void set_global_spike_callback(spike_export_function export_callback) { - global_export_callback_ = export_callback; - } + void set_global_spike_callback(spike_export_function export_callback); // register a callback that will perform a export of the rank local // spike vector - void set_local_spike_callback(spike_export_function export_callback) { - local_export_callback_ = export_callback; - } + void set_local_spike_callback(spike_export_function export_callback); private: const domain_decomposition &domain_; time_type t_ = 0.; - std::vector<std::unique_ptr<cell_group>> cell_groups_; + std::vector<cell_group_ptr> cell_groups_; communicator_type communicator_; std::vector<probe_record> probes_; diff --git a/src/probes.hpp b/src/probes.hpp new file mode 100644 index 00000000..36b27c3b --- /dev/null +++ b/src/probes.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include <cell.hpp> +#include <morphology.hpp> +#include <segment.hpp> + +namespace nest { +namespace mc { + +enum class probeKind { + membrane_voltage, + membrane_current +}; + +struct probe_record { + cell_member_type id; + segment_location location; + probeKind kind; +}; + +} // namespace mc +} // namespace nest diff --git a/src/recipe.hpp b/src/recipe.hpp index 683ac1d6..496543be 100644 --- a/src/recipe.hpp +++ b/src/recipe.hpp @@ -5,6 +5,7 @@ #include <stdexcept> #include <cell.hpp> +#include <util/unique_any.hpp> namespace nest { namespace mc { @@ -47,7 +48,7 @@ class recipe { public: virtual cell_size_type num_cells() const =0; - virtual cell get_cell(cell_gid_type) const =0; + virtual util::unique_any get_cell(cell_gid_type) const =0; virtual cell_kind get_cell_kind(cell_gid_type) const = 0; virtual cell_count_info get_cell_count_info(cell_gid_type) const =0; @@ -69,8 +70,8 @@ public: return 1; } - cell get_cell(cell_gid_type) const override { - return cell(clone_cell, cell_); + util::unique_any get_cell(cell_gid_type) const override { + return util::unique_any(cell(clone_cell, cell_)); } cell_kind get_cell_kind(cell_gid_type) const override { diff --git a/src/segment.hpp b/src/segment.hpp index 53a486c2..e208ec69 100644 --- a/src/segment.hpp +++ b/src/segment.hpp @@ -475,6 +475,18 @@ DivCompClass div_compartments(const cable_segment* cable) { return DivCompClass(cable->num_compartments(), cable->radii(), cable->lengths()); } +struct segment_location { + segment_location(cell_lid_type s, double l): + segment(s), position(l) + { + EXPECTS(position>=0. && position<=1.); + } + friend bool operator==(segment_location l, segment_location r) { + return l.segment==r.segment && l.position==r.position; + } + cell_lid_type segment; + double position; +}; + } // namespace mc } // namespace nest - diff --git a/tests/unit/test_domain_decomposition.cpp b/tests/unit/test_domain_decomposition.cpp index ab5d3fe9..a57f83a5 100644 --- a/tests/unit/test_domain_decomposition.cpp +++ b/tests/unit/test_domain_decomposition.cpp @@ -17,8 +17,8 @@ public: return size_; } - cell get_cell(cell_gid_type) const override { - return cell(); + util::unique_any get_cell(cell_gid_type) const override { + return {}; } cell_kind get_cell_kind(cell_gid_type) const override { return cell_kind::cable1d_neuron; -- GitLab