From efe17c53c5f82f9b0f4f422bdb70648332789bf2 Mon Sep 17 00:00:00 2001 From: Nora Abi Akar <nora.abiakar@gmail.com> Date: Mon, 2 Sep 2019 12:41:54 +0200 Subject: [PATCH] Modcc compartment (#846) * Adds modcc support for COMPARTMENT statements of the form: `COMPARTMENT v {state_0, state_1, ..., state_n}`. * Use `COMPARTMENT` values `v` to multiply the derivative of state variables `state_0, state_1, ..., state_n` in associated kinetic scheme blocks. Fixes #838. --- modcc/expression.cpp | 35 ++++++++++++++++++ modcc/expression.hpp | 29 +++++++++++++++ modcc/parser.cpp | 48 +++++++++++++++++++++++-- modcc/parser.hpp | 4 ++- modcc/solvers.cpp | 22 ++++++++++++ modcc/solvers.hpp | 5 +++ modcc/token.cpp | 2 ++ modcc/token.hpp | 2 +- modcc/visitor.hpp | 1 + test/unit/CMakeLists.txt | 2 ++ test/unit/mod/test0_kin_compartment.mod | 36 +++++++++++++++++++ test/unit/mod/test1_kin_compartment.mod | 36 +++++++++++++++++++ test/unit/test_kinetic_linear.cpp | 21 +++++++++++ test/unit/unit_test_catalogue.cpp | 4 +++ 14 files changed, 242 insertions(+), 5 deletions(-) create mode 100644 test/unit/mod/test0_kin_compartment.mod create mode 100644 test/unit/mod/test1_kin_compartment.mod diff --git a/modcc/expression.cpp b/modcc/expression.cpp index ef404d43..1e06c82a 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -359,6 +359,38 @@ void StoichExpression::semantic(scope_ptr scp) { } } +/******************************************************************************* + CompartmentExpression +*******************************************************************************/ + +expression_ptr CompartmentExpression::clone() const { + std::vector<expression_ptr> cloned_state_vars; + for(auto& e: state_vars()) { + cloned_state_vars.emplace_back(e->clone()); + } + + return make_expression<CompartmentExpression>(location_, scale_factor()->clone(), std::move(cloned_state_vars)); +} + +std::string CompartmentExpression::to_string() const { + std::string s; + s += scale_factor()->to_string(); + s += " {"; + bool first = true; + for(auto& e: state_vars()) { + if (!first) s += ","; + s += e->to_string(); + first = false; + } + s += "}"; + return s; +} + +void CompartmentExpression::semantic(scope_ptr scp) { + scope_ = scp; + scale_factor()->semantic(scp); +} + /******************************************************************************* LinearExpression *******************************************************************************/ @@ -1039,6 +1071,9 @@ void ConditionalExpression::accept(Visitor *v) { void PDiffExpression::accept(Visitor *v) { v->visit(this); } +void CompartmentExpression::accept(Visitor *v) { + v->visit(this); +} expression_ptr unary_expression( Location loc, tok op, diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 1860c7ee..a5ef42a0 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -37,6 +37,7 @@ class LinearExpression; class ReactionExpression; class StoichExpression; class StoichTermExpression; +class CompartmentExpression; class ConditionalExpression; class InitialBlock; class SolveExpression; @@ -175,6 +176,7 @@ public: virtual StoichExpression* is_stoich() {return nullptr;} virtual StoichTermExpression* is_stoich_term() {return nullptr;} virtual ConditionalExpression* is_conditional() {return nullptr;} + virtual CompartmentExpression* is_compartment() {return nullptr;} virtual InitialBlock* is_initial_block() {return nullptr;} virtual SolveExpression* is_solve_statement() {return nullptr;} virtual Symbol* is_symbol() {return nullptr;} @@ -860,6 +862,33 @@ private: expression_ptr rev_rate_; }; +class CompartmentExpression : public Expression { +public: + CompartmentExpression(Location loc, + expression_ptr&& scale_factor, + std::vector<expression_ptr>&& state_vars) + : Expression(loc), scale_factor_(std::move(scale_factor)), state_vars_(std::move(state_vars)) {} + + CompartmentExpression* is_compartment() override {return this;} + + std::string to_string() const override; + void semantic(scope_ptr scp) override; + expression_ptr clone() const override; + void accept(Visitor *v) override; + + expression_ptr& scale_factor() { return scale_factor_; } + const expression_ptr& scale_factor() const { return scale_factor_; } + + std::vector<expression_ptr>& state_vars() { return state_vars_; } + const std::vector<expression_ptr>& state_vars() const { return state_vars_; } + + ~CompartmentExpression() {} + +private: + expression_ptr scale_factor_; + std::vector<expression_ptr> state_vars_; +}; + class StoichTermExpression : public Expression { public: StoichTermExpression(Location loc, diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 0202d0b1..27dac4fc 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1000,6 +1000,8 @@ expression_ptr Parser::parse_statement() { return parse_line_expression(); case tok::conserve : return parse_conserve_expression(); + case tok::compartment : + return parse_compartment_statement(); case tok::tilde : return parse_tilde_expression(); case tok::initial : @@ -1263,7 +1265,7 @@ expression_ptr Parser::parse_tilde_expression() { std::move(fwd), std::move(rev)); } else if (search_to_eol(tok::eq)) { - auto lhs_bin = parse_expression(); + auto lhs_bin = parse_expression(tok::eq); if(token_.type!=tok::eq) { error(pprintf("expected '%', found '%'", yellow("="), yellow(token_.spelling))); @@ -1304,13 +1306,13 @@ expression_ptr Parser::parse_conserve_expression() { return make_expression<ConserveExpression>(here, std::move(lhs), std::move(rhs)); } -expression_ptr Parser::parse_expression(int prec) { +expression_ptr Parser::parse_expression(int prec, tok stop_token) { auto lhs = parse_unaryop(); if(lhs==nullptr) return nullptr; // Combine all sub-expressions with precedence greater than prec. for (;;) { - if(token_.type==tok::eq) { + if(token_.type==stop_token) { return lhs; } @@ -1335,6 +1337,10 @@ expression_ptr Parser::parse_expression() { return parse_expression(0); } +expression_ptr Parser::parse_expression(tok t) { + return parse_expression(0, t); +} + /// Parse a unary expression. /// If called when the current node in the AST is not a unary expression the call /// will be forwarded to parse_primary. This mechanism makes it possible to parse @@ -1739,3 +1745,39 @@ expression_ptr Parser::parse_initial() { return make_expression<InitialBlock>(block_location, std::move(body)); } + +expression_ptr Parser::parse_compartment_statement() { + auto here = location_; + + if(token_.type!=tok::compartment) { + error(pprintf("expected '%', found '%'", yellow("COMPARTMENT"), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume 'COMPARTMENT' + auto scale_factor = parse_expression(tok::rbrace); + if (!scale_factor) return nullptr; + + if(token_.type != tok::lbrace) { + error(pprintf("expected '%', found '%'", yellow("{"), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume '{' + std::vector<expression_ptr> states; + while (token_.type!=tok::rbrace) { + // check identifier + if(token_.type != tok::identifier) { + error( "expected a valid identifier, found '" + + yellow(token_.spelling) + "'"); + return nullptr; + } + + auto e = make_expression<IdentifierExpression>(token_.location, token_.spelling); + states.emplace_back(std::move(e)); + + get_token(); // consume the identifier + } + get_token(); // consume the rbrace + return make_expression<CompartmentExpression>(here, std::move(scale_factor), std::move(states)); +} diff --git a/modcc/parser.hpp b/modcc/parser.hpp index ad841b57..09130e9a 100644 --- a/modcc/parser.hpp +++ b/modcc/parser.hpp @@ -21,8 +21,9 @@ public: expression_ptr parse_integer(); expression_ptr parse_real(); expression_ptr parse_call(); - expression_ptr parse_expression(int prec); + expression_ptr parse_expression(int prec, tok t=tok::eq); expression_ptr parse_expression(); + expression_ptr parse_expression(tok); expression_ptr parse_primary(); expression_ptr parse_parenthesis_expression(); expression_ptr parse_line_expression(); @@ -37,6 +38,7 @@ public: expression_ptr parse_conductance(); expression_ptr parse_block(bool); expression_ptr parse_initial(); + expression_ptr parse_compartment_statement(); expression_ptr parse_if(); symbol_ptr parse_procedure(); diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index 6f770912..3ab90255 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -180,10 +180,25 @@ void SparseSolverVisitor::visit(BlockExpression* e) { dvars_.push_back(id->name()); } } + scale_factor_.resize(dvars_.size()); BlockRewriterBase::visit(e); } +void SparseSolverVisitor::visit(CompartmentExpression *e) { + auto loc = e->location(); + + for (auto& s: e->is_compartment()->state_vars()) { + auto it = std::find(dvars_.begin(), dvars_.end(), s->is_identifier()->spelling()); + if (it == dvars_.end()) { + error({"COMPARTMENT variable is not used", loc}); + return; + } + auto idx = it - dvars_.begin(); + scale_factor_[idx] = e->scale_factor()->clone(); + } +} + void SparseSolverVisitor::visit(AssignmentExpression *e) { if (A_.empty()) { unsigned n = dvars_.size(); @@ -242,6 +257,10 @@ void SparseSolverVisitor::visit(AssignmentExpression *e) { expr = make_expression<MulBinaryExpression>(loc, r.coef[dvars_[j]]->clone(), dt_expr->clone()); + + if (scale_factor_[j]) { + expr = make_expression<DivBinaryExpression>(loc, std::move(expr), scale_factor_[j]->clone()); + } } if (j==deq_index_) { @@ -312,6 +331,9 @@ void SparseSolverVisitor::visit(ConserveExpression *e) { if (it != terms.end()) { auto expr = (*it)->is_stoich_term()->coeff()->clone(); + if (scale_factor_[j]) { + expr = make_expression<MulBinaryExpression>(loc, scale_factor_[j]->clone(), std::move(expr)); + } auto local_a_term = make_unique_local_assign(scope, expr.get(), "a_"); auto a_ = local_a_term.id->is_identifier()->spelling(); diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp index 1198d1b1..1aff2aea 100644 --- a/modcc/solvers.hpp +++ b/modcc/solvers.hpp @@ -85,6 +85,9 @@ protected: // Flag to indicate whether conserve statements are part of the system bool conserve_ = false; + // state variable multiplier/divider + std::vector<expression_ptr> scale_factor_; + // rhs of conserve statement std::vector<std::string> conserve_rhs_; std::vector<unsigned> conserve_idx_; @@ -96,6 +99,7 @@ public: virtual void visit(BlockExpression* e) override; virtual void visit(AssignmentExpression *e) override; + virtual void visit(CompartmentExpression *e) override; virtual void visit(ConserveExpression *e) override; virtual void finalize() override; virtual void reset() override { @@ -104,6 +108,7 @@ public: A_.clear(); symtbl_.clear(); conserve_ = false; + scale_factor_.clear(); conserve_rhs_.clear(); conserve_idx_.clear(); SolverVisitorBase::reset(); diff --git a/modcc/token.cpp b/modcc/token.cpp index 58915c16..4e09f545 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -51,6 +51,7 @@ static Keyword keywords[] = { {"THREADSAFE", tok::threadsafe}, {"GLOBAL", tok::global}, {"POINT_PROCESS", tok::point_process}, + {"COMPARTMENT", tok::compartment}, {"METHOD", tok::method}, {"if", tok::if_stmt}, {"else", tok::else_stmt}, @@ -122,6 +123,7 @@ static TokenString token_strings[] = { {"THREADSAFE", tok::threadsafe}, {"GLOBAL", tok::global}, {"POINT_PROCESS", tok::point_process}, + {"COMPARTMENT", tok::compartment}, {"METHOD", tok::method}, {"if", tok::if_stmt}, {"else", tok::else_stmt}, diff --git a/modcc/token.hpp b/modcc/token.hpp index 06b2ad6b..5eef7a94 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -60,7 +60,7 @@ enum class tok { unitsoff, unitson, suffix, nonspecific_current, useion, read, write, valence, - range, local, conserve, + range, local, conserve, compartment, solve, method, threadsafe, global, point_process, diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index 8f7dd6e4..feb9d3a6 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -27,6 +27,7 @@ public: virtual void visit(ReactionExpression *e) { visit((Expression*) e); } virtual void visit(StoichTermExpression *e) { visit((Expression*) e); } virtual void visit(StoichExpression *e) { visit((Expression*) e); } + virtual void visit(CompartmentExpression *e) { visit((Expression*) e); } virtual void visit(VariableExpression *e) { visit((Expression*) e); } virtual void visit(IndexedVariable *e) { visit((Expression*) e); } virtual void visit(FunctionExpression *e) { visit((Expression*) e); } diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 751647b4..4bebb191 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -9,6 +9,8 @@ set(test_mechanisms test0_kin_conserve test1_kin_diff test1_kin_conserve + test0_kin_compartment + test1_kin_compartment fixed_ica_current point_ica_current linear_ca_conc diff --git a/test/unit/mod/test0_kin_compartment.mod b/test/unit/mod/test0_kin_compartment.mod new file mode 100644 index 00000000..7f3a7bba --- /dev/null +++ b/test/unit/mod/test0_kin_compartment.mod @@ -0,0 +1,36 @@ +NEURON { + SUFFIX test0_kin_compartment +} + +STATE { + s d h +} + +PARAMETER { + A = 0.5 + B = 0.1 +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT A {s h} + COMPARTMENT B {d} + + LOCAL alpha1, beta1, alpha2, beta2 + alpha1 = 2 + beta1 = 0.6 + alpha2 = 3 + beta2 = 0.7 + + ~ s <-> h (alpha1, beta1) + ~ d <-> s (alpha2, beta2) +} + +INITIAL { + h = 0.2 + d = 0.3 + s = 1-d-h +} diff --git a/test/unit/mod/test1_kin_compartment.mod b/test/unit/mod/test1_kin_compartment.mod new file mode 100644 index 00000000..15f2c311 --- /dev/null +++ b/test/unit/mod/test1_kin_compartment.mod @@ -0,0 +1,36 @@ +NEURON { + SUFFIX test1_kin_compartment +} + +STATE { + s d h +} + +PARAMETER { + A = 0.5 +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT A {s h d} + + LOCAL alpha1, beta1, alpha2, beta2 + alpha1 = 2 + beta1 = 0.6 + alpha2 = 3 + beta2 = 0.7 + + ~ s <-> h (alpha1, beta1) + ~ d <-> s (alpha2, beta2) + + CONSERVE s + d + h = A +} + +INITIAL { + h = 0.2 + d = 0.3 + s = 1-d-h +} diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp index 16dddc4e..a85fc46e 100644 --- a/test/unit/test_kinetic_linear.cpp +++ b/test/unit/test_kinetic_linear.cpp @@ -88,6 +88,17 @@ void run_test(std::string mech_name, } } +TEST(mech_kinetic, kintetic_scaled) { + std::vector<std::string> state_variables = {"s", "h", "d"}; + std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; + std::vector<fvm_value_type> t1_0_values = {0.373297, 0.591621, 0.0350817}; + std::vector<fvm_value_type> t1_1_values = {0.329897, 0.537371, 0.132732}; + + run_test<multicore::backend>("test0_kin_compartment", state_variables, {}, t0_values, t1_0_values); + run_test<multicore::backend>("test1_kin_compartment", state_variables, {}, t0_values, t1_1_values); + +} + TEST(mech_kinetic, kintetic_1_conserve) { std::vector<std::string> state_variables = {"s", "h", "d"}; std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; @@ -117,6 +128,16 @@ TEST(mech_linear, linear) { } #ifdef ARB_GPU_ENABLED +TEST(mech_kinetic_gpu, kintetic_scaled) { + std::vector<std::string> state_variables = {"s", "h", "d"}; + std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; + std::vector<fvm_value_type> t1_0_values = {0.373297, 0.591621, 0.0350817}; + std::vector<fvm_value_type> t1_1_values = {0.329897, 0.537371, 0.132732}; + + run_test<gpu::backend>("test0_kin_compartment", state_variables, {}, t0_values, t1_0_values); + run_test<gpu::backend>("test1_kin_compartment", state_variables, {}, t0_values, t1_1_values); +} + TEST(mech_kinetic_gpu, kintetic_1_conserve) { std::vector<std::string> state_variables = {"s", "h", "d"}; std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index 5e9606c2..d5d54e4d 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -13,6 +13,8 @@ #include "mechanisms/test_linear_init.hpp" #include "mechanisms/test_linear_init_shuffle.hpp" #include "mechanisms/test0_kin_conserve.hpp" +#include "mechanisms/test0_kin_compartment.hpp" +#include "mechanisms/test1_kin_compartment.hpp" #include "mechanisms/test1_kin_diff.hpp" #include "mechanisms/test1_kin_conserve.hpp" #include "mechanisms/fixed_ica_current.hpp" @@ -48,6 +50,8 @@ mechanism_catalogue make_unit_test_catalogue() { ADD_MECH(cat, test_linear_init_shuffle) ADD_MECH(cat, test0_kin_diff) ADD_MECH(cat, test0_kin_conserve) + ADD_MECH(cat, test0_kin_compartment) + ADD_MECH(cat, test1_kin_compartment) ADD_MECH(cat, test1_kin_diff) ADD_MECH(cat, test1_kin_conserve) ADD_MECH(cat, fixed_ica_current) -- GitLab