diff --git a/arbor/include/arbor/gpu/cuda_api.hpp b/arbor/include/arbor/gpu/cuda_api.hpp index 48bc96f262e5d284af7d2adcd8beeb3e3167ffd4..532b9ec8f9f504dbfead6032525bf996610e7827 100644 --- a/arbor/include/arbor/gpu/cuda_api.hpp +++ b/arbor/include/arbor/gpu/cuda_api.hpp @@ -139,17 +139,6 @@ inline float gpu_atomic_sub(float* address, float val) { /// Warp-Level Primitives -__device__ __inline__ double shfl(unsigned mask, double x, int lane) -{ - auto tmp = static_cast<uint64_t>(x); - auto lo = static_cast<unsigned>(tmp); - auto hi = static_cast<unsigned>(tmp >> 32); - hi = __shfl_sync(mask, static_cast<int>(hi), lane, warpSize); - lo = __shfl_sync(mask, static_cast<int>(lo), lane, warpSize); - return static_cast<double>(static_cast<uint64_t>(hi) << 32 | - static_cast<uint64_t>(lo)); -} - __device__ __inline__ unsigned ballot(unsigned mask, unsigned is_root) { return __ballot_sync(mask, is_root); } @@ -158,24 +147,15 @@ __device__ __inline__ unsigned any(unsigned mask, unsigned width) { return __any_sync(mask, width); } -#ifdef __NVCC__ -__device__ __inline__ double shfl_up(unsigned mask, int idx, unsigned lane_id, unsigned shift) { - return __shfl_up_sync(mask, idx, shift); -} - -__device__ __inline__ double shfl_down(unsigned mask, int idx, unsigned lane_id, unsigned shift) { - return __shfl_down_sync(mask, idx, shift); +template<typename T> +__device__ __inline__ T shfl_up(unsigned mask, T var, unsigned lane_id, unsigned shift) { + return __shfl_up_sync(mask, var, shift); } -#else -__device__ __inline__ double shfl_up(unsigned mask, int idx, unsigned lane_id, unsigned shift) { - return shfl(mask, idx, lane_id - shift); +template<typename T> +__device__ __inline__ T shfl_down(unsigned mask, T var, unsigned lane_id, unsigned shift) { + return __shfl_down_sync(mask, var, shift); } - -__device__ __inline__ double shfl_down(unsigned mask, int idx, unsigned lane_id, unsigned shift) { - return shfl(mask, idx, lane_id + shift); -} -#endif #endif } // namespace gpu diff --git a/arbor/include/arbor/gpu/hip_api.hpp b/arbor/include/arbor/gpu/hip_api.hpp index fbdba3151fb1bf0e0e3fccc328f2824f588c79e1..019f235a8e2b9da9a36b7501f185d06e4fe75251 100644 --- a/arbor/include/arbor/gpu/hip_api.hpp +++ b/arbor/include/arbor/gpu/hip_api.hpp @@ -1,5 +1,6 @@ #include <utility> #include <string> +#include <type_traits> #include <hip/hip_runtime.h> #include <hip/hip_runtime_api.h> @@ -118,6 +119,14 @@ inline float gpu_atomic_sub(float* address, float val) { /// Warp-level Primitives +template<typename T> +__device__ __inline__ +std::enable_if_t< !std::is_same_v<std::decay_t<T>, double>, std::decay_t<T>> +shfl(T x, int lane) +{ + return __shfl(x, lane); +} + __device__ __inline__ double shfl(double x, int lane) { auto tmp = static_cast<uint64_t>(x); @@ -137,12 +146,14 @@ __device__ __inline__ unsigned any(unsigned mask, unsigned width) { return __any(width); } -__device__ __inline__ double shfl_up(unsigned mask, int idx, unsigned lane_id, unsigned shift) { - return shfl(idx, lane_id - shift); +template<typename T> +__device__ __inline__ T shfl_up(unsigned mask, T var, unsigned lane_id, unsigned shift) { + return shfl(var, (int)lane_id - shift); } -__device__ __inline__ double shfl_down(unsigned mask, int idx, unsigned lane_id, unsigned shift) { - return shfl(idx, lane_id + shift); +template<typename T> +__device__ __inline__ T shfl_down(unsigned mask, T var, unsigned lane_id, unsigned shift) { + return shfl(var, (int)lane_id + shift); } } // namespace gpu diff --git a/test/unit/test_reduce_by_key.cu b/test/unit/test_reduce_by_key.cu index 44dcab912a9bd1e217052d488f9f871ab92fbffc..5b5763c95e8760053a902bcb59ec4daaa88326ee 100644 --- a/test/unit/test_reduce_by_key.cu +++ b/test/unit/test_reduce_by_key.cu @@ -89,8 +89,8 @@ TEST(reduce_by_key, scatter) // onto an array of length 12. std::size_t n = 12; std::vector<int> index = {0,0,0,1,2,2,2,2,3,3,7,7,7,7,7,11}; - std::vector<double> in(index.size(), 1); - std::vector<double> expected = {3., 1., 4., 2., 0., 0., 0., 5., 0., 0., 0., 1.}; + std::vector<double> in(index.size(), 0.5); + std::vector<double> expected = {1.5, 0.5, 2., 1., 0., 0., 0., 2.5, 0., 0., 0., 0.5}; EXPECT_EQ(n, expected.size());