Skip to content
Snippets Groups Projects
Unverified Commit ea25d7aa authored by Nora Abi Akar's avatar Nora Abi Akar Committed by GitHub
Browse files

Solve non-linear systems that are not kinetic schemes. (#1724)

Some refactoring of the SOLVE statement handling in `module.cpp`. 
Allow the usage of `SparseNonlinearSolverVisitor` for non-linear systems not only when in the form of a kinetic scheme. 
Check linearity of kinetic schemes and linear systems (was previously skipped).
parent dc371e5d
No related branches found
No related tags found
No related merge requests found
......@@ -151,7 +151,7 @@ bool Module::semantic() {
// symbol table.
// Returns false if a symbol name clashes with the name of a symbol that
// is already in the symbol table.
bool linear = true;
bool linear_homogeneous = true;
std::vector<std::string> state_vars;
auto move_symbols = [this] (std::vector<symbol_ptr>& symbol_list) {
for(auto& symbol: symbol_list) {
......@@ -388,61 +388,56 @@ bool Module::semantic() {
continue;
}
found_solve = true;
std::unique_ptr<SolverVisitorBase> solver;
// If the derivative block is a kinetic block, perform kinetic rewrite first.
auto deriv = solve_expression->procedure();
auto solve_body = deriv->body()->clone();
if (deriv->kind()==procedureKind::kinetic) {
solve_body = kinetic_rewrite(deriv->body());
}
else if (deriv->kind()==procedureKind::linear) {
solve_body = linear_rewrite(deriv->body(), state_vars);
}
// Calculate linearity and homogeneity of the statements in the derivative block.
bool linear = true;
bool homogeneous = true;
for (auto& s: solve_body->is_block()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
homogeneous &= r.is_homogeneous;
}
}
linear_homogeneous &= (linear & homogeneous);
// Construct solver based on system kind, linearity and solver method.
std::unique_ptr<SolverVisitorBase> solver;
switch(solve_expression->method()) {
case solverMethod::cnexp:
solver = std::make_unique<CnexpSolverVisitor>();
break;
case solverMethod::sparse: {
solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
break;
}
case solverMethod::none:
solver = std::make_unique<DirectSolverVisitor>();
break;
}
// If the derivative block is a kinetic block, perform kinetic
// rewrite first.
auto deriv = solve_expression->procedure();
if (deriv->kind()==procedureKind::kinetic) {
auto rewrite_body = kinetic_rewrite(deriv->body());
bool linear_kinetic = true;
for (auto& s: rewrite_body->is_block()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear_kinetic &= r.is_linear;
}
if (linear) {
solver = std::make_unique<SparseSolverVisitor>(solve_expression->variant());
}
if (!linear_kinetic) {
else {
solver = std::make_unique<SparseNonlinearSolverVisitor>();
}
rewrite_body->semantic(advance_state_scope);
rewrite_body->accept(solver.get());
}
else if (deriv->kind()==procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
auto rewrite_body = linear_rewrite(deriv->body(), state_vars);
rewrite_body->semantic(advance_state_scope);
rewrite_body->accept(solver.get());
break;
}
else {
deriv->body()->accept(solver.get());
for (auto& s: deriv->body()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
linear &= r.is_homogeneous;
}
case solverMethod::none:
if (deriv->kind()==procedureKind::linear) {
solver = std::make_unique<LinearSolverVisitor>(state_vars);
}
else {
solver = std::make_unique<DirectSolverVisitor>();
}
break;
}
// Perform semantic analysis on the solve block statements and solve them.
solve_body->semantic(advance_state_scope);
solve_body->accept(solver.get());
if (auto solve_block = solver->as_block(false)) {
// Check that we didn't solve an already solved variable.
......@@ -490,8 +485,8 @@ bool Module::semantic() {
for (auto& s: breakpoint->body()->statements()) {
if(s->is_assignment() && !state_vars.empty()) {
linear_test_result r = linear_test(s->is_assignment()->rhs(), state_vars);
linear &= r.is_linear;
linear &= r.is_homogeneous;
linear_homogeneous &= r.is_linear;
linear_homogeneous &= r.is_homogeneous;
}
}
......@@ -517,7 +512,7 @@ bool Module::semantic() {
for (const auto &id: state_vars) {
auto coef = symbolic_pdiff(s->is_assignment()->rhs(), id);
if(!coef) {
linear = false;
linear_homogeneous = false;
continue;
}
if(coef->is_number()) {
......@@ -525,19 +520,19 @@ bool Module::semantic() {
error(pprintf("Left hand side of assignment is not an identifier"));
return false;
}
linear &= s->is_assignment()->lhs()->is_identifier()->name() == id ?
coef->is_number()->value() == 1 :
coef->is_number()->value() == 0;
linear_homogeneous &= s->is_assignment()->lhs()->is_identifier()->name() == id ?
coef->is_number()->value() == 1 :
coef->is_number()->value() == 0;
}
else {
linear = false;
linear_homogeneous = false;
}
}
}
}
}
}
linear_ = linear;
linear_ = linear_homogeneous;
post_events_ = has_symbol("post_event", symbolKind::procedure);
if (post_events_) {
......
......@@ -25,6 +25,8 @@ set(test_mechanisms
test2_kin_diff
test3_kin_diff
test4_kin_compartment
test5_nonlinear_diff
test6_nonlinear_diff
test_ca
test_ca_read_valence
test_cl_valence
......
......@@ -37,5 +37,5 @@ KINETIC state {
COMPARTMENT s1 {d e}
~ A + B <-> C ( x, y )
~ A + d <-> e ( z, w )
~ A + d <-> e ( z, w )
}
\ No newline at end of file
NEURON {
SUFFIX test5_nonlinear_diff
}
STATE {
a b c
}
BREAKPOINT {
SOLVE state METHOD sparse
}
DERIVATIVE state {
LOCAL f0, f1, r0, r1
f0 = 2
r0 = 1
f1 = 3
r1 = 0
a' = -f0*a*b + r0*c
b' = -f0*a*b -r1*b + (r0+f1)*c
c' = f0*a*b +r1*b - (r0+f1)*c
}
INITIAL {
a = 0.2
b = 0.3
c = 0.5
}
NEURON {
SUFFIX test6_nonlinear_diff
}
STATE {
p
}
BREAKPOINT {
SOLVE state METHOD sparse
}
DERIVATIVE state {
p' = sin(p)
}
INITIAL {
p = 1
}
......@@ -127,6 +127,20 @@ TEST(mech_kinetic, kinetic_nonlinear) {
}
TEST(mech_kinetic, normal_nonlinear_0) {
std::vector<std::string> state_variables = {"a", "b", "c"};
std::vector<fvm_value_type> t0_values = {0.2, 0.3, 0.5};
std::vector<fvm_value_type> t1_values = {0.2078873133, 0.34222075, 0.45777925};
run_test<multicore::backend>("test5_nonlinear_diff", state_variables, t0_values, t1_values, 0.025);
}
TEST(mech_kinetic, normal_nonlinear_1) {
std::vector<std::string> state_variables = {"p"};
std::vector<fvm_value_type> t0_values = {1};
std::vector<fvm_value_type> t1_values = {1.0213199524};
run_test<multicore::backend>("test6_nonlinear_diff", state_variables, t0_values, t1_values, 0.025);
}
TEST(mech_kinetic, kinetic_nonlinear_scaled) {
std::vector<std::string> state_variables = {"A", "B", "C", "d", "e"};
std::vector<fvm_value_type> t0_values = {4.5, 6.6, 0.28, 2, 0};
......@@ -191,6 +205,20 @@ TEST(mech_kinetic_gpu, kinetic_nonlinear) {
run_test<gpu::backend>("test3_kin_diff", state_variables, t0_values, t1_1_values, 0.025);
}
TEST(mech_kinetic_gpu, normal_nonlinear_0) {
std::vector<std::string> state_variables = {"a", "b", "c"};
std::vector<fvm_value_type> t0_values = {0.2, 0.3, 0.5};
std::vector<fvm_value_type> t1_values = {0.2078873133, 0.34222075, 0.45777925};
run_test<gpu::backend>("test5_nonlinear_diff", state_variables, t0_values, t1_values, 0.025);
}
TEST(mech_kinetic_gpu, normal_nonlinear_1) {
std::vector<std::string> state_variables = {"p"};
std::vector<fvm_value_type> t0_values = {1};
std::vector<fvm_value_type> t1_values = {1.0213199524};
run_test<gpu::backend>("test6_nonlinear_diff", state_variables, t0_values, t1_values, 0.025);
}
TEST(mech_kinetic_gpu, kinetic_nonlinear_scaled) {
std::vector<std::string> state_variables = {"A", "B", "C", "d", "e"};
std::vector<fvm_value_type> t0_values = {4.5, 6.6, 0.28, 2, 0};
......
......@@ -28,6 +28,8 @@
#include "mechanisms/test2_kin_diff.hpp"
#include "mechanisms/test3_kin_diff.hpp"
#include "mechanisms/test4_kin_compartment.hpp"
#include "mechanisms/test5_nonlinear_diff.hpp"
#include "mechanisms/test6_nonlinear_diff.hpp"
#include "mechanisms/test1_kin_steadystate.hpp"
#include "mechanisms/fixed_ica_current.hpp"
#include "mechanisms/point_ica_current.hpp"
......@@ -86,6 +88,8 @@ mechanism_catalogue make_unit_test_catalogue(const mechanism_catalogue& from) {
ADD_MECH(cat, test1_kin_steadystate)
ADD_MECH(cat, test1_kin_compartment)
ADD_MECH(cat, test4_kin_compartment)
ADD_MECH(cat, test5_nonlinear_diff)
ADD_MECH(cat, test6_nonlinear_diff)
ADD_MECH(cat, fixed_ica_current)
ADD_MECH(cat, non_linear)
ADD_MECH(cat, point_ica_current)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment