From b1b584fa3ec8f5ccde404819c412198619c55138 Mon Sep 17 00:00:00 2001 From: Sam Yates <halfflat@gmail.com> Date: Mon, 17 Jun 2019 11:17:27 +0200 Subject: [PATCH] Implement mechanism ion rebinding. (#786) * Make global parameters and ion rebindings part of the instantiate interface, rather than insist that all concrete mechanisms implement these as methods. * Mechanism catalogue instance() returns a pair, comprising the concrete mechanism for the requested backend, together with the override data. * Extend catalgoue derive() method to take a list of old ion name -> new ion name remappings for a mechanism. * Add exceptions for ion remapping errors, and check for these errors. * Add convenience function for reparameterizing a mechanism with a single ion dependency over other ions. (This will be used for the future nernst pre-supplied mechanism.) * Add unit tests for: chained renamings of ion names across multiple derivations; correct shared state ion assignment after renamings; ion remapping exceptions; parameterize_over_ion. --- arbor/arbexcept.cpp | 10 ++ arbor/backends/gpu/mechanism.cpp | 35 +++-- arbor/backends/gpu/mechanism.hpp | 4 +- arbor/backends/multicore/mechanism.cpp | 36 +++--- arbor/backends/multicore/mechanism.hpp | 4 +- arbor/fvm_lowered_cell_impl.hpp | 12 +- arbor/include/arbor/arbexcept.hpp | 7 + arbor/include/arbor/mechanism.hpp | 39 ++++-- arbor/include/arbor/mechcat.hpp | 36 ++++-- arbor/include/arbor/mechinfo.hpp | 14 +- arbor/include/arbor/util/optional.hpp | 13 +- arbor/mechcat.cpp | 129 +++++++++++++++++-- test/simple_recipes.hpp | 4 + test/unit/mod/test_ca_read_valence.mod | 4 +- test/unit/test_fvm_lowered.cpp | 69 +++++++--- test/unit/test_mech_temperature.cpp | 10 +- test/unit/test_mechcat.cpp | 172 ++++++++++++++++++++----- test/unit/test_synapses.cpp | 8 +- 18 files changed, 461 insertions(+), 145 deletions(-) diff --git a/arbor/arbexcept.cpp b/arbor/arbexcept.cpp index 10da063a..45a8fb27 100644 --- a/arbor/arbexcept.cpp +++ b/arbor/arbexcept.cpp @@ -72,6 +72,16 @@ invalid_parameter_value::invalid_parameter_value(const std::string& mech_name, c value(value) {} +invalid_ion_remap::invalid_ion_remap(const std::string& mech_name): + arbor_exception(pprintf("invalid ion parameter remapping for mechanism {}", mech_name)) +{} + +invalid_ion_remap::invalid_ion_remap(const std::string& mech_name, const std::string& from_ion = "", const std::string& to_ion = ""): + arbor_exception(pprintf("invalid ion parameter remapping for mechanism {}: {} -> {}", mech_name, from_ion, to_ion)), + from_ion(from_ion), + to_ion(to_ion) +{} + no_such_implementation::no_such_implementation(const std::string& mech_name): arbor_exception(pprintf("missing implementation for mechanism {} in catalogue", mech_name)), mech_name(mech_name) diff --git a/arbor/backends/gpu/mechanism.cpp b/arbor/backends/gpu/mechanism.cpp index 9d32f5ff..8d7c1aff 100644 --- a/arbor/backends/gpu/mechanism.cpp +++ b/arbor/backends/gpu/mechanism.cpp @@ -54,8 +54,22 @@ memory::const_device_view<T> device_view(const T* ptr, std::size_t n) { void mechanism::instantiate(unsigned id, backend::shared_state& shared, - const layout& pos_data) + const mechanism_overrides& overrides, + const mechanism_layout& pos_data) { + // Assign global scalar parameters: + + for (auto &kv: overrides.globals) { + if (auto opt_ptr = value_by_key(global_table(), kv.first)) { + // Take reference to corresponding derived (generated) mechanism value member. + value_type& global = *opt_ptr.value(); + global = kv.second; + } + else { + throw arbor_internal_error("multicore/mechanism: no such mechanism global"); + } + } + mult_in_place_ = !pos_data.multiplicity.empty(); mechanism_id_ = id; width_ = pos_data.cv.size(); @@ -84,7 +98,9 @@ void mechanism::instantiate(unsigned id, num_ions_ = ion_state_tbl.size(); for (auto i: ion_state_tbl) { - util::optional<ion_state&> oion = value_by_key(shared.ion_data, i.first); + auto ion_binding = value_by_key(overrides.ion_rebind, i.first).value_or(i.first); + + util::optional<ion_state&> oion = value_by_key(shared.ion_data, ion_binding); if (!oion) { throw arbor_internal_error("gpu/mechanism: mechanism holds ion with no corresponding shared state"); } @@ -137,7 +153,9 @@ void mechanism::instantiate(unsigned id, arb_assert(num_ions_==ion_index_tbl.size()); for (auto i: make_span(0, num_ions_)) { - util::optional<ion_state&> oion = value_by_key(shared.ion_data, ion_index_tbl[i].first); + auto ion_binding = value_by_key(overrides.ion_rebind, ion_index_tbl[i].first).value_or(ion_index_tbl[i].first); + + util::optional<ion_state&> oion = value_by_key(shared.ion_data, ion_binding); if (!oion) { throw arbor_internal_error("gpu/mechanism: mechanism holds ion with no corresponding shared state"); } @@ -176,17 +194,6 @@ void mechanism::set_parameter(const std::string& key, const std::vector<fvm_valu } } -void mechanism::set_global(const std::string& key, fvm_value_type value) { - if (auto opt_ptr = value_by_key(global_table(), key)) { - // Take reference to corresponding derived (generated) mechanism value member. - value_type& global = *opt_ptr.value(); - global = value; - } - else { - throw arbor_internal_error("gpu/mechanism: no such mechanism global"); - } -} - void multiply_in_place(fvm_value_type* s, const fvm_index_type* p, int n); void mechanism::initialize() { diff --git a/arbor/backends/gpu/mechanism.hpp b/arbor/backends/gpu/mechanism.hpp index c9c242d2..42c7e57a 100644 --- a/arbor/backends/gpu/mechanism.hpp +++ b/arbor/backends/gpu/mechanism.hpp @@ -46,7 +46,7 @@ public: return s; } - void instantiate(fvm_size_type id, backend::shared_state& shared, const layout& w) override; + void instantiate(fvm_size_type id, backend::shared_state& shared, const mechanism_overrides&, const mechanism_layout&) override; void deliver_events() override { // Delegate to derived class, passing in event queue state. @@ -55,8 +55,6 @@ public: void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; - void set_global(const std::string& key, fvm_value_type value) override; - void initialize() override; protected: diff --git a/arbor/backends/multicore/mechanism.cpp b/arbor/backends/multicore/mechanism.cpp index ab101ab7..117fdc25 100644 --- a/arbor/backends/multicore/mechanism.cpp +++ b/arbor/backends/multicore/mechanism.cpp @@ -61,9 +61,22 @@ void copy_extend(const Source& source, Dest&& dest, const Fill& fill) { // these past-the-end values are given a weight of zero, and any corresponding // indices into shared state point to the last valid slot. -void mechanism::instantiate(unsigned id, backend::shared_state& shared, const layout& pos_data) { +void mechanism::instantiate(unsigned id, backend::shared_state& shared, const mechanism_overrides& overrides, const mechanism_layout& pos_data) { using util::make_range; + // Assign global scalar parameters: + + for (auto &kv: overrides.globals) { + if (auto opt_ptr = value_by_key(global_table(), kv.first)) { + // Take reference to corresponding derived (generated) mechanism value member. + value_type& global = *opt_ptr.value(); + global = kv.second; + } + else { + throw arbor_internal_error("multicore/mechanism: no such mechanism global"); + } + } + mult_in_place_ = !pos_data.multiplicity.empty(); util::padded_allocator<> pad(shared.alignment); mechanism_id_ = id; @@ -85,7 +98,9 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const la auto ion_state_tbl = ion_state_table(); n_ion_ = ion_state_tbl.size(); for (auto i: ion_state_tbl) { - util::optional<ion_state&> oion = value_by_key(shared.ion_data, i.first); + auto ion_binding = value_by_key(overrides.ion_rebind, i.first).value_or(i.first); + + util::optional<ion_state&> oion = value_by_key(shared.ion_data, ion_binding); if (!oion) { throw arbor_internal_error("multicore/mechanism: mechanism holds ion with no corresponding shared state"); } @@ -145,7 +160,9 @@ void mechanism::instantiate(unsigned id, backend::shared_state& shared, const la } for (auto i: ion_index_table()) { - util::optional<ion_state&> oion = value_by_key(shared.ion_data, i.first); + auto ion_binding = value_by_key(overrides.ion_rebind, i.first).value_or(i.first); + + util::optional<ion_state&> oion = value_by_key(shared.ion_data, ion_binding); if (!oion) { throw arbor_internal_error("multicore/mechanism: mechanism holds ion with no corresponding shared state"); } @@ -180,23 +197,12 @@ void mechanism::set_parameter(const std::string& key, const std::vector<fvm_valu } } -void mechanism::set_global(const std::string& key, fvm_value_type value) { - if (auto opt_ptr = value_by_key(global_table(), key)) { - // Take reference to corresponding derived (generated) mechanism value member. - value_type& global = *opt_ptr.value(); - global = value; - } - else { - throw arbor_internal_error("multicore/mechanism: no such mechanism global"); - } -} - void mechanism::initialize() { nrn_init(); auto states = state_table(); - if(mult_in_place_) { + if (mult_in_place_) { for (auto& state: states) { for (std::size_t j = 0; j < width_; ++j) { (*state.second)[j] *= multiplicity_[j]; diff --git a/arbor/backends/multicore/mechanism.hpp b/arbor/backends/multicore/mechanism.hpp index 937432ac..bcd5c46e 100644 --- a/arbor/backends/multicore/mechanism.hpp +++ b/arbor/backends/multicore/mechanism.hpp @@ -55,7 +55,7 @@ public: return s; } - void instantiate(fvm_size_type id, backend::shared_state& shared, const layout& w) override; + void instantiate(fvm_size_type id, backend::shared_state& shared, const mechanism_overrides&, const mechanism_layout&) override; void initialize() override; void deliver_events() override { @@ -65,8 +65,6 @@ public: void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) override; - void set_global(const std::string& key, fvm_value_type value) override; - protected: size_type width_ = 0; // Instance width (number of CVs/sites) size_type width_padded_ = 0; // Width rounded up to multiple of pad/alignment. diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 60394aeb..10c5c36a 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -396,7 +396,7 @@ void fvm_lowered_cell_impl<B>::initialize( unsigned data_alignment = util::max_value( util::transform_view(keys(mech_data.mechanisms), - [&](const std::string& name) { return mech_instance(name)->data_alignment(); })); + [&](const std::string& name) { return mech_instance(name).mech->data_alignment(); })); state_ = std::make_unique<shared_state>(num_intdoms, cv_to_intdom, gj_vector, data_alignment? data_alignment: 1u); @@ -420,7 +420,7 @@ void fvm_lowered_cell_impl<B>::initialize( auto& config = m.second; unsigned mech_id = mechanisms_.size(); - mechanism::layout layout; + mechanism_layout layout; layout.cv = config.cv; layout.multiplicity = config.multiplicity; layout.weight.resize(layout.cv.size()); @@ -462,13 +462,13 @@ void fvm_lowered_cell_impl<B>::initialize( } } - auto mech = mech_instance(name); - mech->instantiate(mech_id, *state_, layout); + auto minst = mech_instance(name); + minst.mech->instantiate(mech_id, *state_, minst.overrides, layout); for (auto& pv: config.param_values) { - mech->set_parameter(pv.first, pv.second); + minst.mech->set_parameter(pv.first, pv.second); } - mechanisms_.push_back(mechanism_ptr(mech.release())); + mechanisms_.push_back(mechanism_ptr(minst.mech.release())); } // Collect detectors, probe handles. diff --git a/arbor/include/arbor/arbexcept.hpp b/arbor/include/arbor/arbexcept.hpp index c30904e0..4f7d0d4d 100644 --- a/arbor/include/arbor/arbexcept.hpp +++ b/arbor/include/arbor/arbexcept.hpp @@ -95,6 +95,13 @@ struct invalid_parameter_value: arbor_exception { double value; }; +struct invalid_ion_remap: arbor_exception { + explicit invalid_ion_remap(const std::string& mech_name); + invalid_ion_remap(const std::string& mech_name, const std::string& from_ion, const std::string& to_ion); + std::string from_ion; + std::string to_ion; +}; + struct no_such_implementation: arbor_exception { explicit no_such_implementation(const std::string& mech_name); std::string mech_name; diff --git a/arbor/include/arbor/mechanism.hpp b/arbor/include/arbor/mechanism.hpp index 0fe32db4..b6ed27f9 100644 --- a/arbor/include/arbor/mechanism.hpp +++ b/arbor/include/arbor/mechanism.hpp @@ -23,15 +23,6 @@ public: mechanism() = default; mechanism(const mechanism&) = delete; - // Description of layout of mechanism across cell group: used as parameter in - // `concrete_mechanism<B>::instantiate` (v.i.) - struct layout { - std::vector<fvm_index_type> cv; // Maps in-instance index to CV index. - std::vector<fvm_value_type> weight; // Maps in-instance index to compartment contribution. - std::vector<fvm_index_type> multiplicity; // Number of logical point processes at in-instance index; - // if empty point processes are not coalesced, all multipliers are 1 - }; - // Return fingerprint of mechanism dynamics source description for validation/replication. virtual const mechanism_fingerprint& fingerprint() const = 0; @@ -55,9 +46,6 @@ public: // copy any state. virtual mechanism_ptr clone() const = 0; - // Parameter setting - virtual void set_global(const std::string& param, fvm_value_type value) = 0; - // Non-global parameters can be set post-instantiation: virtual void set_parameter(const std::string& key, const std::vector<fvm_value_type>& values) = 0; @@ -81,6 +69,31 @@ protected: // Backend-specific implementations provide mechanisms that are derived from `concrete_mechanism<Backend>`, // likely via an intermediate class that captures common behaviour for that backend. +// +// `concrete_mechanism` provides the `instantiate` method, which takes the backend-specific shared state, +// together with a layout derived from the discretization, and any global parameter overrides. + +struct mechanism_layout { + // Maps in-instance index to CV index. + std::vector<fvm_index_type> cv; + + // Maps in-instance index to compartment contribution. + std::vector<fvm_value_type> weight; + + // Number of logical point processes at in-instance index; + // if empty, point processes are not coalesced and all multipliers are 1. + std::vector<fvm_index_type> multiplicity; +}; + +struct mechanism_overrides { + // Global scalar parameters (any value down-conversion to fvm_value_type is the + // responsibility of the concrete mechanism). + std::unordered_map<std::string, double> globals; + + // Ion renaming: keys are ion dependency names as + // reported by the mechanism info. + std::unordered_map<std::string, std::string> ion_rebind; +}; template <typename Backend> class concrete_mechanism: public mechanism { @@ -88,7 +101,7 @@ public: using backend = Backend; // Instantiation: allocate per-instance state; set views/pointers to shared data. - virtual void instantiate(unsigned id, typename backend::shared_state&, const layout&) = 0; + virtual void instantiate(unsigned id, typename backend::shared_state&, const mechanism_overrides&, const mechanism_layout&) = 0; }; diff --git a/arbor/include/arbor/mechcat.hpp b/arbor/include/arbor/mechcat.hpp index 981112f4..76a03df4 100644 --- a/arbor/include/arbor/mechcat.hpp +++ b/arbor/include/arbor/mechcat.hpp @@ -14,7 +14,7 @@ // 1. Collection of mechanism metadata indexed by name. // // 2. A further hierarchy of 'derived' mechanisms, that allow specialization of -// global parameters and implementations. +// global parameters, ion bindings, and implementations. // // 3. A map taking mechanism names x back-end class -> mechanism implementation // prototype object. @@ -66,19 +66,31 @@ public: const mechanism_fingerprint& fingerprint(const std::string& name) const; // Construct a schema for a mechanism derived from an existing entry, - // with a sequence of overrides for global scalar parameter settings. - void derive(const std::string& name, const std::string& parent, const std::vector<std::pair<std::string, double>>& global_params); + // with a sequence of overrides for global scalar parameter settings + // and a set of ion renamings. + void derive(const std::string& name, const std::string& parent, + const std::vector<std::pair<std::string, double>>& global_params, + const std::vector<std::pair<std::string, std::string>>& ion_remap = {}); // Remove mechanism from catalogue, together with any derived. void remove(const std::string& name); // Clone the implementation associated with name (search derivation hierarchy starting from - // most derived) and set global parameters according to derivations. + // most derived) and return together with any global overrides. template <typename B> - std::unique_ptr<concrete_mechanism<B>> instance(const std::string& name) const { - mechanism_ptr mech = instance_impl(std::type_index(typeid(B)), name); + struct cat_instance { + std::unique_ptr<concrete_mechanism<B>> mech; + mechanism_overrides overrides; + }; + + template <typename B> + cat_instance<B> instance(const std::string& name) const { + auto mech = instance_impl(std::type_index(typeid(B)), name); - return std::unique_ptr<concrete_mechanism<B>>(dynamic_cast<concrete_mechanism<B>*>(mech.release())); + return cat_instance<B>{ + std::unique_ptr<concrete_mechanism<B>>(dynamic_cast<concrete_mechanism<B>*>(mech.first.release())), + std::move(mech.second) + }; } // Associate a concrete (prototype) mechanism for a given back-end B with a (possibly derived) @@ -100,7 +112,8 @@ private: struct derivation { std::string parent; - string_map<value_type> globals; // global overrides relative to parent + string_map<value_type> globals; // global overrides relative to parent + string_map<std::string> ion_remap; // ion name remap overrides relative to parent mechanism_info_ptr derived_info; }; @@ -111,13 +124,18 @@ private: string_map<std::unordered_map<std::type_index, mechanism_ptr>> impl_map_; // Concrete-type erased helper methods. - mechanism_ptr instance_impl(std::type_index, const std::string&) const; + std::pair<mechanism_ptr, mechanism_overrides> instance_impl(std::type_index, const std::string&) const; void register_impl(std::type_index, const std::string&, mechanism_ptr); // Perform copy and prototype clone from other catalogue (overwrites all entries). void copy_impl(const mechanism_catalogue&); }; +// Convenience routine for deriving a single-ion dependency mechanism 'name' +// over a new ion 'ion', and adding it to the catalogue as 'name/ion'. + +void parameterize_over_ion(mechanism_catalogue&, const std::string& name, const std::string& ion); + // Reference to global default mechanism catalogue. const mechanism_catalogue& global_default_catalogue(); diff --git a/arbor/include/arbor/mechinfo.hpp b/arbor/include/arbor/mechinfo.hpp index 1badd32c..f3cc748a 100644 --- a/arbor/include/arbor/mechinfo.hpp +++ b/arbor/include/arbor/mechinfo.hpp @@ -30,17 +30,17 @@ struct mechanism_field_spec { }; struct ion_dependency { - bool write_concentration_int; - bool write_concentration_ext; + bool write_concentration_int = false; + bool write_concentration_ext = false; - bool read_reversal_potential; - bool write_reversal_potential; + bool read_reversal_potential = false; + bool write_reversal_potential = false; - bool read_ion_charge; + bool read_ion_charge = false; // Support for NMODL 'VALENCE n' construction. - bool verify_ion_charge; - int expected_ion_charge; + bool verify_ion_charge = false; + int expected_ion_charge = 0; }; // A hash of the mechanism dynamics description is used to ensure that offline-compiled diff --git a/arbor/include/arbor/util/optional.hpp b/arbor/include/arbor/util/optional.hpp index e15d5add..c314f3a8 100644 --- a/arbor/include/arbor/util/optional.hpp +++ b/arbor/include/arbor/util/optional.hpp @@ -354,14 +354,17 @@ struct optional<X&>: detail::optional_base<X&> { return assert_set(), ref(); } - template <typename T> - X& value_or(T& alternative) { - return set? ref(): static_cast<X&>(alternative); + X& value_or(X& alternative) & { + return set? ref(): alternative; + } + + const X& value_or(const X& alternative) const& { + return set? ref(): alternative; } template <typename T> - const X& value_or(const T& alternative) const { - return set? ref(): static_cast<const X&>(alternative); + const X value_or(const T& alternative) && { + return set? ref(): static_cast<X>(alternative); } }; diff --git a/arbor/mechcat.cpp b/arbor/mechcat.cpp index 26d21b93..af1eb807 100644 --- a/arbor/mechcat.cpp +++ b/arbor/mechcat.cpp @@ -8,6 +8,49 @@ #include "util/maputil.hpp" +/* Notes on implementation: + * + * The catalogue maintains the following data: + * + * 1. impl_map_ + * + * This contains the mapping between mechanism names and concrete mechanisms + * for a specific backend that have been registered with + * register_implementation(). + * + * It is a two-level map, first indexed by name, and then by the back-end + * type (using std::type_index). + * + * 2. info_map_ + * + * Contains the mechanism_info metadata for a mechanism, as given to the + * catalogue via the add() method. + * + * 3. derived_map_ + * + * A 'derived' mechanism is one that shares the same metadata schema as its + * parent, but with possible overrides to its global scalar parameters and + * to the bindings of its ion names. + * + * The derived_map_ entry for a given mechanism gives: the parent mechanism + * from which it is derived (which might also be a derived mechanism); the + * set of changes to global parameters relative to its parent; the set of + * ion rebindings relative to its parent; and an updated copy of the + * mechanism_info metadata that reflects those changes. + * + * The derived_map_ and info_map_ together constitute a forest: info_map_ has + * an entry for each un-derived mechanism in the catalogue, while for any + * derived mechanism, the parent field in derived_map_ provides the parent in + * the derivation tree, or a root mechanism which is catalogued in info_map_. + * + * When an instance of the mechanism is requested from the catalogue, the + * instance_impl_() function walks up the derivation tree to find the first + * entry which has an associated implementation. It then accumulates the set of + * global parameter and ion overrides that need to be applied, starting from + * the top-most (least-derived) ancestor and working down to the requested derived + * mechanism. + */ + namespace arb { using util::value_by_key; @@ -45,7 +88,10 @@ const mechanism_fingerprint& mechanism_catalogue::fingerprint(const std::string& throw no_such_mechanism(name); } -void mechanism_catalogue::derive(const std::string& name, const std::string& parent, const std::vector<std::pair<std::string, double>>& global_params) { +void mechanism_catalogue::derive(const std::string& name, const std::string& parent, + const std::vector<std::pair<std::string, double>>& global_params, + const std::vector<std::pair<std::string, std::string>>& ion_remap_vec) +{ if (has(name)) { throw duplicate_mechanism(name); } @@ -54,9 +100,12 @@ void mechanism_catalogue::derive(const std::string& name, const std::string& par throw no_such_mechanism(parent); } - derivation deriv = {parent, {}, nullptr}; + string_map<std::string> ion_remap_map(ion_remap_vec.begin(), ion_remap_vec.end()); + derivation deriv = {parent, {}, ion_remap_map, nullptr}; mechanism_info_ptr info = mechanism_info_ptr(new mechanism_info((*this)[deriv.parent])); + // Update global parameter values in info for derived mechanism. + for (const auto& kv: global_params) { const auto& param = kv.first; const auto& value = kv.second; @@ -74,6 +123,35 @@ void mechanism_catalogue::derive(const std::string& name, const std::string& par info->globals.at(param).default_value = value; } + for (const auto& kv: ion_remap_vec) { + if (!info->ions.count(kv.first)) { + throw invalid_ion_remap(name, kv.first, kv.second); + } + } + + // Update ion dependencies in info to reflect the requested ion remapping. + + string_map<ion_dependency> new_ions; + for (const auto& kv: info->ions) { + if (auto new_ion = value_by_key(ion_remap_map, kv.first)) { + if (!new_ions.insert({*new_ion, kv.second}).second) { + throw invalid_ion_remap(name, kv.first, *new_ion); + } + } + else { + if (!new_ions.insert(kv).second) { + // (find offending remap to report in exception) + for (const auto& entry: ion_remap_map) { + if (entry.second==kv.first) { + throw invalid_ion_remap(name, kv.first, entry.second); + } + } + throw arbor_internal_error("inconsistent catalogue ion remap state"); + } + } + } + info->ions = std::move(new_ions); + deriv.derived_info = std::move(info); derived_map_[name] = std::move(deriv); } @@ -108,7 +186,10 @@ void mechanism_catalogue::remove(const std::string& name) { } while (n_delete>0); } -std::unique_ptr<mechanism> mechanism_catalogue::instance_impl(std::type_index tidx, const std::string& name) const { +std::pair<std::unique_ptr<mechanism>, mechanism_overrides> +mechanism_catalogue::instance_impl(std::type_index tidx, const std::string& name) const { + std::pair<std::unique_ptr<mechanism>, mechanism_overrides> mech; + // Find implementation associated with this name or its closest ancestor. auto impl_name = name; @@ -131,18 +212,38 @@ std::unique_ptr<mechanism> mechanism_catalogue::instance_impl(std::type_index ti } } - std::unique_ptr<mechanism> mech = prototype->clone(); + mech.first = prototype->clone(); - auto apply_globals = [this](auto& self, const std::string& name, mechanism* mptr) -> void { + // Recurse up the derivation tree to find the most distant ancestor; + // accumulate global parameter settings and ion remappings down to the + // requested mechanism. + + auto apply_globals = [this](auto& self, const std::string& name, mechanism_overrides& over) -> void { if (auto p = value_by_key(derived_map_, name)) { - self(self, p->parent, mptr); + self(self, p->parent, over); for (auto& kv: p->globals) { - mptr->set_global(kv.first, kv.second); + over.globals[kv.first] = kv.second; + } + + if (!p->ion_remap.empty()) { + string_map<std::string> new_rebind = p->ion_remap; + for (auto& kv: over.ion_rebind) { + if (auto opt_v = value_by_key(p->ion_remap, kv.second)) { + new_rebind.erase(kv.second); + new_rebind[kv.first] = *opt_v; + } + } + for (auto& kv: over.ion_rebind) { + if (!value_by_key(p->ion_remap, kv.second)) { + new_rebind[kv.first] = kv.second; + } + } + std::swap(new_rebind, over.ion_rebind); } } }; - apply_globals(apply_globals, name, mech.get()); + apply_globals(apply_globals, name, mech.second); return mech; } @@ -165,7 +266,7 @@ void mechanism_catalogue::copy_impl(const mechanism_catalogue& other) { derived_map_.clear(); for (const auto& kv: other.derived_map_) { const derivation& v = kv.second; - derived_map_[kv.first] = {v.parent, v.globals, make_unique<mechanism_info>(*v.derived_info)}; + derived_map_[kv.first] = {v.parent, v.globals, v.ion_remap, make_unique<mechanism_info>(*v.derived_info)}; } impl_map_.clear(); @@ -179,4 +280,14 @@ void mechanism_catalogue::copy_impl(const mechanism_catalogue& other) { } } +void parameterize_over_ion(mechanism_catalogue& cat, const std::string& name, const std::string& ion) { + mechanism_info info = cat[name]; + if (info.ions.size()!=1) { + throw invalid_ion_remap(name); + } + + std::string from_ion = info.ions.begin()->first; + cat.derive(name+"/"+ion, name, {}, {{from_ion, ion}}); +} + } // namespace arb diff --git a/test/simple_recipes.hpp b/test/simple_recipes.hpp index b92ae176..e6da789a 100644 --- a/test/simple_recipes.hpp +++ b/test/simple_recipes.hpp @@ -50,6 +50,10 @@ public: return catalogue_; } + void add_ion(const char* name, int charge, double iconc, double econc) { + cell_gprop_.ion_default[name] = {charge, iconc, econc}; + } + protected: std::unordered_map<cell_gid_type, std::vector<probe_info>> probes_; cable_cell_global_properties cell_gprop_; diff --git a/test/unit/mod/test_ca_read_valence.mod b/test/unit/mod/test_ca_read_valence.mod index 97f40707..23bd9ecf 100644 --- a/test/unit/mod/test_ca_read_valence.mod +++ b/test/unit/mod/test_ca_read_valence.mod @@ -10,11 +10,11 @@ PARAMETER {} ASSIGNED {} STATE { - record_zca + record_z } INITIAL { - record_zca = zca + record_z = zca } BREAKPOINT { diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index 011939b0..a693f0fa 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -55,6 +55,11 @@ arb::mechanism* find_mechanism(fvm_cell& fvcell, const std::string& name) { return nullptr; } +arb::mechanism* find_mechanism(fvm_cell& fvcell, int index) { + auto& mechs = fvcell.*private_mechanisms_ptr; + return index<(int)mechs.size()? mechs[index].get(): nullptr; +} + // Access to mechanism-internal data: using mechanism_global_table = std::vector<std::pair<const char*, arb::fvm_value_type*>>; @@ -490,31 +495,61 @@ TEST(fvm_lowered, derived_mechs) { // Test that ion charge is propagated into mechanism variable. TEST(fvm_lowered, read_valence) { - std::vector<cable_cell> cells(1); - - cable_cell& c = cells[0]; - c.add_soma(6.0)->add_mechanism("test_ca_read_valence"); - - cable1d_recipe rec(cells); - rec.catalogue() = make_unit_test_catalogue(); + execution_context context; std::vector<target_handle> targets; std::vector<fvm_index_type> cell_to_intdom; probe_association_map<probe_handle> probe_map; - execution_context context; - fvm_cell fvcell(context); - fvcell.initialize({0}, rec, cell_to_intdom, targets, probe_map); + { + std::vector<cable_cell> cells(1); + + cable_cell& c = cells[0]; + auto soma = c.add_soma(6.0); + soma->add_mechanism("test_ca_read_valence"); + + cable1d_recipe rec(cells); + rec.catalogue() = make_unit_test_catalogue(); + + fvm_cell fvcell(context); + fvcell.initialize({0}, rec, cell_to_intdom, targets, probe_map); + + // test_ca_read_valence initialization should write ca ion valence + // to state variable 'record_zca': + + auto mech_ptr = dynamic_cast<multicore::mechanism*>(find_mechanism(fvcell, "test_ca_read_valence")); + auto opt_record_z_ptr = util::value_by_key((mech_ptr->*private_field_table_ptr)(), "record_z"s); + + ASSERT_TRUE(opt_record_z_ptr); + auto& record_z = *opt_record_z_ptr.value(); + ASSERT_EQ(2.0, record_z[0]); + } + + { + // Check ion renaming. + std::vector<cable_cell> cells(1); + + cable_cell& c = cells[0]; + auto soma = c.add_soma(6.0); + soma->add_mechanism("cr_read_valence"); + + cable1d_recipe rec(cells); + rec.catalogue() = make_unit_test_catalogue(); - // test_ca_read_valence initialization should write ca ion valence - // to state variable 'record_zca': + rec.catalogue().derive("na_read_valence", "test_ca_read_valence", {}, {{"ca", "na"}}); + rec.catalogue().derive("cr_read_valence", "na_read_valence", {}, {{"na", "cr"}}); + rec.add_ion("cr", 7, 0, 0); - auto mech_ptr = dynamic_cast<multicore::mechanism*>(find_mechanism(fvcell, "test_ca_read_valence")); - auto opt_record_zca_ptr = util::value_by_key((mech_ptr->*private_field_table_ptr)(), "record_zca"s); + fvm_cell fvcell(context); + fvcell.initialize({0}, rec, cell_to_intdom, targets, probe_map); + + auto cr_mech_ptr = dynamic_cast<multicore::mechanism*>(find_mechanism(fvcell, 0)); + auto cr_opt_record_z_ptr = util::value_by_key((cr_mech_ptr->*private_field_table_ptr)(), "record_z"s); - ASSERT_TRUE(opt_record_zca_ptr); - auto& record_zca = *opt_record_zca_ptr.value(); - ASSERT_EQ(2.0, record_zca[0]); + ASSERT_TRUE(cr_opt_record_z_ptr); + auto& cr_record_z = *cr_opt_record_z_ptr.value(); + ASSERT_EQ(7.0, cr_record_z[0]); + } } // Test area-weighted linear combination of ion species concentrations diff --git a/test/unit/test_mech_temperature.cpp b/test/unit/test_mech_temperature.cpp index 3fcb4caa..0fae8d43 100644 --- a/test/unit/test_mech_temperature.cpp +++ b/test/unit/test_mech_temperature.cpp @@ -25,17 +25,21 @@ void run_celsius_test() { std::vector<fvm_index_type> cv_to_intdom(ncv, 0); std::vector<fvm_gap_junction> gj = {}; - auto celsius_test = cat.instance<backend>("celsius_test"); + auto instance = cat.instance<backend>("celsius_test"); + auto& celsius_test = instance.mech; + auto shared_state = std::make_unique<typename backend::shared_state>( ncell, cv_to_intdom, gj, celsius_test->data_alignment()); - mechanism::layout layout; + mechanism_layout layout; + mechanism_overrides overrides; + layout.weight.assign(ncv, 1.); for (fvm_size_type i = 0; i<ncv; ++i) { layout.cv.push_back(i); } - celsius_test->instantiate(0, *shared_state, layout); + celsius_test->instantiate(0, *shared_state, overrides, layout); double temperature_K = 300.; double temperature_C = temperature_K-273.15; diff --git a/test/unit/test_mechcat.cpp b/test/unit/test_mechcat.cpp index 44df5f86..db77676d 100644 --- a/test/unit/test_mechcat.cpp +++ b/test/unit/test_mechcat.cpp @@ -34,7 +34,7 @@ mechanism_info burble_info = { {"xyzzy", {field_kind::global, "mV", 5.1, -20, 20.}}}, {}, {}, - {}, + {{"x", {}}}, "burbleprint" }; @@ -43,7 +43,7 @@ mechanism_info fleeb_info = { {"norf", {field_kind::global, "mGy", 0.1, 0, 5000.}}}, {}, {}, - {}, + {{"a", {}}, {"b", {}}, {"c", {}}, {"d", {}}}, "fleebprint" }; @@ -51,13 +51,22 @@ mechanism_info fleeb_info = { template <typename B> struct common_impl: concrete_mechanism<B> { - void instantiate(fvm_size_type id, typename B::shared_state& state, const mechanism::layout& l) override { + void instantiate(fvm_size_type id, typename B::shared_state& state, const mechanism_overrides& o, const mechanism_layout& l) override { width_ = l.cv.size(); // Write mechanism global values to shared state to test instatiation call and catalogue global // variable overrides. - for (auto& kv: overrides_) { + for (auto& kv: o.globals) { state.overrides.insert(kv); } + + for (auto& ion: mech_ions) { + if (o.ion_rebind.count(ion)) { + ion_bindings_[ion] = state.ions.at(o.ion_rebind.at(ion)); + } + else { + ion_bindings_[ion] = state.ions.at(ion); + } + } } std::size_t memory() const override { return 10u; } @@ -65,23 +74,37 @@ struct common_impl: concrete_mechanism<B> { void set_parameter(const std::string& key, const std::vector<fvm_value_type>& vs) override {} - void set_global(const std::string& key, fvm_value_type v) override { - overrides_[key] = v; - } - void initialize() override {} void nrn_state() override {} void nrn_current() override {} void deliver_events() override {} void write_ions() override {} - std::unordered_map<std::string, fvm_value_type> overrides_; std::size_t width_ = 0; + + std::vector<std::string> mech_ions; + + std::unordered_map<std::string, std::string> ion_bindings_; }; +template <typename B> +std::string ion_binding(const std::unique_ptr<concrete_mechanism<B>>& mech, const char* ion) { + const common_impl<B>& impl = dynamic_cast<const common_impl<B>&>(*mech.get()); + return impl.ion_bindings_.count(ion)? impl.ion_bindings_.at(ion): ""; +} + + struct foo_backend { struct shared_state { std::unordered_map<std::string, fvm_value_type> overrides; + std::unordered_map<std::string, std::string> ions = { + { "a", "foo_ion_a" }, + { "b", "foo_ion_b" }, + { "c", "foo_ion_c" }, + { "d", "foo_ion_d" }, + { "e", "foo_ion_e" }, + { "f", "foo_ion_f" } + }; }; }; @@ -90,6 +113,14 @@ using foo_mechanism = common_impl<foo_backend>; struct bar_backend { struct shared_state { std::unordered_map<std::string, fvm_value_type> overrides; + std::unordered_map<std::string, std::string> ions = { + { "a", "bar_ion_a" }, + { "b", "bar_ion_b" }, + { "c", "bar_ion_c" }, + { "d", "bar_ion_d" }, + { "e", "bar_ion_e" }, + { "f", "bar_ion_f" } + }; }; }; @@ -98,6 +129,10 @@ using bar_mechanism = common_impl<bar_backend>; // Fleeb implementations: struct fleeb_foo: foo_mechanism { + fleeb_foo() { + this->mech_ions = {"a", "b", "c", "d"}; + } + const mechanism_fingerprint& fingerprint() const override { static mechanism_fingerprint hash = "fleebprint"; return hash; @@ -109,6 +144,10 @@ struct fleeb_foo: foo_mechanism { }; struct special_fleeb_foo: foo_mechanism { + special_fleeb_foo() { + this->mech_ions = {"a", "b", "c", "d"}; + } + const mechanism_fingerprint& fingerprint() const override { static mechanism_fingerprint hash = "fleebprint"; return hash; @@ -120,6 +159,10 @@ struct special_fleeb_foo: foo_mechanism { }; struct fleeb_bar: bar_mechanism { + fleeb_bar() { + this->mech_ions = {"a", "b", "c", "d"}; + } + const mechanism_fingerprint& fingerprint() const override { static mechanism_fingerprint hash = "fleebprint"; return hash; @@ -174,9 +217,10 @@ mechanism_catalogue build_fake_catalogue() { // Add derived versions with global overrides: - cat.derive("fleeb1", "fleeb", {{"plugh", 1.0}}); + cat.derive("fleeb1", "fleeb", {{"plugh", 1.0}}, {{"a", "b"}, {"b", "a"}}); cat.derive("special_fleeb", "fleeb", {{"plugh", 2.0}}); cat.derive("fleeb2", "special_fleeb", {{"norf", 11.0}}); + cat.derive("fleeb3", "fleeb1", {}, {{"b", "c"}, {"c", "b"}}); cat.derive("bleeble", "burble", {{"quux", 10.}, {"xyzzy", -20.}}); // Attach implementations: @@ -252,33 +296,33 @@ TEST(mechcat, instance) { // All fleebs on the bar backend have the same implementation: - auto fleeb_bar_mech = cat.instance<bar_backend>("fleeb"); - auto fleeb1_bar_mech = cat.instance<bar_backend>("fleeb1"); - auto special_fleeb_bar_mech = cat.instance<bar_backend>("special_fleeb"); - auto fleeb2_bar_mech = cat.instance<bar_backend>("fleeb2"); + auto fleeb_bar_inst = cat.instance<bar_backend>("fleeb"); + auto fleeb1_bar_inst = cat.instance<bar_backend>("fleeb1"); + auto special_fleeb_bar_inst = cat.instance<bar_backend>("special_fleeb"); + auto fleeb2_bar_inst = cat.instance<bar_backend>("fleeb2"); - EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb_bar_mech.get())); - EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb1_bar_mech.get())); - EXPECT_EQ(typeid(fleeb_bar), typeid(*special_fleeb_bar_mech.get())); - EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb2_bar_mech.get())); + EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb_bar_inst.mech.get())); + EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb1_bar_inst.mech.get())); + EXPECT_EQ(typeid(fleeb_bar), typeid(*special_fleeb_bar_inst.mech.get())); + EXPECT_EQ(typeid(fleeb_bar), typeid(*fleeb2_bar_inst.mech.get())); - EXPECT_EQ("fleeb"s, fleeb2_bar_mech->internal_name()); + EXPECT_EQ("fleeb"s, fleeb2_bar_inst.mech->internal_name()); // special_fleeb and fleeb2 (deriving from special_fleeb) have a specialized // implementation: - auto fleeb_foo_mech = cat.instance<foo_backend>("fleeb"); - auto fleeb1_foo_mech = cat.instance<foo_backend>("fleeb1"); - auto special_fleeb_foo_mech = cat.instance<foo_backend>("special_fleeb"); - auto fleeb2_foo_mech = cat.instance<foo_backend>("fleeb2"); + auto fleeb_foo_inst = cat.instance<foo_backend>("fleeb"); + auto fleeb1_foo_inst = cat.instance<foo_backend>("fleeb1"); + auto special_fleeb_foo_inst = cat.instance<foo_backend>("special_fleeb"); + auto fleeb2_foo_inst = cat.instance<foo_backend>("fleeb2"); - EXPECT_EQ(typeid(fleeb_foo), typeid(*fleeb_foo_mech.get())); - EXPECT_EQ(typeid(fleeb_foo), typeid(*fleeb1_foo_mech.get())); - EXPECT_EQ(typeid(special_fleeb_foo), typeid(*special_fleeb_foo_mech.get())); - EXPECT_EQ(typeid(special_fleeb_foo), typeid(*fleeb2_foo_mech.get())); + EXPECT_EQ(typeid(fleeb_foo), typeid(*fleeb_foo_inst.mech.get())); + EXPECT_EQ(typeid(fleeb_foo), typeid(*fleeb1_foo_inst.mech.get())); + EXPECT_EQ(typeid(special_fleeb_foo), typeid(*special_fleeb_foo_inst.mech.get())); + EXPECT_EQ(typeid(special_fleeb_foo), typeid(*fleeb2_foo_inst.mech.get())); - EXPECT_EQ("fleeb"s, fleeb1_foo_mech->internal_name()); - EXPECT_EQ("special fleeb"s, fleeb2_foo_mech->internal_name()); + EXPECT_EQ("fleeb"s, fleeb1_foo_inst.mech->internal_name()); + EXPECT_EQ("special fleeb"s, fleeb2_foo_inst.mech->internal_name()); } TEST(mechcat, instantiate) { @@ -286,18 +330,76 @@ TEST(mechcat, instantiate) { // write its specialized global variables to shared state, but we do in // these tests for testing purposes. - mechanism::layout layout = {{0u, 1u, 2u}, {1., 2., 1.}, {1u, 1u, 1u}}; + mechanism_layout layout = {{0u, 1u, 2u}, {1., 2., 1.}, {1u, 1u, 1u}}; bar_backend::shared_state bar_state; auto cat = build_fake_catalogue(); - cat.instance<bar_backend>("fleeb")->instantiate(0, bar_state, layout); + auto fleeb = cat.instance<bar_backend>("fleeb"); + fleeb.mech->instantiate(0, bar_state, fleeb.overrides, layout); EXPECT_TRUE(bar_state.overrides.empty()); bar_state.overrides.clear(); - cat.instance<bar_backend>("fleeb2")->instantiate(0, bar_state, layout); + auto fleeb2 = cat.instance<bar_backend>("fleeb2"); + fleeb2.mech->instantiate(0, bar_state, fleeb2.overrides, layout); EXPECT_EQ(2.0, bar_state.overrides.at("plugh")); EXPECT_EQ(11.0, bar_state.overrides.at("norf")); + + // Check ion rebinding: + // fleeb1 should have ions 'a' and 'b' swapped; + // fleeb2 should swap 'b' and 'c' relative to fleeb1, so that + // 'b' maps to the state 'c' ion, 'c' maps to the state 'a' ion, + // and 'a' maps to the state 'b' ion. + + EXPECT_EQ("bar_ion_a", ion_binding(fleeb.mech, "a")); + EXPECT_EQ("bar_ion_b", ion_binding(fleeb.mech, "b")); + EXPECT_EQ("bar_ion_c", ion_binding(fleeb.mech, "c")); + EXPECT_EQ("bar_ion_d", ion_binding(fleeb.mech, "d")); + + auto fleeb3 = cat.instance<bar_backend>("fleeb3"); + fleeb3.mech->instantiate(0, bar_state, fleeb3.overrides, layout); + + foo_backend::shared_state foo_state; + auto fleeb1 = cat.instance<foo_backend>("fleeb1"); + fleeb1.mech->instantiate(0, foo_state, fleeb1.overrides, layout); + + EXPECT_EQ("foo_ion_b", ion_binding(fleeb1.mech, "a")); + EXPECT_EQ("foo_ion_a", ion_binding(fleeb1.mech, "b")); + EXPECT_EQ("foo_ion_c", ion_binding(fleeb1.mech, "c")); + EXPECT_EQ("foo_ion_d", ion_binding(fleeb1.mech, "d")); + + EXPECT_EQ("bar_ion_c", ion_binding(fleeb3.mech, "a")); + EXPECT_EQ("bar_ion_a", ion_binding(fleeb3.mech, "b")); + EXPECT_EQ("bar_ion_b", ion_binding(fleeb3.mech, "c")); + EXPECT_EQ("bar_ion_d", ion_binding(fleeb3.mech, "d")); +} + +TEST(mechcat, bad_ion_rename) { + auto cat = build_fake_catalogue(); + + // missing ion + EXPECT_THROW(cat.derive("ono", "fleeb", {}, {{"nosuchion", "x"}}), invalid_ion_remap); + + // two ions with the same name, the original 'b', and the renamed 'a' + EXPECT_THROW(cat.derive("alas", "fleeb", {}, {{"a", "b"}}), invalid_ion_remap); +} + +TEST(mechcat, parameterize_over_ion) { + auto cat = build_fake_catalogue(); + + // Can only use parametrize_over_ion with mechanisms that depend on exactly + // one ion. + + EXPECT_THROW(parameterize_over_ion(cat, "fleeb", "nope"), invalid_ion_remap); + + parameterize_over_ion(cat, "burble", "one"); + parameterize_over_ion(cat, "burble", "two"); + + auto b1 = cat["burble/one"]; + EXPECT_EQ("one", b1.ions.begin()->first); + + auto b2 = cat["burble/two"]; + EXPECT_EQ("two", b2.ions.begin()->first); } TEST(mechcat, copy) { @@ -306,10 +408,10 @@ TEST(mechcat, copy) { EXPECT_EQ(cat["fleeb2"], cat2["fleeb2"]); - auto fleeb2_instance = cat.instance<foo_backend>("fleeb2"); - auto fleeb2_instance2 = cat2.instance<foo_backend>("fleeb2"); + auto fleeb2_inst = cat.instance<foo_backend>("fleeb2"); + auto fleeb2_inst2 = cat2.instance<foo_backend>("fleeb2"); - EXPECT_EQ(typeid(*fleeb2_instance.get()), typeid(*fleeb2_instance.get())); + EXPECT_EQ(typeid(*fleeb2_inst.mech.get()), typeid(*fleeb2_inst2.mech.get())); } diff --git a/test/unit/test_synapses.cpp b/test/unit/test_synapses.cpp index 7b46ad4a..0efc50f5 100644 --- a/test/unit/test_synapses.cpp +++ b/test/unit/test_synapses.cpp @@ -79,10 +79,10 @@ TEST(synapses, syn_basic_state) { int num_comp = 4; int num_intdom = 1; - auto expsyn = unique_cast<multicore::mechanism>(global_default_catalogue().instance<backend>("expsyn")); + auto expsyn = unique_cast<multicore::mechanism>(global_default_catalogue().instance<backend>("expsyn").mech); ASSERT_TRUE(expsyn); - auto exp2syn = unique_cast<multicore::mechanism>(global_default_catalogue().instance<backend>("exp2syn")); + auto exp2syn = unique_cast<multicore::mechanism>(global_default_catalogue().instance<backend>("exp2syn").mech); ASSERT_TRUE(exp2syn); std::vector<fvm_gap_junction> gj = {}; @@ -98,8 +98,8 @@ TEST(synapses, syn_basic_state) { std::vector<index_type> syn_mult(num_syn, 1); std::vector<value_type> syn_weight(num_syn, 1.0); - expsyn->instantiate(0, state, {syn_cv, syn_weight, syn_mult}); - exp2syn->instantiate(1, state, {syn_cv, syn_weight, syn_mult}); + expsyn->instantiate(0, state, {}, {syn_cv, syn_weight, syn_mult}); + exp2syn->instantiate(1, state, {}, {syn_cv, syn_weight, syn_mult}); // Parameters initialized to default values? -- GitLab