diff --git a/arbor/backends/gpu/shared_state.cu b/arbor/backends/gpu/shared_state.cu
index 685fe2c72352fb7b1dc4fc0e105a199188b58f7a..57989e642132f00fe3da7826862ba6e27e128feb 100644
--- a/arbor/backends/gpu/shared_state.cu
+++ b/arbor/backends/gpu/shared_state.cu
@@ -42,21 +42,19 @@ __global__ void add_scalar(unsigned n, T* x, fvm_value_type v) {
     }
 }
 
-// Vector minus: x = y - z
-template <typename T>
-__global__ void vec_minus(unsigned n, T* x, const T* y, const T* z) {
-    unsigned i = threadIdx.x+blockIdx.x*blockDim.x;
-    if (i<n) {
-        x[i] = y[i]-z[i];
-    }
-}
-
-// Vector gather: x[i] = y[index[i]]
 template <typename T, typename I>
-__global__ void gather(unsigned n, T* x, const T* y, const I* index) {
-    unsigned i = threadIdx.x+blockIdx.x*blockDim.x;
-    if (i<n) {
-        x[i] = y[index[i]];
+__global__ void set_dt_impl(      T* __restrict__ dt_intdom,
+                            const T* time_to,
+                            const T* time,
+                            const unsigned ncomp,
+                                  T* __restrict__ dt_comp,
+                            const I* cv_to_intdom) {
+    auto idx = blockIdx.x*blockDim.x + threadIdx.x;
+    if (idx < ncomp) {
+        const auto ind = cv_to_intdom[idx];
+        const auto dt = time_to[ind] - time[ind];
+        dt_intdom[ind] = dt;
+        dt_comp[idx] = dt;
     }
 }
 
@@ -105,11 +103,8 @@ void set_dt_impl(
     if (!nintdom || !ncomp) return;
 
     constexpr int block_dim = 128;
-    int nblock = block_count(nintdom, block_dim);
-    kernel::vec_minus<<<nblock, block_dim>>>(nintdom, dt_intdom, time_to, time);
-
-    nblock = block_count(ncomp, block_dim);
-    kernel::gather<<<nblock, block_dim>>>(ncomp, dt_comp, dt_intdom, cv_to_intdom);
+    const int nblock = block_count(ncomp, block_dim);
+    kernel::set_dt_impl<<<nblock, block_dim>>>(dt_intdom, time_to, time, ncomp, dt_comp, cv_to_intdom);
 }
 
 void add_gj_current_impl(