From be2a8a9faa58832ea4482b8e012d5169c805a5f1 Mon Sep 17 00:00:00 2001
From: Ben Cumming <bcumming@cscs.ch>
Date: Tue, 13 Nov 2018 11:18:22 +0100
Subject: [PATCH] Squashed merge for fine matrix solver (#640)

Add a new Hines matrix solver implementation for the GPU that can solve a single tree in parallel with multiple threads. It replaces the interleaved solver, which used a single thread to solve each matrix.
Branches with the same common root in the tree can be solved independently on each of the forward and backward solution passes.

* Add a matrix storage type, `arb::gpu::matrix_state_fine` that stores the branches of multiple trees for efficient backward and forward substitution.
* Extend the `arb::tree` data structure to support operations for choosing a new root node and determining a root node which minimises the maximum distance between the root and any of the trees leaves.
* Implement code for rebalancing a set of matrix trees, a.k.a. a "forest" of trees.
* Add CUDA kernels for efficiently performing matrix assembly and matrix solution steps.
* Add CMake option `ARB_WITH_GPU_FINE_MATRIX` for toggling the new solver (default `on`).
---
 CMakeLists.txt                           |   4 +
 arbor/CMakeLists.txt                     |   4 +
 arbor/algorithms.hpp                     |  31 +-
 arbor/backends/gpu/forest.cpp            | 181 +++++++++
 arbor/backends/gpu/forest.hpp            | 140 +++++++
 arbor/backends/gpu/fvm.hpp               |  12 +-
 arbor/backends/gpu/matrix_fine.cpp       |  71 ++++
 arbor/backends/gpu/matrix_fine.cu        | 314 ++++++++++++++
 arbor/backends/gpu/matrix_fine.hpp       |  77 ++++
 arbor/backends/gpu/matrix_solve.cu       |   1 +
 arbor/backends/gpu/matrix_state_fine.hpp | 494 +++++++++++++++++++++++
 arbor/tree.cpp                           | 385 ++++++++++++++++++
 arbor/tree.hpp                           | 153 +++----
 test/unit/test_gpu_stack.cu              |   3 +
 test/unit/test_matrix.cu                 |  90 +----
 test/unit/test_matrix_cpuvsgpu.cpp       |   9 +-
 test/unit/test_tree.cpp                  |  83 ++++
 17 files changed, 1873 insertions(+), 179 deletions(-)
 create mode 100644 arbor/backends/gpu/forest.cpp
 create mode 100644 arbor/backends/gpu/forest.hpp
 create mode 100644 arbor/backends/gpu/matrix_fine.cpp
 create mode 100644 arbor/backends/gpu/matrix_fine.cu
 create mode 100644 arbor/backends/gpu/matrix_fine.hpp
 create mode 100644 arbor/backends/gpu/matrix_state_fine.hpp
 create mode 100644 arbor/tree.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index b8c83fae..b9802b68 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -39,6 +39,7 @@ set(ARB_VALIDATION_DATA_DIR "${PROJECT_SOURCE_DIR}/validation/data" CACHE PATH
 #----------------------------------------------------------
 
 option(ARB_WITH_GPU "build with GPU support" OFF)
+option(ARB_WITH_GPU_FINE_MATRIX "use optimized fine matrix solver" ON)
 
 option(ARB_WITH_MPI "build with MPI support" OFF)
 
@@ -197,6 +198,9 @@ if(ARB_WITH_GPU)
         "$<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe=--diag_suppress=integer_sign_change>"
         "$<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe=--diag_suppress=unsigned_compare_with_zero>")
     target_compile_definitions(arbor-private-deps INTERFACE ARB_HAVE_GPU)
+    if(ARB_WITH_GPU_FINE_MATRIX)
+        target_compile_definitions(arbor-private-deps INTERFACE ARB_HAVE_GPU_FINE_MATRIX)
+    endif()
 
     target_compile_options(arbor-private-deps INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_35,code=sm_35>)
     target_compile_options(arbor-private-deps INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_37,code=sm_37>)
diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt
index dd885eac..87ec5150 100644
--- a/arbor/CMakeLists.txt
+++ b/arbor/CMakeLists.txt
@@ -45,6 +45,7 @@ set(arbor_sources
     threading/threading.cpp
     threading/thread_info.cpp
     thread_private_spike_store.cpp
+    tree.cpp
     util/hostname.cpp
     util/unwind.cpp
     version.cpp
@@ -59,10 +60,13 @@ if(ARB_WITH_CUDA)
         backends/gpu/threshold_watcher.cu
         backends/gpu/matrix_assemble.cu
         backends/gpu/matrix_interleave.cu
+        backends/gpu/matrix_fine.cu
+        backends/gpu/matrix_fine.cpp
         backends/gpu/matrix_solve.cu
         backends/gpu/multi_event_stream.cpp
         backends/gpu/multi_event_stream.cu
         backends/gpu/shared_state.cu
+        backends/gpu/forest.cpp
         backends/gpu/stimulus.cu
         backends/gpu/threshold_watcher.cu
         memory/fill.cu
diff --git a/arbor/algorithms.hpp b/arbor/algorithms.hpp
index 4e295834..a2ec968a 100644
--- a/arbor/algorithms.hpp
+++ b/arbor/algorithms.hpp
@@ -41,6 +41,8 @@ mean(C const& c)
     return sum(c)/util::size(c);
 }
 
+// returns the prefix sum of c in the form `[0, c[0], c[0]+c[1], ..., sum(c)]`.
+// This means that the returned vector has one more element than c.
 template <typename C>
 C make_index(C const& c)
 {
@@ -94,6 +96,9 @@ bool is_strictly_monotonic_decreasing(C const& c)
     );
 }
 
+// check if c[0] == 0 and c[i] < 0 holds for i != 0
+// this means that children of a node always have larger indices than their
+// parent
 template <
     typename C,
     typename = typename std::enable_if<std::is_integral<typename C::value_type>::value>
@@ -130,17 +135,23 @@ bool all_negative(const C& c) {
     return util::all_of(c, [](auto v) { return v<decltype(v){}; });
 }
 
+// returns a vector containing the number of children for each node.
 template<typename C>
 std::vector<typename C::value_type> child_count(const C& parent_index)
 {
+    using value_type = typename C::value_type;
     static_assert(
-        std::is_integral<typename C::value_type>::value,
+        std::is_integral<value_type>::value,
         "integral type required"
     );
 
-    std::vector<typename C::value_type> count(parent_index.size(), 0);
-    for (auto i = 1u; i < parent_index.size(); ++i) {
-        ++count[parent_index[i]];
+    std::vector<value_type> count(parent_index.size(), 0);
+    for (auto i = 0u; i < parent_index.size(); ++i) {
+        auto p = parent_index[i];
+        // -1 means no parent
+        if (p != value_type(i) && p != value_type(-1)) {
+            ++count[p];
+        }
     }
 
     return count;
@@ -190,7 +201,7 @@ std::vector<typename C::value_type> branches(const C& parent_index)
     for (std::size_t i = 1; i < parent_index.size(); ++i) {
         auto p = parent_index[i];
         if (num_child[p] > 1 || parent_index[p] == p) {
-            // parent_index[p] == p -> parent_index[i] is the soma
+            // `parent_index[p] == p` ~> parent_index[i] is the soma
             branch_index.push_back(i);
         }
     }
@@ -199,7 +210,9 @@ std::vector<typename C::value_type> branches(const C& parent_index)
     return branch_index;
 }
 
-
+// creates a vector that contains the branch index for each compartment.
+// e.g. {0, 1, 5, 9, 10} -> {0, 1, 1, 1, 1, 2, 2, 2, 2, 3}
+//                  indices  0  1           5           9
 template<typename C>
 std::vector<typename C::value_type> expand_branches(const C& branch_index)
 {
@@ -281,11 +294,13 @@ std::vector<typename C::value_type> tree_reduce(
     arb_assert(has_contiguous_compartments(parent_index));
     arb_assert(is_strictly_monotonic_increasing(branch_index));
 
-    // expand the branch index
+    // expand the branch index to lookup the banch id for each compartment
     auto expanded_branch = expand_branches(branch_index);
 
     std::vector<typename C::value_type> new_parent_index;
-    for (std::size_t i = 0; i < branch_index.size()-1; ++i) {
+    // push the first element manually as the parent of the root might be -1
+    new_parent_index.push_back(expanded_branch[0]);
+    for (std::size_t i = 1; i < branch_index.size()-1; ++i) {
         auto p = parent_index[branch_index[i]];
         new_parent_index.push_back(expanded_branch[p]);
     }
diff --git a/arbor/backends/gpu/forest.cpp b/arbor/backends/gpu/forest.cpp
new file mode 100644
index 00000000..6650b92c
--- /dev/null
+++ b/arbor/backends/gpu/forest.cpp
@@ -0,0 +1,181 @@
+#include "backends/gpu/forest.hpp"
+#include "util/span.hpp"
+
+namespace arb {
+namespace gpu {
+
+forest::forest(const std::vector<size_type>& p, const std::vector<size_type>& cell_cv_divs) :
+    perm_balancing(p.size())
+{
+    using util::make_span;
+
+    auto num_cells = cell_cv_divs.size() - 1;
+
+    for (auto c: make_span(0u, num_cells)) {
+        // build the parent index for cell c
+        auto cell_start = cell_cv_divs[c];
+        std::vector<unsigned> cell_p =
+            util::assign_from(
+                util::transform_view(
+                    util::subrange_view(p, cell_cv_divs[c], cell_cv_divs[c+1]),
+                    [cell_start](unsigned i) {return i-cell_start;}));
+
+        auto fine_tree = tree(cell_p);
+
+        // select a root node and merge branches with discontinuous compartment
+        // indices
+        auto perm = fine_tree.select_new_root(0);
+        for (auto i: make_span(perm.size())) {
+            perm_balancing[cell_start + i] = cell_start + perm[i];
+        }
+
+        // find the index of the first node for each branch
+        auto branch_starts = algorithms::branches(fine_tree.parents());
+
+        // compute branch length and apply permutation
+        std::vector<unsigned> branch_lengths(branch_starts.size() - 1);
+        for (auto i: make_span(branch_lengths.size())) {
+            branch_lengths[i] = branch_starts[i+1] - branch_starts[i];
+        }
+
+        // find the parent index of branches
+        // we need to convert to cell_lid_type, required to construct a tree.
+        std::vector<cell_lid_type> branch_p =
+            util::assign_from(
+                algorithms::tree_reduce(fine_tree.parents(), branch_starts));
+        // build tree structure that describes the branch topology
+        auto cell_tree = tree(branch_p);
+
+        trees.push_back(cell_tree);
+        fine_trees.push_back(fine_tree);
+        tree_branch_starts.push_back(branch_starts);
+        tree_branch_lengths.push_back(branch_lengths);
+    }
+}
+
+void forest::optimize() {
+    using util::make_span;
+
+    // cut the tree
+    unsigned count = 1; // number of nodes found on the previous level
+    for (auto level = 0; count > 0; level++) {
+        count = 0;
+
+        // decide where to cut it ...
+        unsigned max_length = 0;
+        for (auto t_ix: make_span(trees.size())) { // TODO make this local on an intermediate packing
+            for (level_iterator it (&trees[t_ix], level); it.valid(); it.next()) {
+                auto length = tree_branch_lengths[t_ix][it.peek()];
+                max_length += length;
+                count++;
+            }
+        }
+        if (count == 0) {
+            // there exists no tree with branches on this level
+            continue;
+        };
+        max_length = max_length / count;
+        // avoid ininite loops
+        if (max_length <= 1) max_length = 1;
+        // we don't want too small segments
+        if (max_length <= 10) max_length = 10;
+
+        for (auto t_ix: make_span(trees.size())) {
+            // ... cut all trees on this level
+            for (level_iterator it (&trees[t_ix], level); it.valid(); it.next()) {
+
+                auto length = tree_branch_lengths[t_ix][it.peek()];
+                if (length > max_length) {
+                    // now cut the tree
+
+                    // we are allowed to mess with the tree because of the
+                    // implementation of level_iterator o.O
+
+                    auto insert_at_bs = tree_branch_starts[t_ix].begin() + it.peek();
+                    auto insert_at_ls = tree_branch_lengths[t_ix].begin() + it.peek();
+
+                    trees[t_ix].split_node(it.peek());
+
+                    // now the tree got a new node.
+                    // we now have to insert a corresponding new 'branch
+                    // start' to the list
+
+                    // make sure that `tree_branch_starts` for A and N point to
+                    // the correct slices
+                    auto old_start = tree_branch_starts[t_ix][it.peek()];
+                    // first insert, then index peek, as we already
+                    // incremented the iterator
+                    tree_branch_starts[t_ix].insert(insert_at_bs, old_start);
+                    tree_branch_lengths[t_ix].insert(insert_at_ls, max_length);
+                    tree_branch_starts[t_ix][it.peek() + 1] = old_start + max_length;
+                    tree_branch_lengths[t_ix][it.peek() + 1] = length - max_length;
+                    // we don't have to shift any indices as we did not
+                    // create any new branch segments, but just split
+                    // one over two nodes
+                }
+            }
+        }
+    }
+}
+
+
+// debugging functions:
+
+
+// Exports the tree's parent structure into a dot file.
+// the children and parent trees are equivalent. Both methods are provided
+// for debugging purposes.
+template<typename F>
+void export_parents(const tree& t, std::string file, F label) {
+    using util::make_span;
+    std::ofstream ofile;
+    ofile.open(file);
+    ofile << "strict digraph Parents {" << std::endl;
+    for (auto i: make_span(t.parents().size())) {
+        ofile << i << "[label=\"" << label(i) << "\"]" << std::endl;
+    }
+    for (auto i: make_span(t.parents().size())) {
+        auto p = t.parent(i);
+        if (p != tree::no_parent) {
+            ofile << i << " -> " << t.parent(i) << std::endl;
+        }
+    }
+    ofile << "}" << std::endl;
+    ofile.close();
+}
+
+void export_parents(const tree& t, std::string file) {
+    // the labels in the the graph are the branch indices
+    export_parents(t, file, [](auto i){return i;});
+}
+
+// Exports the tree's children structure into a dot file.
+// the children and parent trees are equivalent. Both methods are provided
+// for debugging purposes.
+template<typename F>
+void export_children(const tree& t, std::string file, F label) {
+    using util::make_span;
+    std::ofstream ofile;
+    ofile.open(file);
+    ofile << "strict digraph Children {" << std::endl;
+    for (auto i: make_span(t.num_segments())) {
+        ofile << i << "[label=\"" << label(i) << "\"]" << std::endl;
+    }
+    for (auto i: make_span(t.num_segments())) {
+        ofile << i << " -> {";
+        for (auto c: t.children(i)) {
+            ofile << " " << c;
+        }
+        ofile << "}" << std::endl;
+    }
+    ofile << "}" << std::endl;
+    ofile.close();
+}
+
+void export_children(const tree& t, std::string file) {
+    // the labels in the the graph are the branch indices
+    export_children(t, file, [](auto i){return i;});
+}
+
+} // namespace gpu
+} // namespace arb
diff --git a/arbor/backends/gpu/forest.hpp b/arbor/backends/gpu/forest.hpp
new file mode 100644
index 00000000..43900858
--- /dev/null
+++ b/arbor/backends/gpu/forest.hpp
@@ -0,0 +1,140 @@
+#pragma once
+
+#include <vector>
+
+#include "tree.hpp"
+
+namespace arb {
+namespace gpu {
+
+using size_type = int;
+
+struct forest {
+    forest(const std::vector<size_type>& p, const std::vector<size_type>& cell_cv_divs);
+
+    void optimize();
+
+    unsigned num_trees() {
+        return fine_trees.size();
+    }
+
+    const tree& branch_tree(unsigned tree_index) {
+        return trees[tree_index];
+    }
+
+    const tree& compartment_tree(unsigned tree_index) {
+        return fine_trees[tree_index];
+    }
+
+    // Returns the offset into the compartments of a tree where each branch
+    // starts. It holds `0 <= offset < tree.num_segments()`
+    const std::vector<unsigned>& branch_offsets(unsigned tree_index) {
+        return tree_branch_starts[tree_index];
+    }
+
+    // Returns vector of the length of each branch in a tree.
+    const std::vector<unsigned>& branch_lengths(unsigned tree_index) {
+        return tree_branch_lengths[tree_index];
+    }
+
+    // Return the permutation that was applied to the compartments in the
+    // format: `new[i] = old[perm[i]]`
+    const std::vector<size_type>& permutation() {
+        return perm_balancing;
+    }
+
+    // trees of compartments
+    std::vector<tree> fine_trees;
+    std::vector<std::vector<unsigned>> tree_branch_starts;
+    std::vector<std::vector<unsigned>> tree_branch_lengths;
+    // trees of branches
+    std::vector<tree> trees;
+
+    // the permutation matrix used by the balancing algorithm
+    // format: `solver_format[i] = external_format[perm[i]]`
+    std::vector<size_type> perm_balancing;
+};
+
+
+struct level_iterator {
+    level_iterator(tree* t, unsigned level) {
+        tree_ = t;
+        only_on_level = level;
+        // due to the ordering of the nodes we know that 0 is the root
+        current_node  = 0;
+        current_level = 0;
+        next_children = 0;
+        if (level != 0) {
+            next();
+        };
+    }
+
+    void advance_depth_first() {
+        auto children = tree_->children(current_node);
+        if (next_children < children.size() && current_level <= only_on_level) {
+            // go to next children
+            current_level += 1;
+            current_node = children[next_children];
+            next_children = 0;
+        } else {
+            // go to parent
+            auto parent_node = tree_->parents()[current_node];
+            constexpr unsigned npos = unsigned(-1);
+            if (parent_node != npos) {
+                auto siblings = tree_->children(parent_node);
+                // get the index in the child list of the parent
+                unsigned index = 0;
+                while (siblings[index] != current_node) { // TODO repalce by array lockup: sibling_nr
+                    index += 1;
+                }
+
+                current_level -= 1;
+                current_node = parent_node;
+                next_children = index + 1;
+            } else {
+                // we are done with the iteration
+                current_level = -1;
+                current_node  = -1;
+                next_children = -1;
+            }
+
+        }
+    }
+
+    unsigned next() {
+        constexpr unsigned npos = unsigned(-1);
+        if (!valid()) {
+            // we are done
+            return npos;
+        } else {
+            advance_depth_first();
+            // next_children != 0 means, that we have seen the node before
+            while (valid() && (current_level != only_on_level || next_children != 0)) {
+                advance_depth_first();
+            }
+            return current_node;
+        }
+    }
+
+    bool valid() {
+        constexpr unsigned npos = unsigned(-1);
+        return this->peek() != npos;
+    }
+
+    unsigned peek() {
+        return current_node;
+    }
+
+private:
+    tree* tree_;
+
+    unsigned current_node;
+    unsigned current_level;
+    unsigned next_children;
+
+    unsigned only_on_level;
+};
+
+
+} // namespace gpu
+} // namespace arb
diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp
index 9d02eea3..cc82bed0 100644
--- a/arbor/backends/gpu/fvm.hpp
+++ b/arbor/backends/gpu/fvm.hpp
@@ -14,9 +14,15 @@
 #include "backends/gpu/gpu_store_types.hpp"
 #include "backends/gpu/shared_state.hpp"
 
-#include "matrix_state_interleaved.hpp"
 #include "threshold_watcher.hpp"
 
+
+#ifdef ARB_HAVE_GPU_FINE_MATRIX
+    #include "matrix_state_fine.hpp"
+#else
+    #include "matrix_state_interleaved.hpp"
+#endif
+
 namespace arb {
 namespace gpu {
 
@@ -39,7 +45,11 @@ struct backend {
         return memory::on_host(v);
     }
 
+#ifdef ARB_HAVE_GPU_FINE_MATRIX
+    using matrix_state = arb::gpu::matrix_state_fine<value_type, index_type>;
+#else
     using matrix_state = arb::gpu::matrix_state_interleaved<value_type, index_type>;
+#endif
     using threshold_watcher = arb::gpu::threshold_watcher;
 
     using deliverable_event_stream = arb::gpu::deliverable_event_stream;
diff --git a/arbor/backends/gpu/matrix_fine.cpp b/arbor/backends/gpu/matrix_fine.cpp
new file mode 100644
index 00000000..b1d7a227
--- /dev/null
+++ b/arbor/backends/gpu/matrix_fine.cpp
@@ -0,0 +1,71 @@
+#include <ostream>
+
+#include <cuda_runtime.h>
+
+#include "memory/cuda_wrappers.hpp"
+#include "util/span.hpp"
+
+#include "matrix_fine.hpp"
+
+namespace arb {
+namespace gpu {
+
+level::level(unsigned branches):
+    num_branches(branches)
+{
+    using memory::cuda_malloc_managed;
+
+    using arb::memory::cuda_malloc_managed;
+    if (num_branches!=0) {
+        lengths = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
+        parents = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
+        cudaDeviceSynchronize();
+    }
+}
+
+level::level(level&& other) {
+    std::swap(other.lengths, this->lengths);
+    std::swap(other.parents, this->parents);
+    std::swap(other.num_branches, this->num_branches);
+    std::swap(other.max_length, this->max_length);
+    std::swap(other.data_index, this->data_index);
+}
+
+level::level(const level& other) {
+    using memory::cuda_malloc_managed;
+
+    num_branches = other.num_branches;
+    max_length = other.max_length;
+    data_index = other.data_index;
+    if (num_branches!=0) {
+        lengths = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
+        parents = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
+        cudaDeviceSynchronize();
+        std::copy(other.lengths, other.lengths+num_branches, lengths);
+        std::copy(other.parents, other.parents+num_branches, parents);
+    }
+}
+
+level::~level() {
+    if (num_branches!=0) {
+        cudaDeviceSynchronize(); // to ensure that managed memory has been freed
+        if (lengths) arb::memory::cuda_free(lengths);
+        if (parents) arb::memory::cuda_free(parents);
+    }
+}
+
+std::ostream& operator<<(std::ostream& o, const level& l) {
+    cudaDeviceSynchronize();
+    o << "branches:" << l.num_branches
+      << " max_len:" << l.max_length
+      << " data_idx:" << l.data_index
+      << " lengths:[";
+    for (auto i: util::make_span(l.num_branches)) o << l.lengths[i] << " ";
+    o << "] parents:[";
+    for (auto i: util::make_span(l.num_branches)) o << l.parents[i] << " ";
+    o << "]";
+    return o;
+}
+
+} // namespace gpu
+} // namespace arb
diff --git a/arbor/backends/gpu/matrix_fine.cu b/arbor/backends/gpu/matrix_fine.cu
new file mode 100644
index 00000000..764b39d5
--- /dev/null
+++ b/arbor/backends/gpu/matrix_fine.cu
@@ -0,0 +1,314 @@
+#include <arbor/fvm_types.hpp>
+
+#include "cuda_atomic.hpp"
+#include "cuda_common.hpp"
+#include "matrix_common.hpp"
+#include "matrix_fine.hpp"
+
+namespace arb {
+namespace gpu {
+
+namespace kernels {
+
+//
+// gather and scatter kernels
+//
+
+// to[i] = from[p[i]]
+template <typename T, typename I>
+__global__
+void gather(const T* from, T* to, const I* p, unsigned n) {
+    unsigned i = threadIdx.x + blockDim.x*blockIdx.x;
+
+    if (i<n) {
+        to[i] = from[p[i]];
+    }
+}
+
+// to[p[i]] = from[i]
+template <typename T, typename I>
+__global__
+void scatter(const T* from, T* to, const I* p, unsigned n) {
+    unsigned i = threadIdx.x + blockDim.x*blockIdx.x;
+
+    if (i<n) {
+        to[p[i]] = from[i];
+    }
+}
+
+/// GPU implementatin of Hines matrix assembly.
+/// Fine layout.
+/// For a given time step size dt:
+///     - use the precomputed alpha and alpha_d values to construct the diagonal
+///       and off diagonal of the symmetric Hines matrix.
+///     - compute the RHS of the linear system to solve.
+template <typename T, typename I>
+__global__
+void assemble_matrix_fine(
+        T* d,
+        T* rhs,
+        const T* invariant_d,
+        const T* voltage,
+        const T* current,
+        const T* cv_capacitance,
+        const T* area,
+        const I* cv_to_cell,
+        const T* dt_cell,
+        const I* perm,
+        unsigned n)
+{
+    const unsigned tid = threadIdx.x + blockDim.x*blockIdx.x;
+
+    if (tid<n) {
+        auto cid = cv_to_cell[tid];
+        auto dt = dt_cell[cid];
+
+        if (dt>0) {
+            // The 1e-3 is a constant of proportionality required to ensure that the
+            // conductance (gi) values have units μS (micro-Siemens).
+            // See the model documentation in docs/model for more information.
+            T factor = 1e-3/dt;
+
+            const auto gi = factor * cv_capacitance[tid];
+            const auto pid = perm[tid];
+            d[pid] = gi + invariant_d[tid];
+            rhs[pid] = gi*voltage[tid] - T(1e-3)*area[tid]*current[tid];
+        }
+        else {
+            const auto pid = perm[tid];
+            d[pid] = 0;
+            rhs[pid] = voltage[tid];
+        }
+    }
+}
+
+/// GPU implementation of Hines Matrix solver.
+/// Fine-grained tree based solver.
+/// Each block solves a set of matricesb iterating over the levels of matrix
+/// and perfoming a backward and forward substitution. On each level one thread
+/// gets assigned to one branch on this level of a matrix and solves and
+/// performs the substitution. Afterwards all threads continue on the next
+/// level.
+/// To avoid idle threads, one should try that on each level, there is a similar
+/// number of branches.
+template <typename T>
+__global__
+void solve_matrix_fine(
+    T* rhs,
+    T* d,
+    const T* u,
+    const level* levels,
+    const unsigned* levels_start,
+    unsigned* num_matrix, // number of packed matrices = number of cells
+    unsigned* padded_size)
+{
+    const auto tid = threadIdx.x;
+    const auto bid = blockIdx.x;
+
+    const auto first_level = levels_start[bid];
+    const auto num_levels  = levels_start[bid + 1] - first_level;
+    const auto block_levels = &levels[first_level];
+
+    // backward substitution
+
+    for (unsigned l=0; l<num_levels-1; ++l) {
+        const auto& lvl = block_levels[l];
+
+        const unsigned width = lvl.num_branches;
+
+        // Perform backward substitution for each branch on this level.
+        // One thread per branch.
+        if (tid < width) {
+            const unsigned len = lvl.lengths[tid];
+            unsigned pos = lvl.data_index + tid;
+
+            // Zero diagonal term implies dt==0; just leave rhs (for whole matrix)
+            // alone in that case.
+
+            // Each cell has a different `dt`, because we choose time step size
+            // according to when the next event is arriving at a cell. So, some
+            // cells require more time steps than others, but we have to solve
+            // all the matrices at the same time. When a cell finishes, we put a
+            // `0` on the diagonal to mark that it should not be solved for.
+            if (d[pos]!=0) {
+
+                // each branch perform substitution
+                T factor = u[pos] / d[pos];
+                for (unsigned i=0; i<len-1; ++i) {
+                    const unsigned next_pos = pos + width;
+                    d[next_pos]   -= factor * u[pos];
+                    rhs[next_pos] -= factor * rhs[pos];
+
+                    factor = u[next_pos] / d[next_pos];
+                    pos = next_pos;
+                }
+
+                // Update d and rhs at the parent node of this branch.
+                // A parent may have more than one contributing to it, so we use
+                // atomic updates to avoid races conditions.
+                const unsigned parent_index = block_levels[l+1].data_index;
+                const unsigned p = parent_index + lvl.parents[tid];
+                //d[p]   -= factor * u[pos];
+                cuda_atomic_add(d  +p, -factor*u[pos]);
+                //rhs[p] -= factor * rhs[pos];
+                cuda_atomic_add(rhs+p, -factor*rhs[pos]);
+            }
+        }
+        __syncthreads();
+    }
+
+    {
+        // the levels are sorted such that the root is the last level
+        const auto& lvl = block_levels[num_levels-1];
+        const unsigned width = num_matrix[bid];
+
+        if (tid < width) {
+            const unsigned len = lvl.lengths[tid];
+            unsigned pos = lvl.data_index + tid;
+
+            if (d[pos]!=0) {
+
+                // backward
+                for (unsigned i=0; i<len-1; ++i) {
+                    T factor = u[pos] / d[pos];
+                    const unsigned next_pos = pos + width;
+                    d[next_pos]   -= factor * u[pos];
+                    rhs[next_pos] -= factor * rhs[pos];
+
+                    pos = next_pos;
+                }
+
+                auto rhsp = rhs[pos] / d[pos];
+                rhs[pos] = rhsp;
+                pos -= width;
+
+                // forward
+                for (unsigned i=0; i<len-1; ++i) {
+                    rhsp = rhs[pos] - u[pos]*rhsp;
+                    rhsp /= d[pos];
+                    rhs[pos] = rhsp;
+                    pos -= width;
+                }
+            }
+        }
+    }
+
+    // forward substitution
+
+    // take great care with loop limits decrementing unsigned counter l
+    for (unsigned l=num_levels-1; l>0; --l) {
+        const auto& lvl = block_levels[l-1];
+
+        const unsigned width = lvl.num_branches;
+        const unsigned parent_index = block_levels[l].data_index;
+
+        __syncthreads();
+
+        // Perform forward-substitution for each branch on this level.
+        // One thread per branch.
+        if (tid < width) {
+            // Load the rhs value for the parent node of this branch.
+            const unsigned p = parent_index + lvl.parents[tid];
+            T rhsp = rhs[p];
+
+            // Find the index of the first node in this branch.
+            const unsigned len = lvl.lengths[tid];
+            unsigned pos = lvl.data_index + (len-1)*width + tid;
+
+            if (d[pos]!=0) {
+                // each branch perform substitution
+                for (unsigned i=0; i<len; ++i) {
+                    rhsp = rhs[pos] - u[pos]*rhsp;
+                    rhsp /= d[pos];
+                    rhs[pos] = rhsp;
+                    pos -= width;
+                }
+            }
+        }
+    }
+}
+
+} // namespace kernels
+
+void gather(
+    const fvm_value_type* from,
+    fvm_value_type* to,
+    const fvm_index_type* p,
+    unsigned n)
+{
+    constexpr unsigned blockdim = 128;
+    const unsigned griddim = impl::block_count(n, blockdim);
+
+    kernels::gather<<<griddim, blockdim>>>(from, to, p, n);
+}
+
+void scatter(
+    const fvm_value_type* from,
+    fvm_value_type* to,
+    const fvm_index_type* p,
+    unsigned n)
+{
+    constexpr unsigned blockdim = 128;
+    const unsigned griddim = impl::block_count(n, blockdim);
+
+    kernels::scatter<<<griddim, blockdim>>>(from, to, p, n);
+}
+
+
+void assemble_matrix_fine(
+    fvm_value_type* d,
+    fvm_value_type* rhs,
+    const fvm_value_type* invariant_d,
+    const fvm_value_type* voltage,
+    const fvm_value_type* current,
+    const fvm_value_type* cv_capacitance,
+    const fvm_value_type* area,
+    const fvm_index_type* cv_to_cell,
+    const fvm_value_type* dt_cell,
+    const fvm_index_type* perm,
+    unsigned n)
+{
+    const unsigned block_dim = 128;
+    const unsigned num_blocks = impl::block_count(n, block_dim);
+
+    kernels::assemble_matrix_fine<<<num_blocks, block_dim>>>(
+        d, rhs, invariant_d, voltage, current, cv_capacitance, area,
+        cv_to_cell, dt_cell,
+        perm, n);
+}
+
+// Example:
+//
+//         block 0                  block 1              block 2
+// .~~~~~~~~~~~~~~~~~~.  .~~~~~~~~~~~~~~~~~~~~~~~~.  .~~~~~~~~~~~ ~ ~
+//
+//  L0 \  /                                           L5    \  /
+//      \/                                                   \/
+//  L1   \   /   \   /    L3 \   /   \ | /   \   /    L6 \   /  . . .
+//        \ /     \ /         \ /     \|/     \ /         \ /
+//  L2     |       |      L4   |       |       |      L7   |
+//         |       |           |       |       |           |
+//
+// levels       = [L0, L1, L2, L3, L4, L5, L6, L7, ... ]
+// levels_start = [0, 3, 5, 8, ...]
+// num_levels   = [3, 2, 3, ...]
+// num_cells    = [2, 3, ...]
+// num_blocks   = level_start.size() - 1 = num_levels.size() = num_cells.size()
+void solve_matrix_fine(
+    fvm_value_type* rhs,
+    fvm_value_type* d,                // diagonal values
+    const fvm_value_type* u,          // upper diagonal (and lower diagonal as the matrix is SPD)
+    const level* levels,              // pointer to an array containing level meta-data for all blocks
+    const unsigned* levels_start,     // start index into levels for each cuda block
+    unsigned* num_cells,              // he number of cells packed into this single matrix
+    unsigned* padded_size,            // length of rhs, d, u, including padding
+    unsigned num_blocks,              // nuber of blocks
+    unsigned blocksize)               // size of each block
+{
+    kernels::solve_matrix_fine<<<num_blocks, blocksize>>>(
+        rhs, d, u, levels, levels_start,
+        num_cells, padded_size);
+}
+
+} // namespace gpu
+} // namespace arb
diff --git a/arbor/backends/gpu/matrix_fine.hpp b/arbor/backends/gpu/matrix_fine.hpp
new file mode 100644
index 00000000..547b80e2
--- /dev/null
+++ b/arbor/backends/gpu/matrix_fine.hpp
@@ -0,0 +1,77 @@
+#include <arbor/fvm_types.hpp>
+
+#include <ostream>
+
+namespace arb {
+namespace gpu {
+
+struct level {
+    level() = default;
+
+    level(unsigned branches);
+    level(level&& other);
+    level(const level& other);
+
+    ~level();
+
+    unsigned num_branches = 0; // Number of branches
+    unsigned max_length = 0;   // Length of the longest branch
+    unsigned data_index = 0;   // Index into data values of the first branch
+
+    //  The lengths and parents vectors are raw pointers to managed memory,
+    //  so there is need for tricksy deep copy of this type to GPU.
+
+    // An array holding the length of each branch in the level.
+    // length: num_branches.
+    unsigned* lengths = nullptr;
+
+    // An array with the index of the parent branch for each branch on this level.
+    // length: num_branches.
+    // When performing backward/forward substitution we need to update/read
+    // data values for the parent node for each branch.
+    // This can be done easily if we know where the parent branch is located
+    // on the next level.
+    unsigned* parents = nullptr;
+};
+
+std::ostream& operator<<(std::ostream& o, const level& l);
+
+// C wrappers around kernels
+void gather(
+    const fvm_value_type* from,
+    fvm_value_type* to,
+    const fvm_index_type* p,
+    unsigned n);
+
+void scatter(
+    const fvm_value_type* from,
+    fvm_value_type* to,
+    const fvm_index_type* p,
+    unsigned n);
+
+void assemble_matrix_fine(
+    fvm_value_type* d,
+    fvm_value_type* rhs,
+    const fvm_value_type* invariant_d,
+    const fvm_value_type* voltage,
+    const fvm_value_type* current,
+    const fvm_value_type* cv_capacitance,
+    const fvm_value_type* area,
+    const fvm_index_type* cv_to_cell,
+    const fvm_value_type* dt_cell,
+    const fvm_index_type* perm,
+    unsigned n);
+
+void solve_matrix_fine(
+    fvm_value_type* rhs,
+    fvm_value_type* d,                // diagonal values
+    const fvm_value_type* u,          // upper diagonal (and lower diagonal as the matrix is SPD)
+    const level* levels,              // pointer to an array containing level meta-data for all blocks
+    const unsigned* levels_end,       // end index (exclusive) into levels for each cuda block
+    unsigned* num_cells,              // he number of cells packed into this single matrix
+    unsigned* padded_size,            // length of rhs, d, u, including padding
+    unsigned num_blocks,              // nuber of blocks
+    unsigned blocksize);              // size of each block
+
+} // namespace gpu
+} // namespace arb
diff --git a/arbor/backends/gpu/matrix_solve.cu b/arbor/backends/gpu/matrix_solve.cu
index 1a9ab8c6..7fa719c8 100644
--- a/arbor/backends/gpu/matrix_solve.cu
+++ b/arbor/backends/gpu/matrix_solve.cu
@@ -10,6 +10,7 @@ namespace kernels {
 
 /// GPU implementation of Hines Matrix solver.
 /// Flat format
+/// p: parent index for each variable. Needed for backward and forward sweep
 template <typename T, typename I>
 __global__
 void solve_matrix_flat(
diff --git a/arbor/backends/gpu/matrix_state_fine.hpp b/arbor/backends/gpu/matrix_state_fine.hpp
new file mode 100644
index 00000000..c457fe24
--- /dev/null
+++ b/arbor/backends/gpu/matrix_state_fine.hpp
@@ -0,0 +1,494 @@
+#pragma once
+
+#include <cstring>
+
+#include <vector>
+#include <type_traits>
+
+#include <arbor/common_types.hpp>
+
+#include "algorithms.hpp"
+#include "memory/memory.hpp"
+#include "util/partition.hpp"
+#include "util/rangeutil.hpp"
+#include "util/span.hpp"
+#include "tree.hpp"
+
+#include "matrix_fine.hpp"
+#include "forest.hpp"
+
+namespace arb {
+namespace gpu {
+
+// Helper type for branch meta data in setup phase of fine grained
+// matrix storage+solver.
+//
+//      leaf
+//      .
+//      .
+//      .
+//  -   *
+//      |
+//  l   *
+//  e   |
+//  n   *
+//  g   |
+//  t   *
+//  h   |
+//  -   start_idx
+//      |
+//      parent_idx
+//      |
+//      .
+//      .
+//      .
+//      root
+struct branch {
+    unsigned id;         // branch id
+    unsigned parent_id;  // parent branch id
+    unsigned parent_idx; //
+    unsigned start_idx;  // the index of the first node in the input parent index
+    unsigned length;     // the number of nodes in the branch
+};
+
+// order branches by:
+//  - descending length
+//  - ascending id
+inline
+bool operator<(const branch& lhs, const branch& rhs) {
+    if (lhs.length!=rhs.length) {
+        return lhs.length>rhs.length;
+    } else {
+        return lhs.id<rhs.id;
+    }
+}
+
+inline
+std::ostream& operator<<(std::ostream& o, branch b) {
+    return o << "[" << b.id
+        << ", len " << b.length
+        << ", pid " << b.parent_idx
+        << ", sta " << b.start_idx
+        << "]";
+}
+
+template <typename T, typename I>
+struct matrix_state_fine {
+public:
+    using value_type = T;
+    using size_type = I;
+
+    using array      = memory::device_vector<value_type>;
+    using view       = typename array::view_type;
+    using const_view = typename array::const_view_type;
+    using iarray     = memory::device_vector<size_type>;
+
+    template <typename ValueType>
+    using managed_vector = std::vector<ValueType, memory::managed_allocator<ValueType>>;
+
+    iarray cv_to_cell;
+
+    array d;     // [μS]
+    array u;     // [μS]
+    array rhs;   // [nA]
+
+    // required for matrix assembly
+
+    array cv_area; // [μm^2]
+
+    array cv_capacitance;      // [pF]
+
+    // the invariant part of the matrix diagonal
+    array invariant_d;         // [μS]
+
+    // for storing the solution in unpacked format
+    array solution_;
+
+    // the maximum nuber of branches in each level per block
+    unsigned max_branches_per_level;
+
+    // number of rows in matrix
+    unsigned matrix_size;
+
+    // number of cells
+    unsigned num_cells;
+    managed_vector<unsigned> num_cells_in_block;
+
+    // end of the data of each level
+    // use data_size.back() to get the total data size
+    //      data_size >= size
+    managed_vector<unsigned> data_size; // TODO rename
+
+    // the meta data for each level for each block layed out linearly in memory
+    managed_vector<level> levels;
+    // the start of the levels of each block
+    // block b owns { leves[level_start[b]], ..., leves[level_start[b+1] - 1] }
+    // there is an additional entry at the end of the vector to make the above
+    // compuation save
+    managed_vector<unsigned> levels_start;
+
+    // permutation from front end storage to packed storage
+    //      `solver_format[perm[i]] = external_format[i]`
+    iarray perm;
+
+
+    matrix_state_fine() = default;
+
+    // constructor for fine-grained matrix.
+    matrix_state_fine(const std::vector<size_type>& p,
+                 const std::vector<size_type>& cell_cv_divs,
+                 const std::vector<value_type>& cap,
+                 const std::vector<value_type>& face_conductance,
+                 const std::vector<value_type>& area)
+    {
+        using util::make_span;
+        constexpr unsigned npos = unsigned(-1);
+
+        max_branches_per_level = 128;
+
+        // for now we have single cell per cell group
+        arb_assert(cell_cv_divs.size()==2);
+
+        num_cells = cell_cv_divs.size()-1;
+
+        forest trees(p, cell_cv_divs);
+        trees.optimize();
+
+        // Now distribute the cells into cuda blocks.
+        // While the total number of branches on each level of theses cells in a
+        // block are less than `max_branches_per_level` we add more cells. If
+        // one block is full, we start a new cuda block.
+
+        unsigned current_block = 0;
+        std::vector<unsigned> block_num_branches_per_depth;
+        std::vector<unsigned> block_ix(num_cells);
+        num_cells_in_block.resize(1, 0);
+
+        // branch_map = branch_maps[block] is a branch map for each cuda block
+        // branch_map[depth] is list of branches is this level
+        // each branch branch_map[depth][i] has
+        // {id, parent_id, start_idx, parent_idx, length}
+        std::vector<std::vector<std::vector<branch>>> branch_maps;
+        branch_maps.resize(1);
+
+        unsigned num_branches = 0u;
+        for (auto c: make_span(0u, num_cells)) {
+            auto cell_start = cell_cv_divs[c];
+            auto cell_tree = trees.branch_tree(c);
+            auto fine_tree = trees.compartment_tree(c);
+            auto branch_starts  = trees.branch_offsets(c);
+            auto branch_lengths = trees.branch_lengths(c);
+
+            auto depths = depth_from_root(cell_tree);
+
+            // calculate the number of levels in this cell
+            auto cell_num_levels = util::max_value(depths)+1u;
+
+            auto num_cell_branches = cell_tree.num_segments();
+
+            // count number of branches per level
+            std::vector<unsigned> cell_num_branches_per_depth(cell_num_levels, 0u);
+            for (auto i: make_span(num_cell_branches)) {
+                cell_num_branches_per_depth[depths[i]] += 1;
+            }
+            // resize the block levels if neccessary
+            if (cell_num_levels > block_num_branches_per_depth.size()) {
+                block_num_branches_per_depth.resize(cell_num_levels, 0);
+            }
+
+
+            // check if we can fit the current cell into the last cuda block
+            bool fits_current_block = true;
+            for (auto i: make_span(cell_num_levels)) {
+                unsigned new_branches_per_depth =
+                    block_num_branches_per_depth[i]
+                    + cell_num_branches_per_depth[i];
+                if (new_branches_per_depth > max_branches_per_level) {
+                    fits_current_block = false;
+                }
+            }
+            if (fits_current_block) {
+                // put the cell into current block
+                block_ix[c] = current_block;
+                num_cells_in_block[block_ix[c]] += 1;
+                // and increment counter
+                for (auto i: make_span(cell_num_levels)) {
+                    block_num_branches_per_depth[i] += cell_num_branches_per_depth[i];
+                }
+            } else {
+                // otherwise start a new block
+                block_ix[c] = current_block + 1;
+                num_cells_in_block.push_back(1);
+                branch_maps.resize(branch_maps.size()+1);
+                current_block += 1;
+                // and reset counter
+                block_num_branches_per_depth = cell_num_branches_per_depth;
+                for (auto num: block_num_branches_per_depth) {
+                    if (num > max_branches_per_level) {
+                        throw std::runtime_error(
+                            "Could not fit " + std::to_string(num)
+                            + " branches in a block of size "
+                            + std::to_string(max_branches_per_level));
+                    }
+                }
+            }
+
+
+            // the branch map for the block in which we put the cell
+            // maps levels to a list of branches in that level
+            auto& branch_map = branch_maps[block_ix[c]];
+
+            // build branch_map:
+            // branch_map[i] is a list of branch meta-data for branches with depth i
+            if (cell_num_levels > branch_map.size()) {
+                branch_map.resize(cell_num_levels);
+            }
+            for (auto i: make_span(num_cell_branches)) {
+                branch b;
+                auto depth = depths[i];
+                // give the branch a unique id number
+                b.id = i + num_branches;
+                // take care to mark branches with no parents with npos
+                b.parent_id = cell_tree.parent(i)==cell_tree.no_parent ?
+                    npos : cell_tree.parent(i) + num_branches;
+                b.start_idx = branch_starts[i] + cell_start;
+                b.length = branch_lengths[i];
+                b.parent_idx = p[b.start_idx] + cell_start;
+                branch_map[depth].push_back(b);
+            }
+            // total number of branches of all cells
+            num_branches += num_cell_branches;
+        }
+
+        for (auto& branch_map: branch_maps) {
+            // reverse the levels
+            std::reverse(branch_map.begin(), branch_map.end());
+
+            // Sort all branches on each level in descending order of length.
+            // Later, branches will be partitioned over thread blocks, and we will
+            // take advantage of the fact that the first branch in a partition is
+            // the longest, to determine how to pack all the branches in a block.
+            for (auto& branches: branch_map) {
+                util::sort(branches);
+            }
+        }
+
+        // The branches generated above have been assigned contiguous ids.
+        // Now generate a vector of branch_loc, one for each branch, that
+        // allow for quick lookup by id of the level and index within a level
+        // of each branch.
+        // This information is only used in the generation of the levels below.
+
+        // Helper for recording location of a branch once packed.
+        struct branch_loc {
+            unsigned block; // the cuda block containing the cell to which the branch blongs to
+            unsigned level; // the level containing the branch
+            unsigned index; // the index of the branch on that level
+        };
+
+        // branch_locs will hold the location information for each branch.
+        std::vector<branch_loc> branch_locs(num_branches);
+        for (unsigned b: make_span(branch_maps.size())) {
+            const auto& branch_map = branch_maps[b];
+            for (unsigned l: make_span(branch_map.size())) {
+                const auto& branches = branch_map[l];
+
+                // Record the location information
+                for (auto i=0u; i<branches.size(); ++i) {
+                    const auto& branch = branches[i];
+                    branch_locs[branch.id] = {b, l, i};
+                }
+            }
+        }
+
+        unsigned total_num_levels = std::accumulate(
+            branch_maps.begin(), branch_maps.end(), 0,
+            [](unsigned value, decltype(branch_maps[0])& l) {
+                return value + l.size();});
+
+        // construct description for the set of branches on each level for each
+        // block. This is later used to sort the branches in each block in each
+        // level into conineous chunks which are easier to read for the cuda
+        // kernel.
+        levels.reserve(total_num_levels);
+        levels_start.reserve(branch_maps.size() + 1);
+        levels_start.push_back(0);
+        data_size.reserve(branch_maps.size());
+        // offset into the packed data format, used to apply permutation on data
+        auto pos = 0u;
+        for (const auto& branch_map: branch_maps) {
+            for (const auto& lvl_branches: branch_map) {
+
+                level lvl(lvl_branches.size());
+
+                // The length of the first branch is the upper bound on branch
+                // length as they are sorted in descending order of length.
+                lvl.max_length = lvl_branches.front().length;
+                lvl.data_index = pos;
+
+                unsigned bi = 0u;
+                for (const auto& b: lvl_branches) {
+                    // Set the length of the branch.
+                    lvl.lengths[bi] = b.length;
+
+                    // Set the parent indexes. During the forward and backward
+                    // substitution phases each branch accesses the last node in
+                    // its parent branch.
+                    auto index = b.parent_id==npos? npos: branch_locs[b.parent_id].index;
+                    lvl.parents[bi] = index;
+                    ++bi;
+                }
+
+                pos += lvl.max_length*lvl.num_branches;
+
+                levels.push_back(std::move(lvl));
+            }
+            auto prev_end = levels_start.back();
+            levels_start.push_back(prev_end + branch_map.size());
+            data_size.push_back(pos);
+        }
+
+        // set matrix state
+        matrix_size = p.size();
+
+        // form the permutation index used to reorder vectors to/from the
+        // ordering used by the fine grained matrix storage.
+        std::vector<size_type> perm_tmp(matrix_size);
+        for (auto block: make_span(branch_maps.size())) {
+            const auto& branch_map = branch_maps[block];
+            const auto first_level = levels_start[block];
+
+            for (auto i: make_span(levels_start[block + 1] - first_level)) {
+                const auto& l = levels[first_level + i];
+                for (auto j: make_span(l.num_branches)) {
+                    const auto& b = branch_map[i][j];
+                    auto to = l.data_index + j + l.num_branches*(l.lengths[j]-1);
+                    auto from = b.start_idx;
+                    for (auto k: make_span(b.length)) {
+                        perm_tmp[from + k] = to - k*l.num_branches;
+                    }
+                }
+            }
+        }
+
+        auto perm_balancing = trees.permutation();
+
+        // apppy permutation form balancing
+        std::vector<size_type> perm_tmp2(matrix_size);
+        for (auto i: make_span(matrix_size)) {
+             // This is CORRECT! verified by using the ring benchmark with root=0 (where the permutation is actually not id)
+            perm_tmp2[perm_balancing[i]] = perm_tmp[i];
+        }
+        // copy permutation to device memory
+        perm = memory::make_const_view(perm_tmp2);
+
+
+        // Summary of fields and their storage format:
+        //
+        // face_conductance : not needed, don't store
+        // d, u, rhs        : packed
+        // cv_capacitance   : flat
+        // invariant_d      : flat
+        // solution_        : flat
+        // cv_to_cell       : flat
+        // area             : flat
+
+        // the invariant part of d is stored in in flat form
+        std::vector<value_type> invariant_d_tmp(matrix_size, 0);
+        managed_vector<value_type> u_tmp(matrix_size, 0);
+        for (auto i: make_span(1u, matrix_size)) {
+            auto gij = face_conductance[i];
+
+            u_tmp[i] = -gij;
+            invariant_d_tmp[i] += gij;
+            invariant_d_tmp[p[i]] += gij;
+        }
+
+        // the matrix components u, d and rhs are stored in packed form
+        auto nan = std::numeric_limits<double>::quiet_NaN();
+        d   = array(data_size.back(), nan);
+        u   = array(data_size.back(), nan);
+        rhs = array(data_size.back(), nan);
+
+        // transform u_tmp values into packed u vector.
+        flat_to_packed(u_tmp, u);
+
+        // the invariant part of d, cv_area and the solution are in flat form
+        solution_ = array(matrix_size, 0);
+        cv_area = memory::make_const_view(area);
+
+        // the cv_capacitance can be copied directly because it is
+        // to be stored in flat format
+        cv_capacitance = memory::make_const_view(cap);
+        invariant_d = memory::make_const_view(invariant_d_tmp);
+
+        // calculte the cv -> cell mappings
+        std::vector<size_type> cv_to_cell_tmp(matrix_size);
+        size_type ci = 0;
+        for (auto cv_span: util::partition_view(cell_cv_divs)) {
+            util::fill(util::subrange_view(cv_to_cell_tmp, cv_span), ci);
+            ++ci;
+        }
+        cv_to_cell = memory::make_const_view(cv_to_cell_tmp);
+    }
+
+    // Assemble the matrix
+    // Afterwards the diagonal and RHS will have been set given dt, voltage and current
+    //   dt_cell [ms] (per cell)
+    //   voltage [mV]
+    //   current [nA]
+    void assemble(const_view dt_cell, const_view voltage, const_view current) {
+        assemble_matrix_fine(
+            d.data(),
+            rhs.data(),
+            invariant_d.data(),
+            voltage.data(),
+            current.data(),
+            cv_capacitance.data(),
+            cv_area.data(),
+            cv_to_cell.data(),
+            dt_cell.data(),
+            perm.data(),
+            size());
+    }
+
+    void solve() {
+        solve_matrix_fine(
+            rhs.data(), d.data(), u.data(),
+            levels.data(), levels_start.data(),
+            num_cells_in_block.data(),
+            data_size.data(),
+            num_cells_in_block.size(), max_branches_per_level);
+
+        // unpermute the solution
+        packed_to_flat(rhs, solution_);
+    }
+
+    const_view solution() const {
+        return solution_;
+    }
+
+    template <typename VFrom, typename VTo>
+    void flat_to_packed(const VFrom& from, VTo& to ) {
+        arb_assert(from.size()==matrix_size);
+        arb_assert(to.size()==data_size.back());
+
+        scatter(from.data(), to.data(), perm.data(), perm.size());
+    }
+
+    template <typename VFrom, typename VTo>
+    void packed_to_flat(const VFrom& from, VTo& to ) {
+        arb_assert(from.size()==data_size.back());
+        arb_assert(to.size()==matrix_size);
+
+        gather(from.data(), to.data(), perm.data(), perm.size());
+    }
+
+private:
+    std::size_t size() const {
+        return matrix_size;
+    }
+};
+
+} // namespace gpu
+} // namespace arb
diff --git a/arbor/tree.cpp b/arbor/tree.cpp
new file mode 100644
index 00000000..a4bc28d6
--- /dev/null
+++ b/arbor/tree.cpp
@@ -0,0 +1,385 @@
+#include <algorithm>
+#include <cassert>
+#include <numeric>
+#include <queue>
+#include <vector>
+
+#include <arbor/common_types.hpp>
+
+#include "algorithms.hpp"
+#include "memory/memory.hpp"
+#include "tree.hpp"
+#include "util/span.hpp"
+
+namespace arb {
+
+tree::tree(std::vector<tree::int_type> parent_index) {
+    // validate the input
+    if(!algorithms::is_minimal_degree(parent_index)) {
+        throw std::domain_error(
+            "parent index used to build a tree did not satisfy minimal degree ordering"
+        );
+    }
+
+    // an empty parent_index implies a single-compartment/segment cell
+    arb_assert(parent_index.size()!=0u);
+
+    init(parent_index.size());
+    memory::copy(parent_index, parents_);
+    parents_[0] = no_parent;
+
+    // compute offsets into children_ array
+    memory::copy(algorithms::make_index(algorithms::child_count(parents_)), child_index_);
+
+    std::vector<int_type> pos(parents_.size(), 0);
+    for (auto i = 1u; i < parents_.size(); ++i) {
+        auto p = parents_[i];
+        children_[child_index_[p] + pos[p]] = i;
+        ++pos[p];
+    }
+}
+
+tree::size_type tree::num_children() const {
+    return static_cast<size_type>(children_.size());
+}
+
+tree::size_type tree::num_children(size_t b) const {
+    return child_index_[b+1] - child_index_[b];
+}
+
+tree::size_type tree::num_segments() const {
+    // the number of segments/nodes is the size of the child index minus 1
+    // ... except for the case of an empty tree
+    auto sz = static_cast<size_type>(child_index_.size());
+    return sz ? sz - 1 : 0;
+}
+
+const tree::iarray& tree::child_index() {
+    return child_index_;
+}
+
+const tree::iarray& tree::children() const {
+    return children_;
+}
+
+const tree::iarray& tree::parents() const {
+    return parents_;
+}
+
+const tree::int_type& tree::parent(size_t b) const {
+    return parents_[b];
+}
+tree::int_type& tree::parent(size_t b) {
+    return parents_[b];
+}
+
+tree::int_type tree::split_node(int_type ix) {
+    using util::make_span;
+
+    auto insert_at_p  = parents_.begin() + ix;
+    auto insert_at_ci = child_index_.begin() + ix;
+    auto insert_at_c  = children_.begin() + child_index_[ix];
+    auto new_node_ix  = ix;
+
+    // we first adjust the parent sructure
+
+    // first create a new node N below the parent
+    auto parent = parents_[ix];
+    parents_.insert(insert_at_p, parent);
+    // and attach the remining subtree below it
+    parents_[ix+1] = new_node_ix;
+    // shift all parents, as the indices changed when we
+    // inserted a new node
+    for (auto i: make_span(ix + 2, parents().size())) {
+        if (parents_[i] >= new_node_ix) {
+            parents_[i]++;
+        }
+    }
+
+    // now we adjust the children structure
+
+    // insert a child node for the new node N, pointing to
+    // the old node A
+    child_index_.insert(insert_at_ci, child_index_[ix]);
+    // we will set this value later as it will be overridden
+    children_.insert(insert_at_c, ~0u);
+    // shift indices for all larger indices, as we inserted
+    // a new element in the list
+    for (auto i: make_span(ix + 1, child_index_.size())) {
+        child_index_[i]++;
+    }
+    for (auto i: make_span(0, children_.size())) {
+        if(children_[i] > new_node_ix) {
+            children_[i]++;
+        }
+    }
+    // set the children of the new node to the old subtree
+    children_[child_index_[ix]] = ix + 1;
+
+    return ix+1;
+}
+
+tree::iarray tree::select_new_root(int_type root) {
+    using util::make_span;
+
+    const auto num_nodes = parents().size();
+
+    if(root >= num_nodes && root != no_parent) {
+        throw std::domain_error(
+            "root is out of bounds: root="+std::to_string(root)+", nodes="
+            +std::to_string(num_nodes)
+        );
+    }
+
+    // walk up to the old root and turn `parent->child` into `parent<-child`
+    auto prev = no_parent;
+    auto current = root;
+    while (current != no_parent) {
+        auto parent = parents_[current];
+        parents_[current] = prev;
+        prev = current;
+        current = parent;
+    }
+
+    // sort the list to get min degree ordering and keep index such that we
+    // can sort also the `branch_starts` array.
+
+    // comput the depth for each node
+    iarray depth (num_nodes, 0);
+    // the depth when we don't count nodes that only have one child
+    iarray reduced_depth (num_nodes, 0);
+    // the index of the last node that passed this node on its way to the root
+    // we need this to keep nodes that are part of the same reduced tree close
+    // together in the final sorting.
+    //
+    // Instead of the left order we want the right order.
+    // .-----------------------.
+    // |    0            0     |
+    // |   / \          / \    |
+    // |  1   2        1   3   |
+    // |  |   |        |   |   |
+    // |  3   4        2   4   |
+    // |     / \          / \  |
+    // |    5  6         5  6  |
+    // '-----------------------'
+    //
+    // we achieve this by first sorting by reduced_depth, branch_ix and depth
+    // in this order. The resulting ordering satisfies minimal degree ordering.
+    //
+    // Using the tree above we would get the following results:
+    //
+    // `depth`           `reduced_depth`   `branch_ix`
+    // .-----------.     .-----------.     .-----------.
+    // |    0      |     |    0      |     |    6      |
+    // |   / \     |     |   / \     |     |   / \     |
+    // |  1   1    |     |  1   1    |     |  2   6    |
+    // |  |   |    |     |  |   |    |     |  |   |    |
+    // |  2   2    |     |  1   1    |     |  2   6    |
+    // |     / \   |     |     / \   |     |     / \   |
+    // |    3   3  |     |    2   2  |     |    5   6  |
+    // '-----------'     '-----------'     '-----------'
+    iarray branch_ix (num_nodes, 0);
+    // we cannot use the existing `children_` array as we only updated the
+    // parent structure yet
+    auto new_num_children = algorithms::child_count(parents_);
+    for (auto n: make_span(num_nodes)) {
+        branch_ix[n] = n;
+        auto prev = n;
+        auto curr = parents_[n];
+
+        // find the way to the root
+        while (curr != no_parent) {
+            depth[n]++;
+            if (new_num_children[curr] > 1) {
+                reduced_depth[n]++;
+            }
+            branch_ix[curr] = branch_ix[prev];
+            curr = parents_[curr];
+        }
+    }
+
+    // maps new indices to old indices
+    iarray indices (num_nodes);
+    // fill array with indices
+    for (auto i: make_span(num_nodes)) {
+        indices[i] = i;
+    }
+    // perform sort by depth index to get the permutation
+    std::sort(indices.begin(), indices.end(), [&](auto i, auto j){
+        if (reduced_depth[i] != reduced_depth[j]) {
+            return reduced_depth[i] < reduced_depth[j];
+        }
+        if (branch_ix[i] != branch_ix[j]) {
+            return branch_ix[i] < branch_ix[j];
+        }
+        return depth[i] < depth[j];
+    });
+    // maps old indices to new indices
+    iarray indices_inv (num_nodes);
+    // fill array with indices
+    for (auto i: make_span(num_nodes)) {
+        indices_inv[i] = i;
+    }
+    // perform sort
+    std::sort(indices_inv.begin(), indices_inv.end(), [&](auto i, auto j){
+        return indices[i] < indices[j];
+    });
+
+    // translate the parent vetor to new indices
+    for (auto i: make_span(num_nodes)) {
+        if (parents_[i] != no_parent) {
+            parents_[i] = indices_inv[parents_[i]];
+        }
+    }
+
+    iarray new_parents (num_nodes);
+    for (auto i: make_span(num_nodes)) {
+        new_parents[i] = parents_[indices[i]];
+    }
+    // parent now starts with the root, then it's children, then their
+    // children, etc...
+
+    // recompute the children array
+    memory::copy(new_parents, parents_);
+    memory::copy(algorithms::make_index(algorithms::child_count(parents_)), child_index_);
+
+    std::vector<int_type> pos(parents_.size(), 0);
+    for (auto i = 1u; i < parents_.size(); ++i) {
+        auto p = parents_[i];
+        children_[child_index_[p] + pos[p]] = i;
+        ++pos[p];
+    }
+
+    return indices;
+}
+
+tree::iarray tree::minimize_depth() {
+    const auto num_nodes = parents().size();
+    tree::iarray seen(num_nodes, 0);
+
+    // find the furhtest node from the root
+    std::queue<tree::int_type> queue;
+    queue.push(0); // start at the root node
+    seen[0] = 1;
+    auto front = queue.front();
+    // breath first traversal
+    while (!queue.empty()) {
+        front = queue.front();
+        queue.pop();
+        // we only have to check children as we started at the root node
+        auto cs = children(front);
+        for (auto c: cs) {
+            if (seen[c] == 0) {
+                seen[c] = 1;
+                queue.push(c);
+            }
+        }
+    }
+
+    auto u = front;
+
+    // find the furhtest node from this node
+    std::fill(seen.begin(), seen.end(), 0);
+    queue.push(u);
+    seen[u] = 1;
+    front = queue.front();
+    // breath first traversal
+    while (!queue.empty()) {
+        front = queue.front();
+        queue.pop();
+        auto cs = children(front);
+        for (auto c: cs) {
+            if (seen[c] == 0) {
+                seen[c] = 1;
+                queue.push(c);
+            }
+        }
+        // also check the partent node!
+        auto c = parent(front);
+        if (c != tree::no_parent && seen[c] == 0) {
+            seen[c] = 1;
+            queue.push(c);
+        }
+    }
+
+    auto v = front;
+
+    // now find the middle between u and v
+
+    // each path to the root
+    tree::iarray path_to_root_u (1, u);
+    tree::iarray path_to_root_v (1, v);
+
+    auto curr = parent(u);
+    while (curr != tree::no_parent) {
+        path_to_root_u.push_back(curr);
+        curr = parent(curr);
+    }
+    curr = parent(v);
+    while (curr != tree::no_parent) {
+        path_to_root_v.push_back(curr);
+        curr = parent(curr);
+    }
+
+    // reduce the path
+    auto last_together = 0;
+    while (path_to_root_u.back() == path_to_root_v.back()) {
+        last_together = path_to_root_u.back();
+        path_to_root_u.pop_back();
+        path_to_root_v.pop_back();
+    }
+    path_to_root_u.push_back(last_together);
+
+    auto path_length = path_to_root_u.size() + path_to_root_v.size() - 1;
+
+    // walk up half of the path length to find the middle node
+    tree::int_type root;
+    if (path_to_root_u.size() > path_to_root_v.size()) {
+        root = path_to_root_u[path_length / 2];
+    } else {
+        root = path_to_root_v[path_length / 2];
+    }
+
+    return select_new_root(root);
+}
+
+/// memory used to store tree (in bytes)
+std::size_t tree::memory() const {
+    return sizeof(int_type)*(children_.size()+child_index_.size()+parents_.size())
+        + sizeof(tree);
+}
+
+void tree::init(tree::size_type nnode) {
+    if (nnode) {
+        auto nchild = nnode - 1;
+        children_.resize(nchild);
+        child_index_.resize(nnode+1);
+        parents_.resize(nnode);
+    }
+    else {
+        children_.resize(0);
+        child_index_.resize(0);
+        parents_.resize(0);
+    }
+}
+
+// recursive helper for the depth_from_root() below
+void depth_from_root(const tree& t, tree::iarray& depth, tree::int_type segment) {
+    auto d = depth[t.parent(segment)] + 1;
+    depth[segment] = d;
+    for(auto c : t.children(segment)) {
+        depth_from_root(t, depth, c);
+    }
+}
+
+tree::iarray depth_from_root(const tree& t) {
+    tree::iarray depth(t.num_segments());
+    depth[0] = 0;
+    for (auto c: t.children(0)) {
+        depth_from_root(t, depth, c);
+    }
+
+    return depth;
+}
+
+} // namespace arb
diff --git a/arbor/tree.hpp b/arbor/tree.hpp
index 5adb6f6d..da55f2a9 100644
--- a/arbor/tree.hpp
+++ b/arbor/tree.hpp
@@ -2,6 +2,7 @@
 
 #include <algorithm>
 #include <cassert>
+#include <fstream>
 #include <numeric>
 #include <vector>
 
@@ -23,81 +24,21 @@ public:
 
     tree() = default;
 
-    tree& operator=(tree&& other) {
-        std::swap(child_index_, other.child_index_);
-        std::swap(children_, other.children_);
-        std::swap(parents_, other.parents_);
-        return *this;
-    }
-
-    tree& operator=(tree const& other) {
-        children_ = other.children_;
-        child_index_ = other.child_index_;
-        parents_ = other.child_index_;
-        return *this;
-    }
-
-    // copy constructors take advantage of the assignment operators
-    // defined above
-    tree(tree const& other) {
-        *this = other;
-    }
-
-    tree(tree&& other) {
-        *this = std::move(other);
-    }
-
     /// Create the tree from a parent index that lists the parent segment
     /// of each segment in a cell tree.
-    tree(std::vector<int_type> parent_index) {
-        // validate the input
-        if(!algorithms::is_minimal_degree(parent_index)) {
-            throw std::domain_error(
-                "parent index used to build a tree did not satisfy minimal degree ordering"
-            );
-        }
-
-        // an empty parent_index implies a single-compartment/segment cell
-        arb_assert(parent_index.size()!=0u);
+    tree(std::vector<int_type> parent_index);
 
-        init(parent_index.size());
-        memory::copy(parent_index, parents_);
-        parents_[0] = no_parent;
+    size_type num_children() const;
 
-        memory::copy(algorithms::make_index(algorithms::child_count(parents_)), child_index_);
+    size_type num_children(size_t b) const;
 
-        std::vector<int_type> pos(parents_.size(), 0);
-        for (auto i = 1u; i < parents_.size(); ++i) {
-            auto p = parents_[i];
-            children_[child_index_[p] + pos[p]] = i;
-            ++pos[p];
-        }
-    }
-
-    size_type num_children() const {
-        return static_cast<size_type>(children_.size());
-    }
-
-    size_type num_children(size_t b) const {
-        return child_index_[b+1] - child_index_[b];
-    }
-
-    size_type num_segments() const {
-        // the number of segments/nodes is the size of the child index minus 1
-        // ... except for the case of an empty tree
-        auto sz = static_cast<size_type>(child_index_.size());
-        return sz ? sz - 1 : 0;
-    }
+    size_type num_segments() const;
 
     /// return the child index
-    const iarray& child_index() {
-        return child_index_;
-    }
+    const iarray& child_index();
 
     /// return the list of all children
-    const iarray& children() const {
-        return children_;
-    }
+    const iarray& children() const;
 
     /// return the list of all children of branch i
     auto children(size_type i) const {
@@ -107,38 +48,62 @@ public:
     }
 
     /// return the list of parents
-    const iarray& parents() const {
-        return parents_;
-    }
+    const iarray& parents() const;
 
     /// return the parent of branch b
-    int_type parent(size_t b) const {
-        return parents_[b];
-    }
-    int_type& parent(size_t b) {
-        return parents_[b];
-    }
+    const int_type& parent(size_t b) const;
+    int_type& parent(size_t b);
+
+    // splits the node in two parts. Returns `ix + 1` which is the new index of
+    // the old node.
+    // .-------------------.
+    // |      P         P  |
+    // |     /         /   |
+    // |    A   ~~>   N    |
+    // |   / \        |    |
+    // |  B  C        A    |
+    // |             / \   |
+    // |            B  C   |
+    // '-------------------'
+    int_type split_node(int_type ix);
+
+    // Changes the root node of a tree
+    // .------------------------.
+    // |        A               |
+    // |       / \         R    |
+    // |      R  B        / \   |
+    // |     /     ~~>   A  C   |
+    // |    C           /  / \  |
+    // |   / \         B  D  E  |
+    // |  D  E                  |
+    // '------------------------'
+    // Returns the permutation applied to the nodes,
+    // i.e. `new_node_data[i] = old_node_data[perm[i]]`
+    //
+    // This function has the additional effect, that branches with only one
+    // child branch get merged. That means that `select_new_root(0)` can also
+    // lead to an permutation of the indices of the compartments:
+    // .------------------------------.
+    // |        0               0     |
+    // |       / \             / \    |
+    // |      1  3            1  3    |
+    // |     /    \   ~~>    /    \   |
+    // |    2     4         2     4   |
+    // |   / \     \       / \     \  |
+    // |  5  6     7      6  7     5  |
+    // '------------------------------'
+    iarray select_new_root(int_type root);
+
+    // Selects a new node such that the depth of the graph is minimal.
+    // Returns the permutation applied to the nodes,
+    // i.e. `new_node_data[i] = old_node_data[perm[i]]`
+    iarray minimize_depth();
 
     /// memory used to store tree (in bytes)
-    std::size_t memory() const {
-        return sizeof(int_type)*(children_.size()+child_index_.size()+parents_.size())
-            + sizeof(tree);
-    }
+    std::size_t memory() const;
 
 private:
-    void init(size_type nnode) {
-        if (nnode) {
-            auto nchild = nnode - 1;
-            children_.resize(nchild);
-            child_index_.resize(nnode+1);
-            parents_.resize(nnode);
-        }
-        else {
-            children_.resize(0);
-            child_index_.resize(0);
-            parents_.resize(0);
-        }
-    }
+    void init(size_type nnode);
 
     // state
     iarray children_;
@@ -146,6 +111,10 @@ private:
     iarray parents_;
 };
 
+// Calculates the depth of each branch from the root of a cell segment tree.
+// The root has depth 0, it's children have depth 1, and so on.
+tree::iarray depth_from_root(const tree& t);
+
 template <typename C>
 std::vector<tree::int_type> make_parent_index(tree const& t, C const& counts)
 {
diff --git a/test/unit/test_gpu_stack.cu b/test/unit/test_gpu_stack.cu
index e75f8c3b..10af85c4 100644
--- a/test/unit/test_gpu_stack.cu
+++ b/test/unit/test_gpu_stack.cu
@@ -66,6 +66,7 @@ TEST(stack, push_back) {
     kernels::push_back<<<1, n>>>(sstorage, kernels::all_ftor());
     cudaDeviceSynchronize();
     EXPECT_EQ(n, s.size());
+    std::sort(sstorage.data, sstorage.data+s.size());
     for (auto i=0; i<int(s.size()); ++i) {
         EXPECT_EQ(i, s[i]);
     }
@@ -74,6 +75,7 @@ TEST(stack, push_back) {
     kernels::push_back<<<1, n>>>(sstorage, kernels::even_ftor());
     cudaDeviceSynchronize();
     EXPECT_EQ(n/2, s.size());
+    std::sort(sstorage.data, sstorage.data+s.size());
     for (auto i=0; i<int(s.size())/2; ++i) {
         EXPECT_EQ(2*i, s[i]);
     }
@@ -82,6 +84,7 @@ TEST(stack, push_back) {
     kernels::push_back<<<1, n>>>(sstorage, kernels::odd_ftor());
     cudaDeviceSynchronize();
     EXPECT_EQ(n/2, s.size());
+    std::sort(sstorage.data, sstorage.data+s.size());
     for (auto i=0; i<int(s.size())/2; ++i) {
         EXPECT_EQ(2*i+1, s[i]);
     }
diff --git a/test/unit/test_matrix.cu b/test/unit/test_matrix.cu
index b01fff92..b3424ad2 100644
--- a/test/unit/test_matrix.cu
+++ b/test/unit/test_matrix.cu
@@ -15,6 +15,7 @@
 #include "backends/gpu/matrix_state_flat.hpp"
 #include "backends/gpu/matrix_state_interleaved.hpp"
 #include "backends/gpu/matrix_interleave.hpp"
+#include "backends/gpu/matrix_state_fine.hpp"
 
 #include "../gtest.h"
 #include "common.hpp"
@@ -214,6 +215,7 @@ TEST(matrix, backends)
 
     using state_flat = gpu::matrix_state_flat<T, I>;
     using state_intl = gpu::matrix_state_interleaved<T, I>;
+    using state_fine = gpu::matrix_state_fine<T, I>;
 
     using gpu_array  = memory::device_vector<T>;
 
@@ -286,6 +288,7 @@ TEST(matrix, backends)
     // Make the reference matrix and the gpu matrix
     auto flat = state_flat(p, cell_cv_divs, Cm, g, area); // flat
     auto intl = state_intl(p, cell_cv_divs, Cm, g, area); // interleaved
+    auto fine = state_fine(p, cell_cv_divs, Cm, g, area); // interleaved
 
     // Set the integration times for the cells to be between 0.01 and 0.02 ms.
     std::vector<T> dt(num_mtx, 0);
@@ -300,90 +303,27 @@ TEST(matrix, backends)
 
     flat.assemble(gpu_dt, gpu_v, gpu_i);
     intl.assemble(gpu_dt, gpu_v, gpu_i);
+    fine.assemble(gpu_dt, gpu_v, gpu_i);
 
     flat.solve();
     intl.solve();
+    fine.solve();
 
     // Compare the results.
     // We expect exact equality for the two gpu matrix implementations because both
     // perform the same operations in the same order on the same inputs.
     std::vector<double> x_flat = assign_from(on_host(flat.solution()));
     std::vector<double> x_intl = assign_from(on_host(intl.solution()));
-    EXPECT_EQ(x_flat, x_intl);
-}
-
-/*
-
-// Test for special zero diagonal behaviour. (see `test_matrix.cpp`.)
-TEST(matrix, zero_diagonal)
-{
-    using util::assign;
-
-    using value_type = gpu::backend::value_type;
-    using size_type = gpu::backend::size_type;
-    using matrix_type = gpu::backend::matrix_state;
-    using vvec = std::vector<value_type>;
-
-    // Combined matrix may have zero-blocks, corresponding to a zero dt.
-    // Zero-blocks are indicated by zero value in the diagonal (the off-diagonal
-    // elements should be ignored).
-    // These submatrices should leave the rhs as-is when solved.
-
-    // Three matrices, sizes 3, 3 and 2, with no branching.
-    std::vector<size_type> p = {0, 0, 1, 3, 3, 5, 5};
-    std::vector<size_type> c = {0, 3, 5, 7};
-
-    // Face conductances.
-    std::vector<value_type> g = {0, 1, 1, 0, 1, 0, 2};
-
-    // dt of 1e-3.
-    std::vector<value_type> dt(3, 1.0e-3);
-
-    // Capacitances.
-    std::vector<value_type> Cm = {1, 1, 1, 1, 1, 2, 3};
-
-    // Intial voltage of zero; currents alone determine rhs.
-    std::vector<value_type> v(7, 0.0);
-    std::vector<value_type> i = {-3, -5, -7, -6, -9, -16, -32};
+    // as the fine algorithm contains atomics the solution might be slightly
+    // different from flat and interleaved
+    std::vector<double> x_fine = assign_from(on_host(fine.solution()));
 
-    // Expected matrix and rhs:
-    // u = [ 0 -1 -1  0 -1  0 -2]
-    // d = [ 2  3  2  2  2  4  5]
-    // b = [ 3  5  7  2  4 16 32]
-    //
-    // Expected solution:
-    // x = [ 4  5  6  7  8  9 10]
-
-    matrix_type m(p, c, Cm, g);
-    auto gpu_dt = on_gpu(dt);
-    auto gpu_v  = on_gpu(v);
-    auto gpu_i  = on_gpu(i);
-    m.assemble(gpu_dt, gpu_v, gpu_i);
-    m.solve();
-
-    vvec x;
-    assign(x, on_host(m.solution()));
-    std::vector<value_type> expected = {4, 5, 6, 7, 8, 9, 10};
-
-    EXPECT_TRUE(testing::seq_almost_eq<double>(expected, x));
-
-    // Set dt of 2nd (middle) submatrix to zero. Solution
-    // should then return voltage values for that submatrix.
+    auto max_diff_fine =
+        util::max_value(
+            util::transform_view(
+                util::count_along(x_flat),
+                [&](unsigned i) {return std::abs(x_flat[i] - x_fine[i]);}));
 
-    dt[1] = 0;
-    gpu_dt = on_gpu(dt);
-
-    v[3] = 20;
-    v[4] = 30;
-    gpu_v  = on_gpu(v);
-
-    m.assemble(gpu_dt, gpu_v, gpu_i);
-    m.solve();
-
-    assign(x, on_host(m.solution()));
-    expected = {4, 5, 6, 20, 30, 9, 10};
-
-    EXPECT_TRUE(testing::seq_almost_eq<double>(expected, x));
+    EXPECT_EQ(x_flat, x_intl);
+    EXPECT_LE(max_diff_fine, 1e-12);
 }
-
-*/
diff --git a/test/unit/test_matrix_cpuvsgpu.cpp b/test/unit/test_matrix_cpuvsgpu.cpp
index 27390686..861c83dd 100644
--- a/test/unit/test_matrix_cpuvsgpu.cpp
+++ b/test/unit/test_matrix_cpuvsgpu.cpp
@@ -69,22 +69,25 @@ TEST(matrix, assemble)
     //           6
     //            \.
     //             7
+    // p_3: 1 branch, 1 compartment
+    //
+    // 0
 
     // The parent indexes that define the two matrix structures
     std::vector<std::vector<I>>
-        p_base = { {0,0,1,2,2,4}, {0,0,1,2,3,3,2,6} };
+        p_base = { {0,0,1,2,2,4}, {0,0,1,2,3,3,2,6}, {0} };
 
     // Make a set of matrices based on repeating this pattern.
     // We assign the patterns round-robin, i.e. so that the input
     // matrices will have alternating sizes of 6 and 8, which will
     // test the solver with variable matrix size, and exercise
     // solvers that reorder matrices according to size.
-    const int num_mtx = 8;
+    const int num_mtx = 100;
 
     std::vector<I> p;
     std::vector<I> cell_index;
     for (auto m=0; m<num_mtx; ++m) {
-        auto &p_ref = p_base[m%2];
+        auto &p_ref = p_base[m%p_base.size()];
         auto first = p.size();
         for (auto i: p_ref) {
             p.push_back(i + first);
diff --git a/test/unit/test_tree.cpp b/test/unit/test_tree.cpp
index e23cc8ed..094e77f1 100644
--- a/test/unit/test_tree.cpp
+++ b/test/unit/test_tree.cpp
@@ -9,6 +9,7 @@
 
 using namespace arb;
 using int_type = tree::int_type;
+using iarray = tree::iarray;
 
 TEST(tree, from_segment_index) {
     auto no_parent = tree::no_parent;
@@ -174,3 +175,85 @@ TEST(tree, from_segment_index) {
     }
 }
 
+TEST(tree, depth_from_root) {
+    // tree with single branch corresponding to the root node
+    // this is equivalent to a single compartment model
+    //      CASE 1 : single root node in parent_index
+    {
+        std::vector<int_type> parent_index = {0};
+        iarray expected = {0u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+
+    {
+        //     0
+        //    / \.
+        //   1   2
+        std::vector<int_type> parent_index = {0, 0, 0};
+        iarray expected = {0u, 1u, 1u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+    {
+        //     0-1-2-3
+        std::vector<int_type> parent_index = {0, 0, 1, 2};
+        iarray expected = {0u, 1u, 2u, 3u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+    {
+        //
+        //     0
+        //    /|\.
+        //   1 2 3
+        //
+        std::vector<int_type> parent_index = {0, 0, 0, 0};
+        iarray expected = {0u, 1u, 1u, 1u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+    {
+        //
+        //   0
+        //  /|\.
+        // 1 2 3
+        //    / \.
+        //   4   5
+        //
+        std::vector<int_type> parent_index = {0, 0, 0, 0, 3, 3};
+        iarray expected = {0u, 1u, 1u, 1u, 2u, 2u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+    {
+        //
+        //              0
+        //             /
+        //            1
+        //           / \.
+        //          2   3
+        std::vector<int_type> parent_index = {0,0,1,1};
+        iarray expected = {0u, 1u, 2u, 2u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+    {
+        //
+        //              0
+        //             /|\.
+        //            1 4 5
+        //           / \.
+        //          2   3
+        std::vector<int_type> parent_index = {0,0,1,1,0,0};
+        iarray expected = {0u, 1u, 2u, 2u, 1u, 1u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+    {
+        //              0
+        //             / \.
+        //            1   2
+        //           / \.
+        //          3   4
+        //             / \.
+        //            5   6
+        std::vector<int_type> parent_index = {0,0,0,1,1,4,4};
+        iarray expected = {0u, 1u, 1u, 2u, 2u, 3u, 3u};
+        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
+    }
+}
+
-- 
GitLab