diff --git a/mechanisms/CMakeLists.txt b/mechanisms/CMakeLists.txt index 10fdae9e86e5700cd79a5ffed9dbafbcaf46d677..3f543b64acd1538dd684a4f3a0fa4e359c981649 100644 --- a/mechanisms/CMakeLists.txt +++ b/mechanisms/CMakeLists.txt @@ -1,7 +1,7 @@ include(BuildModules.cmake) # the list of built-in mechanisms to be provided by default -set(mechanisms pas hh expsyn exp2syn) +set(mechanisms pas hh expsyn exp2syn test_kin1 test_kinlva) set(modcc_opt) if(NMC_USE_OPTIMIZED_KERNELS) # generate optimized kernels diff --git a/mechanisms/mod/test_kin1.mod b/mechanisms/mod/test_kin1.mod new file mode 100644 index 0000000000000000000000000000000000000000..c877fb4ef8ba04f0ab17adb720deb03e791e5964 --- /dev/null +++ b/mechanisms/mod/test_kin1.mod @@ -0,0 +1,38 @@ +NEURON { + SUFFIX test_kin1 + NONSPECIFIC_CURRENT il +} + +UNITS { + (mV) = (millivolt) + (mA) = (milliamp) + (S) = (siemens) +} + +PARAMETER { + tau = 10 (ms) +} + +STATE { + a (mA/cm2) + b (mA/cm2) +} + +ASSIGNED { + v (mV) +} + +BREAKPOINT { + SOLVE states METHOD sparse + il = 0*v+a +} + +INITIAL { + a = 0.01 + b = 0 +} + +KINETIC states { + ~ a <-> b (2/3/tau, 1/3/tau) +} + diff --git a/mechanisms/mod/test_kinlva.mod b/mechanisms/mod/test_kinlva.mod new file mode 100644 index 0000000000000000000000000000000000000000..51a541e8e8f6b61a7608dd9a96263ed85d8fe12a --- /dev/null +++ b/mechanisms/mod/test_kinlva.mod @@ -0,0 +1,76 @@ +: Adaption of T-type calcium channel from Wang, X. J. et al. 1991; +: c.f. NMODL file in ModelDB: +: https://senselab.med.yale.edu/modeldb/showModel.cshtml?model=53893 +: +: Note the temperature rate correction factors of 5 (for m) and 3 +: (for h <-> s <-> d) have been applied to match the model described +: in the current-clamp experiments (see p. 842). + +NEURON { + SUFFIX test_kinlva + USEION ca WRITE ica + NONSPECIFIC_CURRENT il +} + +UNITS { + (mS) = (millisiemens) + (mV) = (millivolts) + (mA) = (millamp) +} + +PARAMETER { + gbar = 0.0002 (S/cm2) + gl = 0.0001 (S/cm2) + eca = 120 (mV) + el = -65 (mV) +} + +STATE { + m h s d +} + +BREAKPOINT { + SOLVE m_state METHOD cnexp + SOLVE dsh_state METHOD sparse + ica = gbar*m^3*h*(v-eca) + il = gl*(v-el) +} + +FUNCTION minf(v) { + minf = 1/(1+exp(-(v+63)/7.8)) +} + +FUNCTION K(v) { + K = (0.25+exp((v+83.5)/6.3))^0.5-0.5 +} + +DERIVATIVE m_state { + LOCAL taum, mi, m_q10 + m_q10 = 5 + mi = minf(v) + taum = (1.7+exp(-(v+28.8)/13.5))*mi + m' = m_q10*(mi - m)/taum +} + +KINETIC dsh_state { + LOCAL k, alpha1, beta1, alpha2, beta2, dsh_q10 + dsh_q10 = 3 + k = K(v) + alpha1 = dsh_q10*exp(-(v+160.3)/17.8) + beta1 = alpha1*k + alpha2 = dsh_q10*(1+exp((v+37.4)/30))/240/(1+k) + beta2 = alpha2*k + + ~ s <-> h (alpha1, beta1) + ~ d <-> s (alpha2, beta2) +} + +INITIAL { + LOCAL k, vrest + vrest = -65 + k = K(v) + m = minf(vrest) + h = 1/(1+k+k^2) + d = h*k^2 + s = 1-h-d +} diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 2dcf6d351d13cd23f3c18513e71c9eb41eb48e21..29dffa3c309634e9366d6c7359fa5ff888f1c38f 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -9,8 +9,12 @@ set(MODCC_SOURCES functionexpander.cpp functioninliner.cpp lexer.cpp + kineticrewriter.cpp module.cpp parser.cpp + solvers.cpp + symdiff.cpp + symge.cpp textbuffer.cpp token.cpp ) diff --git a/modcc/astmanip.cpp b/modcc/astmanip.cpp index e723c675a1a755358f5aa88e260773c1a919c61c..898374ff332395e40eb7572f03429bffd9f37a86 100644 --- a/modcc/astmanip.cpp +++ b/modcc/astmanip.cpp @@ -28,3 +28,14 @@ local_assignment make_unique_local_assign(scope_ptr scope, Expression* e, std::s return { std::move(local), std::move(ass), std::move(id), scope }; } +local_declaration make_unique_local_decl(scope_ptr scope, Location loc, std::string const& prefix) { + std::string name = unique_local_name(scope, prefix); + + auto local = make_expression<LocalDeclaration>(loc, name); + local->semantic(scope); + + auto id = make_expression<LocalDeclaration>(loc, name); + id->semantic(scope); + + return { std::move(local), std::move(id), scope }; +} diff --git a/modcc/astmanip.hpp b/modcc/astmanip.hpp index b3135cfc54366a3ce9b0fa6d6cc8874f952949ae..e42586ff9b4197e633b7702801a3f133c58ad67f 100644 --- a/modcc/astmanip.hpp +++ b/modcc/astmanip.hpp @@ -10,19 +10,30 @@ // Create new local variable symbol and local declaration expression in current scope. // Returns the local declaration expression. -expression_ptr make_unique_local_decl(scope_ptr scope, Location loc, std::string const& prefix="ll"); -struct local_assignment { +struct local_declaration { expression_ptr local_decl; - expression_ptr assignment; expression_ptr id; scope_ptr scope; }; +local_declaration make_unique_local_decl( + scope_ptr scope, + Location loc, + std::string const& prefix="ll"); + // Create a local declaration as for `make_unique_local_decl`, together with an // assignment to it from the given expression, using the location of that expression. // Returns local declaration expression, assignment expression, new identifier id and // consequent scope. + +struct local_assignment { + expression_ptr local_decl; + expression_ptr assignment; + expression_ptr id; + scope_ptr scope; +}; + local_assignment make_unique_local_assign( scope_ptr scope, Expression* e, diff --git a/modcc/error.hpp b/modcc/error.hpp index f113eff06de12c34c2f766e7e1b53e8db69e9101..e66fc45478faa8dde356e37b7e51d7b0460f000c 100644 --- a/modcc/error.hpp +++ b/modcc/error.hpp @@ -1,25 +1,76 @@ #pragma once +#include <deque> +#include <iterator> +#include <stdexcept> +#include <string> + #include "location.hpp" +struct error_entry { + std::string message; + Location location; + + error_entry(std::string m): message(std::move(m)) {} + error_entry(std::string m, Location l): message(std::move(m)), location(l) {} +}; + +// Mixin class for managing a stack of error info. + +class error_stack { +private: + std::deque<error_entry> errors_; + std::deque<error_entry> warnings_; + +public: + bool has_error() const { return !errors_.empty(); } + void error(error_entry info) { errors_.push_back(std::move(info)); } + void clear_errors() { errors_.clear(); } + + std::deque<error_entry>& errors() { return errors_; } + const std::deque<error_entry>& errors() const { return errors_; } + + template <typename Seq> + void append_errors(const Seq& seq) { + errors_.insert(errors_.end(), std::begin(seq), std::end(seq)); + } + + bool has_warning() const { return !warnings_.empty(); } + void warning(error_entry info) { warnings_.push_back(std::move(info)); } + void clear_warnings() { warnings_.clear(); } + + std::deque<error_entry>& warnings() { return warnings_; } + const std::deque<error_entry>& warnings() const { return warnings_; } + + template <typename Seq> + void append_warnings(const Seq& seq) { + warnings_.insert(warnings_.end(), std::begin(seq), std::end(seq)); + } +}; + +// Wrap error entry in exception. + class compiler_exception : public std::exception { public: + explicit compiler_exception(error_entry info) + : error_info_(std::move(info)) + {} + compiler_exception(std::string m, Location location) - : location_(location), - message_(std::move(m)) + : error_info_({std::move(m), location}) {} virtual const char* what() const throw() { - return message_.c_str(); + return error_info_.message.c_str(); } Location const& location() const { - return location_; + return error_info_.location; } private: - - Location location_; - std::string message_; + error_entry error_info_; }; + + diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 7c3eca9f87e0abac094fc95f4cc8e0514ad0b0a5..7b05d65fe8c974a9bc54c5e01e81e9b6cc666a49 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -756,7 +756,7 @@ void BlockExpression::semantic(scope_ptr scp) { } expression_ptr BlockExpression::clone() const { - std::list<expression_ptr> statements; + expr_list_type statements; for(auto& e: statements_) { statements.emplace_back(e->clone()); } @@ -805,6 +805,29 @@ expression_ptr IfExpression::clone() const { ); } +/******************************************************************************* + PDiffExpression +*******************************************************************************/ + +std::string PDiffExpression::to_string() const { + return blue("pdiff") + " ( " + var_->to_string() + "; " + arg_->to_string() + ")"; +} + +void PDiffExpression::semantic(scope_ptr scp) { + scope_ = scp; + + if (!var_->is_identifier()) { + error(pprintf("the variable in the partial differential expression is not " + "an identifier, but instead %", yellow(var_->to_string()))); + } + var_->semantic(scp); + arg_->semantic(scp); +} + +expression_ptr PDiffExpression::clone() const { + return make_expression<PDiffExpression>(location_, var_->clone(), arg_->clone()); +} + #include "visitor.hpp" /* @@ -930,6 +953,9 @@ void PowBinaryExpression::accept(Visitor *v) { void ConditionalExpression::accept(Visitor *v) { v->visit(this); } +void PDiffExpression::accept(Visitor *v) { + v->visit(this); +} expression_ptr unary_expression( Location loc, tok op, diff --git a/modcc/expression.hpp b/modcc/expression.hpp index c20e00062a9fb37acaa91c4f3fd8ebdf47f323c8..a44bdc5bdea229a98ec4529c0a0e305f571009d4 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -55,11 +55,13 @@ class SolveExpression; class ConductanceExpression; class Symbol; class LocalVariable; +class PDiffExpression; // not parsed; possible result of symbolic differentiation using expression_ptr = std::unique_ptr<Expression>; using symbol_ptr = std::unique_ptr<Symbol>; using scope_type = Scope<Symbol>; using scope_ptr = std::shared_ptr<scope_type>; +using expr_list_type = std::list<expression_ptr>; template <typename T, typename... Args> expression_ptr make_expression(Args&&... args) { @@ -101,14 +103,16 @@ std::string to_string(symbolKind k); /// methods for time stepping state enum class solverMethod { - cnexp, // the only method we have at the moment + cnexp, // for diagonal linear ODE systems. + sparse, // for non-diagonal linear ODE systems. none }; static std::string to_string(solverMethod m) { switch(m) { - case solverMethod::cnexp : return std::string("cnexp"); - case solverMethod::none : return std::string("none"); + case solverMethod::cnexp: return std::string("cnexp"); + case solverMethod::sparse: return std::string("sparse"); + case solverMethod::none: return std::string("none"); } return std::string("<error : undefined solverMethod>"); } @@ -155,6 +159,7 @@ public: virtual expression_ptr clone() const; // easy lookup of properties + virtual CallExpression* is_call() {return nullptr;} virtual CallExpression* is_function_call() {return nullptr;} virtual CallExpression* is_procedure_call() {return nullptr;} virtual BlockExpression* is_block() {return nullptr;} @@ -179,6 +184,7 @@ public: virtual SolveExpression* is_solve_statement() {return nullptr;} virtual Symbol* is_symbol() {return nullptr;} virtual ConductanceExpression* is_conductance_statement() {return nullptr;} + virtual PDiffExpression* is_pdiff() {return nullptr;} virtual bool is_lvalue() const {return false;} virtual bool is_global_lvalue() const {return false;} @@ -732,13 +738,13 @@ private: class BlockExpression : public Expression { protected: - std::list<expression_ptr> statements_; + expr_list_type statements_; bool is_nested_ = false; public: BlockExpression( Location loc, - std::list<expression_ptr>&& statements, + expr_list_type&& statements, bool is_nested) : Expression(loc), statements_(std::move(statements)), @@ -749,7 +755,7 @@ public: return this; } - std::list<expression_ptr>& statements() { + expr_list_type& statements() { return statements_; } @@ -955,6 +961,9 @@ public: void accept(Visitor *v) override; + CallExpression* is_call() override { + return this; + } CallExpression* is_function_call() override { return symbol_->kind() == symbolKind::function ? this : nullptr; } @@ -1003,6 +1012,16 @@ public: BlockExpression* body() { return body_.get()->is_block(); } + void body(expression_ptr&& new_body) { + if(!new_body->is_block()) { + Location loc = new_body? new_body->location(): Location{}; + throw compiler_exception( + " attempt to set ProcedureExpression body with non-block expression, i.e.\n" + + new_body->to_string(), + loc); + } + body_ = std::move(new_body); + } void semantic(scope_type::symbol_map &scp) override; ProcedureExpression* is_procedure() override {return this;} @@ -1044,7 +1063,7 @@ class InitialBlock : public BlockExpression { public: InitialBlock( Location loc, - std::list<expression_ptr>&& statements) + expr_list_type&& statements) : BlockExpression(loc, std::move(statements), true) {} @@ -1315,3 +1334,25 @@ public: void accept(Visitor *v) override; }; + +class PDiffExpression : public Expression { +public: + PDiffExpression(Location loc, expression_ptr&& var, expression_ptr&& arg) + : Expression(loc), var_(std::move(var)), arg_(std::move(arg)) + {} + + std::string to_string() const override; + void accept(Visitor *v) override; + void semantic(scope_ptr scp) override; + expression_ptr clone() const override; + + PDiffExpression* is_pdiff() override { return this; } + + Expression* var() { return var_.get(); } + Expression* arg() { return arg_.get(); } + +private: + expression_ptr var_; + expression_ptr arg_; +}; + diff --git a/modcc/functionexpander.cpp b/modcc/functionexpander.cpp index 402ca71cfaa7c2677dcd9900079fa59526ed7e75..7cbc3545f1e109002cb64e8d183b7507b4dd8ed0 100644 --- a/modcc/functionexpander.cpp +++ b/modcc/functionexpander.cpp @@ -5,9 +5,7 @@ #include "functionexpander.hpp" #include "modccutil.hpp" -using namespace nest::mc; - -expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* e) { +expression_ptr insert_unique_local_assignment(expr_list_type& stmts, Expression* e) { auto exprs = make_unique_local_assign(e->scope(), e); stmts.push_front(std::move(exprs.local_decl)); stmts.push_back(std::move(exprs.assignment)); @@ -18,9 +16,9 @@ expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* // function call site lowering /////////////////////////////////////////////////////////////////////////////// -call_list_type lower_function_calls(Expression* e) +expr_list_type lower_function_calls(Expression* e) { - auto v = util::make_unique<FunctionCallLowerer>(e->scope()); + auto v = make_unique<FunctionCallLowerer>(e->scope()); if(auto a=e->is_assignment()) { #ifdef LOGGING @@ -107,10 +105,10 @@ void FunctionCallLowerer::visit(BinaryExpression *e) { // function argument lowering /////////////////////////////////////////////////////////////////////////////// -call_list_type +expr_list_type lower_function_arguments(std::vector<expression_ptr>& args) { - call_list_type new_statements; + expr_list_type new_statements; for(auto it=args.begin(); it!=args.end(); ++it) { // get reference to the unique_ptr with the expression auto& e = *it; diff --git a/modcc/functionexpander.hpp b/modcc/functionexpander.hpp index c9185332d458a9595d077fb6e841f9f5dbd5c9b8..b1ea95a2f3777d7c45869728406883074af02862 100644 --- a/modcc/functionexpander.hpp +++ b/modcc/functionexpander.hpp @@ -2,19 +2,17 @@ #include <sstream> +#include "expression.hpp" #include "scope.hpp" #include "visitor.hpp" -// storage for a list of expressions -using call_list_type = std::list<expression_ptr>; - // Make a local declaration and assignment for the given expression, // and insert at the front and back respectively of the statement list. // Return the new unique local identifier. -expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* e); +expression_ptr insert_unique_local_assignment(expr_list_type& stmts, Expression* e); // prototype for lowering function calls -call_list_type lower_function_calls(Expression* e); +expr_list_type lower_function_calls(Expression* e); /////////////////////////////////////////////////////////////////////////////// // visitor that takes function call sites and lowers them to inline assignments @@ -48,11 +46,11 @@ public: void visit(NumberExpression *e) override {}; void visit(IdentifierExpression *e) override {}; - call_list_type& calls() { + expr_list_type& calls() { return calls_; } - call_list_type move_calls() { + expr_list_type move_calls() { return std::move(calls_); } @@ -67,7 +65,7 @@ private: replacer(std::move(id)); } - call_list_type calls_; + expr_list_type calls_; scope_ptr scope_; }; @@ -91,5 +89,5 @@ private: // If the calls_ data is spliced directly before the original statement // the function arguments will have been fully lowered /////////////////////////////////////////////////////////////////////////////// -call_list_type lower_function_arguments(std::vector<expression_ptr>& args); +expr_list_type lower_function_arguments(std::vector<expression_ptr>& args); diff --git a/modcc/functioninliner.cpp b/modcc/functioninliner.cpp index 837c36f03a3922e4bfbf7d8223bd3a2ee8024fc2..a652c60a1088dd904f2868e79ca2eba5d0dd8de4 100644 --- a/modcc/functioninliner.cpp +++ b/modcc/functioninliner.cpp @@ -5,8 +5,6 @@ #include "modccutil.hpp" #include "errorvisitor.hpp" -using namespace nest::mc; - expression_ptr inline_function_call(Expression* e) { if(auto f=e->is_function_call()) { @@ -35,12 +33,11 @@ expression_ptr inline_function_call(Expression* e) << id->to_string() << " -> " << fargs[i]->to_string() << " in the expression " << new_e->to_string() << "\n"; #endif - auto v = - util::make_unique<VariableReplacer>( - fargs[i]->is_argument()->spelling(), - id->spelling() - ); - new_e->accept(v.get()); + VariableReplacer v( + fargs[i]->is_argument()->spelling(), + id->spelling() + ); + new_e->accept(&v); } else if(auto value = cargs[i]->is_number()) { #ifdef LOGGING @@ -48,12 +45,11 @@ expression_ptr inline_function_call(Expression* e) << value->to_string() << " -> " << fargs[i]->to_string() << " in the expression " << new_e->to_string() << "\n"; #endif - auto v = - util::make_unique<ValueInliner>( - fargs[i]->is_argument()->spelling(), - value->value() - ); - new_e->accept(v.get()); + ValueInliner v( + fargs[i]->is_argument()->spelling(), + value->value() + ); + new_e->accept(&v); } else { throw compiler_exception( @@ -64,12 +60,12 @@ expression_ptr inline_function_call(Expression* e) } new_e->semantic(e->scope()); - auto v = util::make_unique<ErrorVisitor>(""); - new_e->accept(v.get()); + ErrorVisitor v(""); + new_e->accept(&v); #ifdef LOGGING std::cout << "inline_function_call result " << new_e->to_string() << "\n\n"; #endif - if(v->num_errors()) { + if(v.num_errors()) { throw compiler_exception("something went wrong with inlined function call ", e->location()); } diff --git a/modcc/kinrewriter.hpp b/modcc/kineticrewriter.cpp similarity index 58% rename from modcc/kinrewriter.hpp rename to modcc/kineticrewriter.cpp index 2ccf5c8aa00dee993fe180fd025397f06f3107eb..32b6d52a6424aca58c5bda5628db027272c00480 100644 --- a/modcc/kinrewriter.hpp +++ b/modcc/kineticrewriter.cpp @@ -1,72 +1,48 @@ -#pragma once - #include <iostream> #include <map> #include <string> #include <list> #include "astmanip.hpp" +#include "symdiff.hpp" #include "visitor.hpp" -using stmt_list_type = std::list<expression_ptr>; - -class KineticRewriter : public Visitor { +class KineticRewriter : public BlockRewriterBase { public: - virtual void visit(Expression *) override; + using BlockRewriterBase::visit; - virtual void visit(UnaryExpression *e) override { visit((Expression*)e); } - virtual void visit(BinaryExpression *e) override { visit((Expression*)e); } + KineticRewriter() {} + KineticRewriter(scope_ptr enclosing_scope): BlockRewriterBase(enclosing_scope) {} virtual void visit(ConserveExpression *e) override; virtual void visit(ReactionExpression *e) override; - virtual void visit(BlockExpression *e) override; - virtual void visit(ProcedureExpression* e) override; - - symbol_ptr as_procedure() { - stmt_list_type body_stmts; - for (const auto& s: statements) body_stmts.push_back(s->clone()); - - auto body = make_expression<BlockExpression>( - proc_loc, - std::move(body_stmts), - false); - - return make_symbol<ProcedureExpression>( - proc_loc, - proc_name, - std::vector<expression_ptr>(), - std::move(body)); - } -private: - // Name and location of original kinetic procedure (used for `as_procedure` above). - std::string proc_name; - Location proc_loc; +protected: + virtual void reset() override { + BlockRewriterBase::reset(); + dterms.clear(); + } - // Statements in replacement procedure body. - stmt_list_type statements; + virtual void finalize() override; +private: // Acccumulated terms for derivative expressions, keyed by id name. std::map<std::string, expression_ptr> dterms; - - // Reset state (at e.g. start of kinetic proc). - void reset() { - proc_name = ""; - statements.clear(); - dterms.clear(); - } }; -// By default, copy statements across verbatim. -inline void KineticRewriter::visit(Expression* e) { - statements.push_back(e->clone()); +expression_ptr kinetic_rewrite(BlockExpression* block) { + KineticRewriter visitor; + block->accept(&visitor); + return visitor.as_block(false); } -inline void KineticRewriter::visit(ConserveExpression*) { +// KineticRewriter implementation follows. + +void KineticRewriter::visit(ConserveExpression*) { // Deliberately ignoring these for now! } -inline void KineticRewriter::visit(ReactionExpression* e) { +void KineticRewriter::visit(ReactionExpression* e) { Location loc = e->location(); scope_ptr scope = e->scope(); @@ -79,10 +55,9 @@ inline void KineticRewriter::visit(ReactionExpression* e) { auto& id = term->is_stoich_term()->ident(); auto& coeff = term->is_stoich_term()->coeff(); - fwd = make_expression<MulBinaryExpression>( - loc, + fwd = constant_simplify(make_expression<MulBinaryExpression>(loc, make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), - std::move(fwd)); + std::move(fwd))); } // Similar for reverse rate. @@ -93,10 +68,9 @@ inline void KineticRewriter::visit(ReactionExpression* e) { auto& id = term->is_stoich_term()->ident(); auto& coeff = term->is_stoich_term()->coeff(); - rev = make_expression<MulBinaryExpression>( - loc, - make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), - std::move(rev)); + rev = constant_simplify(make_expression<MulBinaryExpression>(loc, + make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), + std::move(rev))); } auto net_rate = make_expression<SubBinaryExpression>( @@ -105,11 +79,11 @@ inline void KineticRewriter::visit(ReactionExpression* e) { net_rate->semantic(scope); auto local_net_rate = make_unique_local_assign(scope, net_rate, "rate"); - statements.push_back(std::move(local_net_rate.local_decl)); - statements.push_back(std::move(local_net_rate.assignment)); + statements_.push_back(std::move(local_net_rate.local_decl)); + statements_.push_back(std::move(local_net_rate.assignment)); scope = local_net_rate.scope; // nop for now... - auto net_rate_sym = std::move(local_net_rate.id); + const auto& net_rate_sym = local_net_rate.id; // Net change in quantity after forward reaction: // e.g. A + ... <-> 3A + ... @@ -142,8 +116,8 @@ inline void KineticRewriter::visit(ReactionExpression* e) { term->semantic(scope); auto local_term = make_unique_local_assign(scope, term, p.first+"_rate"); - statements.push_back(std::move(local_term.local_decl)); - statements.push_back(std::move(local_term.assignment)); + statements_.push_back(std::move(local_term.local_decl)); + statements_.push_back(std::move(local_term.assignment)); scope = local_term.scope; // nop for now... auto& dterm = dterms[p.first]; @@ -163,13 +137,8 @@ inline void KineticRewriter::visit(ReactionExpression* e) { } } -inline void KineticRewriter::visit(ProcedureExpression* e) { - reset(); - proc_name = e->name(); - proc_loc = e->location(); - e->body()->accept(this); - - // make new procedure from saved statements and terms +void KineticRewriter::finalize() { + // append new derivative assignments from saved terms for (auto& p: dterms) { auto loc = p.second->location(); auto scope = p.second->scope(); @@ -184,13 +153,7 @@ inline void KineticRewriter::visit(ProcedureExpression* e) { std::move(deriv), std::move(p.second)); - assign->scope(scope); // don't re-do semantic analysis here - statements.push_back(std::move(assign)); + statements_.push_back(std::move(assign)); } } -inline void KineticRewriter::visit(BlockExpression* e) { - for (auto& s: e->statements()) { - s->accept(this); - } -} diff --git a/modcc/kineticrewriter.hpp b/modcc/kineticrewriter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3dfb2648a1006bc6432e9bd7d082e988b49a0795 --- /dev/null +++ b/modcc/kineticrewriter.hpp @@ -0,0 +1,6 @@ +#pragma once + +#include "expression.hpp" + +// Translate a supplied KINETIC block to equivalent DERIVATIVE block. +expression_ptr kinetic_rewrite(BlockExpression*); diff --git a/modcc/modcc.cpp b/modcc/modcc.cpp index 0375eece7adddbe903a58ee95372e2a8479d7ba1..c1c210ebce0c098278bad2ae0a1a1e0b80dbf6e3 100644 --- a/modcc/modcc.cpp +++ b/modcc/modcc.cpp @@ -13,8 +13,6 @@ #include "modccutil.hpp" #include "options.hpp" -using namespace nest::mc; - //#define VERBOSE int main(int argc, char **argv) { @@ -113,7 +111,7 @@ int main(int argc, char **argv) { std::cout << m.error_string() << std::endl; } - if(m.status() == lexerStatus::error) { + if(m.has_error()) { return 1; } @@ -123,7 +121,7 @@ int main(int argc, char **argv) { if(Options::instance().optimize) { if(Options::instance().verbose) std::cout << green("[") + "optimize" + green("]") << std::endl; m.optimize(); - if(m.status() == lexerStatus::error) { + if(m.has_error()) { return 1; } } @@ -175,15 +173,15 @@ int main(int argc, char **argv) { std::cout << yellow("method " + method->name()) << "\n"; std::cout << white("-------------------------\n"); - auto flops = util::make_unique<FlopVisitor>(); - method->accept(flops.get()); + FlopVisitor flops; + method->accept(&flops); std::cout << white("FLOPS") << std::endl; - std::cout << flops->print() << std::endl; + std::cout << flops.print() << std::endl; std::cout << white("MEMOPS") << std::endl; - auto memops = util::make_unique<MemOpVisitor>(); - method->accept(memops.get()); - std::cout << memops->print() << std::endl;; + MemOpVisitor memops; + method->accept(&memops); + std::cout << memops.print() << std::endl;; } } } diff --git a/modcc/modccutil.hpp b/modcc/modccutil.hpp index 63f0701a57da67c4cf029882cc1da1417df9374e..ad83a03a30400e2bd02c34f4d08cdf104a8364b8 100644 --- a/modcc/modccutil.hpp +++ b/modcc/modccutil.hpp @@ -6,25 +6,39 @@ #include <vector> #include <initializer_list> -// is thing in list? -template <typename T, int N> -bool is_in(T thing, const T (&list)[N]) { - for(auto const& item : list) { - if(thing==item) { - return true; +namespace impl { + template <typename C, typename V> + struct has_count_method { + template <typename T, typename U> + static decltype(std::declval<T>().count(std::declval<U>()), std::true_type{}) test(int); + template <typename T, typename U> + static std::false_type test(...); + + using type = decltype(test<C, V>(0)); + }; + + template <typename X, typename C> + bool is_in(const X& x, const C& c, std::false_type) { + for (const auto& y: c) { + if (y==x) return true; } + return false; } - return false; -} -template <typename T> -bool is_in(T thing, const std::initializer_list<T> list) { - for(auto const& item : list) { - if(thing==item) { - return true; - } + template <typename X, typename C> + bool is_in(const X& x, const C& c, std::true_type) { + return !!c.count(x); } - return false; +} + +template <typename X, typename C> +bool is_in(const X& x, const C& c) { + return impl::is_in(x, c, typename impl::has_count_method<C,X>::type{}); +} + +template <typename X> +bool is_in(const X& x, const std::initializer_list<X>& c) { + return impl::is_in(x, c, std::false_type{}); } inline std::string pprintf(const char *s) { @@ -138,16 +152,8 @@ std::ostream& operator<< (std::ostream& os, std::vector<T> const& V) { return os << "]"; } -namespace nest { -namespace mc { -namespace util { - -// just because we aren't using C++14, doesn't mean we shouldn't go -// without make_unique template <typename T, typename... Args> std::unique_ptr<T> make_unique(Args&&... args) { return std::unique_ptr<T>(new T(std::forward<Args>(args) ...)); } -}}} - diff --git a/modcc/module.cpp b/modcc/module.cpp index e65aecd03a3fab2055c6a8526ea3e474b438b566..2b1590d667c3872aa4dfd142ecd0fea816baba20 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -8,10 +8,82 @@ #include "expressionclassifier.hpp" #include "functionexpander.hpp" #include "functioninliner.hpp" +#include "kineticrewriter.hpp" #include "module.hpp" #include "parser.hpp" +#include "solvers.hpp" +#include "symdiff.hpp" +#include "visitor.hpp" -using namespace nest::mc; +class NrnCurrentRewriter: public BlockRewriterBase { + expression_ptr id(const std::string& name, Location loc) { + return make_expression<IdentifierExpression>(loc, name); + } + + expression_ptr id(const std::string& name) { + return id(name, loc_); + } + + static ionKind is_ion_update(Expression* e) { + if(auto a = e->is_assignment()) { + if(auto sym = a->lhs()->is_identifier()->symbol()) { + if(auto var = sym->is_local_variable()) { + return var->ion_channel(); + } + } + } + return ionKind::none; + } + + moduleKind kind_; + bool has_current_update_ = false; + +public: + using BlockRewriterBase::visit; + + explicit NrnCurrentRewriter(moduleKind kind): kind_(kind) {} + + virtual void finalize() override { + if (has_current_update_) { + // Initialize current_ as first statement. + statements_.push_front(make_expression<AssignmentExpression>(loc_, + id("current_"), + make_expression<NumberExpression>(loc_, 0.0))); + + if (kind_==moduleKind::density) { + statements_.push_back(make_expression<AssignmentExpression>(loc_, + id("current_"), + make_expression<MulBinaryExpression>(loc_, + id("weights_"), + id("current_")))); + } + } + } + + virtual void visit(SolveExpression *e) override {} + virtual void visit(ConductanceExpression *e) override {} + virtual void visit(AssignmentExpression *e) override { + statements_.push_back(e->clone()); + auto loc = e->location(); + + if (is_ion_update(e)!=ionKind::none) { + has_current_update_ = true; + + if (!linear_test(e->rhs(), {"v"}).is_linear) { + error({"current update expressions must be linear in v: "+e->rhs()->to_string(), + e->location()}); + return; + } + else { + statements_.push_back(make_expression<AssignmentExpression>(loc, + id("current_", loc), + make_expression<AddBinaryExpression>(loc, + id("current_", loc), + e->lhs()->clone()))); + } + } + } +}; Module::Module(std::string const& fname) : fname_(fname) @@ -78,22 +150,26 @@ Module::symbols() const { return symbols_; } -void Module::error(std::string const& msg, Location loc) { - std::string location_info = pprintf("%:% ", file_name(), loc); - if(error_string_.size()) {// append to current string - error_string_ += "\n"; +std::string Module::error_string() const { + std::string str; + for (const error_entry& entry: errors()) { + if (!str.empty()) str += '\n'; + str += red("error "); + str += white(pprintf("%:% ", file_name(), entry.location)); + str += entry.message; } - error_string_ += red("error ") + white(location_info) + msg; - status_ = lexerStatus::error; + return str; } -void Module::warning(std::string const& msg, Location loc) { - std::string location_info = pprintf("%:% ", file_name(), loc); - if(error_string_.size()) {// append to current string - error_string_ += "\n"; +std::string Module::warning_string() const { + std::string str; + for (const error_entry& entry: errors()) { + if (!str.empty()) str += '\n'; + str += purple("error "); + str += white(pprintf("%:% ", file_name(), entry.location)); + str += entry.message; } - error_string_ += purple("warning ") + white(location_info) + msg; - has_warning_ = true; + return str; } bool Module::semantic() { @@ -165,13 +241,13 @@ bool Module::semantic() { s->semantic(symbols_); // then use an error visitor to print out all the semantic errors - auto v = util::make_unique<ErrorVisitor>(file_name()); - s->accept(v.get()); - errors += v->num_errors(); + ErrorVisitor v(file_name()); + s->accept(&v); + errors += v.num_errors(); // inline function calls // this requires that the symbol table has already been built - if(v->num_errors()==0) { + if(v.num_errors()==0) { auto &b = s->kind()==symbolKind::function ? s->is_function()->body()->statements() : s->is_procedure()->body()->statements(); @@ -246,9 +322,7 @@ bool Module::semantic() { } if(errors) { - std::cout << "\nthere were " << errors - << " errors in the semantic analysis" << std::endl; - status_ = lexerStatus::error; + error("There were "+std::to_string(errors)+" errors in the semantic analysis"); return false; } @@ -281,7 +355,7 @@ bool Module::semantic() { loc, name, std::vector<expression_ptr>(), // no arguments make_expression<BlockExpression> - (loc, std::list<expression_ptr>(), false) + (loc, expr_list_type(), false) ); auto proc = symbols_[name]->is_api_method(); @@ -312,46 +386,6 @@ bool Module::semantic() { return false; } - - // evaluate whether an expression has the form - // (b - x)/a - // where x is a state variable with name state_variable - // this test is used to detect ODEs with the signature - // dx/dt = (xinf - x)/xtau - // so that we can integrate them efficiently using the cnexp integrator - // - // this is messy, but ok for a once off. If this pattern is - // repeated, it will be worth finding a more sophisticated solution - auto is_gating = [] (Expression* e, std::string const& state_variable) { - IdentifierExpression* a = nullptr; - IdentifierExpression* b = nullptr; - BinaryExpression* other = nullptr; - if(auto binop = e->is_binary()) { - if(binop->op()==tok::divide) { - if((a = binop->rhs()->is_identifier())) { - other = binop->lhs()->is_binary(); - } - } - } - if(other) { - if(other->op()==tok::minus) { - if(auto rhs = other->rhs()->is_identifier()) { - if(rhs->name() == state_variable) { - if(auto lhs = other->lhs()->is_identifier()) { - if(lhs->name() != state_variable) { - b = lhs; - return std::make_pair(a, b); - } - } - } - } - } - } - - a = b = nullptr; - return std::make_pair(a, b); - }; - // Look in the symbol table for a procedure with the name "breakpoint". // This symbol corresponds to the BREAKPOINT block in the .mod file // There are two APIMethods generated from BREAKPOINT. @@ -361,222 +395,118 @@ bool Module::semantic() { auto api_state = state_api.first; auto breakpoint = state_api.second; - if( breakpoint ) { - // helper for making identifiers on the fly - auto id = [] (std::string const& name, Location loc=Location()) { - return make_expression<IdentifierExpression>(loc, name); - }; + api_state->semantic(symbols_); + scope_ptr nrn_state_scope = api_state->scope(); + + if(breakpoint) { //.......................................................... // nrn_state : The temporal integration of state variables //.......................................................... - // find the SOLVE statement - SolveExpression* solve_expression = nullptr; - for(auto& e: *(breakpoint->body())) { - solve_expression = e->is_solve_statement(); - if(solve_expression) break; - } + // grab SOLVE statements, put them in `nrn_state` after translation. + bool found_solve = false; + bool found_non_solve = false; + std::set<std::string> solved_ids; - // handle the case where there is no SOLVE in BREAKPOINT - if( solve_expression==nullptr ) { - warning( " there is no SOLVE statement, required to update the" - " state variables, in the BREAKPOINT block", - breakpoint->location()); - } - else { - // get the DERIVATIVE block - auto dblock = solve_expression->procedure(); - - // body refers to the currently empty body of the APIMethod that - // will hold the AST for the nrn_state function. - auto& body = api_state->body()->statements(); - - auto has_provided_integration_method = - solve_expression->method() == solverMethod::cnexp; - - // loop over the statements in the SOLVE block from the mod file - // put each statement into the new APIMethod, performing - // transformations if necessary. - for(auto& e : *(dblock->body())) { - if(auto ass = e->is_assignment()) { - auto lhs = ass->lhs(); - auto rhs = ass->rhs(); - if(auto deriv = lhs->is_derivative()) { - // Check that a METHOD was provided in the original SOLVE - // statment. We have to do this because it is possible - // to call SOLVE without a METHOD, in which case there should - // be no derivative expressions in the DERIVATIVE block. - if(!has_provided_integration_method) { - error("The DERIVATIVE block has a derivative expression" - " but no METHOD was specified in the SOLVE statement", - deriv->location()); - return false; - } + for(auto& e: (breakpoint->body()->statements())) { + SolveExpression* solve_expression = e->is_solve_statement(); + if(!solve_expression) { + found_non_solve = true; + continue; + } + if(found_non_solve) { + error("SOLVE statements must come first in BREAKPOINT block", + e->location()); + return false; + } - auto sym = deriv->symbol(); - auto name = deriv->name(); - - auto gating_vars = is_gating(rhs, name); - if(gating_vars.first && gating_vars.second) { - auto const& inf = gating_vars.second->spelling(); - auto const& rate = gating_vars.first->spelling(); - auto e_string = name + "=" + inf - + "+(" + name + "-" + inf + ")*exp(-dt/" - + rate + ")"; - auto stmt_update = Parser(e_string).parse_line_expression(); - body.emplace_back(std::move(stmt_update)); - continue; - } - else { - // create visitor for linear analysis - auto v = util::make_unique<ExpressionClassifierVisitor>(sym); - rhs->accept(v.get()); - - // quit if ODE is not linear - if( v->classify() != expressionClassification::linear ) { - error("unable to integrate nonlinear state ODEs", - rhs->location()); - return false; - } - - // the linear differential equation is of the form - // s' = a*s + b - // integration by separation of variables gives the following - // update function to integrate s for one time step dt - // s = -b/a + (s+b/a)*exp(a*dt) - // we are going to build this update function by - // 1. generating statements that define a_=a and ba_=b/a - // 2. generating statements that update the solution - - // statement : a_ = a - auto stmt_a = - binary_expression(Location(), - tok::eq, - id("a_"), - v->linear_coefficient()->clone()); - - // expression : b/a - auto expr_ba = - binary_expression(Location(), - tok::divide, - v->constant_term()->clone(), - id("a_")); - // statement : ba_ = b/a - auto stmt_ba = binary_expression(Location(), tok::eq, id("ba_"), std::move(expr_ba)); - - // the update function - auto e_string = name + " = -ba_ + " - "(" + name + " + ba_)*exp(a_*dt)"; - auto stmt_update = Parser(e_string).parse_line_expression(); - - // add declaration of local variables - body.emplace_back(Parser("LOCAL a_").parse_local()); - body.emplace_back(Parser("LOCAL ba_").parse_local()); - // add integration statements - body.emplace_back(std::move(stmt_a)); - body.emplace_back(std::move(stmt_ba)); - body.emplace_back(std::move(stmt_update)); - continue; - } - } - else { - body.push_back(e->clone()); - continue; - } - } - body.push_back(e->clone()); + found_solve = true; + std::unique_ptr<SolverVisitorBase> solver; + + switch(solve_expression->method()) { + case solverMethod::cnexp: + solver = make_unique<CnexpSolverVisitor>(); + break; + case solverMethod::sparse: + solver = make_unique<SparseSolverVisitor>(); + break; + case solverMethod::none: + solver = make_unique<DirectSolverVisitor>(); + break; } - } - // perform semantic analysis - api_state->semantic(symbols_); + // If the derivative block is a kinetic block, perform kinetic + // rewrite first. - //.......................................................... - // nrn_current : update contributions to currents - //.......................................................... - std::list<expression_ptr> block; - - // helper which tests a statement to see if it updates an ion - // channel variable. - auto is_ion_update = [] (Expression* e) { - if(auto a = e->is_assignment()) { - // semantic analysis has been performed on the original expression - // which ensures that the lhs is an identifier and a variable - if(auto sym = a->lhs()->is_identifier()->symbol()) { - // assume that a scalar stack variable is being used for - // the indexed value: i.e. the value is not cached - if(auto var = sym->is_local_variable()) { - return var->ion_channel(); - } - } + auto deriv = solve_expression->procedure(); + + if (deriv->kind()==procedureKind::kinetic) { + kinetic_rewrite(deriv->body())->accept(solver.get()); } - return ionKind::none; - }; - - // add statements that initialize the reduction variables - bool has_current_update = false; - for(auto& e: *(breakpoint->body())) { - // ignore solve and conductance statements - if(e->is_solve_statement()) continue; - if(e->is_conductance_statement()) continue; - - // add the expression - block.emplace_back(e->clone()); - - // we are updating an ionic current - // so keep track of current and conductance accumulation - auto channel = is_ion_update(e.get()); - if(channel != ionKind::none) { - auto lhs = e->is_assignment()->lhs()->is_identifier(); - auto rhs = e->is_assignment()->rhs(); - - // analyze the expression for linear terms - //auto v = util::make_unique<ExpressionClassifierVisitor>(symbols_["v"].get()); - auto v_symbol = breakpoint->scope()->find("v"); - auto v = util::make_unique<ExpressionClassifierVisitor>(v_symbol); - rhs->accept(v.get()); - - if(v->classify()==expressionClassification::linear) { - // add current update - if(has_current_update) { - block.emplace_back(Parser("current_ = current_ + " + lhs->name()).parse_line_expression()); - } - else { - block.emplace_back(Parser("current_ = " + lhs->name()).parse_line_expression()); + else { + deriv->body()->accept(solver.get()); + } + + if (auto solve_block = solver->as_block(false)) { + // Check that we didn't solve an already solved variable. + for (const auto& id: solver->solved_identifiers()) { + if (solved_ids.count(id)>0) { + error("Variable "+id+" solved twice!", e->location()); + return false; } + solved_ids.insert(id); } - else { - error("current update functions must be a linear" - " function of v : " + rhs->to_string(), e->location()); - return false; + + // May have now redundant local variables; remove these first. + solve_block = remove_unused_locals(solve_block->is_block()); + + // Copy body into nrn_state. + for (auto& stmt: solve_block->is_block()->statements()) { + api_state->body()->statements().push_back(std::move(stmt)); } - has_current_update = true; } + else { + // Something went wrong: copy errors across. + append_errors(solver->errors()); + return false; + } + } + + // handle the case where there is no SOLVE in BREAKPOINT + if(!found_solve) { + warning(" there is no SOLVE statement, required to update the" + " state variables, in the BREAKPOINT block", + breakpoint->location()); } - if(has_current_update && kind()==moduleKind::density) { - block.emplace_back(Parser("current_ = weights_ * current_").parse_line_expression()); + else { + // redo semantic pass in order to elimate any removed local symbols. + api_state->semantic(symbols_); } - auto v = util::make_unique<ConstantFolderVisitor>(); - for(auto& e : block) { - e->accept(v.get()); + //.......................................................... + // nrn_current : update contributions to currents + //.......................................................... + NrnCurrentRewriter nrn_current_rewriter(kind()); + breakpoint->accept(&nrn_current_rewriter); + auto nrn_current_block = nrn_current_rewriter.as_block(); + if (!nrn_current_block) { + append_errors(nrn_current_rewriter.errors()); + return false; } symbols_["nrn_current"] = make_symbol<APIMethod>( breakpoint->location(), "nrn_current", std::vector<expression_ptr>(), - make_expression<BlockExpression>(breakpoint->location(), - std::move(block), false) - ); + constant_simplify(nrn_current_block)); symbols_["nrn_current"]->semantic(symbols_); } else { - error("a BREAKPOINT block is required", Location()); + error("a BREAKPOINT block is required"); return false; } - return status() == lexerStatus::happy; + return !has_error(); } /// populate the symbol table with class scope variables @@ -782,7 +712,7 @@ bool Module::optimize() { // how to structure the optimizer // loop over APIMethods // - apply optimization to each in turn - auto folder = util::make_unique<ConstantFolderVisitor>(); + ConstantFolderVisitor folder; for(auto &symbol : symbols_) { auto kind = symbol.second->kind(); BlockExpression* body; @@ -809,7 +739,7 @@ bool Module::optimize() { // perform constant folding for(auto& line : *body) { - line->accept(folder.get()); + line->accept(&folder); } // preform expression simplification diff --git a/modcc/module.hpp b/modcc/module.hpp index 5a16e64c45c75adbe17a2f88c169c73722b26cb4..41abdeb5256ed77b499ea6e0ed1aae038f2b38fd 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -4,11 +4,12 @@ #include <vector> #include "blocks.hpp" +#include "error.hpp" #include "expression.hpp" // wrapper around a .mod file -class Module { -public : +class Module: public error_stack { +public: using symbol_map = scope_type::symbol_map; using symbol_ptr = scope_type::symbol_ptr; @@ -58,36 +59,30 @@ public : symbol_map const& symbols() const; // error handling - void error(std::string const& msg, Location loc); - std::string const& error_string() { - return error_string_; + using error_stack::error; + void error(std::string const& msg, Location loc = Location{}) { + error({msg, loc}); } - lexerStatus status() const { - return status_; - } + std::string error_string() const; // warnings - void warning(std::string const& msg, Location loc); - bool has_warning() const { - return has_warning_; - } - bool has_error() const { - return status()==lexerStatus::error; + using error_stack::warning; + void warning(std::string const& msg, Location loc = Location{}) { + warning({msg, loc}); } - moduleKind kind() const { - return kind_; - } - void kind(moduleKind k) { - kind_ = k; - } + std::string warning_string() const; + + moduleKind kind() const { return kind_; } + void kind(moduleKind k) { kind_ = k; } // perform semantic analysis void add_variables_to_symbols(); bool semantic(); bool optimize(); -private : + +private: moduleKind kind_; std::string title_; std::string fname_; @@ -97,11 +92,6 @@ private : bool generate_current_api(); bool generate_state_api(); - // error handling - std::string error_string_; - lexerStatus status_ = lexerStatus::happy; - bool has_warning_ = false; - // AST storage std::vector<symbol_ptr> procedures_; std::vector<symbol_ptr> functions_; diff --git a/modcc/msparse.hpp b/modcc/msparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c7868764e93861d7d49bbb25b91ee3ecaa7cd573 --- /dev/null +++ b/modcc/msparse.hpp @@ -0,0 +1,206 @@ +#pragma once + +// (Possibly augmented) matrix implementation, represented as a vector of sparse rows. + +#include <algorithm> +#include <utility> +#include <initializer_list> +#include <iterator> +#include <stdexcept> +#include <vector> + +namespace msparse { + +struct msparse_error: std::runtime_error { + msparse_error(const std::string &what): std::runtime_error(what) {} +}; + +constexpr unsigned row_npos = unsigned(-1); + +// `msparse::row` represents one sparse matrix row as a vector of +// (column, value) pairs, ordered by (unsigned) column. `row_npos` +// is used to represent an invalid column number. + +template <typename X> +class row { +public: + struct entry { + unsigned col; + X value; + }; + static constexpr unsigned npos = row_npos; + +private: + std::vector<entry> data_; + + // Entries must have strictly monotonically increasing column numbers. + bool check_invariant() const { + for (unsigned i = 1; i<data_.size(); ++i) { + if (data_[i].col<=data_[i-1].col) return false; + } + return true; + } + +public: + row() = default; + row(const row&) = default; + + row(std::initializer_list<entry> il): data_(il) { + if (!check_invariant()) + throw msparse_error("improper row element list"); + } + + template <typename InIter> + row(InIter b, InIter e): data_(b, e) { + if (!check_invariant()) + throw msparse_error("improper row element list"); + } + + unsigned size() const { return data_.size(); } + bool empty() const { return size()==0; } + + // Iterators present row as sequence of `entry` objects. + auto begin() -> decltype(data_.begin()) { return data_.begin(); } + auto begin() const -> decltype(data_.cbegin()) { return data_.cbegin(); } + auto end() -> decltype(data_.end()) { return data_.end(); } + auto end() const -> decltype(data_.cend()) { return data_.cend(); } + + // Return column of first (left-most) entry. + unsigned mincol() const { + return empty()? npos: data_.front().col; + } + + // Return column of first entry with column greater than `c`. + unsigned mincol_after(unsigned c) const { + auto i = std::upper_bound(data_.begin(), data_.end(), c, + [](unsigned a, const entry& b) { return a<b.col; }); + + return i==data_.end()? npos: i->col; + } + + // Return column of last (right-most) entry. + unsigned maxcol() const { + return empty()? npos: data_.back().col; + } + + // As opposed to [] indexing (see below), retrieve `i'th entry from + // the list of entries. + const entry& get(unsigned i) const { + return data_[i]; + } + + void push_back(const entry& e) { + if (!empty() && e.col <= data_.back().col) + throw msparse_error("cannot push_back row elements out of order"); + data_.push_back(e); + } + + // Return index into entry list which has column `c`. + unsigned index(unsigned c) const { + auto i = std::lower_bound(data_.begin(), data_.end(), c, + [](const entry& a, unsigned b) { return a.col<b; }); + + return (i==data_.end() || i->col!=c)? npos: std::distance(data_.begin(), i); + } + + // Remove all entries from column `c` onwards. + void truncate(unsigned c) { + auto i = std::lower_bound(data_.begin(), data_.end(), c, + [](const entry& a, unsigned b) { return a.col<b; }); + data_.erase(i, data_.end()); + } + + // Return value at column `c`; if no entry in row, return default-constructed `X`, + // i.e. 0 for numeric types. + X operator[](unsigned c) const { + auto i = index(c); + return i==npos? X{}: data_[i].value; + } + + // Proxy object to allow assigning elements with the syntax `row[c] = value`. + struct assign_proxy { + row<X>& row_; + unsigned c; + + assign_proxy(row<X>& r, unsigned c): row_(r), c(c) {} + + operator X() const { return const_cast<const row<X>&>(row_)[c]; } + assign_proxy& operator=(const X& x) { + auto i = std::lower_bound(row_.data_.begin(), row_.data_.end(), c, + [](const entry& a, unsigned b) { return a.col<b; }); + + if (i==row_.data_.end() || i->col!=c) { + row_.data_.insert(i, {c, x}); + } + else if (x == X{}) { + row_.data_.erase(i); + } + else { + i->value = x; + } + + return *this; + } + }; + + assign_proxy operator[](unsigned c) { + return assign_proxy{*this, c}; + } +}; + +// `msparse::matrix` represents a matrix by a size (number of rows, +// columns) and vector of sparse `mspase::row` rows. +// +// The matrix may also be 'augmented', with columns corresponding to a second +// matrix appended on the right. + +template <typename X> +class matrix { +private: + std::vector<row<X>> rows; + unsigned cols = 0; + unsigned aug = row_npos; + +public: + static constexpr unsigned npos = row_npos; + + matrix() = default; + matrix(unsigned n, unsigned c): rows(n), cols(c) {} + + row<X>& operator[](unsigned i) { return rows[i]; } + const row<X>& operator[](unsigned i) const { return rows[i]; } + + unsigned size() const { return rows.size(); } + unsigned nrow() const { return size(); } + unsigned ncol() const { return cols; } + + // First column corresponding to the augmented submatrix. + unsigned augcol() const { return aug; } + + bool empty() const { return size()==0; } + bool augmented() const { return aug!=npos; } + + // Add a column on the right as part of the augmented submatrix. + // The new entries are provided by a (full, dense representation) + // sequence of values. + template <typename Seq> + void augment(const Seq& col_dense) { + unsigned r = 0; + for (const auto& v: col_dense) { + if (r>=rows.size()) throw msparse_error("augmented column size mismatch"); + rows[r++].push_back({cols, v}); + } + if (aug==npos) aug=cols; + ++cols; + } + + // Remove all augmented columns. + void diminish() { + if (aug==npos) return; + for (auto& row: rows) row.truncate(aug); + cols = aug; + aug = npos; + } +}; + +} // namespace msparse diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 7bf11c37f9140b6800240eca05256a44c1ed1146..84f7941371baef881239581a6e30bbc88eaeb07d 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -1312,8 +1312,16 @@ expression_ptr Parser::parse_solve() { } else { get_token(); // consume the METHOD keyword - if(token_.type != tok::cnexp) goto solve_statement_error; - method = solverMethod::cnexp; + switch(token_.type) { + case tok::cnexp: + method = solverMethod::cnexp; + break; + case tok::sparse: + method = solverMethod::sparse; + break; + default: + goto solve_statement_error; + } get_token(); // consume the method description } @@ -1326,10 +1334,11 @@ expression_ptr Parser::parse_solve() { solve_statement_error: error( "SOLVE statements must have the form\n" - " SOLVE x METHOD cnexp\n" + " SOLVE x METHOD method\n" " or\n" " SOLVE x\n" - "where 'x' is the name of a DERIVATIVE block", loc); + "where 'x' is the name of a DERIVATIVE block and " + "'method' is 'cnexp' or 'sparse'", loc); return nullptr; } @@ -1431,7 +1440,7 @@ expression_ptr Parser::parse_block(bool is_nested) { // save the location of the first statement as the starting point for the block Location block_location = token_.location; - std::list<expression_ptr> body; + expr_list_type body; while(token_.type != tok::rbrace) { auto e = parse_statement(); if(!e) return e; @@ -1473,7 +1482,7 @@ expression_ptr Parser::parse_initial() { if(!expect(tok::lbrace)) return nullptr; get_token(); // consume '{' - std::list<expression_ptr> body; + expr_list_type body; while(token_.type != tok::rbrace) { auto e = parse_statement(); if(!e) return e; diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f3f4a7d6f7d9675ca28bde0300174bc7306896b3 --- /dev/null +++ b/modcc/solvers.cpp @@ -0,0 +1,448 @@ +#include <map> +#include <set> +#include <string> +#include <vector> + +#include "astmanip.hpp" +#include "expression.hpp" +#include "parser.hpp" +#include "solvers.hpp" +#include "symdiff.hpp" +#include "symge.hpp" +#include "visitor.hpp" + +// Cnexp solver visitor implementation. + +void CnexpSolverVisitor::visit(BlockExpression* e) { + // Do a first pass to extract variables comprising ODE system + // lhs; can't really trust 'STATE' block. + + for (auto& stmt: e->statements()) { + if (stmt && stmt->is_assignment() && stmt->is_assignment()->lhs()->is_derivative()) { + auto id = stmt->is_assignment()->lhs()->is_derivative(); + dvars_.push_back(id->name()); + } + } + + BlockRewriterBase::visit(e); +} + +void CnexpSolverVisitor::visit(AssignmentExpression *e) { + auto loc = e->location(); + scope_ptr scope = e->scope(); + + auto lhs = e->lhs(); + auto rhs = e->rhs(); + auto deriv = lhs->is_derivative(); + + if (!deriv) { + statements_.push_back(e->clone()); + return; + } + + auto s = deriv->name(); + linear_test_result r = linear_test(rhs, dvars_); + + if (!r.monolinear(s)) { + error({"System not diagonal linear for cnexp", loc}); + return; + } + + Expression* coef = r.coef[s].get(); + if (r.is_homogeneous) { + // s' = a*s becomes s = s*exp(a*dt); use a_ as a local variable + // for the coefficient. + auto local_a_term = make_unique_local_assign(scope, coef, "a_"); + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + std::string s_update = pprintf("% = %*exp(%*dt)", s, s, a_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } + else if (is_zero(coef)) { + // s' = b becomes s = s + b*dt; use b_ as a local variable for + // the constant term b. + + auto local_b_term = make_unique_local_assign(scope, r.constant.get(), "b_"); + statements_.push_back(std::move(local_b_term.local_decl)); + statements_.push_back(std::move(local_b_term.assignment)); + auto b_ = local_b_term.id->is_identifier()->spelling(); + + std::string s_update = pprintf("% = %+%*dt", s, s, b_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } + else { + // s' = a*s + b becomes s = -b/a + (s+b/a)*exp(a*dt); use + // a_ as a local variable for the coefficient and ba_ for the + // quotient. + // + // Note though this will be numerically bad for very small + // (or zero) a. Perhaps re-implement as: + // s = s + exprel(a*dt)*(s*a+b)*dt + // where exprel(x) = (exp(x)-1)/x and can be well approximated + // by e.g. a Taylor expansion for small x. + // + // Special case ('gating variable') when s' = (b-s)/a; rather + // than implement more general algebraic simplification, jump + // straight to simplified update: s = b + (s-b)*exp(-dt/a). + + // Check for 'gating' case: + if (rhs->is_binary() && rhs->is_binary()->op()==tok::divide) { + auto denom = rhs->is_binary()->rhs(); + if (involves_identifier(denom, s)) { + goto not_gating; + } + auto numer = rhs->is_binary()->lhs(); + linear_test_result r = linear_test(numer, {s}); + if (expr_value(r.coef[s]) != -1) { + goto not_gating; + } + + auto local_a_term = make_unique_local_assign(scope, denom, "a_"); + auto a_ = local_a_term.id->is_identifier()->spelling(); + auto local_b_term = make_unique_local_assign(scope, r.constant, "b_"); + auto b_ = local_b_term.id->is_identifier()->spelling(); + + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + statements_.push_back(std::move(local_b_term.local_decl)); + statements_.push_back(std::move(local_b_term.assignment)); + + std::string s_update = pprintf("% = %+(%-%)*exp(-dt/%)", s, b_, s, b_, a_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } + +not_gating: + auto local_a_term = make_unique_local_assign(scope, coef, "a_"); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + auto ba_expr = make_expression<DivBinaryExpression>(loc, + r.constant->clone(), local_a_term.id->clone()); + auto local_ba_term = make_unique_local_assign(scope, ba_expr, "ba_"); + auto ba_ = local_ba_term.id->is_identifier()->spelling(); + + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + statements_.push_back(std::move(local_ba_term.local_decl)); + statements_.push_back(std::move(local_ba_term.assignment)); + + std::string s_update = pprintf("% = -%+(%+%)*exp(%*dt)", s, ba_, s, ba_, a_); + statements_.push_back(Parser{s_update}.parse_line_expression()); + return; + } +} + + +// Sparse solver visitor implementation. + +static expression_ptr as_expression(symge::symbol_term term) { + Location loc; + if (term.is_zero()) { + return make_expression<IntegerExpression>(loc, 0); + } + else { + return make_expression<MulBinaryExpression>(loc, + make_expression<IdentifierExpression>(loc, name(term.left)), + make_expression<IdentifierExpression>(loc, name(term.right))); + } +} + +static expression_ptr as_expression(symge::symbol_term_diff diff) { + Location loc; + if (diff.left.is_zero() && diff.right.is_zero()) { + return make_expression<IntegerExpression>(loc, 0); + } + else if (diff.right.is_zero()) { + return as_expression(diff.left); + } + else if (diff.left.is_zero()) { + return make_expression<NegUnaryExpression>(loc, + as_expression(diff.right)); + } + else { + return make_expression<SubBinaryExpression>(loc, + as_expression(diff.left), + as_expression(diff.right)); + } +} + +void SparseSolverVisitor::visit(BlockExpression* e) { + // Do a first pass to extract variables comprising ODE system + // lhs; can't really trust 'STATE' block. + + for (auto& stmt: e->statements()) { + if (stmt && stmt->is_assignment() && stmt->is_assignment()->lhs()->is_derivative()) { + auto id = stmt->is_assignment()->lhs()->is_derivative(); + dvars_.push_back(id->name()); + } + } + + BlockRewriterBase::visit(e); +} + +void SparseSolverVisitor::visit(AssignmentExpression *e) { + if (A_.empty()) { + unsigned n = dvars_.size(); + A_ = symge::sym_matrix(n, n); + } + + auto loc = e->location(); + scope_ptr scope = e->scope(); + + auto lhs = e->lhs(); + auto rhs = e->rhs(); + auto deriv = lhs->is_derivative(); + + if (!deriv) { + statements_.push_back(e->clone()); + + auto id = lhs->is_identifier(); + if (id) { + auto expand = substitute(rhs, local_expr_); + if (involves_identifier(expand, dvars_)) { + local_expr_[id->spelling()] = std::move(expand); + } + } + return; + } + + auto s = deriv->name(); + auto expanded_rhs = substitute(rhs, local_expr_); + linear_test_result r = linear_test(expanded_rhs, dvars_); + if (!r.is_homogeneous) { + error({"System not homogeneous linear for sparse", loc}); + return; + } + + // Populate sparse symbolic matrix for GE. + if (s!=dvars_[deq_index_]) { + error({"ICE: inconsistent ordering of derivative assignments", loc}); + } + + auto dt_expr = make_expression<IdentifierExpression>(loc, "dt"); + auto one_expr = make_expression<NumberExpression>(loc, 1.0); + for (unsigned j = 0; j<dvars_.size(); ++j) { + expression_ptr expr; + + // For zero coefficient and diagonal element, the matrix entry is 1. + // For non-zero coefficient c and diagonal element, the entry is 1-c*dt. + // Otherwise, for non-zero coefficient c, the entry is -c*dt. + + if (r.coef.count(dvars_[j])) { + expr = make_expression<MulBinaryExpression>(loc, + r.coef[dvars_[j]]->clone(), + dt_expr->clone()); + } + + if (j==deq_index_) { + if (expr) { + expr = make_expression<SubBinaryExpression>(loc, + one_expr->clone(), + std::move(expr)); + } + else { + expr = one_expr->clone(); + } + } + else if (expr) { + expr = make_expression<NegUnaryExpression>(loc, std::move(expr)); + } + + if (!expr) continue; + + auto local_a_term = make_unique_local_assign(scope, expr.get(), "a_"); + auto a_ = local_a_term.id->is_identifier()->spelling(); + + statements_.push_back(std::move(local_a_term.local_decl)); + statements_.push_back(std::move(local_a_term.assignment)); + + A_[deq_index_].push_back({j, symtbl_.define(a_)}); + } + ++deq_index_; +} + +void SparseSolverVisitor::finalize() { + std::vector<symge::symbol> rhs; + for (const auto& var: dvars_) { + rhs.push_back(symtbl_.define(var)); + } + A_.augment(rhs); + + symge::gj_reduce(A_, symtbl_); + + // Create and assign intermediate variables. + for (unsigned i = 0; i<symtbl_.size(); ++i) { + symge::symbol s = symtbl_[i]; + + if (primitive(s)) continue; + + auto expr = as_expression(definition(s)); + auto local_t_term = make_unique_local_assign(block_scope_, expr.get(), "t_"); + auto t_ = local_t_term.id->is_identifier()->spelling(); + symtbl_.name(s, t_); + + statements_.push_back(std::move(local_t_term.local_decl)); + statements_.push_back(std::move(local_t_term.assignment)); + } + + // State variable updates given by rhs/diagonal for reduced matrix. + Location loc; + for (unsigned i = 0; i<A_.nrow(); ++i) { + unsigned rhs = A_.augcol(); + + auto expr = + make_expression<AssignmentExpression>(loc, + make_expression<IdentifierExpression>(loc, dvars_[i]), + make_expression<DivBinaryExpression>(loc, + make_expression<IdentifierExpression>(loc, symge::name(A_[i][rhs])), + make_expression<IdentifierExpression>(loc, symge::name(A_[i][i])))); + + statements_.push_back(std::move(expr)); + } + + BlockRewriterBase::finalize(); +} + +// Implementation for `remove_unused_locals`: uses two visitors, +// `UnusedVisitor` and `RemoveVariableVisitor` below. + +class UnusedVisitor : public Visitor { +protected: + std::multimap<std::string, std::string> deps; + std::set<std::string> unused_ids; + std::set<std::string> used_ids; + Symbol* lhs_sym = nullptr; + +public: + using Visitor::visit; + + UnusedVisitor() {} + + virtual void visit(Expression* e) override {} + + virtual void visit(BlockExpression* e) override { + for (auto& s: e->statements()) { + s->accept(this); + } + } + + virtual void visit(AssignmentExpression* e) override { + auto lhs = e->lhs()->is_identifier(); + if (!lhs) return; + + lhs_sym = lhs->symbol(); + e->rhs()->accept(this); + lhs_sym = nullptr; + } + + virtual void visit(UnaryExpression* e) override { + e->expression()->accept(this); + } + + virtual void visit(BinaryExpression* e) override { + e->lhs()->accept(this); + e->rhs()->accept(this); + } + + virtual void visit(CallExpression* e) override { + for (auto& a: e->args()) { + a->accept(this); + } + } + + virtual void visit(IfExpression* e) override { + e->condition()->accept(this); + e->true_branch()->accept(this); + e->false_branch()->accept(this); + } + + virtual void visit(IdentifierExpression* e) override { + if (lhs_sym && lhs_sym->is_local_variable()) { + deps.insert({lhs_sym->name(), e->name()}); + } + else { + used_ids.insert(e->name()); + } + } + + virtual void visit(LocalDeclaration* e) override { + for (auto& v: e->variables()) { + unused_ids.insert(v.first); + } + } + + std::set<std::string> unused_locals() { + if (!computed_) { + for (auto& id: used_ids) { + remove_deps_from_unused(id); + } + computed_ = true; + } + return unused_ids; + } + + void reset() { + deps.clear(); + unused_ids.clear(); + used_ids.clear(); + computed_ = false; + } + +private: + bool computed_ = false; + + void remove_deps_from_unused(const std::string& id) { + auto range = deps.equal_range(id); + for (auto i = range.first; i != range.second; ++i) { + if (unused_ids.count(i->second)) { + remove_deps_from_unused(i->second); + } + } + unused_ids.erase(id); + } +}; + +class RemoveVariableVisitor: public BlockRewriterBase { + std::set<std::string> remove_; + +public: + using BlockRewriterBase::visit; + + RemoveVariableVisitor(std::set<std::string> ids): + remove_(std::move(ids)) {} + + RemoveVariableVisitor(std::set<std::string> ids, scope_ptr enclosing): + BlockRewriterBase(enclosing), remove_(std::move(ids)) {} + + virtual void visit(LocalDeclaration* e) override { + auto replacement = e->clone(); + auto& vars = replacement->is_local_declaration()->variables(); + + for (const auto& id: remove_) { + vars.erase(id); + } + if (!vars.empty()) { + statements_.push_back(std::move(replacement)); + } + } + + virtual void visit(AssignmentExpression* e) override { + std::string lhs_id = e->lhs()->is_identifier()->name(); + if (!remove_.count(lhs_id)) { + statements_.push_back(e->clone()); + } + } +}; + +expression_ptr remove_unused_locals(BlockExpression* block) { + UnusedVisitor unused_visitor; + block->accept(&unused_visitor); + + RemoveVariableVisitor remove_visitor(unused_visitor.unused_locals()); + block->accept(&remove_visitor); + return remove_visitor.as_block(false); +} diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e2bbc35323cda815a5f1918d9a16b7e00189d2b9 --- /dev/null +++ b/modcc/solvers.hpp @@ -0,0 +1,97 @@ +#pragma once + +// Transform derivative block into AST representing +// an integration step over the state variables, based on +// solver method. + +#include <string> +#include <vector> + +#include "expression.hpp" +#include "symdiff.hpp" +#include "symge.hpp" +#include "visitor.hpp" + +expression_ptr remove_unused_locals(BlockExpression* block); + +class SolverVisitorBase: public BlockRewriterBase { +protected: + // list of identifier names appearing in derivatives on lhs + std::vector<std::string> dvars_; + +public: + using BlockRewriterBase::visit; + + SolverVisitorBase() {} + SolverVisitorBase(scope_ptr enclosing): BlockRewriterBase(enclosing) {} + + virtual std::vector<std::string> solved_identifiers() const { + return dvars_; + } + + virtual void reset() override { + dvars_.clear(); + BlockRewriterBase::reset(); + } +}; + +class DirectSolverVisitor : public SolverVisitorBase { +public: + using SolverVisitorBase::visit; + + DirectSolverVisitor() {} + DirectSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(AssignmentExpression *e) override { + // No solver method, so declare an error if lhs is a derivative. + if(auto deriv = e->lhs()->is_derivative()) { + error({"The DERIVATIVE block has a derivative expression" + " but no METHOD was specified in the SOLVE statement", + deriv->location()}); + } + } +}; + +class CnexpSolverVisitor : public SolverVisitorBase { +public: + using SolverVisitorBase::visit; + + CnexpSolverVisitor() {} + CnexpSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(BlockExpression* e) override; + virtual void visit(AssignmentExpression *e) override; +}; + +class SparseSolverVisitor : public SolverVisitorBase { +protected: + // 'Current' differential equation is for variable with this + // index in `dvars`. + unsigned deq_index_ = 0; + + // Expanded local assignments that need to be substituted in for derivative + // calculations. + substitute_map local_expr_; + + // Symbolic matrix for backwards Euler step. + symge::sym_matrix A_; + + // 'Symbol table' for symbolic manipulation. + symge::symbol_table symtbl_; + +public: + using SolverVisitorBase::visit; + + SparseSolverVisitor() {} + SparseSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(BlockExpression* e) override; + virtual void visit(AssignmentExpression *e) override; + virtual void finalize() override; + virtual void reset() override { + deq_index_ = 0; + local_expr_.clear(); + symtbl_.clear(); + SolverVisitorBase::reset(); + } +}; diff --git a/modcc/symdiff.cpp b/modcc/symdiff.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7a634434cd85bdcab48113fb9c1e8cb220d81379 --- /dev/null +++ b/modcc/symdiff.cpp @@ -0,0 +1,692 @@ +#include <cmath> +#include <map> +#include <set> +#include <stdexcept> +#include <string> +#include <utility> + +#include "error.hpp" +#include "expression.hpp" +#include "symdiff.hpp" +#include "visitor.hpp" + +class FindIdentifierVisitor: public Visitor { +public: + explicit FindIdentifierVisitor(const identifier_set& ids): ids_(ids) {} + + void reset() { found_ = false; } + bool found() const { return found_; } + + void visit(Expression* e) override {} + + void visit(UnaryExpression* e) override { + if (!found()) e->expression()->accept(this); + } + + void visit(BinaryExpression* e) override { + if (!found()) e->lhs()->accept(this); + if (!found()) e->rhs()->accept(this); + } + + void visit(CallExpression* e) override { + for (auto& e: e->args()) { + if (found()) return; + e->accept(this); + } + } + + void visit(PDiffExpression* e) override { + if (!found()) e->arg()->accept(this); + } + + void visit(IdentifierExpression* e) override { + if (!found()) { + found_ |= is_in(e->spelling(), ids_); + } + } + + void visit(DerivativeExpression* e) override { + if (!found()) { + found_ |= is_in(e->spelling(), ids_); + } + } + void visit(ReactionExpression* e) override { + if (!found()) e->lhs()->accept(this); + if (!found()) e->rhs()->accept(this); + if (!found()) e->fwd_rate()->accept(this); + if (!found()) e->rev_rate()->accept(this); + } + + void visit(StoichTermExpression* e) override { + if (!found()) e->ident()->accept(this); + } + + void visit(StoichExpression* e) override { + for (auto& e: e->terms()) { + if (found()) return; + e->accept(this); + } + } + + void visit(BlockExpression* e) override { + for (auto& e: e->statements()) { + if (found()) return; + e->accept(this); + } + } + + void visit(IfExpression* e) override { + if (!found()) e->condition()->accept(this); + if (!found()) e->true_branch()->accept(this); + if (!found()) e->false_branch()->accept(this); + } + +private: + const identifier_set& ids_; + bool found_ = false; +}; + +bool involves_identifier(Expression* e, const identifier_set& ids) { + FindIdentifierVisitor v(ids); + e->accept(&v); + return v.found(); +} + +bool involves_identifier(Expression* e, const std::string& id) { + identifier_set ids = {id}; + FindIdentifierVisitor v(ids); + e->accept(&v); + return v.found(); +} + +class SymPDiffVisitor: public Visitor, public error_stack { +public: + explicit SymPDiffVisitor(std::string id): id_(std::move(id)) {} + + void reset() { result_ = nullptr; } + + // Note: moves result, forces reset. + expression_ptr result() { + auto r = std::move(result_); + reset(); + return r; + } + + void visit(Expression* e) override { + error({"symbolic differential of improper expression", e->location()}); + } + + void visit(UnaryExpression* e) override { + error({"symbolic differential of unrecognized unary expression", e->location()}); + } + + void visit(BinaryExpression* e) override { + error({"symbolic differential of unrecognized binary expression", e->location()}); + } + + void visit(NegUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<NegUnaryExpression>(loc, result()); + } + + void visit(ExpUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, result(), e->clone()); + } + + void visit(LogUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<DivBinaryExpression>(loc, result(), e->expression()->clone()); + } + + void visit(CosUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<NegUnaryExpression>(loc, + make_expression<SinUnaryExpression>(loc, e->expression()->clone())), + result()); + } + + void visit(SinUnaryExpression* e) override { + auto loc = e->location(); + e->expression()->accept(this); + result_ = make_expression<MulBinaryExpression>(loc, + make_expression<CosUnaryExpression>(loc, e->expression()->clone()), + result()); + } + + void visit(AddBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + result_ = make_expression<AddBinaryExpression>(loc, std::move(dlhs), result()); + } + + void visit(SubBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + result_ = make_expression<SubBinaryExpression>(loc, move(dlhs), result()); + } + + void visit(MulBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + expression_ptr drhs = std::move(result_); + + result_ = make_expression<AddBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, e->lhs()->clone(), std::move(drhs)), + make_expression<MulBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone())); + + } + + void visit(DivBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + expression_ptr drhs = std::move(result_); + + result_ = make_expression<SubBinaryExpression>(loc, + make_expression<DivBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone()), + make_expression<MulBinaryExpression>(loc, + make_expression<DivBinaryExpression>(loc, + e->lhs()->clone(), + make_expression<MulBinaryExpression>(loc, e->rhs()->clone(), e->rhs()->clone())), + std::move(drhs))); + } + + void visit(PowBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr dlhs = std::move(result_); + + e->rhs()->accept(this); + expression_ptr drhs = std::move(result_); + + result_ = make_expression<AddBinaryExpression>(loc, + make_expression<MulBinaryExpression>(loc, + std::move(drhs), + make_expression<MulBinaryExpression>(loc, + make_expression<LogUnaryExpression>(loc, e->lhs()->clone()), + make_expression<PowBinaryExpression>(loc, e->lhs()->clone(), e->rhs()->clone()))), + make_expression<MulBinaryExpression>(loc, + e->rhs()->clone(), + make_expression<MulBinaryExpression>(loc, + make_expression<PowBinaryExpression>(loc, + e->lhs()->clone(), + make_expression<SubBinaryExpression>(loc, + e->rhs()->clone(), + make_expression<IntegerExpression>(loc, 1))), + std::move(dlhs)))); + } + + void visit(CallExpression* e) override { + auto loc = e->location(); + result_ = make_expression<PDiffExpression>(loc, + make_expression<IdentifierExpression>(loc, id_), + e->clone()); + } + + void visit(PDiffExpression* e) override { + auto loc = e->location(); + e->arg()->accept(this); + result_ = make_expression<PDiffExpression>(loc, e->var()->clone(), result()); + } + + void visit(IdentifierExpression* e) override { + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, e->spelling()==id_); + } + + void visit(NumberExpression* e) override { + auto loc = e->location(); + result_ = make_expression<IntegerExpression>(loc, 0); + } + +private: + expression_ptr result_; + std::string id_; +}; + +// ConstantSimplifyVisitior is not the same as ConstantFolderVisitor, as there is no way for a visitor +// to modify an expression in place (only its children). This visitor instead builds a new expression +// from the given one with constant simplifications. + +long double expr_value(Expression* e) { + return e && e->is_number()? e->is_number()->value(): NAN; +} + +class ConstantSimplifyVisitor: public Visitor { +private: + expression_ptr result_; + + static bool is_number(Expression* e) { return e && e->is_number(); } + static bool is_number(const expression_ptr& e) { return is_number(e.get()); } + + void as_number(Location loc, long double v) { + result_ = make_expression<NumberExpression>(loc, v); + } + +public: + using Visitor::visit; + + ConstantSimplifyVisitor() {} + + // Note: moves result, forces reset. + expression_ptr result() { + auto r = std::move(result_); + reset(); + return r; + } + + void reset() { + result_ = nullptr; + } + + long double value() const { return expr_value(result_); } + + bool is_number() const { return is_number(result_); } + + void visit(Expression* e) override { + result_ = e->clone(); + } + + void visit(BlockExpression* e) override { + auto block_ = e->clone(); + block_->is_block()->statements().clear(); + + for (auto& stmt: e->statements()) { + stmt->accept(this); + auto simpl = result(); + + // flatten any naked blocks generated by if/else simplification + if (auto inner = simpl->is_block()) { + for (auto& stmt: inner->statements()) { + block_->is_block()->statements().push_back(std::move(stmt)); + } + } + else { + block_->is_block()->statements().push_back(std::move(simpl)); + } + } + result_ = std::move(block_); + } + + void visit(IfExpression* e) override { + auto loc = e->location(); + e->condition()->accept(this); + auto cond_expr = result(); + e->true_branch()->accept(this); + auto true_expr = result(); + e->false_branch()->accept(this); + auto false_expr = result(); + + if (!is_number(cond_expr)) { + result_ = make_expression<IfExpression>(loc, + std::move(cond_expr), std::move(true_expr), std::move(false_expr)); + } + else if (expr_value(cond_expr)) { + result_ = std::move(true_expr); + } + else { + result_ = std::move(false_expr); + } + } + + // TODO: procedure, function expressions + + void visit(UnaryExpression* e) override { + e->expression()->accept(this); + + if (is_number()) { + auto loc = e->location(); + auto val = value(); + + switch (e->op()) { + case tok::minus: + as_number(loc, -val); + return; + case tok::exp: + as_number(loc, std::exp(val)); + return; + case tok::sin: + as_number(loc, std::sin(val)); + return; + case tok::cos: + as_number(loc, std::cos(val)); + return; + case tok::log: + as_number(loc, std::log(val)); + return; + default: ; // treat opaquely as below + } + } + + expression_ptr arg = result(); + result_ = e->clone(); + result_->is_unary()->replace_expression(std::move(arg)); + } + + void visit(BinaryExpression* e) override { + result_ = e->clone(); + } + + void visit(AssignmentExpression* e) override { + auto loc = e->location(); + e->rhs()->accept(this); + result_ = make_expression<AssignmentExpression>(loc, e->lhs()->clone(), result()); + } + + void visit(MulBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)*expr_value(rhs)); + } + else if (expr_value(lhs)==0 || expr_value(rhs)==0) { + as_number(loc, 0); + } + else if (expr_value(lhs)==1) { + result_ = std::move(rhs); + } + else if (expr_value(rhs)==1) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<MulBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(DivBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)/expr_value(rhs)); + } + else if (expr_value(lhs)==0) { + as_number(loc, 0); + } + else if (expr_value(rhs)==1) { + result_ = e->lhs()->clone(); + } + else { + result_ = make_expression<DivBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(AddBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)+expr_value(rhs)); + } + else if (expr_value(lhs)==0) { + result_ = std::move(rhs); + } + else if (expr_value(rhs)==0) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<AddBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(SubBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, expr_value(lhs)-expr_value(rhs)); + } + else if (expr_value(lhs)==0) { + make_expression<NegUnaryExpression>(loc, std::move(rhs))->accept(this); + } + else if (expr_value(rhs)==0) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<SubBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(PowBinaryExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + as_number(loc, std::pow(expr_value(lhs),expr_value(rhs))); + } + else if (expr_value(lhs)==0) { + as_number(loc, 0); + } + else if (expr_value(rhs)==0 || expr_value(lhs)==1) { + as_number(loc, 1); + } + else if (expr_value(rhs)==1) { + result_ = std::move(lhs); + } + else { + result_ = make_expression<PowBinaryExpression>(loc, std::move(lhs), std::move(rhs)); + } + } + + void visit(ConditionalExpression* e) override { + auto loc = e->location(); + e->lhs()->accept(this); + expression_ptr lhs = result(); + e->rhs()->accept(this); + expression_ptr rhs = result(); + + if (is_number(lhs) && is_number(rhs)) { + auto lval = expr_value(lhs); + auto rval = expr_value(rhs); + switch (e->op()) { + case tok::equality: + as_number(loc, lval==rval); + return; + case tok::ne: + as_number(loc, lval!=rval); + return; + case tok::lt: + as_number(loc, lval<rval); + return; + case tok::gt: + as_number(loc, lval>rval); + return; + case tok::lte: + as_number(loc, lval<=rval); + return; + case tok::gte: + as_number(loc, lval>=rval); + return; + default: ; + // unrecognized, fall through to non-numeric case below + } + } + if (!is_number(lhs) || !is_number(rhs)) { + result_ = make_expression<ConditionalExpression>(loc, e->op(), std::move(lhs), std::move(rhs)); + } + } +}; + +expression_ptr constant_simplify(Expression* e) { + ConstantSimplifyVisitor csimp_visitor; + e->accept(&csimp_visitor); + return csimp_visitor.result(); +} + + +expression_ptr symbolic_pdiff(Expression* e, const std::string& id) { + if (!involves_identifier(e, id)) { + return make_expression<NumberExpression>(e->location(), 0); + } + + SymPDiffVisitor pdiff_visitor(id); + e->accept(&pdiff_visitor); + + return constant_simplify(pdiff_visitor.result()); +} + +// Substitute all occurances of an identifier within a unary, binary, call +// or (trivially) number expression with a copy of the provided substitute +// expression. + +class SubstituteVisitor: public Visitor { +public: + explicit SubstituteVisitor(const substitute_map& sub): + sub_(sub) {} + + expression_ptr result() { + auto r = std::move(result_); + reset(); + return r; + } + + void reset() { + result_ = nullptr; + } + + void visit(Expression* e) override { + throw compiler_exception("substitution attempt on improper expression", e->location()); + } + + void visit(NumberExpression* e) override { + result_ = e->clone(); + } + + void visit(IdentifierExpression* e) override { + result_ = is_in(e->spelling(), sub_)? sub_.at(e->spelling())->clone(): e->clone(); + } + + void visit(UnaryExpression* e) override { + e->expression()->accept(this); + auto arg = result(); + + result_ = e->clone(); + result_->is_unary()->replace_expression(std::move(arg)); + } + + void visit(BinaryExpression* e) override { + e->lhs()->accept(this); + auto lhs = result(); + e->rhs()->accept(this); + auto rhs = result(); + + result_ = e->clone(); + result_->is_binary()->replace_lhs(std::move(lhs)); + result_->is_binary()->replace_rhs(std::move(rhs)); + } + + void visit(CallExpression* e) override { + auto newexpr = e->clone(); + for (auto& arg: newexpr->is_call()->args()) { + arg->accept(this); + arg = result(); + } + result_ = std::move(newexpr); + } + + void visit(PDiffExpression* e) override { + // Doing the correct thing when the derivative variable is the + // substitution variable would require another 'opaque' expression, + // e.g. `SubstitutionExpression`, but we're not about to do that yet, + // so throw an exception instead, i.e. Don't Do That. + if (is_in(e->var()->is_identifier()->spelling(), sub_)) { + throw compiler_exception("attempt to substitute value for derivative variable", e->location()); + } + + e->arg()->accept(this); + result_ = make_expression<PDiffExpression>(e->location(), e->var()->clone(), result()); + } + +private: + expression_ptr result_; + const substitute_map& sub_; +}; + +expression_ptr substitute(Expression* e, const std::string& id, Expression* sub) { + substitute_map subs; + subs[id] = sub->clone(); + SubstituteVisitor sub_visitor(subs); + e->accept(&sub_visitor); + return sub_visitor.result(); +} + +expression_ptr substitute(Expression* e, const substitute_map& sub) { + SubstituteVisitor sub_visitor(sub); + e->accept(&sub_visitor); + return sub_visitor.result(); +} + +linear_test_result linear_test(Expression* e, const std::vector<std::string>& vars) { + linear_test_result result; + auto loc = e->location(); + auto zero = [loc]() { return make_expression<IntegerExpression>(loc, 0); }; + + result.constant = e->clone(); + for (const auto& id: vars) { + auto coef = symbolic_pdiff(e, id); + if (!is_zero(coef)) result.coef[id] = std::move(coef); + + result.constant = substitute(result.constant, id, zero()); + } + + ConstantSimplifyVisitor csimp_visitor; + result.constant->accept(&csimp_visitor); + result.constant = csimp_visitor.result(); + + // linearity test: take second order derivatives, test against zero. + result.is_linear = true; + for (unsigned i = 0; i<vars.size(); ++i) { + auto v1 = vars[i]; + if (!is_in(v1, result.coef)) continue; + + for (unsigned j = i; j<vars.size(); ++j) { + auto v2 = vars[j]; + + if (!is_zero(symbolic_pdiff(result.coef[v1].get(), v2).get())) { + result.is_linear = false; + goto done; + } + } + } +done: + + if (result.is_linear) { + result.is_homogeneous = is_zero(result.constant); + } + + return result; +} + diff --git a/modcc/symdiff.hpp b/modcc/symdiff.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31c7150150e5b2dd47c48e8ae9372cb22848d7d6 --- /dev/null +++ b/modcc/symdiff.hpp @@ -0,0 +1,119 @@ +#pragma once + +// Perform naive symbolic differenation on a (rhs) expression; +// treat all identifiers as independent, and function calls +// with the variable in argument as opaque. +// +// This is just for linearity and possibly polynomiality testing, so +// don't try too hard. + +#include <iostream> +#include <map> +#include <string> +#include <utility> + +#include "expression.hpp" + + +// True if `id` matches the spelling of any identifier in the expression. +bool involves_identifier(Expression* e, const std::string& id); + +using identifier_set = std::vector<std::string>; +bool involves_identifier(Expression* e, const identifier_set& ids); + +// Return new expression formed by folding constants and removing trivial terms. +expression_ptr constant_simplify(Expression* e); + +// Extract value of expression that is a NumberExpression, or else return NAN. +long double expr_value(Expression* e); + +// Test if expression is a NumberExpression with value zero. +inline bool is_zero(Expression* e) { + return expr_value(e)==0; +} + +// Return new expression of symbolic partial differentiation of argument wrt `id`. +expression_ptr symbolic_pdiff(Expression* e, const std::string& id); + +// Substitute all occurances of identifier `id` within expression by a clone of `sub`. +// (Only applicable to unary, binary, call and number expressions.) +expression_ptr substitute(Expression* e, const std::string& id, Expression* sub); + +using substitute_map = std::map<std::string, expression_ptr>; +expression_ptr substitute(Expression* e, const substitute_map& sub); + +// Convenience interfaces for the above functions work with `expression_ptr` as +// well as with `Expression*` values. + +inline bool involves_identifier(const expression_ptr& e, const std::string& id) { + return involves_identifier(e.get(), id); +} + +inline bool involves_identifier(const expression_ptr& e, const identifier_set& ids) { + return involves_identifier(e.get(), ids); +} + +inline expression_ptr constant_simplify(const expression_ptr& e) { + return constant_simplify(e.get()); +} + +inline long double expr_value(const expression_ptr& e) { + return expr_value(e.get()); +} + +inline long double is_zero(const expression_ptr& e) { + return is_zero(e.get()); +} + +inline expression_ptr symbolic_pdiff(const expression_ptr& e, const std::string& id) { + return symbolic_pdiff(e.get(), id); +} + +inline expression_ptr substitute(const expression_ptr& e, const std::string& id, const expression_ptr& sub) { + return substitute(e.get(), id, sub.get()); +} + +inline expression_ptr substitute(const expression_ptr& e, const substitute_map& sub) { + return substitute(e.get(), sub); +} + + +// Linearity testing + +struct linear_test_result { + bool is_linear = false; + bool is_homogeneous = false; + expression_ptr constant; + std::map<std::string, expression_ptr> coef; + + bool monolinear() const { + unsigned nlinear = 0; + for (auto& entry: coef) { + if (!is_zero(entry.second) && ++nlinear>1) return false; + } + return true; + } + + bool monolinear(const std::string& var) const { + for (auto& entry: coef) { + if (!is_zero(entry.second) && var!=entry.first) return false; + } + return true; + } + + friend std::ostream& operator<<(std::ostream& o, const linear_test_result& r) { + o << "{linear: " << r.is_linear << "; homogeneous: " << r.is_homogeneous << "\n"; + o << " constant term: " << r.constant->to_string(); + for (const auto& p: r.coef) { + o << "\n coef " << p.first << ": " << p.second->to_string(); + } + o << "}"; + return o; + } +}; + +linear_test_result linear_test(Expression* e, const std::vector<std::string>& vars); + +inline linear_test_result linear_test(const expression_ptr& e, const std::vector<std::string>& vars) { + return linear_test(e.get(), vars); +} diff --git a/modcc/symge.cpp b/modcc/symge.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4f09a0226ad0975b36aaaa2887cd331136f8ebf2 --- /dev/null +++ b/modcc/symge.cpp @@ -0,0 +1,109 @@ +#include <algorithm> +#include <stdexcept> +#include <vector> + +#include "symge.hpp" + +namespace symge { + +// Returns q[c]*p - p[c]*q; new symbols required due to fill-in are provided by the +// `define_sym` functor, which takes a `symbol_term_diff` and returns a `symbol`. + +template <typename DefineSym> +sym_row row_reduce(unsigned c, const sym_row& p, const sym_row& q, DefineSym define_sym) { + if (p.index(c)==p.npos || q.index(c)==q.npos) throw std::runtime_error("improper row reduction"); + + sym_row u; + symbol x = q[c]; + symbol y = p[c]; + + auto piter = p.begin(); + auto qiter = q.begin(); + unsigned pj = piter->col; + unsigned qj = qiter->col; + + while (piter!=p.end() || qiter!=q.end()) { + unsigned j = std::min(pj, qj); + symbol_term t1, t2; + + if (j==pj) { + t1 = x*piter->value; + ++piter; + pj = piter==p.end()? p.npos: piter->col; + } + if (j==qj) { + t2 = y*qiter->value; + ++qiter; + qj = qiter==q.end()? q.npos: qiter->col; + } + if (j!=c) { + u.push_back({j, define_sym(t1-t2)}); + } + } + return u; +} + +// Estimate cost of a choice of pivot for G–J reduction below. Uses a simple greedy +// estimate based on immediate fill cost. +double estimate_cost(const sym_matrix& A, unsigned p) { + unsigned nfill = 0; + + auto count_fill = [&nfill](symbol_term_diff t) { + bool l = t.left; + bool r = t.right; + nfill += r&!l; + return symbol{}; + }; + + for (unsigned i = 0; i<A.nrow(); ++i) { + if (i==p || A[i].index(p)==msparse::row_npos) continue; + row_reduce(p, A[i], A[p], count_fill); + } + + return nfill; +} + +// Perform Gauss-Jordan elimination on given symbolic matrix. New symbols +// required due to fill-in are added to the supplied symbol table. +// +// The matrix A is regarded as being diagonally dominant, and so pivots +// are selected from the diagonal. The choice of pivot at each stage of +// the reduction is goverened by a cost estimation (see above). +// +// The reduction is division-free: the result will have non-zero terms +// that are symbols that are either primitive, or defined (in the symbol +// table) as products or differences of products of other symbols. +void gj_reduce(sym_matrix& A, symbol_table& table) { + if (A.nrow()>A.ncol()) throw std::runtime_error("improper matrix for reduction"); + + auto define_sym = [&table](symbol_term_diff t) { return table.define(t); }; + + std::vector<unsigned> pivots; + for (unsigned r = 0; r<A.nrow(); ++r) { + pivots.push_back(r); + } + + std::vector<double> cost(pivots.size()); + + while (!pivots.empty()) { + for (unsigned i = 0; i<pivots.size(); ++i) { + cost[pivots[i]] = estimate_cost(A, pivots[i]); + } + + std::sort(pivots.begin(), pivots.end(), + [&](unsigned r1, unsigned r2) { return cost[r1]>cost[r2]; }); + + unsigned pivrow = pivots.back(); + pivots.erase(std::prev(pivots.end())); + + unsigned pivcol = pivrow; + + for (unsigned i = 0; i<A.nrow(); ++i) { + if (i==pivrow || A[i].index(pivcol)==msparse::row_npos) continue; + + A[i] = row_reduce(pivcol, A[i], A[pivrow], define_sym); + } + } +} + +} // namespace symge diff --git a/modcc/symge.hpp b/modcc/symge.hpp new file mode 100644 index 0000000000000000000000000000000000000000..edd37b9031d6cad6640088661b04e544eb10692b --- /dev/null +++ b/modcc/symge.hpp @@ -0,0 +1,158 @@ +#pragma once + +#include <stdexcept> + +#include "msparse.hpp" + +// Symbolic sparse matrix manipulation for symbolic Gauss-Jordan elimination +// (used in `sparse` solver). + +namespace symge { + +struct symbol_error: public std::runtime_error { + symbol_error(const std::string& what): std::runtime_error(what) {} +}; + +// Abstract symbols: + +class symbol_table; + +class symbol { +private: + unsigned index_; + const symbol_table* table_; + + // Valid symbols are constructed via a symbol table. + friend class symbol_table; + symbol(unsigned index, const symbol_table* table): + index_(index), table_(table) {} + +public: + symbol(): index_(0), table_(nullptr) {} + + // true => valid symbol. + operator bool() const { return table_; } + + bool operator==(symbol other) const { return index_==other.index_ && table_==other.table_; } + bool operator!=(symbol other) const { return !(*this==other); } + + const symbol_table* table() const { return table_; } +}; + +// A `symbol_term` is either zero or a product of symbols. + +struct symbol_term { + symbol left, right; + + symbol_term() = default; + bool is_zero() const { return !left || !right; } + operator bool() const { return !is_zero(); } +}; + +struct symbol_term_diff { + symbol_term left, right; + + symbol_term_diff() = default; + symbol_term_diff(const symbol_term& left): left(left), right{} {} + symbol_term_diff(const symbol_term& left, const symbol_term& right): + left(left), right(right) {} +}; + +inline symbol_term operator*(symbol a, symbol b) { + return symbol_term{a, b}; +} + +inline symbol_term_diff operator-(symbol_term l, symbol_term r) { + return symbol_term_diff{l, r}; +} + +inline symbol_term_diff operator-(symbol_term r) { + return symbol_term_diff{symbol_term{}, r}; +} + +// Symbols are not re-assignable; they are created as primitive, or +// have a definition in terms of a `symbol_term_diff`. + +class symbol_table { +private: + struct table_entry { + std::string name; + symbol_term_diff def; + bool defined; + }; + + std::vector<table_entry> entries_; + +public: + // make new primitive symbol + symbol define(const std::string& name="") { + symbol s(size(), this); + entries_.push_back({name, symbol_term_diff{}, false}); + return s; + } + + // make new symbol with definition + symbol define(const std::string& name, const symbol_term_diff& def) { + symbol s(size(), this); + entries_.push_back({name, def, true}); + return s; + } + + symbol define(const symbol_term_diff& def) { + return define("", def); + } + + symbol_term_diff get(symbol s) const { + if (!defined(s)) throw symbol_error("symbol is primitive"); + return entries_[s.index_].def; + } + + bool defined(symbol s) const { + if (!valid(s)) throw symbol_error("symbol not present in this table"); + return entries_[s.index_].defined; + } + + bool primitive(symbol s) const { return !defined(s); } + + const std::string& name(symbol s) const { + if (!valid(s)) throw symbol_error("symbol not present in this table"); + return entries_[s.index_].name; + } + + void name(symbol s, const std::string& n) { + if (!valid(s)) throw symbol_error("symbol not present in this table"); + entries_[s.index_].name = n; + } + + std::size_t size() const { return entries_.size(); } + + symbol operator[](unsigned i) const { return symbol{i, this}; } + + bool valid(symbol s) const { return s.table_==this && s.index_<size(); } + + // Existing symbols associated with this table are invalidated by clear(). + void clear() { entries_.clear(); } +}; + +inline std::string name(symbol s) { + return s? s.table()->name(s): ""; +} + +inline symbol_term_diff definition(symbol s) { + if (!s) throw symbol_error("invalid symbol"); + return s.table()->get(s); +} + +inline bool primitive(symbol s) { + return s && s.table()->primitive(s); +} + +using sym_row = msparse::row<symbol>; +using sym_matrix = msparse::matrix<symbol>; + +// Perform Gauss-Jordan reduction on a (possibly augmented) symbolic matrix, with +// pivots taken from the diagonal elements. New symbol definitions due to fill-in +// will be added via the provided symbol table. +void gj_reduce(sym_matrix& A, symbol_table& table); + +} // namespace symge diff --git a/modcc/token.cpp b/modcc/token.cpp index 7383ca9fc684b4f47122a98e8f1af035724c8d9a..25d08da9d6060ed3767042b94de0388927f02395 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -52,6 +52,7 @@ static Keyword keywords[] = { {"if", tok::if_stmt}, {"else", tok::else_stmt}, {"cnexp", tok::cnexp}, + {"sparse", tok::sparse}, {"exp", tok::exp}, {"sin", tok::sin}, {"cos", tok::cos}, diff --git a/modcc/token.hpp b/modcc/token.hpp index c31cdbc409a2dd411cc5c074bd351fd994d4d029..95f7c461f368775f4bc1c67d6538f0c15e506503 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -70,6 +70,7 @@ enum class tok { // solver methods cnexp, + sparse, conductance, diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index c474dc89fda572abad331c9e7254241a7f31230b..921189bc70cda3674bf94fabe6441a3a43f174b4 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -34,9 +34,11 @@ public: virtual void visit(IfExpression *e) { visit((Expression*) e); } virtual void visit(SolveExpression *e) { visit((Expression*) e); } virtual void visit(DerivativeExpression *e) { visit((Expression*) e); } + virtual void visit(PDiffExpression *e) { visit((Expression*) e); } virtual void visit(ProcedureExpression *e) { visit((Expression*) e); } virtual void visit(NetReceiveExpression *e) { visit((ProcedureExpression*) e); } virtual void visit(APIMethod *e) { visit((Expression*) e); } + virtual void visit(ConductanceExpression *e) { visit((Expression*) e); } virtual void visit(BlockExpression *e) { visit((Expression*) e); } virtual void visit(InitialBlock *e) { visit((BlockExpression*) e); } @@ -48,6 +50,7 @@ public: virtual void visit(SinUnaryExpression *e) { visit((UnaryExpression*) e); } virtual void visit(BinaryExpression *e) = 0; + virtual void visit(ConditionalExpression *e) {visit((BinaryExpression*) e); } virtual void visit(AssignmentExpression *e) { visit((BinaryExpression*) e); } virtual void visit(ConserveExpression *e) { visit((BinaryExpression*) e); } virtual void visit(AddBinaryExpression *e) { visit((BinaryExpression*) e); } @@ -59,3 +62,136 @@ public: virtual ~Visitor() {}; }; + +// Visitor specialization intended for use as a base class for visitors that +// operate as function or procedure body rewriters after semantic analysis. +// +// Block rewriter visitors construct a new block body from a supplied +// `BlockExpression`, `ProcedureExpression` or `FunctionExpression`. By default, +// expressions are simply copied to the list of statements corresponding to the +// rewritten block; nested blocks as provided by `IfExpression` objects are +// handled recursively. +// +// The `finalize()` method is called after visiting all the statements in the +// top-level block, and is intended to be overrided by derived classes as required. +// +// The `as_block()` method is intended to be called by users of the derived block +// rewriter objects. It constructs a new `BlockExpression` from the accumulated +// replacement statements and applies a semantic pass if the `BlockRewriterBase` +// was given a corresponding scope +// +// The visitor maintains significant internal state: the `reset` method should +// be called between visits of top-level blocks. +// +// Errors are recorded through the `error_stack` mixin rather than by +// throwing an exception. + +class BlockRewriterBase : public Visitor, public error_stack { +public: + BlockRewriterBase() {} + BlockRewriterBase(scope_ptr block_scope): + block_scope_(block_scope) {} + + virtual void visit(Expression *e) override { + statements_.push_back(e->clone()); + } + + virtual void visit(UnaryExpression *e) override { visit((Expression*)e); } + virtual void visit(BinaryExpression *e) override { visit((Expression*)e); } + + virtual void visit(BlockExpression *e) override { + bool top = !started_; + if (top) { + loc_ = e->location(); + started_ = true; + + if (!block_scope_) { + block_scope_ = e->scope(); + } + } + + for (auto& s: e->statements()) { + s->accept(this); + } + + if (top) { + finalize(); + } + } + + virtual void visit(IfExpression* e) override { + expr_list_type outer; + std::swap(outer, statements_); + + e->true_branch()->accept(this); + auto true_branch = make_expression<BlockExpression>( + e->true_branch()->location(), + std::move(statements_), + true); + + statements_.clear(); + e->false_branch()->accept(this); + auto false_branch = make_expression<BlockExpression>( + e->false_branch()->location(), + std::move(statements_), + true); + + statements_ = std::move(outer); + statements_.push_back(make_expression<IfExpression>( + e->location(), + e->condition()->clone(), + std::move(true_branch), + std::move(false_branch))); + } + + virtual void visit(ProcedureExpression* e) override { + e->body()->accept(this); + } + + virtual void visit(FunctionExpression* e) override { + e->body()->accept(this); + } + + virtual expression_ptr as_block(bool is_nested=false) { + if (has_error()) return nullptr; + + expr_list_type body_stmts; + for (const auto& s: statements_) body_stmts.push_back(s->clone()); + + auto body = make_expression<BlockExpression>( + loc_, + std::move(body_stmts), + is_nested); + + if (block_scope_) { + body->semantic(block_scope_); + } + return body; + } + + // Reset state. + virtual void reset() { + statements_.clear(); + started_ = false; + loc_ = Location{}; + clear_errors(); + clear_warnings(); + } + +protected: + // False until processing of top block starts. + bool started_ = false; + + // Location of original block. + Location loc_; + + // Scope for semantic pass. + scope_ptr block_scope_; + + // Statements in replacement procedure body. + expr_list_type statements_; + + // Finalise statements list at end of top block visit. + virtual void finalize() {} +}; + diff --git a/src/backends/fvm_multicore.cpp b/src/backends/fvm_multicore.cpp index dd8f797c99ad29347fc9a103a2b730527f23f976..8d20f2ffe7161832bbdd66fd77dd83255766f8b4 100644 --- a/src/backends/fvm_multicore.cpp +++ b/src/backends/fvm_multicore.cpp @@ -4,6 +4,8 @@ #include <mechanisms/multicore/pas.hpp> #include <mechanisms/multicore/expsyn.hpp> #include <mechanisms/multicore/exp2syn.hpp> +#include <mechanisms/multicore/test_kin1.hpp> +#include <mechanisms/multicore/test_kinlva.hpp> namespace nest { namespace mc { @@ -11,10 +13,12 @@ namespace multicore { std::map<std::string, backend::maker_type> backend::mech_map_ = { - { std::string("pas"), maker<mechanisms::pas::mechanism_pas> }, - { std::string("hh"), maker<mechanisms::hh::mechanism_hh> }, - { std::string("expsyn"), maker<mechanisms::expsyn::mechanism_expsyn> }, - { std::string("exp2syn"), maker<mechanisms::exp2syn::mechanism_exp2syn> } + { std::string("pas"), maker<mechanisms::pas::mechanism_pas> }, + { std::string("hh"), maker<mechanisms::hh::mechanism_hh> }, + { std::string("expsyn"), maker<mechanisms::expsyn::mechanism_expsyn> }, + { std::string("exp2syn"), maker<mechanisms::exp2syn::mechanism_exp2syn> }, + { std::string("test_kin1"), maker<mechanisms::test_kin1::mechanism_test_kin1> }, + { std::string("test_kinlva"), maker<mechanisms::test_kinlva::mechanism_test_kinlva> } }; } // namespace multicore diff --git a/tests/modcc/CMakeLists.txt b/tests/modcc/CMakeLists.txt index bb27815a76e014d5e51ae2264c6204964f1ff971..34f87ed79d325d6954f50a3c0473e9aa2d31ae14 100644 --- a/tests/modcc/CMakeLists.txt +++ b/tests/modcc/CMakeLists.txt @@ -3,9 +3,13 @@ set(MODCC_TEST_SOURCES test_lexer.cpp test_kinetic_rewriter.cpp test_module.cpp + test_msparse.cpp test_optimization.cpp test_parser.cpp #test_printers.cpp + test_removelocals.cpp + test_symdiff.cpp + test_symge.cpp test_visitors.cpp # unit test driver @@ -13,6 +17,7 @@ set(MODCC_TEST_SOURCES # utility expr_expand.cpp + test.cpp ) add_definitions("-DDATADIR=\"${PROJECT_SOURCE_DIR}/data\"") diff --git a/tests/modcc/driver.cpp b/tests/modcc/driver.cpp index 67505a0d20d3406d1ed8142715fb2243e74c6b0b..dcd36906c77fc5485fa3d80b31e3fd7483156436 100644 --- a/tests/modcc/driver.cpp +++ b/tests/modcc/driver.cpp @@ -6,8 +6,6 @@ #include "test.hpp" -bool g_verbose_flag = false; - int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); if (argc>1 && (!std::strcmp(argv[1],"-v") || !std::strcmp(argv[1],"--verbose"))) { diff --git a/tests/modcc/test.cpp b/tests/modcc/test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72e0db777b525350a6e6ddfa5591df5bc372f955 --- /dev/null +++ b/tests/modcc/test.cpp @@ -0,0 +1,27 @@ +#include <regex> +#include <string> + +#include "test.hpp" + +bool g_verbose_flag = false; + +std::string plain_text(Expression* expr) { + static std::regex csi_code(R"_(\x1B\[.*?[\x40-\x7E])_"); + return !expr? "null": regex_replace(expr->to_string(), csi_code, ""); +} + +::testing::AssertionResult assert_expr_eq(const char *arg1, const char *arg2, Expression* expected, Expression* value) { + auto value_rep = plain_text(value); + auto expected_rep = plain_text(expected); + + if ((!value && expected) || (value && !expected) || (value_rep!=expected_rep)) { + return ::testing::AssertionFailure() + << "Value of: " << arg2 << "\n" + << " Actual: " << value_rep << "\n" + << "Expected: " << arg1 << "\n" + << "Which is: " << expected_rep << "\n"; + } + else { + return ::testing::AssertionSuccess(); + } +} diff --git a/tests/modcc/test.hpp b/tests/modcc/test.hpp index 45dced6442094dd73ad0aac80b51d743c43c2750..fe6e76662170b7ca0a3edf0eb30ffac00ebda125 100644 --- a/tests/modcc/test.hpp +++ b/tests/modcc/test.hpp @@ -1,7 +1,10 @@ #pragma once +#include <string> + #include "../gtest.h" +#include "expression.hpp" #include "parser.hpp" #include "modccutil.hpp" @@ -24,3 +27,63 @@ inline expression_ptr parse_function(std::string const& s) { inline expression_ptr parse_procedure(std::string const& s) { return Parser(s).parse_procedure(); } + +// Helpers for comparing expressions, and verbose expression printing. + +// Strip ANSI control sequences from `to_string` output. +std::string plain_text(Expression* expr); + +// Compare two expressions via their representation. +// Use with EXPECT_PRED_FORMAT2. +::testing::AssertionResult assert_expr_eq(const char *arg1, const char *arg2, Expression* expected, Expression* value); + +#define EXPECT_EXPR_EQ(a,b) EXPECT_PRED_FORMAT2(assert_expr_eq, a, b) + +// Print arguments, but only if verbose flag set. +// Use `to_string()` to print (smart) pointers to Expression or Scope objects. + +namespace impl { + template <typename X> + struct has_to_string { + template <typename T> + static decltype(std::declval<T>()->to_string(), std::true_type{}) test(int); + template <typename T> + static std::false_type test(...); + + using type = decltype(test<X>(0)); + }; + + template <typename X> + void print(const X& x, std::true_type) { + if (x) { + std::cout << x->to_string(); + } + else { + std::cout << "null"; + } + } + + template <typename X> + void print(const X& x, std::false_type) { + std::cout << x; + } + + template <typename X> + void print(const X& x) { + print(x, typename has_to_string<X>::type{}); + } + +} + +inline void verbose_print() { + if (!g_verbose_flag) return; + std::cout << "\n"; +} + +template <typename X, typename... Args> +void verbose_print(const X& arg, const Args&... tail) { + if (!g_verbose_flag) return; + impl::print(arg); + verbose_print(tail...); +} + diff --git a/tests/modcc/test_kinetic_rewriter.cpp b/tests/modcc/test_kinetic_rewriter.cpp index 8ec3173a509102e75472096bff129fa2ca28b096..62c2feb2e663eb1536579766cc367a96379ac6db 100644 --- a/tests/modcc/test_kinetic_rewriter.cpp +++ b/tests/modcc/test_kinetic_rewriter.cpp @@ -2,21 +2,31 @@ #include <string> #include "expression.hpp" -#include "kinrewriter.hpp" +#include "kineticrewriter.hpp" #include "parser.hpp" #include "alg_collect.hpp" #include "expr_expand.hpp" #include "test.hpp" -using namespace nest::mc; +expr_list_type& statements(Expression *e) { + if (e) { + if (auto block = e->is_block()) { + return block->statements(); + } + + if (auto sym = e->is_symbol()) { + if (auto proc = sym->is_procedure()) { + return proc->body()->statements(); + } -stmt_list_type& proc_statements(Expression *e) { - if (!e || !e->is_symbol() || ! e->is_symbol()->is_procedure()) { - throw std::runtime_error("not a procedure"); + if (auto proc = sym->is_procedure()) { + return proc->body()->statements(); + } + } } - return e->is_symbol()->is_procedure()->body()->statements(); + throw std::runtime_error("not a block, function or procedure"); } @@ -49,37 +59,40 @@ static const char* derivative_abc = "} \n"; TEST(KineticRewriter, equiv) { - auto visitor = util::make_unique<KineticRewriter>(); auto kin = Parser(kinetic_abc).parse_procedure(); auto deriv = Parser(derivative_abc).parse_procedure(); + auto kin_ptr = kin.get(); + auto deriv_ptr = deriv.get(); + ASSERT_NE(nullptr, kin); ASSERT_NE(nullptr, deriv); ASSERT_TRUE(kin->is_symbol() && kin->is_symbol()->is_procedure()); ASSERT_TRUE(deriv->is_symbol() && deriv->is_symbol()->is_procedure()); - auto kin_weak = kin->is_symbol()->is_procedure(); scope_type::symbol_map globals; globals["kin"] = std::move(kin); + globals["deriv"] = std::move(deriv); globals["a"] = state_var("a"); globals["b"] = state_var("b"); globals["c"] = state_var("c"); globals["u"] = assigned_var("u"); globals["v"] = assigned_var("v"); - kin_weak->semantic(globals); - kin_weak->accept(visitor.get()); + deriv_ptr->semantic(globals); - auto kin_deriv = visitor->as_procedure(); + auto kin_body = kin_ptr->is_procedure()->body(); + scope_ptr scope = std::make_shared<scope_type>(globals); + kin_body->semantic(scope); - if (g_verbose_flag) { - std::cout << "derivative procedure:\n" << deriv->to_string() << "\n"; - std::cout << "kin procedure:\n" << kin_weak->to_string() << "\n"; - std::cout << "rewritten kin procedure:\n" << kin_deriv->to_string() << "\n"; - } + auto kin_deriv = kinetic_rewrite(kin_body); + + verbose_print("derivative procedure:\n", deriv_ptr); + verbose_print("kin procedure:\n", kin_ptr); + verbose_print("rewritten kin body:\n", kin_deriv); - auto deriv_map = expand_assignments(proc_statements(deriv.get())); - auto kin_map = expand_assignments(proc_statements(kin_deriv.get())); + auto deriv_map = expand_assignments(statements(deriv_ptr)); + auto kin_map = expand_assignments(statements(kin_deriv.get())); if (g_verbose_flag) { std::cout << "derivative assignments (canonical):\n"; diff --git a/tests/modcc/test_lexer.cpp b/tests/modcc/test_lexer.cpp index e148b12b0631e1622bd52c2a4771f9b96802b23b..91ee7df7d1ddac0f48e6da079e54ed115279f75a 100644 --- a/tests/modcc/test_lexer.cpp +++ b/tests/modcc/test_lexer.cpp @@ -7,48 +7,38 @@ #include "test.hpp" #include "lexer.hpp" -void verbose_print(const char* string) { - if (!g_verbose_flag) return; - std::cout << "________________\n" << string << "\n________________\n"; -} - -void verbose_print(const Token& token) { - if (!g_verbose_flag) return; - std::cout << "tok: " << token << "\n"; -} - class VerboseLexer: public Lexer { public: template <typename... Args> VerboseLexer(Args&&... args): Lexer(std::forward<Args>(args)...) { - if (g_verbose_flag) { - std::cout << "________________\n" << std::string(begin_, end_) << "\n________________\n"; - } + verbose_print("________________"); + verbose_print(std::string(begin_, end_)); + verbose_print("________________"); } Token parse() { auto tok = Lexer::parse(); - if (g_verbose_flag) { - std::cout << "token: " << tok << "\n"; - } + verbose_print("token: ",tok); return tok; } char character() { char c = Lexer::character(); - if (g_verbose_flag) { - std::cout << "character: "; - if (!std::isprint(c)) { - char buf[5] = "XXXX"; - snprintf(buf, sizeof buf, "0x%02x", (unsigned)c); - std::cout << buf << '\n'; - } - else { - std::cout << c << '\n'; - } - } + verbose_print("character: ", pretty(c)); return c; } + + static const char* pretty(char c) { + static char buf[5] = "XXXX"; + if (!std::isprint(c)) { + snprintf(buf, sizeof buf, "0x%02x", (unsigned)c); + } + else { + buf[0] = c; + buf[1] = 0; + } + return buf; + } }; /************************************************************** diff --git a/tests/modcc/test_msparse.cpp b/tests/modcc/test_msparse.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6229e725a3912046cdd70c0fa152d3a5580f5ad7 --- /dev/null +++ b/tests/modcc/test_msparse.cpp @@ -0,0 +1,188 @@ +#include <utility> + +#include "msparse.hpp" +#include "test.hpp" + +using drow = msparse::row<double>; +using dmatrix = msparse::matrix<double>; +using msparse::row_npos; + +namespace msparse { +bool operator==(const drow::entry& a, const drow::entry& b) { + return a.col==b.col && a.value==b.value; +} +} + +TEST(msparse, row_ctor) { + + drow r1; + EXPECT_TRUE(r1.empty()); + + drow r2({{0,3.0},{2,-1.5}}); + EXPECT_FALSE(r2.empty()); + EXPECT_EQ(2u, r2.size()); + + drow r3(r2); + EXPECT_FALSE(r3.empty()); + EXPECT_EQ(2u, r3.size()); + + drow::entry entries[] = { {1,2.}, {4,-2.}, {5,0} }; + drow r4(std::begin(entries), std::end(entries)); + EXPECT_FALSE(r4.empty()); + EXPECT_EQ(3u, r4.size()); + + ASSERT_THROW(drow({{2,-1.5}, {0,3.0}}), msparse::msparse_error); + + drow r5 = r4; + EXPECT_FALSE(r5.empty()); + EXPECT_EQ(3u, r5.size()); +} + +TEST(msparse, row_iter) { + using drow = msparse::row<double>; + + drow r1; + EXPECT_EQ(r1.begin(), r1.end()); + const drow& r1c = r1; + EXPECT_EQ(r1c.begin(), r1c.end()); + + drow r2({{0,3.0},{2,-1.5}}); + EXPECT_EQ(*r2.begin(), drow::entry({0,3.0})); + EXPECT_EQ(*std::prev(r2.end()), drow::entry({2,-1.5})); + EXPECT_EQ(2u, std::distance(r2.begin(), r2.end())); +} + +TEST(msparse, row_query) { + EXPECT_EQ(row_npos, drow{}.mincol()); + EXPECT_EQ(row_npos, drow{}.maxcol()); + EXPECT_EQ(row_npos, drow{}.index(2u)); + + drow r({{1,2.0}, {2,4.0}, {4,7.0}, {6,9.0}}); + + EXPECT_EQ(4u, r.size()); + EXPECT_EQ(1u, r.mincol()); + EXPECT_EQ(6u, r.maxcol()); + + EXPECT_EQ(2u, r.mincol_after(1)); + EXPECT_EQ(4u, r.mincol_after(2)); + EXPECT_EQ(4u, r.mincol_after(3)); + EXPECT_EQ(row_npos, r.mincol_after(6)); + + EXPECT_EQ(0u, r.index(1)); + EXPECT_EQ(1u, r.index(2)); + EXPECT_EQ(3u, r.index(6)); + EXPECT_EQ(row_npos, r.index(0)); + EXPECT_EQ(row_npos, r.index(5)); + EXPECT_EQ(row_npos, r.index(7)); + + EXPECT_EQ(0.0, r[0]); + EXPECT_EQ(2.0, r[1]); + EXPECT_EQ(4.0, r[2]); + EXPECT_EQ(0.0, r[3]); + EXPECT_EQ(9.0, r[6]); + EXPECT_EQ(0.0, r[7]); + + EXPECT_EQ(drow::entry({1,2.0}), r.get(0)); + EXPECT_EQ(drow::entry({2,4.0}), r.get(1)); + EXPECT_EQ(drow::entry({4,7.0}), r.get(2)); +} + +TEST(msparse, row_mutate) { + drow r; + + r.push_back({1, 2.0}); + EXPECT_EQ(1u, r.size()); + + r.push_back({3, 3.0}); + EXPECT_EQ(2u, r.size()); + + ASSERT_THROW(r.push_back({3, 2.0}), msparse::msparse_error); + ASSERT_THROW(r.push_back({2, 2.0}), msparse::msparse_error); + + r.truncate(4); + EXPECT_EQ(2u, r.size()); + + r.truncate(3); + EXPECT_EQ(1u, r.size()); + + r.truncate(0); + EXPECT_EQ(0u, r.size()); + + r[7] = 4.0; + EXPECT_EQ(1u, r.size()); + EXPECT_EQ(7u, r.mincol()); +} + +TEST(msparse, matrix_ctor) { + dmatrix M; + EXPECT_EQ(0u, M.size()); + EXPECT_EQ(0u, M.nrow()); + EXPECT_EQ(0u, M.ncol()); + EXPECT_TRUE(M.empty()); + + dmatrix M2(5,3); + EXPECT_EQ(5u, M2.size()); + EXPECT_EQ(5u, M2.nrow()); + EXPECT_EQ(3u, M2.ncol()); + EXPECT_FALSE(M2.empty()); + + dmatrix M3(M2); + EXPECT_EQ(5u, M3.nrow()); + EXPECT_EQ(3u, M3.ncol()); +} + +TEST(msparse, matrix_index) { + dmatrix M2(5,3); + M2[4].push_back({2, -1.0}); + + const dmatrix& M2c = M2; + + EXPECT_EQ(-1.0, M2[4][2]); + EXPECT_EQ(-1.0, M2c[4][2]); +} + +TEST(msparse, matrix_augment) { + dmatrix M(5,3); + + M[1][2] = 9.0; + EXPECT_FALSE(M.augmented()); + EXPECT_EQ(row_npos, M.augcol()); + + EXPECT_EQ(0u, M[0].size()); + EXPECT_EQ(1u, M[1].size()); + + double aug1[] = {5, 4, 3, 2, 1}; + M.augment(aug1); + + // short is okay... + double aug2[] = {1, 2, 1}; + M.augment(aug2); + + EXPECT_TRUE(M.augmented()); + EXPECT_EQ(3u, M.augcol()); + + EXPECT_EQ(2u, M[0].size()); + EXPECT_EQ(3u, M[0].mincol()); + EXPECT_EQ(4u, M[0].maxcol()); + + EXPECT_EQ(3u, M[1].size()); + EXPECT_EQ(2u, M[1].mincol()); + EXPECT_EQ(4u, M[1].maxcol()); + + EXPECT_EQ(9.0, M[1][2]); + EXPECT_EQ(4.0, M[1][3]); + EXPECT_EQ(2.0, M[1][4]); + + M.diminish(); + + EXPECT_FALSE(M.augmented()); + EXPECT_EQ(row_npos, M.augcol()); + + EXPECT_EQ(0u, M[0].size()); + EXPECT_EQ(1u, M[1].size()); + EXPECT_EQ(9.0, M[1][M[1].maxcol()]); + + // augmenting with too long a column is not okay... + double aug3[] = {1, 2, 3, 4, 5, 6}; + ASSERT_THROW(M.augment(aug3), msparse::msparse_error); +} diff --git a/tests/modcc/test_optimization.cpp b/tests/modcc/test_optimization.cpp index 79102bd2c8b4b335b91066cd7c14232d834036b6..af289eb65d765b55f1d46c68b4ad4840fd58f61b 100644 --- a/tests/modcc/test_optimization.cpp +++ b/tests/modcc/test_optimization.cpp @@ -5,54 +5,52 @@ #include "constantfolder.hpp" #include "modccutil.hpp" -using namespace nest::mc; - TEST(Optimizer, constant_folding) { - auto v = util::make_unique<ConstantFolderVisitor>(); + ConstantFolderVisitor v; { auto e = parse_line_expression("x = 2*3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 6); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x = 1 + 2 + 3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 6); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x = exp(2)"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); // The tolerance has to be loosend to 1e-15, because the optimizer performs // all intermediate calculations in 80 bit precision, which disagrees in // the 16 decimal place to the double precision value from std::exp(2.0). // This is a good thing: by using the constant folder we increase accuracy // over the unoptimized code! EXPECT_EQ(std::fabs(e->is_assignment()->rhs()->is_number()->value()-std::exp(2.0))<1e-15, true); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print("" ); } { auto e = parse_line_expression("x= 2*2 + 3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 7); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x= 3 + 2*2"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); + verbose_print(e); + e->accept(&v); EXPECT_EQ(e->is_assignment()->rhs()->is_number()->value(), 7); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + verbose_print(); } { // this doesn't work: the (y+2) expression is not a constant, so folding stops. @@ -60,23 +58,23 @@ TEST(Optimizer, constant_folding) { // one approach would be try sorting communtative operations so that numbers // are adjacent to one another in the tree auto e = parse_line_expression("x= y + 2 + 3"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT( "" ); + verbose_print(e); + e->accept(&v); + verbose_print(e); + verbose_print(); } { auto e = parse_line_expression("x= 2 + 3 + y"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT("");; + verbose_print(e); + e->accept(&v); + verbose_print(e); + verbose_print();; } { auto e = parse_line_expression("foo(2+3, log(32), 2*3 + x)"); - VERBOSE_PRINT( e->to_string() ); - e->accept(v.get()); - VERBOSE_PRINT( e->to_string() ); - VERBOSE_PRINT("");; + verbose_print(e); + e->accept(&v); + verbose_print(e); + verbose_print(); } } diff --git a/tests/modcc/test_parser.cpp b/tests/modcc/test_parser.cpp index 853b2136511d3c73705111d2062946605e6c12bc..d2a39b21ba735e889421f1ec317c33e4ecf9b879 100644 --- a/tests/modcc/test_parser.cpp +++ b/tests/modcc/test_parser.cpp @@ -6,13 +6,13 @@ #include "modccutil.hpp" #include "parser.hpp" +// overload for parser errors template <typename EPtr> void verbose_print(const EPtr& e, Parser& p, const char* text) { - if (!g_verbose_flag) return; - - if (e) std::cout << e->to_string() << "\n"; - if (p.status()==lexerStatus::error) - std::cout << "in " << red(text) << "\t" << p.error_message() << "\n"; + verbose_print(e); + if (p.status()==lexerStatus::error) { + verbose_print("in ", red(text), "\t", p.error_message()); + } } template <typename Derived, typename RetUniqPtr> diff --git a/tests/modcc/test_removelocals.cpp b/tests/modcc/test_removelocals.cpp new file mode 100644 index 0000000000000000000000000000000000000000..140de2f18a969e45e1f0933a4ca2a77c8b538565 --- /dev/null +++ b/tests/modcc/test_removelocals.cpp @@ -0,0 +1,197 @@ +#include <iostream> +#include <map> +#include <regex> +#include <string> + +#include "expression.hpp" +#include "solvers.hpp" +#include "parser.hpp" +#include "scope.hpp" + +#include "test.hpp" + +symbol_ptr make_global(std::string name) { + return make_symbol<VariableExpression>(Location(), std::move(name)); +} + +using symbol_map = Scope<Symbol>::symbol_map; + +Symbol* add_procedure(symbol_map& symbols, const char* src) { + auto proc = Parser(src).parse_procedure(); + std::string name = proc->is_procedure()->name(); + + Symbol* weak = (symbols[name] = std::move(proc)).get(); + weak->semantic(symbols); + return weak; +} + +Symbol* add_global(symbol_map& symbols, const std::string& name) { + auto& var = (symbols[name] = make_symbol<VariableExpression>(Location(), name)); + return var.get(); +} + +TEST(remove_unused_locals, simple) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a \n" + " LOCAL b \n" + " a = 3 \n" + " g = a \n" + " b = 4 \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a \n" + " a = 3 \n" + " g = a \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} + +TEST(remove_unused_locals, compound) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a, b, c, d \n" + " g1 = a \n" + " g2 = c \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a, c \n" + " g1 = a \n" + " g2 = c \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g1"); + add_global(symbols, "g2"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} + +TEST(remove_unused_locals, with_dependencies) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " LOCAL d \n" + " LOCAL e \n" + " LOCAL f \n" + " g1 = f \n" + " b = 3 \n" + " a = b-e \n" + " c = log(a)+2 \n" + " d = 4 \n" + " g2 = c \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " LOCAL e \n" + " LOCAL f \n" + " g1 = f \n" + " b = 3 \n" + " a = b-e \n" + " c = log(a)+2 \n" + " g2 = c \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g1"); + add_global(symbols, "g2"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} + +TEST(remove_unused_locals, inner_block) { + const char* before_src = + "PROCEDURE before { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " LOCAL d \n" + " if (a>0) { \n" + " g = b \n" + " } \n" + " else { \n" + " d = g \n" + " g = c \n" + " } \n" + "} \n"; + + const char* expected_src = + "PROCEDURE expected { \n" + " LOCAL a \n" + " LOCAL b \n" + " LOCAL c \n" + " if (a>0) { \n" + " g = b \n" + " } \n" + " else { \n" + " g = c \n" + " } \n" + "} \n"; + + symbol_map symbols; + add_global(symbols, "g"); + + auto before = add_procedure(symbols, before_src); + auto before_body = before->is_procedure()->body(); + + auto expected = add_procedure(symbols, expected_src); + auto expected_body = expected->is_procedure()->body(); + + auto after = remove_unused_locals(before_body); + + verbose_print("before: ", before_body); + verbose_print("after: ", after); + verbose_print("expected: ", expected_body); + + EXPECT_EXPR_EQ(expected_body, after.get()); +} diff --git a/tests/modcc/test_symdiff.cpp b/tests/modcc/test_symdiff.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7f7098254cdb081a2e61c2c5a964353dde057520 --- /dev/null +++ b/tests/modcc/test_symdiff.cpp @@ -0,0 +1,330 @@ +#include <cmath> + +#include "test.hpp" + +#include "symdiff.hpp" +#include "parser.hpp" +#include "modccutil.hpp" + +// Test visitors for extended constant reduction, +// identifier presence detection and symbolic differentiation. + +TEST(involves_identifier, line_expr) { + const char* expr_x[] = { + "a=4+x", + "x'=a", + "a=exp(x)", + "x=sin(2)", + "~ 2x <-> (1, a)" + }; + + const char* expr_xy[] = { + "a=4+x+y", + "x'=exp(x*sin(y))", + "x=sin(2*y)", + "~ 2x <-> 3y (1, 1)" + }; + + const char* expr_xyz[] = { + "a=4+x+y*log(z)", + "x'=exp(x*sin(y))+func(2,z)", + "x=sin(2*y)+(-z)", + "~ 2x <-> 3y (a, z)" + }; + + identifier_set xyz_ids = { "x", "y", "z" }; + identifier_set yz_ids = { "y", "z" }; + identifier_set uvw_ids = { "u", "v", "w" }; + + for (auto line: expr_x) { + SCOPED_TRACE(std::string("expression: ")+line); + Parser p(line); + auto e = p.parse_statement(); + ASSERT_TRUE(e); + + EXPECT_TRUE(involves_identifier(e, "x")); + EXPECT_FALSE(involves_identifier(e, "y")); + EXPECT_FALSE(involves_identifier(e, "z")); + + EXPECT_TRUE(involves_identifier(e, xyz_ids)); + EXPECT_FALSE(involves_identifier(e, yz_ids)); + EXPECT_FALSE(involves_identifier(e, uvw_ids)); + } + + for (auto line: expr_xy) { + SCOPED_TRACE(std::string("expression: ")+line); + Parser p(line); + auto e = p.parse_statement(); + ASSERT_TRUE(e); + + EXPECT_TRUE(involves_identifier(e, "x")); + EXPECT_TRUE(involves_identifier(e, "y")); + EXPECT_FALSE(involves_identifier(e, "z")); + + EXPECT_TRUE(involves_identifier(e, xyz_ids)); + EXPECT_TRUE(involves_identifier(e, yz_ids)); + EXPECT_FALSE(involves_identifier(e, uvw_ids)); + } + + for (auto line: expr_xyz) { + SCOPED_TRACE(std::string("expression: ")+line); + Parser p(line); + auto e = p.parse_statement(); + ASSERT_TRUE(e); + + EXPECT_TRUE(involves_identifier(e, "x")); + EXPECT_TRUE(involves_identifier(e, "y")); + EXPECT_TRUE(involves_identifier(e, "z")); + + EXPECT_TRUE(involves_identifier(e, xyz_ids)); + EXPECT_TRUE(involves_identifier(e, yz_ids)); + EXPECT_FALSE(involves_identifier(e, uvw_ids)); + } +} + +TEST(constant_simplify, constants) { + struct { const char* repn; double value; } tests[] = { + { "17+3", 20. }, + { "log(exp(2))+cos(0)", 3. }, + { "0/17-1", -1. }, + { "2.5*(34/17-1.0e2)", -245. }, + { "-sin(0.523598775598298873077107230546583814)", -0.5 } + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expression: ")+item.repn); + Parser p(item.repn); + auto e = p.parse_expression(); + ASSERT_TRUE(e); + + auto value = expr_value(constant_simplify(e)); + EXPECT_FALSE(std::isnan(value)); + EXPECT_NEAR(item.value, value, 1e-8); + } +} + +TEST(constant_simplify, simplified_expr) { + // Expect simplification of 'before' expression matches parse of 'after'. + // Use output string representation of expression for easy comparison. + + struct { const char* before; const char* after; } tests[] = { + { "x+y/z", "x+y/z" }, + { "(0*x)+y/(z-(0*w))", "y/z" }, + { "y*exp(0)", "y" }, + { "x^(2-1)", "x" }, + { "0-(y+0)", "-y" } + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string()); + } +} + +TEST(constant_simplify, block_with_if) { + const char* before_repn = + "{\n" + " x = 0/z - y*log(1) + w*(3-2)\n" + " if (2>1) {\n" + " if (1==3) {\n" + " y = 64\n" + " }\n" + " else {\n" + " y = exp(0)*x\n" + " z = y-(x*0)\n" + " }\n" + " }\n" + " else {\n" + " y = 32\n" + " }\n" + "}\n"; + + const char* after_repn = + "{\n" + " x = w\n" + " y = x\n" + " z = y\n" + "}\n"; + + auto before = Parser{before_repn}.parse_block(false); + auto after = Parser{after_repn}.parse_block(false); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), constant_simplify(before)->to_string()); +} + +TEST(symbolic_pdiff, expressions) { + struct { const char* before; const char* after; } tests[] = { + { "y+z*4", "0" }, + { "x", "1" }, + { "x*3", "3"}, + { "x-log(x)", "1-1/x"}, + { "sin(x)", "cos(x)"}, + { "cos(x)", "-sin(x)"}, + { "exp(x)", "exp(x)"}, + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + } +} + +TEST(symbolic_pdiff, linear) { + struct { const char* before; const char* after; } tests[] = { + { "x+y/z", "1" }, + { "3.0*x-x/2.0+x*y", "2.5+y"}, + { "(1+2.*x)/(1-exp(y))", "2./(1-exp(y))"} + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + } +} + +TEST(symbolic_pdiff, nonlinear) { + struct { const char* before; const char* after; } tests[] = { + { "sin(x)", "cos(x)" }, + { "exp(2*x)", "2*exp(2*x)" }, + { "x^2", "2*x" }, + { "a^x", "log(a)*a^x" } + }; + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + EXPECT_EQ(after->to_string(), symbolic_pdiff(before, "x")->to_string()); + } +} + +inline expression_ptr operator""_expr(const char* literal, std::size_t) { + return Parser{literal}.parse_expression(); +} + +TEST(substitute, expressions) { + struct { const char* before; const char* after; } tests[] = { + { "x", "y+z" }, + { "y", "y" }, + { "2.0", "2.0" }, + { "sin(x)", "sin(y+z)" }, + { "func(x,3+x)+y", "func(y+z,3+(y+z))+y" } + }; + + auto yplusz = "y+z"_expr; + ASSERT_TRUE(yplusz); + + for (const auto& item: tests) { + SCOPED_TRACE(std::string("expressions: ")+item.before+"; "+item.after); + auto before = Parser{item.before}.parse_expression(); + auto after = Parser{item.after}.parse_expression(); + ASSERT_TRUE(before); + ASSERT_TRUE(after); + + auto result = substitute(before.get(), "x", yplusz.get()); + EXPECT_EQ(after->to_string(), result->to_string()); + } +} + +TEST(substitute, exprmap) { + substitute_map subs; + subs["x"] = "sin(y)"_expr; + subs["y"] = "cos(x)"_expr; + + // note substitute is not recursive! + auto before = "exp(x+y)"_expr; + ASSERT_TRUE(before); + + auto after = "exp(sin(y)+cos(x))"_expr; + ASSERT_TRUE(after); + + auto result = substitute(before.get(), subs); + EXPECT_EQ(after->to_string(), result->to_string()); +} + +TEST(linear_test, homogeneous) { + linear_test_result r; + + r = linear_test("3*x"_expr, {"x"}); + EXPECT_TRUE(r.is_linear); + EXPECT_TRUE(r.is_homogeneous); + EXPECT_TRUE(r.monolinear()); + EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string()); + + r = linear_test("y-a*x+2*x"_expr, {"x", "y"}); + EXPECT_TRUE(r.is_linear); + EXPECT_TRUE(r.is_homogeneous); + EXPECT_FALSE(r.monolinear()); + EXPECT_EQ(r.coef["x"]->to_string(), "-a+2"_expr->to_string()); + EXPECT_EQ(r.coef["y"]->to_string(), "1"_expr->to_string()); +} + +TEST(linear_test, inhomogeneous) { + linear_test_result r; + + r = linear_test("sin(y)+3*x"_expr, {"x"}); + EXPECT_TRUE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); + EXPECT_EQ(r.coef["x"]->to_string(), "3"_expr->to_string()); + EXPECT_EQ(r.constant->to_string(), "sin(y)"_expr->to_string()); + + r = linear_test("(x+y+1)*(a+b)"_expr, {"x", "y"}); + EXPECT_TRUE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); + EXPECT_EQ(r.coef["x"]->to_string(), "a+b"_expr->to_string()); + EXPECT_EQ(r.coef["y"]->to_string(), "a+b"_expr->to_string()); + EXPECT_EQ(r.constant->to_string(), "a+b"_expr->to_string()); + + // check 'gating' case still works! (Use plus instead of minus + // though because of -1 vs (- 1) parsing makes the test harder.) + r = linear_test("(a+x)/b"_expr, {"x"}); + EXPECT_TRUE(r.is_linear); + EXPECT_FALSE(r.is_homogeneous); + EXPECT_EQ(r.coef["x"]->to_string(), "1/b"_expr->to_string()); + EXPECT_EQ(r.constant->to_string(), "a/b"_expr->to_string()); +} + +TEST(linear_test, nonlinear) { + linear_test_result r; + + r = linear_test("x+x^2"_expr, {"x", "y"}); + EXPECT_FALSE(r.is_linear); + + r = linear_test("x+y*x"_expr, {"x", "y"}); + EXPECT_FALSE(r.is_linear); +} + +TEST(linear_test, diagonality) { + auto xdot = "a*x"_expr; + auto ydot = "-b*y/2"_expr; + auto zdot = "x+y+z"_expr; + + // xdot, ydot diagonal linear + EXPECT_TRUE(linear_test(xdot, {"x", "y"}).monolinear("x")); + EXPECT_TRUE(linear_test(ydot, {"x", "y"}).monolinear("y")); + + // but xdot, ydot, zdot not diagonal + EXPECT_TRUE(linear_test(xdot, {"x", "y", "z"}).monolinear("x")); + EXPECT_TRUE(linear_test(ydot, {"x", "y", "z"}).monolinear("y")); + EXPECT_FALSE(linear_test(zdot, {"x", "y", "z"}).monolinear("z")); +} diff --git a/tests/modcc/test_symge.cpp b/tests/modcc/test_symge.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7ce77aebb5423e28a5b26ff38117dbc132f80096 --- /dev/null +++ b/tests/modcc/test_symge.cpp @@ -0,0 +1,178 @@ +#include <map> +#include <utility> +#include <vector> + +#include "symge.hpp" +#include "test.hpp" + +using namespace symge; + +TEST(symge, table_define) { + symbol_table tbl; + + EXPECT_EQ(0u, tbl.size()); + + auto s1 = tbl.define(); + auto s2 = tbl.define("foo"); + auto s3 = tbl.define(symbol_term_diff{}); + auto s4 = tbl.define("bar",symbol_term_diff{}); + + EXPECT_EQ(4u, tbl.size()); + EXPECT_TRUE(tbl.valid(s1)); + EXPECT_TRUE(tbl.valid(s2)); + EXPECT_TRUE(tbl.valid(s3)); + EXPECT_TRUE(tbl.valid(s4)); + EXPECT_FALSE(tbl.valid(symbol{})); +} + +TEST(symge, symbol_ops) { + symbol_table tbl; + + symbol a = tbl.define("a"); + symbol b = tbl.define("b"); + symbol c = tbl.define("c", a*b); + symbol d = tbl.define("d", -(a*b)); + symbol e = tbl.define("e", c*d-a*b); + + EXPECT_TRUE(tbl.primitive(a)); + EXPECT_TRUE(tbl.primitive(b)); + EXPECT_FALSE(tbl.primitive(c)); + EXPECT_FALSE(tbl.primitive(d)); + EXPECT_FALSE(tbl.primitive(e)); + + auto cdef = tbl.get(c); + EXPECT_FALSE(cdef.left.is_zero()); + EXPECT_TRUE(cdef.right.is_zero()); + + auto ddef = tbl.get(d); + EXPECT_TRUE(ddef.left.is_zero()); + EXPECT_FALSE(ddef.right.is_zero()); + + auto edef = tbl.get(e); + EXPECT_FALSE(edef.left.is_zero()); + EXPECT_FALSE(edef.right.is_zero()); + + EXPECT_EQ(c, edef.left.left); + EXPECT_EQ(d, edef.left.right); + EXPECT_EQ(a, edef.right.left); + EXPECT_EQ(b, edef.right.right); +} + +TEST(symge, symbol_name) { + symbol_table tbl; + + symbol a = tbl.define("a"); + EXPECT_EQ("a", tbl.name(a)); + EXPECT_EQ("a", name(a)); + tbl.name(a, "b"); + EXPECT_EQ("b", name(a)); +} + +TEST(symge, table_index) { + symbol_table tbl; + + symbol a = tbl.define("a"); + symbol b = tbl.define("b"); + symbol c = tbl.define("c", a*b); + symbol d = tbl.define("d", -(a*b)); + symbol e = tbl.define("e", c*d-a*b); + + EXPECT_EQ(a, tbl[0]); + EXPECT_EQ(d, tbl[3]); + EXPECT_EQ(e, tbl[4]); +} + +#include <iostream> + +struct value_store { + mutable std::map<std::string, double> values; + + std::string new_var() { + return "v"+std::to_string(values.size()); + } + + double& assign(symbol s, double v) { + if (name(s).empty()) { + const_cast<symbol_table*>(s.table())->name(s, new_var()); + } + + return values[name(s)] = v; + } + + double eval(symbol s) { + return values[name(s)]; + } + + double eval(symbol_term t) { + return t.is_zero()? 0.0: eval(t.left)*eval(t.right); + } + + double eval(symbol_term_diff d) { + return eval(d.left)-eval(d.right); + } + + double& operator[](symbol s) { + if (name(s).empty() || !values.count(name(s))) { + return assign(s, 0.0); + } + return values[name(s)]; + } +}; + + +TEST(symge, gj_reduce_3x3) { + // solve: + // + // | 2 0 3 | | x | | 6 | + // | 0 4 0 | | y | = | 7 | + // | 0 -1 5 | | z | | 8 | + // + // with expected answer: + // + // x = 3/40; y = 7/4; z = 39/20 + + symbol_table tbl; + auto a = tbl.define("a"); + auto b = tbl.define("b"); + auto c = tbl.define("c"); + auto d = tbl.define("d"); + auto e = tbl.define("e"); + auto p = tbl.define("p"); + auto q = tbl.define("q"); + auto r = tbl.define("r"); + + sym_matrix A(3,3); + A[0] = sym_row({{0, a}, {2, b}}); + A[1] = sym_row({{1, c}}); + A[2] = sym_row({{1, d}, {2, e}}); + + std::vector<symbol> B = { p, q, r }; + A.augment(B); + + gj_reduce(A, tbl); + + value_store v; + v[a] = 2; + v[b] = 3; + v[c] = 4; + v[d] = -1; + v[e] = 5; + v[p] = 6; + v[q] = 7; + v[r] = 8; + + for (unsigned i = 0; i<tbl.size(); ++i) { + symbol s = tbl[i]; + if (!primitive(s)) { + v.assign(s, v.eval(definition(s))); + } + } + + double x = v.eval(A[0][3])/v.eval(A[0][0]); + double y = v.eval(A[1][3])/v.eval(A[1][1]); + double z = v.eval(A[2][3])/v.eval(A[2][2]); + + EXPECT_NEAR(x, 3.0/40.0, 1e-6); + EXPECT_NEAR(y, 7.0/4.0, 1e-6); + EXPECT_NEAR(z, 39.0/20.0, 1e-6); +} diff --git a/tests/modcc/test_visitors.cpp b/tests/modcc/test_visitors.cpp index e2b7dd7caf055cee0bcafeac18f74ddeff8dc9d5..6f1caf64d3a066d80d940cee3530c52a9f1d114d 100644 --- a/tests/modcc/test_visitors.cpp +++ b/tests/modcc/test_visitors.cpp @@ -6,110 +6,116 @@ #include "parser.hpp" #include "modccutil.hpp" +// overload for parser errors +template <typename EPtr> +void verbose_print(const EPtr& e, Parser& p, const char* text) { + verbose_print(e); + if (p.status()==lexerStatus::error) { + verbose_print("in ", cyan(text), "\t", p.error_message()); + } +} + /************************************************************** * visitors **************************************************************/ -using namespace nest::mc; - TEST(FlopVisitor, basic) { { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x+y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 1); + FlopVisitor visitor; + auto e = parse_expression("x+y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x-y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 1); + FlopVisitor visitor; + auto e = parse_expression("x-y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x*y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.mul, 1); + FlopVisitor visitor; + auto e = parse_expression("x*y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.mul, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x/y"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.div, 1); + FlopVisitor visitor; + auto e = parse_expression("x/y"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.div, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("exp(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_expression("exp(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.exp, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("log(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.log, 1); + FlopVisitor visitor; + auto e = parse_expression("log(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.log, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("cos(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.cos, 1); + FlopVisitor visitor; + auto e = parse_expression("cos(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.cos, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("sin(x)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.sin, 1); + FlopVisitor visitor; + auto e = parse_expression("sin(x)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.sin, 1); } } TEST(FlopVisitor, compound) { { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("x+y*z/a-b"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 2); - EXPECT_EQ(visitor->flops.mul, 1); - EXPECT_EQ(visitor->flops.div, 1); + FlopVisitor visitor; + auto e = parse_expression("x+y*z/a-b"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 2); + EXPECT_EQ(visitor.flops.mul, 1); + EXPECT_EQ(visitor.flops.div, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("exp(x+y+z)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 2); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_expression("exp(x+y+z)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 2); + EXPECT_EQ(visitor.flops.exp, 1); } { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_expression("exp(x+y) + 3/(12 + z)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 3); - EXPECT_EQ(visitor->flops.div, 1); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_expression("exp(x+y) + 3/(12 + z)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 3); + EXPECT_EQ(visitor.flops.div, 1); + EXPECT_EQ(visitor.flops.exp, 1); } // test asssignment expression { - auto visitor = util::make_unique<FlopVisitor>(); - auto e = parse_line_expression("x = exp(x+y) + 3/(12 + z)"); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 3); - EXPECT_EQ(visitor->flops.div, 1); - EXPECT_EQ(visitor->flops.exp, 1); + FlopVisitor visitor; + auto e = parse_line_expression("x = exp(x+y) + 3/(12 + z)"); + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 3); + EXPECT_EQ(visitor.flops.div, 1); + EXPECT_EQ(visitor.flops.exp, 1); } } TEST(FlopVisitor, procedure) { - { const char *expression = "PROCEDURE trates(v) {\n" " LOCAL qt\n" @@ -119,20 +125,18 @@ TEST(FlopVisitor, procedure) { " mtau = 0.6\n" " htau = 1500\n" "}"; - auto visitor = util::make_unique<FlopVisitor>(); + FlopVisitor visitor; auto e = parse_procedure(expression); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 6); - EXPECT_EQ(visitor->flops.neg, 0); - EXPECT_EQ(visitor->flops.mul, 0); - EXPECT_EQ(visitor->flops.div, 5); - EXPECT_EQ(visitor->flops.exp, 2); - EXPECT_EQ(visitor->flops.pow, 1); - } + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 6); + EXPECT_EQ(visitor.flops.neg, 0); + EXPECT_EQ(visitor.flops.mul, 0); + EXPECT_EQ(visitor.flops.div, 5); + EXPECT_EQ(visitor.flops.exp, 2); + EXPECT_EQ(visitor.flops.pow, 1); } TEST(FlopVisitor, function) { - { const char *expression = "FUNCTION foo(v) {\n" " LOCAL qt\n" @@ -141,16 +145,15 @@ TEST(FlopVisitor, function) { " hinf=1/(1+exp((v-vhalfh)/kh))\n" " foo = minf + hinf\n" "}"; - auto visitor = util::make_unique<FlopVisitor>(); + FlopVisitor visitor; auto e = parse_function(expression); - e->accept(visitor.get()); - EXPECT_EQ(visitor->flops.add, 7); - EXPECT_EQ(visitor->flops.neg, 1); - EXPECT_EQ(visitor->flops.mul, 0); - EXPECT_EQ(visitor->flops.div, 5); - EXPECT_EQ(visitor->flops.exp, 2); - EXPECT_EQ(visitor->flops.pow, 1); - } + e->accept(&visitor); + EXPECT_EQ(visitor.flops.add, 7); + EXPECT_EQ(visitor.flops.neg, 1); + EXPECT_EQ(visitor.flops.mul, 0); + EXPECT_EQ(visitor.flops.div, 5); + EXPECT_EQ(visitor.flops.exp, 2); + EXPECT_EQ(visitor.flops.pow, 1); } TEST(ClassificationVisitor, linear) { @@ -199,20 +202,14 @@ TEST(ClassificationVisitor, linear) { if( e==nullptr ) continue; e->semantic(scope); - auto v = new ExpressionClassifierVisitor(x); - e->accept(v); - //std::cout << "expression " << e->to_string() << std::endl; - //std::cout << "linear " << v->linear_coefficient()->to_string() << std::endl; - //std::cout << "constant " << v->constant_term()->to_string() << std::endl; - EXPECT_EQ(v->classify(), expressionClassification::linear); - -#ifdef VERBOSE_TEST - std::cout << "eq " << e->to_string() - << "\ncoeff " << v->linear_coefficient()->to_string() - << "\nconst " << v-> constant_term()->to_string() - << "\n----" << std::endl; -#endif - delete v; + ExpressionClassifierVisitor v(x); + e->accept(&v); + EXPECT_EQ(v.classify(), expressionClassification::linear); + + verbose_print("eq ", e); + verbose_print("coeff ", v.linear_coefficient()); + verbose_print("const ", v.constant_term()); + verbose_print("----"); } } @@ -235,23 +232,19 @@ TEST(ClassificationVisitor, constant) { auto x = globals["x"].get(); for(auto const& expression : expressions) { - auto e = parse_expression(expression); + Parser p{expression}; + auto e = p.parse_expression(); // sanity check the compiler EXPECT_NE(e, nullptr); if( e==nullptr ) continue; e->semantic(scope); - auto v = new ExpressionClassifierVisitor(x); - e->accept(v); - EXPECT_EQ(v->classify(), expressionClassification::constant); - -#ifdef VERBOSE_TEST - if(e) std::cout << e->to_string() << std::endl; - if(p.status()==lexerStatus::error) - std::cout << "in " << colorize(expression, kCyan) << "\t" << p.error_message() << std::endl; -#endif - delete v; + ExpressionClassifierVisitor v(x); + e->accept(&v); + EXPECT_EQ(v.classify(), expressionClassification::constant); + + verbose_print(e, p, expression); } } @@ -281,24 +274,20 @@ TEST(ClassificationVisitor, nonlinear) { auto scope = std::make_shared<Scope<Symbol>>(globals); auto x = globals["x"].get(); - auto v = new ExpressionClassifierVisitor(x); + ExpressionClassifierVisitor v(x); for(auto const& expression : expressions) { - auto e = parse_expression(expression); + Parser p{expression}; + auto e = p.parse_expression(); // sanity check the compiler EXPECT_NE(e, nullptr); if( e==nullptr ) continue; e->semantic(scope); - v->reset(); - e->accept(v); - EXPECT_EQ(v->classify(), expressionClassification::nonlinear); - -#ifdef VERBOSE_TEST - if(e) std::cout << e->to_string() << std::endl; - if(p.status()==lexerStatus::error) - std::cout << "in " << colorize(expression, kCyan) << "\t" << p.error_message() << std::endl; -#endif + v.reset(); + e->accept(&v); + EXPECT_EQ(v.classify(), expressionClassification::nonlinear); + + verbose_print(e, p, expression); } - delete v; } diff --git a/tests/validation/CMakeLists.txt b/tests/validation/CMakeLists.txt index f7e9f8494bc2802d3354bbcde5b8c47df27a3ce9..a42fc4ddd65cb0ad6fce7fe8376bac4caa25d48e 100644 --- a/tests/validation/CMakeLists.txt +++ b/tests/validation/CMakeLists.txt @@ -3,6 +3,7 @@ set(VALIDATION_SOURCES validate_ball_and_stick.cpp validate_compartment_policy.cpp validate_soma.cpp + validate_kinetic.cpp validate_synapses.cpp # support code diff --git a/tests/validation/validate_kinetic.cpp b/tests/validation/validate_kinetic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c2005165dac0295542fdc2335147606efc0538a --- /dev/null +++ b/tests/validation/validate_kinetic.cpp @@ -0,0 +1,13 @@ +#include "validate_kinetic.hpp" + +#include "../gtest.h" + +using lowered_cell = nest::mc::fvm::fvm_multicell<nest::mc::multicore::backend>; + +TEST(kinetic, kin1_numeric_ref) { + validate_kinetic_kin1<lowered_cell>(); +} + +TEST(kinetic, kinlva_numeric_ref) { + validate_kinetic_kinlva<lowered_cell>(); +} diff --git a/tests/validation/validate_kinetic.hpp b/tests/validation/validate_kinetic.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7ccde7c93e5609fb12364ec308facecab498e2eb --- /dev/null +++ b/tests/validation/validate_kinetic.hpp @@ -0,0 +1,90 @@ +#include <json/json.hpp> + +#include <common_types.hpp> +#include <cell.hpp> +#include <fvm_multicell.hpp> +#include <model.hpp> +#include <recipe.hpp> +#include <simple_sampler.hpp> +#include <util/rangeutil.hpp> + +#include "../test_common_cells.hpp" +#include "convergence_test.hpp" +#include "trace_analysis.hpp" +#include "validation_data.hpp" + +template <typename LoweredCell> +void run_kinetic_dt(nest::mc::cell& c, float t_end, nlohmann::json meta, const std::string& ref_file) { + using namespace nest::mc; + + float sample_dt = .025f; + sampler_info samplers[] = { + {"soma.mid", {0u, 0u}, simple_sampler(sample_dt)} + }; + + meta["sim"] = "nestmc"; + meta["backend"] = LoweredCell::backend::name(); + convergence_test_runner<float> runner("dt", samplers, meta); + runner.load_reference_data(ref_file); + + model<LoweredCell> model(singleton_recipe{c}); + + auto exclude = stimulus_ends(c); + + // use dt = 0.05, 0.02, 0.01, 0.005, 0.002, ... + double max_oo_dt = std::round(1.0/g_trace_io.min_dt()); + for (double base = 100; ; base *= 10) { + for (double multiple: {5., 2., 1.}) { + double oo_dt = base/multiple; + if (oo_dt>max_oo_dt) goto end; + + model.reset(); + float dt = float(1./oo_dt); + runner.run(model, dt, t_end, dt, exclude); + } + } + +end: + runner.report(); + runner.assert_all_convergence(); +} + +template <typename LoweredCell> +void validate_kinetic_kin1() { + using namespace nest::mc; + + // 20 µm diameter soma with single mechanism, current probe + cell c; + auto soma = c.add_soma(10); + c.add_probe({{0, 0.5}, probeKind::membrane_current}); + soma->add_mechanism(std::string("test_kin1")); + + nlohmann::json meta = { + {"model", "test_kin1"}, + {"name", "membrane current"}, + {"units", "nA"} + }; + + run_kinetic_dt<LoweredCell>(c, 100.f, meta, "numeric_kin1.json"); +} + +template <typename LoweredCell> +void validate_kinetic_kinlva() { + using namespace nest::mc; + + // 20 µm diameter soma with single mechanism, current probe + cell c; + auto soma = c.add_soma(10); + c.add_probe({{0, 0.5}, probeKind::membrane_voltage}); + c.add_stimulus({0,0.5}, {20., 130., -0.025}); + soma->add_mechanism(std::string("test_kinlva")); + + nlohmann::json meta = { + {"model", "test_kinlva"}, + {"name", "membrane voltage"}, + {"units", "mV"} + }; + + run_kinetic_dt<LoweredCell>(c, 300.f, meta, "numeric_kinlva.json"); +} + diff --git a/tests/validation/validate_soma.hpp b/tests/validation/validate_soma.hpp index 2658ee1c9adca61e6efff6b0b9e728aa39cb0a4e..eee44e6bd280d63400ff103e441c38ece4c166a3 100644 --- a/tests/validation/validate_soma.hpp +++ b/tests/validation/validate_soma.hpp @@ -37,10 +37,10 @@ void validate_soma() { float t_end = 100.f; - // use dt = 0.05, 0.025, 0.01, 0.005, 0.0025, ... + // use dt = 0.05, 0.02, 0.01, 0.005, 0.002, ... double max_oo_dt = std::round(1.0/g_trace_io.min_dt()); for (double base = 100; ; base *= 10) { - for (double multiple: {5., 2.5, 1.}) { + for (double multiple: {5., 2., 1.}) { double oo_dt = base/multiple; if (oo_dt>max_oo_dt) goto end; diff --git a/validation/ref/numeric/CMakeLists.txt b/validation/ref/numeric/CMakeLists.txt index 4a86db710b1a5f115d97cbfafbbf2f97f5bec793..510f58ee6db5d7b77ba01dcecc00c94dbf3817ea 100644 --- a/validation/ref/numeric/CMakeLists.txt +++ b/validation/ref/numeric/CMakeLists.txt @@ -1,6 +1,16 @@ # note: function add_validation_data defined in validation/CMakeLists.txt if(NMC_BUILD_JULIA_VALIDATION_DATA) + add_validation_data( + OUTPUT numeric_kin1.json + DEPENDS numeric_kin1.jl + COMMAND ${JULIA_BIN} numeric_kin1.jl) + + add_validation_data( + OUTPUT numeric_kinlva.json + DEPENDS numeric_kinlva.jl LVAChannels.jl + COMMAND ${JULIA_BIN} numeric_kinlva.jl) + add_validation_data( OUTPUT numeric_soma.json DEPENDS numeric_soma.jl HHChannels.jl diff --git a/validation/ref/numeric/LVAChannels.jl b/validation/ref/numeric/LVAChannels.jl new file mode 100644 index 0000000000000000000000000000000000000000..1d11e6b0a9c8f5fd9847e6fbf9c98f687fd21f81 --- /dev/null +++ b/validation/ref/numeric/LVAChannels.jl @@ -0,0 +1,168 @@ +module LVAChannels + +export Stim, run_lva, LVAParam + +using Sundials +using SIUnits +using SIUnits.ShortUnits + +const mS = Milli*Siemens + +immutable LVAParam + c_m # membrane spacific capacitance + gbar # Ca channel cross-membrane conductivity + eca # Ca channel reversal potential + gl # leak conductivity + el # leak reversal potential + q10_1 # rate scaling for 'm' activation gate + q10_2 # rate scaling for 'dsh' inactivation gate + vrest # (as hoc) resting potential + + # constructor with default values, with q10 values + # corresponding to the body temperature current-clamp experiments + LVAParam(; + c_m = 0.01F*m^-2, + gbar = 0.2mS*cm^-2, + eca = 120mV, + gl = .1mS*cm^-2, + el = -65mV, + q10_1 = 5, + q10_2 = 3, +# vrest = -61.47mV + vrest = -63mV + ) = new(c_m, gbar, eca, gl, el, q10_1, q10_2, vrest) + +end + +immutable Stim + t0 # start time of stimulus + t1 # stop time of stimulus + i_e # stimulus current density + + Stim() = new(0s, 0s, 0A/m^2) + Stim(t0, t1, i_e) = new(t0, t1, i_e) +end + +# 'm' activation gate +function m_lims(v, q10) + quotient = 1+exp(-(v+63mV)/7.8mV) + mtau = 1ms*(1.7+exp(-(v+28.8mV)/13.5mV))/(q10*quotient) + minf = 1/quotient + return mtau, minf +end + +# 'dsh' inactivation gate: +# d <-> s (alpha2, beta2) +# s <-> h (alpha1, beta1) +# subject to s=1-h-d +function dsh_lims(v, q10) + k = sqrt(0.25+exp((v+83.5mV)/6.3mV))-0.5 + alpha1 = q10*exp(-(v+160.3mV)/17.8mV)/1ms + beta1 = alpha1*k + + tau2 = 240ms/(1+exp((v+37.4mV)/30mV))/q10 + alpha2 = 1/tau2/(1+k) + beta2 = alpha2*k + + hinf = 1/(1+k+k^2) + dinf = k^2*hinf + + return alpha1, beta1, alpha2, beta2, dinf, hinf +end + +# Choose initial conditions for the system such that the gating variables +# are at steady state for the user-specified voltage v. +function initial_conditions(v, q10_1, q10_2) + mtau, minf = m_lims(v, q10_1) + alpha1, beta1, alpha2, beta2, dinf, hinf = dsh_lims(v, q10_2) + + return (v, minf, dinf, hinf) +end + +# Given time t and state (v, m, d, h), +# return (vdot, mdot, ddot, hdot). +function f(t, state; p=LVAParam(), istim=0A/m^2) + v, m, d, h = state + + ica = p.gbar*m^3*h*(v - p.eca) + il = p.gl*(v - p.el) + itot = ica + il - istim + + # Calculate the voltage dependent rates for the gating variables. + mtau, minf = m_lims(v, p.q10_1) + alpha1, beta1, alpha2, beta2, dinf, hinf = dsh_lims(v, p.q10_2) + + mdot = (minf-m)/mtau + hdot = alpha1*(1-h-d)-beta1*h + ddot = beta2*(1-h-d)-alpha2*d + + return (-itot/p.c_m, mdot, ddot, hdot) +end + +function make_range(t_start, dt, t_end) + r = collect(t_start: dt: t_end) + if length(r)>0 && r[length(r)]<t_end + push!(r, t_end) + end + return r +end + +function run_lva(t_end; stim=Stim(), param=LVAParam(), sample_dt=0.01ms) + v_scale = 1V + t_scale = 1s + + v0, m0, h0, d0 = initial_conditions(param.vrest, param.q10_1, param.q10_2) + y0 = [ v0/v_scale, m0, h0, d0 ] + + + fbis(t, y, ydot, istim) = begin + vdot, mdot, hdot, ddot = + f(t*t_scale, (y[1]*v_scale, y[2], y[3], y[4]), istim=istim, p=param) + + ydot[1], ydot[2], ydot[3], ydot[4] = + vdot*t_scale/v_scale, mdot*t_scale, hdot*t_scale, ddot*t_scale + + return Sundials.CV_SUCCESS + end + + fbis_nostim(t, y, ydot) = fbis(t, y, ydot, 0A/m^2) + fbis_stim(t, y, ydot) = fbis(t, y, ydot, stim.i_e) + + # Ideally would run with vector absolute tolerance to account for v_scale, + # but this would prevent us using the nice cvode wrapper. + + res = [] + samples = [] + + t1 = clamp(stim.t0, 0s, t_end) + if t1>0s + ts = make_range(0s, sample_dt, t1) + r = Sundials.cvode(fbis_nostim, y0, map(t->t/t_scale, ts), abstol=1e-6, reltol=5e-10) + y0 = vec(r[size(r)[1], :]) + push!(res, r) + push!(samples, ts) + end + t2 = clamp(stim.t1, t1, t_end) + if t2>t1 + ts = make_range(t1, sample_dt, t2) + r = Sundials.cvode(fbis_stim, y0, map(t->t/t_scale, ts), abstol=1e-6, reltol=5e-10) + y0 = vec(r[size(r)[1], :]) + push!(res, r) + push!(samples, ts) + end + if t_end>t2 + ts = make_range(t2, sample_dt, t_end) + r = Sundials.cvode(fbis_nostim, y0, map(t->t/t_scale, ts), abstol=1e-6, reltol=5e-10) + y0 = vec(r[size(r)[1], :]) + push!(res, r) + push!(samples, ts) + end + + res = vcat(res...) + samples = vcat(samples...) + + # Use map here because of issues with type deduction with arrays and SIUnits. + return samples, map(v->v*v_scale, res[:, 1]), res[:, 2], res[:, 3], res[:, 4] +end + +end # module LVAChannels diff --git a/validation/ref/numeric/numeric_kin1.jl b/validation/ref/numeric/numeric_kin1.jl new file mode 100644 index 0000000000000000000000000000000000000000..300903147b1e29dbc9603448f2289bbd70ad5ce7 --- /dev/null +++ b/validation/ref/numeric/numeric_kin1.jl @@ -0,0 +1,33 @@ +#!/usr/bin/env julia + +include("HHChannels.jl") + +using JSON +using SIUnits.ShortUnits + +radius = 20µm/2 +area = 4*pi*radius^2 +sample_dt = 0.025ms +t_end = 100ms + +a0 = 0.01mA/cm^2 +c = 0.01mA/cm^2 + +tau = 10ms + +ts = collect(0s: sample_dt: t_end) +is = area*(1/3*c + (a0-1/3*c)*exp(-ts/tau)) + +trace = Dict( + :name => "membrane current", + :sim => "numeric", + :model => "test_kin1", + :units => "nA", + :data => Dict( + :time => map(t->t/ms, ts), + Symbol("soma.mid") => map(i->i/nA, is) + ) +) + +println(JSON.json([trace])) + diff --git a/validation/ref/numeric/numeric_kinlva.jl b/validation/ref/numeric/numeric_kinlva.jl new file mode 100644 index 0000000000000000000000000000000000000000..9dc18ceacd103e77515967189d73a93498bbe98b --- /dev/null +++ b/validation/ref/numeric/numeric_kinlva.jl @@ -0,0 +1,40 @@ +#!/usr/bin/env julia + +include("LVAChannels.jl") + +using JSON +using SIUnits.ShortUnits +using LVAChannels + +radius = 20µm/2 +area = 4*pi*radius^2 +current = -0.025nA + +stim = Stim(20ms, 150ms, current/area) +ts, vs, m, d, h = run_lva(300ms, param=LVAParam(vrest=-65mV), stim=stim, sample_dt=0.025ms) + +trace = Dict( + :name => "membrane voltage", + :sim => "numeric", + :model => "test_kinlva", + :units => "mV", + :data => Dict( + :time => map(t->t/ms, ts), + Symbol("soma.mid") => map(v->v/mV, vs) + ) +) + +state = Dict( + :name => "mechanisms state", + :sim => "numeric", + :model => "kinlva", + :units => "1", + :data => Dict( + :time => map(t->t/ms, ts), + Symbol("m") => m, + Symbol("d") => d, + Symbol("h") => h + ) +) +println(JSON.json([trace, state])) +