diff --git a/arbor/backends/gpu/shared_state.cu b/arbor/backends/gpu/shared_state.cu index 685fe2c72352fb7b1dc4fc0e105a199188b58f7a..57989e642132f00fe3da7826862ba6e27e128feb 100644 --- a/arbor/backends/gpu/shared_state.cu +++ b/arbor/backends/gpu/shared_state.cu @@ -42,21 +42,19 @@ __global__ void add_scalar(unsigned n, T* x, fvm_value_type v) { } } -// Vector minus: x = y - z -template <typename T> -__global__ void vec_minus(unsigned n, T* x, const T* y, const T* z) { - unsigned i = threadIdx.x+blockIdx.x*blockDim.x; - if (i<n) { - x[i] = y[i]-z[i]; - } -} - -// Vector gather: x[i] = y[index[i]] template <typename T, typename I> -__global__ void gather(unsigned n, T* x, const T* y, const I* index) { - unsigned i = threadIdx.x+blockIdx.x*blockDim.x; - if (i<n) { - x[i] = y[index[i]]; +__global__ void set_dt_impl( T* __restrict__ dt_intdom, + const T* time_to, + const T* time, + const unsigned ncomp, + T* __restrict__ dt_comp, + const I* cv_to_intdom) { + auto idx = blockIdx.x*blockDim.x + threadIdx.x; + if (idx < ncomp) { + const auto ind = cv_to_intdom[idx]; + const auto dt = time_to[ind] - time[ind]; + dt_intdom[ind] = dt; + dt_comp[idx] = dt; } } @@ -105,11 +103,8 @@ void set_dt_impl( if (!nintdom || !ncomp) return; constexpr int block_dim = 128; - int nblock = block_count(nintdom, block_dim); - kernel::vec_minus<<<nblock, block_dim>>>(nintdom, dt_intdom, time_to, time); - - nblock = block_count(ncomp, block_dim); - kernel::gather<<<nblock, block_dim>>>(ncomp, dt_comp, dt_intdom, cv_to_intdom); + const int nblock = block_count(ncomp, block_dim); + kernel::set_dt_impl<<<nblock, block_dim>>>(dt_intdom, time_to, time, ncomp, dt_comp, cv_to_intdom); } void add_gj_current_impl(