From 5846f90bfc8563cc746a401b3550b8a7854774f3 Mon Sep 17 00:00:00 2001 From: Sam Yates <yates@cscs.ch> Date: Sun, 5 Mar 2017 09:20:39 +0100 Subject: [PATCH] Add linear kinetic schemes to modcc. (#145) Incorporate symbolic GE code from prototype (with some simplifications) in msparse.hpp, symge.hpp and symge.cpp, together with unit tests. Add two kinetic scheme test cases for validation: test_kin1 (simple exponential scheme) and test_kinlva (combination of exponential gate and a three-species kinetic scheme, modelling a low voltage-activated Calcium channel from Wang, X. J. et al., J. Neurophys. 1991). Adapt numeric HH validation data generation to LVA Ca channel, with explicit stopping at stimulus discontinuities. Add two new validation tests based on above: kinetic.kin1_numeric_ref and kinetic.kinlva_numeric_ref (multicore backend only). Introduce a BlockRewriterBase visitor base class, as an aid for visitors that transform/rewrite procedure bodies; refactor KineticRewriter over this class. Introduce common error_stack mixin class for common functionality across Module and the various procedure rewriters. Implement visitors and public-facing convenience wrappers in symdiff.hpp and symdiff.cpp: involves_identifer for testing if an expression contains given identifiers. constant_simplify for constant folding with removal of trivial terms arising from a NumberExpression of zero or one. expr_value to extract the numerical value of a NumberExpression, or NaN othereise. is_zero to test if an expression is numerically zero. symbolic_pdiff to perform symbolic partial differentiation; this adds a new (not parseable) expression subclass to represent opaque partial differential terms. substitute to substitute identifiers for other expressions within an expression. linear_test for linearity, diagonality and homogeneity testing (this is probably redundant, given ExpressionClassifier already exists). Simplify unnecessary uses of make_unique with Vistor subclasses. Make SOLVE statement rewriting more generic, through the use of solve-rewriter visitors CnexpSolverVisitor, SparseSolverVisitor, and DirectSolverVisitor; implementations in solvers.hpp and solvers.cpp. Supports multiple SOLVE statements for independent subsets of state variables with the BREAKPOINT block. Add block rewriter for the removal of unused local variables, with convenience wrapper remove_unused_locals. Generalize is_in utility in modccutil.hpp. Simplify expression comparison in modcc unit tests with EXPECT_EXPR_EQ macro added to tests/modcc/test.hpp, that operates by comparing expression text representations. Simplify and consolidate verbose printing in modcc unit tests with verbose_print function that tests the global verbose flag and handles expression_ptr and similar which have to_string methods. --- mechanisms/CMakeLists.txt | 2 +- mechanisms/mod/test_kin1.mod | 38 + mechanisms/mod/test_kinlva.mod | 76 ++ modcc/CMakeLists.txt | 4 + modcc/astmanip.cpp | 11 + modcc/astmanip.hpp | 17 +- modcc/error.hpp | 65 +- modcc/expression.cpp | 28 +- modcc/expression.hpp | 55 +- modcc/functionexpander.cpp | 12 +- modcc/functionexpander.hpp | 16 +- modcc/functioninliner.cpp | 30 +- .../{kinrewriter.hpp => kineticrewriter.cpp} | 103 +-- modcc/kineticrewriter.hpp | 6 + modcc/modcc.cpp | 18 +- modcc/modccutil.hpp | 52 +- modcc/module.cpp | 430 +++++------ modcc/module.hpp | 42 +- modcc/msparse.hpp | 206 ++++++ modcc/parser.cpp | 21 +- modcc/solvers.cpp | 448 ++++++++++++ modcc/solvers.hpp | 97 +++ modcc/symdiff.cpp | 692 ++++++++++++++++++ modcc/symdiff.hpp | 119 +++ modcc/symge.cpp | 109 +++ modcc/symge.hpp | 158 ++++ modcc/token.cpp | 1 + modcc/token.hpp | 1 + modcc/visitor.hpp | 136 ++++ src/backends/fvm_multicore.cpp | 12 +- tests/modcc/CMakeLists.txt | 5 + tests/modcc/driver.cpp | 2 - tests/modcc/test.cpp | 27 + tests/modcc/test.hpp | 63 ++ tests/modcc/test_kinetic_rewriter.cpp | 49 +- tests/modcc/test_lexer.cpp | 44 +- tests/modcc/test_msparse.cpp | 188 +++++ tests/modcc/test_optimization.cpp | 68 +- tests/modcc/test_parser.cpp | 10 +- tests/modcc/test_removelocals.cpp | 197 +++++ tests/modcc/test_symdiff.cpp | 330 +++++++++ tests/modcc/test_symge.cpp | 178 +++++ tests/modcc/test_visitors.cpp | 217 +++--- tests/validation/CMakeLists.txt | 1 + tests/validation/validate_kinetic.cpp | 13 + tests/validation/validate_kinetic.hpp | 90 +++ tests/validation/validate_soma.hpp | 4 +- validation/ref/numeric/CMakeLists.txt | 10 + validation/ref/numeric/LVAChannels.jl | 168 +++++ validation/ref/numeric/numeric_kin1.jl | 33 + validation/ref/numeric/numeric_kinlva.jl | 40 + 51 files changed, 4098 insertions(+), 644 deletions(-) create mode 100644 mechanisms/mod/test_kin1.mod create mode 100644 mechanisms/mod/test_kinlva.mod rename modcc/{kinrewriter.hpp => kineticrewriter.cpp} (58%) create mode 100644 modcc/kineticrewriter.hpp create mode 100644 modcc/msparse.hpp create mode 100644 modcc/solvers.cpp create mode 100644 modcc/solvers.hpp create mode 100644 modcc/symdiff.cpp create mode 100644 modcc/symdiff.hpp create mode 100644 modcc/symge.cpp create mode 100644 modcc/symge.hpp create mode 100644 tests/modcc/test.cpp create mode 100644 tests/modcc/test_msparse.cpp create mode 100644 tests/modcc/test_removelocals.cpp create mode 100644 tests/modcc/test_symdiff.cpp create mode 100644 tests/modcc/test_symge.cpp create mode 100644 tests/validation/validate_kinetic.cpp create mode 100644 tests/validation/validate_kinetic.hpp create mode 100644 validation/ref/numeric/LVAChannels.jl create mode 100644 validation/ref/numeric/numeric_kin1.jl create mode 100644 validation/ref/numeric/numeric_kinlva.jl diff --git a/mechanisms/CMakeLists.txt b/mechanisms/CMakeLists.txt index 10fdae9e..3f543b64 100644 --- a/mechanisms/CMakeLists.txt +++ b/mechanisms/CMakeLists.txt @@ -1,7 +1,7 @@ include(BuildModules.cmake) # the list of built-in mechanisms to be provided by default -set(mechanisms pas hh expsyn exp2syn) +set(mechanisms pas hh expsyn exp2syn test_kin1 test_kinlva) set(modcc_opt) if(NMC_USE_OPTIMIZED_KERNELS) # generate optimized kernels diff --git a/mechanisms/mod/test_kin1.mod b/mechanisms/mod/test_kin1.mod new file mode 100644 index 00000000..c877fb4e --- /dev/null +++ b/mechanisms/mod/test_kin1.mod @@ -0,0 +1,38 @@ +NEURON { + SUFFIX test_kin1 + NONSPECIFIC_CURRENT il +} + +UNITS { + (mV) = (millivolt) + (mA) = (milliamp) + (S) = (siemens) +} + +PARAMETER { + tau = 10 (ms) +} + +STATE { + a (mA/cm2) + b (mA/cm2) +} + +ASSIGNED { + v (mV) +} + +BREAKPOINT { + SOLVE states METHOD sparse + il = 0*v+a +} + +INITIAL { + a = 0.01 + b = 0 +} + +KINETIC states { + ~ a <-> b (2/3/tau, 1/3/tau) +} + diff --git a/mechanisms/mod/test_kinlva.mod b/mechanisms/mod/test_kinlva.mod new file mode 100644 index 00000000..51a541e8 --- /dev/null +++ b/mechanisms/mod/test_kinlva.mod @@ -0,0 +1,76 @@ +: Adaption of T-type calcium channel from Wang, X. J. et al. 1991; +: c.f. NMODL file in ModelDB: +: https://senselab.med.yale.edu/modeldb/showModel.cshtml?model=53893 +: +: Note the temperature rate correction factors of 5 (for m) and 3 +: (for h <-> s <-> d) have been applied to match the model described +: in the current-clamp experiments (see p. 842). + +NEURON { + SUFFIX test_kinlva + USEION ca WRITE ica + NONSPECIFIC_CURRENT il +} + +UNITS { + (mS) = (millisiemens) + (mV) = (millivolts) + (mA) = (millamp) +} + +PARAMETER { + gbar = 0.0002 (S/cm2) + gl = 0.0001 (S/cm2) + eca = 120 (mV) + el = -65 (mV) +} + +STATE { + m h s d +} + +BREAKPOINT { + SOLVE m_state METHOD cnexp + SOLVE dsh_state METHOD sparse + ica = gbar*m^3*h*(v-eca) + il = gl*(v-el) +} + +FUNCTION minf(v) { + minf = 1/(1+exp(-(v+63)/7.8)) +} + +FUNCTION K(v) { + K = (0.25+exp((v+83.5)/6.3))^0.5-0.5 +} + +DERIVATIVE m_state { + LOCAL taum, mi, m_q10 + m_q10 = 5 + mi = minf(v) + taum = (1.7+exp(-(v+28.8)/13.5))*mi + m' = m_q10*(mi - m)/taum +} + +KINETIC dsh_state { + LOCAL k, alpha1, beta1, alpha2, beta2, dsh_q10 + dsh_q10 = 3 + k = K(v) + alpha1 = dsh_q10*exp(-(v+160.3)/17.8) + beta1 = alpha1*k + alpha2 = dsh_q10*(1+exp((v+37.4)/30))/240/(1+k) + beta2 = alpha2*k + + ~ s <-> h (alpha1, beta1) + ~ d <-> s (alpha2, beta2) +} + +INITIAL { + LOCAL k, vrest + vrest = -65 + k = K(v) + m = minf(vrest) + h = 1/(1+k+k^2) + d = h*k^2 + s = 1-h-d +} diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 2dcf6d35..29dffa3c 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -9,8 +9,12 @@ set(MODCC_SOURCES functionexpander.cpp functioninliner.cpp lexer.cpp + kineticrewriter.cpp module.cpp parser.cpp + solvers.cpp + symdiff.cpp + symge.cpp textbuffer.cpp token.cpp ) diff --git a/modcc/astmanip.cpp b/modcc/astmanip.cpp index e723c675..898374ff 100644 --- a/modcc/astmanip.cpp +++ b/modcc/astmanip.cpp @@ -28,3 +28,14 @@ local_assignment make_unique_local_assign(scope_ptr scope, Expression* e, std::s return { std::move(local), std::move(ass), std::move(id), scope }; } +local_declaration make_unique_local_decl(scope_ptr scope, Location loc, std::string const& prefix) { + std::string name = unique_local_name(scope, prefix); + + auto local = make_expression<LocalDeclaration>(loc, name); + local->semantic(scope); + + auto id = make_expression<LocalDeclaration>(loc, name); + id->semantic(scope); + + return { std::move(local), std::move(id), scope }; +} diff --git a/modcc/astmanip.hpp b/modcc/astmanip.hpp index b3135cfc..e42586ff 100644 --- a/modcc/astmanip.hpp +++ b/modcc/astmanip.hpp @@ -10,19 +10,30 @@ // Create new local variable symbol and local declaration expression in current scope. // Returns the local declaration expression. -expression_ptr make_unique_local_decl(scope_ptr scope, Location loc, std::string const& prefix="ll"); -struct local_assignment { +struct local_declaration { expression_ptr local_decl; - expression_ptr assignment; expression_ptr id; scope_ptr scope; }; +local_declaration make_unique_local_decl( + scope_ptr scope, + Location loc, + std::string const& prefix="ll"); + // Create a local declaration as for `make_unique_local_decl`, together with an // assignment to it from the given expression, using the location of that expression. // Returns local declaration expression, assignment expression, new identifier id and // consequent scope. + +struct local_assignment { + expression_ptr local_decl; + expression_ptr assignment; + expression_ptr id; + scope_ptr scope; +}; + local_assignment make_unique_local_assign( scope_ptr scope, Expression* e, diff --git a/modcc/error.hpp b/modcc/error.hpp index f113eff0..e66fc454 100644 --- a/modcc/error.hpp +++ b/modcc/error.hpp @@ -1,25 +1,76 @@ #pragma once +#include <deque> +#include <iterator> +#include <stdexcept> +#include <string> + #include "location.hpp" +struct error_entry { + std::string message; + Location location; + + error_entry(std::string m): message(std::move(m)) {} + error_entry(std::string m, Location l): message(std::move(m)), location(l) {} +}; + +// Mixin class for managing a stack of error info. + +class error_stack { +private: + std::deque<error_entry> errors_; + std::deque<error_entry> warnings_; + +public: + bool has_error() const { return !errors_.empty(); } + void error(error_entry info) { errors_.push_back(std::move(info)); } + void clear_errors() { errors_.clear(); } + + std::deque<error_entry>& errors() { return errors_; } + const std::deque<error_entry>& errors() const { return errors_; } + + template <typename Seq> + void append_errors(const Seq& seq) { + errors_.insert(errors_.end(), std::begin(seq), std::end(seq)); + } + + bool has_warning() const { return !warnings_.empty(); } + void warning(error_entry info) { warnings_.push_back(std::move(info)); } + void clear_warnings() { warnings_.clear(); } + + std::deque<error_entry>& warnings() { return warnings_; } + const std::deque<error_entry>& warnings() const { return warnings_; } + + template <typename Seq> + void append_warnings(const Seq& seq) { + warnings_.insert(warnings_.end(), std::begin(seq), std::end(seq)); + } +}; + +// Wrap error entry in exception. + class compiler_exception : public std::exception { public: + explicit compiler_exception(error_entry info) + : error_info_(std::move(info)) + {} + compiler_exception(std::string m, Location location) - : location_(location), - message_(std::move(m)) + : error_info_({std::move(m), location}) {} virtual const char* what() const throw() { - return message_.c_str(); + return error_info_.message.c_str(); } Location const& location() const { - return location_; + return error_info_.location; } private: - - Location location_; - std::string message_; + error_entry error_info_; }; + + diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 7c3eca9f..7b05d65f 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -756,7 +756,7 @@ void BlockExpression::semantic(scope_ptr scp) { } expression_ptr BlockExpression::clone() const { - std::list<expression_ptr> statements; + expr_list_type statements; for(auto& e: statements_) { statements.emplace_back(e->clone()); } @@ -805,6 +805,29 @@ expression_ptr IfExpression::clone() const { ); } +/******************************************************************************* + PDiffExpression +*******************************************************************************/ + +std::string PDiffExpression::to_string() const { + return blue("pdiff") + " ( " + var_->to_string() + "; " + arg_->to_string() + ")"; +} + +void PDiffExpression::semantic(scope_ptr scp) { + scope_ = scp; + + if (!var_->is_identifier()) { + error(pprintf("the variable in the partial differential expression is not " + "an identifier, but instead %", yellow(var_->to_string()))); + } + var_->semantic(scp); + arg_->semantic(scp); +} + +expression_ptr PDiffExpression::clone() const { + return make_expression<PDiffExpression>(location_, var_->clone(), arg_->clone()); +} + #include "visitor.hpp" /* @@ -930,6 +953,9 @@ void PowBinaryExpression::accept(Visitor *v) { void ConditionalExpression::accept(Visitor *v) { v->visit(this); } +void PDiffExpression::accept(Visitor *v) { + v->visit(this); +} expression_ptr unary_expression( Location loc, tok op, diff --git a/modcc/expression.hpp b/modcc/expression.hpp index c20e0006..a44bdc5b 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -55,11 +55,13 @@ class SolveExpression; class ConductanceExpression; class Symbol; class LocalVariable; +class PDiffExpression; // not parsed; possible result of symbolic differentiation using expression_ptr = std::unique_ptr<Expression>; using symbol_ptr = std::unique_ptr<Symbol>; using scope_type = Scope<Symbol>; using scope_ptr = std::shared_ptr<scope_type>; +using expr_list_type = std::list<expression_ptr>; template <typename T, typename... Args> expression_ptr make_expression(Args&&... args) { @@ -101,14 +103,16 @@ std::string to_string(symbolKind k); /// methods for time stepping state enum class solverMethod { - cnexp, // the only method we have at the moment + cnexp, // for diagonal linear ODE systems. + sparse, // for non-diagonal linear ODE systems. none }; static std::string to_string(solverMethod m) { switch(m) { - case solverMethod::cnexp : return std::string("cnexp"); - case solverMethod::none : return std::string("none"); + case solverMethod::cnexp: return std::string("cnexp"); + case solverMethod::sparse: return std::string("sparse"); + case solverMethod::none: return std::string("none"); } return std::string("<error : undefined solverMethod>"); } @@ -155,6 +159,7 @@ public: virtual expression_ptr clone() const; // easy lookup of properties + virtual CallExpression* is_call() {return nullptr;} virtual CallExpression* is_function_call() {return nullptr;} virtual CallExpression* is_procedure_call() {return nullptr;} virtual BlockExpression* is_block() {return nullptr;} @@ -179,6 +184,7 @@ public: virtual SolveExpression* is_solve_statement() {return nullptr;} virtual Symbol* is_symbol() {return nullptr;} virtual ConductanceExpression* is_conductance_statement() {return nullptr;} + virtual PDiffExpression* is_pdiff() {return nullptr;} virtual bool is_lvalue() const {return false;} virtual bool is_global_lvalue() const {return false;} @@ -732,13 +738,13 @@ private: class BlockExpression : public Expression { protected: - std::list<expression_ptr> statements_; + expr_list_type statements_; bool is_nested_ = false; public: BlockExpression( Location loc, - std::list<expression_ptr>&& statements, + expr_list_type&& statements, bool is_nested) : Expression(loc), statements_(std::move(statements)), @@ -749,7 +755,7 @@ public: return this; } - std::list<expression_ptr>& statements() { + expr_list_type& statements() { return statements_; } @@ -955,6 +961,9 @@ public: void accept(Visitor *v) override; + CallExpression* is_call() override { + return this; + } CallExpression* is_function_call() override { return symbol_->kind() == symbolKind::function ? this : nullptr; } @@ -1003,6 +1012,16 @@ public: BlockExpression* body() { return body_.get()->is_block(); } + void body(expression_ptr&& new_body) { + if(!new_body->is_block()) { + Location loc = new_body? new_body->location(): Location{}; + throw compiler_exception( + " attempt to set ProcedureExpression body with non-block expression, i.e.\n" + + new_body->to_string(), + loc); + } + body_ = std::move(new_body); + } void semantic(scope_type::symbol_map &scp) override; ProcedureExpression* is_procedure() override {return this;} @@ -1044,7 +1063,7 @@ class InitialBlock : public BlockExpression { public: InitialBlock( Location loc, - std::list<expression_ptr>&& statements) + expr_list_type&& statements) : BlockExpression(loc, std::move(statements), true) {} @@ -1315,3 +1334,25 @@ public: void accept(Visitor *v) override; }; + +class PDiffExpression : public Expression { +public: + PDiffExpression(Location loc, expression_ptr&& var, expression_ptr&& arg) + : Expression(loc), var_(std::move(var)), arg_(std::move(arg)) + {} + + std::string to_string() const override; + void accept(Visitor *v) override; + void semantic(scope_ptr scp) override; + expression_ptr clone() const override; + + PDiffExpression* is_pdiff() override { return this; } + + Expression* var() { return var_.get(); } + Expression* arg() { return arg_.get(); } + +private: + expression_ptr var_; + expression_ptr arg_; +}; + diff --git a/modcc/functionexpander.cpp b/modcc/functionexpander.cpp index 402ca71c..7cbc3545 100644 --- a/modcc/functionexpander.cpp +++ b/modcc/functionexpander.cpp @@ -5,9 +5,7 @@ #include "functionexpander.hpp" #include "modccutil.hpp" -using namespace nest::mc; - -expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* e) { +expression_ptr insert_unique_local_assignment(expr_list_type& stmts, Expression* e) { auto exprs = make_unique_local_assign(e->scope(), e); stmts.push_front(std::move(exprs.local_decl)); stmts.push_back(std::move(exprs.assignment)); @@ -18,9 +16,9 @@ expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* // function call site lowering /////////////////////////////////////////////////////////////////////////////// -call_list_type lower_function_calls(Expression* e) +expr_list_type lower_function_calls(Expression* e) { - auto v = util::make_unique<FunctionCallLowerer>(e->scope()); + auto v = make_unique<FunctionCallLowerer>(e->scope()); if(auto a=e->is_assignment()) { #ifdef LOGGING @@ -107,10 +105,10 @@ void FunctionCallLowerer::visit(BinaryExpression *e) { // function argument lowering /////////////////////////////////////////////////////////////////////////////// -call_list_type +expr_list_type lower_function_arguments(std::vector<expression_ptr>& args) { - call_list_type new_statements; + expr_list_type new_statements; for(auto it=args.begin(); it!=args.end(); ++it) { // get reference to the unique_ptr with the expression auto& e = *it; diff --git a/modcc/functionexpander.hpp b/modcc/functionexpander.hpp index c9185332..b1ea95a2 100644 --- a/modcc/functionexpander.hpp +++ b/modcc/functionexpander.hpp @@ -2,19 +2,17 @@ #include <sstream> +#include "expression.hpp" #include "scope.hpp" #include "visitor.hpp" -// storage for a list of expressions -using call_list_type = std::list<expression_ptr>; - // Make a local declaration and assignment for the given expression, // and insert at the front and back respectively of the statement list. // Return the new unique local identifier. -expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* e); +expression_ptr insert_unique_local_assignment(expr_list_type& stmts, Expression* e); // prototype for lowering function calls -call_list_type lower_function_calls(Expression* e); +expr_list_type lower_function_calls(Expression* e); /////////////////////////////////////////////////////////////////////////////// // visitor that takes function call sites and lowers them to inline assignments @@ -48,11 +46,11 @@ public: void visit(NumberExpression *e) override {}; void visit(IdentifierExpression *e) override {}; - call_list_type& calls() { + expr_list_type& calls() { return calls_; } - call_list_type move_calls() { + expr_list_type move_calls() { return std::move(calls_); } @@ -67,7 +65,7 @@ private: replacer(std::move(id)); } - call_list_type calls_; + expr_list_type calls_; scope_ptr scope_; }; @@ -91,5 +89,5 @@ private: // If the calls_ data is spliced directly before the original statement // the function arguments will have been fully lowered /////////////////////////////////////////////////////////////////////////////// -call_list_type lower_function_arguments(std::vector<expression_ptr>& args); +expr_list_type lower_function_arguments(std::vector<expression_ptr>& args); diff --git a/modcc/functioninliner.cpp b/modcc/functioninliner.cpp index 837c36f0..a652c60a 100644 --- a/modcc/functioninliner.cpp +++ b/modcc/functioninliner.cpp @@ -5,8 +5,6 @@ #include "modccutil.hpp" #include "errorvisitor.hpp" -using namespace nest::mc; - expression_ptr inline_function_call(Expression* e) { if(auto f=e->is_function_call()) { @@ -35,12 +33,11 @@ expression_ptr inline_function_call(Expression* e) << id->to_string() << " -> " << fargs[i]->to_string() << " in the expression " << new_e->to_string() << "\n"; #endif - auto v = - util::make_unique<VariableReplacer>( - fargs[i]->is_argument()->spelling(), - id->spelling() - ); - new_e->accept(v.get()); + VariableReplacer v( + fargs[i]->is_argument()->spelling(), + id->spelling() + ); + new_e->accept(&v); } else if(auto value = cargs[i]->is_number()) { #ifdef LOGGING @@ -48,12 +45,11 @@ expression_ptr inline_function_call(Expression* e) << value->to_string() << " -> " << fargs[i]->to_string() << " in the expression " << new_e->to_string() << "\n"; #endif - auto v = - util::make_unique<ValueInliner>( - fargs[i]->is_argument()->spelling(), - value->value() - ); - new_e->accept(v.get()); + ValueInliner v( + fargs[i]->is_argument()->spelling(), + value->value() + ); + new_e->accept(&v); } else { throw compiler_exception( @@ -64,12 +60,12 @@ expression_ptr inline_function_call(Expression* e) } new_e->semantic(e->scope()); - auto v = util::make_unique<ErrorVisitor>(""); - new_e->accept(v.get()); + ErrorVisitor v(""); + new_e->accept(&v); #ifdef LOGGING std::cout << "inline_function_call result " << new_e->to_string() << "\n\n"; #endif - if(v->num_errors()) { + if(v.num_errors()) { throw compiler_exception("something went wrong with inlined function call ", e->location()); } diff --git a/modcc/kinrewriter.hpp b/modcc/kineticrewriter.cpp similarity index 58% rename from modcc/kinrewriter.hpp rename to modcc/kineticrewriter.cpp index 2ccf5c8a..32b6d52a 100644 --- a/modcc/kinrewriter.hpp +++ b/modcc/kineticrewriter.cpp @@ -1,72 +1,48 @@ -#pragma once - #include <iostream> #include <map> #include <string> #include <list> #include "astmanip.hpp" +#include "symdiff.hpp" #include "visitor.hpp" -using stmt_list_type = std::list<expression_ptr>; - -class KineticRewriter : public Visitor { +class KineticRewriter : public BlockRewriterBase { public: - virtual void visit(Expression *) override; + using BlockRewriterBase::visit; - virtual void visit(UnaryExpression *e) override { visit((Expression*)e); } - virtual void visit(BinaryExpression *e) override { visit((Expression*)e); } + KineticRewriter() {} + KineticRewriter(scope_ptr enclosing_scope): BlockRewriterBase(enclosing_scope) {} virtual void visit(ConserveExpression *e) override; virtual void visit(ReactionExpression *e) override; - virtual void visit(BlockExpression *e) override; - virtual void visit(ProcedureExpression* e) override; - - symbol_ptr as_procedure() { - stmt_list_type body_stmts; - for (const auto& s: statements) body_stmts.push_back(s->clone()); - - auto body = make_expression<BlockExpression>( - proc_loc, - std::move(body_stmts), - false); - - return make_symbol<ProcedureExpression>( - proc_loc, - proc_name, - std::vector<expression_ptr>(), - std::move(body)); - } -private: - // Name and location of original kinetic procedure (used for `as_procedure` above). - std::string proc_name; - Location proc_loc; +protected: + virtual void reset() override { + BlockRewriterBase::reset(); + dterms.clear(); + } - // Statements in replacement procedure body. - stmt_list_type statements; + virtual void finalize() override; +private: // Acccumulated terms for derivative expressions, keyed by id name. std::map<std::string, expression_ptr> dterms; - - // Reset state (at e.g. start of kinetic proc). - void reset() { - proc_name = ""; - statements.clear(); - dterms.clear(); - } }; -// By default, copy statements across verbatim. -inline void KineticRewriter::visit(Expression* e) { - statements.push_back(e->clone()); +expression_ptr kinetic_rewrite(BlockExpression* block) { + KineticRewriter visitor; + block->accept(&visitor); + return visitor.as_block(false); } -inline void KineticRewriter::visit(ConserveExpression*) { +// KineticRewriter implementation follows. + +void KineticRewriter::visit(ConserveExpression*) { // Deliberately ignoring these for now! } -inline void KineticRewriter::visit(ReactionExpression* e) { +void KineticRewriter::visit(ReactionExpression* e) { Location loc = e->location(); scope_ptr scope = e->scope(); @@ -79,10 +55,9 @@ inline void KineticRewriter::visit(ReactionExpression* e) { auto& id = term->is_stoich_term()->ident(); auto& coeff = term->is_stoich_term()->coeff(); - fwd = make_expression<MulBinaryExpression>( - loc, + fwd = constant_simplify(make_expression<MulBinaryExpression>(loc, make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), - std::move(fwd)); + std::move(fwd))); } // Similar for reverse rate. @@ -93,10 +68,9 @@ inline void KineticRewriter::visit(ReactionExpression* e) { auto& id = term->is_stoich_term()->ident(); auto& coeff = term->is_stoich_term()->coeff(); - rev = make_expression<MulBinaryExpression>( - loc, - make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), - std::move(rev)); + rev = constant_simplify(make_expression<MulBinaryExpression>(loc, + make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), + std::move(rev))); } auto net_rate = make_expression<SubBinaryExpression>( @@ -105,11 +79,11 @@ inline void KineticRewriter::visit(ReactionExpression* e) { net_rate->semantic(scope); auto local_net_rate = make_unique_local_assign(scope, net_rate, "rate"); - statements.push_back(std::move(local_net_rate.local_decl)); - statements.push_back(std::move(local_net_rate.assignment)); + statements_.push_back(std::move(local_net_rate.local_decl)); + statements_.push_back(std::move(local_net_rate.assignment)); scope = local_net_rate.scope; // nop for now... - auto net_rate_sym = std::move(local_net_rate.id); + const auto& net_rate_sym = local_net_rate.id; // Net change in quantity after forward reaction: // e.g. A + ... <-> 3A + ... @@ -142,8 +116,8 @@ inline void KineticRewriter::visit(ReactionExpression* e) { term->semantic(scope); auto local_term = make_unique_local_assign(scope, term, p.first+"_rate"); - statements.push_back(std::move(local_term.local_decl)); - statements.push_back(std::move(local_term.assignment)); + statements_.push_back(std::move(local_term.local_decl)); + statements_.push_back(std::move(local_term.assignment)); scope = local_term.scope; // nop for now... auto& dterm = dterms[p.first]; @@ -163,13 +137,8 @@ inline void KineticRewriter::visit(ReactionExpression* e) { } } -inline void KineticRewriter::visit(ProcedureExpression* e) { - reset(); - proc_name = e->name(); - proc_loc = e->location(); - e->body()->accept(this); - - // make new procedure from saved statements and terms +void KineticRewriter::finalize() { + // append new derivative assignments from saved terms for (auto& p: dterms) { auto loc = p.second->location(); auto scope = p.second->scope(); @@ -184,13 +153,7 @@ inline void KineticRewriter::visit(ProcedureExpression* e) { std::move(deriv), std::move(p.second)); - assign->scope(scope); // don't re-do semantic analysis here - statements.push_back(std::move(assign)); + statements_.push_back(std::move(assign)); } } -inline void KineticRewriter::visit(BlockExpression* e) { - for (auto& s: e->statements()) { - s->accept(this); - } -} diff --git a/modcc/kineticrewriter.hpp b/modcc/kineticrewriter.hpp new file mode 100644 index 00000000..3dfb2648 --- /dev/null +++ b/modcc/kineticrewriter.hpp @@ -0,0 +1,6 @@ +#pragma once + +#include "expression.hpp" + +// Translate a supplied KINETIC block to equivalent DERIVATIVE block. +expression_ptr kinetic_rewrite(BlockExpression*); diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp index 0375eece..c1c210eb 100644 --- a/modcc/modcc.cpp +++ b/modcc/modcc.cpp @@ -13,8 +13,6 @@ #include "modccutil.hpp" #include "options.hpp" -using namespace nest::mc; - //#define VERBOSE int main(int argc, char **argv) { @@ -113,7 +111,7 @@ int main(int argc, char **argv) { std::cout << m.error_string() << std::endl; } - if(m.status() == lexerStatus::error) { + if(m.has_error()) { return 1; } @@ -123,7 +121,7 @@ int main(int argc, char **argv) { if(Options::instance().optimize) { if(Options::instance().verbose) std::cout << green("[") + "optimize" + green("]") << std::endl; m.optimize(); - if(m.status() == lexerStatus::error) { + if(m.has_error()) { return 1; } } @@ -175,15 +173,15 @@ int main(int argc, char **argv) { std::cout << yellow("method " + method->name()) << "\n"; std::cout << white("-------------------------\n"); - auto flops = util::make_unique<FlopVisitor>(); - method->accept(flops.get()); + FlopVisitor flops; + method->accept(&flops); std::cout << white("FLOPS") << std::endl; - std::cout << flops->print() << std::endl; + std::cout << flops.print() << std::endl; std::cout << white("MEMOPS") << std::endl; - auto memops = util::make_unique<MemOpVisitor>(); - method->accept(memops.get()); - std::cout << memops->print() << std::endl;; + MemOpVisitor memops; + method->accept(&memops); + std::cout << memops.print() << std::endl;; } } } diff --git a/modcc/modccutil.hpp b/modcc/modccutil.hpp index 63f0701a..ad83a03a 100644 --- a/modcc/modccutil.hpp +++ b/modcc/modccutil.hpp @@ -6,25 +6,39 @@ #include <vector> #include <initializer_list> -// is thing in list? -template <typename T, int N> -bool is_in(T thing, const T (&list)[N]) { - for(auto const& item : list) { - if(thing==item) { - return true; +namespace impl { + template <typename C, typename V> + struct has_count_method { + template <typename T, typename U> + static decltype(std::declval<T>().count(std::declval<U>()), std::true_type{}) test(int); + template <typename T, typename U> + static std::false_type test(...); + + using type = decltype(test<C, V>(0)); + }; + + template <typename X, typename C> + bool is_in(const X& x, const C& c, std::false_type) { + for (const auto& y: c) { + if (y==x) return true; } + return false; } - return false; -} -template <typename T> -bool is_in(T thing, const std::initializer_list<T> list) { - for(auto const& item : list) { - if(thing==item) { - return true; - } + template <typename X, typename C> + bool is_in(const X& x, const C& c, std::true_type) { + return !!c.count(x); } - return false; +} + +template <typename X, typename C> +bool is_in(const X& x, const C& c) { + return impl::is_in(x, c, typename impl::has_count_method<C,X>::type{}); +} + +template <typename X> +bool is_in(const X& x, const std::initializer_list<X>& c) { + return impl::is_in(x, c, std::false_type{}); } inline std::string pprintf(const char *s) { @@ -138,16 +152,8 @@ std::ostream& operator<< (std::ostream& os, std::vector<T> const& V) { return os << "]"; } -namespace nest { -namespace mc { -namespace util { - -// just because we aren't using C++14, doesn't mean we shouldn't go -// without make_unique template <typename T, typename... Args> std::unique_ptr<T> make_unique(Args&&... args) { return std::unique_ptr<T>(new T(std::forward<Args>(args) ...)); } -}}} - diff --git a/modcc/module.cpp b/modcc/module.cpp index e65aecd0..2b1590d6 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -8,10 +8,82 @@ #include "expressionclassifier.hpp" #include "functionexpander.hpp" #include "functioninliner.hpp" +#include "kineticrewriter.hpp" #include "module.hpp" #include "parser.hpp" +#include "solvers.hpp" +#include "symdiff.hpp" +#include "visitor.hpp" -using namespace nest::mc; +class NrnCurrentRewriter: public BlockRewriterBase { + expression_ptr id(const std::string& name, Location loc) { + return make_expression<IdentifierExpression>(loc, name); + } + + expression_ptr id(const std::string& name) { + return id(name, loc_); + } + + static ionKind is_ion_update(Expression* e) { + if(auto a = e->is_assignment()) { + if(auto sym = a->lhs()->is_identifier()->symbol()) { + if(auto var = sym->is_local_variable()) { + return var->ion_channel(); + } + } + } + return ionKind::none; + } + + moduleKind kind_; + bool has_current_update_ = false; + +public: + using BlockRewriterBase::visit; + + explicit NrnCurrentRewriter(moduleKind kind): kind_(kind) {} + + virtual void finalize() override { + if (has_current_update_) { + // Initialize current_ as first statement. + statements_.push_front(make_expression<AssignmentExpression>(loc_, + id("current_"), + make_expression<NumberExpression>(loc_, 0.0))); + + if (kind_==moduleKind::density) { + statements_.push_back(make_expression<AssignmentExpression>(loc_, + id("current_"), + make_expression<MulBinaryExpression>(loc_, + id("weights_"), + id("current_")))); + } + } + } + + virtual void visit(SolveExpression *e) override {} + virtual void visit(ConductanceExpression *e) override {} + virtual void visit(AssignmentExpression *e) override { + statements_.push_back(e->clone()); + auto loc = e->location(); + + if (is_ion_update(e)!=ionKind::none) { + has_current_update_ = true; + + if (!linear_test(e->rhs(), {"v"}).is_linear) { + error({"current update expressions must be linear in v: "+e->rhs()->to_string(), + e->location()}); + return; + } + else { + statements_.push_back(make_expression<AssignmentExpression>(loc, + id("current_", loc), + make_expression<AddBinaryExpression>(loc, + id("current_", loc), + e->lhs()->clone()))); + } + } + } +}; Module::Module(std::string const& fname) : fname_(fname) @@ -78,22 +150,26 @@ Module::symbols() const { return symbols_; } -void Module::error(std::string const& msg, Location loc) { - std::string location_info = pprintf("%:% ", file_name(), loc); - if(error_string_.size()) {// append to current string - error_string_ += "\n"; +std::string Module::error_string() const { + std::string str; + for (const error_entry& entry: errors()) { + if (!str.empty()) str += '\n'; + str += red("error "); + str += white(pprintf("%:% ", file_name(), entry.location)); + str += entry.message; } - error_string_ += red("error ") + white(location_info) + msg; - status_ = lexerStatus::error; + return str; } -void Module::warning(std::string const& msg, Location loc) { - std::string location_info = pprintf("%:% ", file_name(), loc); - if(error_string_.size()) {// append to current string - error_string_ += "\n"; +std::string Module::warning_string() const { + std::string str; + for (const error_entry& entry: errors()) { + if (!str.empty()) str += '\n'; + str += purple("error "); + str += white(pprintf("%:% ", file_name(), entry.location)); + str += entry.message; } - error_string_ += purple("warning ") + white(location_info) + msg; - has_warning_ = true; + return str; } bool Module::semantic() { @@ -165,13 +241,13 @@ bool Module::semantic() { s->semantic(symbols_); // then use an error visitor to print out all the semantic errors - auto v = util::make_unique<ErrorVisitor>(file_name()); - s->accept(v.get()); - errors += v->num_errors(); + ErrorVisitor v(file_name()); + s->accept(&v); + errors += v.num_errors(); // inline function calls // this requires that the symbol table has already been built - if(v->num_errors()==0) { + if(v.num_errors()==0) { auto &b = s->kind()==symbolKind::function ? s->is_function()->body()->statements() : s->is_procedure()->body()->statements(); @@ -246,9 +322,7 @@ bool Module::semantic() { } if(errors) { - std::cout << "\nthere were " << errors - << " errors in the semantic analysis" << std::endl; - status_ = lexerStatus::error; + error("There were "+std::to_string(errors)+" errors in the semantic analysis"); return false; } @@ -281,7 +355,7 @@ bool Module::semantic() { loc, name, std::vector<expression_ptr>(), // no arguments make_expression<BlockExpression> - (loc, std::list<expression_ptr>(), false) + (loc, expr_list_type(), false) ); auto proc = symbols_[name]->is_api_method(); @@ -312,46 +386,6 @@ bool Module::semantic() { return false; } - - // evaluate whether an expression has the form - // (b - x)/a - // where x is a state variable with name state_variable - // this test is used to detect ODEs with the signature - // dx/dt = (xinf - x)/xtau - // so that we can integrate them efficiently using the cnexp integrator - // - // this is messy, but ok for a once off. If this pattern is - // repeated, it will be worth finding a more sophisticated solution - auto is_gating = [] (Expression* e, std::string const& state_variable) { - IdentifierExpression* a = nullptr; - IdentifierExpression* b = nullptr; - BinaryExpression* other = nullptr; - if(auto binop = e->is_binary()) { - if(binop->op()==tok::divide) { - if((a = binop->rhs()->is_identifier())) { - other = binop->lhs()->is_binary(); - } - } - } - if(other) { - if(other->op()==tok::minus) { - if(auto rhs = other->rhs()->is_identifier()) { - if(rhs->name() == state_variable) { - if(auto lhs = other->lhs()->is_identifier()) { - if(lhs->name() != state_variable) { - b = lhs; - return std::make_pair(a, b); - } - } - } - } - } - } - - a = b = nullptr; - return std::make_pair(a, b); - }; - // Look in the symbol table for a procedure with the name "breakpoint". // This symbol corresponds to the BREAKPOINT block in the .mod file // There are two APIMethods generated from BREAKPOINT. @@ -361,222 +395,118 @@ bool Module::semantic() { auto api_state = state_api.first; auto breakpoint = state_api.second; - if( breakpoint ) { - // helper for making identifiers on the fly - auto id = [] (std::string const& name, Location loc=Location()) { - return make_expression<IdentifierExpression>(loc, name); - }; + api_state->semantic(symbols_); + scope_ptr nrn_state_scope = api_state->scope(); + + if(breakpoint) { //.......................................................... // nrn_state : The temporal integration of state variables //.......................................................... - // find the SOLVE statement - SolveExpression* solve_expression = nullptr; - for(auto& e: *(breakpoint->body())) { - solve_expression = e->is_solve_statement(); - if(solve_expression) break; - } + // grab SOLVE statements, put them in `nrn_state` after translation. + bool found_solve = false; + bool found_non_solve = false; + std::set<std::string> solved_ids; - // handle the case where there is no SOLVE in BREAKPOINT - if( solve_expression==nullptr ) { - warning( " there is no SOLVE statement, required to update the" - " state variables, in the BREAKPOINT block", - breakpoint->location()); - } - else { - // get the DERIVATIVE block - auto dblock = solve_expression->procedure(); - - // body refers to the currently empty body of the APIMethod that - // will hold the AST for the nrn_state function. - auto& body = api_state->body()->statements(); - - auto has_provided_integration_method = - solve_expression->method() == solverMethod::cnexp; - - // loop over the statements in the SOLVE block from the mod file - // put each statement into the new APIMethod, performing - // transformations if necessary. - for(auto& e : *(dblock->body())) { - if(auto ass = e->is_assignment()) { - auto lhs = ass->lhs(); - auto rhs = ass->rhs(); - if(auto deriv = lhs->is_derivative()) { - // Check that a METHOD was provided in the original SOLVE - // statment. We have to do this because it is possible - // to call SOLVE without a METHOD, in which case there should - // be no derivative expressions in the DERIVATIVE block. - if(!has_provided_integration_method) { - error("The DERIVATIVE block has a derivative expression" - " but no METHOD was specified in the SOLVE statement", - deriv->location()); - return false; - } + for(auto& e: (breakpoint->body()->statements())) { + SolveExpression* solve_expression = e->is_solve_statement(); + if(!solve_expression) { + found_non_solve = true; + continue; + } + if(found_non_solve) { + error("SOLVE statements must come first in BREAKPOINT block", + e->location()); + return false; + } - auto sym = deriv->symbol(); - auto name = deriv->name(); - - auto gating_vars = is_gating(rhs, name); - if(gating_vars.first && gating_vars.second) { - auto const& inf = gating_vars.second->spelling(); - auto const& rate = gating_vars.first->spelling(); - auto e_string = name + "=" + inf - + "+(" + name + "-" + inf + ")*exp(-dt/" - + rate + ")"; - auto stmt_update = Parser(e_string).parse_line_expression(); - body.emplace_back(std::move(stmt_update)); - continue; - } - else { - // create visitor for linear analysis - auto v = util::make_unique<ExpressionClassifierVisitor>(sym); - rhs->accept(v.get()); - - // quit if ODE is not linear - if( v->classify() != expressionClassification::linear ) { - error("unable to integrate nonlinear state ODEs", - rhs->location()); - return false; - } - - // the linear differential equation is of the form - // s' = a*s + b - // integration by separation of variables gives the following - // update function to integrate s for one time step dt - // s = -b/a + (s+b/a)*exp(a*dt) - // we are going to build this update function by - // 1. generating statements that define a_=a and ba_=b/a - // 2. generating statements that update the solution - - // statement : a_ = a - auto stmt_a = - binary_expression(Location(), - tok::eq, - id("a_"), - v->linear_coefficient()->clone()); - - // expression : b/a - auto expr_ba = - binary_expression(Location(), - tok::divide, - v->constant_term()->clone(), - id("a_")); - // statement : ba_ = b/a - auto stmt_ba = binary_expression(Location(), tok::eq, id("ba_"), std::move(expr_ba)); - - // the update function - auto e_string = name + " = -ba_ + " - "(" + name + " + ba_)*exp(a_*dt)"; - auto stmt_update = Parser(e_string).parse_line_expression(); - - // add declaration of local variables - body.emplace_back(Parser("LOCAL a_").parse_local()); - body.emplace_back(Parser("LOCAL ba_").parse_local()); - // add integration statements - body.emplace_back(std::move(stmt_a)); - body.emplace_back(std::move(stmt_ba)); - body.emplace_back(std::move(stmt_update)); - continue; - } - } - else { - body.push_back(e->clone()); - continue; - } - } - body.push_back(e->clone()); + found_solve = true; + std::unique_ptr<SolverVisitorBase> solver; + + switch(solve_expression->method()) { + case solverMethod::cnexp: + solver = make_unique<CnexpSolverVisitor>(); + break; + case solverMethod::sparse: + solver = make_unique<SparseSolverVisitor>(); + break; + case solverMethod::none: + solver = make_unique<DirectSolverVisitor>(); + break; } - } - // perform semantic analysis - api_state->semantic(symbols_); + // If the derivative block is a kinetic block, perform kinetic + // rewrite first. - //.......................................................... - // nrn_current : update contributions to currents - //.......................................................... - std::list<expression_ptr> block; - - // helper which tests a statement to see if it updates an ion - // channel variable. - auto is_ion_update = [] (Expression* e) { - if(auto a = e->is_assignment()) { - // semantic analysis has been performed on the original expression - // which ensures that the lhs is an identifier and a variable - if(auto sym = a->lhs()->is_identifier()->symbol()) { - // assume that a scalar stack variable is being used for - // the indexed value: i.e. the value is not cached - if(auto var = sym->is_local_variable()) { - return var->ion_channel(); - } - } + auto deriv = solve_expression->procedure(); + + if (deriv->kind()==procedureKind::kinetic) { + kinetic_rewrite(deriv->body())->accept(solver.get()); } - return ionKind::none; - }; - - // add statements that initialize the reduction variables - bool has_current_update = false; - for(auto& e: *(breakpoint->body())) { - // ignore solve and conductance statements - if(e->is_solve_statement()) continue; - if(e->is_conductance_statement()) continue; - - // add the expression - block.emplace_back(e->clone()); - - // we are updating an ionic current - // so keep track of current and conductance accumulation - auto channel = is_ion_update(e.get()); - if(channel != ionKind::none) { - auto lhs = e->is_assignment()->lhs()->is_identifier(); - auto rhs = e->is_assignment()->rhs(); - - // analyze the expression for linear terms - //auto v = util::make_unique<ExpressionClassifierVisitor>(symbols_["v"].get()); - auto v_symbol = breakpoint->scope()->find("v"); - auto v = util::make_unique<ExpressionClassifierVisitor>(v_symbol); - rhs->accept(v.get()); - - if(v->classify()==expressionClassification::linear) { - // add current update - if(has_current_update) { - block.emplace_back(Parser("current_ = current_ + " + lhs->name()).parse_line_expression()); - } - else { - block.emplace_back(Parser("current_ = " + lhs->name()).parse_line_expression()); + else { + deriv->body()->accept(solver.get()); + } + + if (auto solve_block = solver->as_block(false)) { + // Check that we didn't solve an already solved variable. + for (const auto& id: solver->solved_identifiers()) { + if (solved_ids.count(id)>0) { + error("Variable "+id+" solved twice!", e->location()); + return false; } + solved_ids.insert(id); } - else { - error("current update functions must be a linear" - " function of v : " + rhs->to_string(), e->location()); - return false; + + // May have now redundant local variables; remove these first. + solve_block = remove_unused_locals(solve_block->is_block()); + + // Copy body into nrn_state. + for (auto& stmt: solve_block->is_block()->statements()) { + api_state->body()->statements().push_back(std::move(stmt)); } - has_current_update = true; } + else { + // Something went wrong: copy errors across. + append_errors(solver->errors()); + return false; + } + } + + // handle the case where there is no SOLVE in BREAKPOINT + if(!found_solve) { + warning(" there is no SOLVE statement, required to update the" + " state variables, in the BREAKPOINT block", + breakpoint->location()); } - if(has_current_update && kind()==moduleKind::density) { - block.emplace_back(Parser("current_ = weights_ * current_").parse_line_expression()); + else { + // redo semantic pass in order to elimate any removed local symbols. + api_state->semantic(symbols_); } - auto v = util::make_unique<ConstantFolderVisitor>(); - for(auto& e : block) { - e->accept(v.get()); + //.......................................................... + // nrn_current : update contributions to currents + //.......................................................... + NrnCurrentRewriter nrn_current_rewriter(kind()); + breakpoint->accept(&nrn_current_rewriter); + auto nrn_current_block = nrn_current_rewriter.as_block(); + if (!nrn_current_block) { + append_errors(nrn_current_rewriter.errors()); + return false; } symbols_["nrn_current"] = make_symbol<APIMethod>( breakpoint->location(), "nrn_current", std::vector<expression_ptr>(), - make_expression<BlockExpression>(breakpoint->location(), - std::move(block), false) - ); + constant_simplify(nrn_current_block)); symbols_["nrn_current"]->semantic(symbols_); } else { - error("a BREAKPOINT block is required", Location()); + error("a BREAKPOINT block is required"); return false; } - return status() == lexerStatus::happy; + return !has_error(); } /// populate the symbol table with class scope variables @@ -782,7 +712,7 @@ bool Module::optimize() { // how to structure the optimizer // loop over APIMethods // - apply optimization to each in turn - auto folder = util::make_unique<ConstantFolderVisitor>(); + ConstantFolderVisitor folder; for(auto &symbol : symbols_) { auto kind = symbol.second->kind(); BlockExpression* body; @@ -809,7 +739,7 @@ bool Module::optimize() { // perform constant folding for(auto& line : *body) { - line->accept(folder.get()); + line->accept(&folder); } // preform expression simplification diff --git a/modcc/module.hpp b/modcc/module.hpp index 5a16e64c..41abdeb5 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -4,11 +4,12 @@ #include <vector> #include "blocks.hpp" +#include "error.hpp" #include "expression.hpp" // wrapper around a .mod file -class Module { -public : +class Module: public error_stack { +public: using symbol_map = scope_type::symbol_map; using symbol_ptr = scope_type::symbol_ptr; @@ -58,36 +59,30 @@ public : symbol_map const& symbols() const; // error handling - void error(std::string const& msg, Location loc); - std::string const& error_string() { - return error_string_; + using error_stack::error; + void error(std::string const& msg, Location loc = Location{}) { + error({msg, loc}); } - lexerStatus status() const { - return status_; - } + std::string error_string() const; // warnings - void warning(std::string const& msg, Location loc); - bool has_warning() const { - return has_warning_; - } - bool has_error() const { - return status()==lexerStatus::error; + using error_stack::warning; + void warning(std::string const& msg, Location loc = Location{}) { + warning({msg, loc}); } - moduleKind kind() const { - return kind_; - } - void kind(moduleKind k) { - kind_ = k; - } + std::string warning_string() const; + + moduleKind kind() const { return kind_; } + void kind(moduleKind k) { kind_ = k; } // perform semantic analysis void add_variables_to_symbols(); bool semantic(); bool optimize(); -private : + +private: moduleKind kind_; std::string title_; std::string fname_; @@ -97,11 +92,6 @@ private : bool generate_current_api(); bool generate_state_api(); - // error handling - std::string error_string_; - lexerStatus status_ = lexerStatus::happy; - bool has_warning_ = false; - // AST storage std::vector<symbol_ptr> procedures_; std::vector<symbol_ptr> functions_; diff --git a/modcc/msparse.hpp b/modcc/msparse.hpp new file mode 100644 index 00000000..c7868764 --- /dev/null +++ b/modcc/msparse.hpp @@ -0,0 +1,206 @@ +#pragma once + +// (Possibly augmented) matrix implementation, represented as a vector of sparse rows. + +#include <algorithm> +#include <utility> +#include <initializer_list> +#include <iterator> +#include <stdexcept> +#include <vector> + +namespace msparse { + +struct msparse_error: std::runtime_error { + msparse_error(const std::string &what): std::runtime_error(what) {} +}; + +constexpr unsigned row_npos = unsigned(-1); + +// `msparse::row` represents one sparse matrix row as a vector of +// (column, value) pairs, ordered by (unsigned) column. `row_npos` +// is used to represent an invalid column number. + +template <typename X> +class row { +public: + struct entry { + unsigned col; + X value; + }; + static constexpr unsigned npos = row_npos; + +private: + std::vector<entry> data_; + + // Entries must have strictly monotonically increasing column numbers. + bool check_invariant() const { + for (unsigned i = 1; i<data_.size(); ++i) { + if (data_[i].col<=data_[i-1].col) return false; + } + return true; + } + +public: + row() = default; + row(const row&) = default; + + row(std::initializer_list<entry> il): data_(il) { + if (!check_invariant()) + throw msparse_error("improper row element list"); + } + + template <typename InIter> + row(InIter b, InIter e): data_(b, e) { + if (!check_invariant()) + throw msparse_error("improper row element list"); + } + + unsigned size() const { return data_.size(); } + bool empty() const { return size()==0; } + + // Iterators present row as sequence of `entry` objects. + auto begin() -> decltype(data_.begin()) { return data_.begin(); } + auto begin() const -> decltype(data_.cbegin()) { return data_.cbegin(); } + auto end() -> decltype(data_.end()) { return data_.end(); } + auto end() const -> decltype(data_.cend()) { return data_.cend(); } + + // Return column of first (left-most) entry. + unsigned mincol() const { + return empty()? npos: data_.front().col; + } + + // Return column of first entry with column greater than `c`. + unsigned mincol_after(unsigned c) const { + auto i = std::upper_bound(data_.begin(), data_.end(), c, + [](unsigned a, const entry& b) { return a<b.col; }); + + return i==data_.end()? npos: i->col; + } + + // Return column of last (right-most) entry. + unsigned maxcol() const { + return empty()? npos: data_.back().col; + } + + // As opposed to [] indexing (see below), retrieve `i'th entry from + // the list of entries. + const entry& get(unsigned i) const { + return data_[i]; + } + + void push_back(const entry& e) { + if (!empty() && e.col <= data_.back().col) + throw msparse_error("cannot push_back row elements out of order"); + data_.push_back(e); + } + + // Return index into entry list which has column `c`. + unsigned index(unsigned c) const { + auto i = std::lower_bound(data_.begin(), data_.end(), c, + [](const entry& a, unsigned b) { return a.col<b; }); + + return (i==data_.end() || i->col!=c)? npos: std::distance(data_.begin(), i); + } + + // Remove all entries from column `c` onwards. + void truncate(unsigned c) { + auto i = std::lower_bound(data_.begin(), data_.end(), c, + [](const entry& a, unsigned b) { return a.col<b; }); + data_.erase(i, data_.end()); + } + + // Return value at column `c`; if no entry in row, return default-constructed `X`, + // i.e. 0 for numeric types. + X operator[](unsigned c) const { + auto i = index(c); + return i==npos? X{}: data_[i].value; + } + + // Proxy object to allow assigning elements with the syntax `row[c] = value`. + struct assign_proxy { + row<X>& row_; + unsigned c; + + assign_proxy(row<X>& r, unsigned c): row_(r), c(c) {} + + operator X() const { return const_cast<const row<X>&>(row_)[c]; } + assign_proxy& operator=(const X& x) { + auto i = std::lower_bound(row_.data_.begin(), row_.data_.end(), c, + [](const entry& a, unsigned b) { return a.col<b; }); + + if (i==row_.data_.end() || i->col!=c) { + row_.data_.insert(i, {c, x}); + } + else if (x == X{}) { + row_.data_.erase(i); + } + else { + i->value = x; + } + + return *this; + } + }; + + assign_proxy operator[](unsigned c) { + return assign_proxy{*this, c}; + } +}; + +// `msparse::matrix` represents a matrix by a size (number of rows, +// columns) and vector of sparse `mspase::row` rows. +// +// The matrix may also be 'augmented', with columns corresponding to a second +// matrix appended on the right. + +template <typename X> +class matrix { +private: + std::vector<row<X>> rows; + unsigned cols = 0; + unsigned aug = row_npos; + +public: + static constexpr unsigned npos = row_npos; + + matrix() = default; + matrix(unsigned n, unsigned c): rows(n), cols(c) {} + + row<X>& operator[](unsigned i) { return rows[i]; } + const row<X>& operator[](unsigned i) const { return rows[i]; } + + unsigned size() const { return rows.size(); } + unsigned nrow() const { return size(); } + unsigned ncol() const { return cols; } + + // First column corresponding to the augmented submatrix. + unsigned augcol() const { return aug; } + + bool empty() const { return size()==0; } + bool augmented() const { return aug!=npos; } + + // Add a column on the right as part of the augmented submatrix. + // The new entries are provided by a (full, dense representation) + // sequence of values. + template <typename Seq> + void augment(const Seq& col_dense) { + unsigned r = 0; + for (const auto& v: col_dense) { + if (r>=rows.size()) throw msparse_error("augmented column size mismatch"); + rows[r++].push_back({cols, v}); + } + if (aug==npos) aug=cols; + ++cols; + } + + // Remove all augmented columns. + void diminish() { + if (aug==npos) return; + for (auto& row: rows) row.truncate(aug); + cols = aug; + aug = npos; + } +}; + +} // namespace msparse diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 7bf11c37..84f79413 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1312,8 +1312,16 @@ expression_ptr Parser::parse_solve() { } else { get_token(); // consume the METHOD keyword - if(token_.type != tok::cnexp) goto solve_statement_error; - method = solverMethod::cnexp; + switch(token_.type) { + case tok::cnexp: + method = solverMethod::cnexp; + break; + case tok::sparse: + method = solverMethod::sparse; + break; + default: + goto solve_statement_error; + } get_token(); // consume the method description } @@ -1326,10 +1334,11 @@ expression_ptr Parser::parse_solve() { solve_statement_error: error( "SOLVE statements must have the form\n" - " SOLVE x METHOD cnexp\n" + " SOLVE x METHOD method\n" " or\n" " SOLVE x\n" - "where 'x' is the name of a DERIVATIVE block", loc); + "where 'x' is the name of a DERIVATIVE block and " + "'method' is 'cnexp' or 'sparse'", loc); return nullptr; } @@ -1431,7 +1440,7 @@ expression_ptr Parser::parse_block(bool is_nested) { // save the location of the first statement as the starting point for the block Location block_location = token_.location; - std::list<expression_ptr> body; + expr_list_type body; while(token_.type != tok::rbrace) { auto e = parse_statement(); if(!e) return e; @@ -1473,7 +1482,7 @@ expression_ptr Parser::parse_initial() { if(!expect(tok::lbrace)) return nullptr; get_token(); // consume '{' - std::list<expression_ptr> body; + expr_list_type body; while(token_.type != tok::rbrace) { auto e = parse_statement(); if(!e) return e; diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp new file mode 100644 index 00000000..f3f4a7d6 --- /dev/null +++ b/modcc/solvers.cpp @@ -0,0 +1,448 @@ +#include <map> +#include <set> +#include <string> +#include <vector> + +#include "astmanip.hpp" +#include "expression.hpp" +#include "parser.hpp" +#include "solvers.hpp" +#include "symdiff.hpp" +#include "symge.hpp" +#include "visitor.hpp" + +// Cnexp solver visitor implementation. + +void CnexpSolverVisitor::visit(BlockExpression* e) { + // Do a first pass to extract variables comprising ODE system + // lhs; can't really trust 'STATE' block. + + for (auto& stmt: e->statements()) { + if (stmt && stmt->is_assignment() && stmt->is_assignment()->lhs()->is_derivative()) { + auto id = stmt->is_assignment()->lhs()->is_derivative(); + dvars_.push_back(id->name()); + } + } + + BlockRewriterBase::visit(e); +} + +void CnexpSolverVisitor::visit(AssignmentExpression *e) { + auto loc = e->location(); + scope_ptr scope = e->scope(); + + auto lhs = e->lhs(); + auto rhs = e->rhs(); + auto deriv = lhs->is_derivative(); + + if (!deriv) { + statements_.push_back(e->clone()); + return; + } + + auto s = deriv->name(); + linear_test_result r = linear_test(rhs, dvars_); + + if (!r.monolinear(s)) { + error({"System not diagonal linear for cnexp", loc}); + return; + } + + Expression* coef = r.coef[s].get(); + if (r.is_homogeneous) { + // s' = a*s becomes s = s*exp(a*dt); use a_ as a local variable + // for the coefficient. + auto local_a_term = make_unique_local_assign(scope, coef, "a_"); + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } + else if (is_zero(coef)) { + // s' = b becomes s = s + b*dt; use b_ as a local variable for + // the constant term b. + + auto local_b_term = make_unique_local_assign(scope, r.constant.get(), "b_"); + statements_.push_back(std::move(local_b_term.local_decl)); + statements_.push_back(std::move(local_b_term.assignment)); + auto b_ = local_b_term.id->is_identifier()->spelling(); + + std::string s_update = pprintf("% = %+%*dt", s, s, b_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } + else { + // s' = a*s + b becomes s = -b/a + (s+b/a)*exp(a*dt); use + // a_ as a local variable for the coefficient and ba_ for the + // quotient. + // + // Note though this will be numerically bad for very small + // (or zero) a. Perhaps re-implement as: + // s = s + exprel(a*dt)*(s*a+b)*dt + // where exprel(x) = (exp(x)-1)/x and can be well approximated + // by e.g. a Taylor expansion for small x. + // + // Special case ('gating variable') when s' = (b-s)/a; rather + // than implement more general algebraic simplification, jump + // straight to simplified update: s = b + (s-b)*exp(-dt/a). + + // Check for 'gating' case: + if (rhs->is_binary() && rhs->is_binary()->op()==tok::divide) { + auto denom = rhs->is_binary()->rhs(); + if (involves_identifier(denom, s)) { + goto not_gating; + } + auto numer = rhs->is_binary()->lhs(); + linear_test_result r = linear_test(numer, {s}); + if (expr_value(r.coef[s]) != -1) { + goto not_gating; + } + + auto local_a_term = make_unique_local_assign(scope, denom, "a_"); + auto a_ = local_a_term.id->is_identifier()->spelling(); + auto local_b_term = make_unique_local_assign(scope, r.constant, "b_"); + auto b_ = local_b_term.id->is_identifier()->spelling(); + + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + statements_.push_back(std::move(local_b_term.local_decl)); + statements_.push_back(std::move(local_b_term.assignment)); + + std::string s_update = pprintf("% = %+(%-%)*exp(-dt/%)", s, b_, s, b_, a_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } + +not_gating: + auto local_a_term = make_unique_local_assign(scope, coef, "a_"); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + auto ba_expr = make_expression<DivBinaryExpression>(loc, + r.constant->clone(), local_a_term.id->clone()); + auto local_ba_term = make_unique_local_assign(scope, ba_expr, "ba_"); + auto ba_ = local_ba_term.id->is_identifier()->spelling(); + + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + statements_.push_back(std::move(local_ba_term.local_decl)); + statements_.push_back(std::move(local_ba_term.assignment)); + + std::string s_update = pprintf("% = -%+(%+%)*exp(%*dt)", s, ba_, s, ba_, a_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } +} + + +// Sparse solver visitor implementation. + +static expression_ptr as_expression(symge::symbol_term term) { + Location loc; + if (term.is_zero()) { + return make_expression<IntegerExpression>(loc, 0); + } + else { + return make_expression<MulBinaryExpression>(loc, + make_expression<IdentifierExpression>(loc, name(term.left)), + make_expression<IdentifierExpression>(loc, name(term.right))); + } +} + +static expression_ptr as_expression(symge::symbol_term_diff diff) { + Location loc; + if (diff.left.is_zero() && diff.right.is_zero()) { + return make_expression<IntegerExpression>(loc, 0); + } + else if (diff.right.is_zero()) { + return as_expression(diff.left); + } + else if (diff.left.is_zero()) { + return make_expression<NegUnaryExpression>(loc, + as_expression(diff.right)); + } + else { + return make_expression<SubBinaryExpression>(loc, + as_expression(diff.left), + as_expression(diff.right)); + } +} + +void SparseSolverVisitor::visit(BlockExpression* e) { + // Do a first pass to extract variables comprising ODE system + // lhs; can't really trust 'STATE' block. + + for (auto& stmt: e->statements()) { + if (stmt && stmt->is_assignment() && stmt->is_assignment()->lhs()->is_derivative()) { + auto id = stmt->is_assignment()->lhs()->is_derivative(); + dvars_.push_back(id->name()); + } + } + + BlockRewriterBase::visit(e); +} + +void SparseSolverVisitor::visit(AssignmentExpression *e) { + if (A_.empty()) { + unsigned n = dvars_.size(); + A_ = symge::sym_matrix(n, n); + } + + auto loc = e->location(); + scope_ptr scope = e->scope(); + + auto lhs = e->lhs(); + auto rhs = e->rhs(); + auto deriv = lhs->is_derivative(); + + if (!deriv) { + statements_.push_back(e->clone()); + + auto id = lhs->is_identifier(); + if (id) { + auto expand = substitute(rhs, local_expr_); + if (involves_identifier(expand, dvars_)) { + local_expr_[id->spelling()] = std::move(expand); + } + } + return; + } + + auto s = deriv->name(); + auto expanded_rhs = substitute(rhs, local_expr_); + linear_test_result r = linear_test(expanded_rhs, dvars_); + if (!r.is_homogeneous) { + error({"System not homogeneous linear for sparse", loc}); + return; + } + + // Populate sparse symbolic matrix for GE. + if (s!=dvars_[deq_index_]) { + error({"ICE: inconsistent ordering of derivative assignments", loc}); + } + + auto dt_expr = make_expression<IdentifierExpression>(loc, "dt"); + auto one_expr = make_expression<NumberExpression>(loc, 1.0); + for (unsigned j = 0; j<dvars_.size(); ++j) { + expression_ptr expr; + + // For zero coefficient and diagonal element, the matrix entry is 1. + // For non-zero coefficient c and diagonal element, the entry is 1-c*dt. + // Otherwise, for non-zero coefficient c, the entry is -c*dt. + + if (r.coef.count(dvars_[j])) { + expr = make_expression<MulBinaryExpression>(loc, + r.coef[dvars_[j]]->clone(), + dt_expr->clone()); + } + + if (j==deq_index_) { + if (expr) { + expr = make_expression<SubBinaryExpression>(loc, + one_expr->clone(), + std::move(expr)); + } + else { + expr = one_expr->clone(); + } + } + else if (expr) { + expr = make_expression<NegUnaryExpression>(loc, std::move(expr)); + } + + if (!expr) continue; + + auto local_a_term = make_unique_local_assign(scope, expr.get(), "a_"); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + + A_[deq_index_].push_back({j, symtbl_.define(a_)}); + } + ++deq_index_; +} + +void SparseSolverVisitor::finalize() { + std::vector<symge::symbol> rhs; + for (const auto& var: dvars_) { + rhs.push_back(symtbl_.define(var)); + } + A_.augment(rhs); + + symge::gj_reduce(A_, symtbl_); + + // Create and assign intermediate variables. + for (unsigned i = 0; i<symtbl_.size(); ++i) { + symge::symbol s = symtbl_[i]; + + if (primitive(s)) continue; + + auto expr = as_expression(definition(s)); + auto local_t_term = make_unique_local_assign(block_scope_, expr.get(), "t_"); + auto t_ = local_t_term.id->is_identifier()->spelling(); + symtbl_.name(s, t_); + + statements_.push_back(std::move(local_t_term.local_decl)); + statements_.push_back(std::move(local_t_term.assignment)); + } + + // State variable updates given by rhs/diagonal for reduced matrix. + Location loc; + for (unsigned i = 0; i<A_.nrow(); ++i) { + unsigned rhs = A_.augcol(); + + auto expr = + make_expression<AssignmentExpression>(loc, + make_expression<IdentifierExpression>(loc, dvars_[i]), + make_expression<DivBinaryExpression>(loc, + make_expression<IdentifierExpression>(loc, symge::name(A_[i][rhs])), + make_expression<IdentifierExpression>(loc, symge::name(A_[i][i])))); + + statements_.push_back(std::move(expr)); + } + + BlockRewriterBase::finalize(); +} + +// Implementation for `remove_unused_locals`: uses two visitors, +// `UnusedVisitor` and `RemoveVariableVisitor` below. + +class UnusedVisitor : public Visitor { +protected: + std::multimap<std::string, std::string> deps; + std::set<std::string> unused_ids; + std::set<std::string> used_ids; + Symbol* lhs_sym = nullptr; + +public: + using Visitor::visit; + + UnusedVisitor() {} + + virtual void visit(Expression* e) override {} + + virtual void visit(BlockExpression* e) override { + for (auto& s: e->statements()) { + s->accept(this); + } + } + + virtual void visit(AssignmentExpression* e) override { + auto lhs = e->lhs()->is_identifier(); + if (!lhs) return; + + lhs_sym = lhs->symbol(); + e->rhs()->accept(this); + lhs_sym = nullptr; + } + + virtual void visit(UnaryExpression* e) override { + e->expression()->accept(this); + } + + virtual void visit(BinaryExpression* e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + } + + virtual void visit(CallExpression* e) override { + for (auto& a: e->args()) { + a->accept(this); + } + } + + virtual void visit(IfExpression* e) override { + e->condition()->accept(this); + e->true_branch()->accept(this); + e->false_branch()->accept(this); + } + + virtual void visit(IdentifierExpression* e) override { + if (lhs_sym && lhs_sym->is_local_variable()) { + deps.insert({lhs_sym->name(), e->name()}); + } + else { + used_ids.insert(e->name()); + } + } + + virtual void visit(LocalDeclaration* e) override { + for (auto& v: e->variables()) { + unused_ids.insert(v.first); + } + } + + std::set<std::string> unused_locals() { + if (!computed_) { + for (auto& id: used_ids) { + remove_deps_from_unused(id); + } + computed_ = true; + } + return unused_ids; + } + + void reset() { + deps.clear(); + unused_ids.clear(); + used_ids.clear(); + computed_ = false; + } + +private: + bool computed_ = false; + + void remove_deps_from_unused(const std::string& id) { + auto range = deps.equal_range(id); + for (auto i = range.first; i != range.second; ++i) { + if (unused_ids.count(i->second)) { + remove_deps_from_unused(i->second); + } + } + unused_ids.erase(id); + } +}; + +class RemoveVariableVisitor: public BlockRewriterBase { + std::set<std::string> remove_; + +public: + using BlockRewriterBase::visit; + + RemoveVariableVisitor(std::set<std::string> ids): + remove_(std::move(ids)) {} + + RemoveVariableVisitor(std::set<std::string> ids, scope_ptr enclosing): + BlockRewriterBase(enclosing), remove_(std::move(ids)) {} + + virtual void visit(LocalDeclaration* e) override { + auto replacement = e->clone(); + auto& vars = replacement->is_local_declaration()->variables(); + + for (const auto& id: remove_) { + vars.erase(id); + } + if (!vars.empty()) { + statements_.push_back(std::move(replacement)); + } + } + + virtual void visit(AssignmentExpression* e) override { + std::string lhs_id = e->lhs()->is_identifier()->name(); + if (!remove_.count(lhs_id)) { + statements_.push_back(e->clone()); + } + } +}; + +expression_ptr remove_unused_locals(BlockExpression* block) { + UnusedVisitor unused_visitor; + block->accept(&unused_visitor); + + RemoveVariableVisitor remove_visitor(unused_visitor.unused_locals()); + block->accept(&remove_visitor); + return remove_visitor.as_block(false); +} diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp new file mode 100644 index 00000000..e2bbc353 --- /dev/null +++ b/modcc/solvers.hpp @@ -0,0 +1,97 @@ +#pragma once + +// Transform derivative block into AST representing +// an integration step over the state variables, based on +// solver method. + +#include <string> +#include <vector> + +#include "expression.hpp" +#include "symdiff.hpp" +#include "symge.hpp" +#include "visitor.hpp" + +expression_ptr remove_unused_locals(BlockExpression* block); + +class SolverVisitorBase: public BlockRewriterBase { +protected: + // list of identifier names appearing in derivatives on lhs + std::vector<std::string> dvars_; + +public: + using BlockRewriterBase::visit; + + SolverVisitorBase() {} + SolverVisitorBase(scope_ptr enclosing): BlockRewriterBase(enclosing) {} + + virtual std::vector<std::string> solved_identifiers() const { + return dvars_; + } + + virtual void reset() override { + dvars_.clear(); + BlockRewriterBase::reset(); + } +}; + +class DirectSolverVisitor : public SolverVisitorBase { +public: + using SolverVisitorBase::visit; + + DirectSolverVisitor() {} + DirectSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(AssignmentExpression *e) override { + // No solver method, so declare an error if lhs is a derivative. + if(auto deriv = e->lhs()->is_derivative()) { + error({"The DERIVATIVE block has a derivative expression" + " but no METHOD was specified in the SOLVE statement", + deriv->location()}); + } + } +}; + +class CnexpSolverVisitor : public SolverVisitorBase { +public: + using SolverVisitorBase::visit; + + CnexpSolverVisitor() {} + CnexpSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(BlockExpression* e) override; + virtual void visit(AssignmentExpression *e) override; +}; + +class SparseSolverVisitor : public SolverVisitorBase { +protected: + // 'Current' differential equation is for variable with this + // index in `dvars`. + unsigned deq_index_ = 0; + + // Expanded local assignments that need to be substituted in for derivative + // calculations. + substitute_map local_expr_; + + // Symbolic matrix for backwards Euler step. + symge::sym_matrix A_; + + // 'Symbol table' for symbolic manipulation. + symge::symbol_table symtbl_; + +public: + using SolverVisitorBase::visit; + + SparseSolverVisitor() {} + SparseSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(BlockExpression* e) override; + virtual void visit(AssignmentExpression *e) override; + virtual void finalize() override; + virtual void reset() override { + deq_index_ = 0; + local_expr_.clear(); + symtbl_.clear(); + SolverVisitorBase::reset(); + } +}; diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp new file mode 100644 index 00000000..7a634434 --- /dev/null +++ b/modcc/symdiff.cpp @@ -0,0 +1,692 @@ +#include <cmath> +#include <map> +#include <set> +#include <stdexcept> +#include <string> +#include <utility> + +#include "error.hpp" +#include "expression.hpp" +#include "symdiff.hpp" +#include "visitor.hpp" + +class FindIdentifierVisitor: public Visitor { +public: + explicit FindIdentifierVisitor(const identifier_set& ids): ids_(ids) {} + + void reset() { found_ = false; } + bool found() const { return found_; } + + void visit(Expression* e) override {} + + void visit(UnaryExpression* e) override { + if (!found()) e->expression()->accept(this); + } + + void visit(BinaryExpression* e) override { + if (!found()) e->lhs()->accept(this); + if (!found()) e->rhs()->accept(this); + } + + void visit(CallExpression* e) override { + for (auto& e: e->args()) { + if (found()) return; + e->accept(this); + } + } + + void visit(PDiffExpression* e) override { + if (!found()) e->arg()->accept(this); + } + + void visit(IdentifierExpression* e) override { + if (!found()) { + found_ |= is_in(e->spelling(), ids_); + } + } + + void visit(DerivativeExpression* e) override { + if (!found()) { + found_ |= is_in(e->spelling(), ids_); + } + } + void visit(ReactionExpression* e) override { + if (!found()) e->lhs()->accept(this); + if (!found()) e->rhs()->accept(this); + if (!found()) e->fwd_rate()->accept(this); + if (!found()) e->rev_rate()->accept(this); + } + + void visit(StoichTermExpression* e) override { + if (!found()) e->ident()->accept(this); + } + + void visit(StoichExpression* e) override { + for (auto& e: e->terms()) { + if (found()) return; + e->accept(this); + } + } + + void visit(BlockExpression* e) override { + for (auto& e: e->statements()) { + if (found()) return; + e->accept(this); + } + } + + void visit(IfExpression* e) override { + if (!found()) e->condition()->accept(this); + if (!found()) e->true_branch()->accept(this); + if (!found()) e->false_branch()->accept(this); + } + +private: + const identifier_set& ids_; + bool found_ = false; +}; + +bool involves_identifier(Expression* e, const identifier_set& ids) { + FindIdentifierVisitor v(ids); + e->accept(&v); + return v.found(); +} + +bool involves_identifier(Expression* e, const std::string& id) { + identifier_set ids = {id}; + FindIdentifierVisitor v(ids); + e->accept(&v); + return v.found(); +} + +class SymPDiffVisitor: public Visitor, public error_stack { +public: + explicit SymPDiffVisitor(std::string id): id_(std::move(id)) {} + + void reset() { result_ = nullptr; } + + // Note: moves result, forces reset. + expression_ptr result() { + auto r = std::move(result_); + reset(); + return r; + } + + void visit(Expression* e) override { + error({"symbolic differential of improper expression", e->location()}); + } + + void visit(UnaryExpression* e) override { + error({"symbolic differential of unrecognized unary expression", e->location()}); + } + + void visit(BinaryExpression* e) override { + error({"symbolic differential of unrecognized binary expression", e->location()}); + } + + void visit(NegUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<NegUnaryExpression>(loc, result()); + } + + void visit(ExpUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, result(), e->clone()); + } + + void visit(LogUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<DivBinaryExpression>(loc, result(), e->expression()->clone()); + } + + void visit(CosUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<NegUnaryExpression>(loc, + make_expression<SinUnaryExpression>(loc, e->expression()->clone())), + result()); + } + + void visit(SinUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<CosUnaryExpression>(loc, e->expression()->clone()), + result()); + } + + void visit(AddBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + result_ = make_expression<AddBinaryExpression>(loc, std::move(dlhs), result()); + } + + void visit(SubBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + result_ = make_expression<SubBinaryExpression>(loc, move(dlhs), result()); + } + + void visit(MulBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + expression_ptr drhs = std::move(result_); + + result_ = make_expression<AddBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, e->lhs()->clone(), std::move(drhs)), + make_expression<MulBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone())); + + } + + void visit(DivBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + expression_ptr drhs = std::move(result_); + + result_ = make_expression<SubBinaryExpression>(loc, + make_expression<DivBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone()), + make_expression<MulBinaryExpression>(loc, + make_expression<DivBinaryExpression>(loc, + e->lhs()->clone(), + make_expression<MulBinaryExpression>(loc, e->rhs()->clone(), e->rhs()->clone())), + std::move(drhs))); + } + + void visit(PowBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + expression_ptr drhs = std::move(result_); + + result_ = make_expression<AddBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, + std::move(drhs), + make_expression<MulBinaryExpression>(loc, + make_expression<LogUnaryExpression>(loc, e->lhs()->clone()), + make_expression<PowBinaryExpression>(loc, e->lhs()->clone(), e->rhs()->clone()))), + make_expression<MulBinaryExpression>(loc, + e->rhs()->clone(), + make_expression<MulBinaryExpression>(loc, + make_expression<PowBinaryExpression>(loc, + e->lhs()->clone(), + make_expression<SubBinaryExpression>(loc, + e->rhs()->clone(), + make_expression<IntegerExpression>(loc, 1))), + std::move(dlhs)))); + } + + void visit(CallExpression* e) override { + auto loc = e->location(); + result_ = make_expression<PDiffExpression>(loc, + make_expression<IdentifierExpression>(loc, id_), + e->clone()); + } + + void visit(PDiffExpression* e) override { + auto loc = e->location(); + e->arg()->accept(this); + result_ = make_expression<PDiffExpression>(loc, e->var()->clone(), result()); + } + + void visit(IdentifierExpression* e) override { + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, e->spelling()==id_); + } + + void visit(NumberExpression* e) override { + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, 0); + } + +private: + expression_ptr result_; + std::string id_; +}; + +// ConstantSimplifyVisitior is not the same as ConstantFolderVisitor, as there is no way for a visitor +// to modify an expression in place (only its children). This visitor instead builds a new expression +// from the given one with constant simplifications. + +long double expr_value(Expression* e) { + return e && e->is_number()? e->is_number()->value(): NAN; +} + +class ConstantSimplifyVisitor: public Visitor { +private: + expression_ptr result_; + + static bool is_number(Expression* e) { return e && e->is_number(); } + static bool is_number(const expression_ptr& e) { return is_number(e.get()); } + + void as_number(Location loc, long double v) { + result_ = make_expression<NumberExpression>(loc, v); + } + +public: + using Visitor::visit; + + ConstantSimplifyVisitor() {} + + // Note: moves result, forces reset. + expression_ptr result() { + auto r = std::move(result_); + reset(); + return r; + } + + void reset() { + result_ = nullptr; + } + + long double value() const { return expr_value(result_); } + + bool is_number() const { return is_number(result_); } + + void visit(Expression* e) override { + result_ = e->clone(); + } + + void visit(BlockExpression* e) override { + auto block_ = e->clone(); + block_->is_block()->statements().clear(); + + for (auto& stmt: e->statements()) { + stmt->accept(this); + auto simpl = result(); + + // flatten any naked blocks generated by if/else simplification + if (auto inner = simpl->is_block()) { + for (auto& stmt: inner->statements()) { + block_->is_block()->statements().push_back(std::move(stmt)); + } + } + else { + block_->is_block()->statements().push_back(std::move(simpl)); + } + } + result_ = std::move(block_); + } + + void visit(IfExpression* e) override { + auto loc = e->location(); + e->condition()->accept(this); + auto cond_expr = result(); + e->true_branch()->accept(this); + auto true_expr = result(); + e->false_branch()->accept(this); + auto false_expr = result(); + + if (!is_number(cond_expr)) { + result_ = make_expression<IfExpression>(loc, + std::move(cond_expr), std::move(true_expr), std::move(false_expr)); + } + else if (expr_value(cond_expr)) { + result_ = std::move(true_expr); + } + else { + result_ = std::move(false_expr); + } + } + + // TODO: procedure, function expressions + + void visit(UnaryExpression* e) override { + e->expression()->accept(this); + + if (is_number()) { + auto loc = e->location(); + auto val = value(); + + switch (e->op()) { + case tok::minus: + as_number(loc, -val); + return; + case tok::exp: + as_number(loc, std::exp(val)); + return; + case tok::sin: + as_number(loc, std::sin(val)); + return; + case tok::cos: + as_number(loc, std::cos(val)); + return; + case tok::log: + as_number(loc, std::log(val)); + return; + default: ; // treat opaquely as below + } + } + + expression_ptr arg = result(); + result_ = e->clone(); + result_->is_unary()->replace_expression(std::move(arg)); + } + + void visit(BinaryExpression* e) override { + result_ = e->clone(); + } + + void visit(AssignmentExpression* e) override { + auto loc = e->location(); + e->rhs()->accept(this); + result_ = make_expression<AssignmentExpression>(loc, e->lhs()->clone(), result()); + } + + void visit(MulBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)*expr_value(rhs)); + } + else if (expr_value(lhs)==0 || expr_value(rhs)==0) { + as_number(loc, 0); + } + else if (expr_value(lhs)==1) { + result_ = std::move(rhs); + } + else if (expr_value(rhs)==1) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<MulBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(DivBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)/expr_value(rhs)); + } + else if (expr_value(lhs)==0) { + as_number(loc, 0); + } + else if (expr_value(rhs)==1) { + result_ = e->lhs()->clone(); + } + else { + result_ = make_expression<DivBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(AddBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)+expr_value(rhs)); + } + else if (expr_value(lhs)==0) { + result_ = std::move(rhs); + } + else if (expr_value(rhs)==0) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<AddBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(SubBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)-expr_value(rhs)); + } + else if (expr_value(lhs)==0) { + make_expression<NegUnaryExpression>(loc, std::move(rhs))->accept(this); + } + else if (expr_value(rhs)==0) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<SubBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(PowBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, std::pow(expr_value(lhs),expr_value(rhs))); + } + else if (expr_value(lhs)==0) { + as_number(loc, 0); + } + else if (expr_value(rhs)==0 || expr_value(lhs)==1) { + as_number(loc, 1); + } + else if (expr_value(rhs)==1) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<PowBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(ConditionalExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + auto lval = expr_value(lhs); + auto rval = expr_value(rhs); + switch (e->op()) { + case tok::equality: + as_number(loc, lval==rval); + return; + case tok::ne: + as_number(loc, lval!=rval); + return; + case tok::lt: + as_number(loc, lval<rval); + return; + case tok::gt: + as_number(loc, lval>rval); + return; + case tok::lte: + as_number(loc, lval<=rval); + return; + case tok::gte: + as_number(loc, lval>=rval); + return; + default: ; + // unrecognized, fall through to non-numeric case below + } + } + if (!is_number(lhs) || !is_number(rhs)) { + result_ = make_expression<ConditionalExpression>(loc, e->op(), std::move(lhs), std::move(rhs)); + } + } +}; + +expression_ptr constant_simplify(Expression* e) { + ConstantSimplifyVisitor csimp_visitor; + e->accept(&csimp_visitor); + return csimp_visitor.result(); +} + + +expression_ptr symbolic_pdiff(Expression* e, const std::string& id) { + if (!involves_identifier(e, id)) { + return make_expression<NumberExpression>(e->location(), 0); + } + + SymPDiffVisitor pdiff_visitor(id); + e->accept(&pdiff_visitor); + + return constant_simplify(pdiff_visitor.result()); +} + +// Substitute all occurances of an identifier within a unary, binary, call +// or (trivially) number expression with a copy of the provided substitute +// expression. + +class SubstituteVisitor: public Visitor { +public: + explicit SubstituteVisitor(const substitute_map& sub): + sub_(sub) {} + + expression_ptr result() { + auto r = std::move(result_); + reset(); + return r; + } + + void reset() { + result_ = nullptr; + } + + void visit(Expression* e) override { + throw compiler_exception("substitution attempt on improper expression", e->location()); + } + + void visit(NumberExpression* e) override { + result_ = e->clone(); + } + + void visit(IdentifierExpression* e) override { + result_ = is_in(e->spelling(), sub_)? sub_.at(e->spelling())->clone(): e->clone(); + } + + void visit(UnaryExpression* e) override { + e->expression()->accept(this); + auto arg = result(); + + result_ = e->clone(); + result_->is_unary()->replace_expression(std::move(arg)); + } + + void visit(BinaryExpression* e) override { + e->lhs()->accept(this); + auto lhs = result(); + e->rhs()->accept(this); + auto rhs = result(); + + result_ = e->clone(); + result_->is_binary()->replace_lhs(std::move(lhs)); + result_->is_binary()->replace_rhs(std::move(rhs)); + } + + void visit(CallExpression* e) override { + auto newexpr = e->clone(); + for (auto& arg: newexpr->is_call()->args()) { + arg->accept(this); + arg = result(); + } + result_ = std::move(newexpr); + } + + void visit(PDiffExpression* e) override { + // Doing the correct thing when the derivative variable is the + // substitution variable would require another 'opaque' expression, + // e.g. `SubstitutionExpression`, but we're not about to do that yet, + // so throw an exception instead, i.e. Don't Do That. + if (is_in(e->var()->is_identifier()->spelling(), sub_)) { + throw compiler_exception("attempt to substitute value for derivative variable", e->location()); + } + + e->arg()->accept(this); + result_ = make_expression<PDiffExpression>(e->location(), e->var()->clone(), result()); + } + +private: + expression_ptr result_; + const substitute_map& sub_; +}; + +expression_ptr substitute(Expression* e, const std::string& id, Expression* sub) { + substitute_map subs; + subs[id] = sub->clone(); + SubstituteVisitor sub_visitor(subs); + e->accept(&sub_visitor); + return sub_visitor.result(); +} + +expression_ptr substitute(Expression* e, const substitute_map& sub) { + SubstituteVisitor sub_visitor(sub); + e->accept(&sub_visitor); + return sub_visitor.result(); +} + +linear_test_result linear_test(Expression* e, const std::vector<std::string>& vars) { + linear_test_result result; + auto loc = e->location(); + auto zero = [loc]() { return make_expression<IntegerExpression>(loc, 0); }; + + result.constant = e->clone(); + for (const auto& id: vars) { + auto coef = symbolic_pdiff(e, id); + if (!is_zero(coef)) result.coef[id] = std::move(coef); + + result.constant = substitute(result.constant, id, zero()); + } + + ConstantSimplifyVisitor csimp_visitor; + result.constant->accept(&csimp_visitor); + result.constant = csimp_visitor.result(); + + // linearity test: take second order derivatives, test against zero. + result.is_linear = true; + for (unsigned i = 0; i<vars.size(); ++i) { + auto v1 = vars[i]; + if (!is_in(v1, result.coef)) continue; + + for (unsigned j = i; j<vars.size(); ++j) { + auto v2 = vars[j]; + + if (!is_zero(symbolic_pdiff(result.coef[v1].get(), v2).get())) { + result.is_linear = false; + goto done; + } + } + } +done: + + if (result.is_linear) { + result.is_homogeneous = is_zero(result.constant); + } + + return result; +} + diff --git a/modcc/symdiff.hpp b/modcc/symdiff.hpp new file mode 100644 index 00000000..31c71501 --- /dev/null +++ b/modcc/symdiff.hpp @@ -0,0 +1,119 @@ +#pragma once + +// Perform naive symbolic differenation on a (rhs) expression; +// treat all identifiers as independent, and function calls +// with the variable in argument as opaque. +// +// This is just for linearity and possibly polynomiality testing, so +// don't try too hard. + +#include <iostream> +#include <map> +#include <string> +#include <utility> + +#include "expression.hpp" + + +// True if `id` matches the spelling of any identifier in the expression. +bool involves_identifier(Expression* e, const std::string& id); + +using identifier_set = std::vector<std::string>; +bool involves_identifier(Expression* e, const identifier_set& ids); + +// Return new expression formed by folding constants and removing trivial terms. +expression_ptr constant_simplify(Expression* e); + +// Extract value of expression that is a NumberExpression, or else return NAN. +long double expr_value(Expression* e); + +// Test if expression is a NumberExpression with value zero. +inline bool is_zero(Expression* e) { + return expr_value(e)==0; +} + +// Return new expression of symbolic partial differentiation of argument wrt `id`. +expression_ptr symbolic_pdiff(Expression* e, const std::string& id); + +// Substitute all occurances of identifier `id` within expression by a clone of `sub`. +// (Only applicable to unary, binary, call and number expressions.) +expression_ptr substitute(Expression* e, const std::string& id, Expression* sub); + +using substitute_map = std::map<std::string, expression_ptr>; +expression_ptr substitute(Expression* e, const substitute_map& sub); + +// Convenience interfaces for the above functions work with `expression_ptr` as +// well as with `Expression*` values. + +inline bool involves_identifier(const expression_ptr& e, const std::string& id) { + return involves_identifier(e.get(), id); +} + +inline bool involves_identifier(const expression_ptr& e, const identifier_set& ids) { + return involves_identifier(e.get(), ids); +} + +inline expression_ptr constant_simplify(const expression_ptr& e) { + return constant_simplify(e.get()); +} + +inline long double expr_value(const expression_ptr& e) { + return expr_value(e.get()); +} + +inline long double is_zero(const expression_ptr& e) { + return is_zero(e.get()); +} + +inline expression_ptr symbolic_pdiff(const expression_ptr& e, const std::string& id) { + return symbolic_pdiff(e.get(), id); +} + +inline expression_ptr substitute(const expression_ptr& e, const std::string& id, const expression_ptr& sub) { + return substitute(e.get(), id, sub.get()); +} + +inline expression_ptr substitute(const expression_ptr& e, const substitute_map& sub) { + return substitute(e.get(), sub); +} + + +// Linearity testing + +struct linear_test_result { + bool is_linear = false; + bool is_homogeneous = false; + expression_ptr constant; + std::map<std::string, expression_ptr> coef; + + bool monolinear() const { + unsigned nlinear = 0; + for (auto& entry: coef) { + if (!is_zero(entry.second) && ++nlinear>1) return false; + } + return true; + } + + bool monolinear(const std::string& var) const { + for (auto& entry: coef) { + if (!is_zero(entry.second) && var!=entry.first) return false; + } + return true; + } + + friend std::ostream& operator<<(std::ostream& o, const linear_test_result& r) { + o << "{linear: " << r.is_linear << "; homogeneous: " << r.is_homogeneous << "\n"; + o << " constant term: " << r.constant->to_string(); + for (const auto& p: r.coef) { + o << "\n coef " << p.first << ": " << p.second->to_string(); + } + o << "}"; + return o; + } +}; + +linear_test_result linear_test(Expression* e, const std::vector<std::string>& vars); + +inline linear_test_result linear_test(const expression_ptr& e, const std::vector<std::string>& vars) { + return linear_test(e.get(), vars); +} diff --git a/modcc/symge.cpp b/modcc/symge.cpp new file mode 100644 index 00000000..4f09a022 --- /dev/null +++ b/modcc/symge.cpp @@ -0,0 +1,109 @@ +#include <algorithm> +#include <stdexcept> +#include <vector> + +#include "symge.hpp" + +namespace symge { + +// Returns q[c]*p - p[c]*q; new symbols required due to fill-in are provided by the +// `define_sym` functor, which takes a `symbol_term_diff` and returns a `symbol`. + +template <typename DefineSym> +sym_row row_reduce(unsigned c, const sym_row& p, const sym_row& q, DefineSym define_sym) { + if (p.index(c)==p.npos || q.index(c)==q.npos) throw std::runtime_error("improper row reduction"); + + sym_row u; + symbol x = q[c]; + symbol y = p[c]; + + auto piter = p.begin(); + auto qiter = q.begin(); + unsigned pj = piter->col; + unsigned qj = qiter->col; + + while (piter!=p.end() || qiter!=q.end()) { + unsigned j = std::min(pj, qj); + symbol_term t1, t2; + + if (j==pj) { + t1 = x*piter->value; + ++piter; + pj = piter==p.end()? p.npos: piter->col; + } + if (j==qj) { + t2 = y*qiter->value; + ++qiter; + qj = qiter==q.end()? q.npos: qiter->col; + } + if (j!=c) { + u.push_back({j, define_sym(t1-t2)}); + } + } + return u; +} + +// Estimate cost of a choice of pivot for G–J reduction below. Uses a simple greedy +// estimate based on immediate fill cost. +double estimate_cost(const sym_matrix& A, unsigned p) { + unsigned nfill = 0; + + auto count_fill = [&nfill](symbol_term_diff t) { + bool l = t.left; + bool r = t.right; + nfill += r&!l; + return symbol{}; + }; + + for (unsigned i = 0; i<A.nrow(); ++i) { + if (i==p || A[i].index(p)==msparse::row_npos) continue; + row_reduce(p, A[i], A[p], count_fill); + } + + return nfill; +} + +// Perform Gauss-Jordan elimination on given symbolic matrix. New symbols +// required due to fill-in are added to the supplied symbol table. +// +// The matrix A is regarded as being diagonally dominant, and so pivots +// are selected from the diagonal. The choice of pivot at each stage of +// the reduction is goverened by a cost estimation (see above). +// +// The reduction is division-free: the result will have non-zero terms +// that are symbols that are either primitive, or defined (in the symbol +// table) as products or differences of products of other symbols. +void gj_reduce(sym_matrix& A, symbol_table& table) { + if (A.nrow()>A.ncol()) throw std::runtime_error("improper matrix for reduction"); + + auto define_sym = [&table](symbol_term_diff t) { return table.define(t); }; + + std::vector<unsigned> pivots; + for (unsigned r = 0; r<A.nrow(); ++r) { + pivots.push_back(r); + } + + std::vector<double> cost(pivots.size()); + + while (!pivots.empty()) { + for (unsigned i = 0; i<pivots.size(); ++i) { + cost[pivots[i]] = estimate_cost(A, pivots[i]); + } + + std::sort(pivots.begin(), pivots.end(), + [&](unsigned r1, unsigned r2) { return cost[r1]>cost[r2]; }); + + unsigned pivrow = pivots.back(); + pivots.erase(std::prev(pivots.end())); + + unsigned pivcol = pivrow; + + for (unsigned i = 0; i<A.nrow(); ++i) { + if (i==pivrow || A[i].index(pivcol)==msparse::row_npos) continue; + + A[i] = row_reduce(pivcol, A[i], A[pivrow], define_sym); + } + } +} + +} // namespace symge diff --git a/modcc/symge.hpp b/modcc/symge.hpp new file mode 100644 index 00000000..edd37b90 --- /dev/null +++ b/modcc/symge.hpp @@ -0,0 +1,158 @@ +#pragma once + +#include <stdexcept> + +#include "msparse.hpp" + +// Symbolic sparse matrix manipulation for symbolic Gauss-Jordan elimination +// (used in `sparse` solver). + +namespace symge { + +struct symbol_error: public std::runtime_error { + symbol_error(const std::string& what): std::runtime_error(what) {} +}; + +// Abstract symbols: + +class symbol_table; + +class symbol { +private: + unsigned index_; + const symbol_table* table_; + + // Valid symbols are constructed via a symbol table. + friend class symbol_table; + symbol(unsigned index, const symbol_table* table): + index_(index), table_(table) {} + +public: + symbol(): index_(0), table_(nullptr) {} + + // true => valid symbol. + operator bool() const { return table_; } + + bool operator==(symbol other) const { return index_==other.index_ && table_==other.table_; } + bool operator!=(symbol other) const { return !(*this==other); } + + const symbol_table* table() const { return table_; } +}; + +// A `symbol_term` is either zero or a product of symbols. + +struct symbol_term { + symbol left, right; + + symbol_term() = default; + bool is_zero() const { return !left || !right; } + operator bool() const { return !is_zero(); } +}; + +struct symbol_term_diff { + symbol_term left, right; + + symbol_term_diff() = default; + symbol_term_diff(const symbol_term& left): left(left), right{} {} + symbol_term_diff(const symbol_term& left, const symbol_term& right): + left(left), right(right) {} +}; + +inline symbol_term operator*(symbol a, symbol b) { + return symbol_term{a, b}; +} + +inline symbol_term_diff operator-(symbol_term l, symbol_term r) { + return symbol_term_diff{l, r}; +} + +inline symbol_term_diff operator-(symbol_term r) { + return symbol_term_diff{symbol_term{}, r}; +} + +// Symbols are not re-assignable; they are created as primitive, or +// have a definition in terms of a `symbol_term_diff`. + +class symbol_table { +private: + struct table_entry { + std::string name; + symbol_term_diff def; + bool defined; + }; + + std::vector<table_entry> entries_; + +public: + // make new primitive symbol + symbol define(const std::string& name="") { + symbol s(size(), this); + entries_.push_back({name, symbol_term_diff{}, false}); + return s; + } + + // make new symbol with definition + symbol define(const std::string& name, const symbol_term_diff& def) { + symbol s(size(), this); + entries_.push_back({name, def, true}); + return s; + } + + symbol define(const symbol_term_diff& def) { + return define("", def); + } + + symbol_term_diff get(symbol s) const { + if (!defined(s)) throw symbol_error("symbol is primitive"); + return entries_[s.index_].def; + } + + bool defined(symbol s) const { + if (!valid(s)) throw symbol_error("symbol not present in this table"); + return entries_[s.index_].defined; + } + + bool primitive(symbol s) const { return !defined(s); } + + const std::string& name(symbol s) const { + if (!valid(s)) throw symbol_error("symbol not present in this table"); + return entries_[s.index_].name; + } + + void name(symbol s, const std::string& n) { + if (!valid(s)) throw symbol_error("symbol not present in this table"); + entries_[s.index_].name = n; + } + + std::size_t size() const { return entries_.size(); } + + symbol operator[](unsigned i) const { return symbol{i, this}; } + + bool valid(symbol s) const { return s.table_==this && s.index_<size(); } + + // Existing symbols associated with this table are invalidated by clear(). + void clear() { entries_.clear(); } +}; + +inline std::string name(symbol s) { + return s? s.table()->name(s): ""; +} + +inline symbol_term_diff definition(symbol s) { + if (!s) throw symbol_error("invalid symbol"); + return s.table()->get(s); +} + +inline bool primitive(symbol s) { + return s && s.table()->primitive(s); +} + +using sym_row = msparse::row<symbol>; +using sym_matrix = msparse::matrix<symbol>; + +// Perform Gauss-Jordan reduction on a (possibly augmented) symbolic matrix, with +// pivots taken from the diagonal elements. New symbol definitions due to fill-in +// will be added via the provided symbol table. +void gj_reduce(sym_matrix& A, symbol_table& table); + +} // namespace symge diff --git a/modcc/token.cpp b/modcc/token.cpp index 7383ca9f..25d08da9 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -52,6 +52,7 @@ static Keyword keywords[] = { {"if", tok::if_stmt}, {"else", tok::else_stmt}, {"cnexp", tok::cnexp}, + {"sparse", tok::sparse}, {"exp", tok::exp}, {"sin", tok::sin}, {"cos", tok::cos}, diff --git a/modcc/token.hpp b/modcc/token.hpp index c31cdbc4..95f7c461 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -70,6 +70,7 @@ enum class tok { // solver methods cnexp, + sparse, conductance, diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index c474dc89..921189bc 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -34,9 +34,11 @@ public: virtual void visit(IfExpression *e) { visit((Expression*) e); } virtual void visit(SolveExpression *e) { visit((Expression*) e); } virtual void visit(DerivativeExpression *e) { visit((Expression*) e); } + virtual void visit(PDiffExpression *e) { visit((Expression*) e); } virtual void visit(ProcedureExpression *e) { visit((Expression*) e); } virtual void visit(NetReceiveExpression *e) { visit((ProcedureExpression*) e); } virtual void visit(APIMethod *e) { visit((Expression*) e); } + virtual void visit(ConductanceExpression *e) { visit((Expression*) e); } virtual void visit(BlockExpression *e) { visit((Expression*) e); } virtual void visit(InitialBlock *e) { visit((BlockExpression*) e); } @@ -48,6 +50,7 @@ public: virtual void visit(SinUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(BinaryExpression *e) = 0; + virtual void visit(ConditionalExpression *e) {visit((BinaryExpression*) e); } virtual void visit(AssignmentExpression *e) { visit((BinaryExpression*) e); } virtual void visit(ConserveExpression *e) { visit((BinaryExpression*) e); } virtual void visit(AddBinaryExpression *e) { visit((BinaryExpression*) e); } @@ -59,3 +62,136 @@ public: virtual ~Visitor() {}; }; + +// Visitor specialization intended for use as a base class for visitors that +// operate as function or procedure body rewriters after semantic analysis. +// +// Block rewriter visitors construct a new block body from a supplied +// `BlockExpression`, `ProcedureExpression` or `FunctionExpression`. By default, +// expressions are simply copied to the list of statements corresponding to the +// rewritten block; nested blocks as provided by `IfExpression` objects are +// handled recursively. +// +// The `finalize()` method is called after visiting all the statements in the +// top-level block, and is intended to be overrided by derived classes as required. +// +// The `as_block()` method is intended to be called by users of the derived block +// rewriter objects. It constructs a new `BlockExpression` from the accumulated +// replacement statements and applies a semantic pass if the `BlockRewriterBase` +// was given a corresponding scope +// +// The visitor maintains significant internal state: the `reset` method should +// be called between visits of top-level blocks. +// +// Errors are recorded through the `error_stack` mixin rather than by +// throwing an exception. + +class BlockRewriterBase : public Visitor, public error_stack { +public: + BlockRewriterBase() {} + BlockRewriterBase(scope_ptr block_scope): + block_scope_(block_scope) {} + + virtual void visit(Expression *e) override { + statements_.push_back(e->clone()); + } + + virtual void visit(UnaryExpression *e) override { visit((Expression*)e); } + virtual void visit(BinaryExpression *e) override { visit((Expression*)e); } + + virtual void visit(BlockExpression *e) override { + bool top = !started_; + if (top) { + loc_ = e->location(); + started_ = true; + + if (!block_scope_) { + block_scope_ = e->scope(); + } + } + + for (auto& s: e->statements()) { + s->accept(this); + } + + if (top) { + finalize(); + } + } + + virtual void visit(IfExpression* e) override { + expr_list_type outer; + std::swap(outer, statements_); + + e->true_branch()->accept(this); + auto true_branch = make_expression<BlockExpression>( + e->true_branch()->location(), + std::move(statements_), + true); + + statements_.clear(); + e->false_branch()->accept(this); + auto false_branch = make_expression<BlockExpression>( + e->false_branch()->location(), + std::move(statements_), + true); + + statements_ = std::move(outer); + statements_.push_back(make_expression<IfExpression>( + e->location(), + e->condition()->clone(), + std::move(true_branch), + std::move(false_branch))); + } + + virtual void visit(ProcedureExpression* e) override { + e->body()->accept(this); + } + + virtual void visit(FunctionExpression* e) override { + e->body()->accept(this); + } + + virtual expression_ptr as_block(bool is_nested=false) { + if (has_error()) return nullptr; + + expr_list_type body_stmts; + for (const auto& s: statements_) body_stmts.push_back(s->clone()); + + auto body = make_expression<BlockExpression>( + loc_, + std::move(body_stmts), + is_nested); + + if (block_scope_) { + body->semantic(block_scope_); + } + return body; + } + + // Reset state. + virtual void reset() { + statements_.clear(); + started_ = false; + loc_ = Location{}; + clear_errors(); + clear_warnings(); + } + +protected: + // False until processing of top block starts. + bool started_ = false; + + // Location of original block. + Location loc_; + + // Scope for semantic pass. + scope_ptr block_scope_; + + // Statements in replacement procedure body. + expr_list_type statements_; + + // Finalise statements list at end of top block visit. + virtual void finalize() {} +}; + diff --git a/src/backends/fvm_multicore.cpp b/src/backends/fvm_multicore.cpp index dd8f797c..8d20f2ff 100644 --- a/src/backends/fvm_multicore.cpp +++ b/src/backends/fvm_multicore.cpp @@ -4,6 +4,8 @@ #include <mechanisms/multicore/pas.hpp> #include <mechanisms/multicore/expsyn.hpp> #include <mechanisms/multicore/exp2syn.hpp> +#include <mechanisms/multicore/test_kin1.hpp> +#include <mechanisms/multicore/test_kinlva.hpp> namespace nest { namespace mc { @@ -11,10 +13,12 @@ namespace multicore { std::map<std::string, backend::maker_type> backend::mech_map_ = { - { std::string("pas"), maker<mechanisms::pas::mechanism_pas> }, - { std::string("hh"), maker<mechanisms::hh::mechanism_hh> }, - { std::string("expsyn"), maker<mechanisms::expsyn::mechanism_expsyn> }, - { std::string("exp2syn"), maker<mechanisms::exp2syn::mechanism_exp2syn> } + { std::string("pas"), maker<mechanisms::pas::mechanism_pas> }, + { std::string("hh"), maker<mechanisms::hh::mechanism_hh> }, + { std::string("expsyn"), maker<mechanisms::expsyn::mechanism_expsyn> }, + { std::string("exp2syn"), maker<mechanisms::exp2syn::mechanism_exp2syn> }, + { std::string("test_kin1"), maker<mechanisms::test_kin1::mechanism_test_kin1> }, + { std::string("test_kinlva"), maker<mechanisms::test_kinlva::mechanism_test_kinlva> } }; } // namespace multicore diff --git a/tests/modcc/CMakeLists.txt b/tests/modcc/CMakeLists.txt index bb27815a..34f87ed7 100644 --- a/tests/modcc/CMakeLists.txt +++ b/tests/modcc/CMakeLists.txt @@ -3,9 +3,13 @@ set(MODCC_TEST_SOURCES test_lexer.cpp test_kinetic_rewriter.cpp test_module.cpp + test_msparse.cpp test_optimization.cpp test_parser.cpp #test_printers.cpp + test_removelocals.cpp + test_symdiff.cpp + test_symge.cpp test_visitors.cpp # unit test driver @@ -13,6 +17,7 @@ set(MODCC_TEST_SOURCES # utility expr_expand.cpp + test.cpp ) add_definitions("-DDATADIR=\"${PROJECT_SOURCE_DIR}/data\"") diff --git a/tests/modcc/driver.cpp b/tests/modcc/driver.cpp index 67505a0d..dcd36906 100644 --- a/tests/modcc/driver.cpp +++ b/tests/modcc/driver.cpp @@ -6,8 +6,6 @@ #include "test.hpp" -bool g_verbose_flag = false; - int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); if (argc>1 && (!std::strcmp(argv[1],"-v") || !std::strcmp(argv[1],"--verbose"))) { diff --git a/tests/modcc/test.cpp b/tests/modcc/test.cpp new file mode 100644 index 00000000..72e0db77 --- /dev/null +++ b/tests/modcc/test.cpp @@ -0,0 +1,27 @@ +#include <regex> +#include <string> + +#include "test.hpp" + +bool g_verbose_flag = false; + +std::string plain_text(Expression* expr) { + static std::regex csi_code(R"_(\x1B\[.*?[\x40-\x7E])_"); + return !expr? "null": regex_replace(expr->to_string(), csi_code, ""); +} + +::testing::AssertionResult assert_expr_eq(const char *arg1, const char *arg2, Expression* expected, Expression* value) { + auto value_rep = plain_text(value); + auto expected_rep = plain_text(expected); + + if ((!value && expected) || (value && !expected) || (value_rep!=expected_rep)) { + return ::testing::AssertionFailure() + << "Value of: " << arg2 << "\n" + << " Actual: " << value_rep << "\n" + << "Expected: " << arg1 << "\n" + << "Which is: " << expected_rep << "\n"; + } + else { + return ::testing::AssertionSuccess(); + } +} diff --git a/tests/modcc/test.hpp b/tests/modcc/test.hpp index 45dced64..fe6e7666 100644 --- a/tests/modcc/test.hpp +++ b/tests/modcc/test.hpp @@ -1,7 +1,10 @@ #pragma once +#include <string> + #include "../gtest.h" +#include "expression.hpp" #include "parser.hpp" #include "modccutil.hpp" @@ -24,3 +27,63 @@ inline expression_ptr parse_function(std::string const& s) { inline expression_ptr parse_procedure(std::string const& s) { return Parser(s).parse_procedure(); } + +// Helpers for comparing expressions, and verbose expression printing. + +// Strip ANSI control sequences from `to_string` output. +std::string plain_text(Expression* expr); + +// Compare two expressions via their representation. +// Use with EXPECT_PRED_FORMAT2. +::testing::AssertionResult assert_expr_eq(const char *arg1, const char *arg2, Expression* expected, Expression* value); + +#define EXPECT_EXPR_EQ(a,b) EXPECT_PRED_FORMAT2(assert_expr_eq, a, b) + +// Print arguments, but only if verbose flag set. +// Use `to_string()` to print (smart) pointers to Expression or Scope objects. + +namespace impl { + template <typename X> + struct has_to_string { + template <typename T> + static decltype(std::declval<T>()->to_string(), std::true_type{}) test(int); + template <typename T> + static std::false_type test(...); + + using type = decltype(test<X>(0)); + }; + + template <typename X> + void print(const X& x, std::true_type) { + if (x) { + std::cout << x->to_string(); + } + else { + std::cout << "null"; + } + } + + template <typename X> + void print(const X& x, std::false_type) { + std::cout << x; + } + + template <typename X> + void print(const X& x) { + print(x, typename has_to_string<X>::type{}); + } + +} + +inline void verbose_print() { + if (!g_verbose_flag) return; + std::cout << "\n"; +} + +template <typename X, typename... Args> +void verbose_print(const X& arg, const Args&... tail) { + if (!g_verbose_flag) return; + impl::print(arg); + verbose_print(tail...); +} + diff --git a/tests/modcc/test_kinetic_rewriter.cpp b/tests/modcc/test_kinetic_rewriter.cpp index 8ec3173a..62c2feb2 100644 --- a/tests/modcc/test_kinetic_rewriter.cpp +++ b/tests/modcc/test_kinetic_rewriter.cpp @@ -2,21 +2,31 @@ #include <string> #include "expression.hpp" -#include "kinrewriter.hpp" +#include "kineticrewriter.hpp" #include "parser.hpp" #include "alg_collect.hpp" #include "expr_expand.hpp" #include "test.hpp" -using namespace nest::mc; +expr_list_type& statements(Expression *e) { + if (e) { + if (auto block = e->is_block()) { + return block->statements(); + } + + if (auto sym = e->is_symbol()) { + if (auto proc = sym->is_procedure()) { + return proc->body()->statements(); + } -stmt_list_type& proc_statements(Expression *e) { - if (!e || !e->is_symbol() || ! e->is_symbol()->is_procedure()) { - throw std::runtime_error("not a procedure"); + if (auto proc = sym->is_procedure()) { + return proc->body()->statements(); + } + } } - return e->is_symbol()->is_procedure()->body()->statements(); + throw std::runtime_error("not a block, function or procedure"); } @@ -49,37 +59,40 @@ static const char* derivative_abc = "} \n"; TEST(KineticRewriter, equiv) { - auto visitor = util::make_unique<KineticRewriter>(); auto kin = Parser(kinetic_abc).parse_procedure(); auto deriv = Parser(derivative_abc).parse_procedure(); + auto kin_ptr = kin.get(); + auto deriv_ptr = deriv.get(); + ASSERT_NE(nullptr, kin); ASSERT_NE(nullptr, deriv); ASSERT_TRUE(kin->is_symbol() && kin->is_symbol()->is_procedure()); ASSERT_TRUE(deriv->is_symbol() && deriv->is_symbol()->is_procedure()); - auto kin_weak = kin->is_symbol()->is_procedure(); scope_type::symbol_map globals; globals["kin"] = std::move(kin); + globals["deriv"] = std::move(deriv); globals["a"] = state_var("a"); globals["b"] = state_var("b"); globals["c"] = state_var("c"); globals["u"] = assigned_var("u"); globals["v"] = assigned_var("v"); - kin_weak->semantic(globals); - kin_weak->accept(visitor.get()); + deriv_ptr->semantic(globals); - auto kin_deriv = visitor->as_procedure(); + auto kin_body = kin_ptr->is_procedure()->body(); + scope_ptr scope = std::make_shared<scope_type>(globals); + kin_body->semantic(scope); - if (g_verbose_flag) { - std::cout << "derivative procedure:\n" << deriv->to_string() << "\n"; - std::cout << "kin procedure:\n" << kin_weak->to_string() << "\n"; - std::cout << "rewritten kin procedure:\n" << kin_deriv->to_string() << "\n"; - } + auto kin_deriv = kinetic_rewrite(kin_body); + + verbose_print("derivative procedure:\n", deriv_ptr); + verbose_print("kin procedure:\n", kin_ptr); + verbose_print("rewritten kin body:\n", kin_deriv); - auto deriv_map = expand_assignments(proc_statements(deriv.get())); - auto kin_map = expand_assignments(proc_statements(kin_deriv.get())); + auto deriv_map = expand_assignments(statements(deriv_ptr)); + auto kin_map = expand_assignments(statements(kin_deriv.get())); if (g_verbose_flag) { std::cout << "derivative assignments (canonical):\n"; diff --git a/tests/modcc/test_lexer.cpp b/tests/modcc/test_lexer.cpp index e148b12b..91ee7df7 100644 --- a/tests/modcc/test_lexer.cpp +++ b/tests/modcc/test_lexer.cpp @@ -7,48 +7,38 @@ #include "test.hpp" #include "lexer.hpp" -void verbose_print(const char* string) { - if (!g_verbose_flag) return; - std::cout << "________________\n" << string << "\n________________\n"; -} - -void verbose_print(const Token& token) { - if (!g_verbose_flag) return; - std::cout << "tok: " << token << "\n"; -} - class VerboseLexer: public Lexer { public: template <typename... Args> VerboseLexer(Args&&... args): Lexer(std::forward<Args>(args)...) { - if (g_verbose_flag) { - std::cout << "________________\n" << std::string(begin_, end_) << "\n________________\n"; - } + verbose_print("________________"); + verbose_print(std::string(begin_, end_)); + verbose_print("________________"); } Token parse() { auto tok = Lexer::parse(); - if (g_verbose_flag) { - std::cout << "token: " << tok << "\n"; - } + verbose_print("token: ",tok); return tok; } char character() { char c = Lexer::character(); - if (g_verbose_flag) { - std::cout << "character: "; - if (!std::isprint(c)) { - char buf[5] = "XXXX"; - snprintf(buf, sizeof buf, "0x%02x", (unsigned)c); - std::cout << buf << '\n'; - } - else { - std::cout << c << '\n'; - } - } + verbose_print("character: ", pretty(c)); return c; } + + static const char* pretty(char c) { + static char buf[5] = "XXXX"; + if (!std::isprint(c)) { + snprintf(buf, sizeof buf, "0x%02x", (unsigned)c); + } + else { + buf[0] = c; + buf[1] = 0; + } + return buf; + } }; /************************************************************** diff --git a/tests/modcc/test_msparse.cpp b/tests/modcc/test_msparse.cpp new file mode 100644 index 00000000..6229e725 --- /dev/null +++ b/tests/modcc/test_msparse.cpp @@ -0,0 +1,188 @@ +#include <utility> + +#include "msparse.hpp" +#include "test.hpp" + +using drow = msparse::row<double>; +using dmatrix = msparse::matrix<double>; +using msparse::row_npos; + +namespace msparse { +bool operator==(const drow::entry& a, const drow::entry& b) { + return a.col==b.col && a.value==b.value; +} +} + +TEST(msparse, row_ctor) { + + drow r1; + EXPECT_TRUE(r1.empty()); + + drow r2({{0,3.0},{2,-1.5}}); + EXPECT_FALSE(r2.empty()); + EXPECT_EQ(2u, r2.size()); + + drow r3(r2); + EXPECT_FALSE(r3.empty()); + EXPECT_EQ(2u, r3.size()); + + drow::entry entries[] = { {1,2.}, {4,-2.}, {5,0} }; + drow r4(std::begin(entries), std::end(entries)); + EXPECT_FALSE(r4.empty()); + EXPECT_EQ(3u, r4.size()); + + ASSERT_THROW(drow({{2,-1.5}, {0,3.0}}), msparse::msparse_error); + + drow r5 = r4; + EXPECT_FALSE(r5.empty()); + EXPECT_EQ(3u, r5.size()); +} + +TEST(msparse, row_iter) { + using drow = msparse::row<double>; + + drow r1; + EXPECT_EQ(r1.begin(), r1.end()); + const drow& r1c = r1; + EXPECT_EQ(r1c.begin(), r1c.end()); + + drow r2({{0,3.0},{2,-1.5}}); + EXPECT_EQ(*r2.begin(), drow::entry({0,3.0})); + EXPECT_EQ(*std::prev(r2.end()), drow::entry({2,-1.5})); + EXPECT_EQ(2u, std::distance(r2.begin(), r2.end())); +} + +TEST(msparse, row_query) { + EXPECT_EQ(row_npos, drow{}.mincol()); + EXPECT_EQ(row_npos, drow{}.maxcol()); + EXPECT_EQ(row_npos, drow{}.index(2u)); + + drow r({{1,2.0}, {2,4.0}, {4,7.0}, {6,9.0}}); + + EXPECT_EQ(4u, r.size()); + EXPECT_EQ(1u, r.mincol()); + EXPECT_EQ(6u, r.maxcol()); + + EXPECT_EQ(2u, r.mincol_after(1)); + EXPECT_EQ(4u, r.mincol_after(2)); + EXPECT_EQ(4u, r.mincol_after(3)); + EXPECT_EQ(row_npos, r.mincol_after(6)); + + EXPECT_EQ(0u, r.index(1)); + EXPECT_EQ(1u, r.index(2)); + EXPECT_EQ(3u, r.index(6)); + EXPECT_EQ(row_npos, r.index(0)); + EXPECT_EQ(row_npos, r.index(5)); + EXPECT_EQ(row_npos, r.index(7)); + + EXPECT_EQ(0.0, r[0]); + EXPECT_EQ(2.0, r[1]); + EXPECT_EQ(4.0, r[2]); + EXPECT_EQ(0.0, r[3]); + EXPECT_EQ(9.0, r[6]); + EXPECT_EQ(0.0, r[7]); + + EXPECT_EQ(drow::entry({1,2.0}), r.get(0)); + EXPECT_EQ(drow::entry({2,4.0}), r.get(1)); + EXPECT_EQ(drow::entry({4,7.0}), r.get(2)); +} + +TEST(msparse, row_mutate) { + drow r; + + r.push_back({1, 2.0}); + EXPECT_EQ(1u, r.size()); + + r.push_back({3, 3.0}); + EXPECT_EQ(2u, r.size()); + + ASSERT_THROW(r.push_back({3, 2.0}), msparse::msparse_error); + ASSERT_THROW(r.push_back({2, 2.0}), msparse::msparse_error); + + r.truncate(4); + EXPECT_EQ(2u, r.size()); + + r.truncate(3); + EXPECT_EQ(1u, r.size()); + + r.truncate(0); + EXPECT_EQ(0u, r.size()); + + r[7] = 4.0; + EXPECT_EQ(1u, r.size()); + EXPECT_EQ(7u, r.mincol()); +} + +TEST(msparse, matrix_ctor) { + dmatrix M; + EXPECT_EQ(0u, M.size()); + EXPECT_EQ(0u, M.nrow()); + EXPECT_EQ(0u, M.ncol()); + EXPECT_TRUE(M.empty()); + + dmatrix M2(5,3); + EXPECT_EQ(5u, M2.size()); + EXPECT_EQ(5u, M2.nrow()); + EXPECT_EQ(3u, M2.ncol()); + EXPECT_FALSE(M2.empty()); + + dmatrix M3(M2); + EXPECT_EQ(5u, M3.nrow()); + EXPECT_EQ(3u, M3.ncol()); +} + +TEST(msparse, matrix_index) { + dmatrix M2(5,3); + M2[4].push_back({2, -1.0}); + + const dmatrix& M2c = M2; + + EXPECT_EQ(-1.0, M2[4][2]); + EXPECT_EQ(-1.0, M2c[4][2]); +} + +TEST(msparse, matrix_augment) { + dmatrix M(5,3); + + M[1][2] = 9.0; + EXPECT_FALSE(M.augmented()); + EXPECT_EQ(row_npos, M.augcol()); + + EXPECT_EQ(0u, M[0].size()); + EXPECT_EQ(1u, M[1].size()); + + double aug1[] = {5, 4, 3, 2, 1}; + M.augment(aug1); + + // short is okay... + double aug2[] = {1, 2, 1}; + M.augment(aug2); + + EXPECT_TRUE(M.augmented()); + EXPECT_EQ(3u, M.augcol()); + + EXPECT_EQ(2u, M[0].size()); + EXPECT_EQ(3u, M[0].mincol()); + EXPECT_EQ(4u, M[0].maxcol()); + + EXPECT_EQ(3u, M[1].size()); + EXPECT_EQ(2u, M[1].mincol()); + EXPECT_EQ(4u, M[1].maxcol()); + + EXPECT_EQ(9.0, M[1][2]); + EXPECT_EQ(4.0, M[1][3]); + EXPECT_EQ(2.0, M[1][4]); + + M.diminish(); + + EXPECT_FALSE(M.augmented()); + EXPECT_EQ(row_npos, M.augcol()); + + EXPECT_EQ(0u, M[0].size()); + EXPECT_EQ(1u, M[1].size()); + EXPECT_EQ(9.0, M[1][M[1].maxcol()]); + + // augmenting with too long a column is not okay... + double aug3[] = {1, 2, 3, 4, 5, 6}; + ASSERT_THROW(M.augment(aug3), msparse::msparse_error); +} diff --git a/tests/modcc/test_optimization.cpp b/tests/modcc/test_optimization.cpp index 79102bd2..af289eb6 100644 --- a/tests/modcc/test_optimization.cpp +++ b/tests/modcc/test_optimization.cpp @@ -5,54 +5,52 @@ #include "constantfolder.hpp" #include "modccutil.hpp" -using namespace nest::mc; - TEST(Optimizer, constant_folding) { - auto v = util::make_unique<ConstantFolderVisitor>(); + ConstantFolderVisitor v; { auto e = parse_line_expression("x = 2*3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 6); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x = 1 + 2 + 3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 6); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x = exp(2)"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); // The tolerance has to be loosend to 1e-15, because the optimizer performs // all intermediate calculations in 80 bit precision, which disagrees in // the 16 decimal place to the double precision value from std::exp(2.0). // This is a good thing: by using the constant folder we increase accuracy // over the unoptimized code! EXPECT_EQ(std::fabs(e->is_assignment()->rhs()->is_number()->value()-std::exp(2.0))<1e-15, true); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print("" ); } { auto e = parse_line_expression("x= 2*2 + 3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 7); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x= 3 + 2*2"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 7); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { // this doesn't work: the (y+2) expression is not a constant, so folding stops. @@ -60,23 +58,23 @@ TEST(Optimizer, constant_folding) { // one approach would be try sorting communtative operations so that numbers // are adjacent to one another in the tree auto e = parse_line_expression("x= y + 2 + 3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + e->accept(&v); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x= 2 + 3 + y"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT("");; + verbose_print(e); + e->accept(&v); + verbose_print(e); + verbose_print();; } { auto e = parse_line_expression("foo(2+3, log(32), 2*3 + x)"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT("");; + verbose_print(e); + e->accept(&v); + verbose_print(e); + verbose_print(); } } diff --git a/tests/modcc/test_parser.cpp b/tests/modcc/test_parser.cpp index 853b2136..d2a39b21 100644 --- a/tests/modcc/test_parser.cpp +++ b/tests/modcc/test_parser.cpp @@ -6,13 +6,13 @@ #include "modccutil.hpp" #include "parser.hpp" +// overload for parser errors template <typename EPtr> void verbose_print(const EPtr& e, Parser& p, const char* text) { - if (!g_verbose_flag) return; - - if (e) std::cout << e->to_string() << "\n"; - if (p.status()==lexerStatus::error) - std::cout << "in " << red(text) << "\t" << p.error_message() << "\n"; + verbose_print(e); + if (p.status()==lexerStatus::error) { + verbose_print("in ", red(text), "\t", p.error_message()); + } } template <typename Derived, typename RetUniqPtr> diff --git a/tests/modcc/test_removelocals.cpp b/tests/modcc/test_removelocals.cpp new file mode 100644 index 00000000..140de2f1 --- /dev/null +++ b/tests/modcc/test_removelocals.cpp @@ -0,0 +1,197 @@ +#include <iostream> +#include <map> +#include <regex> +#include <string> + +#include "expression.hpp" +#include "solvers.hpp" +#include "parser.hpp" +#include "scope.hpp" + +#include "test.hpp" + +symbol_ptr make_global(std::string name) { + return make_symbol<VariableExpression>(Location(), std::move(name)); +} + +using symbol_map = Scope<Symbol>::symbol_map; + +Symbol* add_procedure(symbol_map& symbols, const char* src) { + auto proc = Parser(src).parse_procedure(); + std::string name = proc->is_procedure()->name(); + + Symbol* weak = (symbols[name] = std::move(proc)).get(); + weak->semantic(symbols); + return weak; +} + +Symbol* add_global(symbol_map& symbols, const std::string& name) { + auto& var = (symbols[name] = make_symbol<VariableExpression>(Location(), name)); + return var.get(); +} + +TEST(remove_unused_locals, simple) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a \n" + " LOCAL b \n" + " a = 3 \n" + " g = a \n" + " b = 4 \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a \n" + " a = 3 \n" + " g = a \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} + +TEST(remove_unused_locals, compound) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a, b, c, d \n" + " g1 = a \n" + " g2 = c \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a, c \n" + " g1 = a \n" + " g2 = c \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g1"); + add_global(symbols, "g2"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} + +TEST(remove_unused_locals, with_dependencies) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " LOCAL d \n" + " LOCAL e \n" + " LOCAL f \n" + " g1 = f \n" + " b = 3 \n" + " a = b-e \n" + " c = log(a)+2 \n" + " d = 4 \n" + " g2 = c \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " LOCAL e \n" + " LOCAL f \n" + " g1 = f \n" + " b = 3 \n" + " a = b-e \n" + " c = log(a)+2 \n" + " g2 = c \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g1"); + add_global(symbols, "g2"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} + +TEST(remove_unused_locals, inner_block) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " LOCAL d \n" + " if (a>0) { \n" + " g = b \n" + " } \n" + " else { \n" + " d = g \n" + " g = c \n" + " } \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " if (a>0) { \n" + " g = b \n" + " } \n" + " else { \n" + " g = c \n" + " } \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} diff --git a/tests/modcc/test_symdiff.cpp b/tests/modcc/test_symdiff.cpp new file mode 100644 index 00000000..7f709825 --- /dev/null +++ b/tests/modcc/test_symdiff.cpp @@ -0,0 +1,330 @@ +#include <cmath> + +#include "test.hpp" + +#include "symdiff.hpp" +#include "parser.hpp" +#include "modccutil.hpp" + +// Test visitors for extended constant reduction, +// identifier presence detection and symbolic differentiation. + +TEST(involves_identifier, line_expr) { + const char* expr_x[] = { + "a=4+x", + "x'=a", + "a=exp(x)", + "x=sin(2)", + "~ 2x <-> (1, a)" + }; + + const char* expr_xy[] = { + "a=4+x+y", + "x'=exp(x*sin(y))", + "x=sin(2*y)", + "~ 2x <-> 3y (1, 1)" + }; + + const char* expr_xyz[] = { + "a=4+x+y*log(z)", + "x'=exp(x*sin(y))+func(2,z)", + "x=sin(2*y)+(-z)", + "~ 2x <-> 3y (a, z)" + }; + + identifier_set xyz_ids = { "x", "y", "z" }; + identifier_set yz_ids = { "y", "z" }; + identifier_set uvw_ids = { "u", "v", "w" }; + + for (auto line: expr_x) { + SCOPED_TRACE(std::string("expression: ")+line); + Parser p(line); + auto e = p.parse_statement(); + ASSERT_TRUE(e); + + EXPECT_TRUE(involves_identifier(e, "x")); + EXPECT_FALSE(involves_identifier(e, "y")); + EXPECT_FALSE(involves_identifier(e, "z")); + + EXPECT_TRUE(involves_identifier(e, xyz_ids)); + EXPECT_FALSE(involves_identifier(e, yz_ids)); + EXPECT_FALSE(involves_identifier(e, uvw_ids)); + } + + for (auto line: expr_xy) { + SCOPED_TRACE(std::string("expression: ")+line); + Parser p(line); + auto e = p.parse_statement(); + ASSERT_TRUE(e); + + EXPECT_TRUE(involves_identifier(e, "x")); + EXPECT_TRUE(involves_identifier(e, "y")); + EXPECT_FALSE(involves_identifier(e, "z")); + + EXPECT_TRUE(involves_identifier(e, xyz_ids)); + EXPECT_TRUE(involves_identifier(e, yz_ids)); + EXPECT_FALSE(involves_identifier(e, uvw_ids)); + } + + for (auto line: expr_xyz) { + SCOPED_TRACE(std::string("expression: ")+line); + Parser p(line); + auto e = p.parse_statement(); + ASSERT_TRUE(e); + + EXPECT_TRUE(involves_identifier(e, "x")); + EXPECT_TRUE(involves_identifier(e, "y")); + EXPECT_TRUE(involves_identifier(e, "z")); + + EXPECT_TRUE(involves_identifier(e, xyz_ids)); + EXPECT_TRUE(involves_identifier(e, yz_ids)); + EXPECT_FALSE(involves_identifier(e, uvw_ids)); + } +} + +TEST(constant_simplify, constants) { + struct { const char* repn; double value; } tests[] = { + { "17+3", 20. }, + { "log(exp(2))+cos(0)", 3. }, + { "0/17-1", -1. }, + { "2.5*(34/17-1.0e2)", -245. }, + { "-sin(0.523598775598298873077107230546583814)", -0.5 } + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expression: ")+item.repn); + Parser p(item.repn); + auto e = p.parse_expression(); + ASSERT_TRUE(e); + + auto value = expr_value(constant_simplify(e)); + EXPECT_FALSE(std::isnan(value)); + EXPECT_NEAR(item.value, value, 1e-8); + } +} + +TEST(constant_simplify, simplified_expr) { + // Expect simplification of 'before' expression matches parse of 'after'. + // Use output string representation of expression for easy comparison. + + struct { const char* before; const char* after; } tests[] = { + { "x+y/z", "x+y/z" }, + { "(0*x)+y/(z-(0*w))", "y/z" }, + { "y*exp(0)", "y" }, + { "x^(2-1)", "x" }, + { "0-(y+0)", "-y" } + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string()); + } +} + +TEST(constant_simplify, block_with_if) { + const char* before_repn = + "{\n" + " x = 0/z - y*log(1) + w*(3-2)\n" + " if (2>1) {\n" + " if (1==3) {\n" + " y = 64\n" + " }\n" + " else {\n" + " y = exp(0)*x\n" + " z = y-(x*0)\n" + " }\n" + " }\n" + " else {\n" + " y = 32\n" + " }\n" + "}\n"; + + const char* after_repn = + "{\n" + " x = w\n" + " y = x\n" + " z = y\n" + "}\n"; + + auto before = Parser{before_repn}.parse_block(false); + auto after = Parser{after_repn}.parse_block(false); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string()); +} + +TEST(symbolic_pdiff, expressions) { + struct { const char* before; const char* after; } tests[] = { + { "y+z*4", "0" }, + { "x", "1" }, + { "x*3", "3"}, + { "x-log(x)", "1-1/x"}, + { "sin(x)", "cos(x)"}, + { "cos(x)", "-sin(x)"}, + { "exp(x)", "exp(x)"}, + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + } +} + +TEST(symbolic_pdiff, linear) { + struct { const char* before; const char* after; } tests[] = { + { "x+y/z", "1" }, + { "3.0*x-x/2.0+x*y", "2.5+y"}, + { "(1+2.*x)/(1-exp(y))", "2./(1-exp(y))"} + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + } +} + +TEST(symbolic_pdiff, nonlinear) { + struct { const char* before; const char* after; } tests[] = { + { "sin(x)", "cos(x)" }, + { "exp(2*x)", "2*exp(2*x)" }, + { "x^2", "2*x" }, + { "a^x", "log(a)*a^x" } + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + } +} + +inline expression_ptr operator""_expr(const char* literal, std::size_t) { + return Parser{literal}.parse_expression(); +} + +TEST(substitute, expressions) { + struct { const char* before; const char* after; } tests[] = { + { "x", "y+z" }, + { "y", "y" }, + { "2.0", "2.0" }, + { "sin(x)", "sin(y+z)" }, + { "func(x,3+x)+y", "func(y+z,3+(y+z))+y" } + }; + + auto yplusz = "y+z"_expr; + ASSERT_TRUE(yplusz); + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + auto result = substitute(before.get(), "x", yplusz.get()); + EXPECT_EQ(after->to_string(), result->to_string()); + } +} + +TEST(substitute, exprmap) { + substitute_map subs; + subs["x"] = "sin(y)"_expr; + subs["y"] = "cos(x)"_expr; + + // note substitute is not recursive! + auto before = "exp(x+y)"_expr; + ASSERT_TRUE(before); + + auto after = "exp(sin(y)+cos(x))"_expr; + ASSERT_TRUE(after); + + auto result = substitute(before.get(), subs); + EXPECT_EQ(after->to_string(), result->to_string()); +} + +TEST(linear_test, homogeneous) { + linear_test_result r; + + r = linear_test("3*x"_expr, {"x"}); + EXPECT_TRUE(r.is_linear); + EXPECT_TRUE(r.is_homogeneous); + EXPECT_TRUE(r.monolinear()); + EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string()); + + r = linear_test("y-a*x+2*x"_expr, {"x", "y"}); + EXPECT_TRUE(r.is_linear); + EXPECT_TRUE(r.is_homogeneous); + EXPECT_FALSE(r.monolinear()); + EXPECT_EQ(r.coef["x"]->to_string(), "-a+2"_expr->to_string()); + EXPECT_EQ(r.coef["y"]->to_string(), "1"_expr->to_string()); +} + +TEST(linear_test, inhomogeneous) { + linear_test_result r; + + r = linear_test("sin(y)+3*x"_expr, {"x"}); + EXPECT_TRUE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); + EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string()); + EXPECT_EQ(r.constant->to_string(), "sin(y)"_expr->to_string()); + + r = linear_test("(x+y+1)*(a+b)"_expr, {"x", "y"}); + EXPECT_TRUE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); + EXPECT_EQ(r.coef["x"]->to_string(), "a+b"_expr->to_string()); + EXPECT_EQ(r.coef["y"]->to_string(), "a+b"_expr->to_string()); + EXPECT_EQ(r.constant->to_string(), "a+b"_expr->to_string()); + + // check 'gating' case still works! (Use plus instead of minus + // though because of -1 vs (- 1) parsing makes the test harder.) + r = linear_test("(a+x)/b"_expr, {"x"}); + EXPECT_TRUE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); + EXPECT_EQ(r.coef["x"]->to_string(), "1/b"_expr->to_string()); + EXPECT_EQ(r.constant->to_string(), "a/b"_expr->to_string()); +} + +TEST(linear_test, nonlinear) { + linear_test_result r; + + r = linear_test("x+x^2"_expr, {"x", "y"}); + EXPECT_FALSE(r.is_linear); + + r = linear_test("x+y*x"_expr, {"x", "y"}); + EXPECT_FALSE(r.is_linear); +} + +TEST(linear_test, diagonality) { + auto xdot = "a*x"_expr; + auto ydot = "-b*y/2"_expr; + auto zdot = "x+y+z"_expr; + + // xdot, ydot diagonal linear + EXPECT_TRUE(linear_test(xdot, {"x", "y"}).monolinear("x")); + EXPECT_TRUE(linear_test(ydot, {"x", "y"}).monolinear("y")); + + // but xdot, ydot, zdot not diagonal + EXPECT_TRUE(linear_test(xdot, {"x", "y", "z"}).monolinear("x")); + EXPECT_TRUE(linear_test(ydot, {"x", "y", "z"}).monolinear("y")); + EXPECT_FALSE(linear_test(zdot, {"x", "y", "z"}).monolinear("z")); +} diff --git a/tests/modcc/test_symge.cpp b/tests/modcc/test_symge.cpp new file mode 100644 index 00000000..7ce77aeb --- /dev/null +++ b/tests/modcc/test_symge.cpp @@ -0,0 +1,178 @@ +#include <map> +#include <utility> +#include <vector> + +#include "symge.hpp" +#include "test.hpp" + +using namespace symge; + +TEST(symge, table_define) { + symbol_table tbl; + + EXPECT_EQ(0u, tbl.size()); + + auto s1 = tbl.define(); + auto s2 = tbl.define("foo"); + auto s3 = tbl.define(symbol_term_diff{}); + auto s4 = tbl.define("bar",symbol_term_diff{}); + + EXPECT_EQ(4u, tbl.size()); + EXPECT_TRUE(tbl.valid(s1)); + EXPECT_TRUE(tbl.valid(s2)); + EXPECT_TRUE(tbl.valid(s3)); + EXPECT_TRUE(tbl.valid(s4)); + EXPECT_FALSE(tbl.valid(symbol{})); +} + +TEST(symge, symbol_ops) { + symbol_table tbl; + + symbol a = tbl.define("a"); + symbol b = tbl.define("b"); + symbol c = tbl.define("c", a*b); + symbol d = tbl.define("d", -(a*b)); + symbol e = tbl.define("e", c*d-a*b); + + EXPECT_TRUE(tbl.primitive(a)); + EXPECT_TRUE(tbl.primitive(b)); + EXPECT_FALSE(tbl.primitive(c)); + EXPECT_FALSE(tbl.primitive(d)); + EXPECT_FALSE(tbl.primitive(e)); + + auto cdef = tbl.get(c); + EXPECT_FALSE(cdef.left.is_zero()); + EXPECT_TRUE(cdef.right.is_zero()); + + auto ddef = tbl.get(d); + EXPECT_TRUE(ddef.left.is_zero()); + EXPECT_FALSE(ddef.right.is_zero()); + + auto edef = tbl.get(e); + EXPECT_FALSE(edef.left.is_zero()); + EXPECT_FALSE(edef.right.is_zero()); + + EXPECT_EQ(c, edef.left.left); + EXPECT_EQ(d, edef.left.right); + EXPECT_EQ(a, edef.right.left); + EXPECT_EQ(b, edef.right.right); +} + +TEST(symge, symbol_name) { + symbol_table tbl; + + symbol a = tbl.define("a"); + EXPECT_EQ("a", tbl.name(a)); + EXPECT_EQ("a", name(a)); + tbl.name(a, "b"); + EXPECT_EQ("b", name(a)); +} + +TEST(symge, table_index) { + symbol_table tbl; + + symbol a = tbl.define("a"); + symbol b = tbl.define("b"); + symbol c = tbl.define("c", a*b); + symbol d = tbl.define("d", -(a*b)); + symbol e = tbl.define("e", c*d-a*b); + + EXPECT_EQ(a, tbl[0]); + EXPECT_EQ(d, tbl[3]); + EXPECT_EQ(e, tbl[4]); +} + +#include <iostream> + +struct value_store { + mutable std::map<std::string, double> values; + + std::string new_var() { + return "v"+std::to_string(values.size()); + } + + double& assign(symbol s, double v) { + if (name(s).empty()) { + const_cast<symbol_table*>(s.table())->name(s, new_var()); + } + + return values[name(s)] = v; + } + + double eval(symbol s) { + return values[name(s)]; + } + + double eval(symbol_term t) { + return t.is_zero()? 0.0: eval(t.left)*eval(t.right); + } + + double eval(symbol_term_diff d) { + return eval(d.left)-eval(d.right); + } + + double& operator[](symbol s) { + if (name(s).empty() || !values.count(name(s))) { + return assign(s, 0.0); + } + return values[name(s)]; + } +}; + + +TEST(symge, gj_reduce_3x3) { + // solve: + // + // | 2 0 3 | | x | | 6 | + // | 0 4 0 | | y | = | 7 | + // | 0 -1 5 | | z | | 8 | + // + // with expected answer: + // + // x = 3/40; y = 7/4; z = 39/20 + + symbol_table tbl; + auto a = tbl.define("a"); + auto b = tbl.define("b"); + auto c = tbl.define("c"); + auto d = tbl.define("d"); + auto e = tbl.define("e"); + auto p = tbl.define("p"); + auto q = tbl.define("q"); + auto r = tbl.define("r"); + + sym_matrix A(3,3); + A[0] = sym_row({{0, a}, {2, b}}); + A[1] = sym_row({{1, c}}); + A[2] = sym_row({{1, d}, {2, e}}); + + std::vector<symbol> B = { p, q, r }; + A.augment(B); + + gj_reduce(A, tbl); + + value_store v; + v[a] = 2; + v[b] = 3; + v[c] = 4; + v[d] = -1; + v[e] = 5; + v[p] = 6; + v[q] = 7; + v[r] = 8; + + for (unsigned i = 0; i<tbl.size(); ++i) { + symbol s = tbl[i]; + if (!primitive(s)) { + v.assign(s, v.eval(definition(s))); + } + } + + double x = v.eval(A[0][3])/v.eval(A[0][0]); + double y = v.eval(A[1][3])/v.eval(A[1][1]); + double z = v.eval(A[2][3])/v.eval(A[2][2]); + + EXPECT_NEAR(x, 3.0/40.0, 1e-6); + EXPECT_NEAR(y, 7.0/4.0, 1e-6); + EXPECT_NEAR(z, 39.0/20.0, 1e-6); +} diff --git a/tests/modcc/test_visitors.cpp b/tests/modcc/test_visitors.cpp index e2b7dd7c..6f1caf64 100644 --- a/tests/modcc/test_visitors.cpp +++ b/tests/modcc/test_visitors.cpp @@ -6,110 +6,116 @@ #include "parser.hpp" #include "modccutil.hpp" +// overload for parser errors +template <typename EPtr> +void verbose_print(const EPtr& e, Parser& p, const char* text) { + verbose_print(e); + if (p.status()==lexerStatus::error) { + verbose_print("in ", cyan(text), "\t", p.error_message()); + } +} + /************************************************************** * visitors **************************************************************/ -using namespace nest::mc; - TEST(FlopVisitor, basic) { { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x+y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 1); + FlopVisitor visitor; + auto e = parse_expression("x+y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x-y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 1); + FlopVisitor visitor; + auto e = parse_expression("x-y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x*y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.mul, 1); + FlopVisitor visitor; + auto e = parse_expression("x*y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.mul, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x/y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.div, 1); + FlopVisitor visitor; + auto e = parse_expression("x/y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.div, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("exp(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_expression("exp(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.exp, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("log(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.log, 1); + FlopVisitor visitor; + auto e = parse_expression("log(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.log, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("cos(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.cos, 1); + FlopVisitor visitor; + auto e = parse_expression("cos(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.cos, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("sin(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.sin, 1); + FlopVisitor visitor; + auto e = parse_expression("sin(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.sin, 1); } } TEST(FlopVisitor, compound) { { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x+y*z/a-b"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 2); - EXPECT_EQ(visitor->flops.mul, 1); - EXPECT_EQ(visitor->flops.div, 1); + FlopVisitor visitor; + auto e = parse_expression("x+y*z/a-b"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 2); + EXPECT_EQ(visitor.flops.mul, 1); + EXPECT_EQ(visitor.flops.div, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("exp(x+y+z)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 2); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_expression("exp(x+y+z)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 2); + EXPECT_EQ(visitor.flops.exp, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("exp(x+y) + 3/(12 + z)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 3); - EXPECT_EQ(visitor->flops.div, 1); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_expression("exp(x+y) + 3/(12 + z)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 3); + EXPECT_EQ(visitor.flops.div, 1); + EXPECT_EQ(visitor.flops.exp, 1); } // test asssignment expression { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_line_expression("x = exp(x+y) + 3/(12 + z)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 3); - EXPECT_EQ(visitor->flops.div, 1); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_line_expression("x = exp(x+y) + 3/(12 + z)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 3); + EXPECT_EQ(visitor.flops.div, 1); + EXPECT_EQ(visitor.flops.exp, 1); } } TEST(FlopVisitor, procedure) { - { const char *expression = "PROCEDURE trates(v) {\n" " LOCAL qt\n" @@ -119,20 +125,18 @@ TEST(FlopVisitor, procedure) { " mtau = 0.6\n" " htau = 1500\n" "}"; - auto visitor = util::make_unique<FlopVisitor>(); + FlopVisitor visitor; auto e = parse_procedure(expression); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 6); - EXPECT_EQ(visitor->flops.neg, 0); - EXPECT_EQ(visitor->flops.mul, 0); - EXPECT_EQ(visitor->flops.div, 5); - EXPECT_EQ(visitor->flops.exp, 2); - EXPECT_EQ(visitor->flops.pow, 1); - } + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 6); + EXPECT_EQ(visitor.flops.neg, 0); + EXPECT_EQ(visitor.flops.mul, 0); + EXPECT_EQ(visitor.flops.div, 5); + EXPECT_EQ(visitor.flops.exp, 2); + EXPECT_EQ(visitor.flops.pow, 1); } TEST(FlopVisitor, function) { - { const char *expression = "FUNCTION foo(v) {\n" " LOCAL qt\n" @@ -141,16 +145,15 @@ TEST(FlopVisitor, function) { " hinf=1/(1+exp((v-vhalfh)/kh))\n" " foo = minf + hinf\n" "}"; - auto visitor = util::make_unique<FlopVisitor>(); + FlopVisitor visitor; auto e = parse_function(expression); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 7); - EXPECT_EQ(visitor->flops.neg, 1); - EXPECT_EQ(visitor->flops.mul, 0); - EXPECT_EQ(visitor->flops.div, 5); - EXPECT_EQ(visitor->flops.exp, 2); - EXPECT_EQ(visitor->flops.pow, 1); - } + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 7); + EXPECT_EQ(visitor.flops.neg, 1); + EXPECT_EQ(visitor.flops.mul, 0); + EXPECT_EQ(visitor.flops.div, 5); + EXPECT_EQ(visitor.flops.exp, 2); + EXPECT_EQ(visitor.flops.pow, 1); } TEST(ClassificationVisitor, linear) { @@ -199,20 +202,14 @@ TEST(ClassificationVisitor, linear) { if( e==nullptr ) continue; e->semantic(scope); - auto v = new ExpressionClassifierVisitor(x); - e->accept(v); - //std::cout << "expression " << e->to_string() << std::endl; - //std::cout << "linear " << v->linear_coefficient()->to_string() << std::endl; - //std::cout << "constant " << v->constant_term()->to_string() << std::endl; - EXPECT_EQ(v->classify(), expressionClassification::linear); - -#ifdef VERBOSE_TEST - std::cout << "eq " << e->to_string() - << "\ncoeff " << v->linear_coefficient()->to_string() - << "\nconst " << v-> constant_term()->to_string() - << "\n----" << std::endl; -#endif - delete v; + ExpressionClassifierVisitor v(x); + e->accept(&v); + EXPECT_EQ(v.classify(), expressionClassification::linear); + + verbose_print("eq ", e); + verbose_print("coeff ", v.linear_coefficient()); + verbose_print("const ", v.constant_term()); + verbose_print("----"); } } @@ -235,23 +232,19 @@ TEST(ClassificationVisitor, constant) { auto x = globals["x"].get(); for(auto const& expression : expressions) { - auto e = parse_expression(expression); + Parser p{expression}; + auto e = p.parse_expression(); // sanity check the compiler EXPECT_NE(e, nullptr); if( e==nullptr ) continue; e->semantic(scope); - auto v = new ExpressionClassifierVisitor(x); - e->accept(v); - EXPECT_EQ(v->classify(), expressionClassification::constant); - -#ifdef VERBOSE_TEST - if(e) std::cout << e->to_string() << std::endl; - if(p.status()==lexerStatus::error) - std::cout << "in " << colorize(expression, kCyan) << "\t" << p.error_message() << std::endl; -#endif - delete v; + ExpressionClassifierVisitor v(x); + e->accept(&v); + EXPECT_EQ(v.classify(), expressionClassification::constant); + + verbose_print(e, p, expression); } } @@ -281,24 +274,20 @@ TEST(ClassificationVisitor, nonlinear) { auto scope = std::make_shared<Scope<Symbol>>(globals); auto x = globals["x"].get(); - auto v = new ExpressionClassifierVisitor(x); + ExpressionClassifierVisitor v(x); for(auto const& expression : expressions) { - auto e = parse_expression(expression); + Parser p{expression}; + auto e = p.parse_expression(); // sanity check the compiler EXPECT_NE(e, nullptr); if( e==nullptr ) continue; e->semantic(scope); - v->reset(); - e->accept(v); - EXPECT_EQ(v->classify(), expressionClassification::nonlinear); - -#ifdef VERBOSE_TEST - if(e) std::cout << e->to_string() << std::endl; - if(p.status()==lexerStatus::error) - std::cout << "in " << colorize(expression, kCyan) << "\t" << p.error_message() << std::endl; -#endif + v.reset(); + e->accept(&v); + EXPECT_EQ(v.classify(), expressionClassification::nonlinear); + + verbose_print(e, p, expression); } - delete v; } diff --git a/tests/validation/CMakeLists.txt b/tests/validation/CMakeLists.txt index f7e9f849..a42fc4dd 100644 --- a/tests/validation/CMakeLists.txt +++ b/tests/validation/CMakeLists.txt @@ -3,6 +3,7 @@ set(VALIDATION_SOURCES validate_ball_and_stick.cpp validate_compartment_policy.cpp validate_soma.cpp + validate_kinetic.cpp validate_synapses.cpp # support code diff --git a/tests/validation/validate_kinetic.cpp b/tests/validation/validate_kinetic.cpp new file mode 100644 index 00000000..1c200516 --- /dev/null +++ b/tests/validation/validate_kinetic.cpp @@ -0,0 +1,13 @@ +#include "validate_kinetic.hpp" + +#include "../gtest.h" + +using lowered_cell = nest::mc::fvm::fvm_multicell<nest::mc::multicore::backend>; + +TEST(kinetic, kin1_numeric_ref) { + validate_kinetic_kin1<lowered_cell>(); +} + +TEST(kinetic, kinlva_numeric_ref) { + validate_kinetic_kinlva<lowered_cell>(); +} diff --git a/tests/validation/validate_kinetic.hpp b/tests/validation/validate_kinetic.hpp new file mode 100644 index 00000000..7ccde7c9 --- /dev/null +++ b/tests/validation/validate_kinetic.hpp @@ -0,0 +1,90 @@ +#include <json/json.hpp> + +#include <common_types.hpp> +#include <cell.hpp> +#include <fvm_multicell.hpp> +#include <model.hpp> +#include <recipe.hpp> +#include <simple_sampler.hpp> +#include <util/rangeutil.hpp> + +#include "../test_common_cells.hpp" +#include "convergence_test.hpp" +#include "trace_analysis.hpp" +#include "validation_data.hpp" + +template <typename LoweredCell> +void run_kinetic_dt(nest::mc::cell& c, float t_end, nlohmann::json meta, const std::string& ref_file) { + using namespace nest::mc; + + float sample_dt = .025f; + sampler_info samplers[] = { + {"soma.mid", {0u, 0u}, simple_sampler(sample_dt)} + }; + + meta["sim"] = "nestmc"; + meta["backend"] = LoweredCell::backend::name(); + convergence_test_runner<float> runner("dt", samplers, meta); + runner.load_reference_data(ref_file); + + model<LoweredCell> model(singleton_recipe{c}); + + auto exclude = stimulus_ends(c); + + // use dt = 0.05, 0.02, 0.01, 0.005, 0.002, ... + double max_oo_dt = std::round(1.0/g_trace_io.min_dt()); + for (double base = 100; ; base *= 10) { + for (double multiple: {5., 2., 1.}) { + double oo_dt = base/multiple; + if (oo_dt>max_oo_dt) goto end; + + model.reset(); + float dt = float(1./oo_dt); + runner.run(model, dt, t_end, dt, exclude); + } + } + +end: + runner.report(); + runner.assert_all_convergence(); +} + +template <typename LoweredCell> +void validate_kinetic_kin1() { + using namespace nest::mc; + + // 20 µm diameter soma with single mechanism, current probe + cell c; + auto soma = c.add_soma(10); + c.add_probe({{0, 0.5}, probeKind::membrane_current}); + soma->add_mechanism(std::string("test_kin1")); + + nlohmann::json meta = { + {"model", "test_kin1"}, + {"name", "membrane current"}, + {"units", "nA"} + }; + + run_kinetic_dt<LoweredCell>(c, 100.f, meta, "numeric_kin1.json"); +} + +template <typename LoweredCell> +void validate_kinetic_kinlva() { + using namespace nest::mc; + + // 20 µm diameter soma with single mechanism, current probe + cell c; + auto soma = c.add_soma(10); + c.add_probe({{0, 0.5}, probeKind::membrane_voltage}); + c.add_stimulus({0,0.5}, {20., 130., -0.025}); + soma->add_mechanism(std::string("test_kinlva")); + + nlohmann::json meta = { + {"model", "test_kinlva"}, + {"name", "membrane voltage"}, + {"units", "mV"} + }; + + run_kinetic_dt<LoweredCell>(c, 300.f, meta, "numeric_kinlva.json"); +} + diff --git a/tests/validation/validate_soma.hpp b/tests/validation/validate_soma.hpp index 2658ee1c..eee44e6b 100644 --- a/tests/validation/validate_soma.hpp +++ b/tests/validation/validate_soma.hpp @@ -37,10 +37,10 @@ void validate_soma() { float t_end = 100.f; - // use dt = 0.05, 0.025, 0.01, 0.005, 0.0025, ... + // use dt = 0.05, 0.02, 0.01, 0.005, 0.002, ... double max_oo_dt = std::round(1.0/g_trace_io.min_dt()); for (double base = 100; ; base *= 10) { - for (double multiple: {5., 2.5, 1.}) { + for (double multiple: {5., 2., 1.}) { double oo_dt = base/multiple; if (oo_dt>max_oo_dt) goto end; diff --git a/validation/ref/numeric/CMakeLists.txt b/validation/ref/numeric/CMakeLists.txt index 4a86db71..510f58ee 100644 --- a/validation/ref/numeric/CMakeLists.txt +++ b/validation/ref/numeric/CMakeLists.txt @@ -1,6 +1,16 @@ # note: function add_validation_data defined in validation/CMakeLists.txt if(NMC_BUILD_JULIA_VALIDATION_DATA) + add_validation_data( + OUTPUT numeric_kin1.json + DEPENDS numeric_kin1.jl + COMMAND ${JULIA_BIN} numeric_kin1.jl) + + add_validation_data( + OUTPUT numeric_kinlva.json + DEPENDS numeric_kinlva.jl LVAChannels.jl + COMMAND ${JULIA_BIN} numeric_kinlva.jl) + add_validation_data( OUTPUT numeric_soma.json DEPENDS numeric_soma.jl HHChannels.jl diff --git a/validation/ref/numeric/LVAChannels.jl b/validation/ref/numeric/LVAChannels.jl new file mode 100644 index 00000000..1d11e6b0 --- /dev/null +++ b/validation/ref/numeric/LVAChannels.jl @@ -0,0 +1,168 @@ +module LVAChannels + +export Stim, run_lva, LVAParam + +using Sundials +using SIUnits +using SIUnits.ShortUnits + +const mS = Milli*Siemens + +immutable LVAParam + c_m # membrane spacific capacitance + gbar # Ca channel cross-membrane conductivity + eca # Ca channel reversal potential + gl # leak conductivity + el # leak reversal potential + q10_1 # rate scaling for 'm' activation gate + q10_2 # rate scaling for 'dsh' inactivation gate + vrest # (as hoc) resting potential + + # constructor with default values, with q10 values + # corresponding to the body temperature current-clamp experiments + LVAParam(; + c_m = 0.01F*m^-2, + gbar = 0.2mS*cm^-2, + eca = 120mV, + gl = .1mS*cm^-2, + el = -65mV, + q10_1 = 5, + q10_2 = 3, +# vrest = -61.47mV + vrest = -63mV + ) = new(c_m, gbar, eca, gl, el, q10_1, q10_2, vrest) + +end + +immutable Stim + t0 # start time of stimulus + t1 # stop time of stimulus + i_e # stimulus current density + + Stim() = new(0s, 0s, 0A/m^2) + Stim(t0, t1, i_e) = new(t0, t1, i_e) +end + +# 'm' activation gate +function m_lims(v, q10) + quotient = 1+exp(-(v+63mV)/7.8mV) + mtau = 1ms*(1.7+exp(-(v+28.8mV)/13.5mV))/(q10*quotient) + minf = 1/quotient + return mtau, minf +end + +# 'dsh' inactivation gate: +# d <-> s (alpha2, beta2) +# s <-> h (alpha1, beta1) +# subject to s=1-h-d +function dsh_lims(v, q10) + k = sqrt(0.25+exp((v+83.5mV)/6.3mV))-0.5 + alpha1 = q10*exp(-(v+160.3mV)/17.8mV)/1ms + beta1 = alpha1*k + + tau2 = 240ms/(1+exp((v+37.4mV)/30mV))/q10 + alpha2 = 1/tau2/(1+k) + beta2 = alpha2*k + + hinf = 1/(1+k+k^2) + dinf = k^2*hinf + + return alpha1, beta1, alpha2, beta2, dinf, hinf +end + +# Choose initial conditions for the system such that the gating variables +# are at steady state for the user-specified voltage v. +function initial_conditions(v, q10_1, q10_2) + mtau, minf = m_lims(v, q10_1) + alpha1, beta1, alpha2, beta2, dinf, hinf = dsh_lims(v, q10_2) + + return (v, minf, dinf, hinf) +end + +# Given time t and state (v, m, d, h), +# return (vdot, mdot, ddot, hdot). +function f(t, state; p=LVAParam(), istim=0A/m^2) + v, m, d, h = state + + ica = p.gbar*m^3*h*(v - p.eca) + il = p.gl*(v - p.el) + itot = ica + il - istim + + # Calculate the voltage dependent rates for the gating variables. + mtau, minf = m_lims(v, p.q10_1) + alpha1, beta1, alpha2, beta2, dinf, hinf = dsh_lims(v, p.q10_2) + + mdot = (minf-m)/mtau + hdot = alpha1*(1-h-d)-beta1*h + ddot = beta2*(1-h-d)-alpha2*d + + return (-itot/p.c_m, mdot, ddot, hdot) +end + +function make_range(t_start, dt, t_end) + r = collect(t_start: dt: t_end) + if length(r)>0 && r[length(r)]<t_end + push!(r, t_end) + end + return r +end + +function run_lva(t_end; stim=Stim(), param=LVAParam(), sample_dt=0.01ms) + v_scale = 1V + t_scale = 1s + + v0, m0, h0, d0 = initial_conditions(param.vrest, param.q10_1, param.q10_2) + y0 = [ v0/v_scale, m0, h0, d0 ] + + + fbis(t, y, ydot, istim) = begin + vdot, mdot, hdot, ddot = + f(t*t_scale, (y[1]*v_scale, y[2], y[3], y[4]), istim=istim, p=param) + + ydot[1], ydot[2], ydot[3], ydot[4] = + vdot*t_scale/v_scale, mdot*t_scale, hdot*t_scale, ddot*t_scale + + return Sundials.CV_SUCCESS + end + + fbis_nostim(t, y, ydot) = fbis(t, y, ydot, 0A/m^2) + fbis_stim(t, y, ydot) = fbis(t, y, ydot, stim.i_e) + + # Ideally would run with vector absolute tolerance to account for v_scale, + # but this would prevent us using the nice cvode wrapper. + + res = [] + samples = [] + + t1 = clamp(stim.t0, 0s, t_end) + if t1>0s + ts = make_range(0s, sample_dt, t1) + r = Sundials.cvode(fbis_nostim, y0, map(t->t/t_scale, ts), abstol=1e-6, reltol=5e-10) + y0 = vec(r[size(r)[1], :]) + push!(res, r) + push!(samples, ts) + end + t2 = clamp(stim.t1, t1, t_end) + if t2>t1 + ts = make_range(t1, sample_dt, t2) + r = Sundials.cvode(fbis_stim, y0, map(t->t/t_scale, ts), abstol=1e-6, reltol=5e-10) + y0 = vec(r[size(r)[1], :]) + push!(res, r) + push!(samples, ts) + end + if t_end>t2 + ts = make_range(t2, sample_dt, t_end) + r = Sundials.cvode(fbis_nostim, y0, map(t->t/t_scale, ts), abstol=1e-6, reltol=5e-10) + y0 = vec(r[size(r)[1], :]) + push!(res, r) + push!(samples, ts) + end + + res = vcat(res...) + samples = vcat(samples...) + + # Use map here because of issues with type deduction with arrays and SIUnits. + return samples, map(v->v*v_scale, res[:, 1]), res[:, 2], res[:, 3], res[:, 4] +end + +end # module LVAChannels diff --git a/validation/ref/numeric/numeric_kin1.jl b/validation/ref/numeric/numeric_kin1.jl new file mode 100644 index 00000000..30090314 --- /dev/null +++ b/validation/ref/numeric/numeric_kin1.jl @@ -0,0 +1,33 @@ +#!/usr/bin/env julia + +include("HHChannels.jl") + +using JSON +using SIUnits.ShortUnits + +radius = 20µm/2 +area = 4*pi*radius^2 +sample_dt = 0.025ms +t_end = 100ms + +a0 = 0.01mA/cm^2 +c = 0.01mA/cm^2 + +tau = 10ms + +ts = collect(0s: sample_dt: t_end) +is = area*(1/3*c + (a0-1/3*c)*exp(-ts/tau)) + +trace = Dict( + :name => "membrane current", + :sim => "numeric", + :model => "test_kin1", + :units => "nA", + :data => Dict( + :time => map(t->t/ms, ts), + Symbol("soma.mid") => map(i->i/nA, is) + ) +) + +println(JSON.json([trace])) + diff --git a/validation/ref/numeric/numeric_kinlva.jl b/validation/ref/numeric/numeric_kinlva.jl new file mode 100644 index 00000000..9dc18cea --- /dev/null +++ b/validation/ref/numeric/numeric_kinlva.jl @@ -0,0 +1,40 @@ +#!/usr/bin/env julia + +include("LVAChannels.jl") + +using JSON +using SIUnits.ShortUnits +using LVAChannels + +radius = 20µm/2 +area = 4*pi*radius^2 +current = -0.025nA + +stim = Stim(20ms, 150ms, current/area) +ts, vs, m, d, h = run_lva(300ms, param=LVAParam(vrest=-65mV), stim=stim, sample_dt=0.025ms) + +trace = Dict( + :name => "membrane voltage", + :sim => "numeric", + :model => "test_kinlva", + :units => "mV", + :data => Dict( + :time => map(t->t/ms, ts), + Symbol("soma.mid") => map(v->v/mV, vs) + ) +) + +state = Dict( + :name => "mechanisms state", + :sim => "numeric", + :model => "kinlva", + :units => "1", + :data => Dict( + :time => map(t->t/ms, ts), + Symbol("m") => m, + Symbol("d") => d, + Symbol("h") => h + ) +) +println(JSON.json([trace, state])) + -- GitLab