Skip to content
Snippets Groups Projects
Commit 8e64bec6 authored by Eleni Mathioulaki's avatar Eleni Mathioulaki
Browse files

add latest py-jax{,lib},bazel packages from spack upstream develop

parent a83a4646
No related branches found
No related tags found
2 merge requests!330create new experimental release,!320use Spack v0.19.2
Pipeline #22054 waiting for manual action with stage
in 1 minute and 49 seconds
--- a/third_party/zlib/gzguts.h 1980-01-01 00:00:00
+++ b/third_party/zlib/gzguts.h 2023-04-03 12:23:10
@@ -3,6 +3,10 @@
* For conditions of distribution and use, see copyright notice in zlib.h
*/
+#ifndef _WIN32
+ #include <unistd.h>
+#endif
+
#ifdef _LARGEFILE64_SOURCE
# ifndef _LARGEFILE_SOURCE
# define _LARGEFILE_SOURCE 1
This diff is collapsed.
# Copyright 2013-2021 Lawrence Livermore National Security, LLC and other # Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details. # Spack Project Developers. See the top-level COPYRIGHT file for details.
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from spack.package import * from spack.package import *
class PyJax(PythonPackage, CudaPackage): class PyJax(PythonPackage):
"""JAX is Autograd and XLA, brought together for high-performance """JAX is Autograd and XLA, brought together for high-performance
machine learning research. With its updated version of Autograd, machine learning research. With its updated version of Autograd,
JAX can automatically differentiate native Python and NumPy JAX can automatically differentiate native Python and NumPy
...@@ -21,30 +21,29 @@ class PyJax(PythonPackage, CudaPackage): ...@@ -21,30 +21,29 @@ class PyJax(PythonPackage, CudaPackage):
homepage = "https://github.com/google/jax" homepage = "https://github.com/google/jax"
pypi = "jax/jax-0.2.25.tar.gz" pypi = "jax/jax-0.2.25.tar.gz"
version("0.3.25", sha256="18bea69321cb95ea5ea913adfe5e2c1d453cade9d4cfd0dc814ecba9fc0cb6e3") version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd")
version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446") version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446")
variant("cuda", default=True, description="CUDA support") depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build") depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.18:", type=("build", "run")) depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-absl-py", type=("build", "run"))
depends_on("py-opt-einsum", type=("build", "run")) depends_on("py-opt-einsum", type=("build", "run"))
depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run"))
depends_on("py-scipy@1.2.1:", type=("build", "run")) depends_on("py-scipy@1.2.1:", type=("build", "run"))
depends_on("py-typing-extensions", type=("build", "run"))
depends_on("py-jaxlib@0.1.69:", type=("build", "run"), when="~cuda") # See _minimum_jaxlib_version in jax/version.py
depends_on("py-jaxlib@0.1.69:+cuda", type=("build", "run"), when="+cuda") jax_to_jaxlib = {
depends_on("py-jaxlib@0.3.22:", type=("build", "run"), when="@0.3.25:~cuda") "0.4.3": "0.4.2",
depends_on("py-jaxlib@0.3.22:+cuda", type=("build", "run"), when="@0.3.25:+cuda") "0.3.23": "0.3.15",
for arch in CudaPackage.cuda_arch_values: "0.2.25": "0.1.69",
depends_on( }
"py-jaxlib@0.1.69:+cuda cuda_arch={0}".format(arch),
type=("build", "run"), for jax, jaxlib in jax_to_jaxlib.items():
when="cuda_arch={0}".format(arch), depends_on(f"py-jaxlib@{jaxlib}:", when=f"@{jax}", type=("build", "run"))
)
depends_on( # Historical dependencies
"py-jaxlib@0.3.22:+cuda cuda_arch={0}".format(arch), depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
type=("build", "run"), depends_on("py-typing-extensions", when="@:0.3", type=("build", "run"))
when="@0.3.25:cuda_arch={0}".format(arch), depends_on("py-etils+epath", when="@0.3", type=("build", "run"))
)
# Copyright 2013-2021 Lawrence Livermore National Security, LLC and other # Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details. # Spack Project Developers. See the top-level COPYRIGHT file for details.
# #
# SPDX-License-Identifier: (Apache-2.0 OR MIT) # SPDX-License-Identifier: (Apache-2.0 OR MIT)
import tempfile import tempfile
import llnl.util.tty as tty
from spack.package import * from spack.package import *
...@@ -17,27 +15,57 @@ class PyJaxlib(PythonPackage, CudaPackage): ...@@ -17,27 +15,57 @@ class PyJaxlib(PythonPackage, CudaPackage):
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz" url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz"
tmp_path = "" tmp_path = ""
buildtmp = ""
version("0.3.25", sha256="73ebc7868631cd9d520385557bbd7f08762d748a5a6a1bebef0f3b8d7ba748ef") version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35") version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")
variant("cuda", default=True, description="Build with CUDA") variant("cuda", default=True, description="Build with CUDA")
depends_on("python@3.7:", type=("build", "run")) # jaxlib/setup.py
depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("py-setuptools", type="build") depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.18:", type=("build", "run")) depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.16:") depends_on("py-scipy@1.5:", type=("build", "run"))
depends_on("py-opt-einsum", type=("build", "run"), when="@0.2.8:")
depends_on("py-scipy", type=("build", "run")) # .bazelversion
depends_on("py-scipy@1.5:", type=("build", "run"), when="@0.3.14:") depends_on("bazel@5.1.1:5.9", when="@0.3:", type="build")
depends_on("py-typing-extensions", type=("build", "run"), when="@0.2.23:") # https://github.com/google/jax/issues/8440
depends_on("py-absl-py", type=("build", "run")) depends_on("bazel@4.1:4", when="@0.1", type="build")
depends_on("py-flatbuffers@1.12:2", type=("build", "run"))
# Bazel 5 not yet supported: https://github.com/google/jax/issues/8440 # README.md
depends_on("bazel@4.1.0:4", type=("build"), when="@:0.3.5") depends_on("cuda@11.4:", when="@0.4:+cuda")
depends_on("bazel@5.1.1:", type=("build"), when="@0.3.6:") depends_on("cuda@11.1:", when="@0.3+cuda")
# https://github.com/google/jax/issues/12614
depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
depends_on("cudnn@8.2:", when="@0.4:+cuda")
depends_on("cudnn@8.0.5:", when="+cuda") depends_on("cudnn@8.0.5:", when="+cuda")
depends_on("cuda@11.1:", when="+cuda")
# Historical dependencies
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
filter_file(
'f"--output_path={output_path}",',
'f"--output_path={output_path}",'
f' "--sources_path={self.tmp_path}",'
' "--nohome_rc",'
' "--nosystem_rc",'
f' "--jobs={make_jobs}",',
"build/build.py",
string=True,
)
filter_file(
"args = parser.parse_args()",
"args, junk = parser.parse_known_args()",
"build/build_wheel.py",
string=True,
)
def install(self, spec, prefix): def install(self, spec, prefix):
args = [] args = []
...@@ -50,30 +78,13 @@ class PyJaxlib(PythonPackage, CudaPackage): ...@@ -50,30 +78,13 @@ class PyJaxlib(PythonPackage, CudaPackage):
"{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value "{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value
) )
args.append("--cuda_compute_capabilities={0}".format(capabilities)) args.append("--cuda_compute_capabilities={0}".format(capabilities))
args.append("--bazel_startup_options=" "--output_user_root={0}".format(self.buildtmp)) args.append(
"--bazel_startup_options="
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
)
python(*args) python(*args)
with working_dir(self.tmp_path): with working_dir(self.wrapped_package_object.tmp_path):
tty.warn("in dir " + self.tmp_path)
args = std_pip_args + ["--prefix=" + self.prefix, "."] args = std_pip_args + ["--prefix=" + self.prefix, "."]
pip(*args) pip(*args)
remove_linked_tree(self.tmp_path) remove_linked_tree(self.wrapped_package_object.tmp_path)
remove_linked_tree(self.buildtmp) remove_linked_tree(self.wrapped_package_object.buildtmp)
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
filter_file(
"""f"--output_path={output_path}",""",
"""f"--output_path={output_path}","""
"""f"--sources_path=%s","""
"""f"--nohome_rc'","""
"""f"--nosystem_rc'",""" % self.tmp_path,
"build/build.py",
)
filter_file(
"args = parser.parse_args()",
"args,junk = parser.parse_known_args()",
"build/build_wheel.py",
string=True,
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment