Skip to content
Snippets Groups Projects
Select Git revision
  • 37d039759ef32ed15fca7b5c84f1ca34d6064fa1
  • master default protected
  • tut_ring_allen
  • docs_furo
  • docs_reorder_cable_cell
  • docs_graphviz
  • docs_rtd_dev
  • ebrains_mirror
  • doc_recat
  • docs_spike_source
  • docs_sim_sample_clar
  • docs_pip_warn
  • github_template_updates
  • docs_fix_link
  • cv_default_and_doc_clarification
  • docs_add_numpy_req
  • readme_zenodo_05
  • install_python_fix
  • install_require_numpy
  • typofix_propetries
  • docs_recipe_lookup
  • v0.10.0
  • v0.10.1
  • v0.10.0-rc5
  • v0.10.0-rc4
  • v0.10.0-rc3
  • v0.10.0-rc2
  • v0.10.0-rc
  • v0.9.0
  • v0.9.0-rc
  • v0.8.1
  • v0.8
  • v0.8-rc
  • v0.7
  • v0.6
  • v0.5.2
  • v0.5.1
  • v0.5
  • v0.4
  • v0.3
  • v0.2.2
41 results

solvers.hpp

Blame
  • user avatar
    Nora Abi Akar authored and Sam Yates committed
    * Modify `SparseSolverVisitor` to allow solving kinetic equations at steady state. 
    
    Addresses #837 
    4bd6f097
    History
    solvers.hpp 4.57 KiB
    #pragma once
    
    // Transform derivative block into AST representing
    // an integration step over the state variables, based on
    // solver method.
    
    #include <string>
    #include <vector>
    
    #include "expression.hpp"
    #include "symdiff.hpp"
    #include "symge.hpp"
    #include "visitor.hpp"
    
    expression_ptr remove_unused_locals(BlockExpression* block);
    
    class SolverVisitorBase: public BlockRewriterBase {
    protected:
        // list of identifier names appearing in derivatives on lhs
        std::vector<std::string> dvars_;
    
    public:
        using BlockRewriterBase::visit;
    
        SolverVisitorBase() {}
        SolverVisitorBase(scope_ptr enclosing): BlockRewriterBase(enclosing) {}
    
        virtual std::vector<std::string> solved_identifiers() const {
            return dvars_;
        }
    
        virtual void reset() override {
            dvars_.clear();
            BlockRewriterBase::reset();
        }
    };
    
    class DirectSolverVisitor : public SolverVisitorBase {
    public:
        using SolverVisitorBase::visit;
    
        DirectSolverVisitor() {}
        DirectSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {}
    
        virtual void visit(AssignmentExpression *e) override {
            // No solver method, so declare an error if lhs is a derivative.
            if(auto deriv = e->lhs()->is_derivative()) {
                error({"The DERIVATIVE block has a derivative expression"
                       " but no METHOD was specified in the SOLVE statement",
                       deriv->location()});
            }
            else {
                visit((Expression*)e);
            }
        }
    };
    
    class CnexpSolverVisitor : public SolverVisitorBase {
    public:
        using SolverVisitorBase::visit;
    
        CnexpSolverVisitor() {}
        CnexpSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {}
    
        virtual void visit(BlockExpression* e) override;
        virtual void visit(AssignmentExpression *e) override;
    };
    
    class SparseSolverVisitor : public SolverVisitorBase {
    protected:
        solverVariant solve_variant_;
        // 'Current' differential equation is for variable with this
        // index in `dvars`.
        unsigned deq_index_ = 0;
    
        // Expanded local assignments that need to be substituted in for derivative
        // calculations.
        substitute_map local_expr_;
    
        // Symbolic matrix for backwards Euler step.
        symge::sym_matrix A_;
    
        // 'Symbol table' for symbolic manipulation.
        symge::symbol_table symtbl_;
    
        // Flag to indicate whether conserve statements are part of the system
        bool conserve_ = false;
    
        // state variable multiplier/divider
        std::vector<expression_ptr> scale_factor_;
    
        // rhs of conserve statement
        std::vector<std::string> conserve_rhs_;
        std::vector<unsigned> conserve_idx_;
    
        // rhs of steadstate
        std::string steadystate_rhs_;
    public:
        using SolverVisitorBase::visit;
    
        explicit SparseSolverVisitor(solverVariant s = solverVariant::regular) :
            solve_variant_(s) {}
        SparseSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {}
    
        virtual void visit(BlockExpression* e) override;
        virtual void visit(AssignmentExpression *e) override;
        virtual void visit(CompartmentExpression *e) override;
        virtual void visit(ConserveExpression *e) override;
        virtual void finalize() override;
        virtual void reset() override {
            deq_index_ = 0;
            local_expr_.clear();
            A_.clear();
            symtbl_.clear();
            conserve_ = false;
            scale_factor_.clear();
            conserve_rhs_.clear();
            conserve_idx_.clear();
            steadystate_rhs_.clear();
            SolverVisitorBase::reset();
        }
    };
    
    class LinearSolverVisitor : public SolverVisitorBase {
    protected:
        // 'Current' differential equation is for variable with this
        // index in `dvars`.
        unsigned deq_index_ = 0;
    
        // Expanded local assignments that need to be substituted in for derivative
        // calculations.
        substitute_map local_expr_;
    
        // Symbolic matrix for backwards Euler step.
        symge::sym_matrix A_;
    
        // RHS
        std::vector<symge::symbol> rhs_;
    
        // 'Symbol table' for symbolic manipulation.
        symge::symbol_table symtbl_;
    
    public:
        using SolverVisitorBase::visit;
    
        LinearSolverVisitor(std::vector<std::string> vars) {
            dvars_ = vars;
        }
        LinearSolverVisitor(scope_ptr enclosing): SolverVisitorBase(enclosing) {}
    
        virtual void visit(BlockExpression* e) override;
        virtual void visit(LinearExpression *e) override;
        virtual void visit(AssignmentExpression *e) override;
        virtual void finalize() override;
        virtual void reset() override {
            deq_index_ = 0;
            local_expr_.clear();
            A_.clear();
            rhs_.clear();
            symtbl_.clear();
            SolverVisitorBase::reset();
        }
    };