diff --git a/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch b/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch new file mode 100644 index 0000000000000000000000000000000000000000..8ba95826666302f50f568ea40ff4d9963025c35c --- /dev/null +++ b/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch @@ -0,0 +1,34 @@ +diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h +index bef22b5..94fd951 100644 +--- a/tensorflow/core/kernels/gpu_prim.h ++++ b/tensorflow/core/kernels/gpu_prim.h +@@ -44,10 +44,10 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>( + Eigen::numext::bit_cast<uint16_t>(val); + } +- +-template <> +-__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>( +- Eigen::half *ptr, Int2Type<true> /*is_primitive*/) { +- uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr); ++ ++__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( ++ const Eigen::half *ptr, Int2Type<true> /*is_primitive*/) { ++ uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr); + return Eigen::numext::bit_cast<Eigen::half>(result); + } +- +@@ -59,11 +59,9 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::bfloat16>( + Eigen::numext::bit_cast<uint16_t>(val); + } +- +-template <> +-__device__ __forceinline__ Eigen::bfloat16 +-ThreadLoadVolatilePointer<Eigen::bfloat16>(Eigen::bfloat16 *ptr, +- Int2Type<true> /*is_primitive*/) { +- uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr); ++__device__ __forceinline__ Eigen::bfloat16 ThreadLoadVolatilePointer( ++ const Eigen::bfloat16 *ptr, Int2Type<true> /*is_primitive*/) { ++ uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr); + return Eigen::numext::bit_cast<Eigen::bfloat16>(result); + } +- diff --git a/packages/py-tensorflow/package.py b/packages/py-tensorflow/package.py index 5b858b1e4843af388551c02037a8c28263c71931..79309899aabb42a5c8dd2bcb11dd189d1bcdf947 100644 --- a/packages/py-tensorflow/package.py +++ b/packages/py-tensorflow/package.py @@ -568,6 +568,9 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension): when="@2.18.0 %gcc", ) + # adapted from https://github.com/tensorflow/tensorflow/commit/5467ee993e1d3e4709c1e99f3a15a978325ae536 + patch("cub_ThreadLoadVolatilePointer.patch", when="@2.18.0 ^cuda@12.8") + def flag_handler(self, name, flags): spec = self.spec # ubuntu gcc has this workaround turned on by default in aarch64