diff --git a/arbor/backends/gpu/shared_state.hpp b/arbor/backends/gpu/shared_state.hpp index 03c8e44a00267cc63317f3b9b2df087f26b89422..6aea214e9352a13edeb80622f2bbe45c44265b7e 100644 --- a/arbor/backends/gpu/shared_state.hpp +++ b/arbor/backends/gpu/shared_state.hpp @@ -27,7 +27,7 @@ namespace gpu { struct ion_state { iarray node_index_; // Instance to CV map. - array iX_; // (nA) current + array iX_; // (A/m²) current density array eX_; // (mV) reversal potential array Xi_; // (mM) internal concentration array Xo_; // (mM) external concentration diff --git a/arbor/backends/multicore/shared_state.hpp b/arbor/backends/multicore/shared_state.hpp index 63e8b51f05c4315d6b296a436d553db23eef899a..4d97d1cd0d3d3350c4e9ab8e449fca8b91b3eb5a 100644 --- a/arbor/backends/multicore/shared_state.hpp +++ b/arbor/backends/multicore/shared_state.hpp @@ -42,7 +42,7 @@ struct ion_state { unsigned alignment = 1; // Alignment and padding multiple. iarray node_index_; // Instance to CV map. - array iX_; // (nA) current + array iX_; // (A/m²) current density array eX_; // (mV) reversal potential array Xi_; // (mM) internal concentration array Xo_; // (mM) external concentration diff --git a/arbor/fvm_lowered_cell_impl.hpp b/arbor/fvm_lowered_cell_impl.hpp index 10c5c36ac1d3c32bf102ddd2c20b011a864ad3b8..59400d301ea60b513a88608f90ef9ea5724f7cb5 100644 --- a/arbor/fvm_lowered_cell_impl.hpp +++ b/arbor/fvm_lowered_cell_impl.hpp @@ -454,11 +454,10 @@ void fvm_lowered_cell_impl<B>::initialize( } } else { - // Density Current density contributions from mechanism are in [mA/cm²] - // (NEURON compatibility). F = [mA/cm²] / [A/m²] = 10. + // Current density contributions from mechanism are already in [A/m²]. for (auto i: count_along(layout.cv)) { - layout.weight[i] = 10*config.norm_area[i]; + layout.weight[i] = config.norm_area[i]; } } diff --git a/modcc/blocks.hpp b/modcc/blocks.hpp index b71284ed8f36ce0a4dc4566db0c057e383b31b01..5d9d1cf85fcbb52a31966927d8834f4b42b1e8e4 100644 --- a/modcc/blocks.hpp +++ b/modcc/blocks.hpp @@ -61,11 +61,6 @@ struct IonDep { } }; -enum class moduleKind { - point, - density -}; - typedef std::vector<Token> unit_tokens; struct Id { Token token; diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 2d4accd191e66a60307dbe13f6fcf4872c6bc3a9..441ea63eb3def23b599efcadd234da7b710bdaf2 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -1,6 +1,7 @@ #include <cstring> #include "expression.hpp" +#include "identifier.hpp" inline std::string to_string(symbolKind k) { switch (k) { @@ -68,7 +69,7 @@ std::string Symbol::to_string() const { std::string LocalVariable::to_string() const { std::string s = blue("Local Variable") + " " + yellow(name()); if(is_indexed()) { - s += " ->(" + token_string(external_->op()) + ") " + yellow(external_->index_name()); + s += " -> " + yellow(external_->name()); } return s; } @@ -279,9 +280,9 @@ std::string VariableExpression::to_string() const { std::string IndexedVariable::to_string() const { return - blue("indexed") + " " + yellow(name()) + "->" + yellow(index_name()) + "(" + blue("indexed") + " " + yellow(name()) + "->" + yellow(::to_string(data_source())) + "(" + (is_write() ? " write-only" : " read-only") - + ", ion" + (is_ion()? colorize(ion_channel(), stringColor::green) + + ", ion " + (is_ion()? colorize(ion_channel(), stringColor::green) : colorize("none", stringColor::red)) + ") "; } diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 2157d3965a1e66c3df53f3cf94bbfcd9082004fa..0613ddcfecffd649bffcb30dba2dac8041f37c5d 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -512,47 +512,28 @@ protected: Symbol* shadows_ = nullptr; }; -// an indexed variable +// Indexed variables refer to data held in the shared simulation state. +// Printers will rewrite reads from or assignments from indexed variables +// according to its data source and ion channel. + class IndexedVariable : public Symbol { public: IndexedVariable(Location loc, std::string lookup_name, - std::string index_name, sourceKind data_source, accessKind acc, - tok o=tok::eq, std::string channel="") : Symbol(std::move(loc), std::move(lookup_name), symbolKind::indexed_variable), access_(acc), ion_channel_(std::move(channel)), - index_name_(std::move(index_name)), // (TODO: deprecate/remove this...) - data_source_(data_source), - op_(o) + data_source_(data_source) { - std::string msg; // external symbols are either read or write only if(access()==accessKind::readwrite) { - msg = pprintf("attempt to generate an index % with readwrite access", - yellow(lookup_name)); - goto compiler_error; - } - // read only variables must be assigned via equality - if(is_read() && op()!=tok::eq) { - msg = pprintf("read only indexes % must use assignment", - yellow(lookup_name)); - goto compiler_error; - } - // write only variables must be update via addition/subtraction - if(is_write() && (op()!=tok::plus && op()!=tok::minus)) { - msg = pprintf("write only index % must use addition or subtraction", - yellow(lookup_name)); - goto compiler_error; + throw compiler_exception( + pprintf("attempt to generate an index % with readwrite access", yellow(lookup_name)), + location_); } - - return; - - compiler_error: - throw(compiler_exception(msg, location_)); } std::string to_string() const override; @@ -561,8 +542,6 @@ public: std::string ion_channel() const { return ion_channel_; } sourceKind data_source() const { return data_source_; } void data_source(sourceKind k) { data_source_ = k; } - std::string const& index_name() const { return index_name_; } - tok op() const { return op_; } bool is_ion() const { return !ion_channel_.empty(); } bool is_read() const { return access_ == accessKind::read; } @@ -575,9 +554,7 @@ public: protected: accessKind access_; std::string ion_channel_; - std::string index_name_; // hint to printer only sourceKind data_source_; - tok op_; }; class LocalVariable : public Symbol { diff --git a/modcc/identifier.hpp b/modcc/identifier.hpp index f01fe61f9fb45523d32308ad975f0765d5f54bc5..8e62534a999740c71d76a280c90c25cd702ca329 100644 --- a/modcc/identifier.hpp +++ b/modcc/identifier.hpp @@ -4,6 +4,10 @@ #include <string> #include <stdexcept> +enum class moduleKind { + point, density +}; + /// indicate how a variable is accessed /// access is (read, written, or both) /// the distinction between write only and read only is required because @@ -36,10 +40,13 @@ enum class linkageKind { /// possible external data source for indexed variables enum class sourceKind { voltage, + current_density, current, conductivity, + conductance, dt, ion_current, + ion_current_density, ion_revpot, ion_iconc, ion_econc, @@ -73,6 +80,26 @@ inline std::string to_string(linkageKind v) { return std::string("<error : undefined visibilityKind>"); } +inline std::string to_string(sourceKind v) { + switch(v) { + case sourceKind::voltage: return "voltage"; + case sourceKind::current_density: return "current_density"; + case sourceKind::current: return "current"; + case sourceKind::conductivity: return "conductivity"; + case sourceKind::conductance: return "conductance"; + case sourceKind::dt: return "dt"; + case sourceKind::ion_current: return "ion_current"; + case sourceKind::ion_current_density: return "ion_current_density"; + case sourceKind::ion_revpot: return "ion_revpot"; + case sourceKind::ion_iconc: return "ion_iconc"; + case sourceKind::ion_econc: return "ion_econc"; + case sourceKind::ion_valence: return "ion_valence"; + case sourceKind::temperature: return "temperature"; + case sourceKind::no_source: return "no source"; + default: return "unknown source"; + } +} + // ostream writers inline std::ostream& operator<< (std::ostream& os, visibilityKind v) { @@ -85,9 +112,9 @@ inline std::ostream& operator<< (std::ostream& os, linkageKind l) { /// ion variable to data source kind -inline sourceKind ion_source(const std::string& ion, const std::string& var) { +inline sourceKind ion_source(const std::string& ion, const std::string& var, moduleKind mkind) { if (ion.empty()) return sourceKind::no_source; - else if (var=="i"+ion) return sourceKind::ion_current; + else if (var=="i"+ion) return mkind==moduleKind::point? sourceKind::ion_current: sourceKind::ion_current_density; else if (var=="e"+ion) return sourceKind::ion_revpot; else if (var==ion+"i") return sourceKind::ion_iconc; else if (var==ion+"e") return sourceKind::ion_econc; diff --git a/modcc/memop.hpp b/modcc/memop.hpp deleted file mode 100644 index 48232ae45d79546d7e4297f9d53e67a7172bd9a2..0000000000000000000000000000000000000000 --- a/modcc/memop.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include "io/pprintf.hpp" -#include "lexer.hpp" -#include "util.hpp" - -/// Defines a memory operation that is to performed by an APIMethod. -/// Kernels can read/write global state via an index, e.g. -/// - loading voltage v from VEC_V in matrix before computation -/// - loading a variable associated with an ionic variable -/// - accumulating an update to VEC_RHS/VEC_D after computation -/// - adding contribution to an ionic current -/// How these operations are handled will vary significantly from -/// one backend implementation to another, so inserting expressions -/// directly into the APIMethod body to perform them is not appropriate. -/// Instead, each API method stores two lists -/// - a list of load/input transactions to perform before kernel -/// - a list of store/output transactions to perform after kernel -/// The lists are of MemOps, which describe the local and external variables -template <typename Symbol> -struct MemOp { - using symbol_type = Symbol; - tok op; - Symbol *local; - Symbol *external; - - MemOp(tok o, Symbol *loc, Symbol *ext) - : op(o), local(loc), external(ext) - { - const tok valid_ops[] = {tok::plus, tok::minus, tok::eq}; - if(!is_in(op, valid_ops)) { - throw compiler_exception( - "invalid operation for creating a MemOp : " + - loc->to_string() + yellow(token_string(op)) + ext->to_string(), - loc->location()); - } - } -}; - diff --git a/modcc/module.cpp b/modcc/module.cpp index dc7b559340853a3b460c1e23b409df648121b8e6..0e1ba7674b124c47fe6dd696359bc509735b7613 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -32,7 +32,11 @@ class NrnCurrentRewriter: public BlockRewriterBase { if(auto var = sym->is_local_variable()) { if(auto ext = var->external_variable()) { sourceKind src = ext->data_source(); - if (src==sourceKind::current || src==sourceKind::ion_current) { + if (src==sourceKind::current_density || + src==sourceKind::current || + src==sourceKind::ion_current_density || + src==sourceKind::ion_current) + { return src; } } @@ -50,33 +54,13 @@ public: virtual void finalize() override { if (has_current_update_) { - // Initialize current_ and conductivity_ as first statements. + // Initialize conductivity_ as first statement. statements_.push_front(make_expression<AssignmentExpression>(loc_, id("conductivity_"), make_expression<NumberExpression>(loc_, 0.0))); statements_.push_front(make_expression<AssignmentExpression>(loc_, id("current_"), make_expression<NumberExpression>(loc_, 0.0))); - - // Scale current and conductivity contributions by weight. - statements_.push_back(make_expression<AssignmentExpression>(loc_, - id("current_"), - make_expression<MulBinaryExpression>(loc_, - id("weight_"), - id("current_")))); - statements_.push_back(make_expression<AssignmentExpression>(loc_, - id("conductivity_"), - make_expression<MulBinaryExpression>(loc_, - id("weight_"), - id("conductivity_")))); - - for (auto& v: ion_current_vars_) { - statements_.push_back(make_expression<AssignmentExpression>(loc_, - id(v), - make_expression<MulBinaryExpression>(loc_, - id("weight_"), - id(v)))); - } } } @@ -90,7 +74,7 @@ public: if (current_source != sourceKind::no_source) { has_current_update_ = true; - if (current_source==sourceKind::ion_current) { + if (current_source==sourceKind::ion_current_density || current_source==sourceKind::ion_current) { ion_current_vars_.insert(e->lhs()->is_identifier()->name()); } else { @@ -254,10 +238,6 @@ bool Module::semantic() { return id; }; - auto numeric_literal = [](double v) -> expression_ptr { - return make_expression<NumberExpression>(Location{}, v); - }; - for (auto& sym: symbols_) { Location loc; @@ -271,17 +251,9 @@ bool Module::semantic() { auto ionvar = shadowed->is_indexed_variable(); if (!ionvar || !ionvar->is_ion() || !ionvar->is_write()) continue; - auto weight = symbols_["weight_"].get(); - if (!weight) throw compiler_exception("missing weight_ global", loc); - ion_assignments.push_back( make_expression<AssignmentExpression>(loc, - sym_to_id(ionvar), - make_expression<MulBinaryExpression>(loc, - make_expression<MulBinaryExpression>(loc, - numeric_literal(0.1), - sym_to_id(weight)), - sym_to_id(state)))); + sym_to_id(ionvar), sym_to_id(state))); } symbols_["write_ions"] = make_symbol<APIMethod>(Location{}, "write_ions", @@ -497,37 +469,26 @@ void Module::add_variables_to_symbols() { return symbols_[var->name()] = symbol_ptr{var}; }; - // mechanisms use a vector of weights to: - // density mechs: - // - convert current densities from 10.A.m^-2 to A.m^-2 - // - density or proportion of a CV's area affected by the mechansim - // point procs: - // - convert current in nA to current densities in A.m^-2 - - create_variable(Token(tok::identifier, "weight_"), - accessKind::read, visibilityKind::global, linkageKind::external, rangeKind::range); - // add indexed variables to the table auto create_indexed_variable = [this] - (std::string const& name, std::string const& indexed_name, sourceKind data_source, - tok op, accessKind acc, std::string ch, Location loc) -> symbol_ptr& + (std::string const& name, sourceKind data_source, + accessKind acc, std::string ch, Location loc) -> symbol_ptr& { if(symbols_.count(name)) { throw compiler_exception( pprintf("the symbol % already exists", yellow(name)), loc); } return symbols_[name] = - make_symbol<IndexedVariable>(loc, name, indexed_name, data_source, acc, op, ch); + make_symbol<IndexedVariable>(loc, name, data_source, acc, ch); }; - create_indexed_variable("current_", "vec_i", sourceKind::current, tok::plus, - accessKind::write, "", Location()); - create_indexed_variable("conductivity_", "vec_g", sourceKind::conductivity, tok::plus, - accessKind::write, "", Location()); - create_indexed_variable("v", "vec_v", sourceKind::voltage, tok::eq, - accessKind::read, "", Location()); - create_indexed_variable("dt", "vec_dt", sourceKind::dt, tok::eq, - accessKind::read, "", Location()); + sourceKind current_kind = kind_==moduleKind::point? sourceKind::current: sourceKind::current_density; + sourceKind conductance_kind = kind_==moduleKind::point? sourceKind::conductance: sourceKind::conductivity; + + create_indexed_variable("current_", current_kind, accessKind::write, "", Location()); + create_indexed_variable("conductivity_", conductance_kind, accessKind::write, "", Location()); + create_indexed_variable("v", sourceKind::voltage, accessKind::read, "", Location()); + create_indexed_variable("dt", sourceKind::dt, accessKind::read, "", Location()); // If we put back support for accessing cell time again from NMODL code, // add indexed_variable also for "time" with appropriate cell-index based @@ -549,8 +510,7 @@ void Module::add_variables_to_symbols() { // data source. Retrieval of value is handled especially by printers. if (id.name() == "celsius") { - create_indexed_variable("celsius", "celsius", - sourceKind::temperature, tok::eq, accessKind::read, "", Location()); + create_indexed_variable("celsius", sourceKind::temperature, accessKind::read, "", Location()); } else { // Parameters are scalar by default, but may later be changed to range. @@ -591,7 +551,7 @@ void Module::add_variables_to_symbols() { (Token const& tkn, accessKind acc, const std::string& channel) { std::string name = tkn.spelling; - sourceKind data_source = ion_source(channel, name); + sourceKind data_source = ion_source(channel, name, kind_); // If the symbol already exists and is not a state variable, // it is an error. @@ -615,8 +575,7 @@ void Module::add_variables_to_symbols() { name += "_shadowed_"; } - auto& sym = create_indexed_variable(name, "ion_"+name, data_source, - acc==accessKind::read ? tok::eq : tok::plus, acc, channel, tkn.location); + auto& sym = create_indexed_variable(name, data_source, acc, channel, tkn.location); if (state) { state->shadows(sym.get()); @@ -631,7 +590,7 @@ void Module::add_variables_to_symbols() { if( neuron_block_.has_nonspecific_current() ) { auto const& i = neuron_block_.nonspecific_current; - create_indexed_variable(i.spelling, "", sourceKind::current, tok::plus, accessKind::write, "", i.location); + create_indexed_variable(i.spelling, sourceKind::current, accessKind::write, "", i.location); } for(auto const& ion : neuron_block_.ions) { @@ -644,8 +603,8 @@ void Module::add_variables_to_symbols() { if(ion.uses_valence()) { Token valence_var = ion.valence_var; - create_indexed_variable(valence_var.spelling, "ion_valence", sourceKind::ion_valence, - tok::eq, accessKind::read, ion.name, valence_var.location); + create_indexed_variable(valence_var.spelling, sourceKind::ion_valence, + accessKind::read, ion.name, valence_var.location); } } diff --git a/modcc/printer/cprinter.cpp b/modcc/printer/cprinter.cpp index 16f96ea74f819f8f8223790c9266e7d67a727ce2..856248d1fe669d63bbd490393733b107d01f3699 100644 --- a/modcc/printer/cprinter.cpp +++ b/modcc/printer/cprinter.cpp @@ -404,16 +404,6 @@ void CPrinter::visit(VariableExpression *sym) { out_ << sym->name() << (sym->is_range()? "[i_]": ""); } -void CPrinter::visit(IndexedVariable *sym) { - indexed_variable_info v = decode_indexed_variable(sym); - if (v.scalar()) { - out_ << v.data_var << "[0]"; - } - else { - out_ << v.data_var << "[" << v.index_var << "[i_]]"; - } -} - void CPrinter::visit(CallExpression* e) { out_ << e->name() << "(i_"; for (auto& arg: e->args()) { @@ -453,11 +443,29 @@ void emit_procedure_proto(std::ostream& out, ProcedureExpression* e, const std:: out << ")"; } +namespace { + // Convenience I/O wrapper for emitting indexed access to an external variable. + + struct deref { + indexed_variable_info d; + deref(indexed_variable_info d): d(d) {} + + friend std::ostream& operator<<(std::ostream& o, const deref& wrap) { + return o << wrap.d.data_var << '[' + << (wrap.d.scalar()? "0": wrap.d.index_var+"[i_]") << ']'; + } + }; +} + void emit_state_read(std::ostream& out, LocalVariable* local) { out << "value_type " << cprint(local) << " = "; if (local->is_read()) { - out << cprint(local->external_variable()) << ";\n"; + auto d = decode_indexed_variable(local->external_variable()); + if (d.scale != 1) { + out << as_c_double(d.scale) << "*"; + } + out << deref(d) << ";\n"; } else { out << "0;\n"; @@ -467,12 +475,27 @@ void emit_state_read(std::ostream& out, LocalVariable* local) { void emit_state_update(std::ostream& out, Symbol* from, IndexedVariable* external) { if (!external->is_write()) return; - if (decode_indexed_variable(external).scalar()) { - throw compiler_exception("Cannot assign to global scalar: "+external->to_string()); + auto d = decode_indexed_variable(external); + double coeff = 1./d.scale; + + if (d.readonly) { + throw compiler_exception("Cannot assign to read-only external state: "+external->to_string()); } - const char* op = external->op()==tok::plus? " += ": " -= "; - out << cprint(external) << op << from->name() << ";\n"; + if (d.accumulate) { + out << deref(d) << " = fma("; + if (coeff != 1) { + out << as_c_double(coeff) << '*'; + } + out << "weight_[i_], " << from->name() << ", " << deref(d) << ");\n"; + } + else { + out << deref(d) << " = "; + if (coeff != 1) { + out << as_c_double(coeff) << '*'; + } + out << from->name() << ";\n"; + } } void emit_api_body(std::ostream& out, APIMethod* method) { @@ -543,18 +566,6 @@ void SimdPrinter::visit(AssignmentExpression* e) { } } -void SimdPrinter::visit(IndexedVariable *sym) { - indexed_variable_info v = decode_indexed_variable(sym); - if (v.scalar()) { - out_ << v.data_var << "[0]"; - } - else { - out_ << "S::indirect(" << v.data_var - << ", " << index_i_name(v.index_var) - << ", constraint_category_)"; - } -} - void SimdPrinter::visit(CallExpression* e) { out_ << e->name() << "(index_"; for (auto& arg: e->args()) { @@ -601,23 +612,27 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con out << "simd_value " << local->name(); if (local->is_read()) { - indexed_variable_info v = decode_indexed_variable(local->external_variable()); - if (v.scalar()) { - out << "(" << v.data_var + auto d = decode_indexed_variable(local->external_variable()); + if (d.scalar()) { + out << "(" << d.data_var << "[0]);\n"; } else if (constraint == simd_expr_constraint::contiguous) { - out << "(" << v.data_var - << " + " << v.index_var + out << "(" << d.data_var + << " + " << d.index_var << "[index_]);\n"; } else if (constraint == simd_expr_constraint::constant) { - out << "(" << v.data_var - << "[" << v.index_var + out << "(" << d.data_var + << "[" << d.index_var << "element0]);\n"; } else { - out << "(" << simdprint(local->external_variable()) << ");\n"; + out << "(S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_));\n"; + } + + if (d.scale != 1) { + out << local->name() << " *= " << d.scale << ";\n"; } } else { @@ -628,21 +643,51 @@ void emit_simd_state_read(std::ostream& out, LocalVariable* local, simd_expr_con void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* external, simd_expr_constraint constraint) { if (!external->is_write()) return; - const char* op = external->op()==tok::plus? " += ": " -= "; - indexed_variable_info v = decode_indexed_variable(external); + auto d = decode_indexed_variable(external);; + double coeff = 1./d.scale; - if (v.scalar()) { - throw compiler_exception("Cannot assign to global scalar: "+external->to_string()); + if (d.readonly) { + throw compiler_exception("Cannot assign to read-only external state: "+external->to_string()); } - else { + + if (d.accumulate) { + std::string tempvar = "t_"+external->name(); + if (constraint == simd_expr_constraint::contiguous) { - out << "simd_value t_"<< external->name() <<"(" << v.data_var << " + " << v.index_var << "[index_]);\n"; - out << "t_" << external->name() << op << from->name() << ";\n"; - out << "t_" << external->name() << ".copy_to(" << v.data_var << " + " << v.index_var << "[index_]);\n"; + out << "simd_value "<< tempvar <<"(" << d.data_var << " + " << d.index_var << "[index_]);\n" + << tempvar << " += w_*"; + + if (coeff!=1) out << as_c_double(coeff) << "*"; + out << from->name() << ";\n" + << tempvar << ".copy_to(" << d.data_var << " + " << d.index_var << "[index_]);\n"; } else { - out << simdprint(external) << op << from->name() << ";\n"; + out << "S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_)" + << " += w_*"; + + if (coeff!=1) out << as_c_double(coeff) << "*"; + + out << from->name() << ";\n"; + } + } + else { + if (constraint == simd_expr_constraint::contiguous) { + if (coeff!=1) { + out << "(" << as_c_double(coeff) << "*" << from->name() << ")"; + } + else { + out << from->name(); + } + out << ".copy_to(" << d.data_var << " + " << d.index_var << "[index_]);\n"; + } + else { + out << "S::indirect(" << d.data_var << ", " << index_i_name(d.index_var) << ", constraint_category_)" + << " = "; + + if (coeff!=1) out << as_c_double(coeff) << "*"; + + out << from->name() << ";\n"; } } } @@ -650,21 +695,19 @@ void emit_simd_state_update(std::ostream& out, Symbol* from, IndexedVariable* ex void emit_index_initialize(std::ostream& out, const std::unordered_set<std::string>& indices, simd_expr_constraint constraint) { switch(constraint) { - case simd_expr_constraint::contiguous : - break; - case simd_expr_constraint::constant : { - for (auto& index: indices) { - out << "simd_index::scalar_type " << index << "element0 = " << index << "[index_];\n"; - out << index_i_name(index) << " = " << index << "element0;\n"; - } + case simd_expr_constraint::contiguous: + break; + case simd_expr_constraint::constant: + for (auto& index: indices) { + out << "simd_index::scalar_type " << index << "element0 = " << index << "[index_];\n"; + out << index_i_name(index) << " = " << index << "element0;\n"; } - break; - case simd_expr_constraint::other : { - for (auto& index: indices) { - out << index_i_name(index) << ".copy_from(" << index << ".data() + index_);\n"; - } + break; + case simd_expr_constraint::other: + for (auto& index: indices) { + out << index_i_name(index) << ".copy_from(" << index << ".data() + index_);\n"; } - break; + break; } } @@ -689,6 +732,7 @@ void emit_body_for_loop(std::ostream& out, BlockExpression* body, const std::vec void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, const std::vector<LocalVariable*>& indexed_vars, + bool requires_weight, const std::unordered_set<std::string>& indices, const simd_expr_constraint& read_constraint, const simd_expr_constraint& write_constraint, @@ -700,6 +744,9 @@ void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, << indent; out << "index_type index_ = index_constraints_." << underlying_constraint_name << "[i_];\n"; + if (requires_weight) { + out << "simd_value w_(weight_+index_);\n"; + } emit_body_for_loop(out, body, indexed_vars, indices, read_constraint, write_constraint); @@ -709,6 +756,7 @@ void emit_for_loop_per_constraint(std::ostream& out, BlockExpression* body, void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_kind) { auto body = method->body(); auto indexed_vars = indexed_locals(method->scope()); + bool requires_weight = false; std::vector<LocalVariable*> scalar_indexed_vars; std::unordered_set<std::string> indices; @@ -720,6 +768,9 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ else { scalar_indexed_vars.push_back(sym); } + if (info.accumulate) { + requires_weight = true; + } } if (!body->statements().empty()) { @@ -734,21 +785,21 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ simd_expr_constraint constraint = simd_expr_constraint::contiguous; std::string underlying_constraint = "contiguous"; - emit_for_loop_per_constraint(out, body, indexed_vars, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all independent simd_vectors constraint = simd_expr_constraint::other; underlying_constraint = "independent"; - emit_for_loop_per_constraint(out, body, indexed_vars, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all simd_vectors that have no optimizing constraints constraint = simd_expr_constraint::other; underlying_constraint = "none"; - emit_for_loop_per_constraint(out, body, indexed_vars, indices, constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, constraint, constraint, underlying_constraint); //Generate for loop for all constant simd_vectors @@ -756,7 +807,7 @@ void emit_simd_api_body(std::ostream& out, APIMethod* method, moduleKind module_ simd_expr_constraint write_constraint = simd_expr_constraint::other; underlying_constraint = "constant"; - emit_for_loop_per_constraint(out, body, indexed_vars, indices, read_constraint, + emit_for_loop_per_constraint(out, body, indexed_vars, requires_weight, indices, read_constraint, write_constraint, underlying_constraint); } diff --git a/modcc/printer/cprinter.hpp b/modcc/printer/cprinter.hpp index e7712b87c9644551e95a9242aacde41f410edfe5..b7612ca98f5ef8669bed42ee0f6de7aed417fa44 100644 --- a/modcc/printer/cprinter.hpp +++ b/modcc/printer/cprinter.hpp @@ -26,7 +26,6 @@ public: void visit(IdentifierExpression*) override; void visit(VariableExpression*) override; void visit(LocalVariable*) override; - void visit(IndexedVariable*) override; // Delegate low-level emits to cexpr_emit: void visit(NumberExpression* e) override { cexpr_emit(e, out_, this); } @@ -61,7 +60,6 @@ public: void visit(IdentifierExpression*) override; void visit(VariableExpression*) override; void visit(LocalVariable*) override; - void visit(IndexedVariable*) override; void visit(AssignmentExpression*) override; void visit(NumberExpression* e) override { cexpr_emit(e, out_, this); } diff --git a/modcc/printer/cudaprinter.cpp b/modcc/printer/cudaprinter.cpp index 3a659214e547e3859e7d018bbb6f5f337f785f0d..783eb9f6573a0bb0cc8173f085b014d5706e6ac5 100644 --- a/modcc/printer/cudaprinter.cpp +++ b/modcc/printer/cudaprinter.cpp @@ -428,37 +428,66 @@ void emit_procedure_body_cu(std::ostream& out, ProcedureExpression* e) { out << cuprint(e->body()); } +namespace { + // Convenience I/O wrapper for emitting indexed access to an external variable. + + struct deref { + indexed_variable_info v; + + deref(indexed_variable_info v): v(v) {} + friend std::ostream& operator<<(std::ostream& o, const deref& wrap) { + return o << "params_." << wrap.v.data_var << '[' + << (wrap.v.scalar()? "0": index_i_name(wrap.v.index_var)) << ']'; + } + }; +} + void emit_state_read_cu(std::ostream& out, LocalVariable* local) { out << "value_type " << cuprint(local) << " = "; if (local->is_read()) { - out << cuprint(local->external_variable()) << ";\n"; + auto d = decode_indexed_variable(local->external_variable()); + if (d.scale != 1) { + out << as_c_double(d.scale) << "*"; + } + out << deref(d) << ";\n"; } else { out << "0;\n"; } } + void emit_state_update_cu(std::ostream& out, Symbol* from, IndexedVariable* external, bool is_point_proc) { if (!external->is_write()) return; - const bool is_minus = external->op()==tok::minus; auto d = decode_indexed_variable(external); + double coeff = 1./d.scale; - if (d.scalar()) { - throw compiler_exception("Cannot assign to global scalar: "+external->to_string()); + if (d.readonly) { + throw compiler_exception("Cannot assign to read-only external state: "+external->to_string()); } - if (is_point_proc) { + if (is_point_proc && d.accumulate) { out << "::arb::gpu::reduce_by_key("; - is_minus && out << "-"; - out << from->name() - << ", params_." << d.data_var << ", " << index_i_name(d.index_var) << ", lane_mask_);\n"; + if (coeff != 1) out << as_c_double(coeff) << '*'; + + out << "params_.weight_[tid_]*" << from->name() << ','; + out << "params_." << d.data_var << ", " << index_i_name(d.index_var) << ", lane_mask_);\n"; + } + else if (d.accumulate) { + out << deref(d) << " = fma("; + if (coeff != 1) out << as_c_double(coeff) << '*'; + + out << "params_.weight_[tid_], " << from->name() << ", " << deref(d) << ");\n"; } else { - out << cuprint(external) << (is_minus? " -= ": " += ") << from->name() << ";\n"; + out << deref(d) << " = "; + if (coeff != 1) out << as_c_double(coeff) << '*'; + + out << from->name() << ";\n"; } } @@ -468,16 +497,6 @@ void CudaPrinter::visit(VariableExpression *sym) { out_ << "params_." << sym->name() << (sym->is_range()? "[tid_]": ""); } -void CudaPrinter::visit(IndexedVariable *e) { - auto d = decode_indexed_variable(e); - if (d.scalar()) { - out_ << "params_." << d.data_var << "[0]"; - } - else { - out_ << "params_." << d.data_var << "[" << index_i_name(d.index_var) << "]"; - } -} - void CudaPrinter::visit(CallExpression* e) { out_ << e->name() << "(params_, tid_"; for (auto& arg: e->args()) { diff --git a/modcc/printer/cudaprinter.hpp b/modcc/printer/cudaprinter.hpp index 2b41f38fa6c1d9851f89a89d6e1090278f8aa390..246f83390df32a491bb9ef8fb74ec7aa1e78ad9b 100644 --- a/modcc/printer/cudaprinter.hpp +++ b/modcc/printer/cudaprinter.hpp @@ -15,6 +15,5 @@ public: void visit(CallExpression*) override; void visit(VariableExpression*) override; - void visit(IndexedVariable*) override; }; diff --git a/modcc/printer/printerutil.cpp b/modcc/printer/printerutil.cpp index e3dbbcf54f2425eac7962128d04b564449ccfd26..4febdb01894843ece1cb01c16c723e1e720a09c7 100644 --- a/modcc/printer/printerutil.cpp +++ b/modcc/printer/printerutil.cpp @@ -114,50 +114,83 @@ NetReceiveExpression* find_net_receive(const Module& m) { } indexed_variable_info decode_indexed_variable(IndexedVariable* sym) { - std::string data_var, ion_pfx; - std::string index_var = "node_index_"; + indexed_variable_info v; + v.index_var = "node_index_"; + v.scale = 1; + v.accumulate = true; + v.readonly = true; + std::string ion_pfx; if (sym->is_ion()) { ion_pfx = "ion_"+sym->ion_channel()+"_"; - index_var = ion_pfx+"index_"; + v.index_var = ion_pfx+"index_"; } switch (sym->data_source()) { case sourceKind::voltage: - data_var="vec_v_"; + v.data_var="vec_v_"; + v.readonly = true; + break; + case sourceKind::current_density: + v.data_var = "vec_i_"; + v.readonly = false; + v.scale = 0.1; break; case sourceKind::current: - data_var="vec_i_"; + // unit scale; sourceKind for point processes updating current variable. + v.data_var = "vec_i_"; + v.readonly = false; break; case sourceKind::conductivity: - data_var="vec_g_"; + v.data_var = "vec_g_"; + v.readonly = false; + v.scale = 0.1; + break; + case sourceKind::conductance: + // unit scale; sourceKind for point processes updating conductivity. + v.data_var = "vec_g_"; + v.readonly = false; break; case sourceKind::dt: - data_var="vec_dt_"; + v.data_var = "vec_dt_"; + v.readonly = true; + break; + case sourceKind::ion_current_density: + v.data_var = ion_pfx+".current_density"; + v.scale = 0.1; + v.readonly = false; break; case sourceKind::ion_current: - data_var=ion_pfx+".current_density"; + // unit scale; sourceKind for point processes updating an ionic current variable. + v.data_var = ion_pfx+".current_density"; + v.readonly = false; break; case sourceKind::ion_revpot: - data_var=ion_pfx+".reversal_potential"; + v.data_var = ion_pfx+".reversal_potential"; + v.accumulate = false; + v.readonly = false; break; case sourceKind::ion_iconc: - data_var=ion_pfx+".internal_concentration"; + v.data_var = ion_pfx+".internal_concentration"; + v.readonly = false; break; case sourceKind::ion_econc: - data_var=ion_pfx+".external_concentration"; + v.data_var = ion_pfx+".external_concentration"; + v.readonly = false; break; case sourceKind::ion_valence: - data_var=ion_pfx+".ionic_charge"; - index_var=""; // scalar global + v.data_var = ion_pfx+".ionic_charge"; + v.index_var = ""; // scalar global + v.readonly = true; break; case sourceKind::temperature: - data_var="temperature_degC_"; - index_var=""; // scalar global + v.data_var = "temperature_degC_"; + v.index_var = ""; // scalar global + v.readonly = true; break; default: throw compiler_exception(pprintf("unrecognized indexed data source: %", sym), sym->location()); } - return {data_var, index_var}; + return v; } diff --git a/modcc/printer/printerutil.hpp b/modcc/printer/printerutil.hpp index 208c577af369c07705909afcd880ec7e5517c31b..5805778a2fff6921af585179df305c3462dbac5b 100644 --- a/modcc/printer/printerutil.hpp +++ b/modcc/printer/printerutil.hpp @@ -117,6 +117,13 @@ NetReceiveExpression* find_net_receive(const Module& m); struct indexed_variable_info { std::string data_var; std::string index_var; + + bool accumulate = true; // true => add with weight_ factor on assignment + bool readonly = false; // true => can never be assigned to by a mechanism + + // Scale is the conversion factor from the data variable + // to the NMODL value. + double scale = 1; bool scalar() const { return index_var.empty(); } }; diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index fa3fdb93f496d78174b61cd8df1aeca082436c44..36e66d6f508be26d9a1e35c0ed5b9cd40a968398 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -2,6 +2,9 @@ set(test_mechanisms celsius_test + fixed_ica_current + point_ica_current + linear_ca_conc test_cl_valence test_ca_read_valence ) diff --git a/test/unit/mod/fixed_ica_current.mod b/test/unit/mod/fixed_ica_current.mod new file mode 100644 index 0000000000000000000000000000000000000000..f3887c09d3ee87bd1339f765cc337a9cb67782e8 --- /dev/null +++ b/test/unit/mod/fixed_ica_current.mod @@ -0,0 +1,24 @@ +: Test mechanism generating a fixed ionic current + +NEURON { + SUFFIX fixed_ica_current + USEION ca WRITE ica VALENCE 2 + RANGE ica_density +} + +PARAMETER { + ica_density = 0 +} + +ASSIGNED {} + +INITIAL { + ica = ica_density +} + +STATE {} + +BREAKPOINT { + ica = ica_density +} + diff --git a/test/unit/mod/linear_ca_conc.mod b/test/unit/mod/linear_ca_conc.mod new file mode 100644 index 0000000000000000000000000000000000000000..86d50bd2268a55f10634cf7407fb9e1826965c1c --- /dev/null +++ b/test/unit/mod/linear_ca_conc.mod @@ -0,0 +1,30 @@ +: Test mechanism with linear response to ica. + +NEURON { + SUFFIX linear_ca_conc + USEION ca READ ica WRITE cai VALENCE 2 + RANGE coeff +} + +PARAMETER { + coeff = 0 +} + +ASSIGNED {} + +STATE { + cai +} + +INITIAL { + cai = 0 +} + +BREAKPOINT { + SOLVE update METHOD cnexp +} + +DERIVATIVE update { + cai' = -coeff*ica +} + diff --git a/test/unit/mod/point_ica_current.mod b/test/unit/mod/point_ica_current.mod new file mode 100644 index 0000000000000000000000000000000000000000..de3ca5255238e514d9aeb252a3c52208b13a98cd --- /dev/null +++ b/test/unit/mod/point_ica_current.mod @@ -0,0 +1,24 @@ +: Test point mechanism generating a fixed ionic current + +NEURON { + POINT_PROCESS point_ica_current + USEION ca WRITE ica VALENCE 2 +} + +ASSIGNED {} + +STATE { + ica_nA +} + +INITIAL { + ica_nA = 0 +} + +BREAKPOINT { + ica = ica_nA +} + +NET_RECEIVE(weight) { + ica_nA = weight +} diff --git a/test/unit/test_fvm_lowered.cpp b/test/unit/test_fvm_lowered.cpp index a693f0fa8d9ceabc1e521f06538a81e3d73ea453..e6d574630b1af13dfea47c414c1e2f9c91d77ad0 100644 --- a/test/unit/test_fvm_lowered.cpp +++ b/test/unit/test_fvm_lowered.cpp @@ -552,6 +552,96 @@ TEST(fvm_lowered, read_valence) { } } +// Test correct scaling of ionic currents in reading and writing + +TEST(fvm_lowered, ionic_currents) { + cable_cell c; + auto soma = c.add_soma(6.0); + + // Mechanism parameter is in NMODL units, i.e. mA/cm². + + const double jca = 1.5; + soma->add_mechanism(mechanism_desc("fixed_ica_current").set("ica_density", jca)); + + // Mechanism models a well-mixed fixed-depth volume without replenishment, + // giving a linear response to ica over time. + // + // cai' = - coeff · ica + // + // with NMODL units: cai' [mM/ms]; ica [mA/cm²], giving coeff in [mol/cm/C]. + + const double coeff = 0.5; + soma->add_mechanism(mechanism_desc("linear_ca_conc").set("coeff", coeff)); + + cable1d_recipe rec(c); + rec.catalogue() = make_unit_test_catalogue(); + + execution_context context; + + std::vector<target_handle> targets; + std::vector<fvm_index_type> cell_to_intdom; + probe_association_map<probe_handle> probe_map; + + fvm_cell fvcell(context); + fvcell.initialize({0}, rec, cell_to_intdom, targets, probe_map); + + auto& state = *(fvcell.*private_state_ptr).get(); + auto& ion = state.ion_data.at("ca"s); + + // Ionic current should be 15 A/m², and initial concentration zero. + EXPECT_EQ(15, ion.iX_[0]); + EXPECT_EQ(0, ion.Xi_[0]); + + // Integration should be (effectively) exact, so check for linear response. + const double time = 12; // [ms] + (void)fvcell.integrate(time, 0.1, {}, {}); + double expected_Xi = -time*coeff*jca; + EXPECT_NEAR(expected_Xi, ion.Xi_[0], 1e-6); +} + +// Test correct scaling of an ionic current updated via a point mechanism + +TEST(fvm_lowered, point_ionic_current) { + cable_cell c; + + double r = 6.0; // [µm] + c.add_soma(6.0); + + double soma_area_m2 = 4*math::pi<double>*r*r*1e-12; // [m²] + + // Event weight is translated by point_ica_current into a current contribution in nA. + c.add_synapse({0u, 0.5}, "point_ica_current"); + + cable1d_recipe rec(c); + rec.catalogue() = make_unit_test_catalogue(); + + execution_context context; + + std::vector<target_handle> targets; + std::vector<fvm_index_type> cell_to_intdom; + probe_association_map<probe_handle> probe_map; + + fvm_cell fvcell(context); + fvcell.initialize({0}, rec, cell_to_intdom, targets, probe_map); + + // Only one target, corresponding to our point process on soma. + double ica_nA = 12.3; + deliverable_event ev = {0.04, target_handle{0, 0, 0}, (float)ica_nA}; + + auto& state = *(fvcell.*private_state_ptr).get(); + auto& ion = state.ion_data.at("ca"s); + + // Ionic current should be 0 A/m² after initialization. + EXPECT_EQ(0, ion.iX_[0]); + + // Ionic current should be ica_nA/soma_area after integrating past event time. + const double time = 0.5; // [ms] + (void)fvcell.integrate(time, 0.01, {ev}, {}); + + double expected_iX = ica_nA*1e-9/soma_area_m2; + EXPECT_FLOAT_EQ(expected_iX, ion.iX_[0]); +} + // Test area-weighted linear combination of ion species concentrations TEST(fvm_lowered, weighted_write_ion) { diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index 34bc8cea8b3f444a64c286251e263f735768c66c..b12af8da2580092de82048956c9c750f4eed6266 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -8,6 +8,9 @@ #include "unit_test_catalogue.hpp" #include "mechanisms/celsius_test.hpp" +#include "mechanisms/fixed_ica_current.hpp" +#include "mechanisms/point_ica_current.hpp" +#include "mechanisms/linear_ca_conc.hpp" #include "mechanisms/test_cl_valence.hpp" #include "mechanisms/test_ca_read_valence.hpp" @@ -30,6 +33,9 @@ mechanism_catalogue make_unit_test_catalogue() { mechanism_catalogue cat; ADD_MECH(cat, celsius_test) + ADD_MECH(cat, fixed_ica_current) + ADD_MECH(cat, point_ica_current) + ADD_MECH(cat, linear_ca_conc) ADD_MECH(cat, test_cl_valence) ADD_MECH(cat, test_ca_read_valence)