diff --git a/arbor/include/arbor/gpu/cuda_api.hpp b/arbor/include/arbor/gpu/cuda_api.hpp
index 48bc96f262e5d284af7d2adcd8beeb3e3167ffd4..532b9ec8f9f504dbfead6032525bf996610e7827 100644
--- a/arbor/include/arbor/gpu/cuda_api.hpp
+++ b/arbor/include/arbor/gpu/cuda_api.hpp
@@ -139,17 +139,6 @@ inline float gpu_atomic_sub(float* address, float val) {
 
 /// Warp-Level Primitives
 
-__device__ __inline__ double shfl(unsigned mask, double x, int lane)
-{
-    auto tmp = static_cast<uint64_t>(x);
-    auto lo = static_cast<unsigned>(tmp);
-    auto hi = static_cast<unsigned>(tmp >> 32);
-    hi = __shfl_sync(mask, static_cast<int>(hi), lane, warpSize);
-    lo = __shfl_sync(mask, static_cast<int>(lo), lane, warpSize);
-    return static_cast<double>(static_cast<uint64_t>(hi) << 32 |
-                               static_cast<uint64_t>(lo));
-}
-
 __device__ __inline__ unsigned ballot(unsigned mask, unsigned is_root) {
     return __ballot_sync(mask, is_root);
 }
@@ -158,24 +147,15 @@ __device__ __inline__ unsigned any(unsigned mask, unsigned width) {
     return __any_sync(mask, width);
 }
 
-#ifdef __NVCC__
-__device__ __inline__ double shfl_up(unsigned mask, int idx, unsigned lane_id, unsigned shift) {
-    return __shfl_up_sync(mask, idx, shift);
-}
-
-__device__ __inline__ double shfl_down(unsigned mask, int idx, unsigned lane_id, unsigned shift) {
-    return __shfl_down_sync(mask, idx, shift);
+template<typename T>
+__device__ __inline__ T shfl_up(unsigned mask, T var, unsigned lane_id, unsigned shift) {
+    return __shfl_up_sync(mask, var, shift);
 }
 
-#else
-__device__ __inline__ double shfl_up(unsigned mask, int idx, unsigned lane_id, unsigned shift) {
-    return shfl(mask, idx, lane_id - shift);
+template<typename T>
+__device__ __inline__ T shfl_down(unsigned mask, T var, unsigned lane_id, unsigned shift) {
+    return __shfl_down_sync(mask, var, shift);
 }
-
-__device__ __inline__ double shfl_down(unsigned mask, int idx, unsigned lane_id, unsigned shift) {
-    return shfl(mask, idx, lane_id + shift);
-}
-#endif
 #endif
 
 } // namespace gpu
diff --git a/arbor/include/arbor/gpu/hip_api.hpp b/arbor/include/arbor/gpu/hip_api.hpp
index fbdba3151fb1bf0e0e3fccc328f2824f588c79e1..019f235a8e2b9da9a36b7501f185d06e4fe75251 100644
--- a/arbor/include/arbor/gpu/hip_api.hpp
+++ b/arbor/include/arbor/gpu/hip_api.hpp
@@ -1,5 +1,6 @@
 #include <utility>
 #include <string>
+#include <type_traits>
 
 #include <hip/hip_runtime.h>
 #include <hip/hip_runtime_api.h>
@@ -118,6 +119,14 @@ inline float gpu_atomic_sub(float* address, float val) {
 
 /// Warp-level Primitives
 
+template<typename T>
+__device__ __inline__
+std::enable_if_t< !std::is_same_v<std::decay_t<T>, double>, std::decay_t<T>>
+shfl(T x, int lane)
+{
+    return __shfl(x, lane);
+}
+
 __device__ __inline__ double shfl(double x, int lane)
 {
     auto tmp = static_cast<uint64_t>(x);
@@ -137,12 +146,14 @@ __device__ __inline__ unsigned any(unsigned mask, unsigned width) {
     return __any(width);
 }
 
-__device__ __inline__ double shfl_up(unsigned mask, int idx, unsigned lane_id, unsigned shift) {
-    return shfl(idx, lane_id - shift);
+template<typename T>
+__device__ __inline__ T shfl_up(unsigned mask, T var, unsigned lane_id, unsigned shift) {
+    return shfl(var, (int)lane_id - shift);
 }
 
-__device__ __inline__ double shfl_down(unsigned mask, int idx, unsigned lane_id, unsigned shift) {
-    return shfl(idx, lane_id + shift);
+template<typename T>
+__device__ __inline__ T shfl_down(unsigned mask, T var, unsigned lane_id, unsigned shift) {
+    return shfl(var, (int)lane_id + shift);
 }
 
 } // namespace gpu
diff --git a/test/unit/test_reduce_by_key.cu b/test/unit/test_reduce_by_key.cu
index 44dcab912a9bd1e217052d488f9f871ab92fbffc..5b5763c95e8760053a902bcb59ec4daaa88326ee 100644
--- a/test/unit/test_reduce_by_key.cu
+++ b/test/unit/test_reduce_by_key.cu
@@ -89,8 +89,8 @@ TEST(reduce_by_key, scatter)
     // onto an array of length 12.
     std::size_t n = 12;
     std::vector<int> index = {0,0,0,1,2,2,2,2,3,3,7,7,7,7,7,11};
-    std::vector<double> in(index.size(), 1);
-    std::vector<double> expected = {3., 1., 4., 2., 0., 0., 0., 5., 0., 0., 0., 1.};
+    std::vector<double> in(index.size(), 0.5);
+    std::vector<double> expected = {1.5, 0.5, 2., 1., 0., 0., 0., 2.5, 0., 0., 0., 0.5};
 
     EXPECT_EQ(n, expected.size());