Skip to content
Snippets Groups Projects
Commit 8e64bec6 authored by Eleni Mathioulaki's avatar Eleni Mathioulaki
Browse files

add latest py-jax{,lib},bazel packages from spack upstream develop

parent a83a4646
No related branches found
No related tags found
2 merge requests!330create new experimental release,!320use Spack v0.19.2
Pipeline #22054 waiting for manual action with stage
in 1 minute and 49 seconds
--- a/third_party/zlib/gzguts.h 1980-01-01 00:00:00
+++ b/third_party/zlib/gzguts.h 2023-04-03 12:23:10
@@ -3,6 +3,10 @@
* For conditions of distribution and use, see copyright notice in zlib.h
*/
+#ifndef _WIN32
+ #include <unistd.h>
+#endif
+
#ifdef _LARGEFILE64_SOURCE
# ifndef _LARGEFILE_SOURCE
# define _LARGEFILE_SOURCE 1
This diff is collapsed.
# Copyright 2013-2021 Lawrence Livermore National Security, LLC and other
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
......@@ -7,7 +7,7 @@
from spack.package import *
class PyJax(PythonPackage, CudaPackage):
class PyJax(PythonPackage):
"""JAX is Autograd and XLA, brought together for high-performance
machine learning research. With its updated version of Autograd,
JAX can automatically differentiate native Python and NumPy
......@@ -21,30 +21,29 @@ class PyJax(PythonPackage, CudaPackage):
homepage = "https://github.com/google/jax"
pypi = "jax/jax-0.2.25.tar.gz"
version("0.3.25", sha256="18bea69321cb95ea5ea913adfe5e2c1d453cade9d4cfd0dc814ecba9fc0cb6e3")
version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd")
version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446")
variant("cuda", default=True, description="CUDA support")
depends_on("python@3.7:", type=("build", "run"))
depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-absl-py", type=("build", "run"))
depends_on("py-opt-einsum", type=("build", "run"))
depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run"))
depends_on("py-scipy@1.2.1:", type=("build", "run"))
depends_on("py-typing-extensions", type=("build", "run"))
depends_on("py-jaxlib@0.1.69:", type=("build", "run"), when="~cuda")
depends_on("py-jaxlib@0.1.69:+cuda", type=("build", "run"), when="+cuda")
depends_on("py-jaxlib@0.3.22:", type=("build", "run"), when="@0.3.25:~cuda")
depends_on("py-jaxlib@0.3.22:+cuda", type=("build", "run"), when="@0.3.25:+cuda")
for arch in CudaPackage.cuda_arch_values:
depends_on(
"py-jaxlib@0.1.69:+cuda cuda_arch={0}".format(arch),
type=("build", "run"),
when="cuda_arch={0}".format(arch),
)
depends_on(
"py-jaxlib@0.3.22:+cuda cuda_arch={0}".format(arch),
type=("build", "run"),
when="@0.3.25:cuda_arch={0}".format(arch),
)
# See _minimum_jaxlib_version in jax/version.py
jax_to_jaxlib = {
"0.4.3": "0.4.2",
"0.3.23": "0.3.15",
"0.2.25": "0.1.69",
}
for jax, jaxlib in jax_to_jaxlib.items():
depends_on(f"py-jaxlib@{jaxlib}:", when=f"@{jax}", type=("build", "run"))
# Historical dependencies
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
depends_on("py-typing-extensions", when="@:0.3", type=("build", "run"))
depends_on("py-etils+epath", when="@0.3", type=("build", "run"))
# Copyright 2013-2021 Lawrence Livermore National Security, LLC and other
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import tempfile
import llnl.util.tty as tty
from spack.package import *
......@@ -17,27 +15,57 @@ class PyJaxlib(PythonPackage, CudaPackage):
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz"
tmp_path = ""
buildtmp = ""
version("0.3.25", sha256="73ebc7868631cd9d520385557bbd7f08762d748a5a6a1bebef0f3b8d7ba748ef")
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")
variant("cuda", default=True, description="Build with CUDA")
depends_on("python@3.7:", type=("build", "run"))
# jaxlib/setup.py
depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-numpy@1.20:", type=("build", "run"), when="@0.3.16:")
depends_on("py-opt-einsum", type=("build", "run"), when="@0.2.8:")
depends_on("py-scipy", type=("build", "run"))
depends_on("py-scipy@1.5:", type=("build", "run"), when="@0.3.14:")
depends_on("py-typing-extensions", type=("build", "run"), when="@0.2.23:")
depends_on("py-absl-py", type=("build", "run"))
depends_on("py-flatbuffers@1.12:2", type=("build", "run"))
# Bazel 5 not yet supported: https://github.com/google/jax/issues/8440
depends_on("bazel@4.1.0:4", type=("build"), when="@:0.3.5")
depends_on("bazel@5.1.1:", type=("build"), when="@0.3.6:")
depends_on("py-scipy@1.5:", type=("build", "run"))
# .bazelversion
depends_on("bazel@5.1.1:5.9", when="@0.3:", type="build")
# https://github.com/google/jax/issues/8440
depends_on("bazel@4.1:4", when="@0.1", type="build")
# README.md
depends_on("cuda@11.4:", when="@0.4:+cuda")
depends_on("cuda@11.1:", when="@0.3+cuda")
# https://github.com/google/jax/issues/12614
depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
depends_on("cudnn@8.2:", when="@0.4:+cuda")
depends_on("cudnn@8.0.5:", when="+cuda")
depends_on("cuda@11.1:", when="+cuda")
# Historical dependencies
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
filter_file(
'f"--output_path={output_path}",',
'f"--output_path={output_path}",'
f' "--sources_path={self.tmp_path}",'
' "--nohome_rc",'
' "--nosystem_rc",'
f' "--jobs={make_jobs}",',
"build/build.py",
string=True,
)
filter_file(
"args = parser.parse_args()",
"args, junk = parser.parse_known_args()",
"build/build_wheel.py",
string=True,
)
def install(self, spec, prefix):
args = []
......@@ -50,30 +78,13 @@ class PyJaxlib(PythonPackage, CudaPackage):
"{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value
)
args.append("--cuda_compute_capabilities={0}".format(capabilities))
args.append("--bazel_startup_options=" "--output_user_root={0}".format(self.buildtmp))
args.append(
"--bazel_startup_options="
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
)
python(*args)
with working_dir(self.tmp_path):
tty.warn("in dir " + self.tmp_path)
with working_dir(self.wrapped_package_object.tmp_path):
args = std_pip_args + ["--prefix=" + self.prefix, "."]
pip(*args)
remove_linked_tree(self.tmp_path)
remove_linked_tree(self.buildtmp)
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
filter_file(
"""f"--output_path={output_path}",""",
"""f"--output_path={output_path}","""
"""f"--sources_path=%s","""
"""f"--nohome_rc'","""
"""f"--nosystem_rc'",""" % self.tmp_path,
"build/build.py",
)
filter_file(
"args = parser.parse_args()",
"args,junk = parser.parse_known_args()",
"build/build_wheel.py",
string=True,
)
remove_linked_tree(self.wrapped_package_object.tmp_path)
remove_linked_tree(self.wrapped_package_object.buildtmp)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment