From e9b6fc1c44f2de3ce67ef4dd311d559f602eed08 Mon Sep 17 00:00:00 2001
From: Ben Cumming <louncharf@gmail.com>
Date: Thu, 9 Mar 2017 17:46:27 +0100
Subject: [PATCH] Unify matrix assembly+solve in backends (#179)

The storage of matrix data, and the operations on matrices (i.e. matrix assembly and matrix solution), have not been implemented in a consistent manner.

The main problem was that matrix assembly was managed by a `matrix_assembler` type provided by the back end, which had views on information that it required to perform assembly. Specifically, the views were on properties like `face_conductance`, model state like `voltage` and `current`, and on the underlying matrix storage `d`, `u` and `p`.

This was not a good solution because
  * there was a hidden dependence of the assembly on model data. e.g. if the voltage storage was reallocated, the reference in the `matrix_assembler` would become stale.
  * if we want back end specific optimizations that require a different data layout to that used elsewhere in the back end, this layout should be shared with the solver, but there is no obvious mechanism for doing that.

This patch addresses this by making a `matrix_state` type in the back end.
  * stores the matrix state opaquely, allowing back-end specific optimizations
  * provides interface for performing operations on the state, namely `assemble`, `solve` and `get_solution`.
  * stores fields such as `face_conductance` and `cv_capacitance` that were stored in the `fvm_multicell`, despite being used only in matrix assembly.
  * takes `voltage` and `current` as parameters to the `assemble` interface, removing the hidden reference to model state.

The actual data layout has not been changed in this PR. Instead, the interface has been refactored and hidden references removed so that it is now possible to implement back-end specific optimizations cleanly.
---
 src/backends/fvm_gpu.hpp       | 100 ++++++++++++++++++-----------
 src/backends/fvm_multicore.hpp | 114 ++++++++++++++++++---------------
 src/fvm_multicell.hpp          |  63 ++++++------------
 src/matrix.hpp                 |  73 ++++++++++-----------
 tests/unit/test_fvm_multi.cpp  |  17 ++---
 tests/unit/test_matrix.cpp     |  37 ++++++-----
 tests/unit/test_matrix.cu      |  22 ++++---
 7 files changed, 221 insertions(+), 205 deletions(-)

diff --git a/src/backends/fvm_gpu.hpp b/src/backends/fvm_gpu.hpp
index 9933dd77..5683efbf 100644
--- a/src/backends/fvm_gpu.hpp
+++ b/src/backends/fvm_gpu.hpp
@@ -100,21 +100,35 @@ struct backend {
     // matrix infrastructure
     //
 
-    /// Hines matrix assembly interface
-    struct matrix_assembler {
-        matrix_update_param_pack<value_type, size_type> params;
+    /// matrix state
+    struct matrix_state {
+        const_iview p;
+        const_iview cell_index;
+
+        array d;     // [μS]
+        array u;     // [μS]
+        array rhs;   // [nA]
+
+        array cv_capacitance;      // [pF]
+        array face_conductance;    // [μS]
 
         // the invariant part of the matrix diagonal
-        array invariant_d;  // [μS]
+        array invariant_d;         // [μS]
+
+        std::size_t size() const { return p.size(); }
 
-        matrix_assembler() = default;
+        matrix_state() = default;
 
-        matrix_assembler(
-            view d, view u, view rhs, const_iview p,
-            const_view cv_capacitance,
-            const_view face_conductance,
-            const_view voltage,
-            const_view current)
+        matrix_state(const_iview p, const_iview cell_index):
+            p(p), cell_index(cell_index),
+            d(size()), u(size()), rhs(size())
+        {}
+
+        matrix_state(const_iview p, const_iview cell_index, array cap, array cond):
+            p(p), cell_index(cell_index),
+            d(size()), u(size()), rhs(size()),
+            cv_capacitance(std::move(cap)),
+            face_conductance(std::move(cond))
         {
             auto n = d.size();
             host_array invariant_d_tmp(n, 0);
@@ -132,47 +146,57 @@ struct backend {
             }
             invariant_d = invariant_d_tmp;
             memory::copy(u_tmp, u);
+        }
+
+        // Assemble the matrix
+        // Afterwards the diagonal and RHS will have been set given dt, voltage and current
+        //   dt      [ms]
+        //   voltage [mV]
+        //   current [nA]
+        void assemble(value_type dt, const_view voltage, const_view current) {
+            EXPECTS(has_fvm_state());
 
-            params = {
+            // determine the grid dimensions for the kernel
+            auto const n = voltage.size();
+            auto const block_dim = 128;
+            auto const grid_dim = (n+block_dim-1)/block_dim;
+
+            auto params = matrix_update_param_pack<value_type, size_type> {
                 d.data(), u.data(), rhs.data(),
                 invariant_d.data(), cv_capacitance.data(), face_conductance.data(),
                 voltage.data(), current.data(), size_type(n)};
+
+            assemble_matrix<value_type, size_type><<<grid_dim, block_dim>>>
+                (params, dt);
+
         }
 
-        void assemble(value_type dt) {
+        void solve() {
+            using solve_param_pack = matrix_solve_param_pack<value_type, size_type>;
+
+            // pack the parameters into a single struct for kernel launch
+            auto params = solve_param_pack{
+                 d.data(), u.data(), rhs.data(),
+                 p.data(), cell_index.data(),
+                 size_type(d.size()), size_type(cell_index.size()-1)
+            };
+
             // determine the grid dimensions for the kernel
-            auto const n = params.n;
+            auto const n = params.ncells;
             auto const block_dim = 96;
             auto const grid_dim = (n+block_dim-1)/block_dim;
 
-            assemble_matrix<value_type, size_type><<<grid_dim, block_dim>>>(params, dt);
+            // perform solve on gpu
+            matrix_solve<value_type, size_type><<<grid_dim, block_dim>>>(params);
         }
 
+        // Test if the matrix has the full state required to assemble the
+        // matrix in the fvm scheme.
+        bool has_fvm_state() const {
+            return cv_capacitance.size()>0;
+        }
     };
 
-    /// Hines solver interface
-    static void hines_solve(
-        view d, view u, view rhs,
-        const_iview p, const_iview cell_index)
-    {
-        using solve_param_pack = matrix_solve_param_pack<value_type, size_type>;
-
-        // pack the parameters into a single struct for kernel launch
-        auto params = solve_param_pack{
-             d.data(), u.data(), rhs.data(),
-             p.data(), cell_index.data(),
-             size_type(d.size()), size_type(cell_index.size()-1)
-        };
-
-        // determine the grid dimensions for the kernel
-        auto const n = params.ncells;
-        auto const block_dim = 96;
-        auto const grid_dim = (n+block_dim-1)/block_dim;
-
-        // perform solve on gpu
-        matrix_solve<value_type, size_type><<<grid_dim, block_dim>>>(params);
-    }
-
     //
     // mechanism infrastructure
     //
diff --git a/src/backends/fvm_multicore.hpp b/src/backends/fvm_multicore.hpp
index 8ac52561..8bddb956 100644
--- a/src/backends/fvm_multicore.hpp
+++ b/src/backends/fvm_multicore.hpp
@@ -35,60 +35,35 @@ struct backend {
     using host_view   = view;
     using host_iview  = iview;
 
-    /// hines matrix solver
-    static void hines_solve(
-        view d, view u, view rhs,
-        const_iview p, const_iview cell_index)
-    {
-        const size_type ncells = cell_index.size()-1;
-
-        // loop over submatrices
-        for (auto m: util::make_span(0, ncells)) {
-            auto first = cell_index[m];
-            auto last = cell_index[m+1];
-
-            // backward sweep
-            for(auto i=last-1; i>first; --i) {
-                auto factor = u[i] / d[i];
-                d[p[i]]   -= factor * u[i];
-                rhs[p[i]] -= factor * rhs[i];
-            }
-            rhs[first] /= d[first];
-
-            // forward sweep
-            for(auto i=first+1; i<last; ++i) {
-                rhs[i] -= u[i] * rhs[p[i]];
-                rhs[i] /= d[i];
-            }
-        }
-    }
-
-    struct matrix_assembler {
-        view d;     // [μS]
-        view u;     // [μS]
-        view rhs;   // [nA]
+    /// matrix state
+    struct matrix_state {
         const_iview p;
+        const_iview cell_index;
+
+        array d;     // [μS]
+        array u;     // [μS]
+        array rhs;   // [nA]
 
-        const_view cv_capacitance;      // [pF]
-        const_view face_conductance;    // [μS]
-        const_view voltage;             // [mV]
-        const_view current;             // [nA]
+        array cv_capacitance;      // [pF]
+        array face_conductance;    // [μS]
 
         // the invariant part of the matrix diagonal
-        array invariant_d;              // [μS]
-
-        matrix_assembler() = default;
-
-        matrix_assembler(
-            view d, view u, view rhs, const_iview p,
-            const_view cv_capacitance,
-            const_view face_conductance,
-            const_view voltage,
-            const_view current)
-        :
-            d{d}, u{u}, rhs{rhs}, p{p},
-            cv_capacitance{cv_capacitance}, face_conductance{face_conductance},
-            voltage{voltage}, current{current}
+        array invariant_d;         // [μS]
+
+        std::size_t size() const { return p.size(); }
+
+        matrix_state() = default;
+
+        matrix_state(const_iview p, const_iview cell_index):
+            p(p), cell_index(cell_index),
+            d(size()), u(size()), rhs(size())
+        {}
+
+        matrix_state(const_iview p, const_iview cell_index, array cap, array cond):
+            p(p), cell_index(cell_index),
+            d(size()), u(size()), rhs(size()),
+            cv_capacitance(std::move(cap)),
+            face_conductance(std::move(cond))
         {
             auto n = d.size();
             invariant_d = array(n, 0);
@@ -101,7 +76,14 @@ struct backend {
             }
         }
 
-        void assemble(value_type dt) {
+        // Assemble the matrix
+        // Afterwards the diagonal and RHS will have been set given dt, voltage and current
+        //   dt      [ms]
+        //   voltage [mV]
+        //   current [nA]
+        void assemble(value_type dt, const_view voltage, const_view current) {
+            EXPECTS(has_fvm_state());
+
             auto n = d.size();
             value_type factor = 1e-3/dt;
             for (auto i: util::make_span(0u, n)) {
@@ -112,6 +94,36 @@ struct backend {
                 rhs[i] = gi*voltage[i] - current[i];
             }
         }
+
+        void solve() {
+            const size_type ncells = cell_index.size()-1;
+
+            // loop over submatrices
+            for (auto m: util::make_span(0, ncells)) {
+                auto first = cell_index[m];
+                auto last = cell_index[m+1];
+
+                // backward sweep
+                for(auto i=last-1; i>first; --i) {
+                    auto factor = u[i] / d[i];
+                    d[p[i]]   -= factor * u[i];
+                    rhs[p[i]] -= factor * rhs[i];
+                }
+                rhs[first] /= d[first];
+
+                // forward sweep
+                for(auto i=first+1; i<last; ++i) {
+                    rhs[i] -= u[i] * rhs[p[i]];
+                    rhs[i] /= d[i];
+                }
+            }
+        }
+
+        // Test if the matrix has the full state required to assemble the
+        // matrix in the fvm scheme.
+        bool has_fvm_state() const {
+            return cv_capacitance.size()>0;
+        }
     };
 
     //
diff --git a/src/fvm_multicell.hpp b/src/fvm_multicell.hpp
index a629da3e..5cc34aa8 100644
--- a/src/fvm_multicell.hpp
+++ b/src/fvm_multicell.hpp
@@ -60,8 +60,6 @@ public:
     /// the container used for indexes
     using iarray = typename backend::iarray;
 
-    using matrix_assembler = typename backend::matrix_assembler;
-
     using target_handle = std::pair<size_type, size_type>;
     using probe_handle = std::pair<const array fvm_multicell::*, size_type>;
 
@@ -121,11 +119,6 @@ public:
     ///     1e-8.cm^2
     const_view cv_areas() const { return cv_areas_; }
 
-    /// return the capacitance of each CV surface
-    /// this is the total capacitance, not per unit area,
-    /// i.e. equivalent to sigma_i * c_m
-    const_view cv_capacitance() const { return cv_capacitance_; }
-
     /// return the voltage in each CV
     view       voltage()       { return voltage_; }
     const_view voltage() const { return voltage_; }
@@ -213,20 +206,9 @@ private:
     /// the linear system for implicit time stepping of cell state
     matrix_type matrix_;
 
-    /// the helper used to construct the matrix
-    matrix_assembler matrix_assembler_;
-
     /// cv_areas_[i] is the surface area of CV i [µm^2]
     array cv_areas_;
 
-    /// CV i and its parent, required when constructing linear system [µS]
-    ///     face_conductance_[i] = area_face  / (r_L * delta_x);
-    array face_conductance_;
-
-    /// cv_capacitance_[i] is the capacitance of CV membrane [pF]
-    ///     C_m = area*c_m
-    array cv_capacitance_; // units [µm^2*F*m^-2 = pF]
-
     /// the transmembrane current over the surface of each CV [nA]
     ///     I = area*i_m - I_e
     array current_;
@@ -277,9 +259,9 @@ private:
         std::pair<size_type, size_type> comp_ival,
         const segment* seg,
         const std::vector<size_type>& parent,
-        std::vector<value_type>& tmp_face_conductance,
+        std::vector<value_type>& face_conductance,
         std::vector<value_type>& tmp_cv_areas,
-        std::vector<value_type>& tmp_cv_capacitance
+        std::vector<value_type>& cv_capacitance
     );
 };
 
@@ -292,9 +274,9 @@ fvm_multicell<Backend>::compute_cv_area_capacitance(
     std::pair<size_type, size_type> comp_ival,
     const segment* seg,
     const std::vector<size_type>& parent,
-    std::vector<value_type>& tmp_face_conductance,
+    std::vector<value_type>& face_conductance,
     std::vector<value_type>& tmp_cv_areas,
-    std::vector<value_type>& tmp_cv_capacitance)
+    std::vector<value_type>& cv_capacitance)
 {
     // precondition: group_parent_index[j] holds the correct value for
     // j in [base_comp, base_comp+segment.num_compartments()].
@@ -313,7 +295,7 @@ fvm_multicell<Backend>::compute_cv_area_capacitance(
         auto c_m = soma->mechanism("membrane").get("c_m").value;
 
         tmp_cv_areas[i] += area;
-        tmp_cv_capacitance[i] += area*c_m;
+        cv_capacitance[i] += area*c_m;
 
         cv_range.segment_cvs = {comp_ival.first, comp_ival.first+1};
         cv_range.areas = {0.0, area};
@@ -399,15 +381,15 @@ fvm_multicell<Backend>::compute_cv_area_capacitance(
             auto conductance = 1/r_L*h*V1*V2/(h2*h2*V1+h1*h1*V2);
             // the scaling factor of 10^2 is to convert the quantity
             // to micro Siemens [μS]
-            tmp_face_conductance[i] =  1e2 * conductance / h;
+            face_conductance[i] =  1e2 * conductance / h;
 
             auto al = div.left.area;
             auto ar = div.right.area;
 
             tmp_cv_areas[j] += al;
             tmp_cv_areas[i] += ar;
-            tmp_cv_capacitance[j] += al * c_m;
-            tmp_cv_capacitance[i] += ar * c_m;
+            cv_capacitance[j] += al * c_m;
+            cv_capacitance[i] += ar * c_m;
         }
     }
     else {
@@ -465,11 +447,13 @@ void fvm_multicell<Backend>::initialize(
 
     // Allocate scratch storage for calculating quantities used to build the
     // linear system: these will later be copied into target-specific storage
-    // as need be.
-    // Initialize to zero, because the results therin are calculated via accumulation.
-    std::vector<value_type> tmp_face_conductance(ncomp, 0.);
-    std::vector<value_type> tmp_cv_areas(ncomp, 0.);
-    std::vector<value_type> tmp_cv_capacitance(ncomp, 0.);
+
+    // face_conductance_[i] = area_face  / (r_L * delta_x);
+    std::vector<value_type> face_conductance(ncomp); // [µS]
+    /// cv_capacitance_[i] is the capacitance of CV membrane
+    std::vector<value_type> cv_capacitance(ncomp);   // [µm^2*F*m^-2 = pF]
+    /// membrane area of each cv
+    std::vector<value_type> tmp_cv_areas(ncomp);         // [µm^2]
 
     // used to build the information required to construct spike detectors
     std::vector<size_type> spike_detector_index;
@@ -508,7 +492,7 @@ void fvm_multicell<Backend>::initialize(
 
             auto cv_range = compute_cv_area_capacitance(
                 seg_comp_ival, seg, group_parent_index,
-                tmp_face_conductance, tmp_cv_areas, tmp_cv_capacitance);
+                face_conductance, tmp_cv_areas, cv_capacitance);
 
             for (const auto& mech: seg->mechanisms()) {
                 if (mech.name()!="membrane") {
@@ -602,16 +586,11 @@ void fvm_multicell<Backend>::initialize(
     EXPECTS(probes_size==probes_count);
 
     // store the geometric information in target-specific containers
-    face_conductance_ = make_const_view(tmp_face_conductance);
-    cv_areas_         = make_const_view(tmp_cv_areas);
-    cv_capacitance_   = make_const_view(tmp_cv_capacitance);
+    cv_areas_ = make_const_view(tmp_cv_areas);
 
     // initalize matrix
-    matrix_ = matrix_type(group_parent_index, cell_comp_bounds);
-
-    matrix_assembler_ = matrix_assembler(
-        matrix_.d(), matrix_.u(), matrix_.rhs(), matrix_.p(),
-        cv_capacitance_, face_conductance_, voltage_, current_);
+    matrix_ = matrix_type(
+        group_parent_index, cell_comp_bounds, cv_capacitance, face_conductance);
 
     // For each density mechanism build the full node index, i.e the list of
     // compartments with that mechanism, then build the mechanism instance.
@@ -792,12 +771,12 @@ void fvm_multicell<Backend>::advance(double dt) {
 
     // solve the linear system
     PE("matrix", "setup");
-    matrix_assembler_.assemble(dt);
+    matrix_.assemble(dt, voltage_, current_);
 
     PL(); PE("solve");
     matrix_.solve();
     PL();
-    memory::copy(matrix_.rhs(), voltage_);
+    memory::copy(matrix_.solution(), voltage_);
     PL();
 
     // integrate state of gating variables etc.
diff --git a/src/matrix.hpp b/src/matrix.hpp
index bf9ca068..f37e54a8 100644
--- a/src/matrix.hpp
+++ b/src/matrix.hpp
@@ -32,25 +32,32 @@ public:
 
     using host_array = typename backend::host_array;
 
+    // back end specific storage for matrix state
+    using state = typename backend::matrix_state;
+
     matrix() = default;
 
-    /// construct matrix for one or more cells, with combined parent index
-    /// and a cell index
+    /// construct matrix for one or more cells, described by a parent index and
+    /// a cell index.
     matrix(const std::vector<size_type>& pi, const std::vector<size_type>& ci):
         parent_index_(memory::make_const_view(pi)),
-        cell_index_(memory::make_const_view(ci))
+        cell_index_(memory::make_const_view(ci)),
+        state_(parent_index_, cell_index_)
     {
-        setup();
+        EXPECTS(cell_index_[num_cells()] == parent_index_.size());
     }
 
-    /// construct matrix for a single cell described by a parent index
-    matrix(const std::vector<size_type>& pi):
+    matrix( const std::vector<size_type>& pi,
+            const std::vector<size_type>& ci,
+            const std::vector<value_type>& cv_capacitance,
+            const std::vector<value_type>& face_conductance):
         parent_index_(memory::make_const_view(pi)),
-        cell_index_(2)
+        cell_index_(memory::make_const_view(ci)),
+        state_( parent_index_, cell_index_,
+                memory::make_const_view(cv_capacitance),
+                memory::make_const_view(face_conductance))
     {
-        cell_index_[0] = 0;
-        cell_index_[1] = parent_index_.size();
-        setup();
+        EXPECTS(cell_index_[num_cells()] == parent_index_.size());
     }
 
     /// the dimension of the matrix (i.e. the number of rows or colums)
@@ -63,54 +70,40 @@ public:
         return cell_index_.size() - 1;
     }
 
-    /// the vector holding the diagonal of the matrix
-    view d() { return d_; }
-    const_view d() const { return d_; }
-
-    /// the vector holding the upper part of the matrix
-    view u() { return u_; }
-    const_view u() const { return u_; }
-
-    /// the vector holding the right hand side of the linear equation system
-    view rhs() { return rhs_; }
-    const_view rhs() const { return rhs_; }
-
     /// the vector holding the parent index
     const_iview p() const { return parent_index_; }
 
-    /// the patrition of the parent index over the cells
+    /// the partition of the parent index over the cells
     const_iview cell_index() const { return cell_index_; }
 
     /// Solve the linear system.
-    /// Upon completion the solution is stored in the RHS storage, which can
-    /// be accessed via rhs().
     void solve() {
-        backend::hines_solve(d_, u_, rhs_, parent_index_, cell_index_);
+        state_.solve();
     }
 
-    private:
-
-    /// Allocate memory for storing matrix and right hand side vector
-    /// and build the face area contribution to the diagonal
-    void setup() {
-        const auto n = size();
-        constexpr auto default_value = std::numeric_limits<value_type>::quiet_NaN();
+    /// Assemble the matrix for given dt
+    void assemble(double dt, const_view voltage, const_view current) {
+        state_.assemble(dt, voltage, current);
+    }
 
-        d_   = array(n, default_value);
-        u_   = array(n, default_value);
-        rhs_ = array(n, default_value);
+    /// Get a view of the solution
+    const_view solution() const {
+        return state_.rhs;
     }
 
+    private:
+
     /// the parent indice that describe matrix structure
     iarray parent_index_;
 
     /// indexes that point to the start of each cell in the index
     iarray cell_index_;
 
-    /// storage for lower, diagonal, upper and rhs
-    array d_;
-    array u_;
-    array rhs_;
+public:
+    // Provide via public interface to make testing much easier.
+    // If you modify this directly without knowing what you are doing,
+    // you get what you deserve.
+    state state_;
 };
 
 } // namespace nest
diff --git a/tests/unit/test_fvm_multi.cpp b/tests/unit/test_fvm_multi.cpp
index 7eb06064..5b678bb4 100644
--- a/tests/unit/test_fvm_multi.cpp
+++ b/tests/unit/test_fvm_multi.cpp
@@ -78,25 +78,26 @@ TEST(fvm_multi, init)
     // test that the matrix is initialized with sensible values
     //J.build_matrix(0.01);
     fvcell.advance(0.01);
-    auto test_nan = [](decltype(J.u()) v) {
+    auto& mat = J.state_;
+    auto test_nan = [](decltype(mat.u) v) {
         for(auto val : v) if(val != val) return false;
         return true;
     };
-    EXPECT_TRUE(test_nan(J.u()(1, J.size())));
-    EXPECT_TRUE(test_nan(J.d()));
-    EXPECT_TRUE(test_nan(J.rhs()));
+    EXPECT_TRUE(test_nan(mat.u(1, J.size())));
+    EXPECT_TRUE(test_nan(mat.d));
+    EXPECT_TRUE(test_nan(J.solution()));
 
     // test matrix diagonals for sign
-    auto is_pos = [](decltype(J.u()) v) {
+    auto is_pos = [](decltype(mat.u) v) {
         for(auto val : v) if(val<=0.) return false;
         return true;
     };
-    auto is_neg = [](decltype(J.u()) v) {
+    auto is_neg = [](decltype(mat.u) v) {
         for(auto val : v) if(val>=0.) return false;
         return true;
     };
-    EXPECT_TRUE(is_neg(J.u()(1, J.size())));
-    EXPECT_TRUE(is_pos(J.d()));
+    EXPECT_TRUE(is_neg(mat.u(1, J.size())));
+    EXPECT_TRUE(is_pos(mat.d));
 
 }
 
diff --git a/tests/unit/test_matrix.cpp b/tests/unit/test_matrix.cpp
index 5f978786..7bd99111 100644
--- a/tests/unit/test_matrix.cpp
+++ b/tests/unit/test_matrix.cpp
@@ -8,17 +8,19 @@
 #include <backends/fvm_multicore.hpp>
 #include <util/span.hpp>
 
-using matrix_type = nest::mc::matrix<nest::mc::multicore::backend>;
+using namespace nest::mc;
+
+using matrix_type = matrix<nest::mc::multicore::backend>;
 using size_type = matrix_type::size_type;
 
 TEST(matrix, construct_from_parent_only)
 {
-    using nest::mc::util::make_span;
+    using util::make_span;
 
     // pass parent index as a std::vector cast to host data
     {
         std::vector<size_type> p = {0,0,1};
-        matrix_type m(p);
+        matrix_type m(p, {0, 3});
         EXPECT_EQ(m.num_cells(), 1u);
         EXPECT_EQ(m.size(), 3u);
         EXPECT_EQ(p.size(), 3u);
@@ -32,38 +34,41 @@ TEST(matrix, construct_from_parent_only)
 
 TEST(matrix, solve_host)
 {
-    using nest::mc::util::make_span;
-    using nest::mc::memory::fill;
+    using util::make_span;
+    using memory::fill;
 
     // trivial case : 1x1 matrix
     {
-        matrix_type m(std::vector<size_type>{0});
-        fill(m.d(),  2);
-        fill(m.u(), -1);
-        fill(m.rhs(),1);
+        matrix_type m({0}, {0,1});
+        auto& state = m.state_;
+        fill(state.d,  2);
+        fill(state.u, -1);
+        fill(state.rhs,1);
 
         m.solve();
 
-        EXPECT_EQ(m.rhs()[0], 0.5);
+        EXPECT_EQ(m.solution()[0], 0.5);
     }
+
     // matrices in the range of 2x2 to 1000x1000
     {
-        using namespace nest::mc;
         for(auto n : make_span(2u,1001u)) {
             auto p = std::vector<size_type>(n);
             std::iota(p.begin()+1, p.end(), 0);
-            matrix_type m(p);
+            matrix_type m(p, {0, n});
 
             EXPECT_EQ(m.size(), n);
             EXPECT_EQ(m.num_cells(), 1u);
 
-            fill(m.d(),  2);
-            fill(m.u(), -1);
-            fill(m.rhs(),1);
+            auto& A = m.state_;
+
+            fill(A.d,  2);
+            fill(A.u, -1);
+            fill(A.rhs,1);
 
             m.solve();
 
-            auto x = m.rhs();
+            auto x = m.solution();
             auto err = math::square(std::fabs(2.*x[0] - x[1] - 1.));
             for(auto i : make_span(1,n-1)) {
                 err += math::square(std::fabs(2.*x[i] - x[i-1] - x[i+1] - 1.));
diff --git a/tests/unit/test_matrix.cu b/tests/unit/test_matrix.cu
index ca407f19..3f13a86f 100644
--- a/tests/unit/test_matrix.cu
+++ b/tests/unit/test_matrix.cu
@@ -20,15 +20,16 @@ TEST(matrix, solve_gpu)
 
     // trivial case : 1x1 matrix
     {
-        matrix_type m({0});
+        matrix_type m({0}, {0,1});
 
-        memory::fill(m.d(),  2);
-        memory::fill(m.u(), -1);
-        memory::fill(m.rhs(),1);
+        auto& state = m.state_;
+        memory::fill(state.d,  2);
+        memory::fill(state.u, -1);
+        memory::fill(state.rhs,1);
 
         m.solve();
 
-        auto rhs = memory::on_host(m.rhs());
+        auto rhs = memory::on_host(m.solution());
 
         EXPECT_EQ(rhs[0], 0.5);
     }
@@ -39,18 +40,19 @@ TEST(matrix, solve_gpu)
         for(auto n : make_span(2u,101u)) {
             auto p = std::vector<index_type>(n);
             std::iota(p.begin()+1, p.end(), 0);
-            matrix_type m{p};
+            matrix_type m{p, {0, n}};
 
             EXPECT_EQ(m.size(), n);
             EXPECT_EQ(m.num_cells(), 1u);
 
-            memory::fill(m.d(),  2);
-            memory::fill(m.u(), -1);
-            memory::fill(m.rhs(),1);
+            auto& state = m.state_;
+            memory::fill(state.d,  2);
+            memory::fill(state.u, -1);
+            memory::fill(state.rhs,1);
 
             m.solve();
 
-            auto x = memory::on_host(m.rhs());
+            auto x = memory::on_host(m.solution());
             auto err = math::square(std::fabs(2.*x[0] - x[1] - 1.));
             for(auto i : make_span(1,n-1)) {
                 err += math::square(std::fabs(2.*x[i] - x[i-1] - x[i+1] - 1.));
-- 
GitLab