Skip to content
Snippets Groups Projects
Commit f7c366b8 authored by Eric Müller's avatar Eric Müller :mountain_bicyclist: Committed by Eleni Mathioulaki
Browse files

feat(jax): add new versions

parent 3fbf807e
No related branches found
No related tags found
No related merge requests found
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack.package import *
# EBRAINS: based on spack/0.21.2
class PyChex(PythonPackage):
"""Chex is a library of utilities for helping to write reliable JAX code."""
homepage = "https://github.com/deepmind/chex"
pypi = "chex/chex-0.1.0.tar.gz"
# begin EBRAINS (added): bring upstream
version("0.1.7", sha256="74ed49799ac4d229881456d468136f1b19a9f9839e3de72b058824e2a4f4dedd")
version("0.1.5", sha256="686858320f8f220c82a6c7eeb54dcdcaa4f3d7f66690dacd13a24baa1ee8299e")
# end EBRAINS
version("0.1.0", sha256="9e032058f5fed2fc1d5e9bf8e12ece5910cf6a478c12d402b6d30984695f2161")
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-absl-py@0.9.0:", type=("build", "run"))
# begin EBRAINS (added): bring upstream
depends_on("py-typing-extensions@4.2.0:", when="@0.1.6: ^python@:3.10", type=("build", "run"))
# end EBRAINS
depends_on("py-dm-tree@0.1.5:", type=("build", "run"))
depends_on("py-jax@0.1.55:", type=("build", "run"))
# begin EBRAINS (added): bring upstream
depends_on("py-jax@0.4.6:", when="@0.1.7:", type=("build", "run"))
# end EBRAINS
depends_on("py-jaxlib@0.1.37:", type=("build", "run"))
depends_on("py-numpy@1.18.0:", type=("build", "run"))
depends_on("py-toolz@0.9.0:", type=("build", "run"))
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
from spack.package import *
# EBRAINS: based on spack/0.21.2
class PyJax(PythonPackage):
"""JAX is Autograd and XLA, brought together for high-performance
machine learning research. With its updated version of Autograd,
JAX can automatically differentiate native Python and NumPy
functions. It can differentiate through loops, branches,
recursion, and closures, and it can take derivatives of
derivatives of derivatives. It supports reverse-mode
differentiation (a.k.a. backpropagation) via grad as well as
forward-mode differentiation, and the two can be composed
arbitrarily to any order."""
homepage = "https://github.com/google/jax"
pypi = "jax/jax-0.2.25.tar.gz"
# begin EBRAINS (added): bring upstream
version("0.4.13", sha256="03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa")
# end EBRAINS
version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd")
version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446")
# begin EBRAINS (modified): bring upstream
depends_on("python@3.7:", type=("build", "run"))
depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("python@3.9:", when="@0.4.14:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run"))
depends_on("py-numpy@1.21:", when="@0.4.9:", type=("build", "run"))
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-opt-einsum", type=("build", "run"))
depends_on("py-scipy@1.2.1:", type=("build", "run"))
depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run"))
depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run"))
depends_on("py-ml-dtypes@0.2.0:", when="@0.4.14:", type=("build", "run"))
depends_on("py-ml-dtypes@0.1.0:", when="@0.4.9:", type=("build", "run"))
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run"))
depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9", type="run")
# end EBRAINS
# See _minimum_jaxlib_version in jax/version.py
# begin EBRAINS (modified): bring upstream
jax_to_jaxlib = {
"0.4.14": "0.4.14",
"0.4.13": "0.4.13",
"0.4.3": "0.4.2",
"0.3.23": "0.3.15",
"0.2.25": "0.1.69",
}
# end EBRAINS
for jax, jaxlib in jax_to_jaxlib.items():
# begin EBRAINS (modified): bring upstream
depends_on(f"py-jaxlib@{jaxlib}", when=f"@{jax}", type=("build", "run"))
# end EBRAINS
# Historical dependencies
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
depends_on("py-typing-extensions", when="@:0.3", type=("build", "run"))
# begin EBRAINS (deleted):
# depends_on("py-etils+epath", when="@0.3", type=("build", "run"))
# end EBRAINS
# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
import tempfile
from spack.package import *
# EBRAINS: based on spack/0.21.2
class PyJaxlib(PythonPackage, CudaPackage):
"""XLA library for Jax"""
homepage = "https://github.com/google/jax"
url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz"
tmp_path = ""
buildtmp = ""
# begin EBRAINS (added): bring upstream
version("0.4.13", sha256="45766238b57b992851763c64bc943858aebafe4cad7b3df6cde844690bc34293")
# end EBRAINS
version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")
# begin EBRAINS (deleted): Variant with default=False is provided by CudaPackage
# variant("cuda", default=True, description="Build with CUDA")
# end EBRAINS
# jaxlib/setup.py
# begin EBRAINS (modified): bring upstream
depends_on("python@3.9:", when="@0.4.14:", type=("build", "run"))
depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
depends_on("python@3.7:", type=("build", "run"))
depends_on("py-setuptools", type="build")
depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run"))
depends_on("py-numpy@1.21:", when="@0.4.9:", type=("build", "run"))
depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
depends_on("py-numpy@1.18:", type=("build", "run"))
depends_on("py-scipy@1.5:", type=("build", "run"))
depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run"))
depends_on("py-ml-dtypes@0.2.0:", when="@0.4.14:", type=("build", "run"))
depends_on("py-ml-dtypes@0.1.0:", when="@0.4.9:", type=("build", "run"))
depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run"))
# end EBRAINS
# .bazelversion
depends_on("bazel@5.1.1:5.9", when="@0.3:", type="build")
# https://github.com/google/jax/issues/8440
depends_on("bazel@4.1:4", when="@0.1", type="build")
# README.md
# begin EBRAINS (added): bring upstream
depends_on("cuda@11.8:", when="@0.4.8:+cuda")
# end EBRAINS
depends_on("cuda@11.4:", when="@0.4:+cuda")
depends_on("cuda@11.1:", when="@0.3+cuda")
# https://github.com/google/jax/issues/12614
depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
depends_on("cudnn@8.2:", when="@0.4:+cuda")
depends_on("cudnn@8.0.5:", when="+cuda")
# Historical dependencies
depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))
def patch(self):
self.tmp_path = tempfile.mkdtemp(prefix="spack")
self.buildtmp = tempfile.mkdtemp(prefix="spack")
filter_file(
'f"--output_path={output_path}",',
'f"--output_path={output_path}",'
f' "--sources_path={self.tmp_path}",'
' "--nohome_rc",'
' "--nosystem_rc",'
f' "--jobs={make_jobs}",',
"build/build.py",
string=True,
)
filter_file(
"args = parser.parse_args()",
"args, junk = parser.parse_known_args()",
"build/build_wheel.py",
string=True,
)
def install(self, spec, prefix):
args = []
args.append("build/build.py")
if "+cuda" in spec:
args.append("--enable_cuda")
args.append("--cuda_path={0}".format(self.spec["cuda"].prefix))
args.append("--cudnn_path={0}".format(self.spec["cudnn"].prefix))
capabilities = ",".join(
"{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value
)
args.append("--cuda_compute_capabilities={0}".format(capabilities))
args.append(
"--bazel_startup_options="
"--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
)
python(*args)
with working_dir(self.wrapped_package_object.tmp_path):
args = std_pip_args + ["--prefix=" + self.prefix, "."]
pip(*args)
remove_linked_tree(self.wrapped_package_object.tmp_path)
remove_linked_tree(self.wrapped_package_object.buildtmp)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment