diff --git a/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch b/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch
new file mode 100644
index 0000000000000000000000000000000000000000..8ba95826666302f50f568ea40ff4d9963025c35c
--- /dev/null
+++ b/packages/py-tensorflow/cub_ThreadLoadVolatilePointer.patch
@@ -0,0 +1,34 @@
+diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h
+index bef22b5..94fd951 100644
+--- a/tensorflow/core/kernels/gpu_prim.h
++++ b/tensorflow/core/kernels/gpu_prim.h
+@@ -44,10 +44,10 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>(
+       Eigen::numext::bit_cast<uint16_t>(val);
+ }
+-
+-template <>
+-__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>(
+-    Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
+-  uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
++
++__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer(
++    const Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
++  uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr);
+   return Eigen::numext::bit_cast<Eigen::half>(result);
+ }
+-
+@@ -59,11 +59,9 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::bfloat16>(
+       Eigen::numext::bit_cast<uint16_t>(val);
+ }
+-
+-template <>
+-__device__ __forceinline__ Eigen::bfloat16
+-ThreadLoadVolatilePointer<Eigen::bfloat16>(Eigen::bfloat16 *ptr,
+-                                           Int2Type<true> /*is_primitive*/) {
+-  uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
++__device__ __forceinline__ Eigen::bfloat16 ThreadLoadVolatilePointer(
++    const Eigen::bfloat16 *ptr, Int2Type<true> /*is_primitive*/) {
++  uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr);
+   return Eigen::numext::bit_cast<Eigen::bfloat16>(result);
+ }
+-
diff --git a/packages/py-tensorflow/package.py b/packages/py-tensorflow/package.py
index 5b858b1e4843af388551c02037a8c28263c71931..79309899aabb42a5c8dd2bcb11dd189d1bcdf947 100644
--- a/packages/py-tensorflow/package.py
+++ b/packages/py-tensorflow/package.py
@@ -568,6 +568,9 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
         when="@2.18.0 %gcc",
     )
 
+    # adapted from https://github.com/tensorflow/tensorflow/commit/5467ee993e1d3e4709c1e99f3a15a978325ae536
+    patch("cub_ThreadLoadVolatilePointer.patch", when="@2.18.0 ^cuda@12.8")
+
     def flag_handler(self, name, flags):
         spec = self.spec
         # ubuntu gcc has this workaround turned on by default in aarch64