-
Also renumber operator precedence. Fixes #25
38a67608
symdiff.cpp 20.90 KiB
#include <cmath>
#include <map>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include "error.hpp"
#include "expression.hpp"
#include "symdiff.hpp"
#include "util.hpp"
#include "visitor.hpp"
class FindIdentifierVisitor: public Visitor {
public:
explicit FindIdentifierVisitor(const identifier_set& ids): ids_(ids) {}
void reset() { found_ = false; }
bool found() const { return found_; }
void visit(Expression* e) override {}
void visit(UnaryExpression* e) override {
if (!found()) e->expression()->accept(this);
}
void visit(BinaryExpression* e) override {
if (!found()) e->lhs()->accept(this);
if (!found()) e->rhs()->accept(this);
}
void visit(CallExpression* e) override {
for (auto& expr: e->args()) {
if (found()) return;
expr->accept(this);
}
}
void visit(PDiffExpression* e) override {
if (!found()) e->arg()->accept(this);
}
void visit(IdentifierExpression* e) override {
if (!found()) {
found_ |= is_in(e->spelling(), ids_);
}
}
void visit(DerivativeExpression* e) override {
if (!found()) {
found_ |= is_in(e->spelling(), ids_);
}
}
void visit(ReactionExpression* e) override {
if (!found()) e->lhs()->accept(this);
if (!found()) e->rhs()->accept(this);
if (!found()) e->fwd_rate()->accept(this);
if (!found()) e->rev_rate()->accept(this);
}
void visit(StoichTermExpression* e) override {
if (!found()) e->ident()->accept(this);
}
void visit(StoichExpression* e) override {
for (auto& expr: e->terms()) {
if (found()) return;
expr->accept(this);
}
}
void visit(BlockExpression* e) override {
for (auto& expr: e->statements()) {
if (found()) return;
expr->accept(this);
}
}
void visit(IfExpression* e) override {
if (!found()) e->condition()->accept(this);
if (!found()) e->true_branch()->accept(this);
if (!found()) e->false_branch()->accept(this);
}
private:
const identifier_set& ids_;
bool found_ = false;
};
bool involves_identifier(Expression* e, const identifier_set& ids) {
FindIdentifierVisitor v(ids);
e->accept(&v);
return v.found();
}
bool involves_identifier(Expression* e, const std::string& id) {
identifier_set ids = {id};
FindIdentifierVisitor v(ids);
e->accept(&v);
return v.found();
}
class SymPDiffVisitor: public Visitor, public error_stack {
public:
explicit SymPDiffVisitor(std::string id): id_(std::move(id)) {}
void reset() { result_ = nullptr; }
// Note: moves result, forces reset.
expression_ptr result() {
auto r = std::move(result_);
reset();
return r;
}
void visit(Expression* e) override {
error({"symbolic differential of improper expression", e->location()});
}
void visit(UnaryExpression* e) override {
error({"symbolic differential of unrecognized unary expression", e->location()});
}
void visit(BinaryExpression* e) override {
error({"symbolic differential of unrecognized binary expression", e->location()});
}
void visit(NegUnaryExpression* e) override {
auto loc = e->location();
e->expression()->accept(this);
result_ = make_expression<NegUnaryExpression>(loc, result());
}
void visit(ExpUnaryExpression* e) override {
auto loc = e->location();
e->expression()->accept(this);
result_ = make_expression<MulBinaryExpression>(loc, result(), e->clone());
}
void visit(LogUnaryExpression* e) override {
auto loc = e->location();
e->expression()->accept(this);
result_ = make_expression<DivBinaryExpression>(loc, result(), e->expression()->clone());
}
void visit(CosUnaryExpression* e) override {
auto loc = e->location();
e->expression()->accept(this);
result_ = make_expression<MulBinaryExpression>(loc,
make_expression<NegUnaryExpression>(loc,
make_expression<SinUnaryExpression>(loc, e->expression()->clone())),
result());
}
void visit(SinUnaryExpression* e) override {
auto loc = e->location();
e->expression()->accept(this);
result_ = make_expression<MulBinaryExpression>(loc,
make_expression<CosUnaryExpression>(loc, e->expression()->clone()),
result());
}
void visit(AddBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr dlhs = std::move(result_);
e->rhs()->accept(this);
result_ = make_expression<AddBinaryExpression>(loc, std::move(dlhs), result());
}
void visit(SubBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr dlhs = std::move(result_);
e->rhs()->accept(this);
result_ = make_expression<SubBinaryExpression>(loc, move(dlhs), result());
}
void visit(MulBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr dlhs = std::move(result_);
e->rhs()->accept(this);
expression_ptr drhs = std::move(result_);
result_ = make_expression<AddBinaryExpression>(loc,
make_expression<MulBinaryExpression>(loc, e->lhs()->clone(), std::move(drhs)),
make_expression<MulBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone()));
}
void visit(DivBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr dlhs = std::move(result_);
e->rhs()->accept(this);
expression_ptr drhs = std::move(result_);
result_ = make_expression<SubBinaryExpression>(loc,
make_expression<DivBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone()),
make_expression<MulBinaryExpression>(loc,
make_expression<DivBinaryExpression>(loc,
e->lhs()->clone(),
make_expression<MulBinaryExpression>(loc, e->rhs()->clone(), e->rhs()->clone())),
std::move(drhs)));
}
void visit(PowBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr dlhs = std::move(result_);
e->rhs()->accept(this);
expression_ptr drhs = std::move(result_);
result_ = make_expression<AddBinaryExpression>(loc,
make_expression<MulBinaryExpression>(loc,
std::move(drhs),
make_expression<MulBinaryExpression>(loc,
make_expression<LogUnaryExpression>(loc, e->lhs()->clone()),
make_expression<PowBinaryExpression>(loc, e->lhs()->clone(), e->rhs()->clone()))),
make_expression<MulBinaryExpression>(loc,
e->rhs()->clone(),
make_expression<MulBinaryExpression>(loc,
make_expression<PowBinaryExpression>(loc,
e->lhs()->clone(),
make_expression<SubBinaryExpression>(loc,
e->rhs()->clone(),
make_expression<IntegerExpression>(loc, 1))),
std::move(dlhs))));
}
void visit(CallExpression* e) override {
auto loc = e->location();
result_ = make_expression<PDiffExpression>(loc,
make_expression<IdentifierExpression>(loc, id_),
e->clone());
}
void visit(PDiffExpression* e) override {
auto loc = e->location();
e->arg()->accept(this);
result_ = make_expression<PDiffExpression>(loc, e->var()->clone(), result());
}
void visit(IdentifierExpression* e) override {
auto loc = e->location();
result_ = make_expression<IntegerExpression>(loc, e->spelling()==id_);
}
void visit(NumberExpression* e) override {
auto loc = e->location();
result_ = make_expression<IntegerExpression>(loc, 0);
}
private:
expression_ptr result_;
std::string id_;
};
long double expr_value(Expression* e) {
return e && e->is_number()? e->is_number()->value(): NAN;
}
class ConstantSimplifyVisitor: public Visitor {
private:
expression_ptr result_;
static bool is_number(Expression* e) { return e && e->is_number(); }
static bool is_number(const expression_ptr& e) { return is_number(e.get()); }
void as_number(Location loc, long double v) {
result_ = make_expression<NumberExpression>(loc, v);
}
public:
using Visitor::visit;
ConstantSimplifyVisitor() {}
// Note: moves result, forces reset.
expression_ptr result() {
auto r = std::move(result_);
reset();
return r;
}
void reset() {
result_ = nullptr;
}
long double value() const { return expr_value(result_); }
bool is_number() const { return is_number(result_); }
void visit(Expression* e) override {
result_ = e->clone();
}
void visit(BlockExpression* e) override {
auto block_ = e->clone();
block_->is_block()->statements().clear();
for (auto& stmt: e->statements()) {
stmt->accept(this);
auto simpl = result();
// flatten any naked blocks generated by if/else simplification
if (auto inner = simpl->is_block()) {
for (auto& stmt: inner->statements()) {
block_->is_block()->statements().push_back(std::move(stmt));
}
}
else {
block_->is_block()->statements().push_back(std::move(simpl));
}
}
result_ = std::move(block_);
}
void visit(IfExpression* e) override {
auto loc = e->location();
e->condition()->accept(this);
auto cond_expr = result();
e->true_branch()->accept(this);
auto true_expr = result();
expression_ptr false_expr;
if (e->false_branch()) {
e->false_branch()->accept(this);
false_expr = result()->clone();
}
if (!is_number(cond_expr)) {
result_ = make_expression<IfExpression>(loc,
std::move(cond_expr), std::move(true_expr), std::move(false_expr));
}
else if (expr_value(cond_expr)) {
result_ = std::move(true_expr);
}
else {
result_ = std::move(false_expr);
}
}
// TODO: procedure, function expressions
void visit(UnaryExpression* e) override {
e->expression()->accept(this);
if (is_number()) {
auto loc = e->location();
auto val = value();
switch (e->op()) {
case tok::minus:
as_number(loc, -val);
return;
case tok::exp:
as_number(loc, std::exp(val));
return;
case tok::sin:
as_number(loc, std::sin(val));
return;
case tok::cos:
as_number(loc, std::cos(val));
return;
case tok::log:
as_number(loc, std::log(val));
return;
default: ; // treat opaquely as below
}
}
expression_ptr arg = result();
result_ = e->clone();
result_->is_unary()->replace_expression(std::move(arg));
}
void visit(BinaryExpression* e) override {
result_ = e->clone();
}
void visit(AssignmentExpression* e) override {
auto loc = e->location();
e->rhs()->accept(this);
result_ = make_expression<AssignmentExpression>(loc, e->lhs()->clone(), result());
}
void visit(MulBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr lhs = result();
e->rhs()->accept(this);
expression_ptr rhs = result();
if (is_number(lhs) && is_number(rhs)) {
as_number(loc, expr_value(lhs)*expr_value(rhs));
}
else if (expr_value(lhs)==0 || expr_value(rhs)==0) {
as_number(loc, 0);
}
else if (expr_value(lhs)==1) {
result_ = std::move(rhs);
}
else if (expr_value(rhs)==1) {
result_ = std::move(lhs);
}
else {
result_ = make_expression<MulBinaryExpression>(loc, std::move(lhs), std::move(rhs));
}
}
void visit(DivBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr lhs = result();
e->rhs()->accept(this);
expression_ptr rhs = result();
if (is_number(lhs) && is_number(rhs)) {
as_number(loc, expr_value(lhs)/expr_value(rhs));
}
else if (expr_value(lhs)==0) {
as_number(loc, 0);
}
else if (expr_value(rhs)==1) {
result_ = e->lhs()->clone();
}
else {
result_ = make_expression<DivBinaryExpression>(loc, std::move(lhs), std::move(rhs));
}
}
void visit(AddBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr lhs = result();
e->rhs()->accept(this);
expression_ptr rhs = result();
if (is_number(lhs) && is_number(rhs)) {
as_number(loc, expr_value(lhs)+expr_value(rhs));
}
else if (expr_value(lhs)==0) {
result_ = std::move(rhs);
}
else if (expr_value(rhs)==0) {
result_ = std::move(lhs);
}
else {
result_ = make_expression<AddBinaryExpression>(loc, std::move(lhs), std::move(rhs));
}
}
void visit(SubBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr lhs = result();
e->rhs()->accept(this);
expression_ptr rhs = result();
if (is_number(lhs) && is_number(rhs)) {
as_number(loc, expr_value(lhs)-expr_value(rhs));
}
else if (expr_value(lhs)==0) {
make_expression<NegUnaryExpression>(loc, std::move(rhs))->accept(this);
}
else if (expr_value(rhs)==0) {
result_ = std::move(lhs);
}
else {
result_ = make_expression<SubBinaryExpression>(loc, std::move(lhs), std::move(rhs));
}
}
void visit(PowBinaryExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr lhs = result();
e->rhs()->accept(this);
expression_ptr rhs = result();
if (is_number(lhs) && is_number(rhs)) {
as_number(loc, std::pow(expr_value(lhs),expr_value(rhs)));
}
else if (expr_value(lhs)==0) {
as_number(loc, 0);
}
else if (expr_value(rhs)==0 || expr_value(lhs)==1) {
as_number(loc, 1);
}
else if (expr_value(rhs)==1) {
result_ = std::move(lhs);
}
else {
result_ = make_expression<PowBinaryExpression>(loc, std::move(lhs), std::move(rhs));
}
}
void visit(ConditionalExpression* e) override {
auto loc = e->location();
e->lhs()->accept(this);
expression_ptr lhs = result();
e->rhs()->accept(this);
expression_ptr rhs = result();
if (is_number(lhs) && is_number(rhs)) {
auto lval = expr_value(lhs);
auto rval = expr_value(rhs);
switch (e->op()) {
case tok::equality:
as_number(loc, lval==rval);
return;
case tok::ne:
as_number(loc, lval!=rval);
return;
case tok::lt:
as_number(loc, lval<rval);
return;
case tok::gt:
as_number(loc, lval>rval);
return;
case tok::lte:
as_number(loc, lval<=rval);
return;
case tok::gte:
as_number(loc, lval>=rval);
case tok::land:
as_number(loc, lval&&rval);
case tok::lor:
as_number(loc, lval||rval);
return;
default: ;
// unrecognized, fall through to non-numeric case below
}
}
if (!is_number(lhs) || !is_number(rhs)) {
result_ = make_expression<ConditionalExpression>(loc, e->op(), std::move(lhs), std::move(rhs));
}
}
};
expression_ptr constant_simplify(Expression* e) {
ConstantSimplifyVisitor csimp_visitor;
e->accept(&csimp_visitor);
return csimp_visitor.result();
}
expression_ptr symbolic_pdiff(Expression* e, const std::string& id) {
if (!involves_identifier(e, id)) {
return make_expression<NumberExpression>(e->location(), 0);
}
SymPDiffVisitor pdiff_visitor(id);
e->accept(&pdiff_visitor);
return constant_simplify(pdiff_visitor.result());
}
// Substitute all occurances of an identifier within a unary, binary, call
// or (trivially) number expression with a copy of the provided substitute
// expression.
class SubstituteVisitor: public Visitor {
public:
explicit SubstituteVisitor(const substitute_map& sub):
sub_(sub) {}
expression_ptr result() {
auto r = std::move(result_);
reset();
return r;
}
void reset() {
result_ = nullptr;
}
void visit(Expression* e) override {
throw compiler_exception("substitution attempt on improper expression", e->location());
}
void visit(NumberExpression* e) override {
result_ = e->clone();
}
void visit(IdentifierExpression* e) override {
result_ = is_in(e->spelling(), sub_)? sub_.at(e->spelling())->clone(): e->clone();
}
void visit(UnaryExpression* e) override {
e->expression()->accept(this);
auto arg = result();
result_ = e->clone();
result_->is_unary()->replace_expression(std::move(arg));
}
void visit(BinaryExpression* e) override {
e->lhs()->accept(this);
auto lhs = result();
e->rhs()->accept(this);
auto rhs = result();
result_ = e->clone();
result_->is_binary()->replace_lhs(std::move(lhs));
result_->is_binary()->replace_rhs(std::move(rhs));
}
void visit(CallExpression* e) override {
auto newexpr = e->clone();
for (auto& arg: newexpr->is_call()->args()) {
arg->accept(this);
arg = result();
}
result_ = std::move(newexpr);
}
void visit(PDiffExpression* e) override {
// Doing the correct thing when the derivative variable is the
// substitution variable would require another 'opaque' expression,
// e.g. `SubstitutionExpression`, but we're not about to do that yet,
// so throw an exception instead, i.e. Don't Do That.
if (is_in(e->var()->is_identifier()->spelling(), sub_)) {
throw compiler_exception("attempt to substitute value for derivative variable", e->location());
}
e->arg()->accept(this);
result_ = make_expression<PDiffExpression>(e->location(), e->var()->clone(), result());
}
private:
expression_ptr result_;
const substitute_map& sub_;
};
expression_ptr substitute(Expression* e, const std::string& id, Expression* sub) {
substitute_map subs;
subs[id] = sub->clone();
SubstituteVisitor sub_visitor(subs);
e->accept(&sub_visitor);
return sub_visitor.result();
}
expression_ptr substitute(Expression* e, const substitute_map& sub) {
SubstituteVisitor sub_visitor(sub);
e->accept(&sub_visitor);
return sub_visitor.result();
}
linear_test_result linear_test(Expression* e, const std::vector<std::string>& vars) {
linear_test_result result;
auto loc = e->location();
auto zero = [loc]() { return make_expression<IntegerExpression>(loc, 0); };
result.constant = e->clone();
for (const auto& id: vars) {
auto coef = symbolic_pdiff(e, id);
if (!is_zero(coef)) result.coef[id] = std::move(coef);
result.constant = substitute(result.constant, id, zero());
}
ConstantSimplifyVisitor csimp_visitor;
result.constant->accept(&csimp_visitor);
result.constant = csimp_visitor.result();
// linearity test: take second order derivatives, test against zero.
result.is_linear = true;
for (unsigned i = 0; i<vars.size(); ++i) {
auto v1 = vars[i];
if (!is_in(v1, result.coef)) continue;
for (unsigned j = i; j<vars.size(); ++j) {
auto v2 = vars[j];
if (!is_zero(symbolic_pdiff(result.coef[v1].get(), v2).get())) {
result.is_linear = false;
goto done;
}
}
}
done:
if (result.is_linear) {
result.is_homogeneous = is_zero(result.constant);
}
return result;
}