diff --git a/packages/py-jaxlib/package.py b/packages/py-jaxlib/package.py
index 6db1dd8066642531333e3fae496200725f089d4d..b22409790803f5b4f287d5f4e0ee8163afeb1383 100644
--- a/packages/py-jaxlib/package.py
+++ b/packages/py-jaxlib/package.py
@@ -234,5 +234,5 @@ class PyJaxlib(PythonPackage, CudaPackage, ROCmPackage):
         )
 
         python(*args)
-        whl = glob.glob(join_path("dist", "*.whl"))[0]
-        pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)
+        for whl in glob.glob(join_path("dist", "*.whl")):
+            pip(*PythonPipBuilder.std_args(self), f"--prefix={self.prefix}", whl)