diff --git a/mechanisms/CMakeLists.txt b/mechanisms/CMakeLists.txt index f5df463219cdf7e0ae586d4b2c47c5bf4ce4eb26..ac434b0d861d7624a86e9ed6a607422222a7d45f 100644 --- a/mechanisms/CMakeLists.txt +++ b/mechanisms/CMakeLists.txt @@ -1,7 +1,7 @@ include(BuildModules.cmake) # the list of built-in mechanisms to be provided by default -set(mechanisms pas hh expsyn exp2syn test_kin1 test_kinlva) +set(mechanisms pas hh expsyn exp2syn test_kin1 test_kinlva test_ca) set(mod_srcdir "${CMAKE_CURRENT_SOURCE_DIR}/mod") diff --git a/mechanisms/mod/test_ca.mod b/mechanisms/mod/test_ca.mod new file mode 100644 index 0000000000000000000000000000000000000000..67d09c028d62c35277aeacd75bf5f74d8c71e818 --- /dev/null +++ b/mechanisms/mod/test_ca.mod @@ -0,0 +1,34 @@ +: Example of mechanism that updates the concentration of an ionic species. + +NEURON { + SUFFIX test_ca + USEION ca READ ica WRITE cai +} + +UNITS { + (mV) = (millivolt) + (mA) = (milliamp) + (molar) = (1/liter) + (mM) = (millimolar) + (um) = (micron) +} + +PARAMETER { + decay = 80 (ms) : decay rate of calcium + cai0 = 1e-4 (mM) + factor = 2.59e-2 : gamma/2*F*depth +} + +ASSIGNED {} + +STATE { + cai (mM) +} + +BREAKPOINT { + SOLVE states METHOD cnexp +} + +DERIVATIVE states { + cai' = (cai0 - cai)/decay - factor*ica +} diff --git a/modcc/blocks.hpp b/modcc/blocks.hpp index fc26c91dc41d6c3a0383fae52c3b68516754f6ce..9256b8a633e1124d778d8c7a622737254281019a 100644 --- a/modcc/blocks.hpp +++ b/modcc/blocks.hpp @@ -1,5 +1,6 @@ #pragma once +#include <algorithm> #include <string> #include <vector> @@ -11,14 +12,42 @@ // describes a relationship with an ion channel struct IonDep { ionKind kind() const { - if(name=="k") return ionKind::K; - if(name=="na") return ionKind::Na; - if(name=="ca") return ionKind::Ca; - return ionKind::none; + return to_ionKind(name); } std::string name; // name of ion channel std::vector<Token> read; // name of channels parameters to write std::vector<Token> write; // name of channels parameters to read + + bool has_variable(std::string const& name) { + return writes_variable(name) || reads_variable(name); + }; + bool uses_current() { + return has_variable("i"+name); + }; + bool uses_rev_potential() { + return has_variable("e"+name); + }; + bool uses_concentration_int() { + return has_variable(name+"i"); + }; + bool uses_concentration_ext() { + return has_variable(name+"o"); + }; + bool writes_concentration_int() { + return writes_variable(name+"i"); + }; + bool writes_concentration_ext() { + return writes_variable(name+"o"); + }; + + bool reads_variable(const std::string& name) { + return std::find_if(read.begin(), read.end(), + [&name](const Token& t) {return t.spelling==name;}) != read.end(); + } + bool writes_variable(const std::string& name) { + return std::find_if(write.begin(), write.end(), + [&name](const Token& t) {return t.spelling==name;}) != write.end(); + } }; enum class moduleKind { diff --git a/modcc/cprinter.cpp b/modcc/cprinter.cpp index 641efe7002a8d060ebc9f4ae63a546b8bb80aae0..02a29a9b3c57f9db579cd3e91f973fc4ac02e974 100644 --- a/modcc/cprinter.cpp +++ b/modcc/cprinter.cpp @@ -193,42 +193,6 @@ std::string CPrinter::emit_source() { text_.add_line("}"); text_.add_line(); - // return true/false indicating if cell has dependency on k - auto const& ions = module_->neuron_block().ions; - auto find_ion = [&ions] (ionKind k) { - return std::find_if( - ions.begin(), ions.end(), - [k](IonDep const& d) {return d.kind()==k;} - ); - }; - auto has_ion = [&ions, find_ion] (ionKind k) { - return find_ion(k) != ions.end(); - }; - - // bool uses_ion(ionKind k) const override - text_.add_line("bool uses_ion(ionKind k) const override {"); - text_.increase_indentation(); - text_.add_line("switch(k) {"); - text_.increase_indentation(); - text_.add_gutter() - << "case ionKind::na : return " - << (has_ion(ionKind::Na) ? "true" : "false") << ";"; - text_.end_line(); - text_.add_gutter() - << "case ionKind::ca : return " - << (has_ion(ionKind::Ca) ? "true" : "false") << ";"; - text_.end_line(); - text_.add_gutter() - << "case ionKind::k : return " - << (has_ion(ionKind::K) ? "true" : "false") << ";"; - text_.end_line(); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line("return false;"); - text_.decrease_indentation(); - text_.add_line("}"); - text_.add_line(); - /*************************************************************************** * * ion channels have the following fields : @@ -245,67 +209,58 @@ std::string CPrinter::emit_source() { * **************************************************************************/ + // ion_spec uses_ion(ionKind k) const override + text_.add_line("typename base::ion_spec uses_ion(ionKind k) const override {"); + text_.increase_indentation(); + text_.add_line("bool uses = false;"); + text_.add_line("bool writes_ext = false;"); + text_.add_line("bool writes_int = false;"); + for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { + if (module_->has_ion(k)) { + auto ion = *module_->find_ion(k); + text_.add_line("if (k==ionKind::" + ion.name + ") {"); + text_.increase_indentation(); + text_.add_line("uses = true;"); + if (ion.writes_concentration_int()) text_.add_line("writes_int = true;"); + if (ion.writes_concentration_ext()) text_.add_line("writes_ext = true;"); + text_.decrease_indentation(); + text_.add_line("}"); + } + } + text_.add_line("return {uses, writes_int, writes_ext};"); + text_.decrease_indentation(); + text_.add_line("}"); + text_.add_line(); + // void set_ion(ionKind k, ion_type& i) override - // TODO: this is done manually, which isn't going to scale - auto has_variable = [] (IonDep const& ion, std::string const& name) { - if( std::find_if(ion.read.begin(), ion.read.end(), - [&name] (Token const& t) {return t.spelling==name;} - ) != ion.read.end() - ) return true; - if( std::find_if(ion.write.begin(), ion.write.end(), - [&name] (Token const& t) {return t.spelling==name;} - ) != ion.write.end() - ) return true; - return false; - }; text_.add_line("void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override {"); text_.increase_indentation(); - text_.add_line("using arb::algorithms::index_into;"); - if(has_ion(ionKind::Na)) { - auto ion = find_ion(ionKind::Na); - text_.add_line("if(k==ionKind::na) {"); - text_.increase_indentation(); - text_.add_line("ion_na.index = iarray(memory::make_const_view(index));"); - if(has_variable(*ion, "ina")) text_.add_line("ion_na.ina = i.current();"); - if(has_variable(*ion, "ena")) text_.add_line("ion_na.ena = i.reversal_potential();"); - if(has_variable(*ion, "nai")) text_.add_line("ion_na.nai = i.internal_concentration();"); - if(has_variable(*ion, "nao")) text_.add_line("ion_na.nao = i.external_concentration();"); - text_.add_line("return;"); - text_.decrease_indentation(); - text_.add_line("}"); - } - if(has_ion(ionKind::Ca)) { - auto ion = find_ion(ionKind::Ca); - text_.add_line("if(k==ionKind::ca) {"); - text_.increase_indentation(); - text_.add_line("ion_ca.index = iarray(memory::make_const_view(index));"); - if(has_variable(*ion, "ica")) text_.add_line("ion_ca.ica = i.current();"); - if(has_variable(*ion, "eca")) text_.add_line("ion_ca.eca = i.reversal_potential();"); - if(has_variable(*ion, "cai")) text_.add_line("ion_ca.cai = i.internal_concentration();"); - if(has_variable(*ion, "cao")) text_.add_line("ion_ca.cao = i.external_concentration();"); - text_.add_line("return;"); - text_.decrease_indentation(); - text_.add_line("}"); - } - if(has_ion(ionKind::K)) { - auto ion = find_ion(ionKind::K); - text_.add_line("if(k==ionKind::k) {"); - text_.increase_indentation(); - text_.add_line("ion_k.index = iarray(memory::make_const_view(index));"); - if(has_variable(*ion, "ik")) text_.add_line("ion_k.ik = i.current();"); - if(has_variable(*ion, "ek")) text_.add_line("ion_k.ek = i.reversal_potential();"); - if(has_variable(*ion, "ki")) text_.add_line("ion_k.ki = i.internal_concentration();"); - if(has_variable(*ion, "ko")) text_.add_line("ion_k.ko = i.external_concentration();"); - text_.add_line("return;"); - text_.decrease_indentation(); - text_.add_line("}"); + for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { + if (module_->has_ion(k)) { + auto ion = *module_->find_ion(k); + text_.add_line("if (k==ionKind::" + ion.name + ") {"); + text_.increase_indentation(); + auto n = ion.name; + auto pre = "ion_"+n; + text_.add_line(pre+".index = memory::make_const_view(index);"); + if (ion.uses_current()) + text_.add_line(pre+".i"+n+" = i.current();"); + if (ion.uses_rev_potential()) + text_.add_line(pre+".e"+n+" = i.reversal_potential();"); + if (ion.uses_concentration_int()) + text_.add_line(pre+"."+n+"i = i.internal_concentration();"); + if (ion.uses_concentration_ext()) + text_.add_line(pre+"."+n+"o = i.external_concentration();"); + text_.add_line("return;"); + text_.decrease_indentation(); + text_.add_line("}"); + } } text_.add_line("throw std::domain_error(arb::util::pprintf(\"mechanism % does not support ion type\\n\", name()));"); text_.decrease_indentation(); text_.add_line("}"); text_.add_line(); - ////////////////////////////////////////////// ////////////////////////////////////////////// auto proctest = [] (procedureKind k) { @@ -344,6 +299,30 @@ std::string CPrinter::emit_source() { text_.add_line(); } + if(module_->write_backs().size()) { + text_.add_line("void write_back() override {"); + text_.increase_indentation(); + + text_.add_line("const size_type n_ = node_index_.size();"); + for (auto& w: module_->write_backs()) { + auto& src = w.source_name; + auto tgt = w.target_name; + tgt.erase(tgt.begin(), tgt.begin()+tgt.find('_')+1); + auto istore = ion_store(w.ion_kind)+"."; + + text_.add_line(); + text_.add_line("auto "+src+"_out_ = util::indirect_view("+istore+tgt+", "+istore+"index);"); + text_.add_line("for (size_type i_ = 0; i_ < n_; ++i_) {"); + text_.increase_indentation(); + text_.add_line("// 1/10 magic number due to unit normalisation"); + text_.add_line(src+"_out_["+istore+"index[i_]] += value_type(0.1)*weights_[i_]*"+src+"[i_];"); + text_.decrease_indentation(); text_.add_line("}"); + + } + text_.decrease_indentation(); text_.add_line("}"); + } + text_.add_line(); + // TODO: replace field_info() generation implemenation with separate schema info generation // as per #349. auto field_info_string = [](const std::string& kind, const Id& id) { diff --git a/modcc/cudaprinter.cpp b/modcc/cudaprinter.cpp index 8005dd8b45f2396e2c342e4c505faebd61778cbb..ba3ff2cd6be572e26505a84659336bba9c776ad5 100644 --- a/modcc/cudaprinter.cpp +++ b/modcc/cudaprinter.cpp @@ -119,6 +119,9 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) "void deliver_events_" + module_name_ +"(" + pack_name() + " params_, arb::fvm_size_type mech_id, deliverable_event_stream_state state);"); } } + if(module_->write_backs().size()) { + buffer().add_line("void write_back_"+module_name_+"("+pack_name()+" params_);"); + } buffer().add_line(); buffer().add_line("}} // namespace arb::gpu"); @@ -129,7 +132,7 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) set_buffer(impl_); // kernels - buffer().add_line("#include \"" + module_name_ + "_impl.hpp\""); + buffer().add_line("#include \"" + module_name_ + "_gpu_impl.hpp\""); buffer().add_line(); buffer().add_line("#include <backends/gpu/intrinsics.hpp>"); buffer().add_line("#include <backends/gpu/kernels/reduce_by_key.hpp>"); @@ -161,6 +164,32 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) } } } + + // print the write_back kernel + if(module_->write_backs().size()) { + buffer().add_line("__global__"); + buffer().add_line("void write_back_"+module_name_+"("+pack_name()+" params_) {"); + buffer().increase_indentation(); + buffer().add_line("using value_type = arb::fvm_value_type;"); + + buffer().add_line("auto tid_ = threadIdx.x + blockDim.x*blockIdx.x;"); + buffer().add_line("auto const n_ = params_.n_;"); + buffer().add_line("if(tid_<n_) {"); + buffer().increase_indentation(); + + for (auto& w: module_->write_backs()) { + auto& src = w.source_name; + auto& tgt = w.target_name; + + auto idx = src + "_idx_"; + buffer().add_line("auto "+idx+" = params_.ion_"+to_string(w.ion_kind)+"_idx_[tid_];"); + buffer().add_line("// 1/10 magic number due to unit normalisation"); + buffer().add_line("params_."+tgt+"["+idx+"] = value_type(0.1)*params_.weights_[tid_]*params_."+src+"[tid_];"); + } + buffer().decrease_indentation(); buffer().add_line("}"); + buffer().decrease_indentation(); buffer().add_line("}"); + } + buffer().decrease_indentation(); buffer().add_line("} // kernel namespace"); @@ -192,6 +221,20 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) buffer().add_line(); } } + + // add the write_back kernel wrapper if required by this module + if(module_->write_backs().size()) { + buffer().add_line("void write_back_"+module_name_+"("+pack_name()+" params_) {"); + buffer().increase_indentation(); + buffer().add_line("auto n = params_.n_;"); + buffer().add_line("constexpr int blockwidth = 128;"); + buffer().add_line("dim3 dim_block(blockwidth);"); + buffer().add_line("dim3 dim_grid(impl::block_count(n, blockwidth));"); + buffer().add_line("arb::gpu::kernels::write_back_"+module_name_+"<<<dim_grid, dim_block>>>(params_);"); + buffer().decrease_indentation(); + buffer().add_line("}"); + buffer().add_line(); + } buffer().add_line("}} // namespace arb::gpu"); // @@ -214,7 +257,7 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) buffer().add_line("#include <backends/gpu/multi_event_stream.hpp>"); buffer().add_line("#include <util/pprintf.hpp>"); buffer().add_line(); - buffer().add_line("#include \"" + module_name_ + "_impl.hpp\""); + buffer().add_line("#include \"" + module_name_ + "_gpu_impl.hpp\""); buffer().add_line(); buffer().add_line("namespace arb { namespace gpu{"); @@ -396,41 +439,6 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) ////////////////////////////////////////////// // print ion channel interface ////////////////////////////////////////////// - // return true/false indicating if cell has dependency on k - auto const& ions = m.neuron_block().ions; - auto find_ion = [&ions] (ionKind k) { - return std::find_if( - ions.begin(), ions.end(), - [k](IonDep const& d) {return d.kind()==k;} - ); - }; - auto has_ion = [&ions, find_ion] (ionKind k) { - return find_ion(k) != ions.end(); - }; - - // bool uses_ion(ionKind k) const override - buffer().add_line("bool uses_ion(ionKind k) const override {"); - buffer().increase_indentation(); - buffer().add_line("switch(k) {"); - buffer().increase_indentation(); - buffer().add_gutter() - << "case ionKind::na : return " - << (has_ion(ionKind::Na) ? "true" : "false") << ";"; - buffer().end_line(); - buffer().add_gutter() - << "case ionKind::ca : return " - << (has_ion(ionKind::Ca) ? "true" : "false") << ";"; - buffer().end_line(); - buffer().add_gutter() - << "case ionKind::k : return " - << (has_ion(ionKind::K) ? "true" : "false") << ";"; - buffer().end_line(); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line("return false;"); - buffer().decrease_indentation(); - buffer().add_line("}"); - buffer().add_line(); /*************************************************************************** * @@ -448,67 +456,58 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) * **************************************************************************/ - // void set_ion(ionKind k, ion_type& i, const std::vector<size_type>&) override - // TODO: this is done manually, which isn't going to scale - auto has_variable = [] (IonDep const& ion, std::string const& name) { - if( std::find_if(ion.read.begin(), ion.read.end(), - [&name] (Token const& t) {return t.spelling==name;} - ) != ion.read.end() - ) return true; - if( std::find_if(ion.write.begin(), ion.write.end(), - [&name] (Token const& t) {return t.spelling==name;} - ) != ion.write.end() - ) return true; - return false; - }; - buffer().add_line("void set_ion(ionKind k, ion_type& i, const std::vector<size_type>& index) override {"); + // ion_spec uses_ion(ionKind k) const override + buffer().add_line("typename base::ion_spec uses_ion(ionKind k) const override {"); buffer().increase_indentation(); - buffer().add_line("using arb::algorithms::index_into;"); - if(has_ion(ionKind::Na)) { - auto ion = find_ion(ionKind::Na); - buffer().add_line("if(k==ionKind::na) {"); - buffer().increase_indentation(); - buffer().add_line("ion_na.index = iarray(memory::make_const_view(index));"); - if(has_variable(*ion, "ina")) buffer().add_line("ion_na.ina = i.current();"); - if(has_variable(*ion, "ena")) buffer().add_line("ion_na.ena = i.reversal_potential();"); - if(has_variable(*ion, "nai")) buffer().add_line("ion_na.nai = i.internal_concentration();"); - if(has_variable(*ion, "nao")) buffer().add_line("ion_na.nao = i.external_concentration();"); - buffer().add_line("return;"); - buffer().decrease_indentation(); - buffer().add_line("}"); - } - if(has_ion(ionKind::Ca)) { - auto ion = find_ion(ionKind::Ca); - buffer().add_line("if(k==ionKind::ca) {"); - buffer().increase_indentation(); - buffer().add_line("ion_ca.index = iarray(memory::make_const_view(index));"); - if(has_variable(*ion, "ica")) buffer().add_line("ion_ca.ica = i.current();"); - if(has_variable(*ion, "eca")) buffer().add_line("ion_ca.eca = i.reversal_potential();"); - if(has_variable(*ion, "cai")) buffer().add_line("ion_ca.cai = i.internal_concentration();"); - if(has_variable(*ion, "cao")) buffer().add_line("ion_ca.cao = i.external_concentration();"); - buffer().add_line("return;"); - buffer().decrease_indentation(); - buffer().add_line("}"); + buffer().add_line("bool uses = false;"); + buffer().add_line("bool writes_ext = false;"); + buffer().add_line("bool writes_int = false;"); + for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { + if (module_->has_ion(k)) { + auto ion = *module_->find_ion(k); + buffer().add_line("if (k==ionKind::" + ion.name + ") {"); + buffer().increase_indentation(); + buffer().add_line("uses = true;"); + if (ion.writes_concentration_int()) buffer().add_line("writes_int = true;"); + if (ion.writes_concentration_ext()) buffer().add_line("writes_ext = true;"); + buffer().decrease_indentation(); + buffer().add_line("}"); + } } - if(has_ion(ionKind::K)) { - auto ion = find_ion(ionKind::K); - buffer().add_line("if(k==ionKind::k) {"); - buffer().increase_indentation(); - buffer().add_line("ion_k.index = iarray(memory::make_const_view(index));"); - if(has_variable(*ion, "ik")) buffer().add_line("ion_k.ik = i.current();"); - if(has_variable(*ion, "ek")) buffer().add_line("ion_k.ek = i.reversal_potential();"); - if(has_variable(*ion, "ki")) buffer().add_line("ion_k.ki = i.internal_concentration();"); - if(has_variable(*ion, "ko")) buffer().add_line("ion_k.ko = i.external_concentration();"); - buffer().add_line("return;"); - buffer().decrease_indentation(); - buffer().add_line("}"); + buffer().add_line("return {uses, writes_int, writes_ext};"); + buffer().decrease_indentation(); + buffer().add_line("}"); + buffer().add_line(); + + // void set_ion(ionKind k, ion_type& i) override + buffer().add_line("void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override {"); + buffer().increase_indentation(); + for (auto k: {ionKind::Na, ionKind::Ca, ionKind::K}) { + if (module_->has_ion(k)) { + auto ion = *module_->find_ion(k); + buffer().add_line("if (k==ionKind::" + ion.name + ") {"); + buffer().increase_indentation(); + auto n = ion.name; + auto pre = "ion_"+n; + buffer().add_line(pre+".index = memory::make_const_view(index);"); + if (ion.uses_current()) + buffer().add_line(pre+".i"+n+" = i.current();"); + if (ion.uses_rev_potential()) + buffer().add_line(pre+".e"+n+" = i.reversal_potential();"); + if (ion.uses_concentration_int()) + buffer().add_line(pre+"."+n+"i = i.internal_concentration();"); + if (ion.uses_concentration_ext()) + buffer().add_line(pre+"."+n+"o = i.external_concentration();"); + buffer().add_line("return;"); + buffer().decrease_indentation(); + buffer().add_line("}"); + } } buffer().add_line("throw std::domain_error(arb::util::pprintf(\"mechanism % does not support ion type\\n\", name()));"); buffer().decrease_indentation(); buffer().add_line("}"); buffer().add_line(); - ////////////////////////////////////////////// ////////////////////////////////////////////// for(auto const &var : m.symbols()) { @@ -540,6 +539,14 @@ CUDAPrinter::CUDAPrinter(Module &m, bool o) } } + if(module_->write_backs().size()) { + buffer().add_line("void write_back() override {"); + buffer().increase_indentation(); + buffer().add_line("arb::gpu::write_back_"+module_name_+"(param_pack_);"); + buffer().decrease_indentation(); buffer().add_line("}"); + } + buffer().add_line(); + std::unordered_set<std::string> scalar_set; for (auto& v: scalar_variables) { scalar_set.insert(v->name()); diff --git a/modcc/identifier.hpp b/modcc/identifier.hpp index 60a09436acb63dcf2e2eb39cd49a25bb026aeddb..78a8ce5d19279ccf2a9dbf767c7a9f3c057012f8 100644 --- a/modcc/identifier.hpp +++ b/modcc/identifier.hpp @@ -1,6 +1,8 @@ #pragma once +#include <cstring> #include <string> +#include <stdexcept> /// indicate how a variable is accessed /// access is (read, written, or both) @@ -50,13 +52,22 @@ inline std::string yesno(bool val) { //////////////////////////////////////////// inline std::string to_string(ionKind i) { switch(i) { + case ionKind::Ca : return std::string("ca"); + case ionKind::Na : return std::string("na"); + case ionKind::K : return std::string("k"); case ionKind::none : return std::string("none"); - case ionKind::Ca : return std::string("calcium"); - case ionKind::Na : return std::string("sodium"); - case ionKind::K : return std::string("potassium"); case ionKind::nonspecific : return std::string("nonspecific"); } - return std::string("<error : undefined ionKind>"); + throw std::runtime_error("unknown ionKind"); +} + +inline ionKind to_ionKind(const std::string& s) { + if(s=="k") return ionKind::K; + if(s=="na") return ionKind::Na; + if(s=="ca") return ionKind::Ca; + if(s=="none") return ionKind::Ca; + if(s=="nonspecific") return ionKind::nonspecific; + throw std::runtime_error("invalid ion description string"); } inline std::string to_string(visibilityKind v) { diff --git a/modcc/module.cpp b/modcc/module.cpp index ff8e80c182f02b3f6153aa58a1849b604f614e1f..02db401871036e0fd288df78ce2c0755ad59390b 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -390,7 +390,7 @@ void Module::add_variables_to_symbols() { { if(symbols_.count(name)) { throw compiler_exception( - "trying to insert a symbol that already exists", + pprintf("the symbol % already exists", yellow(name)), loc); } symbols_[name] = @@ -496,27 +496,48 @@ void Module::add_variables_to_symbols() { auto update_ion_symbols = [this, create_indexed_variable] (Token const& tkn, accessKind acc, ionKind channel) { - auto const& var = tkn.spelling; + auto const& name = tkn.spelling; + + if(has_symbol(name)) { + auto sym = symbols_[name].get(); + + // if sym is an indexed_variable: error + // else if sym is a state variable: register a writeback call + // else if sym is a range (non parameter) variable: error + // else if sym is a parameter variable: error + // else it does not exist so make an indexed variable + + // If an indexed variable has already been created with the same name + // throw an error. + if(sym->kind()==symbolKind::indexed_variable) { + error(pprintf("the symbol defined % at % can't be redeclared", + sym->location(), yellow(name)), + tkn.location); + return; + } + else if(sym->kind()==symbolKind::variable) { + auto var = sym->is_variable(); + + // state variable: register writeback + if(var->is_state()) { + // create writeback + write_backs_.push_back(WriteBack(name, "ion_"+name, channel)); + return; + } - // add the ion variable's indexed shadow - if(has_symbol(var)) { - auto sym = symbols_[var].get(); - - // has the user declared a range/parameter with the same name? - if(sym->kind()!=symbolKind::indexed_variable) { - warning( - pprintf("the symbol % clashes with the ion channel variable," - " and will be ignored", yellow(var)), - sym->location() - ); - // erase symbol - symbols_.erase(var); + // error: a normal range variable or parameter can't have the same + // name as an indexed ion variable + error(pprintf("the ion channel variable % at % can't be redeclared", + yellow(name), sym->location()), + tkn.location); + return; } } - create_indexed_variable(var, "ion_"+var, + // add the ion variable's indexed shadow + create_indexed_variable(name, "ion_"+name, acc==accessKind::read ? tok::eq : tok::plus, - acc, channel, tkn.location); +acc, channel, tkn.location); }; // check for nonspecific current diff --git a/modcc/module.hpp b/modcc/module.hpp index afadac59e17570a1d548a6ac94510db926861383..8195caaaf48a8973b4f8e8220c6c151158938e14 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -6,6 +6,7 @@ #include "blocks.hpp" #include "error.hpp" #include "expression.hpp" +#include "writeback.hpp" // wrapper around a .mod file class Module: public error_stack { @@ -98,6 +99,23 @@ public: void add_variables_to_symbols(); bool semantic(); + const std::vector<WriteBack>& write_backs() const { + return write_backs_; + } + + auto find_ion(ionKind k) -> decltype(neuron_block().ions.begin()) { + auto& ions = neuron_block().ions; + return std::find_if( + ions.begin(), ions.end(), + [k](IonDep const& d) {return d.kind()==k;} + ); + }; + + bool has_ion(ionKind k) { + return find_ion(k) != neuron_block().ions.end(); + }; + + private: moduleKind kind_; std::string title_; @@ -136,4 +154,6 @@ private: UnitsBlock units_block_; ParameterBlock parameter_block_; AssignedBlock assigned_block_; + + std::vector<WriteBack> write_backs_; }; diff --git a/modcc/writeback.hpp b/modcc/writeback.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8da02100a3af006d9a250d235f3ef9ebf7f419b --- /dev/null +++ b/modcc/writeback.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include <expression.hpp> +#include <identifier.hpp> + +// Holds the state required to generate a write_back call in a mechanism. +struct WriteBack { + // Name of the symbol inside the mechanism used to store. + // must be a state field + std::string source_name; + // Name of the field in the ion channel being written to. + std::string target_name; + // The ion channel being written to. + // must not be ionKind::none + ionKind ion_kind; + + WriteBack(std::string src, std::string tgt, ionKind k): + source_name(std::move(src)), target_name(std::move(tgt)), ion_kind(k) + {} +}; + diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 76b18755f735c0b4c7fa4f47a7c0f0bd7c64d721..4e1b4db31a6c17b18f59f812d625f0c4d3ead903 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,7 +31,7 @@ set(CUDA_SOURCES backends/gpu/multi_event_stream.cu backends/gpu/kernels/assemble_matrix.cu backends/gpu/kernels/interleave.cu - backends/gpu/kernels/nernst.cu + backends/gpu/kernels/ions.cu backends/gpu/kernels/solve_matrix.cu backends/gpu/kernels/stim_current.cu backends/gpu/kernels/take_samples.cu diff --git a/src/backends/gpu/fvm.cpp b/src/backends/gpu/fvm.cpp index 90b10e2c48248af267ebb14ad874a56bbe53ee4c..550962e06fba440e573a60e4bcdee1d92365056b 100644 --- a/src/backends/gpu/fvm.cpp +++ b/src/backends/gpu/fvm.cpp @@ -6,6 +6,7 @@ #include <mechanisms/gpu/exp2syn_gpu.hpp> #include <mechanisms/gpu/test_kin1_gpu.hpp> #include <mechanisms/gpu/test_kinlva_gpu.hpp> +#include <mechanisms/gpu/test_ca_gpu.hpp> namespace arb { namespace gpu { @@ -17,7 +18,8 @@ backend::mech_map_ = { { "expsyn", maker<mechanism_expsyn> }, { "exp2syn", maker<mechanism_exp2syn> }, { "test_kin1", maker<mechanism_test_kin1> }, - { "test_kinlva", maker<mechanism_test_kinlva> } + { "test_kinlva", maker<mechanism_test_kinlva> }, + { "test_ca", maker<mechanism_test_ca> } }; } // namespace gpu diff --git a/src/backends/gpu/fvm.hpp b/src/backends/gpu/fvm.hpp index 4dd50b4e38132003eb533cfab741ee6aef8a6aff..0a9a2c9a38c2b18774c0026f38acfabf2f76e760 100644 --- a/src/backends/gpu/fvm.hpp +++ b/src/backends/gpu/fvm.hpp @@ -13,7 +13,7 @@ #include "kernels/take_samples.hpp" #include "matrix_state_interleaved.hpp" #include "multi_event_stream.hpp" -#include "nernst.hpp" +#include "ions.hpp" #include "stimulus.hpp" #include "threshold_watcher.hpp" #include "time_ops.hpp" @@ -147,6 +147,14 @@ struct backend { arb::gpu::nernst(eX.size(), valency, temperature, Xo.data(), Xi.data(), eX.data()); } + static void init_concentration( + view Xi, view Xo, + const_view weight_Xi, const_view weight_Xo, + value_type c_int, value_type c_ext) + { + arb::gpu::init_concentration(Xi.size(), Xi.data(), Xo.data(), weight_Xi.data(), weight_Xo.data(), c_int, c_ext); + } + private: using maker_type = mechanism_ptr (*)(size_type, const_iview, const_view, const_view, const_view, view, view, array&&, iarray&&); static std::map<std::string, maker_type> mech_map_; diff --git a/src/backends/gpu/nernst.hpp b/src/backends/gpu/ions.hpp similarity index 57% rename from src/backends/gpu/nernst.hpp rename to src/backends/gpu/ions.hpp index 2597dada739f6d24dc853464964db57f5c28774b..6cd18b043185681b85b1b3c17edf89dba2307611 100644 --- a/src/backends/gpu/nernst.hpp +++ b/src/backends/gpu/ions.hpp @@ -14,6 +14,12 @@ void nernst(std::size_t n, int valency, const fvm_value_type* Xi, fvm_value_type* eX); +// prototype for inializing ion species concentrations +void init_concentration(std::size_t n, + fvm_value_type* Xi, fvm_value_type* Xo, + const fvm_value_type* weight_Xi, const fvm_value_type* weight_Xo, + fvm_value_type c_int, fvm_value_type c_ext); + } // namespace gpu } // namespace arb diff --git a/src/backends/gpu/kernels/ions.cu b/src/backends/gpu/kernels/ions.cu new file mode 100644 index 0000000000000000000000000000000000000000..3631315a796e20b0bca28751978745a6e61cd366 --- /dev/null +++ b/src/backends/gpu/kernels/ions.cu @@ -0,0 +1,63 @@ +#include <cstdint> + +#include <constants.hpp> + +#include "../ions.hpp" +#include "detail.hpp" + +namespace arb { +namespace gpu { + +namespace kernels { + template <typename T> + __global__ + void nernst(std::size_t n, int valency, T temperature, const T* Xo, const T* Xi, T* eX) { + auto i = threadIdx.x+blockIdx.x*blockDim.x; + + // factor 1e3 to scale from V -> mV + constexpr T RF = 1e3*constant::gas_constant/constant::faraday; + T factor = RF*temperature/valency; + if (i<n) { + eX[i] = factor*std::log(Xo[i]/Xi[i]); + } + } + + template <typename T> + __global__ + void init_concentration(std::size_t n, T* Xi, T* Xo, const T* weight_Xi, const T* weight_Xo, T c_int, T c_ext) { + auto i = threadIdx.x+blockIdx.x*blockDim.x; + + if (i<n) { + Xi[i] = c_int*weight_Xi[i]; + Xo[i] = c_ext*weight_Xo[i]; + } + } +} // namespace kernels + +void nernst(std::size_t n, + int valency, + fvm_value_type temperature, + const fvm_value_type* Xo, + const fvm_value_type* Xi, + fvm_value_type* eX) +{ + constexpr int block_dim = 128; + const int grid_dim = impl::block_count(n, block_dim); + kernels::nernst<<<grid_dim, block_dim>>> + (n, valency, temperature, Xo, Xi, eX); +} + +void init_concentration( + std::size_t n, + fvm_value_type* Xi, fvm_value_type* Xo, + const fvm_value_type* weight_Xi, const fvm_value_type* weight_Xo, + fvm_value_type c_int, fvm_value_type c_ext) +{ + constexpr int block_dim = 128; + const int grid_dim = impl::block_count(n, block_dim); + kernels::init_concentration<<<grid_dim, block_dim>>> + (n, Xi, Xo, weight_Xi, weight_Xo, c_int, c_ext); +} + +} // namespace gpu +} // namespace arb diff --git a/src/backends/gpu/kernels/nernst.cu b/src/backends/gpu/kernels/nernst.cu deleted file mode 100644 index 5e84d63bfe010570a67a93b6ec66fb85ddc5dc3c..0000000000000000000000000000000000000000 --- a/src/backends/gpu/kernels/nernst.cu +++ /dev/null @@ -1,39 +0,0 @@ -#include <cstdint> - -#include <constants.hpp> - -#include "../nernst.hpp" -#include "detail.hpp" - -namespace arb { -namespace gpu { - -namespace kernels { - template <typename T> - __global__ void nernst(std::size_t n, int valency, T temperature, const T* Xo, const T* Xi, T* eX) { - auto i = threadIdx.x+blockIdx.x*blockDim.x; - - // factor 1e3 to scale from V -> mV - constexpr T RF = 1e3*constant::gas_constant/constant::faraday; - T factor = RF*temperature/valency; - if (i<n) { - eX[i] = factor*std::log(Xo[i]/Xi[i]); - } - } -} // namespace kernels - -void nernst(std::size_t n, - int valency, - fvm_value_type temperature, - const fvm_value_type* Xo, - const fvm_value_type* Xi, - fvm_value_type* eX) -{ - constexpr int block_dim = 128; - const int grid_dim = impl::block_count(n, block_dim); - kernels::nernst<<<grid_dim, block_dim>>> - (n, valency, temperature, Xo, Xi, eX); -} - -} // namespace gpu -} // namespace arb diff --git a/src/backends/gpu/stimulus.hpp b/src/backends/gpu/stimulus.hpp index 08995e171dcd0198477dce777a89b4260da57600..72a3a3e0c481eebeaed622bab81c5dea16adb2c6 100644 --- a/src/backends/gpu/stimulus.hpp +++ b/src/backends/gpu/stimulus.hpp @@ -48,8 +48,8 @@ public: return mechanismKind::point; } - bool uses_ion(ionKind k) const override { - return false; + typename base::ion_spec uses_ion(ionKind k) const override { + return {false, false, false}; } void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override { diff --git a/src/backends/multicore/fvm.cpp b/src/backends/multicore/fvm.cpp index 1ebbdb292e22cf2f6f48db381d5572234d0218e8..2c34fb0039979bd29c51c553c92753fb703b3cdb 100644 --- a/src/backends/multicore/fvm.cpp +++ b/src/backends/multicore/fvm.cpp @@ -6,6 +6,7 @@ #include <mechanisms/multicore/exp2syn_cpu.hpp> #include <mechanisms/multicore/test_kin1_cpu.hpp> #include <mechanisms/multicore/test_kinlva_cpu.hpp> +#include <mechanisms/multicore/test_ca_cpu.hpp> namespace arb { namespace multicore { @@ -17,7 +18,8 @@ backend::mech_map_ = { { std::string("expsyn"), maker<mechanism_expsyn> }, { std::string("exp2syn"), maker<mechanism_exp2syn> }, { std::string("test_kin1"), maker<mechanism_test_kin1> }, - { std::string("test_kinlva"), maker<mechanism_test_kinlva> } + { std::string("test_kinlva"), maker<mechanism_test_kinlva> }, + { std::string("test_ca"), maker<mechanism_test_ca> } }; } // namespace multicore diff --git a/src/backends/multicore/fvm.hpp b/src/backends/multicore/fvm.hpp index a6f758a12d05480926c388f793637b4d6665c585..4a0ff6607e7bef389a1b821f8f25724df0e46de2 100644 --- a/src/backends/multicore/fvm.hpp +++ b/src/backends/multicore/fvm.hpp @@ -165,6 +165,17 @@ struct backend { } } + static void init_concentration( + view Xi, view Xo, + const_view weight_Xi, const_view weight_Xo, + value_type c_int, value_type c_ext) + { + for (std::size_t i=0u; i<Xi.size(); ++i) { + Xi[i] = c_int*weight_Xi[i]; + Xo[i] = c_ext*weight_Xo[i]; + } + } + private: using maker_type = mechanism_ptr (*)(value_type, const_iview, const_view, const_view, const_view, view, view, array&&, iarray&&); static std::map<std::string, maker_type> mech_map_; diff --git a/src/backends/multicore/stimulus.hpp b/src/backends/multicore/stimulus.hpp index fba35ca99cfa5145020752fb614447921a85975a..c7c1e9cd7ae9b7c209027a315c37dbd6b8e016ec 100644 --- a/src/backends/multicore/stimulus.hpp +++ b/src/backends/multicore/stimulus.hpp @@ -46,8 +46,8 @@ public: return mechanismKind::point; } - bool uses_ion(ionKind k) const override { - return false; + typename base::ion_spec uses_ion(ionKind k) const override { + return {false, false, false}; } void set_ion(ionKind k, ion_type& i, std::vector<size_type>const& index) override { diff --git a/src/fvm_multicell.hpp b/src/fvm_multicell.hpp index a36ee18b9f75c5528a37c5568c44a65ad4a5e0dd..4a4dd52a03a5db1f501e4c094a018aebac73ef3e 100644 --- a/src/fvm_multicell.hpp +++ b/src/fvm_multicell.hpp @@ -849,6 +849,9 @@ void fvm_multicell<Backend>::initialize( // Keep cv index list for each mechanism for ion set up below. std::map<std::string, std::vector<size_type>> mech_to_cv_index; + // Keep area of each cv occupied by each mechanism, which may be less than + // the total area of the cv. + std::map<std::string, std::vector<value_type>> mech_to_area; // Working vectors (re-used per mechanism). std::vector<size_type> mech_cv(ncomp); @@ -952,6 +955,9 @@ void fvm_multicell<Backend>::initialize( } } + // Save the areas for ion setup below. + mech_to_area[mech_name] = mech_weight; + for (auto& entry: param_tbl) { for (size_type i = 0; i<nindex; ++i) { entry.second.values[i] /= mech_weight[i]; @@ -1032,26 +1038,55 @@ void fvm_multicell<Backend>::initialize( // mechanism that depends on/influences ion std::set<size_type> index_set; for (auto const& mech : mechanisms_) { - if(mech->uses_ion(ion)) { + if(mech->uses_ion(ion).uses) { auto const& ni = mech_to_cv_index[mech->name()]; index_set.insert(ni.begin(), ni.end()); } } std::vector<size_type> indexes(index_set.begin(), index_set.end()); + const auto n = indexes.size(); + + if (n==0u) continue; // create the ion state - if (indexes.size()) { - ions_[ion] = indexes; + ions_[ion] = indexes; + + std::vector<value_type> w_int; + w_int.reserve(n); + for (auto i: indexes) { + w_int.push_back(tmp_cv_areas[i]); } + std::vector<value_type> w_out = w_int; - // join the ion reference in each mechanism into the cell-wide ion state + // Join the ion reference in each mechanism into the cell-wide ion state. for (auto& mech : mechanisms_) { - if (mech->uses_ion(ion)) { - auto const& ni = mech_to_cv_index[mech->name()]; - mech->set_ion(ion, ions_[ion], - util::make_copy<std::vector<size_type>> (algorithms::index_into(ni, indexes))); + const auto spec = mech->uses_ion(ion); + if (spec.uses) { + const auto& ni = mech_to_cv_index[mech->name()]; + const auto m = ni.size(); // number of CVs + const std::vector<size_type> sub_index = + util::assign_from(algorithms::index_into(ni, indexes)); + mech->set_ion(ion, ions_[ion], sub_index); + + const auto& ai = mech_to_area[mech->name()]; + if (spec.write_concentration_in) { + for (auto i: make_span(0, m)) { + w_int[sub_index[i]] -= ai[i]; + } + } + if (spec.write_concentration_out) { + for (auto i: make_span(0, m)) { + w_out[sub_index[i]] -= ai[i]; + } + } } } + // Normalise the weights. + for (auto i: make_span(0, n)) { + w_int[i] /= tmp_cv_areas[indexes[i]]; + w_out[i] /= tmp_cv_areas[indexes[i]]; + } + ions_[ion].set_weights(w_int, w_out); } // Note: NEURON defined default values for reversal potential as follows, @@ -1065,16 +1100,16 @@ void fvm_multicell<Backend>::initialize( // Whereas we use the Nernst equation to calculate reversal potentials at // the start of each time step. - ion_na().default_internal_concentration = 10; - ion_na().default_external_concentration =140; + ion_na().default_int_concentration = 10; + ion_na().default_ext_concentration =140; ion_na().valency = 1; - ion_k().default_internal_concentration =54.4; - ion_k().default_external_concentration = 2.5; + ion_k().default_int_concentration =54.4; + ion_k().default_ext_concentration = 2.5; ion_k().valency = 1; - ion_ca().default_internal_concentration =5e-5; - ion_ca().default_external_concentration = 2.0; + ion_ca().default_int_concentration =5e-5; + ion_ca().default_ext_concentration = 2.0; ion_ca().valency = 2; // initialize mechanism and voltage state @@ -1099,6 +1134,7 @@ void fvm_multicell<Backend>::reset() { for (auto& m : mechanisms_) { m->set_params(); m->nrn_init(); + m->write_back(); } // Update reversal potential to account for changes to concentrations made @@ -1184,6 +1220,15 @@ void fvm_multicell<Backend>::step_integration() { } PL(); + PE("ion-update"); + for(auto& i: ions_) { + i.second.init_concentration(); + } + for(auto& m: mechanisms_) { + m->write_back(); + } + PL(); + memory::copy(time_to_, time_); invalidate_time_cache(); diff --git a/src/ion.hpp b/src/ion.hpp index 99530bdf1cda6993acd9d4be846f660944b91924..3d64d0370371bce0a66c0ae6779bea84cfe49909 100644 --- a/src/ion.hpp +++ b/src/ion.hpp @@ -65,14 +65,22 @@ public : Xi_{idx.size(), std::numeric_limits<value_type>::quiet_NaN()}, Xo_{idx.size(), std::numeric_limits<value_type>::quiet_NaN()}, valency(0), - default_internal_concentration(0), - default_external_concentration(0) + default_int_concentration(0), + default_ext_concentration(0) {} - std::size_t memory() const { - return 4u*size() * sizeof(value_type) - + size() * sizeof(iarray) - + sizeof(ion); + // Set the weights used when setting default concentration values in each CV. + // The concentration of an ion species in a CV is a linear combination of + // default concentration and contributions from mechanisms that update the + // concentration. The weight is a value between 0 and 1 that represents the + // proportion of the CV area for which the default value is to be used + // (i.e. the proportion of the CV where the concentration is prescribed by a + // mechanism). + void set_weights(const std::vector<value_type>& win, const std::vector<value_type>& wout) { + EXPECTS(win.size() == size()); + EXPECTS(wout.size() == size()); + weight_Xi_ = memory::make_const_view(win); + weight_Xo_ = memory::make_const_view(wout); } view current() { @@ -91,13 +99,20 @@ public : return Xo_; } + view internal_concentration_weights() { + return weight_Xi_; + } + + view external_concentration_weights() { + return weight_Xo_; + } + void reset() { // The Nernst equation uses the assumption of nonzero concentrations: - EXPECTS(default_internal_concentration> value_type(0)); - EXPECTS(default_external_concentration> value_type(0)); - memory::fill(iX_, 0); - memory::fill(Xi_, default_internal_concentration); - memory::fill(Xo_, default_external_concentration); + EXPECTS(default_int_concentration > value_type(0)); + EXPECTS(default_ext_concentration > value_type(0)); + memory::fill(iX_, 0); // reset current + init_concentration(); // reset internal and external concentrations nernst_reversal_potential(constant::hh_squid_temp); // TODO: use temperature specfied in model } @@ -107,6 +122,12 @@ public : backend::nernst(valency, temperature, Xo_, Xi_, eX_); } + void init_concentration() { + backend::init_concentration( + Xi_, Xo_, weight_Xi_, weight_Xo_, + default_int_concentration, default_ext_concentration); + } + const_iview node_index() const { return node_index_; } @@ -117,15 +138,17 @@ public : private: iarray node_index_; - array iX_; // (nA) current - array eX_; // (mV) reversal potential - array Xi_; // (mM) internal concentration - array Xo_; // (mM) external concentration + array iX_; // (nA) current + array eX_; // (mV) reversal potential + array Xi_; // (mM) internal concentration + array Xo_; // (mM) external concentration + array weight_Xi_; // (1) concentration weight internal + array weight_Xo_; // (1) concentration weight external public: int valency; // valency of ionic species - value_type default_internal_concentration; // (mM) default internal concentration - value_type default_external_concentration; // (mM) default external concentration + value_type default_int_concentration; // (mM) default internal concentration + value_type default_ext_concentration; // (mM) default external concentration }; } // namespace arb diff --git a/src/mechanism.hpp b/src/mechanism.hpp index 33d53872843bab5504494c942d7daacd70fec70b..b0c4105213d22801556a841217c6eefb6ac02cee 100644 --- a/src/mechanism.hpp +++ b/src/mechanism.hpp @@ -50,6 +50,12 @@ enum class mechanismKind {point, density}; template <typename Backend> class mechanism { public: + struct ion_spec { + bool uses; + bool write_concentration_in; + bool write_concentration_out; + }; + using backend = Backend; using value_type = typename backend::value_type; @@ -100,10 +106,15 @@ public: virtual void nrn_state() = 0; virtual void nrn_current() = 0; virtual void deliver_events(const deliverable_event_stream_state& events) {}; - virtual bool uses_ion(ionKind) const = 0; + virtual ion_spec uses_ion(ionKind) const = 0; virtual void set_ion(ionKind k, ion_type& i, const std::vector<size_type>& index) = 0; virtual mechanismKind kind() const = 0; + // Used by mechanisms that update ion concentrations. + // Calling will copy the concentration, stored as internal state of the + // mechanism, to the "global" copy of ion species state. + virtual void write_back() {}; + // Mechanism instances with different global parameter settings can be distinguished by alias. std::string alias() const { return alias_.empty()? name(): alias_; diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index ff5959382ecefe0a1e6dfe9d5a12978f31f78f1d..7410c7249a933bd798bd79e508f4086fc655fd6c 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -1,7 +1,7 @@ include(${PROJECT_SOURCE_DIR}/mechanisms/BuildModules.cmake) # Build prototype mechanisms for testing in test_mechanisms. -set(proto_mechanisms pas hh expsyn exp2syn test_kin1 test_kinlva) +set(proto_mechanisms pas hh expsyn exp2syn test_kin1 test_kinlva test_ca) set(mech_proto_dir "${CMAKE_CURRENT_BINARY_DIR}/mech_proto") file(MAKE_DIRECTORY "${mech_proto_dir}") diff --git a/tests/unit/test_fvm_multi.cpp b/tests/unit/test_fvm_multi.cpp index d6e2e07720f64731992ca75f874c597a9662f362..1d2f940858be2ffeffb3b69d88947258c0438dec 100644 --- a/tests/unit/test_fvm_multi.cpp +++ b/tests/unit/test_fvm_multi.cpp @@ -1,6 +1,3 @@ -#include <cstddef> -#include <fstream> -#include <set> #include <vector> #include "../gtest.h" @@ -815,3 +812,81 @@ TEST(fvm_multi, target_handles_general) run_target_handle_test(handles); } +// Test area-weighted linear combination of ion species concentrations + +TEST(fvm_multi, ion_weights) { + using namespace arb; + + // Create a cell with 4 segments: + // - Soma (segment 0) plus three dendrites (1, 2, 3) meeting at a branch point. + // - Dendritic segments are given 1 compartments each. + // + // / + // d2 + // / + // s0-d1 + // \. + // d3 + // + // The CV corresponding to the branch point should comprise the terminal + // 1/2 of segment 1 and the initial 1/2 of segments 2 and 3. + // + // Geometry: + // soma 0: radius 5 µm + // dend 1: 100 µm long, 1 µm diameter cynlinder + // dend 2: 200 µm long, 1 µm diameter cynlinder + // dend 3: 100 µm long, 1 µm diameter cynlinder + // The radius of the soma is chosen such that the surface area of soma is + // the same as a 100µm dendrite, which makes it easier to describe the + // expected weights. + + cell c; + c.add_soma(5); + + c.add_cable(0, section_kind::dendrite, 0.5, 0.5, 100); + c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 200); + c.add_cable(1, section_kind::dendrite, 0.5, 0.5, 100); + + for (auto& s: c.segments()) s->set_compartments(1); + + std::vector<std::vector<int>> seg_sets = { + {0}, {0,2}, {2, 3}, {0, 1, 2, 3}, + }; + std::vector<std::vector<unsigned>> expected_nodes = { + {0}, {0, 1, 2}, {0, 1, 2, 3}, {0, 1, 2, 3}, + }; + std::vector<std::vector<fvm_value_type>> expected_wght = { + {1./3}, {1./3, 1./2, 0.}, {1./3, 1./4, 0., 0.}, {0., 0., 0., 0.}, + }; + + double con_int = 80; + double con_ext = 120; + for (auto run=0u; run<seg_sets.size(); ++run) { + for (auto i: seg_sets[run]) { + c.segments()[i]->add_mechanism(mechanism_spec("test_ca")); + } + + std::vector<fvm_cell::target_handle> targets; + probe_association_map<fvm_cell::probe_handle> probe_map; + + fvm_cell fvcell; + fvcell.initialize({0}, cable1d_recipe(c), targets, probe_map); + + auto& ion = fvcell.ion_ca(); + ion.default_int_concentration = con_int; + ion.default_ext_concentration = con_ext; + ion.init_concentration(); + + auto& nodes = expected_nodes[run]; + auto& weights = expected_wght[run]; + auto ncv = nodes.size(); + EXPECT_EQ(ncv, ion.node_index().size()); + for (auto i: util::make_span(0, ncv)) { + EXPECT_EQ(nodes[i], ion.node_index()[i]); + EXPECT_FLOAT_EQ(weights[i], ion.internal_concentration_weights()[i]); + + EXPECT_EQ(con_ext, ion.external_concentration()[i]); + EXPECT_FLOAT_EQ(1.0, ion.external_concentration_weights()[i]); + } + } +} diff --git a/tests/unit/test_mechanisms.cpp b/tests/unit/test_mechanisms.cpp index 3925e0fd348cfb0a18b26a52db5c6a356bdc1fc4..a49557f7adef09d392ee3bca1898887613942827 100644 --- a/tests/unit/test_mechanisms.cpp +++ b/tests/unit/test_mechanisms.cpp @@ -7,6 +7,7 @@ #include "mech_proto/pas_cpu.hpp" #include "mech_proto/test_kin1_cpu.hpp" #include "mech_proto/test_kinlva_cpu.hpp" +#include "mech_proto/test_ca_cpu.hpp" // modcc generated mechanisms #include "mechanisms/multicore/expsyn_cpu.hpp" @@ -15,6 +16,7 @@ #include "mechanisms/multicore/pas_cpu.hpp" #include "mechanisms/multicore/test_kin1_cpu.hpp" #include "mechanisms/multicore/test_kinlva_cpu.hpp" +#include "mechanisms/multicore/test_ca_cpu.hpp" #include <initializer_list> #include <backends/multicore/fvm.hpp> @@ -87,7 +89,7 @@ void mech_update(T* mech, unsigned num_iters) { memory::fill(ion.external_concentration(), 140.); ions[ion_kind] = ion; - if (mech->uses_ion(ion_kind)) { + if (mech->uses_ion(ion_kind).uses) { mech->set_ion(ion_kind, ions[ion_kind], ion_indexes); } } @@ -227,6 +229,10 @@ using mechanism_types = ::testing::Types< mechanism_info< arb::multicore::mechanism_test_kinlva<arb::multicore::backend>, arb::multicore::mechanism_test_kinlva_proto<arb::multicore::backend> + >, + mechanism_info< + arb::multicore::mechanism_test_ca<arb::multicore::backend>, + arb::multicore::mechanism_test_ca_proto<arb::multicore::backend> > >;