diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index 6c4d92010af3b1eb2dc470b171303e95e6322f5f..c0a6537ee9b39abde93fd4bed0936d355e50548a 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -11,6 +11,7 @@ set(libmodcc_sources functioninliner.cpp lexer.cpp kineticrewriter.cpp + linearrewriter.cpp module.cpp parser.cpp solvers.cpp diff --git a/modcc/expression.cpp b/modcc/expression.cpp index 441ea63eb3def23b599efcadd234da7b710bdaf2..ef404d43fd0e75352fb6bfcdd31b22707a170fda 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -359,6 +359,25 @@ void StoichExpression::semantic(scope_ptr scp) { } } +/******************************************************************************* + LinearExpression +*******************************************************************************/ + +expression_ptr LinearExpression::clone() const { + return make_expression<LinearExpression>( + location_, lhs()->clone(), rhs()->clone()); +} + +void LinearExpression::semantic(scope_ptr scp) { + scope_ = scp; + lhs_->semantic(scp); + rhs_->semantic(scp); + + if(rhs_->is_procedure_call()) { + error("procedure calls can't be made in an expression"); + } +} + /******************************************************************************* ConserveExpression *******************************************************************************/ @@ -984,6 +1003,9 @@ void ConserveExpression::accept(Visitor *v) { void ReactionExpression::accept(Visitor *v) { v->visit(this); } +void LinearExpression::accept(Visitor *v) { + v->visit(this); +} void StoichExpression::accept(Visitor *v) { v->visit(this); } diff --git a/modcc/expression.hpp b/modcc/expression.hpp index 0613ddcfecffd649bffcb30dba2dac8041f37c5d..1860c7eea77dc57bf1cbfd3eaa63b0799263e55c 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -33,6 +33,7 @@ class BinaryExpression; class UnaryExpression; class AssignmentExpression; class ConserveExpression; +class LinearExpression; class ReactionExpression; class StoichExpression; class StoichTermExpression; @@ -79,7 +80,8 @@ enum class procedureKind { net_receive, ///< NET_RECEIVE breakpoint, ///< BREAKPOINT kinetic, ///< KINETIC - derivative ///< DERIVATIVE + derivative, ///< DERIVATIVE + linear, ///< LINEAR }; std::string to_string(procedureKind k); @@ -168,6 +170,7 @@ public: virtual UnaryExpression* is_unary() {return nullptr;} virtual AssignmentExpression* is_assignment() {return nullptr;} virtual ConserveExpression* is_conserve() {return nullptr;} + virtual LinearExpression* is_linear() {return nullptr;} virtual ReactionExpression* is_reaction() {return nullptr;} virtual StoichExpression* is_stoich() {return nullptr;} virtual StoichTermExpression* is_stoich_term() {return nullptr;} @@ -1277,6 +1280,20 @@ public: void accept(Visitor *v) override; }; +class LinearExpression : public BinaryExpression { +public: + LinearExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) + : BinaryExpression(loc, tok::eq, std::move(lhs), std::move(rhs)) + {} + + LinearExpression* is_linear() override {return this;} + expression_ptr clone() const override; + + void semantic(scope_ptr scp) override; + + void accept(Visitor *v) override; +}; + class AddBinaryExpression : public BinaryExpression { public: AddBinaryExpression(Location loc, expression_ptr&& lhs, expression_ptr&& rhs) diff --git a/modcc/lexer.cpp b/modcc/lexer.cpp index a2bb90af8d1133956b5917e75592a787b54a0249..1eb684ff9f03f4aeb5d2191cd2ce1551eda1b2a6 100644 --- a/modcc/lexer.cpp +++ b/modcc/lexer.cpp @@ -234,6 +234,30 @@ Token Lexer::peek() { return t; } +bool Lexer::search_to_eol(tok const& t) { + // save the current position + const char *oldpos = current_; + const char *oldlin = line_; + Location oldloc = location_; + + Token p = token_; + bool ret = false; + while (line_ == oldlin && p.type != tok::eof) { + if (p.type == t) { + ret = true; + break; + } + p = parse(); + } + + // reset position + current_ = oldpos; + location_ = oldloc; + line_ = oldlin; + + return ret; +} + // scan floating point number from stream Token Lexer::number() { std::string str; diff --git a/modcc/lexer.hpp b/modcc/lexer.hpp index 142a5d3b1f7483bd548eb0a377e70ced5e9f2a19..f897db5856e0df368e1a840aadb3367f35544d39 100644 --- a/modcc/lexer.hpp +++ b/modcc/lexer.hpp @@ -72,6 +72,9 @@ public: // return the next token in the stream without advancing the current position Token peek(); + // Look for `t` until new line or eof without advancing the current position, return true if found + bool search_to_eol(tok const& t); + // scan a number from the stream Token number(); diff --git a/modcc/linearrewriter.cpp b/modcc/linearrewriter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3705a2ed74ee569ca8bc33a949749ad97d1d54d3 --- /dev/null +++ b/modcc/linearrewriter.cpp @@ -0,0 +1,88 @@ +#include <iostream> +#include <map> +#include <string> +#include <list> + +#include "astmanip.hpp" +#include "symdiff.hpp" +#include "visitor.hpp" + +class LinearRewriter : public BlockRewriterBase { +public: + using BlockRewriterBase::visit; + + LinearRewriter(std::vector<std::string> st_vars): state_vars(st_vars) {} + LinearRewriter(scope_ptr enclosing_scope): BlockRewriterBase(enclosing_scope) {} + + virtual void visit(LinearExpression *e) override; + +protected: + virtual void reset() override { + BlockRewriterBase::reset(); + } + +private: + std::vector<std::string> state_vars; +}; + +expression_ptr linear_rewrite(BlockExpression* block, std::vector<std::string> state_vars) { + LinearRewriter visitor(state_vars); + block->accept(&visitor); + return visitor.as_block(false); +} + +// LinearRewriter implementation follows. + +// Factorize the linear expression in terms of the state variables and place +// the resulting sum of products on the lhs. Place everything else on the rhs +void LinearRewriter::visit(LinearExpression* e) { + Location loc = e->location(); + scope_ptr scope = e->scope(); + + expression_ptr lhs; + for (const auto& state : state_vars) { + // To factorize w.r.t state, differentiate the lhs and rhs + auto ident = make_expression<IdentifierExpression>(loc, state); + auto coeff = constant_simplify(make_expression<SubBinaryExpression>(loc, + symbolic_pdiff(e->lhs(), state), + symbolic_pdiff(e->rhs(), state))); + + if (expr_value(coeff) != 0) { + auto local_coeff = make_unique_local_assign(scope, coeff, "l_"); + statements_.push_back(std::move(local_coeff.local_decl)); + statements_.push_back(std::move(local_coeff.assignment)); + + auto pair = make_expression<MulBinaryExpression>(loc, std::move(local_coeff.id), std::move(ident)); + + // Construct the lhs of the new linear expression + if (!lhs) { + lhs = std::move(pair); + } else { + lhs = make_expression<AddBinaryExpression>(loc, std::move(lhs), std::move(pair)); + } + } + } + + // To find the rhs of the new linear expression, simplify the old + // linear expression with state variables set to zero + auto rhs_0 = e->lhs()->clone(); + auto rhs_1 = e->rhs()->clone(); + + for (auto state: state_vars) { + auto zero_expr = make_expression<NumberExpression>(loc, 0.0); + rhs_0 = substitute(rhs_0, state, zero_expr); + rhs_1 = substitute(rhs_1, state, zero_expr); + } + rhs_0 = constant_simplify(rhs_0); + rhs_1 = constant_simplify(rhs_1); + + auto rhs = constant_simplify(make_expression<SubBinaryExpression>(loc, std::move(rhs_1), std::move(rhs_0))); + + auto local_rhs = make_unique_local_assign(scope, rhs, "l_"); + statements_.push_back(std::move(local_rhs.local_decl)); + statements_.push_back(std::move(local_rhs.assignment)); + + rhs = std::move(local_rhs.id); + + statements_.push_back(make_expression<LinearExpression>(loc, std::move(lhs), std::move(rhs))); +} diff --git a/modcc/linearrewriter.hpp b/modcc/linearrewriter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..32b3386334b87c4799dd86034bb0f747772cf892 --- /dev/null +++ b/modcc/linearrewriter.hpp @@ -0,0 +1,6 @@ +#pragma once + +#include "expression.hpp" + +// Translate a supplied LINEAR block. +expression_ptr linear_rewrite(BlockExpression*, std::vector<std::string>); diff --git a/modcc/module.cpp b/modcc/module.cpp index 86a3ebec8777c7e8e706f86e046bb23563269143..2df8ec16b7840c2256a99b4a105b1ecc2a5728db 100644 --- a/modcc/module.cpp +++ b/modcc/module.cpp @@ -11,6 +11,7 @@ #include "functionexpander.hpp" #include "functioninliner.hpp" #include "kineticrewriter.hpp" +#include "linearrewriter.hpp" #include "module.hpp" #include "parser.hpp" #include "solvers.hpp" @@ -279,7 +280,45 @@ bool Module::semantic() { auto& init_body = api_init->body()->statements(); for(auto& e : *proc_init->body()) { - init_body.emplace_back(e->clone()); + auto solve_expression = e->is_solve_statement(); + if (solve_expression) { + // Grab SOLVE statements, put them in `body` after translation. + std::set<std::string> solved_ids; + std::unique_ptr<SolverVisitorBase> solver = std::make_unique<SparseSolverVisitor>(); + + // The solve expression inside an initial block can only refer to a linear block + auto solve_proc = solve_expression->procedure(); + + if (solve_proc->kind() == procedureKind::linear) { + solver = std::make_unique<LinearSolverVisitor>(state_vars); + linear_rewrite(solve_proc->body(), state_vars)->accept(solver.get()); + } else { + error("A SOLVE expression in an INITIAL block can only be used to solve a LINEAR block, which" + + solve_expression->name() + "is not.", solve_expression->location()); + return false; + } + + if (auto solve_block = solver->as_block(false)) { + // Check that we didn't solve an already solved variable. + for (const auto &id: solver->solved_identifiers()) { + if (solved_ids.count(id) > 0) { + error("Variable " + id + " solved twice!", solve_expression->location()); + return false; + } + solved_ids.insert(id); + } + // Copy body into nrn_init. + for (auto &stmt: solve_block->is_block()->statements()) { + init_body.emplace_back(stmt->clone()); + } + } else { + // Something went wrong: copy errors across. + append_errors(solver->errors()); + return false; + } + } else { + init_body.emplace_back(e->clone()); + } } api_init->semantic(symbols_); @@ -337,6 +376,10 @@ bool Module::semantic() { if (deriv->kind()==procedureKind::kinetic) { kinetic_rewrite(deriv->body())->accept(solver.get()); } + else if (deriv->kind()==procedureKind::linear) { + solver = std::make_unique<LinearSolverVisitor>(state_vars); + linear_rewrite(deriv->body(), state_vars)->accept(solver.get()); + } else { deriv->body()->accept(solver.get()); for (auto& s: deriv->body()->statements()) { diff --git a/modcc/msparse.hpp b/modcc/msparse.hpp index b6e2292cf9b87bf7fa7425fad110b9ea093dd9ac..34e6866826411410dc907064bffa596573dd4ab9 100644 --- a/modcc/msparse.hpp +++ b/modcc/msparse.hpp @@ -185,6 +185,12 @@ public: bool empty() const { return size()==0; } bool augmented() const { return aug!=npos; } + void clear() { + rows.clear(); + cols = 0; + aug = row_npos; + } + // Add a column on the right as part of the augmented submatrix. // The new entries are provided by a (full, dense representation) // sequence of values. diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 0e9e064bd512d41c1f3fd1d8d20337437bd1d42a..0202d0b1bb54dbe45aab5db2d27f7d20a81ea42b 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -122,6 +122,7 @@ bool Parser::parse() { case tok::breakpoint : case tok::initial : case tok::kinetic : + case tok::linear : case tok::derivative : case tok::procedure : { @@ -905,6 +906,12 @@ symbol_ptr Parser::parse_procedure() { if( !expect( tok::identifier ) ) return nullptr; p = parse_prototype(); break; + case tok::linear: + kind = procedureKind::linear; + get_token(); // consume keyword token + if( !expect( tok::identifier ) ) return nullptr; + p = parse_prototype(); + break; case tok::procedure: kind = procedureKind::normal; get_token(); // consume keyword token @@ -994,7 +1001,7 @@ expression_ptr Parser::parse_statement() { case tok::conserve : return parse_conserve_expression(); case tok::tilde : - return parse_reaction_expression(); + return parse_tilde_expression(); case tok::initial : // only used for INITIAL block in NET_RECEIVE return parse_initial(); @@ -1183,76 +1190,94 @@ expression_ptr Parser::parse_stoich_expression() { return make_expression<StoichExpression>(here, std::move(terms)); } -expression_ptr Parser::parse_reaction_expression() { +expression_ptr Parser::parse_tilde_expression() { auto here = location_; if(token_.type!=tok::tilde) { error(pprintf("expected '%', found '%'", yellow("~"), yellow(token_.spelling))); return nullptr; } - get_token(); // consume tilde - expression_ptr lhs = parse_stoich_expression(); - if (!lhs) return nullptr; - // reaction halves must comprise non-negative terms - for (const auto& term: lhs->is_stoich()->terms()) { - // should always be true - if (auto sterm = term->is_stoich_term()) { - if (sterm->negative()) { - error(pprintf("expected only non-negative terms in reaction lhs, found '%'", - yellow(term->to_string()))); - return nullptr; + if (search_to_eol(tok::arrow)) { + expression_ptr lhs = parse_stoich_expression(); + if (!lhs) return nullptr; + + // reaction halves must comprise non-negative terms + for (const auto& term: lhs->is_stoich()->terms()) { + // should always be true + if (auto sterm = term->is_stoich_term()) { + if (sterm->negative()) { + error(pprintf("expected only non-negative terms in reaction lhs, found '%'", + yellow(term->to_string()))); + return nullptr; + } } } - } - if(token_.type != tok::arrow) { - error(pprintf("expected '%', found '%'", yellow("<->"), yellow(token_.spelling))); - return nullptr; - } + if(token_.type != tok::arrow) { + error(pprintf("expected '%', found '%'", yellow("<->"), yellow(token_.spelling))); + return nullptr; + } - get_token(); // consume arrow - expression_ptr rhs = parse_stoich_expression(); - if (!rhs) return nullptr; + get_token(); // consume arrow + expression_ptr rhs = parse_stoich_expression(); + if (!rhs) return nullptr; - for (const auto& term: rhs->is_stoich()->terms()) { - // should always be true - if (auto sterm = term->is_stoich_term()) { - if (sterm->negative()) { - error(pprintf("expected only non-negative terms in reaction rhs, found '%'", - yellow(term->to_string()))); - return nullptr; + for (const auto& term: rhs->is_stoich()->terms()) { + // should always be true + if (auto sterm = term->is_stoich_term()) { + if (sterm->negative()) { + error(pprintf("expected only non-negative terms in reaction rhs, found '%'", + yellow(term->to_string()))); + return nullptr; + } } } - } - if(token_.type != tok::lparen) { - error(pprintf("expected '%', found '%'", yellow("("), yellow(token_.spelling))); - return nullptr; - } + if (token_.type != tok::lparen) { + error(pprintf("expected '%', found '%'", yellow("("), yellow(token_.spelling))); + return nullptr; + } - get_token(); // consume lparen - expression_ptr fwd = parse_expression(); - if (!fwd) return nullptr; + get_token(); // consume lparen + expression_ptr fwd = parse_expression(); + if (!fwd) return nullptr; - if(token_.type != tok::comma) { - error(pprintf("expected '%', found '%'", yellow(","), yellow(token_.spelling))); - return nullptr; + if (token_.type != tok::comma) { + error(pprintf("expected '%', found '%'", yellow(","), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume comma + expression_ptr rev = parse_expression(); + if (!rev) return nullptr; + + if (token_.type != tok::rparen) { + error(pprintf("expected '%', found '%'", yellow(")"), yellow(token_.spelling))); + return nullptr; + } + + get_token(); // consume rparen + return make_expression<ReactionExpression>(here, std::move(lhs), std::move(rhs), + std::move(fwd), std::move(rev)); } + else if (search_to_eol(tok::eq)) { + auto lhs_bin = parse_expression(); - get_token(); // consume comma - expression_ptr rev = parse_expression(); - if (!rev) return nullptr; + if(token_.type!=tok::eq) { + error(pprintf("expected '%', found '%'", yellow("="), yellow(token_.spelling))); + return nullptr; + } - if(token_.type != tok::rparen) { - error(pprintf("expected '%', found '%'", yellow(")"), yellow(token_.spelling))); + get_token(); // consume = + auto rhs = parse_expression(); + return make_expression<LinearExpression>(here, std::move(lhs_bin), std::move(rhs)); + } + else { + error(pprintf("expected stoichiometric or linear expression, found neither")); return nullptr; } - - get_token(); // consume rparen - return make_expression<ReactionExpression>(here, std::move(lhs), std::move(rhs), - std::move(fwd), std::move(rev)); } expression_ptr Parser::parse_conserve_expression() { @@ -1286,8 +1311,7 @@ expression_ptr Parser::parse_expression(int prec) { // Combine all sub-expressions with precedence greater than prec. for (;;) { if(token_.type==tok::eq) { - error("assignment '"+yellow("=")+"' not allowed in sub-expression"); - return nullptr; + return lhs; } auto op = token_; diff --git a/modcc/parser.hpp b/modcc/parser.hpp index 3c2a297a2cea5e814323104845687c9bc88d62e4..ad841b57f99ec03fab04433688782d6df2cc5f6c 100644 --- a/modcc/parser.hpp +++ b/modcc/parser.hpp @@ -28,7 +28,7 @@ public: expression_ptr parse_line_expression(); expression_ptr parse_stoich_expression(); expression_ptr parse_stoich_term(); - expression_ptr parse_reaction_expression(); + expression_ptr parse_tilde_expression(); expression_ptr parse_conserve_expression(); expression_ptr parse_binop(expression_ptr&&, Token); expression_ptr parse_unaryop(); diff --git a/modcc/solvers.cpp b/modcc/solvers.cpp index 96872a51e4def816d828094a6835d0618c14d93f..6f77091263e5bc873bfe2b703f726c6aa018817e 100644 --- a/modcc/solvers.cpp +++ b/modcc/solvers.cpp @@ -368,14 +368,108 @@ void SparseSolverVisitor::finalize() { // State variable updates given by rhs/diagonal for reduced matrix. Location loc; for (unsigned i = 0; i<A_.nrow(); ++i) { - unsigned rhs = A_.augcol(); + const symge::sym_row& row = A_[i]; + unsigned rhs_col = A_.augcol(); + unsigned lhs_col; + for (unsigned r = 0; r < A_.nrow(); r++) { + if (row[r]) { + lhs_col = r; + break; + } + } auto expr = make_expression<AssignmentExpression>(loc, - make_expression<IdentifierExpression>(loc, dvars_[i]), + make_expression<IdentifierExpression>(loc, dvars_[lhs_col]), make_expression<DivBinaryExpression>(loc, - make_expression<IdentifierExpression>(loc, symge::name(A_[i][rhs])), - make_expression<IdentifierExpression>(loc, symge::name(A_[i][i])))); + make_expression<IdentifierExpression>(loc, symge::name(A_[i][rhs_col])), + make_expression<IdentifierExpression>(loc, symge::name(A_[i][lhs_col])))); + + statements_.push_back(std::move(expr)); + } + + BlockRewriterBase::finalize(); +} + +void LinearSolverVisitor::visit(BlockExpression* e) { + BlockRewriterBase::visit(e); +} + +void LinearSolverVisitor::visit(AssignmentExpression *e) { + statements_.push_back(e->clone()); + return; +} + +void LinearSolverVisitor::visit(LinearExpression *e) { + auto loc = e->location(); + scope_ptr scope = e->scope(); + + if (A_.empty()) { + unsigned n = dvars_.size(); + A_ = symge::sym_matrix(n, n); + } + + linear_test_result r = linear_test(e->lhs(), dvars_); + if (!r.is_homogeneous) { + error({"System not homogeneous linear for sparse", loc}); + return; + } + + for (unsigned j = 0; j<dvars_.size(); ++j) { + expression_ptr expr; + + if (r.coef.count(dvars_[j])) { + expr = r.coef[dvars_[j]]->clone(); + } + + if (!expr) continue; + + auto a_ = expr->is_identifier()->spelling(); + + A_[deq_index_].push_back({j, symtbl_.define(a_)}); + } + rhs_.push_back(symtbl_.define(e->rhs()->is_identifier()->spelling())); + ++deq_index_; +} +void LinearSolverVisitor::finalize() { + A_.augment(rhs_); + + symge::gj_reduce(A_, symtbl_); + + // Create and assign intermediate variables. + for (unsigned i = 0; i<symtbl_.size(); ++i) { + symge::symbol s = symtbl_[i]; + + if (primitive(s)) continue; + + auto expr = as_expression(definition(s)); + auto local_t_term = make_unique_local_assign(block_scope_, expr.get(), "t_"); + auto t_ = local_t_term.id->is_identifier()->spelling(); + symtbl_.name(s, t_); + + statements_.push_back(std::move(local_t_term.local_decl)); + statements_.push_back(std::move(local_t_term.assignment)); + } + + // State variable updates given by rhs/diagonal for reduced matrix. + Location loc; + for (unsigned i = 0; i < A_.nrow(); ++i) { + const symge::sym_row& row = A_[i]; + unsigned rhs = A_.augcol(); + unsigned lhs; + for (unsigned r = 0; r < A_.nrow(); r++) { + if (row[r]) { + lhs = r; + break; + } + } + + auto expr = + make_expression<AssignmentExpression>(loc, + make_expression<IdentifierExpression>(loc, dvars_[lhs]), + make_expression<DivBinaryExpression>(loc, + make_expression<IdentifierExpression>(loc, symge::name(A_[i][rhs])), + make_expression<IdentifierExpression>(loc, symge::name(A_[i][lhs])))); statements_.push_back(std::move(expr)); } diff --git a/modcc/solvers.hpp b/modcc/solvers.hpp index e31084dfc0a87137ca0d46c276543d076aa08087..1198d1b10461b6e7b0db6f5f03526b472044aed2 100644 --- a/modcc/solvers.hpp +++ b/modcc/solvers.hpp @@ -101,10 +101,52 @@ public: virtual void reset() override { deq_index_ = 0; local_expr_.clear(); + A_.clear(); symtbl_.clear(); + conserve_ = false; conserve_rhs_.clear(); conserve_idx_.clear(); - conserve_ = false; + SolverVisitorBase::reset(); + } +}; + +class LinearSolverVisitor : public SolverVisitorBase { +protected: + // 'Current' differential equation is for variable with this + // index in `dvars`. + unsigned deq_index_ = 0; + + // Expanded local assignments that need to be substituted in for derivative + // calculations. + substitute_map local_expr_; + + // Symbolic matrix for backwards Euler step. + symge::sym_matrix A_; + + // RHS + std::vector<symge::symbol> rhs_; + + // 'Symbol table' for symbolic manipulation. + symge::symbol_table symtbl_; + +public: + using SolverVisitorBase::visit; + + LinearSolverVisitor(std::vector<std::string> vars) { + dvars_ = vars; + } + LinearSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {} + + virtual void visit(BlockExpression* e) override; + virtual void visit(LinearExpression *e) override; + virtual void visit(AssignmentExpression *e) override; + virtual void finalize() override; + virtual void reset() override { + deq_index_ = 0; + local_expr_.clear(); + A_.clear(); + rhs_.clear(); + symtbl_.clear(); SolverVisitorBase::reset(); } }; diff --git a/modcc/symge.cpp b/modcc/symge.cpp index 4f09a0226ad0975b36aaaa2887cd331136f8ebf2..51a9008cd9b774eed6eebe1baceb2980f66b5aaf 100644 --- a/modcc/symge.cpp +++ b/modcc/symge.cpp @@ -1,11 +1,18 @@ #include <algorithm> #include <stdexcept> #include <vector> +#include <numeric> #include "symge.hpp" namespace symge { +struct pivot { + unsigned row; + unsigned col; +}; + + // Returns q[c]*p - p[c]*q; new symbols required due to fill-in are provided by the // `define_sym` functor, which takes a `symbol_term_diff` and returns a `symbol`. @@ -45,7 +52,7 @@ sym_row row_reduce(unsigned c, const sym_row& p, const sym_row& q, DefineSym def // Estimate cost of a choice of pivot for G–J reduction below. Uses a simple greedy // estimate based on immediate fill cost. -double estimate_cost(const sym_matrix& A, unsigned p) { +double estimate_cost(const sym_matrix& A, pivot p) { unsigned nfill = 0; auto count_fill = [&nfill](symbol_term_diff t) { @@ -56,8 +63,8 @@ double estimate_cost(const sym_matrix& A, unsigned p) { }; for (unsigned i = 0; i<A.nrow(); ++i) { - if (i==p || A[i].index(p)==msparse::row_npos) continue; - row_reduce(p, A[i], A[p], count_fill); + if (i==p.row || A[i].index(p.col)==msparse::row_npos) continue; + row_reduce(p.col, A[i], A[p.row], count_fill); } return nfill; @@ -78,30 +85,48 @@ void gj_reduce(sym_matrix& A, symbol_table& table) { auto define_sym = [&table](symbol_term_diff t) { return table.define(t); }; - std::vector<unsigned> pivots; - for (unsigned r = 0; r<A.nrow(); ++r) { - pivots.push_back(r); - } + auto get_pivots = [&A](const std::vector<unsigned>& remaining_rows) { + std::vector<pivot> pivots; + for (auto r: remaining_rows) { + pivot p; + p.row = r; + const sym_row &row = A[r]; + for (unsigned c = 0; c < A.nrow(); ++c) { + if (row[c]) { + p.col = c; + break; + } + } + pivots.push_back(std::move(p)); + } + return pivots; + }; - std::vector<double> cost(pivots.size()); + std::vector<unsigned> remaining_rows(A.nrow()); + std::iota(remaining_rows.begin(), remaining_rows.end(), 0); + + std::vector<double> cost(A.nrow()); + + while (true) { + auto pivots = get_pivots(remaining_rows); - while (!pivots.empty()) { for (unsigned i = 0; i<pivots.size(); ++i) { - cost[pivots[i]] = estimate_cost(A, pivots[i]); + cost[pivots[i].row] = estimate_cost(A, pivots[i]); } std::sort(pivots.begin(), pivots.end(), - [&](unsigned r1, unsigned r2) { return cost[r1]>cost[r2]; }); + [&](pivot r1, pivot r2) { return cost[r1.row]>cost[r2.row]; }); - unsigned pivrow = pivots.back(); - pivots.erase(std::prev(pivots.end())); - - unsigned pivcol = pivrow; + pivot p = pivots.back(); + remaining_rows.erase(std::lower_bound(remaining_rows.begin(), remaining_rows.end(), p.row)); for (unsigned i = 0; i<A.nrow(); ++i) { - if (i==pivrow || A[i].index(pivcol)==msparse::row_npos) continue; + if (i==p.row || A[i].index(p.col)==msparse::row_npos) continue; + A[i] = row_reduce(p.col, A[i], A[p.row], define_sym); + } - A[i] = row_reduce(pivcol, A[i], A[pivrow], define_sym); + if (remaining_rows.empty()) { + break; } } } diff --git a/modcc/token.cpp b/modcc/token.cpp index 169bbc512b20c4bb692b172564b76495ae38fb41..58915c160f3d3899352286c615a459fec749ad4a 100644 --- a/modcc/token.cpp +++ b/modcc/token.cpp @@ -31,6 +31,7 @@ static Keyword keywords[] = { {"BREAKPOINT", tok::breakpoint}, {"DERIVATIVE", tok::derivative}, {"KINETIC", tok::kinetic}, + {"LINEAR", tok::linear}, {"PROCEDURE", tok::procedure}, {"FUNCTION", tok::function}, {"INITIAL", tok::initial}, @@ -102,6 +103,7 @@ static TokenString token_strings[] = { {"BREAKPOINT", tok::breakpoint}, {"DERIVATIVE", tok::derivative}, {"KINETIC", tok::kinetic}, + {"LINEAR", tok::linear}, {"PROCEDURE", tok::procedure}, {"FUNCTION", tok::function}, {"INITIAL", tok::initial}, diff --git a/modcc/token.hpp b/modcc/token.hpp index ae03b3d57324c671c478ad028237c90f0758a3a7..06b2ad6bb76e0db908c7bafeb773df406f8eeb46 100644 --- a/modcc/token.hpp +++ b/modcc/token.hpp @@ -53,7 +53,7 @@ enum class tok { title, neuron, units, parameter, constant, assigned, state, breakpoint, - derivative, kinetic, procedure, initial, function, + derivative, kinetic, procedure, initial, function, linear, net_receive, // keywoards inside blocks diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index b6ae70cab102d7495c415089f0fa050fc0fe0886..8f7dd6e47e66b56f821e426a5a4b92f60320d54c 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -52,6 +52,7 @@ public: virtual void visit(ConditionalExpression *e) {visit((BinaryExpression*) e); } virtual void visit(AssignmentExpression *e) { visit((BinaryExpression*) e); } virtual void visit(ConserveExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(LinearExpression *e) { visit((BinaryExpression*) e); } virtual void visit(AddBinaryExpression *e) { visit((BinaryExpression*) e); } virtual void visit(SubBinaryExpression *e) { visit((BinaryExpression*) e); } virtual void visit(MulBinaryExpression *e) { visit((BinaryExpression*) e); } diff --git a/test/unit-modcc/test_msparse.cpp b/test/unit-modcc/test_msparse.cpp index 9b97187f5143e603553a605ce26e75ff9f5515bd..be3dc6e04285e9d4157f1428ec15077b3ccf4d44 100644 --- a/test/unit-modcc/test_msparse.cpp +++ b/test/unit-modcc/test_msparse.cpp @@ -129,6 +129,12 @@ TEST(msparse, matrix_ctor) { dmatrix M3(M2); EXPECT_EQ(5u, M3.nrow()); EXPECT_EQ(3u, M3.ncol()); + + M2.clear(); + EXPECT_EQ(0u, M2.size()); + EXPECT_EQ(0u, M2.nrow()); + EXPECT_EQ(0u, M2.ncol()); + EXPECT_TRUE(M2.empty()); } TEST(msparse, matrix_index) { diff --git a/test/unit-modcc/test_parser.cpp b/test/unit-modcc/test_parser.cpp index f251d202a408f2491b5f6f6e6c79d1409a0d1cec..ad76e7d941f822c807be43ff04f72ce86f8d248d 100644 --- a/test/unit-modcc/test_parser.cpp +++ b/test/unit-modcc/test_parser.cpp @@ -466,7 +466,7 @@ TEST(Parser, parse_reaction_expression) { for (auto& text: good_expr) { std::unique_ptr<ReactionExpression> s; - EXPECT_TRUE(check_parse(s, &Parser::parse_reaction_expression, text)); + EXPECT_TRUE(check_parse(s, &Parser::parse_tilde_expression, text)); } const char* bad_expr[] = { @@ -483,7 +483,7 @@ TEST(Parser, parse_reaction_expression) { }; for (auto& text: bad_expr) { - EXPECT_TRUE(check_parse_fail(&Parser::parse_reaction_expression, text)); + EXPECT_TRUE(check_parse_fail(&Parser::parse_tilde_expression, text)); } } diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 1e042fbb902b89aa858127b349cbbd75a4208fa0..751647b42a6046c3450477c3182ac65d3d3d56d4 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -2,6 +2,9 @@ set(test_mechanisms celsius_test + test_linear_state + test_linear_init + test_linear_init_shuffle test0_kin_diff test0_kin_conserve test1_kin_diff @@ -83,7 +86,7 @@ set(unit_sources test_fvm_lowered.cpp test_glob_basic.cpp test_mc_cell_group.cpp - test_kinetic.cpp + test_kinetic_linear.cpp test_lexcmp.cpp test_lif_cell_group.cpp test_maputil.cpp diff --git a/test/unit/mod/test_linear_init.mod b/test/unit/mod/test_linear_init.mod new file mode 100644 index 0000000000000000000000000000000000000000..089f0453d0ccec33bde9e104cf8469dee6082a8e --- /dev/null +++ b/test/unit/mod/test_linear_init.mod @@ -0,0 +1,33 @@ +NEURON { + SUFFIX test_linear_init + RANGE a0, a1, a3, a3 +} + +STATE { + s d h +} + +PARAMETER { + a4 = 7.3 +} + +ASSIGNED { + a0 : = 2.5 + a1 : = 0.5 + a2 : = 3 + a3 : = 2.3 +} + +BREAKPOINT { + s = a1 +} + +INITIAL { + SOLVE sinit +} + +LINEAR sinit { + ~ (a4 - a3)*d - a2*h = 0 + ~ (a0 + a1)*s - (-a1 + a0)*d = 0 + ~ s + d + h = 1 +} diff --git a/test/unit/mod/test_linear_init_shuffle.mod b/test/unit/mod/test_linear_init_shuffle.mod new file mode 100644 index 0000000000000000000000000000000000000000..07d4ce954b033109359e1d18917ac72b9721b5f8 --- /dev/null +++ b/test/unit/mod/test_linear_init_shuffle.mod @@ -0,0 +1,33 @@ +NEURON { + SUFFIX test_linear_init_shuffle + RANGE a0, a1, a3, a3 +} + +STATE { + s d h +} + +PARAMETER { + a4 = 7.3 +} + +ASSIGNED { + a0 : = 2.5 + a1 : = 0.5 + a2 : = 3 + a3 : = 2.3 +} + +BREAKPOINT { + s = a1 +} + +INITIAL { + SOLVE sinit +} + +LINEAR sinit { + ~ a4*d - a3*d - a2*h = 0 + ~ a0*s - a0*d = - a1*s - a1*d + ~ s + d + h = 1 +} diff --git a/test/unit/mod/test_linear_state.mod b/test/unit/mod/test_linear_state.mod new file mode 100644 index 0000000000000000000000000000000000000000..1b0f8235091deb3cc4916624f61f4c2221d959f2 --- /dev/null +++ b/test/unit/mod/test_linear_state.mod @@ -0,0 +1,29 @@ +NEURON { + SUFFIX test_linear_state + RANGE a0, a1, a3, a3 +} + +STATE { + s d h +} + +PARAMETER { + a4 = 7.3 +} + +ASSIGNED { + a0 : = 2.5 + a1 : = 0.5 + a2 : = 3 + a3 : = 2.3 +} + +BREAKPOINT { + SOLVE sinit +} + +LINEAR sinit { + ~ (a4 - a3)*d - a2*h = 0 + ~ (a0 + a1)*s - (-a1 + a0)*d = 0 + ~ s + d + h = 1 +} diff --git a/test/unit/test_kinetic.cpp b/test/unit/test_kinetic.cpp deleted file mode 100644 index 71f1a636f03cea4998d895aa4d460e58e2f99479..0000000000000000000000000000000000000000 --- a/test/unit/test_kinetic.cpp +++ /dev/null @@ -1,117 +0,0 @@ -#include <vector> - -#include <arbor/mechanism.hpp> -#include <arbor/version.hpp> - -#include "backends/multicore/fvm.hpp" - -#ifdef ARB_GPU_ENABLED -#include "backends/gpu/fvm.hpp" -#endif - -#include "common.hpp" -#include "mech_private_field_access.hpp" -#include "fvm_lowered_cell.hpp" -#include "fvm_lowered_cell_impl.hpp" -#include "sampler_map.hpp" -#include "simple_recipes.hpp" -#include "unit_test_catalogue.hpp" - -using namespace arb; - -using backend = arb::multicore::backend; -using fvm_cell = arb::fvm_lowered_cell_impl<backend>; - -using shared_state = backend::shared_state; -ACCESS_BIND(std::unique_ptr<shared_state> fvm_cell::*, private_state_ptr, &fvm_cell::state_) - -template <typename backend> -void run_kinetic_test(std::string mech_name, - std::vector<std::string> variables, - std::vector<fvm_value_type> t0_values, - std::vector<fvm_value_type> t1_values) { - - auto cat = make_unit_test_catalogue(); - - fvm_size_type ncell = 1; - fvm_size_type ncv = 1; - std::vector<fvm_index_type> cv_to_intdom(ncv, 0); - - std::vector<fvm_gap_junction> gj = {}; - auto instance = cat.instance<backend>(mech_name); - auto& kinetic_test = instance.mech; - - std::vector<fvm_value_type> temp(ncv, 300.); - std::vector<fvm_value_type> vinit(ncv, -65); - - auto shared_state = std::make_unique<typename backend::shared_state>( - ncell, cv_to_intdom, gj, vinit, temp, kinetic_test->data_alignment()); - - mechanism_layout layout; - mechanism_overrides overrides; - - layout.weight.assign(ncv, 1.); - for (fvm_size_type i = 0; i<ncv; ++i) { - layout.cv.push_back(i); - } - - kinetic_test->instantiate(0, *shared_state, overrides, layout); - shared_state->reset(); - - kinetic_test->initialize(); - - for (unsigned i = 0; i < variables.size(); i++) { - for (unsigned j = 0; j < ncv; j++) { - EXPECT_NEAR(t0_values[i], mechanism_field(kinetic_test.get(), variables[i]).at(j), 1e-6); - } - } - - shared_state->update_time_to(0.5, 0.5); - shared_state->set_dt(); - - kinetic_test->nrn_state(); - - for (unsigned i = 0; i < variables.size(); i++) { - for (unsigned j = 0; j < ncv; j++) { - EXPECT_NEAR(t1_values[i], mechanism_field(kinetic_test.get(), variables[i]).at(j), 1e-6); - } - } -} - -TEST(mech_kinetic, kinetic_1_conserve) { - std::vector<std::string> variables = {"s", "h", "d"}; - std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; - std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247}; - - run_kinetic_test<multicore::backend>("test0_kin_diff", variables, t0_values, t1_values); - run_kinetic_test<multicore::backend>("test0_kin_conserve", variables, t0_values, t1_values); -} - -TEST(mech_kinetic, kinetic_2_conserve) { - std::vector<std::string> variables = {"a", "b", "x", "y"}; - std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4}; - std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; - - run_kinetic_test<multicore::backend>("test1_kin_diff", variables, t0_values, t1_values); - run_kinetic_test<multicore::backend>("test1_kin_conserve", variables, t0_values, t1_values); -} - -#ifdef ARB_GPU_ENABLED -TEST(mech_kinetic_gpu, kinetic_1_conserve) { - std::vector<std::string> variables = {"s", "h", "d"}; - std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; - std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247}; - - run_kinetic_test<gpu::backend>("test0_kin_diff", variables, t0_values, t1_values); - run_kinetic_test<gpu::backend>("test0_kin_conserve", variables, t0_values, t1_values); -} - -TEST(mech_kinetic_gpu, kinetic_2_conserve) { - std::vector<std::string> variables = {"a", "b", "x", "y"}; - std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4}; - std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; - - run_kinetic_test<gpu::backend>("test1_kin_diff", variables, t0_values, t1_values); - run_kinetic_test<gpu::backend>("test1_kin_conserve", variables, t0_values, t1_values); -} -#endif diff --git a/test/unit/test_kinetic_linear.cpp b/test/unit/test_kinetic_linear.cpp new file mode 100644 index 0000000000000000000000000000000000000000..16dddc4eda39e8a51bc021473a4f1ed39df17fbf --- /dev/null +++ b/test/unit/test_kinetic_linear.cpp @@ -0,0 +1,148 @@ +#include <vector> + +#include <arbor/mechanism.hpp> +#include <arbor/version.hpp> + +#include "backends/multicore/fvm.hpp" + +#ifdef ARB_GPU_ENABLED +#include "backends/gpu/fvm.hpp" +#endif + +#include "common.hpp" +#include "mech_private_field_access.hpp" +#include "fvm_lowered_cell.hpp" +#include "fvm_lowered_cell_impl.hpp" +#include "sampler_map.hpp" +#include "simple_recipes.hpp" +#include "unit_test_catalogue.hpp" + +using namespace arb; + +using backend = arb::multicore::backend; +using fvm_cell = arb::fvm_lowered_cell_impl<backend>; + +using shared_state = backend::shared_state; +ACCESS_BIND(std::unique_ptr<shared_state> fvm_cell::*, private_state_ptr, &fvm_cell::state_) + +template <typename backend> +void run_test(std::string mech_name, + std::vector<std::string> state_variables, + std::unordered_map<std::string, fvm_value_type> assigned_variables, + std::vector<fvm_value_type> t0_values, + std::vector<fvm_value_type> t1_values) { + + auto cat = make_unit_test_catalogue(); + + fvm_size_type ncell = 1; + fvm_size_type ncv = 1; + std::vector<fvm_index_type> cv_to_intdom(ncv, 0); + + std::vector<fvm_gap_junction> gj = {}; + auto instance = cat.instance<backend>(mech_name); + auto& test = instance.mech; + + std::vector<fvm_value_type> temp(ncv, 300.); + std::vector<fvm_value_type> vinit(ncv, -65); + + auto shared_state = std::make_unique<typename backend::shared_state>( + ncell, cv_to_intdom, gj, vinit, temp, test->data_alignment()); + + mechanism_layout layout; + mechanism_overrides overrides; + + layout.weight.assign(ncv, 1.); + for (fvm_size_type i = 0; i<ncv; ++i) { + layout.cv.push_back(i); + } + + test->instantiate(0, *shared_state, overrides, layout); + + for (auto a: assigned_variables) { + test->set_parameter(a.first, std::vector<fvm_value_type>(ncv,a.second)); + } + + shared_state->reset(); + + test->initialize(); + + if (!t0_values.empty()) { + for (unsigned i = 0; i < state_variables.size(); i++) { + for (unsigned j = 0; j < ncv; j++) { + EXPECT_NEAR(t0_values[i], mechanism_field(test.get(), state_variables[i]).at(j), 1e-6); + } + } + } + + shared_state->update_time_to(0.5, 0.5); + shared_state->set_dt(); + + test->nrn_state(); + + if (!t1_values.empty()) { + for (unsigned i = 0; i < state_variables.size(); i++) { + for (unsigned j = 0; j < ncv; j++) { + EXPECT_NEAR(t1_values[i], mechanism_field(test.get(), state_variables[i]).at(j), 1e-6); + } + } + } +} + +TEST(mech_kinetic, kintetic_1_conserve) { + std::vector<std::string> state_variables = {"s", "h", "d"}; + std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; + std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247}; + + run_test<multicore::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_values); + run_test<multicore::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_kinetic, kintetic_2_conserve) { + std::vector<std::string> state_variables = {"a", "b", "x", "y"}; + std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4}; + std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + + run_test<multicore::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_values); + run_test<multicore::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_linear, linear) { + std::vector<std::string> state_variables = {"h", "s", "d"}; + std::vector<fvm_value_type> values = {0.5, 0.2, 0.3}; + std::unordered_map<std::string, fvm_value_type> assigned_variables = {{"a0", 2.5}, {"a1",0.5}, {"a2",3}, {"a3",2.3}}; + + run_test<multicore::backend>("test_linear_state", state_variables, assigned_variables, {}, values); + run_test<multicore::backend>("test_linear_init", state_variables, assigned_variables, values, {}); + run_test<multicore::backend>("test_linear_init_shuffle", state_variables, assigned_variables, values, {}); +} + +#ifdef ARB_GPU_ENABLED +TEST(mech_kinetic_gpu, kintetic_1_conserve) { + std::vector<std::string> state_variables = {"s", "h", "d"}; + std::vector<fvm_value_type> t0_values = {0.5, 0.2, 0.3}; + std::vector<fvm_value_type> t1_values = {0.380338, 0.446414, 0.173247}; + + run_test<gpu::backend>("test0_kin_diff", state_variables, {}, t0_values, t1_values); + run_test<gpu::backend>("test0_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_kinetic_gpu, kintetic_2_conserve) { + std::vector<std::string> state_variables = {"a", "b", "x", "y"}; + std::vector<fvm_value_type> t0_values = {0.2, 0.8, 0.6, 0.4}; + std::vector<fvm_value_type> t1_values = {0.217391304, 0.782608696, 0.33333333, 0.66666666}; + + run_test<gpu::backend>("test1_kin_diff", state_variables, {}, t0_values, t1_values); + run_test<gpu::backend>("test1_kin_conserve", state_variables, {}, t0_values, t1_values); +} + +TEST(mech_linear_gpu, linear) { + std::vector<std::string> state_variables = {"h", "s", "d"}; + std::vector<fvm_value_type> values = {0.5, 0.2, 0.3}; + std::unordered_map<std::string, fvm_value_type> assigned_variables = {{"a0", 2.5},{"a1",0.5},{"a2",3},{"a3",2.3}}; + + run_test<gpu::backend>("test_linear_state", state_variables, assigned_variables, {}, values); + run_test<gpu::backend>("test_linear_init", state_variables, assigned_variables, values, {}); + run_test<gpu::backend>("test_linear_init_shuffle", state_variables, assigned_variables, values, {}); +} + +#endif diff --git a/test/unit/unit_test_catalogue.cpp b/test/unit/unit_test_catalogue.cpp index ccd007e3591192edad4139097df90dc8cfd18dd9..5e9606c2a4eb5e45755f143de76096bd341d7f0a 100644 --- a/test/unit/unit_test_catalogue.cpp +++ b/test/unit/unit_test_catalogue.cpp @@ -9,6 +9,9 @@ #include "unit_test_catalogue.hpp" #include "mechanisms/celsius_test.hpp" #include "mechanisms/test0_kin_diff.hpp" +#include "mechanisms/test_linear_state.hpp" +#include "mechanisms/test_linear_init.hpp" +#include "mechanisms/test_linear_init_shuffle.hpp" #include "mechanisms/test0_kin_conserve.hpp" #include "mechanisms/test1_kin_diff.hpp" #include "mechanisms/test1_kin_conserve.hpp" @@ -40,6 +43,9 @@ mechanism_catalogue make_unit_test_catalogue() { mechanism_catalogue cat; ADD_MECH(cat, celsius_test) + ADD_MECH(cat, test_linear_state) + ADD_MECH(cat, test_linear_init) + ADD_MECH(cat, test_linear_init_shuffle) ADD_MECH(cat, test0_kin_diff) ADD_MECH(cat, test0_kin_conserve) ADD_MECH(cat, test1_kin_diff)