diff --git a/arbor/backends/gpu/matrix_state_fine.hpp b/arbor/backends/gpu/matrix_state_fine.hpp
index 6790c1e8f86fad9116722b761d707cc2b21d4ee4..e2459166eda64531cba0d3079c55a670697db8fb 100644
--- a/arbor/backends/gpu/matrix_state_fine.hpp
+++ b/arbor/backends/gpu/matrix_state_fine.hpp
@@ -117,9 +117,8 @@ public:
     managed_vector<unsigned> num_cells_in_block;
 
     // end of the data of each level
-    // use data_size.back() to get the total data size
-    //      data_size >= size
-    managed_vector<unsigned> data_size; // TODO rename
+    managed_vector<unsigned> data_partition;
+    std::size_t data_size;
 
     // the meta data for each level for each block layed out linearly in memory
     managed_vector<level> levels;
@@ -313,7 +312,7 @@ public:
         levels.reserve(total_num_levels);
         levels_start.reserve(branch_maps.size() + 1);
         levels_start.push_back(0);
-        data_size.reserve(branch_maps.size());
+        data_partition.reserve(branch_maps.size());
         // offset into the packed data format, used to apply permutation on data
         auto pos = 0u;
         for (const auto& branch_map: branch_maps) {
@@ -345,8 +344,9 @@ public:
             }
             auto prev_end = levels_start.back();
             levels_start.push_back(prev_end + branch_map.size());
-            data_size.push_back(pos);
+            data_partition.push_back(pos);
         }
+	data_size = pos;
 
         // set matrix state
         matrix_size = p.size();
@@ -406,9 +406,9 @@ public:
 
         // the matrix components u, d and rhs are stored in packed form
         auto nan = std::numeric_limits<double>::quiet_NaN();
-        d   = array(data_size.back(), nan);
-        u   = array(data_size.back(), nan);
-        rhs = array(data_size.back(), nan);
+        d   = array(data_size, nan);
+        u   = array(data_size, nan);
+        rhs = array(data_size, nan);
 
         // transform u_tmp values into packed u vector.
         flat_to_packed(u_tmp, u);
@@ -461,7 +461,7 @@ public:
             rhs.data(), d.data(), u.data(),
             levels.data(), levels_start.data(),
             num_cells_in_block.data(),
-            data_size.data(),
+            data_partition.data(),
             num_cells_in_block.size(), max_branches_per_level);
 
         // unpermute the solution
@@ -475,14 +475,14 @@ public:
     template <typename VFrom, typename VTo>
     void flat_to_packed(const VFrom& from, VTo& to ) {
         arb_assert(from.size()==matrix_size);
-        arb_assert(to.size()==data_size.back());
+        arb_assert(to.size()==data_size);
 
         scatter(from.data(), to.data(), perm.data(), perm.size());
     }
 
     template <typename VFrom, typename VTo>
     void packed_to_flat(const VFrom& from, VTo& to ) {
-        arb_assert(from.size()==data_size.back());
+        arb_assert(from.size()==data_size);
         arb_assert(to.size()==matrix_size);
 
         gather(from.data(), to.data(), perm.data(), perm.size());