From 67b70a80f71a0dc4bb8b2f555b6da7b31aba9803 Mon Sep 17 00:00:00 2001
From: Sam Yates <yates@cscs.ch>
Date: Tue, 13 Nov 2018 11:20:13 +0100
Subject: [PATCH] Revert "Squashed merge for fine matrix solver (#640)"

This reverts commit be2a8a9faa58832ea4482b8e012d5169c805a5f1.
---
 CMakeLists.txt                           |   4 -
 arbor/CMakeLists.txt                     |   4 -
 arbor/algorithms.hpp                     |  31 +-
 arbor/backends/gpu/forest.cpp            | 181 ---------
 arbor/backends/gpu/forest.hpp            | 140 -------
 arbor/backends/gpu/fvm.hpp               |  12 +-
 arbor/backends/gpu/matrix_fine.cpp       |  71 ----
 arbor/backends/gpu/matrix_fine.cu        | 314 --------------
 arbor/backends/gpu/matrix_fine.hpp       |  77 ----
 arbor/backends/gpu/matrix_solve.cu       |   1 -
 arbor/backends/gpu/matrix_state_fine.hpp | 494 -----------------------
 arbor/tree.cpp                           | 385 ------------------
 arbor/tree.hpp                           | 153 ++++---
 test/unit/test_gpu_stack.cu              |   3 -
 test/unit/test_matrix.cu                 |  90 ++++-
 test/unit/test_matrix_cpuvsgpu.cpp       |   9 +-
 test/unit/test_tree.cpp                  |  83 ----
 17 files changed, 179 insertions(+), 1873 deletions(-)
 delete mode 100644 arbor/backends/gpu/forest.cpp
 delete mode 100644 arbor/backends/gpu/forest.hpp
 delete mode 100644 arbor/backends/gpu/matrix_fine.cpp
 delete mode 100644 arbor/backends/gpu/matrix_fine.cu
 delete mode 100644 arbor/backends/gpu/matrix_fine.hpp
 delete mode 100644 arbor/backends/gpu/matrix_state_fine.hpp
 delete mode 100644 arbor/tree.cpp

diff --git a/CMakeLists.txt b/CMakeLists.txt
index b9802b68..b8c83fae 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -39,7 +39,6 @@ set(ARB_VALIDATION_DATA_DIR "${PROJECT_SOURCE_DIR}/validation/data" CACHE PATH
 #----------------------------------------------------------
 
 option(ARB_WITH_GPU "build with GPU support" OFF)
-option(ARB_WITH_GPU_FINE_MATRIX "use optimized fine matrix solver" ON)
 
 option(ARB_WITH_MPI "build with MPI support" OFF)
 
@@ -198,9 +197,6 @@ if(ARB_WITH_GPU)
         "$<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe=--diag_suppress=integer_sign_change>"
         "$<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe=--diag_suppress=unsigned_compare_with_zero>")
     target_compile_definitions(arbor-private-deps INTERFACE ARB_HAVE_GPU)
-    if(ARB_WITH_GPU_FINE_MATRIX)
-        target_compile_definitions(arbor-private-deps INTERFACE ARB_HAVE_GPU_FINE_MATRIX)
-    endif()
 
     target_compile_options(arbor-private-deps INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_35,code=sm_35>)
     target_compile_options(arbor-private-deps INTERFACE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_37,code=sm_37>)
diff --git a/arbor/CMakeLists.txt b/arbor/CMakeLists.txt
index 87ec5150..dd885eac 100644
--- a/arbor/CMakeLists.txt
+++ b/arbor/CMakeLists.txt
@@ -45,7 +45,6 @@ set(arbor_sources
     threading/threading.cpp
     threading/thread_info.cpp
     thread_private_spike_store.cpp
-    tree.cpp
     util/hostname.cpp
     util/unwind.cpp
     version.cpp
@@ -60,13 +59,10 @@ if(ARB_WITH_CUDA)
         backends/gpu/threshold_watcher.cu
         backends/gpu/matrix_assemble.cu
         backends/gpu/matrix_interleave.cu
-        backends/gpu/matrix_fine.cu
-        backends/gpu/matrix_fine.cpp
         backends/gpu/matrix_solve.cu
         backends/gpu/multi_event_stream.cpp
         backends/gpu/multi_event_stream.cu
         backends/gpu/shared_state.cu
-        backends/gpu/forest.cpp
         backends/gpu/stimulus.cu
         backends/gpu/threshold_watcher.cu
         memory/fill.cu
diff --git a/arbor/algorithms.hpp b/arbor/algorithms.hpp
index a2ec968a..4e295834 100644
--- a/arbor/algorithms.hpp
+++ b/arbor/algorithms.hpp
@@ -41,8 +41,6 @@ mean(C const& c)
     return sum(c)/util::size(c);
 }
 
-// returns the prefix sum of c in the form `[0, c[0], c[0]+c[1], ..., sum(c)]`.
-// This means that the returned vector has one more element than c.
 template <typename C>
 C make_index(C const& c)
 {
@@ -96,9 +94,6 @@ bool is_strictly_monotonic_decreasing(C const& c)
     );
 }
 
-// check if c[0] == 0 and c[i] < 0 holds for i != 0
-// this means that children of a node always have larger indices than their
-// parent
 template <
     typename C,
     typename = typename std::enable_if<std::is_integral<typename C::value_type>::value>
@@ -135,23 +130,17 @@ bool all_negative(const C& c) {
     return util::all_of(c, [](auto v) { return v<decltype(v){}; });
 }
 
-// returns a vector containing the number of children for each node.
 template<typename C>
 std::vector<typename C::value_type> child_count(const C& parent_index)
 {
-    using value_type = typename C::value_type;
     static_assert(
-        std::is_integral<value_type>::value,
+        std::is_integral<typename C::value_type>::value,
         "integral type required"
     );
 
-    std::vector<value_type> count(parent_index.size(), 0);
-    for (auto i = 0u; i < parent_index.size(); ++i) {
-        auto p = parent_index[i];
-        // -1 means no parent
-        if (p != value_type(i) && p != value_type(-1)) {
-            ++count[p];
-        }
+    std::vector<typename C::value_type> count(parent_index.size(), 0);
+    for (auto i = 1u; i < parent_index.size(); ++i) {
+        ++count[parent_index[i]];
     }
 
     return count;
@@ -201,7 +190,7 @@ std::vector<typename C::value_type> branches(const C& parent_index)
     for (std::size_t i = 1; i < parent_index.size(); ++i) {
         auto p = parent_index[i];
         if (num_child[p] > 1 || parent_index[p] == p) {
-            // `parent_index[p] == p` ~> parent_index[i] is the soma
+            // parent_index[p] == p -> parent_index[i] is the soma
             branch_index.push_back(i);
         }
     }
@@ -210,9 +199,7 @@ std::vector<typename C::value_type> branches(const C& parent_index)
     return branch_index;
 }
 
-// creates a vector that contains the branch index for each compartment.
-// e.g. {0, 1, 5, 9, 10} -> {0, 1, 1, 1, 1, 2, 2, 2, 2, 3}
-//                  indices  0  1           5           9
+
 template<typename C>
 std::vector<typename C::value_type> expand_branches(const C& branch_index)
 {
@@ -294,13 +281,11 @@ std::vector<typename C::value_type> tree_reduce(
     arb_assert(has_contiguous_compartments(parent_index));
     arb_assert(is_strictly_monotonic_increasing(branch_index));
 
-    // expand the branch index to lookup the banch id for each compartment
+    // expand the branch index
     auto expanded_branch = expand_branches(branch_index);
 
     std::vector<typename C::value_type> new_parent_index;
-    // push the first element manually as the parent of the root might be -1
-    new_parent_index.push_back(expanded_branch[0]);
-    for (std::size_t i = 1; i < branch_index.size()-1; ++i) {
+    for (std::size_t i = 0; i < branch_index.size()-1; ++i) {
         auto p = parent_index[branch_index[i]];
         new_parent_index.push_back(expanded_branch[p]);
     }
diff --git a/arbor/backends/gpu/forest.cpp b/arbor/backends/gpu/forest.cpp
deleted file mode 100644
index 6650b92c..00000000
--- a/arbor/backends/gpu/forest.cpp
+++ /dev/null
@@ -1,181 +0,0 @@
-#include "backends/gpu/forest.hpp"
-#include "util/span.hpp"
-
-namespace arb {
-namespace gpu {
-
-forest::forest(const std::vector<size_type>& p, const std::vector<size_type>& cell_cv_divs) :
-    perm_balancing(p.size())
-{
-    using util::make_span;
-
-    auto num_cells = cell_cv_divs.size() - 1;
-
-    for (auto c: make_span(0u, num_cells)) {
-        // build the parent index for cell c
-        auto cell_start = cell_cv_divs[c];
-        std::vector<unsigned> cell_p =
-            util::assign_from(
-                util::transform_view(
-                    util::subrange_view(p, cell_cv_divs[c], cell_cv_divs[c+1]),
-                    [cell_start](unsigned i) {return i-cell_start;}));
-
-        auto fine_tree = tree(cell_p);
-
-        // select a root node and merge branches with discontinuous compartment
-        // indices
-        auto perm = fine_tree.select_new_root(0);
-        for (auto i: make_span(perm.size())) {
-            perm_balancing[cell_start + i] = cell_start + perm[i];
-        }
-
-        // find the index of the first node for each branch
-        auto branch_starts = algorithms::branches(fine_tree.parents());
-
-        // compute branch length and apply permutation
-        std::vector<unsigned> branch_lengths(branch_starts.size() - 1);
-        for (auto i: make_span(branch_lengths.size())) {
-            branch_lengths[i] = branch_starts[i+1] - branch_starts[i];
-        }
-
-        // find the parent index of branches
-        // we need to convert to cell_lid_type, required to construct a tree.
-        std::vector<cell_lid_type> branch_p =
-            util::assign_from(
-                algorithms::tree_reduce(fine_tree.parents(), branch_starts));
-        // build tree structure that describes the branch topology
-        auto cell_tree = tree(branch_p);
-
-        trees.push_back(cell_tree);
-        fine_trees.push_back(fine_tree);
-        tree_branch_starts.push_back(branch_starts);
-        tree_branch_lengths.push_back(branch_lengths);
-    }
-}
-
-void forest::optimize() {
-    using util::make_span;
-
-    // cut the tree
-    unsigned count = 1; // number of nodes found on the previous level
-    for (auto level = 0; count > 0; level++) {
-        count = 0;
-
-        // decide where to cut it ...
-        unsigned max_length = 0;
-        for (auto t_ix: make_span(trees.size())) { // TODO make this local on an intermediate packing
-            for (level_iterator it (&trees[t_ix], level); it.valid(); it.next()) {
-                auto length = tree_branch_lengths[t_ix][it.peek()];
-                max_length += length;
-                count++;
-            }
-        }
-        if (count == 0) {
-            // there exists no tree with branches on this level
-            continue;
-        };
-        max_length = max_length / count;
-        // avoid ininite loops
-        if (max_length <= 1) max_length = 1;
-        // we don't want too small segments
-        if (max_length <= 10) max_length = 10;
-
-        for (auto t_ix: make_span(trees.size())) {
-            // ... cut all trees on this level
-            for (level_iterator it (&trees[t_ix], level); it.valid(); it.next()) {
-
-                auto length = tree_branch_lengths[t_ix][it.peek()];
-                if (length > max_length) {
-                    // now cut the tree
-
-                    // we are allowed to mess with the tree because of the
-                    // implementation of level_iterator o.O
-
-                    auto insert_at_bs = tree_branch_starts[t_ix].begin() + it.peek();
-                    auto insert_at_ls = tree_branch_lengths[t_ix].begin() + it.peek();
-
-                    trees[t_ix].split_node(it.peek());
-
-                    // now the tree got a new node.
-                    // we now have to insert a corresponding new 'branch
-                    // start' to the list
-
-                    // make sure that `tree_branch_starts` for A and N point to
-                    // the correct slices
-                    auto old_start = tree_branch_starts[t_ix][it.peek()];
-                    // first insert, then index peek, as we already
-                    // incremented the iterator
-                    tree_branch_starts[t_ix].insert(insert_at_bs, old_start);
-                    tree_branch_lengths[t_ix].insert(insert_at_ls, max_length);
-                    tree_branch_starts[t_ix][it.peek() + 1] = old_start + max_length;
-                    tree_branch_lengths[t_ix][it.peek() + 1] = length - max_length;
-                    // we don't have to shift any indices as we did not
-                    // create any new branch segments, but just split
-                    // one over two nodes
-                }
-            }
-        }
-    }
-}
-
-
-// debugging functions:
-
-
-// Exports the tree's parent structure into a dot file.
-// the children and parent trees are equivalent. Both methods are provided
-// for debugging purposes.
-template<typename F>
-void export_parents(const tree& t, std::string file, F label) {
-    using util::make_span;
-    std::ofstream ofile;
-    ofile.open(file);
-    ofile << "strict digraph Parents {" << std::endl;
-    for (auto i: make_span(t.parents().size())) {
-        ofile << i << "[label=\"" << label(i) << "\"]" << std::endl;
-    }
-    for (auto i: make_span(t.parents().size())) {
-        auto p = t.parent(i);
-        if (p != tree::no_parent) {
-            ofile << i << " -> " << t.parent(i) << std::endl;
-        }
-    }
-    ofile << "}" << std::endl;
-    ofile.close();
-}
-
-void export_parents(const tree& t, std::string file) {
-    // the labels in the the graph are the branch indices
-    export_parents(t, file, [](auto i){return i;});
-}
-
-// Exports the tree's children structure into a dot file.
-// the children and parent trees are equivalent. Both methods are provided
-// for debugging purposes.
-template<typename F>
-void export_children(const tree& t, std::string file, F label) {
-    using util::make_span;
-    std::ofstream ofile;
-    ofile.open(file);
-    ofile << "strict digraph Children {" << std::endl;
-    for (auto i: make_span(t.num_segments())) {
-        ofile << i << "[label=\"" << label(i) << "\"]" << std::endl;
-    }
-    for (auto i: make_span(t.num_segments())) {
-        ofile << i << " -> {";
-        for (auto c: t.children(i)) {
-            ofile << " " << c;
-        }
-        ofile << "}" << std::endl;
-    }
-    ofile << "}" << std::endl;
-    ofile.close();
-}
-
-void export_children(const tree& t, std::string file) {
-    // the labels in the the graph are the branch indices
-    export_children(t, file, [](auto i){return i;});
-}
-
-} // namespace gpu
-} // namespace arb
diff --git a/arbor/backends/gpu/forest.hpp b/arbor/backends/gpu/forest.hpp
deleted file mode 100644
index 43900858..00000000
--- a/arbor/backends/gpu/forest.hpp
+++ /dev/null
@@ -1,140 +0,0 @@
-#pragma once
-
-#include <vector>
-
-#include "tree.hpp"
-
-namespace arb {
-namespace gpu {
-
-using size_type = int;
-
-struct forest {
-    forest(const std::vector<size_type>& p, const std::vector<size_type>& cell_cv_divs);
-
-    void optimize();
-
-    unsigned num_trees() {
-        return fine_trees.size();
-    }
-
-    const tree& branch_tree(unsigned tree_index) {
-        return trees[tree_index];
-    }
-
-    const tree& compartment_tree(unsigned tree_index) {
-        return fine_trees[tree_index];
-    }
-
-    // Returns the offset into the compartments of a tree where each branch
-    // starts. It holds `0 <= offset < tree.num_segments()`
-    const std::vector<unsigned>& branch_offsets(unsigned tree_index) {
-        return tree_branch_starts[tree_index];
-    }
-
-    // Returns vector of the length of each branch in a tree.
-    const std::vector<unsigned>& branch_lengths(unsigned tree_index) {
-        return tree_branch_lengths[tree_index];
-    }
-
-    // Return the permutation that was applied to the compartments in the
-    // format: `new[i] = old[perm[i]]`
-    const std::vector<size_type>& permutation() {
-        return perm_balancing;
-    }
-
-    // trees of compartments
-    std::vector<tree> fine_trees;
-    std::vector<std::vector<unsigned>> tree_branch_starts;
-    std::vector<std::vector<unsigned>> tree_branch_lengths;
-    // trees of branches
-    std::vector<tree> trees;
-
-    // the permutation matrix used by the balancing algorithm
-    // format: `solver_format[i] = external_format[perm[i]]`
-    std::vector<size_type> perm_balancing;
-};
-
-
-struct level_iterator {
-    level_iterator(tree* t, unsigned level) {
-        tree_ = t;
-        only_on_level = level;
-        // due to the ordering of the nodes we know that 0 is the root
-        current_node  = 0;
-        current_level = 0;
-        next_children = 0;
-        if (level != 0) {
-            next();
-        };
-    }
-
-    void advance_depth_first() {
-        auto children = tree_->children(current_node);
-        if (next_children < children.size() && current_level <= only_on_level) {
-            // go to next children
-            current_level += 1;
-            current_node = children[next_children];
-            next_children = 0;
-        } else {
-            // go to parent
-            auto parent_node = tree_->parents()[current_node];
-            constexpr unsigned npos = unsigned(-1);
-            if (parent_node != npos) {
-                auto siblings = tree_->children(parent_node);
-                // get the index in the child list of the parent
-                unsigned index = 0;
-                while (siblings[index] != current_node) { // TODO repalce by array lockup: sibling_nr
-                    index += 1;
-                }
-
-                current_level -= 1;
-                current_node = parent_node;
-                next_children = index + 1;
-            } else {
-                // we are done with the iteration
-                current_level = -1;
-                current_node  = -1;
-                next_children = -1;
-            }
-
-        }
-    }
-
-    unsigned next() {
-        constexpr unsigned npos = unsigned(-1);
-        if (!valid()) {
-            // we are done
-            return npos;
-        } else {
-            advance_depth_first();
-            // next_children != 0 means, that we have seen the node before
-            while (valid() && (current_level != only_on_level || next_children != 0)) {
-                advance_depth_first();
-            }
-            return current_node;
-        }
-    }
-
-    bool valid() {
-        constexpr unsigned npos = unsigned(-1);
-        return this->peek() != npos;
-    }
-
-    unsigned peek() {
-        return current_node;
-    }
-
-private:
-    tree* tree_;
-
-    unsigned current_node;
-    unsigned current_level;
-    unsigned next_children;
-
-    unsigned only_on_level;
-};
-
-
-} // namespace gpu
-} // namespace arb
diff --git a/arbor/backends/gpu/fvm.hpp b/arbor/backends/gpu/fvm.hpp
index cc82bed0..9d02eea3 100644
--- a/arbor/backends/gpu/fvm.hpp
+++ b/arbor/backends/gpu/fvm.hpp
@@ -14,15 +14,9 @@
 #include "backends/gpu/gpu_store_types.hpp"
 #include "backends/gpu/shared_state.hpp"
 
+#include "matrix_state_interleaved.hpp"
 #include "threshold_watcher.hpp"
 
-
-#ifdef ARB_HAVE_GPU_FINE_MATRIX
-    #include "matrix_state_fine.hpp"
-#else
-    #include "matrix_state_interleaved.hpp"
-#endif
-
 namespace arb {
 namespace gpu {
 
@@ -45,11 +39,7 @@ struct backend {
         return memory::on_host(v);
     }
 
-#ifdef ARB_HAVE_GPU_FINE_MATRIX
-    using matrix_state = arb::gpu::matrix_state_fine<value_type, index_type>;
-#else
     using matrix_state = arb::gpu::matrix_state_interleaved<value_type, index_type>;
-#endif
     using threshold_watcher = arb::gpu::threshold_watcher;
 
     using deliverable_event_stream = arb::gpu::deliverable_event_stream;
diff --git a/arbor/backends/gpu/matrix_fine.cpp b/arbor/backends/gpu/matrix_fine.cpp
deleted file mode 100644
index b1d7a227..00000000
--- a/arbor/backends/gpu/matrix_fine.cpp
+++ /dev/null
@@ -1,71 +0,0 @@
-#include <ostream>
-
-#include <cuda_runtime.h>
-
-#include "memory/cuda_wrappers.hpp"
-#include "util/span.hpp"
-
-#include "matrix_fine.hpp"
-
-namespace arb {
-namespace gpu {
-
-level::level(unsigned branches):
-    num_branches(branches)
-{
-    using memory::cuda_malloc_managed;
-
-    using arb::memory::cuda_malloc_managed;
-    if (num_branches!=0) {
-        lengths = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
-        parents = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
-        cudaDeviceSynchronize();
-    }
-}
-
-level::level(level&& other) {
-    std::swap(other.lengths, this->lengths);
-    std::swap(other.parents, this->parents);
-    std::swap(other.num_branches, this->num_branches);
-    std::swap(other.max_length, this->max_length);
-    std::swap(other.data_index, this->data_index);
-}
-
-level::level(const level& other) {
-    using memory::cuda_malloc_managed;
-
-    num_branches = other.num_branches;
-    max_length = other.max_length;
-    data_index = other.data_index;
-    if (num_branches!=0) {
-        lengths = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
-        parents = static_cast<unsigned*>(cuda_malloc_managed(num_branches*sizeof(unsigned)));
-        cudaDeviceSynchronize();
-        std::copy(other.lengths, other.lengths+num_branches, lengths);
-        std::copy(other.parents, other.parents+num_branches, parents);
-    }
-}
-
-level::~level() {
-    if (num_branches!=0) {
-        cudaDeviceSynchronize(); // to ensure that managed memory has been freed
-        if (lengths) arb::memory::cuda_free(lengths);
-        if (parents) arb::memory::cuda_free(parents);
-    }
-}
-
-std::ostream& operator<<(std::ostream& o, const level& l) {
-    cudaDeviceSynchronize();
-    o << "branches:" << l.num_branches
-      << " max_len:" << l.max_length
-      << " data_idx:" << l.data_index
-      << " lengths:[";
-    for (auto i: util::make_span(l.num_branches)) o << l.lengths[i] << " ";
-    o << "] parents:[";
-    for (auto i: util::make_span(l.num_branches)) o << l.parents[i] << " ";
-    o << "]";
-    return o;
-}
-
-} // namespace gpu
-} // namespace arb
diff --git a/arbor/backends/gpu/matrix_fine.cu b/arbor/backends/gpu/matrix_fine.cu
deleted file mode 100644
index 764b39d5..00000000
--- a/arbor/backends/gpu/matrix_fine.cu
+++ /dev/null
@@ -1,314 +0,0 @@
-#include <arbor/fvm_types.hpp>
-
-#include "cuda_atomic.hpp"
-#include "cuda_common.hpp"
-#include "matrix_common.hpp"
-#include "matrix_fine.hpp"
-
-namespace arb {
-namespace gpu {
-
-namespace kernels {
-
-//
-// gather and scatter kernels
-//
-
-// to[i] = from[p[i]]
-template <typename T, typename I>
-__global__
-void gather(const T* from, T* to, const I* p, unsigned n) {
-    unsigned i = threadIdx.x + blockDim.x*blockIdx.x;
-
-    if (i<n) {
-        to[i] = from[p[i]];
-    }
-}
-
-// to[p[i]] = from[i]
-template <typename T, typename I>
-__global__
-void scatter(const T* from, T* to, const I* p, unsigned n) {
-    unsigned i = threadIdx.x + blockDim.x*blockIdx.x;
-
-    if (i<n) {
-        to[p[i]] = from[i];
-    }
-}
-
-/// GPU implementatin of Hines matrix assembly.
-/// Fine layout.
-/// For a given time step size dt:
-///     - use the precomputed alpha and alpha_d values to construct the diagonal
-///       and off diagonal of the symmetric Hines matrix.
-///     - compute the RHS of the linear system to solve.
-template <typename T, typename I>
-__global__
-void assemble_matrix_fine(
-        T* d,
-        T* rhs,
-        const T* invariant_d,
-        const T* voltage,
-        const T* current,
-        const T* cv_capacitance,
-        const T* area,
-        const I* cv_to_cell,
-        const T* dt_cell,
-        const I* perm,
-        unsigned n)
-{
-    const unsigned tid = threadIdx.x + blockDim.x*blockIdx.x;
-
-    if (tid<n) {
-        auto cid = cv_to_cell[tid];
-        auto dt = dt_cell[cid];
-
-        if (dt>0) {
-            // The 1e-3 is a constant of proportionality required to ensure that the
-            // conductance (gi) values have units μS (micro-Siemens).
-            // See the model documentation in docs/model for more information.
-            T factor = 1e-3/dt;
-
-            const auto gi = factor * cv_capacitance[tid];
-            const auto pid = perm[tid];
-            d[pid] = gi + invariant_d[tid];
-            rhs[pid] = gi*voltage[tid] - T(1e-3)*area[tid]*current[tid];
-        }
-        else {
-            const auto pid = perm[tid];
-            d[pid] = 0;
-            rhs[pid] = voltage[tid];
-        }
-    }
-}
-
-/// GPU implementation of Hines Matrix solver.
-/// Fine-grained tree based solver.
-/// Each block solves a set of matricesb iterating over the levels of matrix
-/// and perfoming a backward and forward substitution. On each level one thread
-/// gets assigned to one branch on this level of a matrix and solves and
-/// performs the substitution. Afterwards all threads continue on the next
-/// level.
-/// To avoid idle threads, one should try that on each level, there is a similar
-/// number of branches.
-template <typename T>
-__global__
-void solve_matrix_fine(
-    T* rhs,
-    T* d,
-    const T* u,
-    const level* levels,
-    const unsigned* levels_start,
-    unsigned* num_matrix, // number of packed matrices = number of cells
-    unsigned* padded_size)
-{
-    const auto tid = threadIdx.x;
-    const auto bid = blockIdx.x;
-
-    const auto first_level = levels_start[bid];
-    const auto num_levels  = levels_start[bid + 1] - first_level;
-    const auto block_levels = &levels[first_level];
-
-    // backward substitution
-
-    for (unsigned l=0; l<num_levels-1; ++l) {
-        const auto& lvl = block_levels[l];
-
-        const unsigned width = lvl.num_branches;
-
-        // Perform backward substitution for each branch on this level.
-        // One thread per branch.
-        if (tid < width) {
-            const unsigned len = lvl.lengths[tid];
-            unsigned pos = lvl.data_index + tid;
-
-            // Zero diagonal term implies dt==0; just leave rhs (for whole matrix)
-            // alone in that case.
-
-            // Each cell has a different `dt`, because we choose time step size
-            // according to when the next event is arriving at a cell. So, some
-            // cells require more time steps than others, but we have to solve
-            // all the matrices at the same time. When a cell finishes, we put a
-            // `0` on the diagonal to mark that it should not be solved for.
-            if (d[pos]!=0) {
-
-                // each branch perform substitution
-                T factor = u[pos] / d[pos];
-                for (unsigned i=0; i<len-1; ++i) {
-                    const unsigned next_pos = pos + width;
-                    d[next_pos]   -= factor * u[pos];
-                    rhs[next_pos] -= factor * rhs[pos];
-
-                    factor = u[next_pos] / d[next_pos];
-                    pos = next_pos;
-                }
-
-                // Update d and rhs at the parent node of this branch.
-                // A parent may have more than one contributing to it, so we use
-                // atomic updates to avoid races conditions.
-                const unsigned parent_index = block_levels[l+1].data_index;
-                const unsigned p = parent_index + lvl.parents[tid];
-                //d[p]   -= factor * u[pos];
-                cuda_atomic_add(d  +p, -factor*u[pos]);
-                //rhs[p] -= factor * rhs[pos];
-                cuda_atomic_add(rhs+p, -factor*rhs[pos]);
-            }
-        }
-        __syncthreads();
-    }
-
-    {
-        // the levels are sorted such that the root is the last level
-        const auto& lvl = block_levels[num_levels-1];
-        const unsigned width = num_matrix[bid];
-
-        if (tid < width) {
-            const unsigned len = lvl.lengths[tid];
-            unsigned pos = lvl.data_index + tid;
-
-            if (d[pos]!=0) {
-
-                // backward
-                for (unsigned i=0; i<len-1; ++i) {
-                    T factor = u[pos] / d[pos];
-                    const unsigned next_pos = pos + width;
-                    d[next_pos]   -= factor * u[pos];
-                    rhs[next_pos] -= factor * rhs[pos];
-
-                    pos = next_pos;
-                }
-
-                auto rhsp = rhs[pos] / d[pos];
-                rhs[pos] = rhsp;
-                pos -= width;
-
-                // forward
-                for (unsigned i=0; i<len-1; ++i) {
-                    rhsp = rhs[pos] - u[pos]*rhsp;
-                    rhsp /= d[pos];
-                    rhs[pos] = rhsp;
-                    pos -= width;
-                }
-            }
-        }
-    }
-
-    // forward substitution
-
-    // take great care with loop limits decrementing unsigned counter l
-    for (unsigned l=num_levels-1; l>0; --l) {
-        const auto& lvl = block_levels[l-1];
-
-        const unsigned width = lvl.num_branches;
-        const unsigned parent_index = block_levels[l].data_index;
-
-        __syncthreads();
-
-        // Perform forward-substitution for each branch on this level.
-        // One thread per branch.
-        if (tid < width) {
-            // Load the rhs value for the parent node of this branch.
-            const unsigned p = parent_index + lvl.parents[tid];
-            T rhsp = rhs[p];
-
-            // Find the index of the first node in this branch.
-            const unsigned len = lvl.lengths[tid];
-            unsigned pos = lvl.data_index + (len-1)*width + tid;
-
-            if (d[pos]!=0) {
-                // each branch perform substitution
-                for (unsigned i=0; i<len; ++i) {
-                    rhsp = rhs[pos] - u[pos]*rhsp;
-                    rhsp /= d[pos];
-                    rhs[pos] = rhsp;
-                    pos -= width;
-                }
-            }
-        }
-    }
-}
-
-} // namespace kernels
-
-void gather(
-    const fvm_value_type* from,
-    fvm_value_type* to,
-    const fvm_index_type* p,
-    unsigned n)
-{
-    constexpr unsigned blockdim = 128;
-    const unsigned griddim = impl::block_count(n, blockdim);
-
-    kernels::gather<<<griddim, blockdim>>>(from, to, p, n);
-}
-
-void scatter(
-    const fvm_value_type* from,
-    fvm_value_type* to,
-    const fvm_index_type* p,
-    unsigned n)
-{
-    constexpr unsigned blockdim = 128;
-    const unsigned griddim = impl::block_count(n, blockdim);
-
-    kernels::scatter<<<griddim, blockdim>>>(from, to, p, n);
-}
-
-
-void assemble_matrix_fine(
-    fvm_value_type* d,
-    fvm_value_type* rhs,
-    const fvm_value_type* invariant_d,
-    const fvm_value_type* voltage,
-    const fvm_value_type* current,
-    const fvm_value_type* cv_capacitance,
-    const fvm_value_type* area,
-    const fvm_index_type* cv_to_cell,
-    const fvm_value_type* dt_cell,
-    const fvm_index_type* perm,
-    unsigned n)
-{
-    const unsigned block_dim = 128;
-    const unsigned num_blocks = impl::block_count(n, block_dim);
-
-    kernels::assemble_matrix_fine<<<num_blocks, block_dim>>>(
-        d, rhs, invariant_d, voltage, current, cv_capacitance, area,
-        cv_to_cell, dt_cell,
-        perm, n);
-}
-
-// Example:
-//
-//         block 0                  block 1              block 2
-// .~~~~~~~~~~~~~~~~~~.  .~~~~~~~~~~~~~~~~~~~~~~~~.  .~~~~~~~~~~~ ~ ~
-//
-//  L0 \  /                                           L5    \  /
-//      \/                                                   \/
-//  L1   \   /   \   /    L3 \   /   \ | /   \   /    L6 \   /  . . .
-//        \ /     \ /         \ /     \|/     \ /         \ /
-//  L2     |       |      L4   |       |       |      L7   |
-//         |       |           |       |       |           |
-//
-// levels       = [L0, L1, L2, L3, L4, L5, L6, L7, ... ]
-// levels_start = [0, 3, 5, 8, ...]
-// num_levels   = [3, 2, 3, ...]
-// num_cells    = [2, 3, ...]
-// num_blocks   = level_start.size() - 1 = num_levels.size() = num_cells.size()
-void solve_matrix_fine(
-    fvm_value_type* rhs,
-    fvm_value_type* d,                // diagonal values
-    const fvm_value_type* u,          // upper diagonal (and lower diagonal as the matrix is SPD)
-    const level* levels,              // pointer to an array containing level meta-data for all blocks
-    const unsigned* levels_start,     // start index into levels for each cuda block
-    unsigned* num_cells,              // he number of cells packed into this single matrix
-    unsigned* padded_size,            // length of rhs, d, u, including padding
-    unsigned num_blocks,              // nuber of blocks
-    unsigned blocksize)               // size of each block
-{
-    kernels::solve_matrix_fine<<<num_blocks, blocksize>>>(
-        rhs, d, u, levels, levels_start,
-        num_cells, padded_size);
-}
-
-} // namespace gpu
-} // namespace arb
diff --git a/arbor/backends/gpu/matrix_fine.hpp b/arbor/backends/gpu/matrix_fine.hpp
deleted file mode 100644
index 547b80e2..00000000
--- a/arbor/backends/gpu/matrix_fine.hpp
+++ /dev/null
@@ -1,77 +0,0 @@
-#include <arbor/fvm_types.hpp>
-
-#include <ostream>
-
-namespace arb {
-namespace gpu {
-
-struct level {
-    level() = default;
-
-    level(unsigned branches);
-    level(level&& other);
-    level(const level& other);
-
-    ~level();
-
-    unsigned num_branches = 0; // Number of branches
-    unsigned max_length = 0;   // Length of the longest branch
-    unsigned data_index = 0;   // Index into data values of the first branch
-
-    //  The lengths and parents vectors are raw pointers to managed memory,
-    //  so there is need for tricksy deep copy of this type to GPU.
-
-    // An array holding the length of each branch in the level.
-    // length: num_branches.
-    unsigned* lengths = nullptr;
-
-    // An array with the index of the parent branch for each branch on this level.
-    // length: num_branches.
-    // When performing backward/forward substitution we need to update/read
-    // data values for the parent node for each branch.
-    // This can be done easily if we know where the parent branch is located
-    // on the next level.
-    unsigned* parents = nullptr;
-};
-
-std::ostream& operator<<(std::ostream& o, const level& l);
-
-// C wrappers around kernels
-void gather(
-    const fvm_value_type* from,
-    fvm_value_type* to,
-    const fvm_index_type* p,
-    unsigned n);
-
-void scatter(
-    const fvm_value_type* from,
-    fvm_value_type* to,
-    const fvm_index_type* p,
-    unsigned n);
-
-void assemble_matrix_fine(
-    fvm_value_type* d,
-    fvm_value_type* rhs,
-    const fvm_value_type* invariant_d,
-    const fvm_value_type* voltage,
-    const fvm_value_type* current,
-    const fvm_value_type* cv_capacitance,
-    const fvm_value_type* area,
-    const fvm_index_type* cv_to_cell,
-    const fvm_value_type* dt_cell,
-    const fvm_index_type* perm,
-    unsigned n);
-
-void solve_matrix_fine(
-    fvm_value_type* rhs,
-    fvm_value_type* d,                // diagonal values
-    const fvm_value_type* u,          // upper diagonal (and lower diagonal as the matrix is SPD)
-    const level* levels,              // pointer to an array containing level meta-data for all blocks
-    const unsigned* levels_end,       // end index (exclusive) into levels for each cuda block
-    unsigned* num_cells,              // he number of cells packed into this single matrix
-    unsigned* padded_size,            // length of rhs, d, u, including padding
-    unsigned num_blocks,              // nuber of blocks
-    unsigned blocksize);              // size of each block
-
-} // namespace gpu
-} // namespace arb
diff --git a/arbor/backends/gpu/matrix_solve.cu b/arbor/backends/gpu/matrix_solve.cu
index 7fa719c8..1a9ab8c6 100644
--- a/arbor/backends/gpu/matrix_solve.cu
+++ b/arbor/backends/gpu/matrix_solve.cu
@@ -10,7 +10,6 @@ namespace kernels {
 
 /// GPU implementation of Hines Matrix solver.
 /// Flat format
-/// p: parent index for each variable. Needed for backward and forward sweep
 template <typename T, typename I>
 __global__
 void solve_matrix_flat(
diff --git a/arbor/backends/gpu/matrix_state_fine.hpp b/arbor/backends/gpu/matrix_state_fine.hpp
deleted file mode 100644
index c457fe24..00000000
--- a/arbor/backends/gpu/matrix_state_fine.hpp
+++ /dev/null
@@ -1,494 +0,0 @@
-#pragma once
-
-#include <cstring>
-
-#include <vector>
-#include <type_traits>
-
-#include <arbor/common_types.hpp>
-
-#include "algorithms.hpp"
-#include "memory/memory.hpp"
-#include "util/partition.hpp"
-#include "util/rangeutil.hpp"
-#include "util/span.hpp"
-#include "tree.hpp"
-
-#include "matrix_fine.hpp"
-#include "forest.hpp"
-
-namespace arb {
-namespace gpu {
-
-// Helper type for branch meta data in setup phase of fine grained
-// matrix storage+solver.
-//
-//      leaf
-//      .
-//      .
-//      .
-//  -   *
-//      |
-//  l   *
-//  e   |
-//  n   *
-//  g   |
-//  t   *
-//  h   |
-//  -   start_idx
-//      |
-//      parent_idx
-//      |
-//      .
-//      .
-//      .
-//      root
-struct branch {
-    unsigned id;         // branch id
-    unsigned parent_id;  // parent branch id
-    unsigned parent_idx; //
-    unsigned start_idx;  // the index of the first node in the input parent index
-    unsigned length;     // the number of nodes in the branch
-};
-
-// order branches by:
-//  - descending length
-//  - ascending id
-inline
-bool operator<(const branch& lhs, const branch& rhs) {
-    if (lhs.length!=rhs.length) {
-        return lhs.length>rhs.length;
-    } else {
-        return lhs.id<rhs.id;
-    }
-}
-
-inline
-std::ostream& operator<<(std::ostream& o, branch b) {
-    return o << "[" << b.id
-        << ", len " << b.length
-        << ", pid " << b.parent_idx
-        << ", sta " << b.start_idx
-        << "]";
-}
-
-template <typename T, typename I>
-struct matrix_state_fine {
-public:
-    using value_type = T;
-    using size_type = I;
-
-    using array      = memory::device_vector<value_type>;
-    using view       = typename array::view_type;
-    using const_view = typename array::const_view_type;
-    using iarray     = memory::device_vector<size_type>;
-
-    template <typename ValueType>
-    using managed_vector = std::vector<ValueType, memory::managed_allocator<ValueType>>;
-
-    iarray cv_to_cell;
-
-    array d;     // [μS]
-    array u;     // [μS]
-    array rhs;   // [nA]
-
-    // required for matrix assembly
-
-    array cv_area; // [μm^2]
-
-    array cv_capacitance;      // [pF]
-
-    // the invariant part of the matrix diagonal
-    array invariant_d;         // [μS]
-
-    // for storing the solution in unpacked format
-    array solution_;
-
-    // the maximum nuber of branches in each level per block
-    unsigned max_branches_per_level;
-
-    // number of rows in matrix
-    unsigned matrix_size;
-
-    // number of cells
-    unsigned num_cells;
-    managed_vector<unsigned> num_cells_in_block;
-
-    // end of the data of each level
-    // use data_size.back() to get the total data size
-    //      data_size >= size
-    managed_vector<unsigned> data_size; // TODO rename
-
-    // the meta data for each level for each block layed out linearly in memory
-    managed_vector<level> levels;
-    // the start of the levels of each block
-    // block b owns { leves[level_start[b]], ..., leves[level_start[b+1] - 1] }
-    // there is an additional entry at the end of the vector to make the above
-    // compuation save
-    managed_vector<unsigned> levels_start;
-
-    // permutation from front end storage to packed storage
-    //      `solver_format[perm[i]] = external_format[i]`
-    iarray perm;
-
-
-    matrix_state_fine() = default;
-
-    // constructor for fine-grained matrix.
-    matrix_state_fine(const std::vector<size_type>& p,
-                 const std::vector<size_type>& cell_cv_divs,
-                 const std::vector<value_type>& cap,
-                 const std::vector<value_type>& face_conductance,
-                 const std::vector<value_type>& area)
-    {
-        using util::make_span;
-        constexpr unsigned npos = unsigned(-1);
-
-        max_branches_per_level = 128;
-
-        // for now we have single cell per cell group
-        arb_assert(cell_cv_divs.size()==2);
-
-        num_cells = cell_cv_divs.size()-1;
-
-        forest trees(p, cell_cv_divs);
-        trees.optimize();
-
-        // Now distribute the cells into cuda blocks.
-        // While the total number of branches on each level of theses cells in a
-        // block are less than `max_branches_per_level` we add more cells. If
-        // one block is full, we start a new cuda block.
-
-        unsigned current_block = 0;
-        std::vector<unsigned> block_num_branches_per_depth;
-        std::vector<unsigned> block_ix(num_cells);
-        num_cells_in_block.resize(1, 0);
-
-        // branch_map = branch_maps[block] is a branch map for each cuda block
-        // branch_map[depth] is list of branches is this level
-        // each branch branch_map[depth][i] has
-        // {id, parent_id, start_idx, parent_idx, length}
-        std::vector<std::vector<std::vector<branch>>> branch_maps;
-        branch_maps.resize(1);
-
-        unsigned num_branches = 0u;
-        for (auto c: make_span(0u, num_cells)) {
-            auto cell_start = cell_cv_divs[c];
-            auto cell_tree = trees.branch_tree(c);
-            auto fine_tree = trees.compartment_tree(c);
-            auto branch_starts  = trees.branch_offsets(c);
-            auto branch_lengths = trees.branch_lengths(c);
-
-            auto depths = depth_from_root(cell_tree);
-
-            // calculate the number of levels in this cell
-            auto cell_num_levels = util::max_value(depths)+1u;
-
-            auto num_cell_branches = cell_tree.num_segments();
-
-            // count number of branches per level
-            std::vector<unsigned> cell_num_branches_per_depth(cell_num_levels, 0u);
-            for (auto i: make_span(num_cell_branches)) {
-                cell_num_branches_per_depth[depths[i]] += 1;
-            }
-            // resize the block levels if neccessary
-            if (cell_num_levels > block_num_branches_per_depth.size()) {
-                block_num_branches_per_depth.resize(cell_num_levels, 0);
-            }
-
-
-            // check if we can fit the current cell into the last cuda block
-            bool fits_current_block = true;
-            for (auto i: make_span(cell_num_levels)) {
-                unsigned new_branches_per_depth =
-                    block_num_branches_per_depth[i]
-                    + cell_num_branches_per_depth[i];
-                if (new_branches_per_depth > max_branches_per_level) {
-                    fits_current_block = false;
-                }
-            }
-            if (fits_current_block) {
-                // put the cell into current block
-                block_ix[c] = current_block;
-                num_cells_in_block[block_ix[c]] += 1;
-                // and increment counter
-                for (auto i: make_span(cell_num_levels)) {
-                    block_num_branches_per_depth[i] += cell_num_branches_per_depth[i];
-                }
-            } else {
-                // otherwise start a new block
-                block_ix[c] = current_block + 1;
-                num_cells_in_block.push_back(1);
-                branch_maps.resize(branch_maps.size()+1);
-                current_block += 1;
-                // and reset counter
-                block_num_branches_per_depth = cell_num_branches_per_depth;
-                for (auto num: block_num_branches_per_depth) {
-                    if (num > max_branches_per_level) {
-                        throw std::runtime_error(
-                            "Could not fit " + std::to_string(num)
-                            + " branches in a block of size "
-                            + std::to_string(max_branches_per_level));
-                    }
-                }
-            }
-
-
-            // the branch map for the block in which we put the cell
-            // maps levels to a list of branches in that level
-            auto& branch_map = branch_maps[block_ix[c]];
-
-            // build branch_map:
-            // branch_map[i] is a list of branch meta-data for branches with depth i
-            if (cell_num_levels > branch_map.size()) {
-                branch_map.resize(cell_num_levels);
-            }
-            for (auto i: make_span(num_cell_branches)) {
-                branch b;
-                auto depth = depths[i];
-                // give the branch a unique id number
-                b.id = i + num_branches;
-                // take care to mark branches with no parents with npos
-                b.parent_id = cell_tree.parent(i)==cell_tree.no_parent ?
-                    npos : cell_tree.parent(i) + num_branches;
-                b.start_idx = branch_starts[i] + cell_start;
-                b.length = branch_lengths[i];
-                b.parent_idx = p[b.start_idx] + cell_start;
-                branch_map[depth].push_back(b);
-            }
-            // total number of branches of all cells
-            num_branches += num_cell_branches;
-        }
-
-        for (auto& branch_map: branch_maps) {
-            // reverse the levels
-            std::reverse(branch_map.begin(), branch_map.end());
-
-            // Sort all branches on each level in descending order of length.
-            // Later, branches will be partitioned over thread blocks, and we will
-            // take advantage of the fact that the first branch in a partition is
-            // the longest, to determine how to pack all the branches in a block.
-            for (auto& branches: branch_map) {
-                util::sort(branches);
-            }
-        }
-
-        // The branches generated above have been assigned contiguous ids.
-        // Now generate a vector of branch_loc, one for each branch, that
-        // allow for quick lookup by id of the level and index within a level
-        // of each branch.
-        // This information is only used in the generation of the levels below.
-
-        // Helper for recording location of a branch once packed.
-        struct branch_loc {
-            unsigned block; // the cuda block containing the cell to which the branch blongs to
-            unsigned level; // the level containing the branch
-            unsigned index; // the index of the branch on that level
-        };
-
-        // branch_locs will hold the location information for each branch.
-        std::vector<branch_loc> branch_locs(num_branches);
-        for (unsigned b: make_span(branch_maps.size())) {
-            const auto& branch_map = branch_maps[b];
-            for (unsigned l: make_span(branch_map.size())) {
-                const auto& branches = branch_map[l];
-
-                // Record the location information
-                for (auto i=0u; i<branches.size(); ++i) {
-                    const auto& branch = branches[i];
-                    branch_locs[branch.id] = {b, l, i};
-                }
-            }
-        }
-
-        unsigned total_num_levels = std::accumulate(
-            branch_maps.begin(), branch_maps.end(), 0,
-            [](unsigned value, decltype(branch_maps[0])& l) {
-                return value + l.size();});
-
-        // construct description for the set of branches on each level for each
-        // block. This is later used to sort the branches in each block in each
-        // level into conineous chunks which are easier to read for the cuda
-        // kernel.
-        levels.reserve(total_num_levels);
-        levels_start.reserve(branch_maps.size() + 1);
-        levels_start.push_back(0);
-        data_size.reserve(branch_maps.size());
-        // offset into the packed data format, used to apply permutation on data
-        auto pos = 0u;
-        for (const auto& branch_map: branch_maps) {
-            for (const auto& lvl_branches: branch_map) {
-
-                level lvl(lvl_branches.size());
-
-                // The length of the first branch is the upper bound on branch
-                // length as they are sorted in descending order of length.
-                lvl.max_length = lvl_branches.front().length;
-                lvl.data_index = pos;
-
-                unsigned bi = 0u;
-                for (const auto& b: lvl_branches) {
-                    // Set the length of the branch.
-                    lvl.lengths[bi] = b.length;
-
-                    // Set the parent indexes. During the forward and backward
-                    // substitution phases each branch accesses the last node in
-                    // its parent branch.
-                    auto index = b.parent_id==npos? npos: branch_locs[b.parent_id].index;
-                    lvl.parents[bi] = index;
-                    ++bi;
-                }
-
-                pos += lvl.max_length*lvl.num_branches;
-
-                levels.push_back(std::move(lvl));
-            }
-            auto prev_end = levels_start.back();
-            levels_start.push_back(prev_end + branch_map.size());
-            data_size.push_back(pos);
-        }
-
-        // set matrix state
-        matrix_size = p.size();
-
-        // form the permutation index used to reorder vectors to/from the
-        // ordering used by the fine grained matrix storage.
-        std::vector<size_type> perm_tmp(matrix_size);
-        for (auto block: make_span(branch_maps.size())) {
-            const auto& branch_map = branch_maps[block];
-            const auto first_level = levels_start[block];
-
-            for (auto i: make_span(levels_start[block + 1] - first_level)) {
-                const auto& l = levels[first_level + i];
-                for (auto j: make_span(l.num_branches)) {
-                    const auto& b = branch_map[i][j];
-                    auto to = l.data_index + j + l.num_branches*(l.lengths[j]-1);
-                    auto from = b.start_idx;
-                    for (auto k: make_span(b.length)) {
-                        perm_tmp[from + k] = to - k*l.num_branches;
-                    }
-                }
-            }
-        }
-
-        auto perm_balancing = trees.permutation();
-
-        // apppy permutation form balancing
-        std::vector<size_type> perm_tmp2(matrix_size);
-        for (auto i: make_span(matrix_size)) {
-             // This is CORRECT! verified by using the ring benchmark with root=0 (where the permutation is actually not id)
-            perm_tmp2[perm_balancing[i]] = perm_tmp[i];
-        }
-        // copy permutation to device memory
-        perm = memory::make_const_view(perm_tmp2);
-
-
-        // Summary of fields and their storage format:
-        //
-        // face_conductance : not needed, don't store
-        // d, u, rhs        : packed
-        // cv_capacitance   : flat
-        // invariant_d      : flat
-        // solution_        : flat
-        // cv_to_cell       : flat
-        // area             : flat
-
-        // the invariant part of d is stored in in flat form
-        std::vector<value_type> invariant_d_tmp(matrix_size, 0);
-        managed_vector<value_type> u_tmp(matrix_size, 0);
-        for (auto i: make_span(1u, matrix_size)) {
-            auto gij = face_conductance[i];
-
-            u_tmp[i] = -gij;
-            invariant_d_tmp[i] += gij;
-            invariant_d_tmp[p[i]] += gij;
-        }
-
-        // the matrix components u, d and rhs are stored in packed form
-        auto nan = std::numeric_limits<double>::quiet_NaN();
-        d   = array(data_size.back(), nan);
-        u   = array(data_size.back(), nan);
-        rhs = array(data_size.back(), nan);
-
-        // transform u_tmp values into packed u vector.
-        flat_to_packed(u_tmp, u);
-
-        // the invariant part of d, cv_area and the solution are in flat form
-        solution_ = array(matrix_size, 0);
-        cv_area = memory::make_const_view(area);
-
-        // the cv_capacitance can be copied directly because it is
-        // to be stored in flat format
-        cv_capacitance = memory::make_const_view(cap);
-        invariant_d = memory::make_const_view(invariant_d_tmp);
-
-        // calculte the cv -> cell mappings
-        std::vector<size_type> cv_to_cell_tmp(matrix_size);
-        size_type ci = 0;
-        for (auto cv_span: util::partition_view(cell_cv_divs)) {
-            util::fill(util::subrange_view(cv_to_cell_tmp, cv_span), ci);
-            ++ci;
-        }
-        cv_to_cell = memory::make_const_view(cv_to_cell_tmp);
-    }
-
-    // Assemble the matrix
-    // Afterwards the diagonal and RHS will have been set given dt, voltage and current
-    //   dt_cell [ms] (per cell)
-    //   voltage [mV]
-    //   current [nA]
-    void assemble(const_view dt_cell, const_view voltage, const_view current) {
-        assemble_matrix_fine(
-            d.data(),
-            rhs.data(),
-            invariant_d.data(),
-            voltage.data(),
-            current.data(),
-            cv_capacitance.data(),
-            cv_area.data(),
-            cv_to_cell.data(),
-            dt_cell.data(),
-            perm.data(),
-            size());
-    }
-
-    void solve() {
-        solve_matrix_fine(
-            rhs.data(), d.data(), u.data(),
-            levels.data(), levels_start.data(),
-            num_cells_in_block.data(),
-            data_size.data(),
-            num_cells_in_block.size(), max_branches_per_level);
-
-        // unpermute the solution
-        packed_to_flat(rhs, solution_);
-    }
-
-    const_view solution() const {
-        return solution_;
-    }
-
-    template <typename VFrom, typename VTo>
-    void flat_to_packed(const VFrom& from, VTo& to ) {
-        arb_assert(from.size()==matrix_size);
-        arb_assert(to.size()==data_size.back());
-
-        scatter(from.data(), to.data(), perm.data(), perm.size());
-    }
-
-    template <typename VFrom, typename VTo>
-    void packed_to_flat(const VFrom& from, VTo& to ) {
-        arb_assert(from.size()==data_size.back());
-        arb_assert(to.size()==matrix_size);
-
-        gather(from.data(), to.data(), perm.data(), perm.size());
-    }
-
-private:
-    std::size_t size() const {
-        return matrix_size;
-    }
-};
-
-} // namespace gpu
-} // namespace arb
diff --git a/arbor/tree.cpp b/arbor/tree.cpp
deleted file mode 100644
index a4bc28d6..00000000
--- a/arbor/tree.cpp
+++ /dev/null
@@ -1,385 +0,0 @@
-#include <algorithm>
-#include <cassert>
-#include <numeric>
-#include <queue>
-#include <vector>
-
-#include <arbor/common_types.hpp>
-
-#include "algorithms.hpp"
-#include "memory/memory.hpp"
-#include "tree.hpp"
-#include "util/span.hpp"
-
-namespace arb {
-
-tree::tree(std::vector<tree::int_type> parent_index) {
-    // validate the input
-    if(!algorithms::is_minimal_degree(parent_index)) {
-        throw std::domain_error(
-            "parent index used to build a tree did not satisfy minimal degree ordering"
-        );
-    }
-
-    // an empty parent_index implies a single-compartment/segment cell
-    arb_assert(parent_index.size()!=0u);
-
-    init(parent_index.size());
-    memory::copy(parent_index, parents_);
-    parents_[0] = no_parent;
-
-    // compute offsets into children_ array
-    memory::copy(algorithms::make_index(algorithms::child_count(parents_)), child_index_);
-
-    std::vector<int_type> pos(parents_.size(), 0);
-    for (auto i = 1u; i < parents_.size(); ++i) {
-        auto p = parents_[i];
-        children_[child_index_[p] + pos[p]] = i;
-        ++pos[p];
-    }
-}
-
-tree::size_type tree::num_children() const {
-    return static_cast<size_type>(children_.size());
-}
-
-tree::size_type tree::num_children(size_t b) const {
-    return child_index_[b+1] - child_index_[b];
-}
-
-tree::size_type tree::num_segments() const {
-    // the number of segments/nodes is the size of the child index minus 1
-    // ... except for the case of an empty tree
-    auto sz = static_cast<size_type>(child_index_.size());
-    return sz ? sz - 1 : 0;
-}
-
-const tree::iarray& tree::child_index() {
-    return child_index_;
-}
-
-const tree::iarray& tree::children() const {
-    return children_;
-}
-
-const tree::iarray& tree::parents() const {
-    return parents_;
-}
-
-const tree::int_type& tree::parent(size_t b) const {
-    return parents_[b];
-}
-tree::int_type& tree::parent(size_t b) {
-    return parents_[b];
-}
-
-tree::int_type tree::split_node(int_type ix) {
-    using util::make_span;
-
-    auto insert_at_p  = parents_.begin() + ix;
-    auto insert_at_ci = child_index_.begin() + ix;
-    auto insert_at_c  = children_.begin() + child_index_[ix];
-    auto new_node_ix  = ix;
-
-    // we first adjust the parent sructure
-
-    // first create a new node N below the parent
-    auto parent = parents_[ix];
-    parents_.insert(insert_at_p, parent);
-    // and attach the remining subtree below it
-    parents_[ix+1] = new_node_ix;
-    // shift all parents, as the indices changed when we
-    // inserted a new node
-    for (auto i: make_span(ix + 2, parents().size())) {
-        if (parents_[i] >= new_node_ix) {
-            parents_[i]++;
-        }
-    }
-
-    // now we adjust the children structure
-
-    // insert a child node for the new node N, pointing to
-    // the old node A
-    child_index_.insert(insert_at_ci, child_index_[ix]);
-    // we will set this value later as it will be overridden
-    children_.insert(insert_at_c, ~0u);
-    // shift indices for all larger indices, as we inserted
-    // a new element in the list
-    for (auto i: make_span(ix + 1, child_index_.size())) {
-        child_index_[i]++;
-    }
-    for (auto i: make_span(0, children_.size())) {
-        if(children_[i] > new_node_ix) {
-            children_[i]++;
-        }
-    }
-    // set the children of the new node to the old subtree
-    children_[child_index_[ix]] = ix + 1;
-
-    return ix+1;
-}
-
-tree::iarray tree::select_new_root(int_type root) {
-    using util::make_span;
-
-    const auto num_nodes = parents().size();
-
-    if(root >= num_nodes && root != no_parent) {
-        throw std::domain_error(
-            "root is out of bounds: root="+std::to_string(root)+", nodes="
-            +std::to_string(num_nodes)
-        );
-    }
-
-    // walk up to the old root and turn `parent->child` into `parent<-child`
-    auto prev = no_parent;
-    auto current = root;
-    while (current != no_parent) {
-        auto parent = parents_[current];
-        parents_[current] = prev;
-        prev = current;
-        current = parent;
-    }
-
-    // sort the list to get min degree ordering and keep index such that we
-    // can sort also the `branch_starts` array.
-
-    // comput the depth for each node
-    iarray depth (num_nodes, 0);
-    // the depth when we don't count nodes that only have one child
-    iarray reduced_depth (num_nodes, 0);
-    // the index of the last node that passed this node on its way to the root
-    // we need this to keep nodes that are part of the same reduced tree close
-    // together in the final sorting.
-    //
-    // Instead of the left order we want the right order.
-    // .-----------------------.
-    // |    0            0     |
-    // |   / \          / \    |
-    // |  1   2        1   3   |
-    // |  |   |        |   |   |
-    // |  3   4        2   4   |
-    // |     / \          / \  |
-    // |    5  6         5  6  |
-    // '-----------------------'
-    //
-    // we achieve this by first sorting by reduced_depth, branch_ix and depth
-    // in this order. The resulting ordering satisfies minimal degree ordering.
-    //
-    // Using the tree above we would get the following results:
-    //
-    // `depth`           `reduced_depth`   `branch_ix`
-    // .-----------.     .-----------.     .-----------.
-    // |    0      |     |    0      |     |    6      |
-    // |   / \     |     |   / \     |     |   / \     |
-    // |  1   1    |     |  1   1    |     |  2   6    |
-    // |  |   |    |     |  |   |    |     |  |   |    |
-    // |  2   2    |     |  1   1    |     |  2   6    |
-    // |     / \   |     |     / \   |     |     / \   |
-    // |    3   3  |     |    2   2  |     |    5   6  |
-    // '-----------'     '-----------'     '-----------'
-    iarray branch_ix (num_nodes, 0);
-    // we cannot use the existing `children_` array as we only updated the
-    // parent structure yet
-    auto new_num_children = algorithms::child_count(parents_);
-    for (auto n: make_span(num_nodes)) {
-        branch_ix[n] = n;
-        auto prev = n;
-        auto curr = parents_[n];
-
-        // find the way to the root
-        while (curr != no_parent) {
-            depth[n]++;
-            if (new_num_children[curr] > 1) {
-                reduced_depth[n]++;
-            }
-            branch_ix[curr] = branch_ix[prev];
-            curr = parents_[curr];
-        }
-    }
-
-    // maps new indices to old indices
-    iarray indices (num_nodes);
-    // fill array with indices
-    for (auto i: make_span(num_nodes)) {
-        indices[i] = i;
-    }
-    // perform sort by depth index to get the permutation
-    std::sort(indices.begin(), indices.end(), [&](auto i, auto j){
-        if (reduced_depth[i] != reduced_depth[j]) {
-            return reduced_depth[i] < reduced_depth[j];
-        }
-        if (branch_ix[i] != branch_ix[j]) {
-            return branch_ix[i] < branch_ix[j];
-        }
-        return depth[i] < depth[j];
-    });
-    // maps old indices to new indices
-    iarray indices_inv (num_nodes);
-    // fill array with indices
-    for (auto i: make_span(num_nodes)) {
-        indices_inv[i] = i;
-    }
-    // perform sort
-    std::sort(indices_inv.begin(), indices_inv.end(), [&](auto i, auto j){
-        return indices[i] < indices[j];
-    });
-
-    // translate the parent vetor to new indices
-    for (auto i: make_span(num_nodes)) {
-        if (parents_[i] != no_parent) {
-            parents_[i] = indices_inv[parents_[i]];
-        }
-    }
-
-    iarray new_parents (num_nodes);
-    for (auto i: make_span(num_nodes)) {
-        new_parents[i] = parents_[indices[i]];
-    }
-    // parent now starts with the root, then it's children, then their
-    // children, etc...
-
-    // recompute the children array
-    memory::copy(new_parents, parents_);
-    memory::copy(algorithms::make_index(algorithms::child_count(parents_)), child_index_);
-
-    std::vector<int_type> pos(parents_.size(), 0);
-    for (auto i = 1u; i < parents_.size(); ++i) {
-        auto p = parents_[i];
-        children_[child_index_[p] + pos[p]] = i;
-        ++pos[p];
-    }
-
-    return indices;
-}
-
-tree::iarray tree::minimize_depth() {
-    const auto num_nodes = parents().size();
-    tree::iarray seen(num_nodes, 0);
-
-    // find the furhtest node from the root
-    std::queue<tree::int_type> queue;
-    queue.push(0); // start at the root node
-    seen[0] = 1;
-    auto front = queue.front();
-    // breath first traversal
-    while (!queue.empty()) {
-        front = queue.front();
-        queue.pop();
-        // we only have to check children as we started at the root node
-        auto cs = children(front);
-        for (auto c: cs) {
-            if (seen[c] == 0) {
-                seen[c] = 1;
-                queue.push(c);
-            }
-        }
-    }
-
-    auto u = front;
-
-    // find the furhtest node from this node
-    std::fill(seen.begin(), seen.end(), 0);
-    queue.push(u);
-    seen[u] = 1;
-    front = queue.front();
-    // breath first traversal
-    while (!queue.empty()) {
-        front = queue.front();
-        queue.pop();
-        auto cs = children(front);
-        for (auto c: cs) {
-            if (seen[c] == 0) {
-                seen[c] = 1;
-                queue.push(c);
-            }
-        }
-        // also check the partent node!
-        auto c = parent(front);
-        if (c != tree::no_parent && seen[c] == 0) {
-            seen[c] = 1;
-            queue.push(c);
-        }
-    }
-
-    auto v = front;
-
-    // now find the middle between u and v
-
-    // each path to the root
-    tree::iarray path_to_root_u (1, u);
-    tree::iarray path_to_root_v (1, v);
-
-    auto curr = parent(u);
-    while (curr != tree::no_parent) {
-        path_to_root_u.push_back(curr);
-        curr = parent(curr);
-    }
-    curr = parent(v);
-    while (curr != tree::no_parent) {
-        path_to_root_v.push_back(curr);
-        curr = parent(curr);
-    }
-
-    // reduce the path
-    auto last_together = 0;
-    while (path_to_root_u.back() == path_to_root_v.back()) {
-        last_together = path_to_root_u.back();
-        path_to_root_u.pop_back();
-        path_to_root_v.pop_back();
-    }
-    path_to_root_u.push_back(last_together);
-
-    auto path_length = path_to_root_u.size() + path_to_root_v.size() - 1;
-
-    // walk up half of the path length to find the middle node
-    tree::int_type root;
-    if (path_to_root_u.size() > path_to_root_v.size()) {
-        root = path_to_root_u[path_length / 2];
-    } else {
-        root = path_to_root_v[path_length / 2];
-    }
-
-    return select_new_root(root);
-}
-
-/// memory used to store tree (in bytes)
-std::size_t tree::memory() const {
-    return sizeof(int_type)*(children_.size()+child_index_.size()+parents_.size())
-        + sizeof(tree);
-}
-
-void tree::init(tree::size_type nnode) {
-    if (nnode) {
-        auto nchild = nnode - 1;
-        children_.resize(nchild);
-        child_index_.resize(nnode+1);
-        parents_.resize(nnode);
-    }
-    else {
-        children_.resize(0);
-        child_index_.resize(0);
-        parents_.resize(0);
-    }
-}
-
-// recursive helper for the depth_from_root() below
-void depth_from_root(const tree& t, tree::iarray& depth, tree::int_type segment) {
-    auto d = depth[t.parent(segment)] + 1;
-    depth[segment] = d;
-    for(auto c : t.children(segment)) {
-        depth_from_root(t, depth, c);
-    }
-}
-
-tree::iarray depth_from_root(const tree& t) {
-    tree::iarray depth(t.num_segments());
-    depth[0] = 0;
-    for (auto c: t.children(0)) {
-        depth_from_root(t, depth, c);
-    }
-
-    return depth;
-}
-
-} // namespace arb
diff --git a/arbor/tree.hpp b/arbor/tree.hpp
index da55f2a9..5adb6f6d 100644
--- a/arbor/tree.hpp
+++ b/arbor/tree.hpp
@@ -2,7 +2,6 @@
 
 #include <algorithm>
 #include <cassert>
-#include <fstream>
 #include <numeric>
 #include <vector>
 
@@ -24,21 +23,81 @@ public:
 
     tree() = default;
 
+    tree& operator=(tree&& other) {
+        std::swap(child_index_, other.child_index_);
+        std::swap(children_, other.children_);
+        std::swap(parents_, other.parents_);
+        return *this;
+    }
+
+    tree& operator=(tree const& other) {
+        children_ = other.children_;
+        child_index_ = other.child_index_;
+        parents_ = other.child_index_;
+        return *this;
+    }
+
+    // copy constructors take advantage of the assignment operators
+    // defined above
+    tree(tree const& other) {
+        *this = other;
+    }
+
+    tree(tree&& other) {
+        *this = std::move(other);
+    }
+
     /// Create the tree from a parent index that lists the parent segment
     /// of each segment in a cell tree.
-    tree(std::vector<int_type> parent_index);
+    tree(std::vector<int_type> parent_index) {
+        // validate the input
+        if(!algorithms::is_minimal_degree(parent_index)) {
+            throw std::domain_error(
+                "parent index used to build a tree did not satisfy minimal degree ordering"
+            );
+        }
+
+        // an empty parent_index implies a single-compartment/segment cell
+        arb_assert(parent_index.size()!=0u);
 
-    size_type num_children() const;
+        init(parent_index.size());
+        memory::copy(parent_index, parents_);
+        parents_[0] = no_parent;
 
-    size_type num_children(size_t b) const;
+        memory::copy(algorithms::make_index(algorithms::child_count(parents_)), child_index_);
 
-    size_type num_segments() const;
+        std::vector<int_type> pos(parents_.size(), 0);
+        for (auto i = 1u; i < parents_.size(); ++i) {
+            auto p = parents_[i];
+            children_[child_index_[p] + pos[p]] = i;
+            ++pos[p];
+        }
+    }
+
+    size_type num_children() const {
+        return static_cast<size_type>(children_.size());
+    }
+
+    size_type num_children(size_t b) const {
+        return child_index_[b+1] - child_index_[b];
+    }
+
+    size_type num_segments() const {
+        // the number of segments/nodes is the size of the child index minus 1
+        // ... except for the case of an empty tree
+        auto sz = static_cast<size_type>(child_index_.size());
+        return sz ? sz - 1 : 0;
+    }
 
     /// return the child index
-    const iarray& child_index();
+    const iarray& child_index() {
+        return child_index_;
+    }
 
     /// return the list of all children
-    const iarray& children() const;
+    const iarray& children() const {
+        return children_;
+    }
 
     /// return the list of all children of branch i
     auto children(size_type i) const {
@@ -48,62 +107,38 @@ public:
     }
 
     /// return the list of parents
-    const iarray& parents() const;
+    const iarray& parents() const {
+        return parents_;
+    }
 
     /// return the parent of branch b
-    const int_type& parent(size_t b) const;
-    int_type& parent(size_t b);
-
-    // splits the node in two parts. Returns `ix + 1` which is the new index of
-    // the old node.
-    // .-------------------.
-    // |      P         P  |
-    // |     /         /   |
-    // |    A   ~~>   N    |
-    // |   / \        |    |
-    // |  B  C        A    |
-    // |             / \   |
-    // |            B  C   |
-    // '-------------------'
-    int_type split_node(int_type ix);
-
-    // Changes the root node of a tree
-    // .------------------------.
-    // |        A               |
-    // |       / \         R    |
-    // |      R  B        / \   |
-    // |     /     ~~>   A  C   |
-    // |    C           /  / \  |
-    // |   / \         B  D  E  |
-    // |  D  E                  |
-    // '------------------------'
-    // Returns the permutation applied to the nodes,
-    // i.e. `new_node_data[i] = old_node_data[perm[i]]`
-    //
-    // This function has the additional effect, that branches with only one
-    // child branch get merged. That means that `select_new_root(0)` can also
-    // lead to an permutation of the indices of the compartments:
-    // .------------------------------.
-    // |        0               0     |
-    // |       / \             / \    |
-    // |      1  3            1  3    |
-    // |     /    \   ~~>    /    \   |
-    // |    2     4         2     4   |
-    // |   / \     \       / \     \  |
-    // |  5  6     7      6  7     5  |
-    // '------------------------------'
-    iarray select_new_root(int_type root);
-
-    // Selects a new node such that the depth of the graph is minimal.
-    // Returns the permutation applied to the nodes,
-    // i.e. `new_node_data[i] = old_node_data[perm[i]]`
-    iarray minimize_depth();
+    int_type parent(size_t b) const {
+        return parents_[b];
+    }
+    int_type& parent(size_t b) {
+        return parents_[b];
+    }
 
     /// memory used to store tree (in bytes)
-    std::size_t memory() const;
+    std::size_t memory() const {
+        return sizeof(int_type)*(children_.size()+child_index_.size()+parents_.size())
+            + sizeof(tree);
+    }
 
 private:
-    void init(size_type nnode);
+    void init(size_type nnode) {
+        if (nnode) {
+            auto nchild = nnode - 1;
+            children_.resize(nchild);
+            child_index_.resize(nnode+1);
+            parents_.resize(nnode);
+        }
+        else {
+            children_.resize(0);
+            child_index_.resize(0);
+            parents_.resize(0);
+        }
+    }
 
     // state
     iarray children_;
@@ -111,10 +146,6 @@ private:
     iarray parents_;
 };
 
-// Calculates the depth of each branch from the root of a cell segment tree.
-// The root has depth 0, it's children have depth 1, and so on.
-tree::iarray depth_from_root(const tree& t);
-
 template <typename C>
 std::vector<tree::int_type> make_parent_index(tree const& t, C const& counts)
 {
diff --git a/test/unit/test_gpu_stack.cu b/test/unit/test_gpu_stack.cu
index 10af85c4..e75f8c3b 100644
--- a/test/unit/test_gpu_stack.cu
+++ b/test/unit/test_gpu_stack.cu
@@ -66,7 +66,6 @@ TEST(stack, push_back) {
     kernels::push_back<<<1, n>>>(sstorage, kernels::all_ftor());
     cudaDeviceSynchronize();
     EXPECT_EQ(n, s.size());
-    std::sort(sstorage.data, sstorage.data+s.size());
     for (auto i=0; i<int(s.size()); ++i) {
         EXPECT_EQ(i, s[i]);
     }
@@ -75,7 +74,6 @@ TEST(stack, push_back) {
     kernels::push_back<<<1, n>>>(sstorage, kernels::even_ftor());
     cudaDeviceSynchronize();
     EXPECT_EQ(n/2, s.size());
-    std::sort(sstorage.data, sstorage.data+s.size());
     for (auto i=0; i<int(s.size())/2; ++i) {
         EXPECT_EQ(2*i, s[i]);
     }
@@ -84,7 +82,6 @@ TEST(stack, push_back) {
     kernels::push_back<<<1, n>>>(sstorage, kernels::odd_ftor());
     cudaDeviceSynchronize();
     EXPECT_EQ(n/2, s.size());
-    std::sort(sstorage.data, sstorage.data+s.size());
     for (auto i=0; i<int(s.size())/2; ++i) {
         EXPECT_EQ(2*i+1, s[i]);
     }
diff --git a/test/unit/test_matrix.cu b/test/unit/test_matrix.cu
index b3424ad2..b01fff92 100644
--- a/test/unit/test_matrix.cu
+++ b/test/unit/test_matrix.cu
@@ -15,7 +15,6 @@
 #include "backends/gpu/matrix_state_flat.hpp"
 #include "backends/gpu/matrix_state_interleaved.hpp"
 #include "backends/gpu/matrix_interleave.hpp"
-#include "backends/gpu/matrix_state_fine.hpp"
 
 #include "../gtest.h"
 #include "common.hpp"
@@ -215,7 +214,6 @@ TEST(matrix, backends)
 
     using state_flat = gpu::matrix_state_flat<T, I>;
     using state_intl = gpu::matrix_state_interleaved<T, I>;
-    using state_fine = gpu::matrix_state_fine<T, I>;
 
     using gpu_array  = memory::device_vector<T>;
 
@@ -288,7 +286,6 @@ TEST(matrix, backends)
     // Make the reference matrix and the gpu matrix
     auto flat = state_flat(p, cell_cv_divs, Cm, g, area); // flat
     auto intl = state_intl(p, cell_cv_divs, Cm, g, area); // interleaved
-    auto fine = state_fine(p, cell_cv_divs, Cm, g, area); // interleaved
 
     // Set the integration times for the cells to be between 0.01 and 0.02 ms.
     std::vector<T> dt(num_mtx, 0);
@@ -303,27 +300,90 @@ TEST(matrix, backends)
 
     flat.assemble(gpu_dt, gpu_v, gpu_i);
     intl.assemble(gpu_dt, gpu_v, gpu_i);
-    fine.assemble(gpu_dt, gpu_v, gpu_i);
 
     flat.solve();
     intl.solve();
-    fine.solve();
 
     // Compare the results.
     // We expect exact equality for the two gpu matrix implementations because both
     // perform the same operations in the same order on the same inputs.
     std::vector<double> x_flat = assign_from(on_host(flat.solution()));
     std::vector<double> x_intl = assign_from(on_host(intl.solution()));
-    // as the fine algorithm contains atomics the solution might be slightly
-    // different from flat and interleaved
-    std::vector<double> x_fine = assign_from(on_host(fine.solution()));
+    EXPECT_EQ(x_flat, x_intl);
+}
 
-    auto max_diff_fine =
-        util::max_value(
-            util::transform_view(
-                util::count_along(x_flat),
-                [&](unsigned i) {return std::abs(x_flat[i] - x_fine[i]);}));
+/*
 
-    EXPECT_EQ(x_flat, x_intl);
-    EXPECT_LE(max_diff_fine, 1e-12);
+// Test for special zero diagonal behaviour. (see `test_matrix.cpp`.)
+TEST(matrix, zero_diagonal)
+{
+    using util::assign;
+
+    using value_type = gpu::backend::value_type;
+    using size_type = gpu::backend::size_type;
+    using matrix_type = gpu::backend::matrix_state;
+    using vvec = std::vector<value_type>;
+
+    // Combined matrix may have zero-blocks, corresponding to a zero dt.
+    // Zero-blocks are indicated by zero value in the diagonal (the off-diagonal
+    // elements should be ignored).
+    // These submatrices should leave the rhs as-is when solved.
+
+    // Three matrices, sizes 3, 3 and 2, with no branching.
+    std::vector<size_type> p = {0, 0, 1, 3, 3, 5, 5};
+    std::vector<size_type> c = {0, 3, 5, 7};
+
+    // Face conductances.
+    std::vector<value_type> g = {0, 1, 1, 0, 1, 0, 2};
+
+    // dt of 1e-3.
+    std::vector<value_type> dt(3, 1.0e-3);
+
+    // Capacitances.
+    std::vector<value_type> Cm = {1, 1, 1, 1, 1, 2, 3};
+
+    // Intial voltage of zero; currents alone determine rhs.
+    std::vector<value_type> v(7, 0.0);
+    std::vector<value_type> i = {-3, -5, -7, -6, -9, -16, -32};
+
+    // Expected matrix and rhs:
+    // u = [ 0 -1 -1  0 -1  0 -2]
+    // d = [ 2  3  2  2  2  4  5]
+    // b = [ 3  5  7  2  4 16 32]
+    //
+    // Expected solution:
+    // x = [ 4  5  6  7  8  9 10]
+
+    matrix_type m(p, c, Cm, g);
+    auto gpu_dt = on_gpu(dt);
+    auto gpu_v  = on_gpu(v);
+    auto gpu_i  = on_gpu(i);
+    m.assemble(gpu_dt, gpu_v, gpu_i);
+    m.solve();
+
+    vvec x;
+    assign(x, on_host(m.solution()));
+    std::vector<value_type> expected = {4, 5, 6, 7, 8, 9, 10};
+
+    EXPECT_TRUE(testing::seq_almost_eq<double>(expected, x));
+
+    // Set dt of 2nd (middle) submatrix to zero. Solution
+    // should then return voltage values for that submatrix.
+
+    dt[1] = 0;
+    gpu_dt = on_gpu(dt);
+
+    v[3] = 20;
+    v[4] = 30;
+    gpu_v  = on_gpu(v);
+
+    m.assemble(gpu_dt, gpu_v, gpu_i);
+    m.solve();
+
+    assign(x, on_host(m.solution()));
+    expected = {4, 5, 6, 20, 30, 9, 10};
+
+    EXPECT_TRUE(testing::seq_almost_eq<double>(expected, x));
 }
+
+*/
diff --git a/test/unit/test_matrix_cpuvsgpu.cpp b/test/unit/test_matrix_cpuvsgpu.cpp
index 861c83dd..27390686 100644
--- a/test/unit/test_matrix_cpuvsgpu.cpp
+++ b/test/unit/test_matrix_cpuvsgpu.cpp
@@ -69,25 +69,22 @@ TEST(matrix, assemble)
     //           6
     //            \.
     //             7
-    // p_3: 1 branch, 1 compartment
-    //
-    // 0
 
     // The parent indexes that define the two matrix structures
     std::vector<std::vector<I>>
-        p_base = { {0,0,1,2,2,4}, {0,0,1,2,3,3,2,6}, {0} };
+        p_base = { {0,0,1,2,2,4}, {0,0,1,2,3,3,2,6} };
 
     // Make a set of matrices based on repeating this pattern.
     // We assign the patterns round-robin, i.e. so that the input
     // matrices will have alternating sizes of 6 and 8, which will
     // test the solver with variable matrix size, and exercise
     // solvers that reorder matrices according to size.
-    const int num_mtx = 100;
+    const int num_mtx = 8;
 
     std::vector<I> p;
     std::vector<I> cell_index;
     for (auto m=0; m<num_mtx; ++m) {
-        auto &p_ref = p_base[m%p_base.size()];
+        auto &p_ref = p_base[m%2];
         auto first = p.size();
         for (auto i: p_ref) {
             p.push_back(i + first);
diff --git a/test/unit/test_tree.cpp b/test/unit/test_tree.cpp
index 094e77f1..e23cc8ed 100644
--- a/test/unit/test_tree.cpp
+++ b/test/unit/test_tree.cpp
@@ -9,7 +9,6 @@
 
 using namespace arb;
 using int_type = tree::int_type;
-using iarray = tree::iarray;
 
 TEST(tree, from_segment_index) {
     auto no_parent = tree::no_parent;
@@ -175,85 +174,3 @@ TEST(tree, from_segment_index) {
     }
 }
 
-TEST(tree, depth_from_root) {
-    // tree with single branch corresponding to the root node
-    // this is equivalent to a single compartment model
-    //      CASE 1 : single root node in parent_index
-    {
-        std::vector<int_type> parent_index = {0};
-        iarray expected = {0u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-
-    {
-        //     0
-        //    / \.
-        //   1   2
-        std::vector<int_type> parent_index = {0, 0, 0};
-        iarray expected = {0u, 1u, 1u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-    {
-        //     0-1-2-3
-        std::vector<int_type> parent_index = {0, 0, 1, 2};
-        iarray expected = {0u, 1u, 2u, 3u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-    {
-        //
-        //     0
-        //    /|\.
-        //   1 2 3
-        //
-        std::vector<int_type> parent_index = {0, 0, 0, 0};
-        iarray expected = {0u, 1u, 1u, 1u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-    {
-        //
-        //   0
-        //  /|\.
-        // 1 2 3
-        //    / \.
-        //   4   5
-        //
-        std::vector<int_type> parent_index = {0, 0, 0, 0, 3, 3};
-        iarray expected = {0u, 1u, 1u, 1u, 2u, 2u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-    {
-        //
-        //              0
-        //             /
-        //            1
-        //           / \.
-        //          2   3
-        std::vector<int_type> parent_index = {0,0,1,1};
-        iarray expected = {0u, 1u, 2u, 2u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-    {
-        //
-        //              0
-        //             /|\.
-        //            1 4 5
-        //           / \.
-        //          2   3
-        std::vector<int_type> parent_index = {0,0,1,1,0,0};
-        iarray expected = {0u, 1u, 2u, 2u, 1u, 1u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-    {
-        //              0
-        //             / \.
-        //            1   2
-        //           / \.
-        //          3   4
-        //             / \.
-        //            5   6
-        std::vector<int_type> parent_index = {0,0,0,1,1,4,4};
-        iarray expected = {0u, 1u, 1u, 2u, 2u, 3u, 3u};
-        EXPECT_EQ(expected, depth_from_root(tree(parent_index)));
-    }
-}
-
-- 
GitLab