diff --git a/arbor/backends/gpu/shared_state.cu b/arbor/backends/gpu/shared_state.cu
index 685fe2c72352fb7b1dc4fc0e105a199188b58f7a..57989e642132f00fe3da7826862ba6e27e128feb 100644
--- a/arbor/backends/gpu/shared_state.cu
+++ b/arbor/backends/gpu/shared_state.cu
@@ -42,21 +42,19 @@ __global__ void add_scalar(unsigned n, T* x, fvm_value_type v) {
}
}
-// Vector minus: x = y - z
-template <typename T>
-__global__ void vec_minus(unsigned n, T* x, const T* y, const T* z) {
- unsigned i = threadIdx.x+blockIdx.x*blockDim.x;
- if (i<n) {
- x[i] = y[i]-z[i];
- }
-}
-
-// Vector gather: x[i] = y[index[i]]
template <typename T, typename I>
-__global__ void gather(unsigned n, T* x, const T* y, const I* index) {
- unsigned i = threadIdx.x+blockIdx.x*blockDim.x;
- if (i<n) {
- x[i] = y[index[i]];
+__global__ void set_dt_impl( T* __restrict__ dt_intdom,
+ const T* time_to,
+ const T* time,
+ const unsigned ncomp,
+ T* __restrict__ dt_comp,
+ const I* cv_to_intdom) {
+ auto idx = blockIdx.x*blockDim.x + threadIdx.x;
+ if (idx < ncomp) {
+ const auto ind = cv_to_intdom[idx];
+ const auto dt = time_to[ind] - time[ind];
+ dt_intdom[ind] = dt;
+ dt_comp[idx] = dt;
}
}
@@ -105,11 +103,8 @@ void set_dt_impl(
if (!nintdom || !ncomp) return;
constexpr int block_dim = 128;
- int nblock = block_count(nintdom, block_dim);
- kernel::vec_minus<<<nblock, block_dim>>>(nintdom, dt_intdom, time_to, time);
-
- nblock = block_count(ncomp, block_dim);
- kernel::gather<<<nblock, block_dim>>>(ncomp, dt_comp, dt_intdom, cv_to_intdom);
+ const int nblock = block_count(ncomp, block_dim);
+ kernel::set_dt_impl<<<nblock, block_dim>>>(dt_intdom, time_to, time, ncomp, dt_comp, cv_to_intdom);
}
void add_gj_current_impl(