Skip to content
Snippets Groups Projects
Select Git revision
  • 37d039759ef32ed15fca7b5c84f1ca34d6064fa1
  • master default protected
  • tut_ring_allen
  • docs_furo
  • docs_reorder_cable_cell
  • docs_graphviz
  • docs_rtd_dev
  • ebrains_mirror
  • doc_recat
  • docs_spike_source
  • docs_sim_sample_clar
  • docs_pip_warn
  • github_template_updates
  • docs_fix_link
  • cv_default_and_doc_clarification
  • docs_add_numpy_req
  • readme_zenodo_05
  • install_python_fix
  • install_require_numpy
  • typofix_propetries
  • docs_recipe_lookup
  • v0.10.0
  • v0.10.1
  • v0.10.0-rc5
  • v0.10.0-rc4
  • v0.10.0-rc3
  • v0.10.0-rc2
  • v0.10.0-rc
  • v0.9.0
  • v0.9.0-rc
  • v0.8.1
  • v0.8
  • v0.8-rc
  • v0.7
  • v0.6
  • v0.5.2
  • v0.5.1
  • v0.5
  • v0.4
  • v0.3
  • v0.2.2
41 results

symge.cpp

Blame
  • user avatar
    Nora Abi Akar authored and Ben Cumming committed
    Add support for parsing and processing `LINEAR` blocks: 
    
    Changes: 
    * `SOLVE` expressions can be called from inside an `INITIAL` block, but only if they are solving a linear system
    * Tilde expressions can now be either linear expressions or reaction expressions
    * Linear expressions need to be rewritten before being sent to the solver, this is done using `LinearRewriter`
    * The linear system is setup in `LinearSolverVisitor` fills the lhs and rhs of the symbolic matrix   
    * The matrix is recued using `gj_reduce`, which now works on non-diagonal matrices. 
    
    Fixes #839 
    336c0574
    History
    symge.cpp 3.90 KiB
    #include <algorithm>
    #include <stdexcept>
    #include <vector>
    #include <numeric>
    
    #include "symge.hpp"
    
    namespace symge {
    
    struct pivot {
        unsigned row;
        unsigned col;
    };
    
    
    // Returns q[c]*p - p[c]*q; new symbols required due to fill-in are provided by the
    // `define_sym` functor, which takes a `symbol_term_diff` and returns a `symbol`.
    
    template <typename DefineSym>
    sym_row row_reduce(unsigned c, const sym_row& p, const sym_row& q, DefineSym define_sym) {
        if (p.index(c)==p.npos || q.index(c)==q.npos) throw std::runtime_error("improper row reduction");
    
        sym_row u;
        symbol x = q[c];
        symbol y = p[c];
    
        auto piter = p.begin();
        auto qiter = q.begin();
        unsigned pj = piter->col;
        unsigned qj = qiter->col;
    
        while (piter!=p.end() || qiter!=q.end()) {
            unsigned j = std::min(pj, qj);
            symbol_term t1, t2;
    
            if (j==pj) {
                t1 = x*piter->value;
                ++piter;
                pj = piter==p.end()? p.npos: piter->col;
            }
            if (j==qj) {
                t2 = y*qiter->value;
                ++qiter;
                qj = qiter==q.end()? q.npos: qiter->col;
            }
            if (j!=c) {
                u.push_back({j, define_sym(t1-t2)});
            }
        }
        return u;
    }
    
    // Estimate cost of a choice of pivot for G–J reduction below. Uses a simple greedy
    // estimate based on immediate fill cost.
    double estimate_cost(const sym_matrix& A, pivot p) {
        unsigned nfill = 0;
    
        auto count_fill = [&nfill](symbol_term_diff t) {
            bool l = t.left;
            bool r = t.right;
            nfill += r&!l;
            return symbol{};
        };
    
        for (unsigned i = 0; i<A.nrow(); ++i) {
            if (i==p.row || A[i].index(p.col)==msparse::row_npos) continue;
            row_reduce(p.col, A[i], A[p.row], count_fill);
        }
    
        return nfill;
    }
    
    // Perform Gauss-Jordan elimination on given symbolic matrix. New symbols
    // required due to fill-in are added to the supplied symbol table.
    //
    // The matrix A is regarded as being diagonally dominant, and so pivots
    // are selected from the diagonal. The choice of pivot at each stage of
    // the reduction is goverened by a cost estimation (see above).
    //
    // The reduction is division-free: the result will have non-zero terms
    // that are symbols that are either primitive, or defined (in the symbol
    // table) as products or differences of products of other symbols.
    void gj_reduce(sym_matrix& A, symbol_table& table) {
        if (A.nrow()>A.ncol()) throw std::runtime_error("improper matrix for reduction");
    
        auto define_sym = [&table](symbol_term_diff t) { return table.define(t); };
    
        auto get_pivots = [&A](const std::vector<unsigned>& remaining_rows) {
            std::vector<pivot> pivots;
            for (auto r: remaining_rows) {
                pivot p;
                p.row = r;
                const sym_row &row = A[r];
                for (unsigned c = 0; c < A.nrow(); ++c) {
                    if (row[c]) {
                        p.col = c;
                        break;
                    }
                }
                pivots.push_back(std::move(p));
            }
            return pivots;
        };
    
        std::vector<unsigned> remaining_rows(A.nrow());
        std::iota(remaining_rows.begin(), remaining_rows.end(), 0);
    
        std::vector<double> cost(A.nrow());
    
        while (true) {
            auto pivots = get_pivots(remaining_rows);
    
            for (unsigned i = 0; i<pivots.size(); ++i) {
                cost[pivots[i].row] = estimate_cost(A, pivots[i]);
            }
    
            std::sort(pivots.begin(), pivots.end(),
                      [&](pivot r1, pivot r2) { return cost[r1.row]>cost[r2.row]; });
    
            pivot p = pivots.back();
            remaining_rows.erase(std::lower_bound(remaining_rows.begin(), remaining_rows.end(), p.row));
    
            for (unsigned i = 0; i<A.nrow(); ++i) {
                if (i==p.row || A[i].index(p.col)==msparse::row_npos) continue;
                A[i] = row_reduce(p.col, A[i], A[p.row], define_sym);
            }
    
            if (remaining_rows.empty()) {
                break;
            }
        }
    }
    
    } // namespace symge