Skip to content
Snippets Groups Projects
Commit 35ccead6 authored by Eric Müller's avatar Eric Müller :mountain_bicyclist:
Browse files

fix(py-tensorflow): build w/ modern cuda

parent 71fcc9f9
No related branches found
No related tags found
1 merge request!655Draft: Fix py-{jax{,lib},tensorflow}+cuda
diff --git a/tensorflow/core/kernels/gpu_prim.h b/tensorflow/core/kernels/gpu_prim.h
index bef22b5..94fd951 100644
--- a/tensorflow/core/kernels/gpu_prim.h
+++ b/tensorflow/core/kernels/gpu_prim.h
@@ -44,10 +44,10 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::half>(
Eigen::numext::bit_cast<uint16_t>(val);
}
-
-template <>
-__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer<Eigen::half>(
- Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
- uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
+
+__device__ __forceinline__ Eigen::half ThreadLoadVolatilePointer(
+ const Eigen::half *ptr, Int2Type<true> /*is_primitive*/) {
+ uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr);
return Eigen::numext::bit_cast<Eigen::half>(result);
}
-
@@ -59,11 +59,9 @@ __device__ __forceinline__ void ThreadStoreVolatilePtr<Eigen::bfloat16>(
Eigen::numext::bit_cast<uint16_t>(val);
}
-
-template <>
-__device__ __forceinline__ Eigen::bfloat16
-ThreadLoadVolatilePointer<Eigen::bfloat16>(Eigen::bfloat16 *ptr,
- Int2Type<true> /*is_primitive*/) {
- uint16_t result = *reinterpret_cast<volatile uint16_t *>(ptr);
+__device__ __forceinline__ Eigen::bfloat16 ThreadLoadVolatilePointer(
+ const Eigen::bfloat16 *ptr, Int2Type<true> /*is_primitive*/) {
+ uint16_t result = *reinterpret_cast<volatile const uint16_t *>(ptr);
return Eigen::numext::bit_cast<Eigen::bfloat16>(result);
}
-
......@@ -568,6 +568,9 @@ class PyTensorflow(Package, CudaPackage, ROCmPackage, PythonExtension):
when="@2.18.0 %gcc",
)
# adapted from https://github.com/tensorflow/tensorflow/commit/5467ee993e1d3e4709c1e99f3a15a978325ae536
patch("cub_ThreadLoadVolatilePointer.patch", when="@2.18.0 ^cuda@12.8")
def flag_handler(self, name, flags):
spec = self.spec
# ubuntu gcc has this workaround turned on by default in aarch64
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment