diff --git a/modcc/CMakeLists.txt b/modcc/CMakeLists.txt index aa98b8617c6581eee3fe386eaaf168b964c56d96..2dcf6d351d13cd23f3c18513e71c9eb41eb48e21 100644 --- a/modcc/CMakeLists.txt +++ b/modcc/CMakeLists.txt @@ -1,17 +1,18 @@ set(MODCC_SOURCES - token.cpp - lexer.cpp - expression.cpp - parser.cpp - textbuffer.cpp + astmanip.cpp + constantfolder.cpp cprinter.cpp - functionexpander.cpp - functioninliner.cpp cudaprinter.cpp - expressionclassifier.cpp - constantfolder.cpp errorvisitor.cpp + expression.cpp + expressionclassifier.cpp + functionexpander.cpp + functioninliner.cpp + lexer.cpp module.cpp + parser.cpp + textbuffer.cpp + token.cpp ) add_library(compiler ${MODCC_SOURCES}) diff --git a/modcc/astmanip.cpp b/modcc/astmanip.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e723c675a1a755358f5aa88e260773c1a919c61c --- /dev/null +++ b/modcc/astmanip.cpp @@ -0,0 +1,30 @@ +#include <string> + +#include "astmanip.hpp" +#include "expression.hpp" +#include "location.hpp" +#include "scope.hpp" + +static std::string unique_local_name(scope_ptr scope, std::string const& prefix) { + for (int i = 0; ; ++i) { + std::string name = prefix + std::to_string(i) + "_"; + if (!scope->find(name)) return name; + } +} + +local_assignment make_unique_local_assign(scope_ptr scope, Expression* e, std::string const& prefix) { + Location loc = e->location(); + std::string name = unique_local_name(scope, prefix); + + auto local = make_expression<LocalDeclaration>(loc, name); + local->semantic(scope); + + auto id = make_expression<IdentifierExpression>(loc, name); + id->semantic(scope); + + auto ass = binary_expression(e->location(), tok::eq, id->clone(), e->clone()); + ass->semantic(scope); + + return { std::move(local), std::move(ass), std::move(id), scope }; +} + diff --git a/modcc/astmanip.hpp b/modcc/astmanip.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b3135cfc54366a3ce9b0fa6d6cc8874f952949ae --- /dev/null +++ b/modcc/astmanip.hpp @@ -0,0 +1,38 @@ +#pragma once + +// Helper utilities for manipulating/modifying AST. + +#include <string> + +#include "expression.hpp" +#include "location.hpp" +#include "scope.hpp" + +// Create new local variable symbol and local declaration expression in current scope. +// Returns the local declaration expression. +expression_ptr make_unique_local_decl(scope_ptr scope, Location loc, std::string const& prefix="ll"); + +struct local_assignment { + expression_ptr local_decl; + expression_ptr assignment; + expression_ptr id; + scope_ptr scope; +}; + +// Create a local declaration as for `make_unique_local_decl`, together with an +// assignment to it from the given expression, using the location of that expression. +// Returns local declaration expression, assignment expression, new identifier id and +// consequent scope. +local_assignment make_unique_local_assign( + scope_ptr scope, + Expression* e, + std::string const& prefix="ll"); + +inline local_assignment make_unique_local_assign( + scope_ptr scope, + expression_ptr& e, + std::string const& prefix="ll") +{ + return make_unique_local_assign(scope, e.get(), prefix); +} + diff --git a/modcc/expression.cpp b/modcc/expression.cpp index e2694ea6f36e5dc33622e9d4fbe725fbd13d24d4..7c3eca9f87e0abac094fc95f4cc8e0514ad0b0a5 100644 --- a/modcc/expression.cpp +++ b/modcc/expression.cpp @@ -43,7 +43,7 @@ inline std::string to_string(procedureKind k) { Expression *******************************************************************************/ -void Expression::semantic(std::shared_ptr<scope_type>) { +void Expression::semantic(scope_ptr) { error("semantic() has not been implemented for this expression"); } @@ -77,7 +77,7 @@ std::string LocalVariable::to_string() const { IdentifierExpression *******************************************************************************/ -void IdentifierExpression::semantic(std::shared_ptr<scope_type> scp) { +void IdentifierExpression::semantic(scope_ptr scp) { scope_ = scp; auto s = scope_->find(spelling_); @@ -119,6 +119,14 @@ bool IdentifierExpression::is_global_lvalue() const { return false; } +/******************************************************************************* + DerivativeExpression +********************************************************************************/ + +expression_ptr DerivativeExpression::clone() const { + return make_expression<DerivativeExpression>(location_, spelling_); +} + /******************************************************************************* NumberExpression ********************************************************************************/ @@ -165,7 +173,7 @@ bool LocalDeclaration::add_variable(Token tok) { return true; } -void LocalDeclaration::semantic(std::shared_ptr<scope_type> scp) { +void LocalDeclaration::semantic(scope_ptr scp) { scope_ = scp; // loop over the variables declared in this LOCAL statement @@ -206,7 +214,7 @@ std::string ArgumentExpression::to_string() const { return blue("arg") + " " + yellow(name_); } -void ArgumentExpression::semantic(std::shared_ptr<scope_type> scp) { +void ArgumentExpression::semantic(scope_ptr scp) { scope_ = scp; auto s = scope_->find(name_); @@ -270,7 +278,7 @@ expression_ptr ReactionExpression::clone() const { location_, lhs()->clone(), rhs()->clone(), fwd_rate()->clone(), rev_rate()->clone()); } -void ReactionExpression::semantic(std::shared_ptr<scope_type> scp) { +void ReactionExpression::semantic(scope_ptr scp) { scope_ = scp; lhs()->semantic(scp); rhs()->semantic(scp); @@ -291,7 +299,7 @@ expression_ptr StoichTermExpression::clone() const { location_, coeff()->clone(), ident()->clone()); } -void StoichTermExpression::semantic(std::shared_ptr<scope_type> scp) { +void StoichTermExpression::semantic(scope_ptr scp) { scope_ = scp; ident()->semantic(scp); } @@ -320,7 +328,7 @@ std::string StoichExpression::to_string() const { return s; } -void StoichExpression::semantic(std::shared_ptr<scope_type> scp) { +void StoichExpression::semantic(scope_ptr scp) { scope_ = scp; for(auto& e: terms()) { e->semantic(scp); @@ -336,7 +344,7 @@ expression_ptr ConserveExpression::clone() const { location_, lhs()->clone(), rhs()->clone()); } -void ConserveExpression::semantic(std::shared_ptr<scope_type> scp) { +void ConserveExpression::semantic(scope_ptr scp) { scope_ = scp; lhs_->semantic(scp); rhs_->semantic(scp); @@ -359,7 +367,7 @@ std::string CallExpression::to_string() const { return str; } -void CallExpression::semantic(std::shared_ptr<scope_type> scp) { +void CallExpression::semantic(scope_ptr scp) { scope_ = scp; // look up to see if symbol is defined @@ -608,7 +616,7 @@ void FunctionExpression::semantic(scope_type::symbol_map &global_symbols) { /******************************************************************************* UnaryExpression *******************************************************************************/ -void UnaryExpression::semantic(std::shared_ptr<scope_type> scp) { +void UnaryExpression::semantic(scope_ptr scp) { scope_ = scp; expression_->semantic(scp); @@ -629,7 +637,7 @@ expression_ptr UnaryExpression::clone() const { /******************************************************************************* BinaryExpression *******************************************************************************/ -void BinaryExpression::semantic(std::shared_ptr<scope_type> scp) { +void BinaryExpression::semantic(scope_ptr scp) { scope_ = scp; lhs_->semantic(scp); rhs_->semantic(scp); @@ -660,7 +668,7 @@ std::string BinaryExpression::to_string() const { AssignmentExpression *******************************************************************************/ -void AssignmentExpression::semantic(std::shared_ptr<scope_type> scp) { +void AssignmentExpression::semantic(scope_ptr scp) { scope_ = scp; lhs_->semantic(scp); rhs_->semantic(scp); @@ -680,7 +688,7 @@ void AssignmentExpression::semantic(std::shared_ptr<scope_type> scp) { SolveExpression *******************************************************************************/ -void SolveExpression::semantic(std::shared_ptr<scope_type> scp) { +void SolveExpression::semantic(scope_ptr scp) { scope_ = scp; auto e = scp->find(name()); @@ -708,7 +716,7 @@ expression_ptr SolveExpression::clone() const { ConductanceExpression *******************************************************************************/ -void ConductanceExpression::semantic(std::shared_ptr<scope_type> scp) { +void ConductanceExpression::semantic(scope_ptr scp) { scope_ = scp; // For now do nothing with the CONDUCTANCE statement, because it is not needed // to optimize conductance calculation. @@ -740,7 +748,7 @@ std::string BlockExpression::to_string() const { return str; } -void BlockExpression::semantic(std::shared_ptr<scope_type> scp) { +void BlockExpression::semantic(scope_ptr scp) { scope_ = scp; for(auto& e : statements_) { e->semantic(scope_); @@ -771,7 +779,7 @@ std::string IfExpression::to_string() const { return s; } -void IfExpression::semantic(std::shared_ptr<scope_type> scp) { +void IfExpression::semantic(scope_ptr scp) { scope_ = scp; condition_->semantic(scp); diff --git a/modcc/expression.hpp b/modcc/expression.hpp index a7d84e61a6a72220e4bbc34e692de703e1681a94..c20e00062a9fb37acaa91c4f3fd8ebdf47f323c8 100644 --- a/modcc/expression.hpp +++ b/modcc/expression.hpp @@ -58,6 +58,8 @@ class LocalVariable; using expression_ptr = std::unique_ptr<Expression>; using symbol_ptr = std::unique_ptr<Symbol>; +using scope_type = Scope<Symbol>; +using scope_ptr = std::shared_ptr<scope_type>; template <typename T, typename... Args> expression_ptr make_expression(Args&&... args) { @@ -113,8 +115,6 @@ static std::string to_string(solverMethod m) { class Expression { public: - using scope_type = Scope<Symbol>; - explicit Expression(Location location) : location_(location) {} @@ -125,9 +125,12 @@ public: // expressions must provide a method for stringification virtual std::string to_string() const = 0; - Location const& location() const {return location_;}; + Location const& location() const { return location_; } + + scope_ptr scope() { return scope_; } - std::shared_ptr<scope_type> scope() {return scope_;}; + // set scope explicitly + void scope(scope_ptr s) { scope_ = s; } void error(std::string const& str) { error_ = true; @@ -143,7 +146,7 @@ public: std::string const& warning_message() const { return warning_string_; } // perform semantic analysis - virtual void semantic(std::shared_ptr<scope_type>); + virtual void semantic(scope_ptr); virtual void semantic(scope_type::symbol_map&) { throw compiler_exception("unable to perform semantic analysis for " + this->to_string(), location_); }; @@ -194,8 +197,7 @@ protected: std::string warning_string_; Location location_; - - std::shared_ptr<scope_type> scope_; + scope_ptr scope_; }; class Symbol : public Expression { @@ -263,7 +265,7 @@ public: expression_ptr clone() const override; - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; Symbol* symbol() { return symbol_; }; @@ -300,6 +302,9 @@ public: std::string to_string() const override { return blue("diff") + "(" + yellow(spelling()) + ")"; } + + expression_ptr clone() const override; + DerivativeExpression* is_derivative() override { return this; } ~DerivativeExpression() {} @@ -325,7 +330,7 @@ public: } // do nothing for number semantic analysis - void semantic(std::shared_ptr<scope_type> scp) override {}; + void semantic(scope_ptr scp) override {}; expression_ptr clone() const override; NumberExpression* is_number() override {return this;} @@ -355,7 +360,7 @@ public: } // do nothing for number semantic analysis - void semantic(std::shared_ptr<scope_type> scp) override {}; + void semantic(scope_ptr scp) override {}; expression_ptr clone() const override; IntegerExpression* is_integer() override {return this;} @@ -386,7 +391,7 @@ public: bool add_variable(Token name); LocalDeclaration* is_local_declaration() override {return this;} - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; std::vector<Symbol*>& symbols() {return symbols_;} std::map<std::string, Token>& variables() {return vars_;} expression_ptr clone() const override; @@ -411,7 +416,7 @@ public: bool add_variable(Token name); ArgumentExpression* is_argument() override {return this;} - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; Token token() {return token_;} std::string const& name() {return name_;} void set_name(std::string const& n) { @@ -668,7 +673,7 @@ public: expression_ptr clone() const override; - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; void accept(Visitor *v) override; ~SolveExpression() {} @@ -709,7 +714,7 @@ public: expression_ptr clone() const override; - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; void accept(Visitor *v) override; ~ConductanceExpression() {} @@ -767,7 +772,7 @@ public: return is_nested_; } - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; void accept(Visitor* v) override; std::string to_string() const override; @@ -795,7 +800,7 @@ public: expression_ptr clone() const override; std::string to_string() const override; - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; void accept(Visitor* v) override; private: @@ -848,7 +853,7 @@ public: ReactionExpression* is_reaction() override {return this;} std::string to_string() const override; - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; expression_ptr clone() const override; void accept(Visitor *v) override; @@ -885,7 +890,7 @@ public: std::string to_string() const override { return pprintf("% %", coeff()->to_string(), ident()->to_string()); } - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; expression_ptr clone() const override; void accept(Visitor *v) override; @@ -918,7 +923,7 @@ public: StoichExpression* is_stoich() override {return this;} std::string to_string() const override; - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; expression_ptr clone() const override; void accept(Visitor *v) override; @@ -943,7 +948,7 @@ public: std::string& name() { return spelling_; } std::string const& name() const { return spelling_; } - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; expression_ptr clone() const override; std::string to_string() const override; @@ -1138,7 +1143,7 @@ public: UnaryExpression* is_unary() override {return this;}; Expression* expression() {return expression_.get();} const Expression* expression() const {return expression_.get();} - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; void accept(Visitor *v) override; void replace_expression(expression_ptr&& other); }; @@ -1220,7 +1225,7 @@ public: const Expression* lhs() const {return lhs_.get();} const Expression* rhs() const {return rhs_.get();} BinaryExpression* is_binary() override {return this;} - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; expression_ptr clone() const override; void replace_rhs(expression_ptr&& other); void replace_lhs(expression_ptr&& other); @@ -1236,7 +1241,7 @@ public: AssignmentExpression* is_assignment() override {return this;} - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; void accept(Visitor *v) override; }; @@ -1250,7 +1255,7 @@ public: ConserveExpression* is_conserve() override {return this;} expression_ptr clone() const override; - void semantic(std::shared_ptr<scope_type> scp) override; + void semantic(scope_ptr scp) override; void accept(Visitor *v) override; }; diff --git a/modcc/functionexpander.cpp b/modcc/functionexpander.cpp index 1ebd9d4908824ab85dc122f283679b3d52fb3672..afa91755df537fed97a0f7639843f833a6de2f13 100644 --- a/modcc/functionexpander.cpp +++ b/modcc/functionexpander.cpp @@ -1,9 +1,17 @@ #include <iostream> +#include "astmanip.hpp" #include "error.hpp" #include "functionexpander.hpp" #include "modccutil.hpp" +expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* e) { + auto exprs = make_unique_local_assign(e->scope(), e); + stmts.push_front(std::move(exprs.local_decl)); + stmts.push_back(std::move(exprs.assignment)); + return std::move(exprs.id); +} + /////////////////////////////////////////////////////////////////////////////// // function call site lowering /////////////////////////////////////////////////////////////////////////////// @@ -96,22 +104,6 @@ void FunctionCallLowerer::visit(BinaryExpression *e) { /////////////////////////////////////////////////////////////////////////////// // function argument lowering /////////////////////////////////////////////////////////////////////////////// -Symbol* make_unique_local(std::shared_ptr<Scope<Symbol>> scope) { - std::string name; - auto i = 0; - do { - name = pprintf("ll%_", i); - ++i; - } while(scope->find(name)); - - return - scope->add_local_symbol( - name, - make_symbol<LocalVariable>( - Location(), name, localVariableKind::local - ) - ); -} call_list_type lower_function_arguments(std::vector<expression_ptr>& args) @@ -130,28 +122,10 @@ lower_function_arguments(std::vector<expression_ptr>& args) continue; } - // use the source location of the original statement - auto loc = e->location(); - - // make an identifier for the new symbol which will store the result of - // the function call - auto id = make_expression<IdentifierExpression> - (loc, make_unique_local(e->scope())->name()); - id->semantic(e->scope()); - - // generate a LOCAL declaration for the variable - new_statements.push_front( - make_expression<LocalDeclaration>(loc, id->is_identifier()->spelling()) - ); - - // make a binary expression which assigns the argument to the variable - auto ass = binary_expression(loc, tok::eq, id->clone(), e->clone()); - ass->semantic(e->scope()); + auto id = insert_unique_local_assignment(new_statements, e.get()); #ifdef LOGGING - std::cout << " lowering to " << ass->to_string() << "\n"; + std::cout << " lowering to " << new_statements.back()->to_string() << "\n"; #endif - new_statements.push_back(std::move(ass)); - // replace the function call in the original expression with the local // variable which holds the pre-computed value std::swap(e, id); diff --git a/modcc/functionexpander.hpp b/modcc/functionexpander.hpp index 38c799de24506732e1399ed64c2cb585a85defca..c9185332d458a9595d077fb6e841f9f5dbd5c9b8 100644 --- a/modcc/functionexpander.hpp +++ b/modcc/functionexpander.hpp @@ -8,6 +8,11 @@ // storage for a list of expressions using call_list_type = std::list<expression_ptr>; +// Make a local declaration and assignment for the given expression, +// and insert at the front and back respectively of the statement list. +// Return the new unique local identifier. +expression_ptr insert_unique_local_assignment(call_list_type& stmts, Expression* e); + // prototype for lowering function calls call_list_type lower_function_calls(Expression* e); @@ -31,11 +36,8 @@ call_list_type lower_function_calls(Expression* e); // the function call will have been fully lowered /////////////////////////////////////////////////////////////////////////////// class FunctionCallLowerer : public Visitor { - public: - using scope_type = Scope<Symbol>; - - FunctionCallLowerer(std::shared_ptr<scope_type> s) + FunctionCallLowerer(scope_ptr s) : scope_(s) {} @@ -57,53 +59,16 @@ public: ~FunctionCallLowerer() {} private: - Symbol* make_unique_local() { - std::string name; - auto i = 0; - do { - name = pprintf("ll%_", i); - ++i; - } while(scope_->find(name)); - - auto sym = - scope_->add_local_symbol( - name, - make_symbol<LocalVariable>( - Location(), name, localVariableKind::local - ) - ); - - return sym; - } - template< typename F> void expand_call(CallExpression* func, F replacer) { - // use the source location of the original statement - auto loc = func->location(); - - // make an identifier for the new symbol which will store the result of - // the function call - auto id = make_expression<IdentifierExpression> - (loc, make_unique_local()->name()); - id->semantic(scope_); - // generate a LOCAL declaration for the variable - calls_.push_front( - make_expression<LocalDeclaration>(loc, id->is_identifier()->spelling()) - ); - calls_.front()->semantic(scope_); - - // make a binary expression which assigns the function to the variable - auto ass = binary_expression(loc, tok::eq, id->clone(), func->clone()); - ass->semantic(scope_); - calls_.push_back(std::move(ass)); - + auto id = insert_unique_local_assignment(calls_, func); // replace the function call in the original expression with the local // variable which holds the pre-computed value replacer(std::move(id)); } call_list_type calls_; - std::shared_ptr<scope_type> scope_; + scope_ptr scope_; }; /////////////////////////////////////////////////////////////////////////////// diff --git a/modcc/kinrewriter.hpp b/modcc/kinrewriter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2ccf5c8aa00dee993fe180fd025397f06f3107eb --- /dev/null +++ b/modcc/kinrewriter.hpp @@ -0,0 +1,196 @@ +#pragma once + +#include <iostream> +#include <map> +#include <string> +#include <list> + +#include "astmanip.hpp" +#include "visitor.hpp" + +using stmt_list_type = std::list<expression_ptr>; + +class KineticRewriter : public Visitor { +public: + virtual void visit(Expression *) override; + + virtual void visit(UnaryExpression *e) override { visit((Expression*)e); } + virtual void visit(BinaryExpression *e) override { visit((Expression*)e); } + + virtual void visit(ConserveExpression *e) override; + virtual void visit(ReactionExpression *e) override; + virtual void visit(BlockExpression *e) override; + virtual void visit(ProcedureExpression* e) override; + + symbol_ptr as_procedure() { + stmt_list_type body_stmts; + for (const auto& s: statements) body_stmts.push_back(s->clone()); + + auto body = make_expression<BlockExpression>( + proc_loc, + std::move(body_stmts), + false); + + return make_symbol<ProcedureExpression>( + proc_loc, + proc_name, + std::vector<expression_ptr>(), + std::move(body)); + } + +private: + // Name and location of original kinetic procedure (used for `as_procedure` above). + std::string proc_name; + Location proc_loc; + + // Statements in replacement procedure body. + stmt_list_type statements; + + // Acccumulated terms for derivative expressions, keyed by id name. + std::map<std::string, expression_ptr> dterms; + + // Reset state (at e.g. start of kinetic proc). + void reset() { + proc_name = ""; + statements.clear(); + dterms.clear(); + } +}; + +// By default, copy statements across verbatim. +inline void KineticRewriter::visit(Expression* e) { + statements.push_back(e->clone()); +} + +inline void KineticRewriter::visit(ConserveExpression*) { + // Deliberately ignoring these for now! +} + +inline void KineticRewriter::visit(ReactionExpression* e) { + Location loc = e->location(); + scope_ptr scope = e->scope(); + + // Total forward rate is the specified forward reaction rate constant, multiplied + // by the concentrations of species present in the left hand side. + + auto fwd = e->fwd_rate()->clone(); + auto lhs = e->lhs()->is_stoich(); + for (const auto& term: lhs->terms()) { + auto& id = term->is_stoich_term()->ident(); + auto& coeff = term->is_stoich_term()->coeff(); + + fwd = make_expression<MulBinaryExpression>( + loc, + make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), + std::move(fwd)); + } + + // Similar for reverse rate. + + auto rev = e->rev_rate()->clone(); + auto rhs = e->rhs()->is_stoich(); + for (const auto& term: rhs->terms()) { + auto& id = term->is_stoich_term()->ident(); + auto& coeff = term->is_stoich_term()->coeff(); + + rev = make_expression<MulBinaryExpression>( + loc, + make_expression<PowBinaryExpression>(loc, id->clone(), coeff->clone()), + std::move(rev)); + } + + auto net_rate = make_expression<SubBinaryExpression>( + loc, + std::move(fwd), std::move(rev)); + net_rate->semantic(scope); + + auto local_net_rate = make_unique_local_assign(scope, net_rate, "rate"); + statements.push_back(std::move(local_net_rate.local_decl)); + statements.push_back(std::move(local_net_rate.assignment)); + scope = local_net_rate.scope; // nop for now... + + auto net_rate_sym = std::move(local_net_rate.id); + + // Net change in quantity after forward reaction: + // e.g. A + ... <-> 3A + ... + // has a net delta of 2 for A. + + std::map<std::string, long long int> net_delta; + + for (const auto& term: lhs->terms()) { + auto sterm = term->is_stoich_term(); + auto name = sterm->ident()->is_identifier()->name(); + net_delta[name] -= sterm->coeff()->is_integer()->integer_value(); + } + + for (const auto& term: rhs->terms()) { + auto sterm = term->is_stoich_term(); + auto name = sterm->ident()->is_identifier()->name(); + net_delta[name] += sterm->coeff()->is_integer()->integer_value(); + } + + // Contribution to final ODE for each species is given by + // net_rate * net_delta. + + for (auto& p: net_delta) { + if (p.second==0) continue; + + auto term = make_expression<MulBinaryExpression>( + loc, + make_expression<IntegerExpression>(loc, p.second), + net_rate_sym->clone()); + term->semantic(scope); + + auto local_term = make_unique_local_assign(scope, term, p.first+"_rate"); + statements.push_back(std::move(local_term.local_decl)); + statements.push_back(std::move(local_term.assignment)); + scope = local_term.scope; // nop for now... + + auto& dterm = dterms[p.first]; + if (!dterm) { + dterm = std::move(local_term.id); + } + else { + dterm = make_expression<AddBinaryExpression>( + loc, + std::move(dterm), + std::move(local_term.id)); + + // don't actually want to overwrite scope of previous terms + // in dterm sum, so set expression 'scope' directly. + dterm->scope(scope); + } + } +} + +inline void KineticRewriter::visit(ProcedureExpression* e) { + reset(); + proc_name = e->name(); + proc_loc = e->location(); + e->body()->accept(this); + + // make new procedure from saved statements and terms + for (auto& p: dterms) { + auto loc = p.second->location(); + auto scope = p.second->scope(); + + auto deriv = make_expression<DerivativeExpression>( + loc, + p.first); + deriv->semantic(scope); + + auto assign = make_expression<AssignmentExpression>( + loc, + std::move(deriv), + std::move(p.second)); + + assign->scope(scope); // don't re-do semantic analysis here + statements.push_back(std::move(assign)); + } +} + +inline void KineticRewriter::visit(BlockExpression* e) { + for (auto& s: e->statements()) { + s->accept(this); + } +} diff --git a/modcc/module.hpp b/modcc/module.hpp index 1cd1cfe89b4e31d1bb18e3e3cedb6e173d049ee2..5a16e64c45c75adbe17a2f88c169c73722b26cb4 100644 --- a/modcc/module.hpp +++ b/modcc/module.hpp @@ -9,7 +9,6 @@ // wrapper around a .mod file class Module { public : - using scope_type = Expression::scope_type; using symbol_map = scope_type::symbol_map; using symbol_ptr = scope_type::symbol_ptr; diff --git a/modcc/parser.cpp b/modcc/parser.cpp index 43855a7ef1ad0915e7312aea048c23448015e06a..bebef194dca3d5ca5233de67051a83e16dc1fbe5 100644 --- a/modcc/parser.cpp +++ b/modcc/parser.cpp @@ -753,9 +753,9 @@ symbol_ptr Parser::parse_procedure() { break; default: // it is a compiler error if trying to parse_procedure() without - // having DERIVATIVE, PROCEDURE, INITIAL or BREAKPOINT keyword + // having DERIVATIVE, KINETIC, PROCEDURE, INITIAL or BREAKPOINT keyword throw compiler_exception( - "attempt to parser_procedure() without {DERIVATIVE,PROCEDURE,INITIAL,BREAKPOINT}", + "attempt to parse_procedure() without {DERIVATIVE,KINETIC,PROCEDURE,INITIAL,BREAKPOINT}", location_); } if(p==nullptr) return nullptr; diff --git a/modcc/visitor.hpp b/modcc/visitor.hpp index ae5c42dcd29f7760a0b8dcc533fdb80e9561ed0f..c474dc89fda572abad331c9e7254241a7f31230b 100644 --- a/modcc/visitor.hpp +++ b/modcc/visitor.hpp @@ -25,6 +25,9 @@ public: virtual void visit(ArgumentExpression *e) { visit((Expression*) e); } virtual void visit(PrototypeExpression *e) { visit((Expression*) e); } virtual void visit(CallExpression *e) { visit((Expression*) e); } + virtual void visit(ReactionExpression *e) { visit((Expression*) e); } + virtual void visit(StoichTermExpression *e) { visit((Expression*) e); } + virtual void visit(StoichExpression *e) { visit((Expression*) e); } virtual void visit(VariableExpression *e) { visit((Expression*) e); } virtual void visit(IndexedVariable *e) { visit((Expression*) e); } virtual void visit(FunctionExpression *e) { visit((Expression*) e); } @@ -46,11 +49,13 @@ public: virtual void visit(BinaryExpression *e) = 0; virtual void visit(AssignmentExpression *e) { visit((BinaryExpression*) e); } + virtual void visit(ConserveExpression *e) { visit((BinaryExpression*) e); } virtual void visit(AddBinaryExpression *e) { visit((BinaryExpression*) e); } virtual void visit(SubBinaryExpression *e) { visit((BinaryExpression*) e); } virtual void visit(MulBinaryExpression *e) { visit((BinaryExpression*) e); } virtual void visit(DivBinaryExpression *e) { visit((BinaryExpression*) e); } virtual void visit(PowBinaryExpression *e) { visit((BinaryExpression*) e); } + virtual ~Visitor() {}; }; diff --git a/tests/modcc/CMakeLists.txt b/tests/modcc/CMakeLists.txt index 044eec0a09c3ebdb530fd7b8196a52024da6127e..40c35687bb961926ba1eebe47a477f60d6159f4f 100644 --- a/tests/modcc/CMakeLists.txt +++ b/tests/modcc/CMakeLists.txt @@ -1,6 +1,7 @@ set(MODCC_TEST_SOURCES # unit tests test_lexer.cpp + test_kinetic_rewriter.cpp test_module.cpp test_optimization.cpp test_parser.cpp @@ -9,6 +10,9 @@ set(MODCC_TEST_SOURCES # unit test driver driver.cpp + + # utility + expr_expand.cpp ) add_definitions("-DDATADIR=\"${CMAKE_SOURCE_DIR}/data\"") diff --git a/tests/modcc/alg_collect.hpp b/tests/modcc/alg_collect.hpp new file mode 100644 index 0000000000000000000000000000000000000000..36f4d3dca20350c7fcddfaa12988065815e93096 --- /dev/null +++ b/tests/modcc/alg_collect.hpp @@ -0,0 +1,313 @@ +#pragma once + +#include <algorithm> +#include <cmath> +#include <sstream> +#include <string> + +// Simple algebraic term expansion/collection routines. + +namespace alg { + +template <typename Prim, typename Num> +struct collectable { + Prim prim; + Num n; + + collectable(): n(0) {} + collectable(const Prim& prim): prim(prim), n(1) {} + collectable(const Prim& prim, Num n): prim(prim), n(n) {} + + friend bool operator<(const collectable& a, const collectable& b) { + return a.prim<b.prim || (a.prim==b.prim && a.n<b.n); + } + + friend bool operator==(const collectable& a, const collectable& b) { + return a.prim==b.prim && a.n==b.n; + } + + friend bool operator!=(const collectable& a, const collectable& b) { + return !(a==b); + } + + void invert() { n = -n; } +}; + +template <typename Prim, typename Num> +void collect(std::vector<collectable<Prim, Num>>& xs) { + std::sort(xs.begin(), xs.end()); + if (xs.size()<2) return; + + std::vector<collectable<Prim, Num>> coll; + coll.push_back(xs[0]); + + for (unsigned j=1; j<xs.size(); ++j) { + const auto& x = xs[j]; + if (coll.back().prim!=x.prim) { + coll.push_back(x); + } + else { + coll.back().n += x.n; + } + } + + xs.clear(); + for (auto& t: coll) { + if (t.n!=0) xs.push_back(std::move(t)); + } +} + +template <typename Prim, typename Num> +void invert(std::vector<collectable<Prim, Num>>& xs) { + for (auto& x: xs) x.invert(); +} + +struct prodterm { + using factor = collectable<std::string, double>; + + std::vector<factor> factors; + + prodterm() {} + explicit prodterm(factor f): factors(1, f) {} + explicit prodterm(const std::vector<factor>& factors): factors(factors) {} + + void collect() { alg::collect(factors); } + void invert() { alg::invert(factors); } + bool empty() const { return factors.empty(); } + + prodterm& operator*=(const prodterm& x) { + factors.insert(factors.end(), x.factors.begin(), x.factors.end()); + collect(); + return *this; + } + + prodterm& operator/=(const prodterm& x) { + prodterm recip(x); + recip.invert(); + return *this *= recip; + } + + prodterm pow(double n) const { + prodterm x(*this); + for (auto& f: x.factors) f.n *= n; + return x; + } + + friend prodterm pow(const prodterm& pt, double n) { + return pt.pow(n); + } + + friend prodterm operator*(const prodterm& a, const prodterm& b) { + prodterm p(a); + return p *= b; + } + + friend prodterm operator/(const prodterm& a, const prodterm& b) { + prodterm p(a); + return p /= b; + } + + friend bool operator<(const prodterm& p, const prodterm& q) { + return p.factors<q.factors; + } + + friend bool operator==(const prodterm& p, const prodterm& q) { + return p.factors==q.factors; + } + + friend bool operator!=(const prodterm& p, const prodterm& q) { + return !(p==q); + } + + friend std::ostream& operator<<(std::ostream& o, const prodterm& x) { + if (x.empty()) return o << "1"; + + int nf = 0; + for (const auto& f: x.factors) { + o << (nf++?"*":"") << f.prim; + if (f.n!=1) o << '^' << f.n; + } + return o; + } +}; + +struct prodsum { + using term = collectable<prodterm, double>; + std::vector<term> terms; + + prodsum() {} + + prodsum(const prodterm& pt): terms(1, pt) {} + prodsum(prodterm&& pt): terms(1, std::move(pt)) {} + explicit prodsum(double x, const prodterm& pt = prodterm()): terms(1, term(pt, x)) {} + + void collect() { alg::collect(terms); } + void invert() { alg::invert(terms); } + bool empty() const { return terms.empty(); } + + prodsum& operator+=(const prodsum& x) { + terms.insert(terms.end(), x.terms.begin(), x.terms.end()); + collect(); + return *this; + } + + prodsum& operator-=(const prodsum& x) { + prodsum neg(x); + neg.invert(); + return *this += neg; + } + + prodsum operator-() const { + prodsum neg(*this); + neg.invert(); + return neg; + } + + // Distribution: + prodsum& operator*=(const prodsum& x) { + if (terms.empty()) return *this; + if (x.empty()) { + terms.clear(); + return *this; + } + + std::vector<term> distrib; + for (const auto& a: terms) { + for (const auto& b: x.terms) { + distrib.emplace_back(a.prim*b.prim, a.n*b.n); + } + } + + terms = distrib; + collect(); + return *this; + } + + prodsum recip() const { + prodterm rterm; + double rcoef = 1; + + if (terms.size()==1) { + rcoef = terms.front().n; + rterm = terms.front().prim; + } + else { + // Make an opaque term from denominator if not a simple product. + rterm = as_opaque_term(); + } + rterm.invert(); + return prodsum(1.0/rcoef, rterm); + } + + prodsum& operator/=(const prodsum& x) { + return *this *= x.recip(); + } + + prodterm as_opaque_term() const { + std::stringstream s; + s << '(' << *this << ')'; + return prodterm(s.str()); + } + + friend prodsum operator+(const prodsum& a, const prodsum& b) { + prodsum p(a); + return p += b; + } + + friend prodsum operator-(const prodsum& a, const prodsum& b) { + prodsum p(a); + return p -= b; + } + + friend prodsum operator*(const prodsum& a, const prodsum& b) { + prodsum p(a); + return p *= b; + } + + friend prodsum operator/(const prodsum& a, const prodsum& b) { + prodsum p(a); + return p /= b; + } + + friend std::ostream& operator<<(std::ostream& o, const prodsum& x) { + if (x.terms.empty()) return o << "0"; + + bool first = true; + for (const auto& t: x.terms) { + double coef = t.n; + const prodterm& pd = t.prim; + + const char* prefix = coef<0? "-": first? "": "+"; + if (coef<0) coef = -coef; + + o << prefix; + if (pd.empty()) { + o << coef; + } + else { + if (coef!=1) o << coef << '*'; + o << pd; + } + first = false; + } + return o; + } + + bool is_scalar() const { + return terms.empty() || (terms.size()==1 && terms.front().prim.empty()); + } + + double first_coeff() const { + return terms.empty()? 0: terms.front().n; + } + + friend bool operator<(const prodsum& p, const prodsum& q) { + return p.terms<q.terms; + } + + friend bool operator==(const prodsum& p, const prodsum& q) { + return p.terms==q.terms; + } + + friend bool operator!=(const prodsum& p, const prodsum& q) { + return !(p==q); + } + + prodsum int_pow(unsigned n) const { + switch (n) { + case 0: + return prodsum(1); + case 1: + return *this; + default: + return int_pow(n/2)*int_pow(n/2)*int_pow(n%2); + } + } + + prodsum pow(double n) const { + if (n==0) { + return prodsum(1); + } + else if (n==1) { + return *this; + } + else if (is_scalar()) { + return prodsum(std::pow(first_coeff(), n)); + } + else if (terms.size()==1) { + const auto& t = terms.front(); + return prodsum(std::pow(t.n, n), t.prim.pow(n)); + } + else if (n<0) { + return recip().pow(-n); + } + else if (n!=std::floor(n)) { + return as_opaque_term().pow(n); + } + else { + return int_pow(static_cast<unsigned>(n)); + } + } +}; + +} // namespace alg diff --git a/tests/modcc/expr_expand.cpp b/tests/modcc/expr_expand.cpp new file mode 100644 index 0000000000000000000000000000000000000000..340acdcc89a25421be41e65a5221f129ec4c5891 --- /dev/null +++ b/tests/modcc/expr_expand.cpp @@ -0,0 +1,77 @@ +#include <stdexcept> +#include <sstream> + +#include "expression.hpp" +#include "modccutil.hpp" +#include "token.hpp" + +#include "alg_collect.hpp" +#include "expr_expand.hpp" + +alg::prodsum expand_expression(Expression* e, const id_prodsum_map& exmap) { + using namespace alg; + + if (const auto& n = e->is_number()) { + return prodsum(n->value()); + } + else if (const auto& c = e->is_function_call()) { + std::stringstream rep(c->name()); + rep << '('; + bool first = true; + for (const auto& arg: c->args()) { + if (!first) rep << ','; + rep << expand_expression(arg.get(), exmap); + first = false; + } + rep << ')'; + return prodterm(rep.str()); + } + else if (const auto& i = e->is_identifier()) { + std::string k = i->spelling(); + auto x = exmap.find(k); + return x!=exmap.end()? x->second: prodterm(k); + } + else if (const auto& b = e->is_binary()) { + prodsum lhs = expand_expression(b->lhs(), exmap); + prodsum rhs = expand_expression(b->rhs(), exmap); + + switch (b->op()) { + case tok::plus: + return lhs+rhs; + case tok::minus: + return lhs-rhs; + case tok::times: + return lhs*rhs; + case tok::divide: + return lhs/rhs; + case tok::pow: + if (!rhs.is_scalar()) { + // make an opaque term for this case (i.e. too hard to simplify) + return prodterm("("+to_string(lhs)+")^("+to_string(rhs)+")"); + } + else return lhs.pow(rhs.first_coeff()); + default: + throw std::runtime_error("unrecognized binop"); + } + } + else if (const auto& u = e->is_unary()) { + prodsum inner = expand_expression(u->expression(), exmap); + switch (u->op()) { + case tok::minus: + return -inner; + case tok::exp: + return prodterm("exp("+to_string(inner)+")"); + case tok::log: + return prodterm("log("+to_string(inner)+")"); + case tok::sin: + return prodterm("sin("+to_string(inner)+")"); + case tok::cos: + return prodterm("cos("+to_string(inner)+")"); + default: + throw std::runtime_error("unrecognized unaryop"); + } + } + else { + throw std::runtime_error("unexpected expression type"); + } +} diff --git a/tests/modcc/expr_expand.hpp b/tests/modcc/expr_expand.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55729b5ebc533ab474b8c4edab8527e4424ec581 --- /dev/null +++ b/tests/modcc/expr_expand.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include <list> +#include <map> +#include <stdexcept> +#include <string> + +#include "expression.hpp" + +#include "alg_collect.hpp" + +using id_prodsum_map = std::map<std::string, alg::prodsum>; + +// Given a value expression (e.g. something found on the right hand side +// of an assignment), return the canonical expanded algebraic representation. +// The `exmap` parameter contains the given associations between identifiers and +// algebraic representations. + +alg::prodsum expand_expression(Expression* e, const id_prodsum_map& exmap); + +// From a sequence of statement expressions, expand all assignments and return +// a map from identifiers to algebraic representations. + +template <typename StmtSeq> +id_prodsum_map expand_assignments(const StmtSeq& stmts) { + using namespace alg; + id_prodsum_map exmap; + + // This is 'just a test', so don't try to be complete: functions are + // left unexpanded; procedure calls are ignored. + + for (const auto& stmt: stmts) { + if (auto assign = stmt->is_assignment()) { + auto lhs = assign->lhs(); + std::string key; + if (auto deriv = lhs->is_derivative()) { + key = deriv->spelling()+"'"; + } + else if (auto id = lhs->is_identifier()) { + key = id->spelling(); + } + else { + // don't know what we have here! skip. + continue; + } + + exmap[key] = expand_expression(assign->rhs(), exmap); + } + } + return exmap; +} diff --git a/tests/modcc/test_kinetic_rewriter.cpp b/tests/modcc/test_kinetic_rewriter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..53a32d919ff654fe700b63347991016acd974743 --- /dev/null +++ b/tests/modcc/test_kinetic_rewriter.cpp @@ -0,0 +1,98 @@ +#include <iostream> +#include <string> + +#include "expression.hpp" +#include "kinrewriter.hpp" +#include "parser.hpp" + +#include "alg_collect.hpp" +#include "expr_expand.hpp" +#include "test.hpp" + + +stmt_list_type& proc_statements(Expression *e) { + if (!e || !e->is_symbol() || ! e->is_symbol()->is_procedure()) { + throw std::runtime_error("not a procedure"); + } + + return e->is_symbol()->is_procedure()->body()->statements(); +} + + +inline symbol_ptr state_var(const char* name) { + auto v = make_symbol<VariableExpression>(Location(), name); + v->is_variable()->state(true); + return v; +} + +inline symbol_ptr assigned_var(const char* name) { + return make_symbol<VariableExpression>(Location(), name); +} + +static const char* kinetic_abc = + "KINETIC kin { \n" + " u = 3 \n" + " ~ a <-> b (u, v) \n" + " u = 4 \n" + " v = sin(u) \n" + " ~ b <-> 3b + c (u, v) \n" + "} \n"; + +static const char* derivative_abc = + "DERIVATIVE deriv { \n" + " a' = -3*a + b*v \n" + " LOCAL rev2 \n" + " rev2 = c*b^3*sin(4) \n" + " b' = 3*a - (v*b) + 8*b - 2*rev2\n" + " c' = 4*b - rev2 \n" + "} \n"; + +TEST(KineticRewriter, equiv) { + auto visitor = make_unique<KineticRewriter>(); + auto kin = Parser(kinetic_abc).parse_procedure(); + auto deriv = Parser(derivative_abc).parse_procedure(); + + ASSERT_NE(nullptr, kin); + ASSERT_NE(nullptr, deriv); + ASSERT_TRUE(kin->is_symbol() && kin->is_symbol()->is_procedure()); + ASSERT_TRUE(deriv->is_symbol() && deriv->is_symbol()->is_procedure()); + + auto kin_weak = kin->is_symbol()->is_procedure(); + scope_type::symbol_map globals; + globals["kin"] = std::move(kin); + globals["a"] = state_var("a"); + globals["b"] = state_var("b"); + globals["c"] = state_var("c"); + globals["u"] = assigned_var("u"); + globals["v"] = assigned_var("v"); + + kin_weak->semantic(globals); + kin_weak->accept(visitor.get()); + + auto kin_deriv = visitor->as_procedure(); + + if (g_verbose_flag) { + std::cout << "derivative procedure:\n" << deriv->to_string() << "\n"; + std::cout << "kin procedure:\n" << kin_weak->to_string() << "\n"; + std::cout << "rewritten kin procedure:\n" << kin_deriv->to_string() << "\n"; + } + + auto deriv_map = expand_assignments(proc_statements(deriv.get())); + auto kin_map = expand_assignments(proc_statements(kin_deriv.get())); + + if (g_verbose_flag) { + std::cout << "derivative assignments (canonical):\n"; + for (const auto&p: deriv_map) { + std::cout << p.first << ": " << p.second << "\n"; + } + std::cout << "rewritten kin assignments (canonical):\n"; + for (const auto&p: kin_map) { + std::cout << p.first << ": " << p.second << "\n"; + } + } + + EXPECT_EQ(deriv_map["a'"], kin_map["a'"]); + EXPECT_EQ(deriv_map["b'"], kin_map["b'"]); + EXPECT_EQ(deriv_map["c'"], kin_map["c'"]); +} + diff --git a/tests/modcc/test_parser.cpp b/tests/modcc/test_parser.cpp index 59d085dee710b14cd49ed17310320db92d285e49..8cab1b8bd9208c508068e9ca50eb8a278bd3e62b 100644 --- a/tests/modcc/test_parser.cpp +++ b/tests/modcc/test_parser.cpp @@ -527,6 +527,7 @@ TEST(Parser, parse_binop) { {"2+3*(-2)", 2.+(3*-2)}, {"2+3*(-+2)", 2.+(3*-+2)}, {"2/3*4", (2./3.)*4.}, + {"2 * 7 - 3 * 11 + 4 * 13", 2.*7.-3.*11.+4.*13.}, // right associative {"2^3^1.5", std::pow(2.,std::pow(3.,1.5))},