diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 250af45d7f31877992edc19be2ec0fd2e8682a38..e90bc3cad1b5fc4d3430fa65d3b6d3d5163aab00 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -789,7 +789,7 @@ void SolveExpression::semantic(scope_ptr scp) { } expression_ptr SolveExpression::clone() const { - auto s = new SolveExpression(location_, name_, method_); + auto s = new SolveExpression(location_, name_, method_, variant_); s->procedure(procedure_); return expression_ptr{s}; } diff --git a/modcc/expression.hpp b/modcc/expression.hpp index a5ef42a04b564c53a32c376b898284d93c98e796..fe3eab9fb438ffab450c0292f82cd194a7a1e6b9 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -103,6 +103,11 @@ enum class solverMethod { none }; +enum class solverVariant { + regular, + steadystate +}; + static std::string to_string(solverMethod m) { switch(m) { case solverMethod::cnexp: return std::string("cnexp"); @@ -628,8 +633,9 @@ public: SolveExpression( Location loc, std::string name, - solverMethod method) - : Expression(loc), name_(std::move(name)), method_(method), procedure_(nullptr) + solverMethod method, + solverVariant variant) + : Expression(loc), name_(std::move(name)), method_(method), variant_(variant), procedure_(nullptr) {} std::string to_string() const override { @@ -645,6 +651,10 @@ public: return method_; } + solverVariant variant() const { + return variant_; + } + ProcedureExpression* procedure() const { return procedure_; } @@ -667,6 +677,7 @@ private: /// pointer to the variable symbol for the state variable to be solved for std::string name_; solverMethod method_; + solverVariant variant_; ProcedureExpression* procedure_; }; diff --git a/modcc/module.cpp b/modcc/module.cpp index b982d2059c433ef1e5960ed60c975cd99fbdb438..30ada9a2e38e4c45f80b9abc838c2be344bfee80 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -360,9 +360,10 @@ bool Module::semantic() { case solverMethod::cnexp: solver = std::make_unique<CnexpSolverVisitor>(); break; - case solverMethod::sparse: - solver = std::make_unique<SparseSolverVisitor>(); + case solverMethod::sparse: { + solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant()); break; + } case solverMethod::none: solver = std::make_unique<DirectSolverVisitor>(); break; diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 27dac4fca4a57c9c06a54a5e5b740fea3197b2f1..3cacb0f1057ee33f5a8762ce9e8cd352f7afd90f 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1549,6 +1549,7 @@ expression_ptr Parser::parse_solve() { Location loc = location_; // solve location for expression std::string name; solverMethod method; + solverVariant variant; get_token(); // consume the SOLVE keyword @@ -1557,10 +1558,14 @@ expression_ptr Parser::parse_solve() { name = token_.spelling; // save name of procedure get_token(); // consume the procedure identifier - if(token_.type != tok::method) { // no method was provided + variant = solverVariant::regular; + if(token_.type != tok::method && token_.type != tok::steadystate) { method = solverMethod::none; } else { + if (token_.type == tok::steadystate) { + variant = solverVariant::steadystate; + } get_token(); // consume the METHOD keyword switch(token_.type) { case tok::cnexp: @@ -1580,12 +1585,14 @@ expression_ptr Parser::parse_solve() { if(token_.type != tok::eof) goto solve_statement_error; } - return make_expression<SolveExpression>(loc, name, method); + return make_expression<SolveExpression>(loc, name, method, variant); solve_statement_error: error( "SOLVE statements must have the form\n" " SOLVE x METHOD method\n" " or\n" + " SOLVE x STEADYSTATE sparse\n" + " or\n" " SOLVE x\n" "where 'x' is the name of a DERIVATIVE block and " "'method' is 'cnexp' or 'sparse'", loc); diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index c3b42d574c72382f8d2e276960f59d7094b025e1..4a53c0d9eb4081c82f4ca0644f0f283840794853 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -181,6 +181,16 @@ void SparseSolverVisitor::visit(BlockExpression* e) { dvars_.push_back(id->name()); } } + if (solve_variant_ == solverVariant::steadystate) { + // create zero_epression local for the rhs + auto zero_expr = make_expression<NumberExpression>(e->location(), 0.0); + auto local_a_term = make_unique_local_assign(e->scope(), zero_expr.get(), "a_"); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + steadystate_rhs_ = a_; + } scale_factor_.resize(dvars_.size()); BlockRewriterBase::visit(e); @@ -250,32 +260,34 @@ void SparseSolverVisitor::visit(AssignmentExpression *e) { for (unsigned j = 0; j<dvars_.size(); ++j) { expression_ptr expr; + // For regular solve: // For zero coefficient and diagonal element, the matrix entry is 1. // For non-zero coefficient c and diagonal element, the entry is 1-c*dt. // Otherwise, for non-zero coefficient c, the entry is -c*dt. + // For steady state solve: + // The entry is always the the coefficient. if (r.coef.count(dvars_[j])) { - expr = make_expression<MulBinaryExpression>(loc, - r.coef[dvars_[j]]->clone(), - dt_expr->clone()); + expr = solve_variant_ == solverVariant::steadystate ? r.coef[dvars_[j]]->clone() : + make_expression<MulBinaryExpression>(loc, r.coef[dvars_[j]]->clone(), dt_expr->clone()); if (scale_factor_[j]) { expr = make_expression<DivBinaryExpression>(loc, std::move(expr), scale_factor_[j]->clone()); } } - if (j==deq_index_) { - if (expr) { - expr = make_expression<SubBinaryExpression>(loc, - one_expr->clone(), - std::move(expr)); - } - else { - expr = one_expr->clone(); - } - } - else if (expr) { + if (solve_variant_ != solverVariant::steadystate) { + if (j == deq_index_) { + if (expr) { + expr = make_expression<SubBinaryExpression>(loc, + one_expr->clone(), + std::move(expr)); + } else { + expr = one_expr->clone(); + } + } else if (expr) { expr = make_expression<NegUnaryExpression>(loc, std::move(expr)); + } } if (!expr) continue; @@ -360,9 +372,15 @@ void SparseSolverVisitor::visit(ConserveExpression *e) { } void SparseSolverVisitor::finalize() { + + if (solve_variant_ == solverVariant::steadystate && !conserve_) { + error({"Conserve statement(s) missing in steady-state solver", {}}); + } + std::vector<symge::symbol> rhs; for (const auto& var: dvars_) { - rhs.push_back(symtbl_.define(var)); + auto v = solve_variant_ == solverVariant::steadystate? steadystate_rhs_ : var; + rhs.push_back(symtbl_.define(v)); } if (conserve_) { for (unsigned i = 0; i < conserve_idx_.size(); ++i) { diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp index 1aff2aea5552be56726e1b0084611af70ec68b47..01bdba33a3499a61c5d6dac8bc28c0e624bc68fd 100644 --- a/modcc/solvers.hpp +++ b/modcc/solvers.hpp @@ -68,6 +68,7 @@ public: class SparseSolverVisitor : public SolverVisitorBase { protected: + solverVariant solve_variant_; // 'Current' differential equation is for variable with this // index in `dvars`. unsigned deq_index_ = 0; @@ -91,10 +92,14 @@ protected: // rhs of conserve statement std::vector<std::string> conserve_rhs_; std::vector<unsigned> conserve_idx_; + + // rhs of steadstate + std::string steadystate_rhs_; public: using SolverVisitorBase::visit; - SparseSolverVisitor() {} + explicit SparseSolverVisitor(solverVariant s = solverVariant::regular) : + solve_variant_(s) {} SparseSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} virtual void visit(BlockExpression* e) override; @@ -111,6 +116,7 @@ public: scale_factor_.clear(); conserve_rhs_.clear(); conserve_idx_.clear(); + steadystate_rhs_.clear(); SolverVisitorBase::reset(); } }; diff --git a/modcc/token.cpp b/modcc/token.cpp index 4e09f545796e1f4e81a179e7e9d9c964ed09006d..fb3276637ce092f2b2160c6012c00548a05e41f2 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -53,6 +53,7 @@ static Keyword keywords[] = { {"POINT_PROCESS", tok::point_process}, {"COMPARTMENT", tok::compartment}, {"METHOD", tok::method}, + {"STEADYSTATE", tok::steadystate}, {"if", tok::if_stmt}, {"else", tok::else_stmt}, {"cnexp", tok::cnexp}, @@ -125,6 +126,7 @@ static TokenString token_strings[] = { {"POINT_PROCESS", tok::point_process}, {"COMPARTMENT", tok::compartment}, {"METHOD", tok::method}, + {"STEADYSTATE", tok::steadystate}, {"if", tok::if_stmt}, {"else", tok::else_stmt}, {"eof", tok::eof}, diff --git a/modcc/token.hpp b/modcc/token.hpp index 5eef7a943e4c81296064e99badc4ecc2e229cd00..03d064104de766300f05ad4a640960acbce7b1d1 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -61,7 +61,7 @@ enum class tok { suffix, nonspecific_current, useion, read, write, valence, range, local, conserve, compartment, - solve, method, + solve, method, steadystate, threadsafe, global, point_process, diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 468264f5c12f9c73b7743bb700d2381859dddc66..0cd9afad10ba24dd970ea26fb36726f9164a8bc2 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -7,10 +7,12 @@ set(test_mechanisms test_linear_init_shuffle test0_kin_diff test0_kin_conserve + test0_kin_compartment + test0_kin_steadystate test1_kin_diff test1_kin_conserve - test0_kin_compartment test1_kin_compartment + test1_kin_steadystate fixed_ica_current point_ica_current linear_ca_conc diff --git a/test/unit/mod/test0_kin_steadystate.mod b/test/unit/mod/test0_kin_steadystate.mod new file mode 100644 index 0000000000000000000000000000000000000000..28098a8f082d07c7436bbaf5a2f31ba3446cdef8 --- /dev/null +++ b/test/unit/mod/test0_kin_steadystate.mod @@ -0,0 +1,30 @@ +NEURON { + SUFFIX test0_kin_steadystate +} + +STATE { + s d h +} + +BREAKPOINT { + SOLVE state STEADYSTATE sparse +} + +KINETIC state { + LOCAL alpha1, beta1, alpha2, beta2 + alpha1 = 2 + beta1 = 0.6 + alpha2 = 3 + beta2 = 0.7 + + ~ s <-> h (alpha1, beta1) + ~ d <-> s (alpha2, beta2) + + CONSERVE s + d + h = 1 +} + +INITIAL { + h = 0.2 + d = 0.3 + s = 1-d-h +} diff --git a/test/unit/mod/test1_kin_steadystate.mod b/test/unit/mod/test1_kin_steadystate.mod new file mode 100644 index 0000000000000000000000000000000000000000..0949b093e029a82aabf53793916c98aa56c243fc --- /dev/null +++ b/test/unit/mod/test1_kin_steadystate.mod @@ -0,0 +1,39 @@ +NEURON { + SUFFIX test1_kin_steadystate +} + +STATE { + a b x y +} + +PARAMETER { + A = 0.5 + B = 0.1 +} + +BREAKPOINT { + SOLVE state STEADYSTATE sparse +} + +KINETIC state { + COMPARTMENT A {a b} + + LOCAL alpha1, beta1, alpha2, beta2 + alpha1 = 2 + beta1 = 0.6 + alpha2 = 3 + beta2 = 0.7 + + ~ a <-> b (alpha1, beta1) + ~ x <-> y (alpha2, beta2) + + CONSERVE a + b = A + CONSERVE x + y = 1 +} + +INITIAL { + a = 0.2 + b = 1 - a + x = 0.6 + y = 1 - x +} diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp index a85fc46e42c463f50e3d248a417e9b2d9584158f..fbbe0aad2ef8c23f1d1c95123aa34f9dfdd5e8b0 100644 --- a/test/unit/test_kinetic_linear.cpp +++ b/test/unit/test_kinetic_linear.cpp @@ -96,25 +96,28 @@ TEST(mech_kinetic, kintetic_scaled) { run_test<multicore::backend>("test0_kin_compartment", state_variables, {}, t0_values, t1_0_values); run_test<multicore::backend>("test1_kin_compartment", state_variables, {}, t0_values, t1_1_values); - } TEST(mech_kinetic, kintetic_1_conserve) { std::vector<std::string> state_variables = {"s", "h", "d"}; std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; - std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247}; + std::vector<fvm_value_type> t1_0_values = {0.380338, 0.446414, 0.173247}; + std::vector<fvm_value_type> t1_1_values = {0.218978, 0.729927, 0.0510949}; - run_test<multicore::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_values); - run_test<multicore::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_values); + run_test<multicore::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_0_values); + run_test<multicore::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_0_values); + run_test<multicore::backend>("test0_kin_steadystate", state_variables, {}, t0_values, t1_1_values); } TEST(mech_kinetic, kintetic_2_conserve) { std::vector<std::string> state_variables = {"a", "b", "x", "y"}; std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4}; - std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + std::vector<fvm_value_type> t1_0_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + std::vector<fvm_value_type> t1_1_values = {0.230769, 0.769231, 0.189189, 0.810811}; - run_test<multicore::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_values); - run_test<multicore::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_values); + run_test<multicore::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_0_values); + run_test<multicore::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_0_values); + run_test<multicore::backend>("test1_kin_steadystate", state_variables, {}, t0_values, t1_1_values); } TEST(mech_linear, linear) { @@ -141,19 +144,23 @@ TEST(mech_kinetic_gpu, kintetic_scaled) { TEST(mech_kinetic_gpu, kintetic_1_conserve) { std::vector<std::string> state_variables = {"s", "h", "d"}; std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; - std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247}; + std::vector<fvm_value_type> t1_0_values = {0.380338, 0.446414, 0.173247}; + std::vector<fvm_value_type> t1_1_values = {0.218978, 0.729927, 0.0510949}; - run_test<gpu::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_values); - run_test<gpu::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_values); + run_test<gpu::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_0_values); + run_test<gpu::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_0_values); + run_test<gpu::backend>("test0_kin_steadystate", state_variables, {}, t0_values, t1_1_values); } TEST(mech_kinetic_gpu, kintetic_2_conserve) { std::vector<std::string> state_variables = {"a", "b", "x", "y"}; std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4}; - std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + std::vector<fvm_value_type> t1_0_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + std::vector<fvm_value_type> t1_1_values = {0.230769, 0.769231, 0.189189, 0.810811}; - run_test<gpu::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_values); - run_test<gpu::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_values); + run_test<gpu::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_0_values); + run_test<gpu::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_0_values); + run_test<gpu ::backend>("test1_kin_steadystate", state_variables, {}, t0_values, t1_1_values); } TEST(mech_linear_gpu, linear) { diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index d5d54e4d24f895ad921275052119a9ed7bfece13..7f803dc21b6da6e3c62cf294e70822822272a0c2 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -13,10 +13,12 @@ #include "mechanisms/test_linear_init.hpp" #include "mechanisms/test_linear_init_shuffle.hpp" #include "mechanisms/test0_kin_conserve.hpp" +#include "mechanisms/test0_kin_steadystate.hpp" #include "mechanisms/test0_kin_compartment.hpp" #include "mechanisms/test1_kin_compartment.hpp" #include "mechanisms/test1_kin_diff.hpp" #include "mechanisms/test1_kin_conserve.hpp" +#include "mechanisms/test1_kin_steadystate.hpp" #include "mechanisms/fixed_ica_current.hpp" #include "mechanisms/point_ica_current.hpp" #include "mechanisms/linear_ca_conc.hpp" @@ -50,10 +52,12 @@ mechanism_catalogue make_unit_test_catalogue() { ADD_MECH(cat, test_linear_init_shuffle) ADD_MECH(cat, test0_kin_diff) ADD_MECH(cat, test0_kin_conserve) + ADD_MECH(cat, test0_kin_steadystate) ADD_MECH(cat, test0_kin_compartment) - ADD_MECH(cat, test1_kin_compartment) ADD_MECH(cat, test1_kin_diff) ADD_MECH(cat, test1_kin_conserve) + ADD_MECH(cat, test1_kin_steadystate) + ADD_MECH(cat, test1_kin_compartment) ADD_MECH(cat, fixed_ica_current) ADD_MECH(cat, point_ica_current) ADD_MECH(cat, linear_ca_conc)