diff --git a/packages/py-jax/package.py b/packages/py-jax/package.py
index f99d02a83a2d4a9b5d08e513a74dbea50a104b51..c896d469acd943abf04318c1c81c0a6a27d87141 100644
--- a/packages/py-jax/package.py
+++ b/packages/py-jax/package.py
@@ -19,14 +19,16 @@ class PyJax(PythonPackage):
     license("Apache-2.0")
     maintainers("adamjstewart", "jonas-eschle")
 
-    # version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
-    # version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
-    # version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
-    # version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
-    # version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
-    # version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
-    # version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
-    # version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
+    version("0.5.2", sha256="2aef7d1912df329470c47ce8f2e6521c105e84aa620311494048c391235087c6")
+    version("0.5.1", sha256="c098f74846ee718165bbfa83521ae10cd52cf50b47f043f8b33a6cfd3c20ddfd")
+    version("0.5.0", sha256="49df70bf293a345a7fb519f71193506d37a024c4f850b358042eb32d502c81c8")
+    version("0.4.38", sha256="43bae65881628319e0a2148e8f81a202fbc2b8d048e35c7cb1df2416672fa4a8")
+    version("0.4.37", sha256="7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b")
+    version("0.4.36", sha256="088bff0575d01fc82682a9af4eb07433d60de7e5164686bd2cea3439492e608a")
+    version("0.4.35", sha256="c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e")
+    version("0.4.34", sha256="44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db")
+    version("0.4.33", sha256="f0d788692fc0179653066c9e1c64e57311b8c15a389837fd7baf328abefcbb92")
+    version("0.4.32", sha256="eb703909968da161894fb6135a931c5f3d2aab64fff7cba5fcb803ce6d968e08")
     version("0.4.31", sha256="fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287")
     version("0.4.30", sha256="94d74b5b2db0d80672b61d83f1f63ebf99d2ab7398ec12b2ca0c9d1e97afe577")
     version("0.4.29", sha256="12904571eaefddcdc8c3b8d4936482b783d5a216e99ef5adcd3522fdfb4fc186")
diff --git a/packages/py-jaxlib/package.py b/packages/py-jaxlib/package.py
index cd679311a25bc1c8658662aa8a6c09766e775ff3..6db1dd8066642531333e3fae496200725f089d4d 100644
--- a/packages/py-jaxlib/package.py
+++ b/packages/py-jaxlib/package.py
@@ -39,14 +39,16 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
     license("Apache-2.0")
     maintainers("adamjstewart", "jonas-eschle")
 
-    # version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
-    # version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
-    # version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
-    # version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
-    # version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
-    # version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
-    # version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
-    # version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
+    version("0.5.2", sha256="8e9de1e012dd65fc4a9eec8af4aa2bf6782767130a5d8e1c1e342b7d658280fe")
+    version("0.5.1", sha256="e74b1209517682075933f757d646b73040d09fe39ee3e9e4cd398407dd0902d2")
+    version("0.5.0", sha256="04cc2eeb2e7ce1916674cea03a7d75a59d583ddb779d5104e103a2798a283ce9")
+    version("0.4.38", sha256="ca1e63c488d505b9c92e81499e8b06cc1977319c50d64a0e58adbd2dae1a625c")
+    version("0.4.37", sha256="17a8444a931f26edda8ccbc921ab71c6bf46857287b1db186deebd357e526870")
+    version("0.4.36", sha256="442bfdf491b509995aa160361e23a9db488d5b97c87e6648cc733501b06eda77")
+    version("0.4.35", sha256="65e086708ae56670676b7b2340ad82b901d8c9993d1241a839c8990bdb8d6212")
+    version("0.4.34", sha256="d3a75ad667772309ade81350fa70c4a78028a920028800282e46d8383c0ee6bb")
+    version("0.4.33", sha256="122a806e80fc1cd7d8ffaf9620701f2cb8e4fe22271c2cec53a9c60b30bd4c31")
+    version("0.4.32", sha256="3fe36d596e4d640443c0a5c533845c74fbc4341e024d9bb1cd75cb49f5f419c2")
     version("0.4.31", sha256="022ea1347f9b21cbea31410b3d650d976ea4452a48ea7317a5f91c238031bf94")
     version("0.4.30", sha256="0ef9635c734d9bbb44fcc87df4f1c3ccce1cfcfd243572c80d36fcdf826fe1e6")
     version("0.4.29", sha256="3a8005f4f62d35a5aad7e3dbd596890b47c81cc6e34fcfe3dcb93b3ca7cb1246")
@@ -149,6 +151,11 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
         sha256="4dfb9f32d4eeb0a0fb3a6f4124c4170e3fe49511f1b768cd634c78d489962275",
         when="@:0.4.25",
     )
+    patch(
+        "https://github.com/jax-ml/jax/commit/91cae595e427969251a79d1ab3d6d5392dd8e6a9.patch?full_index=1",
+        sha256="265c8f682df3f573f405b8e089827de81be402e999abd11bbf82e9e49e688152",
+        when="@0.5.0:0.5.2 ^cuda@12.3:", # PACKED_ALIGNMENT introduced maybe earlier?
+    )
 
     # Might be able to be applied to earlier versions
     # backports https://github.com/abseil/abseil-cpp/pull/1732