Skip to content
Snippets Groups Projects
Select Git revision
  • 0549d2eb66cf4badbdfcb3e57311229a1f60b94e
  • master default protected
  • tut_ring_allen
  • docs_furo
  • docs_reorder_cable_cell
  • docs_graphviz
  • docs_rtd_dev
  • ebrains_mirror
  • doc_recat
  • docs_spike_source
  • docs_sim_sample_clar
  • docs_pip_warn
  • github_template_updates
  • docs_fix_link
  • cv_default_and_doc_clarification
  • docs_add_numpy_req
  • readme_zenodo_05
  • install_python_fix
  • install_require_numpy
  • typofix_propetries
  • docs_recipe_lookup
  • v0.10.0
  • v0.10.1
  • v0.10.0-rc5
  • v0.10.0-rc4
  • v0.10.0-rc3
  • v0.10.0-rc2
  • v0.10.0-rc
  • v0.9.0
  • v0.9.0-rc
  • v0.8.1
  • v0.8
  • v0.8-rc
  • v0.7
  • v0.6
  • v0.5.2
  • v0.5.1
  • v0.5
  • v0.4
  • v0.3
  • v0.2.2
41 results

symdiff.cpp

Blame
  • user avatar
    Nora Abi Akar authored and Ben Cumming committed
    Also renumber operator precedence.
    
    Fixes #25
    38a67608
    History
    symdiff.cpp 20.90 KiB
    #include <cmath>
    #include <map>
    #include <set>
    #include <stdexcept>
    #include <string>
    #include <utility>
    
    #include "error.hpp"
    #include "expression.hpp"
    #include "symdiff.hpp"
    #include "util.hpp"
    #include "visitor.hpp"
    
    class FindIdentifierVisitor: public Visitor {
    public:
        explicit FindIdentifierVisitor(const identifier_set& ids): ids_(ids) {}
    
        void reset() { found_ = false; }
        bool found() const { return found_; }
    
        void visit(Expression* e) override {}
    
        void visit(UnaryExpression* e) override {
            if (!found()) e->expression()->accept(this);
        }
    
        void visit(BinaryExpression* e) override {
            if (!found()) e->lhs()->accept(this);
            if (!found()) e->rhs()->accept(this);
        }
    
        void visit(CallExpression* e) override {
            for (auto& expr: e->args()) {
                if (found()) return;
                expr->accept(this);
            }
        }
    
        void visit(PDiffExpression* e) override {
            if (!found()) e->arg()->accept(this);
        }
    
        void visit(IdentifierExpression* e) override {
            if (!found()) {
                found_ |= is_in(e->spelling(), ids_);
            }
        }
    
        void visit(DerivativeExpression* e) override {
            if (!found()) {
                found_ |= is_in(e->spelling(), ids_);
            }
        }
        void visit(ReactionExpression* e) override {
            if (!found()) e->lhs()->accept(this);
            if (!found()) e->rhs()->accept(this);
            if (!found()) e->fwd_rate()->accept(this);
            if (!found()) e->rev_rate()->accept(this);
        }
    
        void visit(StoichTermExpression* e) override {
            if (!found()) e->ident()->accept(this);
        }
    
        void visit(StoichExpression* e) override {
            for (auto& expr: e->terms()) {
                if (found()) return;
                expr->accept(this);
            }
        }
    
        void visit(BlockExpression* e) override {
            for (auto& expr: e->statements()) {
                if (found()) return;
                expr->accept(this);
            }
        }
    
        void visit(IfExpression* e) override {
            if (!found()) e->condition()->accept(this);
            if (!found()) e->true_branch()->accept(this);
            if (!found()) e->false_branch()->accept(this);
        }
    
    private:
        const identifier_set& ids_;
        bool found_ = false;
    };
    
    bool involves_identifier(Expression* e, const identifier_set& ids) {
        FindIdentifierVisitor v(ids);
        e->accept(&v);
        return v.found();
    }
    
    bool involves_identifier(Expression* e, const std::string& id) {
        identifier_set ids = {id};
        FindIdentifierVisitor v(ids);
        e->accept(&v);
        return v.found();
    }
    
    class SymPDiffVisitor: public Visitor, public error_stack {
    public:
        explicit SymPDiffVisitor(std::string id): id_(std::move(id)) {}
    
        void reset() { result_ = nullptr; }
    
        // Note: moves result, forces reset.
        expression_ptr result() {
            auto r = std::move(result_);
            reset();
            return r;
        }
    
        void visit(Expression* e) override {
            error({"symbolic differential of improper expression", e->location()});
        }
    
        void visit(UnaryExpression* e) override {
            error({"symbolic differential of unrecognized unary expression", e->location()});
        }
    
        void visit(BinaryExpression* e) override {
            error({"symbolic differential of unrecognized binary expression", e->location()});
        }
    
        void visit(NegUnaryExpression* e) override {
            auto loc = e->location();
            e->expression()->accept(this);
            result_ = make_expression<NegUnaryExpression>(loc, result());
        }
    
        void visit(ExpUnaryExpression* e) override {
            auto loc = e->location();
            e->expression()->accept(this);
            result_ = make_expression<MulBinaryExpression>(loc, result(), e->clone());
        }
    
        void visit(LogUnaryExpression* e) override {
            auto loc = e->location();
            e->expression()->accept(this);
            result_ = make_expression<DivBinaryExpression>(loc, result(), e->expression()->clone());
        }
    
        void visit(CosUnaryExpression* e) override {
            auto loc = e->location();
            e->expression()->accept(this);
            result_ = make_expression<MulBinaryExpression>(loc,
                          make_expression<NegUnaryExpression>(loc,
                              make_expression<SinUnaryExpression>(loc, e->expression()->clone())),
                          result());
        }
    
        void visit(SinUnaryExpression* e) override {
            auto loc = e->location();
            e->expression()->accept(this);
            result_ = make_expression<MulBinaryExpression>(loc,
                            make_expression<CosUnaryExpression>(loc, e->expression()->clone()),
                            result());
        }
    
        void visit(AddBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr dlhs = std::move(result_);
    
            e->rhs()->accept(this);
            result_ = make_expression<AddBinaryExpression>(loc, std::move(dlhs), result());
        }
    
        void visit(SubBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr dlhs = std::move(result_);
    
            e->rhs()->accept(this);
            result_ = make_expression<SubBinaryExpression>(loc, move(dlhs), result());
        }
    
        void visit(MulBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr dlhs = std::move(result_);
    
            e->rhs()->accept(this);
            expression_ptr drhs = std::move(result_);
    
            result_ = make_expression<AddBinaryExpression>(loc,
                make_expression<MulBinaryExpression>(loc, e->lhs()->clone(), std::move(drhs)),
                make_expression<MulBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone()));
    
        }
    
        void visit(DivBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr dlhs = std::move(result_);
    
            e->rhs()->accept(this);
            expression_ptr drhs = std::move(result_);
    
            result_ = make_expression<SubBinaryExpression>(loc,
                make_expression<DivBinaryExpression>(loc, std::move(dlhs), e->rhs()->clone()),
                make_expression<MulBinaryExpression>(loc,
                    make_expression<DivBinaryExpression>(loc,
                        e->lhs()->clone(),
                        make_expression<MulBinaryExpression>(loc, e->rhs()->clone(), e->rhs()->clone())),
                    std::move(drhs)));
        }
    
        void visit(PowBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr dlhs = std::move(result_);
    
            e->rhs()->accept(this);
            expression_ptr drhs = std::move(result_);
    
            result_ = make_expression<AddBinaryExpression>(loc,
                make_expression<MulBinaryExpression>(loc,
                    std::move(drhs),
                    make_expression<MulBinaryExpression>(loc,
                        make_expression<LogUnaryExpression>(loc, e->lhs()->clone()),
                        make_expression<PowBinaryExpression>(loc, e->lhs()->clone(), e->rhs()->clone()))),
                make_expression<MulBinaryExpression>(loc,
                    e->rhs()->clone(),
                    make_expression<MulBinaryExpression>(loc,
                        make_expression<PowBinaryExpression>(loc,
                            e->lhs()->clone(),
                            make_expression<SubBinaryExpression>(loc,
                                e->rhs()->clone(),
                                make_expression<IntegerExpression>(loc, 1))),
                        std::move(dlhs))));
        }
    
        void visit(CallExpression* e) override {
            auto loc = e->location();
            result_ = make_expression<PDiffExpression>(loc,
                        make_expression<IdentifierExpression>(loc, id_),
                        e->clone());
        }
    
        void visit(PDiffExpression* e) override {
            auto loc = e->location();
            e->arg()->accept(this);
            result_ = make_expression<PDiffExpression>(loc, e->var()->clone(), result());
        }
    
        void visit(IdentifierExpression* e) override {
            auto loc = e->location();
            result_ = make_expression<IntegerExpression>(loc, e->spelling()==id_);
        }
    
        void visit(NumberExpression* e) override {
            auto loc = e->location();
            result_ = make_expression<IntegerExpression>(loc, 0);
        }
    
    private:
        expression_ptr result_;
        std::string id_;
    };
    
    long double expr_value(Expression* e) {
        return e && e->is_number()? e->is_number()->value(): NAN;
    }
    
    class ConstantSimplifyVisitor: public Visitor {
    private:
        expression_ptr result_;
    
        static bool is_number(Expression* e) { return e && e->is_number(); }
        static bool is_number(const expression_ptr& e) { return is_number(e.get()); }
    
        void as_number(Location loc, long double v) {
            result_ = make_expression<NumberExpression>(loc, v);
        }
    
    public:
        using Visitor::visit;
    
        ConstantSimplifyVisitor() {}
    
        // Note: moves result, forces reset.
        expression_ptr result() {
            auto r = std::move(result_);
            reset();
            return r;
        }
    
        void reset() {
            result_ = nullptr;
        }
    
        long double value() const { return expr_value(result_); }
    
        bool is_number() const { return is_number(result_); }
    
        void visit(Expression* e) override {
            result_ = e->clone();
        }
    
        void visit(BlockExpression* e) override {
            auto block_ = e->clone();
            block_->is_block()->statements().clear();
    
            for (auto& stmt: e->statements()) {
                stmt->accept(this);
                auto simpl = result();
    
                // flatten any naked blocks generated by if/else simplification
                if (auto inner = simpl->is_block()) {
                    for (auto& stmt: inner->statements()) {
                        block_->is_block()->statements().push_back(std::move(stmt));
                    }
                }
                else {
                    block_->is_block()->statements().push_back(std::move(simpl));
                }
            }
            result_ = std::move(block_);
        }
    
        void visit(IfExpression* e) override {
            auto loc = e->location();
            e->condition()->accept(this);
            auto cond_expr = result();
            e->true_branch()->accept(this);
            auto true_expr = result();
            expression_ptr false_expr;
            if (e->false_branch()) {
                e->false_branch()->accept(this);
                false_expr = result()->clone();
            }
    
            if (!is_number(cond_expr)) {
                result_ = make_expression<IfExpression>(loc,
                            std::move(cond_expr), std::move(true_expr), std::move(false_expr));
            }
            else if (expr_value(cond_expr)) {
                result_ = std::move(true_expr);
            }
            else {
                result_ = std::move(false_expr);
            }
        }
    
        // TODO: procedure, function expressions
    
        void visit(UnaryExpression* e) override {
            e->expression()->accept(this);
    
            if (is_number()) {
                auto loc = e->location();
                auto val = value();
    
                switch (e->op()) {
                case tok::minus:
                    as_number(loc, -val);
                    return;
                case tok::exp:
                    as_number(loc, std::exp(val));
                    return;
                case tok::sin:
                    as_number(loc, std::sin(val));
                    return;
                case tok::cos:
                    as_number(loc, std::cos(val));
                    return;
                case tok::log:
                    as_number(loc, std::log(val));
                    return;
                default: ; // treat opaquely as below
                }
            }
    
            expression_ptr arg = result();
            result_ = e->clone();
            result_->is_unary()->replace_expression(std::move(arg));
        }
    
        void visit(BinaryExpression* e) override {
            result_ = e->clone();
        }
    
        void visit(AssignmentExpression* e) override {
            auto loc = e->location();
            e->rhs()->accept(this);
            result_ = make_expression<AssignmentExpression>(loc, e->lhs()->clone(), result());
        }
    
        void visit(MulBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr lhs = result();
            e->rhs()->accept(this);
            expression_ptr rhs = result();
    
            if (is_number(lhs) && is_number(rhs)) {
                as_number(loc, expr_value(lhs)*expr_value(rhs));
            }
            else if (expr_value(lhs)==0 || expr_value(rhs)==0) {
                as_number(loc, 0);
            }
            else if (expr_value(lhs)==1) {
                result_ = std::move(rhs);
            }
            else if (expr_value(rhs)==1) {
                result_ = std::move(lhs);
            }
            else {
                result_ = make_expression<MulBinaryExpression>(loc, std::move(lhs), std::move(rhs));
            }
        }
    
        void visit(DivBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr lhs = result();
            e->rhs()->accept(this);
            expression_ptr rhs = result();
    
            if (is_number(lhs) && is_number(rhs)) {
                as_number(loc, expr_value(lhs)/expr_value(rhs));
            }
            else if (expr_value(lhs)==0) {
                as_number(loc, 0);
            }
            else if (expr_value(rhs)==1) {
                result_ = e->lhs()->clone();
            }
            else {
                result_ = make_expression<DivBinaryExpression>(loc, std::move(lhs), std::move(rhs));
            }
        }
    
        void visit(AddBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr lhs = result();
            e->rhs()->accept(this);
            expression_ptr rhs = result();
    
            if (is_number(lhs) && is_number(rhs)) {
                as_number(loc, expr_value(lhs)+expr_value(rhs));
            }
            else if (expr_value(lhs)==0) {
                result_ = std::move(rhs);
            }
            else if (expr_value(rhs)==0) {
                result_ = std::move(lhs);
            }
            else {
                result_ = make_expression<AddBinaryExpression>(loc, std::move(lhs), std::move(rhs));
            }
        }
    
        void visit(SubBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr lhs = result();
            e->rhs()->accept(this);
            expression_ptr rhs = result();
    
            if (is_number(lhs) && is_number(rhs)) {
                as_number(loc, expr_value(lhs)-expr_value(rhs));
            }
            else if (expr_value(lhs)==0) {
                make_expression<NegUnaryExpression>(loc, std::move(rhs))->accept(this);
            }
            else if (expr_value(rhs)==0) {
                result_ = std::move(lhs);
            }
            else {
                result_ = make_expression<SubBinaryExpression>(loc, std::move(lhs), std::move(rhs));
            }
        }
    
        void visit(PowBinaryExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr lhs = result();
            e->rhs()->accept(this);
            expression_ptr rhs = result();
    
            if (is_number(lhs) && is_number(rhs)) {
                as_number(loc, std::pow(expr_value(lhs),expr_value(rhs)));
            }
            else if (expr_value(lhs)==0) {
                as_number(loc, 0);
            }
            else if (expr_value(rhs)==0 || expr_value(lhs)==1) {
                as_number(loc, 1);
            }
            else if (expr_value(rhs)==1) {
                result_ = std::move(lhs);
            }
            else {
                result_ = make_expression<PowBinaryExpression>(loc, std::move(lhs), std::move(rhs));
            }
        }
    
        void visit(ConditionalExpression* e) override {
            auto loc = e->location();
            e->lhs()->accept(this);
            expression_ptr lhs = result();
            e->rhs()->accept(this);
            expression_ptr rhs = result();
    
            if (is_number(lhs) && is_number(rhs)) {
                auto lval = expr_value(lhs);
                auto rval = expr_value(rhs);
                switch (e->op()) {
                case tok::equality:
                    as_number(loc, lval==rval);
                    return;
                case tok::ne:
                    as_number(loc, lval!=rval);
                    return;
                case tok::lt:
                    as_number(loc, lval<rval);
                    return;
                case tok::gt:
                    as_number(loc, lval>rval);
                    return;
                case tok::lte:
                    as_number(loc, lval<=rval);
                    return;
                case tok::gte:
                    as_number(loc, lval>=rval);
                case tok::land:
                    as_number(loc, lval&&rval);
                case tok::lor:
                    as_number(loc, lval||rval);
                        return;
                default: ;
                    // unrecognized, fall through to non-numeric case below
                }
            }
            if (!is_number(lhs) || !is_number(rhs)) {
                result_ = make_expression<ConditionalExpression>(loc, e->op(), std::move(lhs), std::move(rhs));
            }
        }
    };
    
    expression_ptr constant_simplify(Expression* e) {
        ConstantSimplifyVisitor csimp_visitor;
        e->accept(&csimp_visitor);
        return csimp_visitor.result();
    }
    
    
    expression_ptr symbolic_pdiff(Expression* e, const std::string& id) {
        if (!involves_identifier(e, id)) {
            return make_expression<NumberExpression>(e->location(), 0);
        }
    
        SymPDiffVisitor pdiff_visitor(id);
        e->accept(&pdiff_visitor);
    
        return constant_simplify(pdiff_visitor.result());
    }
    
    // Substitute all occurances of an identifier within a unary, binary, call
    // or (trivially) number expression with a copy of the provided substitute
    // expression.
    
    class SubstituteVisitor: public Visitor {
    public:
        explicit SubstituteVisitor(const substitute_map& sub):
            sub_(sub) {}
    
        expression_ptr result() {
            auto r = std::move(result_);
            reset();
            return r;
        }
    
        void reset() {
            result_ = nullptr;
        }
    
        void visit(Expression* e) override {
            throw compiler_exception("substitution attempt on improper expression", e->location());
        }
    
        void visit(NumberExpression* e) override {
            result_ = e->clone();
        }
    
        void visit(IdentifierExpression* e) override {
            result_ = is_in(e->spelling(), sub_)? sub_.at(e->spelling())->clone(): e->clone();
        }
    
        void visit(UnaryExpression* e) override {
            e->expression()->accept(this);
            auto arg = result();
    
            result_ = e->clone();
            result_->is_unary()->replace_expression(std::move(arg));
        }
    
        void visit(BinaryExpression* e) override {
            e->lhs()->accept(this);
            auto lhs = result();
            e->rhs()->accept(this);
            auto rhs = result();
    
            result_ = e->clone();
            result_->is_binary()->replace_lhs(std::move(lhs));
            result_->is_binary()->replace_rhs(std::move(rhs));
        }
    
        void visit(CallExpression* e) override {
            auto newexpr = e->clone();
            for (auto& arg: newexpr->is_call()->args()) {
                arg->accept(this);
                arg = result();
            }
            result_ = std::move(newexpr);
        }
    
        void visit(PDiffExpression* e) override {
            // Doing the correct thing when the derivative variable is the
            // substitution variable would require another 'opaque' expression,
            // e.g. `SubstitutionExpression`, but we're not about to do that yet,
            // so throw an exception instead, i.e. Don't Do That.
            if (is_in(e->var()->is_identifier()->spelling(), sub_)) {
                throw compiler_exception("attempt to substitute value for derivative variable", e->location());
            }
    
            e->arg()->accept(this);
            result_ = make_expression<PDiffExpression>(e->location(), e->var()->clone(), result());
        }
    
    private:
        expression_ptr result_;
        const substitute_map& sub_;
    };
    
    expression_ptr substitute(Expression* e, const std::string& id, Expression* sub) {
        substitute_map subs;
        subs[id] = sub->clone();
        SubstituteVisitor sub_visitor(subs);
        e->accept(&sub_visitor);
        return sub_visitor.result();
    }
    
    expression_ptr substitute(Expression* e, const substitute_map& sub) {
        SubstituteVisitor sub_visitor(sub);
        e->accept(&sub_visitor);
        return sub_visitor.result();
    }
    
    linear_test_result linear_test(Expression* e, const std::vector<std::string>& vars) {
        linear_test_result result;
        auto loc = e->location();
        auto zero = [loc]() { return make_expression<IntegerExpression>(loc, 0); };
    
        result.constant = e->clone();
        for (const auto& id: vars) {
            auto coef = symbolic_pdiff(e, id);
            if (!is_zero(coef)) result.coef[id] = std::move(coef);
    
            result.constant = substitute(result.constant, id, zero());
        }
    
        ConstantSimplifyVisitor csimp_visitor;
        result.constant->accept(&csimp_visitor);
        result.constant = csimp_visitor.result();
    
        // linearity test: take second order derivatives, test against zero.
        result.is_linear = true;
        for (unsigned i = 0; i<vars.size(); ++i) {
            auto v1 = vars[i];
            if (!is_in(v1, result.coef)) continue;
    
            for (unsigned j = i; j<vars.size(); ++j) {
                auto v2 = vars[j];
    
                if (!is_zero(symbolic_pdiff(result.coef[v1].get(), v2).get())) {
                    result.is_linear = false;
                    goto done;
                }
            }
        }
    done:
    
        if (result.is_linear) {
            result.is_homogeneous = is_zero(result.constant);
        }
    
        return result;
    }