diff --git a/packages/py-jax/package.py b/packages/py-jax/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99d02a83a2d4a9b5d08e513a74dbea50a104b51
--- /dev/null
+++ b/packages/py-jax/package.py
@@ -0,0 +1,149 @@
+# Copyright Spack Project Developers. See COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+
+from spack.package import *
+
+
+class PyJax(PythonPackage):
+    """Differentiate, compile, and transform Numpy code.
+
+    JAX is a Python library for accelerator-oriented array computation and program transformation,
+    designed for high-performance numerical computing and large-scale machine learning.
+    """
+
+    homepage = "https://github.com/jax-ml/jax"
+    pypi = "jax/jax-0.4.27.tar.gz"
+
+    license("Apache-2.0")
+    maintainers("adamjstewart", "jonas-eschle")
+
+    # version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
+    # version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
+    # version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
+    # version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
+    # version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
+    # version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
+    # version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
+    # version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
+    version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
+    version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
+    version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
+    version("0.4.28", sha256="dcf0a44aff2e1713f0a2b369281cd5b79d8c18fc1018905c4125897cb06b37e9")
+    version("0.4.27", sha256="f3d7f19bdc0a17ccdb305086099a5a90c704f904d4272a70debe06ae6552998c")
+    version("0.4.26", sha256="2cce025d0a279ec630d550524749bc8efe25d2ff47240d2a7d4cfbc5090c5383")
+    version("0.4.25", sha256="a8ee189c782de2b7b2ffb64a8916da380b882a617e2769aa429b71d79747b982")
+    version("0.4.24", sha256="4a6b6fd026ddd22653c7fa2fac1904c3de2dbe845b61ede08af9a5cc709662ae")
+    version("0.4.23", sha256="2a229a5a758d1b803891b2eaed329723f6b15b4258b14dc0ccb1498c84963685")
+    version("0.4.22", sha256="801434dda6e14f82a45fff753969a33281ab22fb2a50fe801b651390321057ba")
+    version("0.4.21", sha256="c97fd0d2751d6e1eb15aa2052ff7cfdc129f8fafc2c14cd779720658926a587b")
+    version("0.4.20", sha256="ea96a763a8b1a9374639d1159ab4de163461d01cd022f67c34c09581b71ed2ac")
+    version("0.4.19", sha256="29f87f9a50964d3ca5eeb2973de3462f0e8b4eca6d46027894a0e9a903420601")
+    version("0.4.18", sha256="776cf33890100803e98f45f9af10aa727271c6993d4e766c069118733c928132")
+    version("0.4.17", sha256="d7508a69e87835f534cb07a2f21d79cc1cb8c4cfdcf7fb010927267ef7355f1d")
+    version("0.4.16", sha256="e2ca82c9bf973c2c1c01f5340a583692b31f277aa3abd0544229c1fe5fa44b02")
+    version("0.4.15", sha256="2aa123ccef591e355dea94a6e714b6559f8e1d6368a576a223f97d031ece0d15")
+    version("0.4.14", sha256="18fed3881f26e8b13c8cb46eeeea3dba9eb4d48e3714d8e8f2304dd6e237083d")
+    version("0.4.13", sha256="03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa")
+    version("0.4.12", sha256="d2de9a2388ffe002f16506d3ad1cc6e34d7536b98948e49c7e05bbcfe8e57998")
+    version("0.4.11", sha256="8b1cd443b698339df8d8807578ee141e5b67e36125b3945b146f600177d60d79")
+    version("0.4.10", sha256="1bf0f2720f778f2937301a16a4d5cd3497f13a4d6c970c24a88918a81816a888")
+    version("0.4.9", sha256="1ed135cd08f48e4baf10f6eafdb4a4cdae781f9052b5838c09c91a9f4fa75f09")
+    version("0.4.8", sha256="08116481f7336db16c24812bfb5e6f9786915f4c2f6ff4028331fa69e7535202")
+    version("0.4.7", sha256="5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8")
+    version("0.4.6", sha256="d06ea8fba4ed315ec55110396058cb48c8edb2ab0b412f28c8a123beee9e58ab")
+    version("0.4.5", sha256="1633e56d34b18ddfa7d2a216ce214fa6fa712d36552532aaa71da416aede7268")
+    version("0.4.4", sha256="39b07e07343ed7c74492ee5e75db77456d3afdd038a322671f09fc748f6392cb")
+    version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
+
+    depends_on("py-setuptools", type="build")
+
+    with default_args(type=("build", "run")):
+        # setup.py
+        depends_on("python@3.10:", when="@0.4.31:")
+        depends_on("python@3.9:", when="@0.4.14:")
+        depends_on("py-ml-dtypes@0.4:", when="@0.4.29,0.4.35:")
+        depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
+        depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
+        depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
+        depends_on("py-numpy@1.25:", when="@0.5:")
+        depends_on("py-numpy@1.24:", when="@0.4.31:")
+        depends_on("py-numpy@1.22:", when="@0.4.14:")
+        depends_on("py-numpy@1.21:", when="@0.4.7:")
+        depends_on("py-numpy@1.20:", when="@0.3:")
+        # https://github.com/google/jax/issues/19246
+        depends_on("py-numpy@:1", when="@:0.4.25")
+        depends_on("py-opt-einsum")
+        depends_on("py-scipy@1.11.1:", when="@0.5:")
+        depends_on("py-scipy@1.10:", when="@0.4.31:")
+        depends_on("py-scipy@1.9:", when="@0.4.19:")
+        depends_on("py-scipy@1.7:", when="@0.4.7:")
+        depends_on("py-scipy@1.5:", when="@0.3:")
+
+        # jax/_src/lib/__init__.py
+        # https://github.com/google/jax/commit/8be057de1f50756fe7522f7e98b2f30fad56f7e4
+        for v in [
+            # "0.5.0",
+            # "0.4.38",
+            # "0.4.37",
+            # "0.4.36",
+            # "0.4.35",
+            # "0.4.34",
+            # "0.4.33",
+            # "0.4.32",
+            "0.4.31",
+            "0.4.30",
+            "0.4.29",
+            "0.4.28",
+            "0.4.27",
+            "0.4.26",
+            "0.4.25",
+            "0.4.24",
+            "0.4.23",
+            "0.4.22",
+            "0.4.21",
+            "0.4.20",
+            "0.4.19",
+            "0.4.18",
+            "0.4.17",
+            "0.4.16",
+            "0.4.15",
+            "0.4.14",
+            "0.4.13",
+            "0.4.12",
+            "0.4.11",
+            "0.4.10",
+            "0.4.9",
+            "0.4.8",
+            "0.4.7",
+            "0.4.6",
+            "0.4.5",
+            "0.4.4",
+            "0.4.3",
+        ]:
+            depends_on(f"py-jaxlib@:{v}", when=f"@{v}")
+
+        # See _minimum_jaxlib_version in jax/version.py
+        # depends_on("py-jaxlib@0.5:", when="@0.5:")
+        # depends_on("py-jaxlib@0.4.38:", when="@0.4.38:")
+        # depends_on("py-jaxlib@0.4.36:", when="@0.4.36:")
+        # depends_on("py-jaxlib@0.4.35:", when="@0.4.35:")
+        # depends_on("py-jaxlib@0.4.34:", when="@0.4.34:")
+        # depends_on("py-jaxlib@0.4.33:", when="@0.4.33:")
+        # depends_on("py-jaxlib@0.4.32:", when="@0.4.32:")
+        depends_on("py-jaxlib@0.4.30:", when="@0.4.31:")
+        depends_on("py-jaxlib@0.4.27:", when="@0.4.28:")
+        depends_on("py-jaxlib@0.4.23:", when="@0.4.27:")
+        depends_on("py-jaxlib@0.4.20:", when="@0.4.25:")
+        depends_on("py-jaxlib@0.4.19:", when="@0.4.21:")
+        depends_on("py-jaxlib@0.4.14:", when="@0.4.15:")
+        depends_on("py-jaxlib@0.4.11:", when="@0.4.12:")
+        depends_on("py-jaxlib@0.4.7:", when="@0.4.8:")
+        depends_on("py-jaxlib@0.4.6:", when="@0.4.7:")
+        depends_on("py-jaxlib@0.4.4:", when="@0.4.5:")
+        depends_on("py-jaxlib@0.4.2:", when="@0.4.3:")
+        depends_on("py-jaxlib@0.4.1:", when="@0.4.2:")
+
+        # Historical dependencies
+        depends_on("py-importlib-metadata@4.6:", when="@0.4.11:0.4.30 ^python@:3.9")
diff --git a/packages/py-jaxlib/jaxxlatsl.patch b/packages/py-jaxlib/jaxxlatsl.patch
new file mode 100644
index 0000000000000000000000000000000000000000..e96cc32e2639691ede2ccabee960cbbacaebc808
--- /dev/null
+++ b/packages/py-jaxlib/jaxxlatsl.patch
@@ -0,0 +1,100 @@
+From 8fce7378ed8ce994107568449806cd99274ab22b Mon Sep 17 00:00:00 2001
+From: Andrew Elble <aweits@rit.edu>
+Date: Mon, 21 Oct 2024 19:42:31 -0400
+Subject: [PATCH] patchit
+
+---
+ ...ch-for-Abseil-to-fix-build-on-Jetson.patch | 68 +++++++++++++++++++
+ third_party/xla/workspace.bzl                 |  1 +
+ 2 files changed, 69 insertions(+)
+ create mode 100644 third_party/xla/0001-Add-patch-for-Abseil-to-fix-build-on-Jetson.patch
+
+diff --git a/third_party/xla/0001-Add-patch-for-Abseil-to-fix-build-on-Jetson.patch b/third_party/xla/0001-Add-patch-for-Abseil-to-fix-build-on-Jetson.patch
+new file mode 100644
+index 000000000000..5138a045082b
+--- /dev/null
++++ b/third_party/xla/0001-Add-patch-for-Abseil-to-fix-build-on-Jetson.patch
+@@ -0,0 +1,68 @@
++From 40da87a0476436ca1da2eafe08935787a05e9a61 Mon Sep 17 00:00:00 2001
++From: David Dunleavy <ddunleavy@google.com>
++Date: Mon, 5 Aug 2024 11:42:53 -0700
++Subject: [PATCH] Add patch for Abseil to fix build on Jetson
++
++Patches in https://github.com/abseil/abseil-cpp/commit/372124e6af36a540e74a2ec31d79d7297a831f98
++
++PiperOrigin-RevId: 659627531
++---
++ .../tsl/third_party/absl/nvidia_jetson.patch  | 35 +++++++++++++++++++
++ .../tsl/third_party/absl/workspace.bzl        |  1 +
++ 2 files changed, 36 insertions(+)
++ create mode 100644 third_party/tsl/third_party/absl/nvidia_jetson.patch
++
++diff --git a/third_party/tsl/third_party/absl/nvidia_jetson.patch b/third_party/tsl/third_party/absl/nvidia_jetson.patch
++new file mode 100644
++index 000000000000..5328c3a0d605
++--- /dev/null
+++++ b/third_party/tsl/third_party/absl/nvidia_jetson.patch
++@@ -0,0 +1,35 @@
+++From 372124e6af36a540e74a2ec31d79d7297a831f98 Mon Sep 17 00:00:00 2001
+++From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20Bastien?= <frederic.bastien@gmail.com>
+++Date: Thu, 1 Aug 2024 12:38:52 -0700
+++Subject: [PATCH] PR #1732: Fix build on NVIDIA Jetson board. Fix #1665
+++
+++Imported from GitHub PR https://github.com/abseil/abseil-cpp/pull/1732
+++
+++Fix build on NVIDIA Jetson board. Fix #1665
+++
+++This patch is already used by the spark project.
+++I'm fixing this as this break the build of Tensorflow and JAX on Jetson board.
+++Merge 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff into 6b8ebb35c0414ef5a2b6fd4a0f59057e41beaff9
+++
+++Merging this change closes #1732
+++
+++COPYBARA_INTEGRATE_REVIEW=https://github.com/abseil/abseil-cpp/pull/1732 from nouiz:fix_neon_on_jetson 7db2d2ab9fbed1f0fabad10a6ec73533ba71bfff
+++PiperOrigin-RevId: 658501520
+++Change-Id: If502ede4efc8c877fb3fed227eca6dc7622dd181
+++---
+++ absl/base/config.h | 2 +-
+++ 1 file changed, 1 insertion(+), 1 deletion(-)
+++
+++diff --git a/absl/base/config.h b/absl/base/config.h
+++index 97c9a22a109..ab1e9860a91 100644
+++--- a/absl/base/config.h
++++++ b/absl/base/config.h
+++@@ -926,7 +926,7 @@ static_assert(ABSL_INTERNAL_INLINE_NAMESPACE_STR[0] != 'h' ||
+++ // https://llvm.org/docs/CompileCudaWithLLVM.html#detecting-clang-vs-nvcc-from-code
+++ #ifdef ABSL_INTERNAL_HAVE_ARM_NEON
+++ #error ABSL_INTERNAL_HAVE_ARM_NEON cannot be directly set
+++-#elif defined(__ARM_NEON) && !defined(__CUDA_ARCH__)
++++#elif defined(__ARM_NEON) && !(defined(__NVCC__) && defined(__CUDACC__))
+++ #define ABSL_INTERNAL_HAVE_ARM_NEON 1
+++ #endif
+++ 
++diff --git a/third_party/tsl/third_party/absl/workspace.bzl b/third_party/tsl/third_party/absl/workspace.bzl
++index 06f75166ce4b..9565a82c3319 100644
++--- a/third_party/tsl/third_party/absl/workspace.bzl
+++++ b/third_party/tsl/third_party/absl/workspace.bzl
++@@ -44,4 +44,5 @@ def repo():
++         system_link_files = SYS_LINKS,
++         strip_prefix = "abseil-cpp-{commit}".format(commit = ABSL_COMMIT),
++         urls = tf_mirror_urls("https://github.com/abseil/abseil-cpp/archive/{commit}.tar.gz".format(commit = ABSL_COMMIT)),
+++        patch_file = ["//third_party/absl:nvidia_jetson.patch"],
++     )
++-- 
++2.31.1
++
+diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl
+index af52e7671507..70481bc970a5 100644
+--- a/third_party/xla/workspace.bzl
++++ b/third_party/xla/workspace.bzl
+@@ -29,6 +29,7 @@ def repo():
+         name = "xla",
+         sha256 = XLA_SHA256,
+         strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT),
++	patch_file = ["//third_party/xla:0001-Add-patch-for-Abseil-to-fix-build-on-Jetson.patch"],
+         urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)),
+     )
+ 
+-- 
+2.31.1
+
diff --git a/packages/py-jaxlib/package.py b/packages/py-jaxlib/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd679311a25bc1c8658662aa8a6c09766e775ff3
--- /dev/null
+++ b/packages/py-jaxlib/package.py
@@ -0,0 +1,231 @@
+# Copyright Spack Project Developers. See COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+import glob
+
+from spack.build_systems.python import PythonPipBuilder
+from spack.package import *
+
+rocm_dependencies = [
+    "hsa-rocr-dev",
+    "hip",
+    "rccl",
+    "rocprim",
+    "hipcub",
+    "rocthrust",
+    "roctracer-dev",
+    "rocrand",
+    "hipsparse",
+    "hipfft",
+    "rocfft",
+    "rocblas",
+    "miopen-hip",
+    "rocminfo",
+]
+
+
+class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
+    """XLA library for Jax.
+
+    jaxlib is the support library for JAX. While JAX itself is a pure Python package,
+    jaxlib contains the binary (C/C++) parts of the library, including Python bindings,
+    the XLA compiler, the PJRT runtime, and a handful of handwritten kernels.
+    """
+
+    homepage = "https://github.com/jax-ml/jax"
+    url = "https://github.com/jax-ml/jax/archive/refs/tags/jax-v0.4.34.tar.gz"
+
+    license("Apache-2.0")
+    maintainers("adamjstewart", "jonas-eschle")
+
+    # version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
+    # version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
+    # version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
+    # version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
+    # version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
+    # version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
+    # version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
+    # version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
+    version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
+    version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
+    version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
+    version("0.4.28", sha256="4dd11577d4ba5a095fbc35258ddd4e4c020829ed6e6afd498c9e38ccbcdfe20b")
+    version("0.4.27", sha256="c2c82cd9ad3b395d5cbc0affa26a2938e52677a69ca8f0b9ef9922a52cac4f0c")
+    version("0.4.26", sha256="ddc14da1eaa34f23430d40ad9b9585088575cac439a2fa1c6833a247e1b221fd")
+    version("0.4.25", sha256="fc1197c401924942eb14185a61688d0c476e3e81ff71f9dc95e620b57c06eec8")
+    version("0.4.24", sha256="c4e6963c2c36f634a9a1765e476a1ed4e6c4a7954465ebf72e29f344c28ddc28")
+    version("0.4.23", sha256="e4c06d62ba54becffd91abc862627b8b11b79c5a77366af8843b819665b6d568")
+    version("0.4.21", sha256="8d57f66d00b9c0b824b1eff84adda5b765a412b3f316ef7c773632d1edbf9477")
+    version("0.4.20", sha256="058410d2bc12f7562c7b01e0c8cd587cb68059c12f78bc945055e5ddc445f5fd")
+    version("0.4.19", sha256="51242b217a1f82474e42d24f09ed5dedff951eeb4579c6e49e706d1adfd6949d")
+    version("0.4.16", sha256="85c8bc050abe0a2cf62e8cfc7edb4904dd3807924b5714ec6277f291c576b5ca")
+    version("0.4.14", sha256="9f309476a8f6337717b059b8d10b5859b4134c30cf8f1220bb70379b5e2744a4")
+    version("0.4.11", sha256="bdfc45f33970beba5caf28d061668a4863f05994deea26791db50ea605fc2e36")
+    version("0.4.7", sha256="0578d5dd5035b5225cadb6a62ca5f93dd76b70292268502fc01a0fd9ca7001d0")
+    version("0.4.6", sha256="2c9bf8962815bc54ef524e33dc8eda9d165d379fe87e0df210f316adead27787")
+    version("0.4.4", sha256="881f402c7983b56b185e182d5315dd64c9f5320be96213d0415996ece1826806")
+    version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
+
+    variant("cuda", default=True, description="Build with CUDA enabled")
+    variant("nccl", default=True, description="Build with NCCL enabled", when="+cuda")
+
+    depends_on("c", type="build")
+    depends_on("cxx", type="build")
+
+    # docs/installation.md (Compatible with)
+    with when("+cuda"):
+        depends_on("cuda@12.1:", when="@0.4.26:")
+        depends_on("cuda@11.8:", when="@0.4.11:")
+        depends_on("cuda@11.4:", when="@0.4.0:0.4.7")
+        depends_on("cudnn@9.1:9", when="@0.4.31:")
+        depends_on("cudnn@9", when="@0.4.29:0.4.30")
+        depends_on("cudnn@8.9:8", when="@0.4.26:0.4.28")
+        depends_on("cudnn@8.8:8", when="@0.4.11:0.4.25")
+        depends_on("cudnn@8.2:8", when="@0.4:0.4.7")
+
+    with when("+nccl"):
+        depends_on("nccl@2.18:", when="@0.4.26:")
+        depends_on("nccl@2.16:", when="@0.4.18:")
+        depends_on("nccl")
+
+    with when("+rocm"):
+        for pkg_dep in rocm_dependencies:
+            depends_on(f"{pkg_dep}@6:", when="@0.4.28:")
+            depends_on(pkg_dep)
+        depends_on("py-nanobind")
+
+    with default_args(type="build"):
+        # .bazelversion
+        depends_on("bazel@6.5.0", when="@0.4.28:")
+        depends_on("bazel@6.1.2", when="@0.4.11:0.4.27")
+        depends_on("bazel@5.1.1", when="@0.3.7:0.4.10")
+
+        # jaxlib/setup.py
+        depends_on("py-setuptools")
+
+        # build/build.py
+        depends_on("py-build", when="@0.4.14:")
+
+    with default_args(type=("build", "run")):
+        # Based on PyPI wheels
+        depends_on("python@3.10:", when="@0.4.31:")
+        depends_on("python@3.9:", when="@0.4.14:")
+        depends_on("python@3.8:", when="@0.4.6:")
+        depends_on("python@:3.13")
+        depends_on("python@:3.12", when="@:0.4.33")
+        depends_on("python@:3.11", when="@:0.4.16")
+
+        # jaxlib/setup.py
+        depends_on("py-scipy@1.11.1:", when="@0.5:")
+        depends_on("py-scipy@1.10:", when="@0.4.31:")
+        depends_on("py-scipy@1.9:", when="@0.4.19:")
+        depends_on("py-scipy@1.7:", when="@0.4.7:")
+        depends_on("py-scipy@1.5:")
+        depends_on("py-numpy@1.25:", when="@0.5:")
+        depends_on("py-numpy@1.24:", when="@0.4.31:")
+        depends_on("py-numpy@1.22:", when="@0.4.14:")
+        depends_on("py-numpy@1.21:", when="@0.4.7:")
+        depends_on("py-numpy@1.20:", when="@0.3:")
+        # https://github.com/google/jax/issues/19246
+        depends_on("py-numpy@:1", when="@:0.4.25")
+        depends_on("py-ml-dtypes@0.4:", when="@0.4.29")
+        depends_on("py-ml-dtypes@0.2:", when="@0.4.14:")
+        depends_on("py-ml-dtypes@0.1:", when="@0.4.9:")
+        depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:")
+
+    patch(
+        "https://github.com/jax-ml/jax/commit/f62af6457a6cc575a7b1ada08d541f0dd0eb5765.patch?full_index=1",
+        sha256="d3b7ea2cfeba927e40a11f07e4cbf80939f7fe69448c9eb55231a93bd64e5c02",
+        when="@0.4.36:0.4.38",
+    )
+    patch(
+        "https://github.com/jax-ml/jax/pull/25473.patch?full_index=1",
+        sha256="9d6977bc32046600bf8b15863251283fe7546896340367a7f14e3dccf418b4fe",
+        when="@0.4.36:0.4.37",
+    )
+    patch(
+        "https://github.com/google/jax/pull/20101.patch?full_index=1",
+        sha256="4dfb9f32d4eeb0a0fb3a6f4124c4170e3fe49511f1b768cd634c78d489962275",
+        when="@:0.4.25",
+    )
+
+    # Might be able to be applied to earlier versions
+    # backports https://github.com/abseil/abseil-cpp/pull/1732
+    patch("jaxxlatsl.patch", when="@0.4.28:0.4.32 target=aarch64:")
+
+    conflicts(
+        "cuda_arch=none",
+        when="+cuda",
+        msg="Must specify CUDA compute capabilities of your GPU, see "
+        "https://developer.nvidia.com/cuda-gpus",
+    )
+
+    # https://github.com/google/jax/issues/19992
+    conflicts("@0.4.4:", when="target=ppc64le:")
+
+    # Fails to build with freshly released CUDA (#48708).
+    conflicts("^cuda@12.8:", when="@:0.4.31")
+
+    def url_for_version(self, version):
+        url = "https://github.com/jax-ml/jax/archive/refs/tags/{}-v{}.tar.gz"
+        if version >= Version("0.4.33"):
+            name = "jax"
+        else:
+            name = "jaxlib"
+        return url.format(name, version)
+
+    def install(self, spec, prefix):
+        # https://jax.readthedocs.io/en/latest/developer.html
+        args = ["build/build.py"]
+
+        if spec.satisfies("@0.4.36:"):
+            args.append("build")
+
+            if spec.satisfies("+cuda"):
+                args.append("--wheels=jaxlib,jax-cuda-plugin,jax-cuda-pjrt")
+            elif spec.satisfies("+rocm"):
+                args.append("--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt")
+            else:
+                args.append("--wheels=jaxlib")
+
+        if spec.satisfies("@0.4.32:"):
+            if spec.satisfies("%clang"):
+                args.append("--use_clang=true")
+            else:
+                args.append("--use_clang=false")
+
+        if "+cuda" in spec:
+            capabilities = CudaPackage.compute_capabilities(spec.variants["cuda_arch"].value)
+            args.append(f"--cuda_compute_capabilities={','.join(capabilities)}")
+            if spec.satisfies("@:0.4.35"):
+                args.append("--enable_cuda")
+            if spec.satisfies("@0.4.32:"):
+                args.extend(
+                    [
+                        f"--bazel_options=--repo_env=LOCAL_CUDA_PATH={spec['cuda'].prefix}",
+                        f"--bazel_options=--repo_env=LOCAL_CUDNN_PATH={spec['cudnn'].prefix}",
+                    ]
+                )
+            else:
+                args.extend(
+                    [f"--cuda_path={spec['cuda'].prefix}", f"--cudnn_path={spec['cudnn'].prefix}"]
+                )
+
+        if "+nccl" in spec and spec.satisfies("@0.4.32:"):
+            args.append(f"--bazel_options=--repo_env=LOCAL_NCCL_PATH={spec['nccl'].prefix}")
+
+        if "+rocm" in spec:
+            args.extend(["--enable_rocm", f"--rocm_path={self.spec['hip'].prefix}"])
+
+        args.extend(
+            [
+                f"--bazel_options=--jobs={make_jobs}",
+                "--bazel_startup_options=--nohome_rc",
+                "--bazel_startup_options=--nosystem_rc",
+            ]
+        )
+
+        python(*args)
+        whl = glob.glob(join_path("dist", "*.whl"))[0]
+        pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)