diff --git a/modcc/expression.cpp b/modcc/expression.cpp index ef404d43fd0e75352fb6bfcdd31b22707a170fda..1e06c82ae657c076d4bf0066feb7561db06778e7 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -359,6 +359,38 @@ void StoichExpression::semantic(scope_ptr scp) { } } +/******************************************************************************* + CompartmentExpression +*******************************************************************************/ + +expression_ptr CompartmentExpression::clone() const { + std::vector<expression_ptr> cloned_state_vars; + for(auto& e: state_vars()) { + cloned_state_vars.emplace_back(e->clone()); + } + + return make_expression<CompartmentExpression>(location_, scale_factor()->clone(), std::move(cloned_state_vars)); +} + +std::string CompartmentExpression::to_string() const { + std::string s; + s += scale_factor()->to_string(); + s += " {"; + bool first = true; + for(auto& e: state_vars()) { + if (!first) s += ","; + s += e->to_string(); + first = false; + } + s += "}"; + return s; +} + +void CompartmentExpression::semantic(scope_ptr scp) { + scope_ = scp; + scale_factor()->semantic(scp); +} + /******************************************************************************* LinearExpression *******************************************************************************/ @@ -1039,6 +1071,9 @@ void ConditionalExpression::accept(Visitor *v) { void PDiffExpression::accept(Visitor *v) { v->visit(this); } +void CompartmentExpression::accept(Visitor *v) { + v->visit(this); +} expression_ptr unary_expression( Location loc, tok op, diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 1860c7eea77dc57bf1cbfd3eaa63b0799263e55c..a5ef42a04b564c53a32c376b898284d93c98e796 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -37,6 +37,7 @@ class LinearExpression; class ReactionExpression; class StoichExpression; class StoichTermExpression; +class CompartmentExpression; class ConditionalExpression; class InitialBlock; class SolveExpression; @@ -175,6 +176,7 @@ public: virtual StoichExpression* is_stoich() {return nullptr;} virtual StoichTermExpression* is_stoich_term() {return nullptr;} virtual ConditionalExpression* is_conditional() {return nullptr;} + virtual CompartmentExpression* is_compartment() {return nullptr;} virtual InitialBlock* is_initial_block() {return nullptr;} virtual SolveExpression* is_solve_statement() {return nullptr;} virtual Symbol* is_symbol() {return nullptr;} @@ -860,6 +862,33 @@ private: expression_ptr rev_rate_; }; +class CompartmentExpression : public Expression { +public: + CompartmentExpression(Location loc, + expression_ptr&& scale_factor, + std::vector<expression_ptr>&& state_vars) + : Expression(loc), scale_factor_(std::move(scale_factor)), state_vars_(std::move(state_vars)) {} + + CompartmentExpression* is_compartment() override {return this;} + + std::string to_string() const override; + void semantic(scope_ptr scp) override; + expression_ptr clone() const override; + void accept(Visitor *v) override; + + expression_ptr& scale_factor() { return scale_factor_; } + const expression_ptr& scale_factor() const { return scale_factor_; } + + std::vector<expression_ptr>& state_vars() { return state_vars_; } + const std::vector<expression_ptr>& state_vars() const { return state_vars_; } + + ~CompartmentExpression() {} + +private: + expression_ptr scale_factor_; + std::vector<expression_ptr> state_vars_; +}; + class StoichTermExpression : public Expression { public: StoichTermExpression(Location loc, diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 0202d0b1bb54dbe45aab5db2d27f7d20a81ea42b..27dac4fca4a57c9c06a54a5e5b740fea3197b2f1 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1000,6 +1000,8 @@ expression_ptr Parser::parse_statement() { return parse_line_expression(); case tok::conserve : return parse_conserve_expression(); + case tok::compartment : + return parse_compartment_statement(); case tok::tilde : return parse_tilde_expression(); case tok::initial : @@ -1263,7 +1265,7 @@ expression_ptr Parser::parse_tilde_expression() { std::move(fwd), std::move(rev)); } else if (search_to_eol(tok::eq)) { - auto lhs_bin = parse_expression(); + auto lhs_bin = parse_expression(tok::eq); if(token_.type!=tok::eq) { error(pprintf("expected '%', found '%'", yellow("="), yellow(token_.spelling))); @@ -1304,13 +1306,13 @@ expression_ptr Parser::parse_conserve_expression() { return make_expression<ConserveExpression>(here, std::move(lhs), std::move(rhs)); } -expression_ptr Parser::parse_expression(int prec) { +expression_ptr Parser::parse_expression(int prec, tok stop_token) { auto lhs = parse_unaryop(); if(lhs==nullptr) return nullptr; // Combine all sub-expressions with precedence greater than prec. for (;;) { - if(token_.type==tok::eq) { + if(token_.type==stop_token) { return lhs; } @@ -1335,6 +1337,10 @@ expression_ptr Parser::parse_expression() { return parse_expression(0); } +expression_ptr Parser::parse_expression(tok t) { + return parse_expression(0, t); +} + /// Parse a unary expression. /// If called when the current node in the AST is not a unary expression the call /// will be forwarded to parse_primary. This mechanism makes it possible to parse @@ -1739,3 +1745,39 @@ expression_ptr Parser::parse_initial() { return make_expression<InitialBlock>(block_location, std::move(body)); } + +expression_ptr Parser::parse_compartment_statement() { + auto here = location_; + + if(token_.type!=tok::compartment) { + error(pprintf("expected '%', found '%'", yellow("COMPARTMENT"), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume 'COMPARTMENT' + auto scale_factor = parse_expression(tok::rbrace); + if (!scale_factor) return nullptr; + + if(token_.type != tok::lbrace) { + error(pprintf("expected '%', found '%'", yellow("{"), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume '{' + std::vector<expression_ptr> states; + while (token_.type!=tok::rbrace) { + // check identifier + if(token_.type != tok::identifier) { + error( "expected a valid identifier, found '" + + yellow(token_.spelling) + "'"); + return nullptr; + } + + auto e = make_expression<IdentifierExpression>(token_.location, token_.spelling); + states.emplace_back(std::move(e)); + + get_token(); // consume the identifier + } + get_token(); // consume the rbrace + return make_expression<CompartmentExpression>(here, std::move(scale_factor), std::move(states)); +} diff --git a/modcc/parser.hpp b/modcc/parser.hpp index ad841b57f99ec03fab04433688782d6df2cc5f6c..09130e9a83b33f0707aede7ea0906915d47864c3 100644 --- a/modcc/parser.hpp +++ b/modcc/parser.hpp @@ -21,8 +21,9 @@ public: expression_ptr parse_integer(); expression_ptr parse_real(); expression_ptr parse_call(); - expression_ptr parse_expression(int prec); + expression_ptr parse_expression(int prec, tok t=tok::eq); expression_ptr parse_expression(); + expression_ptr parse_expression(tok); expression_ptr parse_primary(); expression_ptr parse_parenthesis_expression(); expression_ptr parse_line_expression(); @@ -37,6 +38,7 @@ public: expression_ptr parse_conductance(); expression_ptr parse_block(bool); expression_ptr parse_initial(); + expression_ptr parse_compartment_statement(); expression_ptr parse_if(); symbol_ptr parse_procedure(); diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index 6f77091263e5bc873bfe2b703f726c6aa018817e..3ab902551ff2bb7af6d04a0fc8ee2fc5705af3e7 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -180,10 +180,25 @@ void SparseSolverVisitor::visit(BlockExpression* e) { dvars_.push_back(id->name()); } } + scale_factor_.resize(dvars_.size()); BlockRewriterBase::visit(e); } +void SparseSolverVisitor::visit(CompartmentExpression *e) { + auto loc = e->location(); + + for (auto& s: e->is_compartment()->state_vars()) { + auto it = std::find(dvars_.begin(), dvars_.end(), s->is_identifier()->spelling()); + if (it == dvars_.end()) { + error({"COMPARTMENT variable is not used", loc}); + return; + } + auto idx = it - dvars_.begin(); + scale_factor_[idx] = e->scale_factor()->clone(); + } +} + void SparseSolverVisitor::visit(AssignmentExpression *e) { if (A_.empty()) { unsigned n = dvars_.size(); @@ -242,6 +257,10 @@ void SparseSolverVisitor::visit(AssignmentExpression *e) { expr = make_expression<MulBinaryExpression>(loc, r.coef[dvars_[j]]->clone(), dt_expr->clone()); + + if (scale_factor_[j]) { + expr = make_expression<DivBinaryExpression>(loc, std::move(expr), scale_factor_[j]->clone()); + } } if (j==deq_index_) { @@ -312,6 +331,9 @@ void SparseSolverVisitor::visit(ConserveExpression *e) { if (it != terms.end()) { auto expr = (*it)->is_stoich_term()->coeff()->clone(); + if (scale_factor_[j]) { + expr = make_expression<MulBinaryExpression>(loc, scale_factor_[j]->clone(), std::move(expr)); + } auto local_a_term = make_unique_local_assign(scope, expr.get(), "a_"); auto a_ = local_a_term.id->is_identifier()->spelling(); diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp index 1198d1b10461b6e7b0db6f5f03526b472044aed2..1aff2aea5552be56726e1b0084611af70ec68b47 100644 --- a/modcc/solvers.hpp +++ b/modcc/solvers.hpp @@ -85,6 +85,9 @@ protected: // Flag to indicate whether conserve statements are part of the system bool conserve_ = false; + // state variable multiplier/divider + std::vector<expression_ptr> scale_factor_; + // rhs of conserve statement std::vector<std::string> conserve_rhs_; std::vector<unsigned> conserve_idx_; @@ -96,6 +99,7 @@ public: virtual void visit(BlockExpression* e) override; virtual void visit(AssignmentExpression *e) override; + virtual void visit(CompartmentExpression *e) override; virtual void visit(ConserveExpression *e) override; virtual void finalize() override; virtual void reset() override { @@ -104,6 +108,7 @@ public: A_.clear(); symtbl_.clear(); conserve_ = false; + scale_factor_.clear(); conserve_rhs_.clear(); conserve_idx_.clear(); SolverVisitorBase::reset(); diff --git a/modcc/token.cpp b/modcc/token.cpp index 58915c160f3d3899352286c615a459fec749ad4a..4e09f545796e1f4e81a179e7e9d9c964ed09006d 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -51,6 +51,7 @@ static Keyword keywords[] = { {"THREADSAFE", tok::threadsafe}, {"GLOBAL", tok::global}, {"POINT_PROCESS", tok::point_process}, + {"COMPARTMENT", tok::compartment}, {"METHOD", tok::method}, {"if", tok::if_stmt}, {"else", tok::else_stmt}, @@ -122,6 +123,7 @@ static TokenString token_strings[] = { {"THREADSAFE", tok::threadsafe}, {"GLOBAL", tok::global}, {"POINT_PROCESS", tok::point_process}, + {"COMPARTMENT", tok::compartment}, {"METHOD", tok::method}, {"if", tok::if_stmt}, {"else", tok::else_stmt}, diff --git a/modcc/token.hpp b/modcc/token.hpp index 06b2ad6bb76e0db908c7bafeb773df406f8eeb46..5eef7a943e4c81296064e99badc4ecc2e229cd00 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -60,7 +60,7 @@ enum class tok { unitsoff, unitson, suffix, nonspecific_current, useion, read, write, valence, - range, local, conserve, + range, local, conserve, compartment, solve, method, threadsafe, global, point_process, diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index 8f7dd6e47e66b56f821e426a5a4b92f60320d54c..feb9d3a64ff8deb91d7bfc33b187611d8243073d 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -27,6 +27,7 @@ public: virtual void visit(ReactionExpression *e) { visit((Expression*) e); } virtual void visit(StoichTermExpression *e) { visit((Expression*) e); } virtual void visit(StoichExpression *e) { visit((Expression*) e); } + virtual void visit(CompartmentExpression *e) { visit((Expression*) e); } virtual void visit(VariableExpression *e) { visit((Expression*) e); } virtual void visit(IndexedVariable *e) { visit((Expression*) e); } virtual void visit(FunctionExpression *e) { visit((Expression*) e); } diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 751647b42a6046c3450477c3182ac65d3d3d56d4..4bebb1919fb99f20bbb3944a5765071d66ba458a 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -9,6 +9,8 @@ set(test_mechanisms test0_kin_conserve test1_kin_diff test1_kin_conserve + test0_kin_compartment + test1_kin_compartment fixed_ica_current point_ica_current linear_ca_conc diff --git a/test/unit/mod/test0_kin_compartment.mod b/test/unit/mod/test0_kin_compartment.mod new file mode 100644 index 0000000000000000000000000000000000000000..7f3a7bbafd217c0883dc186ffdcb9bd15f749da0 --- /dev/null +++ b/test/unit/mod/test0_kin_compartment.mod @@ -0,0 +1,36 @@ +NEURON { + SUFFIX test0_kin_compartment +} + +STATE { + s d h +} + +PARAMETER { + A = 0.5 + B = 0.1 +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT A {s h} + COMPARTMENT B {d} + + LOCAL alpha1, beta1, alpha2, beta2 + alpha1 = 2 + beta1 = 0.6 + alpha2 = 3 + beta2 = 0.7 + + ~ s <-> h (alpha1, beta1) + ~ d <-> s (alpha2, beta2) +} + +INITIAL { + h = 0.2 + d = 0.3 + s = 1-d-h +} diff --git a/test/unit/mod/test1_kin_compartment.mod b/test/unit/mod/test1_kin_compartment.mod new file mode 100644 index 0000000000000000000000000000000000000000..15f2c31152103514e3355b6863cacaf34270bc55 --- /dev/null +++ b/test/unit/mod/test1_kin_compartment.mod @@ -0,0 +1,36 @@ +NEURON { + SUFFIX test1_kin_compartment +} + +STATE { + s d h +} + +PARAMETER { + A = 0.5 +} + +BREAKPOINT { + SOLVE state METHOD sparse +} + +KINETIC state { + COMPARTMENT A {s h d} + + LOCAL alpha1, beta1, alpha2, beta2 + alpha1 = 2 + beta1 = 0.6 + alpha2 = 3 + beta2 = 0.7 + + ~ s <-> h (alpha1, beta1) + ~ d <-> s (alpha2, beta2) + + CONSERVE s + d + h = A +} + +INITIAL { + h = 0.2 + d = 0.3 + s = 1-d-h +} diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp index 16dddc4eda39e8a51bc021473a4f1ed39df17fbf..a85fc46e42c463f50e3d248a417e9b2d9584158f 100644 --- a/test/unit/test_kinetic_linear.cpp +++ b/test/unit/test_kinetic_linear.cpp @@ -88,6 +88,17 @@ void run_test(std::string mech_name, } } +TEST(mech_kinetic, kintetic_scaled) { + std::vector<std::string> state_variables = {"s", "h", "d"}; + std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; + std::vector<fvm_value_type> t1_0_values = {0.373297, 0.591621, 0.0350817}; + std::vector<fvm_value_type> t1_1_values = {0.329897, 0.537371, 0.132732}; + + run_test<multicore::backend>("test0_kin_compartment", state_variables, {}, t0_values, t1_0_values); + run_test<multicore::backend>("test1_kin_compartment", state_variables, {}, t0_values, t1_1_values); + +} + TEST(mech_kinetic, kintetic_1_conserve) { std::vector<std::string> state_variables = {"s", "h", "d"}; std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; @@ -117,6 +128,16 @@ TEST(mech_linear, linear) { } #ifdef ARB_GPU_ENABLED +TEST(mech_kinetic_gpu, kintetic_scaled) { + std::vector<std::string> state_variables = {"s", "h", "d"}; + std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; + std::vector<fvm_value_type> t1_0_values = {0.373297, 0.591621, 0.0350817}; + std::vector<fvm_value_type> t1_1_values = {0.329897, 0.537371, 0.132732}; + + run_test<gpu::backend>("test0_kin_compartment", state_variables, {}, t0_values, t1_0_values); + run_test<gpu::backend>("test1_kin_compartment", state_variables, {}, t0_values, t1_1_values); +} + TEST(mech_kinetic_gpu, kintetic_1_conserve) { std::vector<std::string> state_variables = {"s", "h", "d"}; std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index 5e9606c2a4eb5e45755f143de76096bd341d7f0a..d5d54e4d24f895ad921275052119a9ed7bfece13 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -13,6 +13,8 @@ #include "mechanisms/test_linear_init.hpp" #include "mechanisms/test_linear_init_shuffle.hpp" #include "mechanisms/test0_kin_conserve.hpp" +#include "mechanisms/test0_kin_compartment.hpp" +#include "mechanisms/test1_kin_compartment.hpp" #include "mechanisms/test1_kin_diff.hpp" #include "mechanisms/test1_kin_conserve.hpp" #include "mechanisms/fixed_ica_current.hpp" @@ -48,6 +50,8 @@ mechanism_catalogue make_unit_test_catalogue() { ADD_MECH(cat, test_linear_init_shuffle) ADD_MECH(cat, test0_kin_diff) ADD_MECH(cat, test0_kin_conserve) + ADD_MECH(cat, test0_kin_compartment) + ADD_MECH(cat, test1_kin_compartment) ADD_MECH(cat, test1_kin_diff) ADD_MECH(cat, test1_kin_conserve) ADD_MECH(cat, fixed_ica_current)