From 4bd6f0978be628903eadeaa29c938c31956ba855 Mon Sep 17 00:00:00 2001
From: Nora Abi Akar <nora.abiakar@gmail.com>
Date: Tue, 1 Oct 2019 16:33:07 +0200
Subject: [PATCH] Modcc: Steady State (#879)

* Modify `SparseSolverVisitor` to allow solving kinetic equations at steady state.

Addresses #837
---
 modcc/expression.cpp                    |  2 +-
 modcc/expression.hpp                    | 15 ++++++--
 modcc/module.cpp                        |  5 +--
 modcc/parser.cpp                        | 11 ++++--
 modcc/solvers.cpp                       | 48 +++++++++++++++++--------
 modcc/solvers.hpp                       |  8 ++++-
 modcc/token.cpp                         |  2 ++
 modcc/token.hpp                         |  2 +-
 test/unit/CMakeLists.txt                |  4 ++-
 test/unit/mod/test0_kin_steadystate.mod | 30 ++++++++++++++++
 test/unit/mod/test1_kin_steadystate.mod | 39 ++++++++++++++++++++
 test/unit/test_kinetic_linear.cpp       | 33 ++++++++++-------
 test/unit/unit_test_catalogue.cpp       |  6 +++-
 13 files changed, 166 insertions(+), 39 deletions(-)
 create mode 100644 test/unit/mod/test0_kin_steadystate.mod
 create mode 100644 test/unit/mod/test1_kin_steadystate.mod

diff --git a/modcc/expression.cpp b/modcc/expression.cpp
index 250af45d..e90bc3ca 100644
--- a/modcc/expression.cpp
+++ b/modcc/expression.cpp
@@ -789,7 +789,7 @@ void SolveExpression::semantic(scope_ptr scp) {
 }
 
 expression_ptr SolveExpression::clone() const {
-    auto s = new SolveExpression(location_, name_, method_);
+    auto s = new SolveExpression(location_, name_, method_, variant_);
     s->procedure(procedure_);
     return expression_ptr{s};
 }
diff --git a/modcc/expression.hpp b/modcc/expression.hpp
index a5ef42a0..fe3eab9f 100644
--- a/modcc/expression.hpp
+++ b/modcc/expression.hpp
@@ -103,6 +103,11 @@ enum class solverMethod {
     none
 };
 
+enum class solverVariant {
+    regular,
+    steadystate
+};
+
 static std::string to_string(solverMethod m) {
     switch(m) {
         case solverMethod::cnexp:  return std::string("cnexp");
@@ -628,8 +633,9 @@ public:
     SolveExpression(
             Location loc,
             std::string name,
-            solverMethod method)
-    :   Expression(loc), name_(std::move(name)), method_(method), procedure_(nullptr)
+            solverMethod method,
+            solverVariant variant)
+    :   Expression(loc), name_(std::move(name)), method_(method), variant_(variant), procedure_(nullptr)
     {}
 
     std::string to_string() const override {
@@ -645,6 +651,10 @@ public:
         return method_;
     }
 
+    solverVariant variant() const {
+        return variant_;
+    }
+
     ProcedureExpression* procedure() const {
         return procedure_;
     }
@@ -667,6 +677,7 @@ private:
     /// pointer to the variable symbol for the state variable to be solved for
     std::string name_;
     solverMethod method_;
+    solverVariant variant_;
 
     ProcedureExpression* procedure_;
 };
diff --git a/modcc/module.cpp b/modcc/module.cpp
index b982d205..30ada9a2 100644
--- a/modcc/module.cpp
+++ b/modcc/module.cpp
@@ -360,9 +360,10 @@ bool Module::semantic() {
             case solverMethod::cnexp:
                 solver = std::make_unique<CnexpSolverVisitor>();
                 break;
-            case solverMethod::sparse:
-                solver = std::make_unique<SparseSolverVisitor>();
+            case solverMethod::sparse: {
+                solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
                 break;
+            }
             case solverMethod::none:
                 solver = std::make_unique<DirectSolverVisitor>();
                 break;
diff --git a/modcc/parser.cpp b/modcc/parser.cpp
index 27dac4fc..3cacb0f1 100644
--- a/modcc/parser.cpp
+++ b/modcc/parser.cpp
@@ -1549,6 +1549,7 @@ expression_ptr Parser::parse_solve() {
     Location loc = location_; // solve location for expression
     std::string name;
     solverMethod method;
+    solverVariant variant;
 
     get_token(); // consume the SOLVE keyword
 
@@ -1557,10 +1558,14 @@ expression_ptr Parser::parse_solve() {
     name = token_.spelling; // save name of procedure
     get_token(); // consume the procedure identifier
 
-    if(token_.type != tok::method) { // no method was provided
+    variant = solverVariant::regular;
+    if(token_.type != tok::method && token_.type != tok::steadystate) {
         method = solverMethod::none;
     }
     else {
+        if (token_.type == tok::steadystate) {
+            variant = solverVariant::steadystate;
+        }
         get_token(); // consume the METHOD keyword
         switch(token_.type) {
         case tok::cnexp:
@@ -1580,12 +1585,14 @@ expression_ptr Parser::parse_solve() {
         if(token_.type != tok::eof) goto solve_statement_error;
     }
 
-    return make_expression<SolveExpression>(loc, name, method);
+    return make_expression<SolveExpression>(loc, name, method, variant);
 
 solve_statement_error:
     error( "SOLVE statements must have the form\n"
            "  SOLVE x METHOD method\n"
            "    or\n"
+           "  SOLVE x STEADYSTATE sparse\n"
+           "    or\n"
            "  SOLVE x\n"
            "where 'x' is the name of a DERIVATIVE block and "
            "'method' is 'cnexp' or 'sparse'", loc);
diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp
index c3b42d57..4a53c0d9 100644
--- a/modcc/solvers.cpp
+++ b/modcc/solvers.cpp
@@ -181,6 +181,16 @@ void SparseSolverVisitor::visit(BlockExpression* e) {
             dvars_.push_back(id->name());
         }
     }
+    if (solve_variant_ == solverVariant::steadystate) {
+        // create zero_epression local for the rhs
+        auto zero_expr = make_expression<NumberExpression>(e->location(), 0.0);
+        auto local_a_term = make_unique_local_assign(e->scope(), zero_expr.get(), "a_");
+        auto a_ = local_a_term.id->is_identifier()->spelling();
+
+        statements_.push_back(std::move(local_a_term.local_decl));
+        statements_.push_back(std::move(local_a_term.assignment));
+        steadystate_rhs_ = a_;
+    }
     scale_factor_.resize(dvars_.size());
 
     BlockRewriterBase::visit(e);
@@ -250,32 +260,34 @@ void SparseSolverVisitor::visit(AssignmentExpression *e) {
     for (unsigned j = 0; j<dvars_.size(); ++j) {
         expression_ptr expr;
 
+        // For regular solve:
         // For zero coefficient and diagonal element, the matrix entry is 1.
         // For non-zero coefficient c and diagonal element, the entry is 1-c*dt.
         // Otherwise, for non-zero coefficient c, the entry is -c*dt.
+        // For steady state solve:
+        // The entry is always the the coefficient.
 
         if (r.coef.count(dvars_[j])) {
-            expr = make_expression<MulBinaryExpression>(loc,
-                       r.coef[dvars_[j]]->clone(),
-                       dt_expr->clone());
+            expr = solve_variant_ == solverVariant::steadystate ? r.coef[dvars_[j]]->clone() :
+                    make_expression<MulBinaryExpression>(loc, r.coef[dvars_[j]]->clone(), dt_expr->clone());
 
             if (scale_factor_[j]) {
                 expr =  make_expression<DivBinaryExpression>(loc, std::move(expr), scale_factor_[j]->clone());
             }
         }
 
-        if (j==deq_index_) {
-            if (expr) {
-                expr = make_expression<SubBinaryExpression>(loc,
-                           one_expr->clone(),
-                           std::move(expr));
-            }
-            else {
-                expr = one_expr->clone();
-            }
-        }
-        else if (expr) {
+        if (solve_variant_ != solverVariant::steadystate) {
+            if (j == deq_index_) {
+                if (expr) {
+                    expr = make_expression<SubBinaryExpression>(loc,
+                                                                one_expr->clone(),
+                                                                std::move(expr));
+                } else {
+                    expr = one_expr->clone();
+                }
+            } else if (expr) {
                 expr = make_expression<NegUnaryExpression>(loc, std::move(expr));
+            }
         }
 
         if (!expr) continue;
@@ -360,9 +372,15 @@ void SparseSolverVisitor::visit(ConserveExpression *e) {
 }
 
 void SparseSolverVisitor::finalize() {
+
+    if (solve_variant_ == solverVariant::steadystate && !conserve_) {
+        error({"Conserve statement(s) missing in steady-state solver", {}});
+    }
+
     std::vector<symge::symbol> rhs;
     for (const auto& var: dvars_) {
-        rhs.push_back(symtbl_.define(var));
+        auto v = solve_variant_ == solverVariant::steadystate? steadystate_rhs_ : var;
+        rhs.push_back(symtbl_.define(v));
     }
     if (conserve_) {
         for (unsigned i = 0; i < conserve_idx_.size(); ++i) {
diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp
index 1aff2aea..01bdba33 100644
--- a/modcc/solvers.hpp
+++ b/modcc/solvers.hpp
@@ -68,6 +68,7 @@ public:
 
 class SparseSolverVisitor : public SolverVisitorBase {
 protected:
+    solverVariant solve_variant_;
     // 'Current' differential equation is for variable with this
     // index in `dvars`.
     unsigned deq_index_ = 0;
@@ -91,10 +92,14 @@ protected:
     // rhs of conserve statement
     std::vector<std::string> conserve_rhs_;
     std::vector<unsigned> conserve_idx_;
+
+    // rhs of steadstate
+    std::string steadystate_rhs_;
 public:
     using SolverVisitorBase::visit;
 
-    SparseSolverVisitor() {}
+    explicit SparseSolverVisitor(solverVariant s = solverVariant::regular) :
+        solve_variant_(s) {}
     SparseSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {}
 
     virtual void visit(BlockExpression* e) override;
@@ -111,6 +116,7 @@ public:
         scale_factor_.clear();
         conserve_rhs_.clear();
         conserve_idx_.clear();
+        steadystate_rhs_.clear();
         SolverVisitorBase::reset();
     }
 };
diff --git a/modcc/token.cpp b/modcc/token.cpp
index 4e09f545..fb327663 100644
--- a/modcc/token.cpp
+++ b/modcc/token.cpp
@@ -53,6 +53,7 @@ static Keyword keywords[] = {
     {"POINT_PROCESS", tok::point_process},
     {"COMPARTMENT", tok::compartment},
     {"METHOD",      tok::method},
+    {"STEADYSTATE", tok::steadystate},
     {"if",          tok::if_stmt},
     {"else",        tok::else_stmt},
     {"cnexp",       tok::cnexp},
@@ -125,6 +126,7 @@ static TokenString token_strings[] = {
     {"POINT_PROCESS", tok::point_process},
     {"COMPARTMENT", tok::compartment},
     {"METHOD",      tok::method},
+    {"STEADYSTATE", tok::steadystate},
     {"if",          tok::if_stmt},
     {"else",        tok::else_stmt},
     {"eof",         tok::eof},
diff --git a/modcc/token.hpp b/modcc/token.hpp
index 5eef7a94..03d06410 100644
--- a/modcc/token.hpp
+++ b/modcc/token.hpp
@@ -61,7 +61,7 @@ enum class tok {
     suffix, nonspecific_current, useion,
     read, write, valence,
     range, local, conserve, compartment,
-    solve, method,
+    solve, method, steadystate,
     threadsafe, global,
     point_process,
 
diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt
index 468264f5..0cd9afad 100644
--- a/test/unit/CMakeLists.txt
+++ b/test/unit/CMakeLists.txt
@@ -7,10 +7,12 @@ set(test_mechanisms
     test_linear_init_shuffle
     test0_kin_diff
     test0_kin_conserve
+    test0_kin_compartment
+    test0_kin_steadystate
     test1_kin_diff
     test1_kin_conserve
-    test0_kin_compartment
     test1_kin_compartment
+    test1_kin_steadystate
     fixed_ica_current
     point_ica_current
     linear_ca_conc
diff --git a/test/unit/mod/test0_kin_steadystate.mod b/test/unit/mod/test0_kin_steadystate.mod
new file mode 100644
index 00000000..28098a8f
--- /dev/null
+++ b/test/unit/mod/test0_kin_steadystate.mod
@@ -0,0 +1,30 @@
+NEURON {
+    SUFFIX test0_kin_steadystate
+}
+
+STATE {
+    s d h
+}
+
+BREAKPOINT {
+    SOLVE state STEADYSTATE sparse
+}
+
+KINETIC state {
+    LOCAL alpha1, beta1, alpha2, beta2
+    alpha1 = 2
+    beta1 = 0.6
+    alpha2 = 3
+    beta2 = 0.7
+
+    ~ s <-> h (alpha1, beta1)
+    ~ d <-> s (alpha2, beta2)
+
+    CONSERVE s + d + h = 1
+}
+
+INITIAL {
+    h = 0.2
+    d = 0.3
+    s = 1-d-h
+}
diff --git a/test/unit/mod/test1_kin_steadystate.mod b/test/unit/mod/test1_kin_steadystate.mod
new file mode 100644
index 00000000..0949b093
--- /dev/null
+++ b/test/unit/mod/test1_kin_steadystate.mod
@@ -0,0 +1,39 @@
+NEURON {
+    SUFFIX test1_kin_steadystate
+}
+
+STATE {
+    a b x y
+}
+
+PARAMETER {
+    A = 0.5
+    B = 0.1
+}
+
+BREAKPOINT {
+    SOLVE state STEADYSTATE sparse
+}
+
+KINETIC state {
+    COMPARTMENT A {a b}
+
+    LOCAL alpha1, beta1, alpha2, beta2
+    alpha1 = 2
+    beta1 = 0.6
+    alpha2 = 3
+    beta2 = 0.7
+
+    ~ a <-> b (alpha1, beta1)
+    ~ x <-> y (alpha2, beta2)
+
+    CONSERVE a + b = A
+    CONSERVE x + y = 1
+}
+
+INITIAL {
+    a = 0.2
+    b = 1 - a
+    x = 0.6
+    y = 1 - x
+}
diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp
index a85fc46e..fbbe0aad 100644
--- a/test/unit/test_kinetic_linear.cpp
+++ b/test/unit/test_kinetic_linear.cpp
@@ -96,25 +96,28 @@ TEST(mech_kinetic, kintetic_scaled) {
 
     run_test<multicore::backend>("test0_kin_compartment", state_variables, {}, t0_values, t1_0_values);
     run_test<multicore::backend>("test1_kin_compartment", state_variables, {}, t0_values, t1_1_values);
-
 }
 
 TEST(mech_kinetic, kintetic_1_conserve) {
     std::vector<std::string> state_variables = {"s", "h", "d"};
     std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3};
-    std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247};
+    std::vector<fvm_value_type> t1_0_values = {0.380338, 0.446414, 0.173247};
+    std::vector<fvm_value_type> t1_1_values = {0.218978, 0.729927, 0.0510949};
 
-    run_test<multicore::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_values);
-    run_test<multicore::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_values);
+    run_test<multicore::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_0_values);
+    run_test<multicore::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_0_values);
+    run_test<multicore::backend>("test0_kin_steadystate", state_variables, {}, t0_values, t1_1_values);
 }
 
 TEST(mech_kinetic, kintetic_2_conserve) {
     std::vector<std::string> state_variables = {"a", "b", "x", "y"};
     std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4};
-    std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666};
+    std::vector<fvm_value_type> t1_0_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666};
+    std::vector<fvm_value_type> t1_1_values = {0.230769, 0.769231, 0.189189, 0.810811};
 
-    run_test<multicore::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_values);
-    run_test<multicore::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_values);
+    run_test<multicore::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_0_values);
+    run_test<multicore::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_0_values);
+    run_test<multicore::backend>("test1_kin_steadystate", state_variables, {}, t0_values, t1_1_values);
 }
 
 TEST(mech_linear, linear) {
@@ -141,19 +144,23 @@ TEST(mech_kinetic_gpu, kintetic_scaled) {
 TEST(mech_kinetic_gpu, kintetic_1_conserve) {
     std::vector<std::string> state_variables = {"s", "h", "d"};
     std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3};
-    std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247};
+    std::vector<fvm_value_type> t1_0_values = {0.380338, 0.446414, 0.173247};
+    std::vector<fvm_value_type> t1_1_values = {0.218978, 0.729927, 0.0510949};
 
-    run_test<gpu::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_values);
-    run_test<gpu::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_values);
+    run_test<gpu::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_0_values);
+    run_test<gpu::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_0_values);
+    run_test<gpu::backend>("test0_kin_steadystate", state_variables, {}, t0_values, t1_1_values);
 }
 
 TEST(mech_kinetic_gpu, kintetic_2_conserve) {
     std::vector<std::string> state_variables = {"a", "b", "x", "y"};
     std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4};
-    std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666};
+    std::vector<fvm_value_type> t1_0_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666};
+    std::vector<fvm_value_type> t1_1_values = {0.230769, 0.769231, 0.189189, 0.810811};
 
-    run_test<gpu::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_values);
-    run_test<gpu::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_values);
+    run_test<gpu::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_0_values);
+    run_test<gpu::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_0_values);
+    run_test<gpu    ::backend>("test1_kin_steadystate", state_variables, {}, t0_values, t1_1_values);
 }
 
 TEST(mech_linear_gpu, linear) {
diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp
index d5d54e4d..7f803dc2 100644
--- a/test/unit/unit_test_catalogue.cpp
+++ b/test/unit/unit_test_catalogue.cpp
@@ -13,10 +13,12 @@
 #include "mechanisms/test_linear_init.hpp"
 #include "mechanisms/test_linear_init_shuffle.hpp"
 #include "mechanisms/test0_kin_conserve.hpp"
+#include "mechanisms/test0_kin_steadystate.hpp"
 #include "mechanisms/test0_kin_compartment.hpp"
 #include "mechanisms/test1_kin_compartment.hpp"
 #include "mechanisms/test1_kin_diff.hpp"
 #include "mechanisms/test1_kin_conserve.hpp"
+#include "mechanisms/test1_kin_steadystate.hpp"
 #include "mechanisms/fixed_ica_current.hpp"
 #include "mechanisms/point_ica_current.hpp"
 #include "mechanisms/linear_ca_conc.hpp"
@@ -50,10 +52,12 @@ mechanism_catalogue make_unit_test_catalogue() {
     ADD_MECH(cat, test_linear_init_shuffle)
     ADD_MECH(cat, test0_kin_diff)
     ADD_MECH(cat, test0_kin_conserve)
+    ADD_MECH(cat, test0_kin_steadystate)
     ADD_MECH(cat, test0_kin_compartment)
-    ADD_MECH(cat, test1_kin_compartment)
     ADD_MECH(cat, test1_kin_diff)
     ADD_MECH(cat, test1_kin_conserve)
+    ADD_MECH(cat, test1_kin_steadystate)
+    ADD_MECH(cat, test1_kin_compartment)
     ADD_MECH(cat, fixed_ica_current)
     ADD_MECH(cat, point_ica_current)
     ADD_MECH(cat, linear_ca_conc)
-- 
GitLab