From 35ccead65bb33468d442613b1e08653d02c7aa8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de> Date: Mon, 17 Mar 2025 13:27:46 +0100 Subject: [PATCH] fix(py-tensorflow): build w/ modern cuda --- .../cub_ThreadLoadVolatilePointer.patch | 34 +++++++++++++++++++ packages/py-tensorflow/package.py | 3 ++ 2 files changed, 37 insertions(+) create mode 100644 packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch diff --git a/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch b/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch new file mode 100644 index 00000000..8ba95826 --- /dev/null +++ b/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch @@ -0,0 +1,34 @@ +diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h +index bef22b5..94fd951 100644 +--- a/tensorflow/core/kernels/gpu_prim.h ++++ b/tensorflow/core/kernels/gpu_prim.h +@@ -44,10 +44,10 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>( + Eigen::numext::bit_cast<uint16_t>(val); + } +- +-template <> +-__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>( +- Eigen::half *ptr, Int2Type<true> /*is_primitive*/) { +- uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr); ++ ++__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer( ++ const Eigen::half *ptr, Int2Type<true> /*is_primitive*/) { ++ uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr); + return Eigen::numext::bit_cast<Eigen::half>(result); + } +- +@@ -59,11 +59,9 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::bfloat16>( + Eigen::numext::bit_cast<uint16_t>(val); + } +- +-template <> +-__device__ __forceinline__ Eigen::bfloat16 +-ThreadLoadVolatilePointer<Eigen::bfloat16>(Eigen::bfloat16 *ptr, +- Int2Type<true> /*is_primitive*/) { +- uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr); ++__device__ __forceinline__ Eigen::bfloat16 ThreadLoadVolatilePointer( ++ const Eigen::bfloat16 *ptr, Int2Type<true> /*is_primitive*/) { ++ uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr); + return Eigen::numext::bit_cast<Eigen::bfloat16>(result); + } +- diff --git a/packages/py-tensorflow/package.py b/packages/py-tensorflow/package.py index 5b858b1e..79309899 100644 --- a/packages/py-tensorflow/package.py +++ b/packages/py-tensorflow/package.py @@ -568,6 +568,9 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension): when="@2.18.0 %gcc", ) + # adapted from https://github.com/tensorflow/tensorflow/commit/5467ee993e1d3e4709c1e99f3a15a978325ae536 + patch("cub_ThreadLoadVolatilePointer.patch", when="@2.18.0 ^cuda@12.8") + def flag_handler(self, name, flags): spec = self.spec # ubuntu gcc has this workaround turned on by default in aarch64 -- GitLab