diff --git a/arbor/backends/event_delivery.md b/arbor/backends/event_delivery.md index 796ccbae65b187d1120b20ef4aa2e544f3cd7399..29f809985681894f5bcc9d7453c4f6673deccb87 100644 --- a/arbor/backends/event_delivery.md +++ b/arbor/backends/event_delivery.md @@ -12,7 +12,7 @@ destinations and event information. The back-end event management structure is supplied by the corresponding `backend` class as `backend::multi_event_stream`. It presents a limited public interface to the lowered cell, and is passed by reference as a parameter to the mechanism -`deliver_events` method. +`apply_events` method. ### Target handles @@ -84,10 +84,10 @@ For `fvm_multicell` one integration step comprises: 2. Each mechanism is requested to deliver to itself any marked events that are associated with that mechanism, via the virtual - `mechanism::deliver_events(backend::multi_event_stream&)` method. + `mechanism::apply_events(backend::multi_event_stream&)` method. This action must precede the computation of mechanism current contributions - with `mechanism::nrn_current()`. + with `mechanism::compute_currents()`. 3. Marked events are discarded with `events_.drop_marked_events()`. @@ -102,7 +102,7 @@ For `fvm_multicell` one integration step comprises: 6. The solver matrix is assembled and solved to compute the voltages, using the newly computed currents and integration step times. -7. The mechanism states are updated with `mechanism::nrn_state()`. +7. The mechanism states are updated with `mechanism::advance_state()`. 8. The cell times `time_` are set to the integration step stop times `time_to_`. @@ -114,7 +114,7 @@ For `fvm_multicell` one integration step comprises: Towards the end of the integration period, an integration step may have a zero _dt_ for one or more cells within the group, and this needs to be handled correctly: -* Generated mechanism `nrn_state()` methods should be numerically correct with +* Generated mechanism `advance_state()` methods should be numerically correct with zero _dt_; a possibility is to guard the integration step with a _dt_ check. * Matrix assemble and solve must check for zero _dt_. In the FVM `multicore` diff --git a/arbor/backends/gpu/mechanism.cpp b/arbor/backends/gpu/mechanism.cpp index b96a01891c2883f864c6c8881f4f45ec04611c1a..aa500ddcf5ce27445249e094e84b4bd62eee14f9 100644 --- a/arbor/backends/gpu/mechanism.cpp +++ b/arbor/backends/gpu/mechanism.cpp @@ -118,7 +118,6 @@ void mechanism::instantiate(unsigned id, event_stream_ptr_ = &shared.deliverable_events; vec_t_ptr_ = &shared.time; - vec_t_to_ptr_ = &shared.time_to; // If there are no sites (is this ever meaningful?) there is nothing more to do. if (width_==0) { @@ -214,7 +213,7 @@ void mechanism::initialize() { mechanism_ppack_base* pp = ppack_ptr(); pp->vec_t_ = vec_t_ptr_->data(); - nrn_init(); + init(); auto states = state_table(); if(mult_in_place_) { diff --git a/arbor/backends/gpu/mechanism.hpp b/arbor/backends/gpu/mechanism.hpp index 8f04a2383d83e786b65ea2382652ef82172c8ce6..e63d222a28070862175a611237e15e3304c4659b 100644 --- a/arbor/backends/gpu/mechanism.hpp +++ b/arbor/backends/gpu/mechanism.hpp @@ -1,8 +1,8 @@ #pragma once #include <algorithm> -#include <cstddef> #include <cmath> +#include <cstddef> #include <string> #include <utility> #include <vector> @@ -22,15 +22,8 @@ namespace gpu { class mechanism: public arb::concrete_mechanism<arb::gpu::backend> { public: - using value_type = fvm_value_type; - using index_type = fvm_index_type; - using size_type = fvm_size_type; - protected: - using backend = arb::gpu::backend; - using deliverable_event_stream = backend::deliverable_event_stream; - - using array = arb::gpu::array; + using array = arb::gpu::array; using iarray = arb::gpu::iarray; public: @@ -50,17 +43,17 @@ public: void deliver_events() override { // Delegate to derived class, passing in event queue state. - deliver_events(event_stream_ptr_->marked_events()); + apply_events(event_stream_ptr_->marked_events()); } void update_current() override { mechanism_ppack_base* pp = ppack_ptr(); pp->vec_t_ = vec_t_ptr_->data(); - nrn_current(); + compute_currents(); } void update_state() override { mechanism_ppack_base* pp = ppack_ptr(); pp->vec_t_ = vec_t_ptr_->data(); - nrn_state(); + advance_state(); } void update_ions() override { mechanism_ppack_base* pp = ppack_ptr(); @@ -72,7 +65,7 @@ public: // Peek into mechanism state variable; implements arb::gpu::backend::mechanism_field_data. // Returns pointer to GPU memory corresponding to state variable data. - fvm_value_type* field_data(const std::string& state_var); + fvm_value_type* field_data(const std::string& state_var) override; void initialize() override; @@ -91,63 +84,12 @@ protected: deliverable_event_stream* event_stream_ptr_; const array* vec_t_ptr_; - const array* vec_t_to_ptr_; // Bulk storage for index vectors and state and parameter variables. iarray indices_; array data_; bool mult_in_place_; - - // Generated mechanism field, global and ion table lookup types. - // First component is name, second is pointer to corresponing member in - // the mechanism's parameter pack, or for field_default_table, - // the scalar value used to initialize the field. - - using global_table_entry = std::pair<const char*, value_type*>; - using mechanism_global_table = std::vector<global_table_entry>; - - using state_table_entry = std::pair<const char*, value_type**>; - using mechanism_state_table = std::vector<state_table_entry>; - - using field_table_entry = std::pair<const char*, value_type**>; - using mechanism_field_table = std::vector<field_table_entry>; - - using field_default_entry = std::pair<const char*, value_type>; - using mechanism_field_default_table = std::vector<field_default_entry>; - - using ion_state_entry = std::pair<const char*, ion_state_view*>; - using mechanism_ion_state_table = std::vector<ion_state_entry>; - - using ion_index_entry = std::pair<const char*, const index_type**>; - using mechanism_ion_index_table = std::vector<ion_index_entry>; - - virtual void nrn_init() = 0; - - // Generated mechanisms must implement the following methods, together with - // fingerprint(), clone(), kind(), nrn_init(), nrn_state(), nrn_current() - // and deliver_events() (if required) from arb::mechanism. - - // Member tables: introspection into derived mechanism fields, views etc. - // Default implementations correspond to no corresponding fields/globals/ions. - - virtual mechanism_field_table field_table() { return {}; } - virtual mechanism_field_default_table field_default_table() { return {}; } - virtual mechanism_global_table global_table() { return {}; } - virtual mechanism_state_table state_table() { return {}; } - virtual mechanism_ion_state_table ion_state_table() { return {}; } - virtual mechanism_ion_index_table ion_index_table() { return {}; } - - // Report raw size in bytes of mechanism object. - - virtual std::size_t object_sizeof() const = 0; - - // Event delivery, given event queue state: - - virtual void nrn_state() {}; - virtual void nrn_current() {}; - virtual void deliver_events(deliverable_event_stream::state) {}; - virtual void write_ions() {}; }; } // namespace gpu diff --git a/arbor/backends/gpu/mechanism_ppack_base.hpp b/arbor/backends/gpu/mechanism_ppack_base.hpp index 0a221aa1dac3fb6052e5833bc7ae58e428004330..6f2456ca5b112207599f4bfe84e915157329c80b 100644 --- a/arbor/backends/gpu/mechanism_ppack_base.hpp +++ b/arbor/backends/gpu/mechanism_ppack_base.hpp @@ -3,50 +3,32 @@ // Base class for parameter packs for GPU generated kernels: // will be included by .cu generated sources. +#include <arbor/mechanism.hpp> #include <arbor/fvm_types.hpp> namespace arb { namespace gpu { -// Derived ppack structs may have ion_state_view fields: - -struct ion_state_view { - using value_type = fvm_value_type; - using index_type = fvm_index_type; - - value_type* current_density; - value_type* reversal_potential; - value_type* internal_concentration; - value_type* external_concentration; - value_type* ionic_charge; -}; - // Parameter pack base: - struct mechanism_ppack_base { - using value_type = fvm_value_type; - using index_type = fvm_index_type; - using ion_state_view = ::arb::gpu::ion_state_view; - - index_type width_; - index_type n_detectors_; - - const index_type* vec_ci_; - const index_type* vec_di_; - const value_type* vec_t_; - const value_type* vec_t_to_; - const value_type* vec_dt_; - const value_type* vec_v_; - value_type* vec_i_; - value_type* vec_g_; - const value_type* temperature_degC_; - const value_type* diam_um_; - const value_type* time_since_spike_; - - const index_type* node_index_; - const index_type* multiplicity_; - - const value_type* weight_; + fvm_index_type width_; + fvm_index_type n_detectors_; + + const fvm_index_type* vec_ci_; + const fvm_index_type* vec_di_; + const fvm_value_type* vec_t_; + const fvm_value_type* vec_dt_; + const fvm_value_type* vec_v_; + fvm_value_type* vec_i_; + fvm_value_type* vec_g_; + const fvm_value_type* temperature_degC_; + const fvm_value_type* diam_um_; + const fvm_value_type* time_since_spike_; + + const fvm_index_type* node_index_; + const fvm_index_type* multiplicity_; + + const fvm_value_type* weight_; }; } // namespace gpu diff --git a/arbor/backends/gpu/stimulus.cpp b/arbor/backends/gpu/stimulus.cpp index 960beeaf92e44f2f1fed4b8558016b29af09f1cc..ba05951e9c1ade209ab4bd37fbdf65a9b00301ac 100644 --- a/arbor/backends/gpu/stimulus.cpp +++ b/arbor/backends/gpu/stimulus.cpp @@ -17,14 +17,14 @@ public: mechanismKind kind() const override { return ::arb::mechanismKind::point; } mechanism_ptr clone() const override { return mechanism_ptr(new stimulus()); } - void nrn_init() override {} - void nrn_state() override {} - void nrn_current() override { + void init() override {} + void advance_state() override {} + void compute_currents() override { stimulus_current_impl(size(), pp_); } void write_ions() override {} - void deliver_events(deliverable_event_stream::state events) override {} + void apply_events(deliverable_event_stream::state events) override {} mechanism_ppack_base* ppack_ptr() override { return &pp_; diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index 3e7f6c2f5a5b76e8ce44ab6030fb113c9571f708..bb6de164c7344dba2e929bafbce8ab593aa6deed 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -29,26 +29,6 @@ using util::make_range; using util::ptr_by_key; using util::value_by_key; -// Copy elements from source sequence into destination sequence, -// and fill the remaining elements of the destination sequence -// with the given fill value. -// -// Assumes that the iterators for these sequences are at least -// forward iterators. - -template <typename Source, typename Dest, typename Fill> -void copy_extend(const Source& source, Dest&& dest, const Fill& fill) { - using std::begin; - using std::end; - - auto dest_n = std::size(dest); - auto source_n = std::size(source); - - auto n = source_n<dest_n? source_n: dest_n; - auto tail = std::copy_n(begin(source), n, begin(dest)); - std::fill(tail, end(dest), fill); -} - // The derived class (typically generated code from modcc) holds pointers that need // to be set to point inside the shared state, or into the allocated parameter/variable // data block. @@ -68,7 +48,7 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me for (auto &kv: overrides.globals) { if (auto opt_ptr = value_by_key(global_table(), kv.first)) { // Take reference to corresponding derived (generated) mechanism value member. - value_type& global = *opt_ptr.value(); + fvm_value_type& global = *opt_ptr.value(); global = kv.second; } else { @@ -116,7 +96,6 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me } vec_t_ptr_ = &shared.time; - vec_t_to_ptr_ = &shared.time_to; event_stream_ptr_ = &shared.deliverable_events; // If there are no sites (is this ever meaningful?) there is nothing more to do. @@ -132,7 +111,7 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me auto fields = field_table(); std::size_t n_field = fields.size(); - // (First sub-array of data_ is used for width_, below.) + // (First sub-array of data_ is used for weight_, below.) data_ = array((1+n_field)*width_padded_, NAN, pad); for (std::size_t i = 0; i<n_field; ++i) { // Take reference to corresponding derived (generated) mechanism value pointer member. @@ -152,33 +131,48 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const me // * For indices in the padded tail of node_index_, set index to last valid CV index. // * For indices in the padded tail of ion index maps, set index to last valid ion index. - node_index_ = iarray(width_padded_, pad); - - copy_extend(pos_data.cv, node_index_, pos_data.cv.back()); copy_extend(pos_data.weight, make_range(data_.data(), data_.data()+width_padded_), 0); - index_constraints_ = make_constraint_partition(node_index_, width_, simd_width()); if (mult_in_place_) { multiplicity_ = iarray(width_padded_, pad); copy_extend(pos_data.multiplicity, multiplicity_, 1); } - for (auto i: ion_index_table()) { - auto ion_binding = value_by_key(overrides.ion_rebind, i.first).value_or(i.first); - - ion_state* oion = ptr_by_key(shared.ion_data, ion_binding); - if (!oion) { - throw arbor_internal_error("multicore/mechanism: mechanism holds ion with no corresponding shared state"); - } + { + auto table = ion_index_table(); + // Allocate bulk storage + auto count = (table.size() + 1)*width_padded_; + indices_ = iarray(count, 0, pad); + auto base_ptr = indices_.data(); + + // Setup node indices + node_index_ = base_ptr; + base_ptr += width_padded_; + auto node_index = make_range(node_index_, node_index_ + width_padded_); + copy_extend(pos_data.cv, node_index, pos_data.cv.back()); + index_constraints_ = make_constraint_partition(node_index, width_, simd_width()); + + // Create ion indices + for (const auto& [ion_name, ion_index_ptr]: table) { + // Index into shared_state respecting ion rebindings + auto ion_binding = value_by_key(overrides.ion_rebind, ion_name).value_or(ion_name); + ion_state* oion = ptr_by_key(shared.ion_data, ion_binding); + if (!oion) { + throw arbor_internal_error("multicore/mechanism: mechanism holds ion with no corresponding shared state"); + } - auto indices = util::index_into(node_index_, oion->node_index_); + // Set the table entry, step offset to next location + *ion_index_ptr = base_ptr; + base_ptr += width_padded_; - // Take reference to derived (generated) mechanism ion index member. - auto& ion_index = *i.second; - ion_index = iarray(width_padded_, pad); - copy_extend(indices, ion_index, util::back(indices)); + // Obtain index and move data + auto indices = util::index_into(node_index, oion->node_index_); + auto ion_index = make_range(*ion_index_ptr, *ion_index_ptr + width_padded_); + copy_extend(indices, ion_index, util::back(indices)); - arb_assert(compatible_index_constraints(node_index_, ion_index, simd_width())); + // Check SIMD constraints + arb_assert(compatible_index_constraints(node_index, ion_index, simd_width())); + } } } @@ -190,8 +184,8 @@ void mechanism::set_parameter(const std::string& key, const std::vector<fvm_valu if (width_>0) { // Retrieve corresponding derived (generated) mechanism value pointer member. - value_type* field_ptr = *opt_ptr.value(); - util::range<value_type*> field(field_ptr, field_ptr+width_padded_); + fvm_value_type* field_ptr = *opt_ptr.value(); + util::range<fvm_value_type*> field(field_ptr, field_ptr+width_padded_); copy_extend(values, field, values.back()); } @@ -203,7 +197,7 @@ void mechanism::set_parameter(const std::string& key, const std::vector<fvm_valu void mechanism::initialize() { vec_t_ = vec_t_ptr_->data(); - nrn_init(); + init(); auto states = state_table(); diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index afd638c34e5137ecc911eda54273cc6e3f5f5fe9..465ab876f685fd3b7762560d80b5ca2cd76079fa 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -1,8 +1,8 @@ #pragma once #include <algorithm> -#include <cstddef> #include <cmath> +#include <cstddef> #include <string> #include <utility> #include <vector> @@ -11,10 +11,9 @@ #include <arbor/fvm_types.hpp> #include <arbor/mechanism.hpp> +#include "backends/multicore/fvm.hpp" #include "backends/multicore/multicore_common.hpp" #include "backends/multicore/partition_by_constraint.hpp" -#include "backends/multicore/fvm.hpp" - namespace arb { namespace multicore { @@ -22,26 +21,10 @@ namespace multicore { // Base class for all generated mechanisms for multicore back-end. class mechanism: public arb::concrete_mechanism<arb::multicore::backend> { -public: - using value_type = fvm_value_type; - using index_type = fvm_index_type; - using size_type = fvm_size_type; - protected: - using backend = arb::multicore::backend; - using deliverable_event_stream = backend::deliverable_event_stream; - using array = arb::multicore::array; using iarray = arb::multicore::iarray; - struct ion_state_view { - value_type* current_density; - value_type* reversal_potential; - value_type* internal_concentration; - value_type* external_concentration; - value_type* ionic_charge; - }; - public: std::size_t size() const override { return width_; @@ -49,9 +32,8 @@ public: std::size_t memory() const override { std::size_t s = object_sizeof(); - - s += sizeof(value_type) * data_.size(); - s += sizeof(size_type) * width_padded_ * (n_ion_ + 1); // node and ion indices. + s += sizeof(data_[0]) * data_.size(); + s += sizeof(indices_[0]) * indices_.size(); return s; } @@ -60,15 +42,15 @@ public: void deliver_events() override { // Delegate to derived class, passing in event queue state. - deliver_events(event_stream_ptr_->marked_events()); + apply_events(event_stream_ptr_->marked_events()); } void update_current() override { vec_t_ = vec_t_ptr_->data(); - nrn_current(); + compute_currents(); } void update_state() override { vec_t_ = vec_t_ptr_->data(); - nrn_state(); + advance_state(); } void update_ions() override { vec_t_ = vec_t_ptr_->data(); @@ -78,97 +60,44 @@ public: void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; // Peek into mechanism state variable; implements arb::multicore::backend::mechanism_field_data. - fvm_value_type* field_data(const std::string& state_var); + fvm_value_type* field_data(const std::string& state_var) override; protected: - size_type width_ = 0; // Instance width (number of CVs/sites) - size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. - size_type n_ion_ = 0; - size_type n_detectors_ = 0; + fvm_size_type width_ = 0; // Instance width (number of CVs/sites) + fvm_size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. + fvm_size_type n_ion_ = 0; + fvm_size_type n_detectors_ = 0; // Non-owning views onto shared cell state, excepting ion state. - const index_type* vec_ci_; // CV to cell index - const index_type* vec_di_; // CV to indom index - const value_type* vec_t_; // Cell index to cell-local time. - const value_type* vec_t_to_; // Cell index to cell-local integration step time end. - const value_type* vec_dt_; // CV to integration time step. - const value_type* vec_v_; // CV to cell membrane voltage. - value_type* vec_i_; // CV to cell membrane current density. - value_type* vec_g_; // CV to cell membrane conductivity. - const value_type* temperature_degC_; // CV to temperature. - const value_type* diam_um_; // CV to diameter. - const value_type* time_since_spike_; // Vector containing time since last spike, indexed by cell index and n_detectors_ + const fvm_index_type* vec_ci_; // CV to cell index + const fvm_index_type* vec_di_; // CV to indom index + const fvm_value_type* vec_t_; // Cell index to cell-local time. + const fvm_value_type* vec_dt_; // CV to integration time step. + const fvm_value_type* vec_v_; // CV to cell membrane voltage. + fvm_value_type* vec_i_; // CV to cell membrane current density. + fvm_value_type* vec_g_; // CV to cell membrane conductivity. + const fvm_value_type* temperature_degC_; // CV to temperature. + const fvm_value_type* diam_um_; // CV to diameter. + const fvm_value_type* time_since_spike_; // Vector containing time since last spike, indexed by cell index and n_detectors_ const array* vec_t_ptr_; - const array* vec_t_to_ptr_; deliverable_event_stream* event_stream_ptr_; // Per-mechanism index and weight data, excepting ion indices. - iarray node_index_; + fvm_index_type* node_index_; iarray multiplicity_; bool mult_in_place_; constraint_partition index_constraints_; - const value_type* weight_; // Points within data_ after instantiation. + const fvm_value_type* weight_; // Points within data_ after instantiation. // Bulk storage for state and parameter variables. array data_; - - // Generated mechanism field, global and ion table lookup types. - // First component is name, second is pointer to corresponing member in - // the mechanism's parameter pack, or for field_default_table, - // the scalar value used to initialize the field. - - using global_table_entry = std::pair<const char*, value_type*>; - using mechanism_global_table = std::vector<global_table_entry>; - - using state_table_entry = std::pair<const char*, value_type**>; - using mechanism_state_table = std::vector<state_table_entry>; - - using field_table_entry = std::pair<const char*, value_type**>; - using mechanism_field_table = std::vector<field_table_entry>; - - using field_default_entry = std::pair<const char*, value_type>; - using mechanism_field_default_table = std::vector<field_default_entry>; - - using ion_state_entry = std::pair<const char*, ion_state_view*>; - using mechanism_ion_state_table = std::vector<ion_state_entry>; - - using ion_index_entry = std::pair<const char*, iarray*>; - using mechanism_ion_index_table = std::vector<ion_index_entry>; - - virtual void nrn_init() = 0; - - // Generated mechanisms must implement the following methods, together with - // fingerprint(), clone(), kind(), nrn_init(), nrn_state(), nrn_current() - // and deliver_events() (if required) from arb::mechanism. - - // Member tables: introspection into derived mechanism fields, views etc. - // Default implementations correspond to no corresponding fields/globals/ions. - - virtual mechanism_field_table field_table() { return {}; } - virtual mechanism_field_default_table field_default_table() { return {}; } - virtual mechanism_global_table global_table() { return {}; } - virtual mechanism_state_table state_table() { return {}; } - virtual mechanism_ion_state_table ion_state_table() { return {}; } - virtual mechanism_ion_index_table ion_index_table() { return {}; } - - // Simd width used in mechanism. + iarray indices_; virtual unsigned simd_width() const { return 1; } - - // Report raw size in bytes of mechanism object. - - virtual std::size_t object_sizeof() const = 0; - - // Event delivery, given event queue state: - - virtual void nrn_state() {}; - virtual void nrn_current() {}; - virtual void deliver_events(deliverable_event_stream::state) {}; - virtual void write_ions() {}; }; } // namespace multicore diff --git a/arbor/backends/multicore/partition_by_constraint.hpp b/arbor/backends/multicore/partition_by_constraint.hpp index dc7d4d926f07757a133f377f73317c19aa12653d..34b28b379c9b095cd43694f5be45be4bc8b71357 100644 --- a/arbor/backends/multicore/partition_by_constraint.hpp +++ b/arbor/backends/multicore/partition_by_constraint.hpp @@ -99,8 +99,8 @@ bool constexpr is_constraint_stronger(index_constraint a, index_constraint b) { (a==index_constraint::independent && b==index_constraint::contiguous); } -template <typename T> -bool compatible_index_constraints(T& node_index, T& ion_index, unsigned simd_width){ +template <typename T, typename U> +bool compatible_index_constraints(const T& node_index, const U& ion_index, unsigned simd_width){ for (unsigned i = 0; i < node_index.size(); i+= simd_width) { auto nc = idx_constraint(&node_index[i], simd_width); auto ic = idx_constraint(&ion_index[i], simd_width); diff --git a/arbor/backends/multicore/stimulus.cpp b/arbor/backends/multicore/stimulus.cpp index d17d2505c3413220b54f3838edde5ab1f0ea0f6b..d6a8c139a816c1c97739a52e62415dc9dd66d7ca 100644 --- a/arbor/backends/multicore/stimulus.cpp +++ b/arbor/backends/multicore/stimulus.cpp @@ -18,9 +18,9 @@ public: mechanismKind kind() const override { return ::arb::mechanismKind::point; } mechanism_ptr clone() const override { return mechanism_ptr(new stimulus()); } - void nrn_init() override {} - void nrn_state() override {} - void nrn_current() override { + void init() override {} + void advance_state() override {} + void compute_currents() override { size_type n = size(); for (size_type i=0; i<n; ++i) { auto cv = node_index_[i]; @@ -33,7 +33,7 @@ public: } } void write_ions() override {} - void deliver_events(deliverable_event_stream::state events) override {} + void apply_events(deliverable_event_stream::state events) override {} protected: std::size_t object_sizeof() const override { return sizeof(*this); } diff --git a/arbor/include/arbor/mechanism.hpp b/arbor/include/arbor/mechanism.hpp index c1361f3021bcab12b5c55bcaeb322ccb80982e94..1c18457e905faff259746b8f5c994105a576ad6f 100644 --- a/arbor/include/arbor/mechanism.hpp +++ b/arbor/include/arbor/mechanism.hpp @@ -9,7 +9,7 @@ namespace arb { -enum class mechanismKind {point, density, revpot}; +enum class mechanismKind { point, density, revpot }; class mechanism; using mechanism_ptr = std::unique_ptr<mechanism>; @@ -20,6 +20,10 @@ using concrete_mech_ptr = std::unique_ptr<concrete_mechanism<B>>; class mechanism { public: + using value_type = fvm_value_type; + using index_type = fvm_index_type; + using size_type = fvm_size_type; + mechanism() = default; mechanism(const mechanism&) = delete; @@ -49,8 +53,11 @@ public: // Non-global parameters can be set post-instantiation: virtual void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) = 0; + // Peek into state variable + virtual fvm_value_type* field_data(const std::string& var) = 0; + // Simulation interfaces: - virtual void initialize() = 0; + virtual void initialize() {}; virtual void update_state() {}; virtual void update_current() {}; virtual void deliver_events() {}; @@ -60,12 +67,12 @@ public: virtual ~mechanism() = default; // Per-cell group identifier for an instantiated mechanism. - unsigned mechanism_id() const { return mechanism_id_; } + unsigned mechanism_id() const { return mechanism_id_; } protected: // Per-cell group identifier for an instantiation of a mechanism; set by // concrete_mechanism<B>::instantiate() - unsigned mechanism_id_ = -1; + unsigned mechanism_id_ = -1; }; // Backend-specific implementations provide mechanisms that are derived from `concrete_mechanism<Backend>`, @@ -96,14 +103,65 @@ struct mechanism_overrides { std::unordered_map<std::string, std::string> ion_rebind; }; +struct ion_state_view { + fvm_value_type* current_density; + fvm_value_type* reversal_potential; + fvm_value_type* internal_concentration; + fvm_value_type* external_concentration; + fvm_value_type* ionic_charge; +}; + template <typename Backend> class concrete_mechanism: public mechanism { public: using backend = Backend; - // Instantiation: allocate per-instance state; set views/pointers to shared data. - virtual void instantiate(unsigned id, typename backend::shared_state&, const mechanism_overrides&, const mechanism_layout&) = 0; -}; + virtual void instantiate(unsigned id, typename backend::shared_state&, const mechanism_overrides&, const mechanism_layout&) = 0; +protected: + using deliverable_event_stream = typename backend::deliverable_event_stream; + using iarray = typename backend::iarray; + + // Generated mechanism field, global and ion table lookup types. + // First component is name, second is pointer to corresponing member in + // the mechanism's parameter pack, or for field_default_table, + // the scalar value used to initialize the field. + using global_table_entry = std::pair<const char*, value_type*>; + using mechanism_global_table = std::vector<global_table_entry>; + + using state_table_entry = std::pair<const char*, value_type**>; + using mechanism_state_table = std::vector<state_table_entry>; + + using field_table_entry = std::pair<const char*, value_type**>; + using mechanism_field_table = std::vector<field_table_entry>; + + using field_default_entry = std::pair<const char*, value_type>; + using mechanism_field_default_table = std::vector<field_default_entry>; + + using ion_state_entry = std::pair<const char*, ion_state_view*>; + using mechanism_ion_state_table = std::vector<ion_state_entry>; + + using ion_index_entry = std::pair<const char*, index_type**>; + using mechanism_ion_index_table = std::vector<ion_index_entry>; + + // Generated mechanisms must implement the following methods + + // Member tables: introspection into derived mechanism fields, views etc. + // Default implementations correspond to no corresponding fields/globals/ions. + virtual mechanism_field_table field_table() { return {}; } + virtual mechanism_field_default_table field_default_table() { return {}; } + virtual mechanism_global_table global_table() { return {}; } + virtual mechanism_state_table state_table() { return {}; } + virtual mechanism_ion_state_table ion_state_table() { return {}; } + virtual mechanism_ion_index_table ion_index_table() { return {}; } + + virtual void advance_state() {}; + virtual void compute_currents() {}; + virtual void apply_events(typename deliverable_event_stream::state) {}; + virtual void write_ions() {}; + virtual void init() {}; + // Report raw size in bytes of mechanism object. + virtual std::size_t object_sizeof() const = 0; +}; } // namespace arb diff --git a/arbor/util/rangeutil.hpp b/arbor/util/rangeutil.hpp index c8db7115b94bdb7ec70dd61677b92cbe3903b5e6..2d3d481fa42c902a16848b285ce4275a32cd61c2 100644 --- a/arbor/util/rangeutil.hpp +++ b/arbor/util/rangeutil.hpp @@ -431,6 +431,24 @@ auto foldl(BinOp f, Acc a, Seq&& seq) { return a; } +// Copy elements from source sequence into destination sequence, +// and fill the remaining elements of the destination sequence +// with the given fill value. +// +// Assumes that the iterators for these sequences are at least +// forward iterators. +template <typename Source, typename Dest, typename Fill> +void copy_extend(const Source& source, Dest&& dest, const Fill& fill) { + using std::begin; + using std::end; + + auto dest_n = std::size(dest); + auto source_n = std::size(source); + + auto n = source_n<dest_n? source_n: dest_n; + auto tail = std::copy_n(begin(source), n, begin(dest)); + std::fill(tail, end(dest), fill); +} } // namespace util } // namespace arb diff --git a/modcc/mechanism.hpp b/modcc/mechanism.hpp index cf3aeddee7ade4971b85b200b641c6b2243435c4..8151816504cf82e429ce02ed66e424d24ca2c948 100644 --- a/modcc/mechanism.hpp +++ b/modcc/mechanism.hpp @@ -6,7 +6,7 @@ Abstract base class for all mechanisms This works well for the standard interface that is exported by all mechanisms, - i.e. nrn_jacobian(), nrn_current(), etc. The overhead of using virtual dispatch + i.e. compute_currents(), etc. The overhead of using virtual dispatch for such functions is negligable compared to the cost of the operations themselves. However, the friction between compile time and run time dispatch has to be considered carefully: diff --git a/modcc/module.cpp b/modcc/module.cpp index 4bc7a56516269c86100f175e6702388737e2711c..007397650afff58b074da1d5a0b223073f631706 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -206,9 +206,9 @@ bool Module::semantic() { } // All API methods are generated from statements in one of the special procedures - // defined in NMODL, e.g. the nrn_init() API call is based on the INITIAL block. + // defined in NMODL, e.g. the init() API call is based on the INITIAL block. // When creating an API method, the first task is to look up the source procedure, - // i.e. the INITIAL block for nrn_init(). This lambda takes care of this repetative + // i.e. the INITIAL block for init(). This lambda takes care of this repetative // lookup work, with error checking. auto make_empty_api_method = [this] (std::string const& name, std::string const& source_name) @@ -274,7 +274,7 @@ bool Module::semantic() { make_expression<BlockExpression>(Location{}, std::move(ion_assignments), false)); //......................................................................... - // nrn_init : based on the INITIAL block (i.e. the 'initial' procedure + // init : based on the INITIAL block (i.e. the 'initial' procedure //......................................................................... // insert an empty INITIAL block if none was defined in the .mod file. @@ -286,14 +286,14 @@ bool Module::semantic() { procedureKind::initial ); } - auto initial_api = make_empty_api_method("nrn_init", "initial"); + auto initial_api = make_empty_api_method("init", "initial"); auto api_init = initial_api.first; auto proc_init = initial_api.second; auto& init_body = api_init->body()->statements(); api_init->semantic(symbols_); - scope_ptr nrn_init_scope = api_init->scope(); + scope_ptr init_scope = api_init->scope(); for(auto& e : *proc_init->body()) { auto solve_expression = e->is_solve_statement(); @@ -314,14 +314,14 @@ bool Module::semantic() { return false; } - rewrite_body->semantic(nrn_init_scope); + rewrite_body->semantic(init_scope); rewrite_body->accept(solver.get()); } else if (solve_proc->kind() == procedureKind::kinetic && solve_expression->variant() == solverVariant::steadystate) { solver = std::make_unique<SparseSolverVisitor>(solverVariant::steadystate); auto rewrite_body = kinetic_rewrite(solve_proc->body()); - rewrite_body->semantic(nrn_init_scope); + rewrite_body->semantic(init_scope); rewrite_body->accept(solver.get()); } else { error("A SOLVE expression in an INITIAL block can only be used to solve a " @@ -342,7 +342,7 @@ bool Module::semantic() { solve_block = remove_unused_locals(solve_block->is_block()); - // Copy body into nrn_init. + // Copy body into init. for (auto &stmt: solve_block->is_block()->statements()) { init_body.emplace_back(stmt->clone()); } @@ -361,11 +361,11 @@ bool Module::semantic() { // Look in the symbol table for a procedure with the name "breakpoint". // This symbol corresponds to the BREAKPOINT block in the .mod file // There are two APIMethods generated from BREAKPOINT. - // The first is nrn_state, which is the first case handled below. - // The second is nrn_current, which is handled after this block - auto state_api = make_empty_api_method("nrn_state", "breakpoint"); + // The first is advance_state, which is the first case handled below. + // The second is compute_currents, which is handled after this block + auto state_api = make_empty_api_method("advance_state", "breakpoint"); auto api_state = state_api.first; - auto breakpoint = state_api.second; // implies we are building the `nrn_state()` method. + auto breakpoint = state_api.second; // implies we are building the `advance_state()` method. if(!breakpoint) { error("a BREAKPOINT block is required"); @@ -373,9 +373,9 @@ bool Module::semantic() { } api_state->semantic(symbols_); - scope_ptr nrn_state_scope = api_state->scope(); + scope_ptr advance_state_scope = api_state->scope(); - // Grab SOLVE statements, put them in `nrn_state` after translation. + // Grab SOLVE statements, put them in `advance_state` after translation. bool found_solve = false; std::set<std::string> solved_ids; @@ -420,14 +420,14 @@ bool Module::semantic() { solver = std::make_unique<SparseNonlinearSolverVisitor>(); } - rewrite_body->semantic(nrn_state_scope); + rewrite_body->semantic(advance_state_scope); rewrite_body->accept(solver.get()); } else if (deriv->kind()==procedureKind::linear) { solver = std::make_unique<LinearSolverVisitor>(state_vars); auto rewrite_body = linear_rewrite(deriv->body(), state_vars); - rewrite_body->semantic(nrn_state_scope); + rewrite_body->semantic(advance_state_scope); rewrite_body->accept(solver.get()); } else { @@ -452,12 +452,13 @@ bool Module::semantic() { } // May have now redundant local variables; remove these first. - solve_block->semantic(nrn_state_scope); + solve_block->semantic(advance_state_scope); solve_block = remove_unused_locals(solve_block->is_block()); - // Copy body into nrn_state. + // Copy body into advance_state. for (auto& stmt: solve_block->is_block()->statements()) { api_state->body()->statements().push_back(std::move(stmt)); + } } else { @@ -478,10 +479,10 @@ bool Module::semantic() { api_state->semantic(symbols_); //.......................................................... - // nrn_current : update contributions to currents + // compute_currents : update contributions to currents //.......................................................... - NrnCurrentRewriter nrn_current_rewriter; - breakpoint->accept(&nrn_current_rewriter); + NrnCurrentRewriter compute_currents_rewriter; + breakpoint->accept(&compute_currents_rewriter); for (auto& s: breakpoint->body()->statements()) { if(s->is_assignment() && !state_vars.empty()) { @@ -491,18 +492,18 @@ bool Module::semantic() { } } - auto nrn_current_block = nrn_current_rewriter.as_block(); - if (!nrn_current_block) { - append_errors(nrn_current_rewriter.errors()); + auto compute_currents_block = compute_currents_rewriter.as_block(); + if (!compute_currents_block) { + append_errors(compute_currents_rewriter.errors()); return false; } - symbols_["nrn_current"] = + symbols_["compute_currents"] = make_symbol<APIMethod>( - breakpoint->location(), "nrn_current", + breakpoint->location(), "compute_currents", std::vector<expression_ptr>(), - constant_simplify(nrn_current_block)); - symbols_["nrn_current"]->semantic(symbols_); + constant_simplify(compute_currents_block)); + symbols_["compute_currents"]->semantic(symbols_); if (has_symbol("net_receive", symbolKind::procedure)) { auto net_rec_api = make_empty_api_method("net_rec_api", "net_receive"); diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 2d4991c75e2355c28a6b9248b0b6b026ecd66a61..332cd70cf156d8d1c3812b81a9f2d86ceffa55f9 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -109,18 +109,18 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { NetReceiveExpression* net_receive = find_net_receive(module_); PostEventExpression* post_event = find_post_event(module_); - APIMethod* init_api = find_api_method(module_, "nrn_init"); - APIMethod* state_api = find_api_method(module_, "nrn_state"); - APIMethod* current_api = find_api_method(module_, "nrn_current"); + APIMethod* init_api = find_api_method(module_, "init"); + APIMethod* state_api = find_api_method(module_, "advance_state"); + APIMethod* current_api = find_api_method(module_, "compute_currents"); APIMethod* write_ions_api = find_api_method(module_, "write_ions"); bool with_simd = opt.simd.abi!=simd_spec::none; // init_api, state_api, current_api methods are mandatory: - assert_has_scope(init_api, "nrn_init"); - assert_has_scope(state_api, "nrn_state"); - assert_has_scope(current_api, "nrn_current"); + assert_has_scope(init_api, "init"); + assert_has_scope(state_api, "advance_state"); + assert_has_scope(current_api, "compute_currents"); auto vars = local_module_variables(module_); auto ion_deps = module_.ion_deps(); @@ -172,9 +172,6 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "\n" "using backend = ::arb::multicore::backend;\n" "using base = ::arb::multicore::mechanism;\n" - "using value_type = base::value_type;\n" - "using size_type = base::size_type;\n" - "using index_type = base::index_type;\n" "using ::arb::math::exprelr;\n" "using ::arb::math::safeinv;\n" "using ::std::abs;\n" @@ -246,14 +243,14 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "::arb::mechanismKind kind() const override { return " << module_kind_str(module_) << "; }\n" "::arb::mechanism_ptr clone() const override { return ::arb::mechanism_ptr(new " << class_name << "()); }\n" "\n" - "void nrn_init() override;\n" - "void nrn_state() override;\n" - "void nrn_current() override;\n" + "void init() override;\n" + "void advance_state() override;\n" + "void compute_currents() override;\n" "void write_ions() override;\n"; net_receive && out << - "void deliver_events(deliverable_event_stream::state events) override;\n" - "void net_receive(int i_, value_type weight);\n"; + "void apply_events(deliverable_event_stream::state events) override;\n" + "void net_receive(int i_, ::arb::fvm_value_type weight);\n"; post_event && out << "void post_event() override;\n"; @@ -342,14 +339,14 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "private:\n" << indent; for (const auto& scalar: vars.scalars) { - out << "value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; + out << "::arb::fvm_value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; } for (const auto& array: vars.arrays) { - out << "value_type* " << array->name() << ";\n"; + out << "::arb::fvm_value_type* " << array->name() << ";\n"; } for (const auto& dep: ion_deps) { - out << "ion_state_view " << ion_state_field(dep.name) << ";\n"; - out << "iarray " << ion_state_index(dep.name) << ";\n"; + out << "::arb::ion_state_view " << ion_state_field(dep.name) << ";\n"; + out << "::arb::fvm_index_type* " << ion_state_index(dep.name) << ";\n"; } for (auto proc: normal_procedures(module_)) { @@ -371,12 +368,12 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "return ::arb::concrete_mech_ptr<backend>(new " << class_name << "());\n" << popindent << "}\n\n"; - // Nrn methods: + // Interface methods: if (net_receive) { const std::string weight_arg = net_receive->args().empty() ? "weight" : net_receive->args().front()->is_argument()->name(); out << - "void " << class_name << "::deliver_events(deliverable_event_stream::state events) {\n" << indent << + "void " << class_name << "::apply_events(deliverable_event_stream::state events) {\n" << indent << "auto ncell = events.n_streams();\n" "for (size_type c = 0; c<ncell; ++c) {\n" << indent << "auto begin = events.begin_marked(c);\n" @@ -387,7 +384,7 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { "}\n" << popindent << "}\n" "\n" - "void " << class_name << "::net_receive(int i_, value_type " << weight_arg << ") {\n" << indent << + "void " << class_name << "::net_receive(int i_, ::arb::fvm_value_type " << weight_arg << ") {\n" << indent << cprint(net_receive->body()) << popindent << "}\n\n"; } @@ -420,17 +417,17 @@ std::string emit_cpp_source(const Module& module_, const printer_options& opt) { } }; - out << "void " << class_name << "::nrn_init() {\n" << indent; + out << "void " << class_name << "::init() {\n" << indent; emit_body(init_api); out << popindent << "}\n\n"; - out << "void " << class_name << "::nrn_state() {\n" << indent; + out << "void " << class_name << "::advance_state() {\n" << indent; out << profiler_enter("advance_integrate_state"); emit_body(state_api); out << profiler_leave(); out << popindent << "}\n\n"; - out << "void " << class_name << "::nrn_current() {\n" << indent; + out << "void " << class_name << "::compute_currents() {\n" << indent; out << profiler_enter("advance_integrate_current"); emit_body(current_api); out << profiler_leave(); @@ -493,7 +490,7 @@ void CPrinter::visit(BlockExpression* block) { if (!block->is_nested()) { auto locals = pure_locals(block->scope()); if (!locals.empty()) { - out_ << "value_type "; + out_ << "::arb::fvm_value_type "; io::separator sep(", "); for (auto local: locals) { out_ << sep << local->name(); @@ -517,7 +514,7 @@ static std::string index_i_name(const std::string& index_var) { void emit_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(int i_"; for (auto& arg: e->args()) { - out << ", value_type " << arg->is_argument()->name(); + out << ", ::arb::fvm_value_type " << arg->is_argument()->name(); } out << ")"; } @@ -538,7 +535,7 @@ namespace { } void emit_state_read(std::ostream& out, LocalVariable* local) { - out << "value_type " << cprint(local) << " = "; + out << "::arb::fvm_value_type " << cprint(local) << " = "; if (local->is_read()) { auto d = decode_indexed_variable(local->external_variable()); @@ -722,7 +719,7 @@ void SimdPrinter::visit(BlockExpression* block) { } void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { - out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(index_type i_"; + out << "void " << qualified << (qualified.empty()? "": "::") << e->name() << "(::arb::fvm_index_type i_"; for (auto& arg: e->args()) { out << ", const simd_value& " << arg->is_argument()->name(); } @@ -731,7 +728,7 @@ void emit_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const void emit_masked_simd_procedure_proto(std::ostream& out, ProcedureExpression* e, const std::string& qualified) { out << "void " << qualified << (qualified.empty()? "": "::") << e->name() - << "(index_type i_, simd_mask mask_input_"; + << "(::arb::fvm_index_type i_, simd_mask mask_input_"; for (auto& arg: e->args()) { out << ", const simd_value& " << arg->is_argument()->name(); } @@ -870,8 +867,8 @@ void emit_simd_index_initialize(std::ostream& out, const std::list<index_prop>& out << "auto " << index_i_name(index.source_var) << " = " << index.source_var << "[" << index.index_name << "];\n"; break; default: - out << "auto " << index_i_name(index.source_var) << " = simd_cast<simd_index>(indirect(" << index.source_var - << ".data() + " << index.index_name << ", simd_width_));\n"; + out << "auto " << index_i_name(index.source_var) << " = simd_cast<simd_index>(indirect(&" << index.source_var + << "[0] + " << index.index_name << ", simd_width_));\n"; break; } } else { @@ -929,7 +926,7 @@ void emit_simd_for_loop_per_constraint(std::ostream& out, BlockExpression* body, << ".size(); i_++) {\n" << indent; - out << "index_type index_ = index_constraints_." << underlying_constraint_name << "[i_];\n"; + out << "::arb::fvm_index_type index_ = index_constraints_." << underlying_constraint_name << "[i_];\n"; if (requires_weight) { out << "simd_value w_;\n" << "assign(w_, indirect((weight_+index_), simd_width_));\n"; diff --git a/modcc/printer/gpuprinter.cpp b/modcc/printer/gpuprinter.cpp index 31148051a7eed63a3a7b097cea096c2b729fb218..2892528209f12894d70c251fb6d33c4c94f2ca68 100644 --- a/modcc/printer/gpuprinter.cpp +++ b/modcc/printer/gpuprinter.cpp @@ -77,13 +77,13 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op emit_common_defs(out, module_); out << - "void " << class_name << "_nrn_init_(" << ppack_name << "&);\n" - "void " << class_name << "_nrn_state_(" << ppack_name << "&);\n" - "void " << class_name << "_nrn_current_(" << ppack_name << "&);\n" + "void " << class_name << "_init_(" << ppack_name << "&);\n" + "void " << class_name << "_advance_state_(" << ppack_name << "&);\n" + "void " << class_name << "_compute_currents_(" << ppack_name << "&);\n" "void " << class_name << "_write_ions_(" << ppack_name << "&);\n"; net_receive && out << - "void " << class_name << "_deliver_events_(int mech_id, " + "void " << class_name << "_apply_events_(int mech_id, " << ppack_name << "&, deliverable_event_stream_state events);\n"; post_event && out << @@ -102,22 +102,22 @@ std::string emit_gpu_cpp_source(const Module& module_, const printer_options& op "::arb::mechanismKind kind() const override { return " << module_kind_str(module_) << "; }\n" "::arb::mechanism_ptr clone() const override { return ::arb::mechanism_ptr(new " << class_name << "()); }\n" "\n" - "void nrn_init() override {\n" << indent << - class_name << "_nrn_init_(pp_);\n" << popindent << + "void init() override {\n" << indent << + class_name << "_init_(pp_);\n" << popindent << "}\n\n" - "void nrn_state() override {\n" << indent << - class_name << "_nrn_state_(pp_);\n" << popindent << + "void advance_state() override {\n" << indent << + class_name << "_advance_state_(pp_);\n" << popindent << "}\n\n" - "void nrn_current() override {\n" << indent << - class_name << "_nrn_current_(pp_);\n" << popindent << + "void compute_currents() override {\n" << indent << + class_name << "_compute_currents_(pp_);\n" << popindent << "}\n\n" "void write_ions() override {\n" << indent << class_name << "_write_ions_(pp_);\n" << popindent << "}\n\n"; net_receive && out << - "void deliver_events(deliverable_event_stream_state events) override {\n" << indent << - class_name << "_deliver_events_(mechanism_id_, pp_, events);\n" << popindent << + "void apply_events(deliverable_event_stream_state events) override {\n" << indent << + class_name << "_apply_events_(mechanism_id_, pp_, events);\n" << popindent << "}\n\n"; post_event && out << @@ -226,14 +226,14 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt NetReceiveExpression* net_receive = find_net_receive(module_); PostEventExpression* post_event = find_post_event(module_); - APIMethod* init_api = find_api_method(module_, "nrn_init"); - APIMethod* state_api = find_api_method(module_, "nrn_state"); - APIMethod* current_api = find_api_method(module_, "nrn_current"); + APIMethod* init_api = find_api_method(module_, "init"); + APIMethod* state_api = find_api_method(module_, "advance_state"); + APIMethod* current_api = find_api_method(module_, "compute_currents"); APIMethod* write_ions_api = find_api_method(module_, "write_ions"); - assert_has_scope(init_api, "nrn_init"); - assert_has_scope(state_api, "nrn_state"); - assert_has_scope(current_api, "nrn_current"); + assert_has_scope(init_api, "init"); + assert_has_scope(state_api, "advance_state"); + assert_has_scope(current_api, "compute_currents"); io::pfxstringstream out; @@ -250,11 +250,6 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt out << "\n" << namespace_declaration_open(ns_components) << "\n"; - out << - "using value_type = ::arb::gpu::mechanism_ppack_base::value_type;\n" - "using index_type = ::arb::gpu::mechanism_ppack_base::index_type;\n" - "\n"; - emit_common_defs(out, module_); // Print the CUDA code and kernels: @@ -274,7 +269,7 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt << "void " << e->name() << "(" << ppack_name << " params_, int tid_"; for(auto& arg: e->args()) { - out << ", value_type " << arg->is_argument()->name(); + out << ", ::arb::fvm_value_type " << arg->is_argument()->name(); } out << ") {\n" << indent << cuprint(e->body()) @@ -306,7 +301,7 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt if (net_receive) { const std::string weight_arg = net_receive->args().empty() ? "weight" : net_receive->args().front()->is_argument()->name(); out << "__global__\n" - << "void deliver_events(int mech_id_, " << ppack_name << " params_, " + << "void apply_events(int mech_id_, " << ppack_name << " params_, " << "deliverable_event_stream_state events) {\n" << indent << "auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;\n" << "auto const ncell_ = events.n;\n\n" @@ -369,13 +364,13 @@ std::string emit_gpu_cu_source(const Module& module_, const printer_options& opt emit_api_wrapper(write_ions_api); net_receive && out - << "void " << class_name << "_deliver_events_(" + << "void " << class_name << "_apply_events_(" << "int mech_id, " << ppack_name << "& p, deliverable_event_stream_state events) {\n" << indent << "auto n = events.n;\n" << "unsigned block_dim = 128;\n" << "unsigned grid_dim = ::arb::gpu::impl::block_count(n, block_dim);\n" - << "deliver_events<<<grid_dim, block_dim>>>(mech_id, p, events);\n" + << "apply_events<<<grid_dim, block_dim>>>(mech_id, p, events);\n" << popindent << "}\n\n"; post_event && out @@ -405,14 +400,14 @@ void emit_common_defs(std::ostream& out, const Module& module_) { out << "struct " << ppack_name << ": ::arb::gpu::mechanism_ppack_base {\n" << indent; for (const auto& scalar: vars.scalars) { - out << "value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; + out << "::arb::fvm_value_type " << scalar->name() << " = " << as_c_double(scalar->value()) << ";\n"; } for (const auto& array: vars.arrays) { - out << "value_type* " << array->name() << ";\n"; + out << "::arb::fvm_value_type* " << array->name() << ";\n"; } for (const auto& dep: ion_deps) { - out << "ion_state_view " << ion_state_field(dep.name) << ";\n"; - out << "const index_type* " << ion_state_index(dep.name) << ";\n"; + out << "::arb::ion_state_view " << ion_state_field(dep.name) << ";\n"; + out << "::arb::fvm_index_type* " << ion_state_index(dep.name) << ";\n"; } out << popindent << "};\n\n"; @@ -502,7 +497,7 @@ namespace { } void emit_state_read_cu(std::ostream& out, LocalVariable* local) { - out << "value_type " << cuprint(local) << " = "; + out << "::arb::fvm_value_type " << cuprint(local) << " = "; if (local->is_read()) { auto d = decode_indexed_variable(local->external_variable()); diff --git a/test/unit-modcc/test_printers.cpp b/test/unit-modcc/test_printers.cpp index 81bdc785d1cedb20f2b6a7bbe35051be1dcaa842..e88518751057db6e9576bf51367e08f48e50c680 100644 --- a/test/unit-modcc/test_printers.cpp +++ b/test/unit-modcc/test_printers.cpp @@ -124,7 +124,7 @@ TEST(CPrinter, proc_body) { " htau = 1500\n" "}" , - "value_type k;\n" + "::arb::fvm_value_type k;\n" "minf[i_] = 1.0-1.0/(1.0+exp((v-k)/k));\n" "hinf[i_] = 1.0/(1.0+exp((v-k)/k));\n" "mtau[i_] = 0.5;\n" diff --git a/test/unit/mech_private_field_access.cpp b/test/unit/mech_private_field_access.cpp index bc86c9bced448c62a2c5aa2e25579bb3833f2f9c..a0af0953eef0f5d575dc2341062572adc25f2558 100644 --- a/test/unit/mech_private_field_access.cpp +++ b/test/unit/mech_private_field_access.cpp @@ -18,7 +18,7 @@ using field_table_type = std::vector<std::pair<const char*, fvm_value_type**>>; // Multicore mechanisms: -ACCESS_BIND(field_table_type (multicore::mechanism::*)(), multicore_field_table_ptr, &multicore::mechanism::field_table) +ACCESS_BIND(field_table_type (concrete_mechanism<multicore::backend>::*)(), multicore_field_table_ptr, &concrete_mechanism<multicore::backend>::field_table) std::vector<fvm_value_type> mechanism_field(multicore::mechanism* m, const std::string& key) { auto opt_ptr = util::value_by_key((m->*multicore_field_table_ptr)(), key); @@ -31,7 +31,7 @@ std::vector<fvm_value_type> mechanism_field(multicore::mechanism* m, const std:: // GPU mechanisms: #ifdef ARB_GPU_ENABLED -ACCESS_BIND(field_table_type (gpu::mechanism::*)(), gpu_field_table_ptr, &gpu::mechanism::field_table) +ACCESS_BIND(field_table_type (concrete_mechanism<gpu::backend>::*)(), gpu_field_table_ptr, &concrete_mechanism<gpu::backend>::field_table) std::vector<fvm_value_type> mechanism_field(gpu::mechanism* m, const std::string& key) { auto opt_ptr = util::value_by_key((m->*gpu_field_table_ptr)(), key); diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index 19650d69ebfd4adeca375bb3d7bf4320302bb0e7..e4a5635c38e8c2fcc42e3896efa43f4984238df1 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -67,22 +67,22 @@ arb::mechanism* find_mechanism(fvm_cell& fvcell, int index) { using mechanism_global_table = std::vector<std::pair<const char*, arb::fvm_value_type*>>; using mechanism_field_table = std::vector<std::pair<const char*, arb::fvm_value_type**>>; -using mechanism_ion_index_table = std::vector<std::pair<const char*, backend::iarray*>>; +using mechanism_ion_index_table = std::vector<std::pair<const char*, arb::fvm_index_type**>>; ACCESS_BIND(\ - mechanism_global_table (arb::multicore::mechanism::*)(),\ + mechanism_global_table (arb::concrete_mechanism<arb::multicore::backend>::*)(), \ private_global_table_ptr,\ - &arb::multicore::mechanism::global_table) + &arb::concrete_mechanism<arb::multicore::backend>::global_table) ACCESS_BIND(\ - mechanism_field_table (arb::multicore::mechanism::*)(),\ + mechanism_field_table (arb::concrete_mechanism<arb::multicore::backend>::*)(),\ private_field_table_ptr,\ - &arb::multicore::mechanism::field_table) + &arb::concrete_mechanism<arb::multicore::backend>::field_table) ACCESS_BIND(\ - mechanism_ion_index_table (arb::multicore::mechanism::*)(),\ + mechanism_ion_index_table (arb::concrete_mechanism<arb::multicore::backend>::*)(),\ private_ion_index_table_ptr,\ - &arb::multicore::mechanism::ion_index_table) + &arb::concrete_mechanism<arb::multicore::backend>::ion_index_table) using namespace arb; diff --git a/test/unit/test_mechanisms.cpp b/test/unit/test_mechanisms.cpp index f281e6cb41834353274f41414db63f546c44dd18..0ab79bd4bac099a481dc395478afff094008e656 100644 --- a/test/unit/test_mechanisms.cpp +++ b/test/unit/test_mechanisms.cpp @@ -77,7 +77,7 @@ void mech_update(T* mech, unsigned num_iters) { std::map<ionKind, ion<typename T::backend>> ions; mech->set_params(); - mech->nrn_init(); + mech->init(); for (auto ion_kind : ion_kinds()) { auto ion_indexes = util::make_copy<std::vector<typename T::size_type>>( mech->node_index_ diff --git a/test/unit/test_mechcat.cpp b/test/unit/test_mechcat.cpp index 2809b31d8342230ee249f5004cd300a4484f9556..60c9dac7662d1c63acf70c6069ef1e78b9e50f35 100644 --- a/test/unit/test_mechcat.cpp +++ b/test/unit/test_mechcat.cpp @@ -79,6 +79,9 @@ struct common_impl: concrete_mechanism<B> { void set_parameter(const std::string& key, const std::vector<fvm_value_type>& vs) override {} + fvm_value_type* field_data(const std::string& var) override { return nullptr; } + std::size_t object_sizeof() const override { return sizeof(*this); } + void initialize() override {} void update_state() override {} void update_current() override {} @@ -98,8 +101,16 @@ std::string ion_binding(const std::unique_ptr<concrete_mechanism<B>>& mech, cons return impl.ion_bindings_.count(ion)? impl.ion_bindings_.at(ion): ""; } +struct foo_stream_state {}; + +struct foo_stream { + using state = foo_stream_state; +}; struct foo_backend { + using iarray = std::vector<fvm_index_type>; + using deliverable_event_stream = foo_stream; + struct shared_state { std::unordered_map<std::string, fvm_value_type> overrides; std::unordered_map<std::string, std::string> ions = { @@ -115,7 +126,15 @@ struct foo_backend { using foo_mechanism = common_impl<foo_backend>; +struct bar_stream_state {}; + +struct bar_stream { + using state = bar_stream_state; +}; + struct bar_backend { + using iarray = std::vector<fvm_index_type>; + using deliverable_event_stream = bar_stream; struct shared_state { std::unordered_map<std::string, fvm_value_type> overrides; std::unordered_map<std::string, std::string> ions = {