diff --git a/packages/py-chex/package.py b/packages/py-chex/package.py new file mode 100644 index 0000000000000000000000000000000000000000..7e41890f4588105fc58e040d90b9f4a1fa84540c --- /dev/null +++ b/packages/py-chex/package.py @@ -0,0 +1,36 @@ +# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + + +from spack.package import * + + +# EBRAINS: based on spack/0.21.2 +class PyChex(PythonPackage): + """Chex is a library of utilities for helping to write reliable JAX code.""" + + homepage = "https://github.com/deepmind/chex" + pypi = "chex/chex-0.1.0.tar.gz" + + # begin EBRAINS (added): bring upstream + version("0.1.7", sha256="74ed49799ac4d229881456d468136f1b19a9f9839e3de72b058824e2a4f4dedd") + version("0.1.5", sha256="686858320f8f220c82a6c7eeb54dcdcaa4f3d7f66690dacd13a24baa1ee8299e") + # end EBRAINS + version("0.1.0", sha256="9e032058f5fed2fc1d5e9bf8e12ece5910cf6a478c12d402b6d30984695f2161") + + depends_on("python@3.7:", type=("build", "run")) + depends_on("py-setuptools", type="build") + depends_on("py-absl-py@0.9.0:", type=("build", "run")) + # begin EBRAINS (added): bring upstream + depends_on("py-typing-extensions@4.2.0:", when="@0.1.6: ^python@:3.10", type=("build", "run")) + # end EBRAINS + depends_on("py-dm-tree@0.1.5:", type=("build", "run")) + depends_on("py-jax@0.1.55:", type=("build", "run")) + # begin EBRAINS (added): bring upstream + depends_on("py-jax@0.4.6:", when="@0.1.7:", type=("build", "run")) + # end EBRAINS + depends_on("py-jaxlib@0.1.37:", type=("build", "run")) + depends_on("py-numpy@1.18.0:", type=("build", "run")) + depends_on("py-toolz@0.9.0:", type=("build", "run")) diff --git a/packages/py-jax/package.py b/packages/py-jax/package.py new file mode 100644 index 0000000000000000000000000000000000000000..233de5b59b40af38d3d4e604477eb91e2b683b3a --- /dev/null +++ b/packages/py-jax/package.py @@ -0,0 +1,72 @@ +# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + + +from spack.package import * + + +# EBRAINS: based on spack/0.21.2 +class PyJax(PythonPackage): + """JAX is Autograd and XLA, brought together for high-performance + machine learning research. With its updated version of Autograd, + JAX can automatically differentiate native Python and NumPy + functions. It can differentiate through loops, branches, + recursion, and closures, and it can take derivatives of + derivatives of derivatives. It supports reverse-mode + differentiation (a.k.a. backpropagation) via grad as well as + forward-mode differentiation, and the two can be composed + arbitrarily to any order.""" + + homepage = "https://github.com/google/jax" + pypi = "jax/jax-0.2.25.tar.gz" + + # begin EBRAINS (added): bring upstream + version("0.4.13", sha256="03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa") + # end EBRAINS + version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae") + version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd") + version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446") + + # begin EBRAINS (modified): bring upstream + depends_on("python@3.7:", type=("build", "run")) + depends_on("python@3.8:", when="@0.4:", type=("build", "run")) + depends_on("python@3.9:", when="@0.4.14:", type=("build", "run")) + depends_on("py-setuptools", type="build") + depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run")) + depends_on("py-numpy@1.21:", when="@0.4.9:", type=("build", "run")) + depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) + depends_on("py-numpy@1.18:", type=("build", "run")) + depends_on("py-opt-einsum", type=("build", "run")) + depends_on("py-scipy@1.2.1:", type=("build", "run")) + depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run")) + depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run")) + depends_on("py-ml-dtypes@0.2.0:", when="@0.4.14:", type=("build", "run")) + depends_on("py-ml-dtypes@0.1.0:", when="@0.4.9:", type=("build", "run")) + depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run")) + depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9", type="run") + # end EBRAINS + + # See _minimum_jaxlib_version in jax/version.py + # begin EBRAINS (modified): bring upstream + jax_to_jaxlib = { + "0.4.14": "0.4.14", + "0.4.13": "0.4.13", + "0.4.3": "0.4.2", + "0.3.23": "0.3.15", + "0.2.25": "0.1.69", + } + # end EBRAINS + + for jax, jaxlib in jax_to_jaxlib.items(): + # begin EBRAINS (modified): bring upstream + depends_on(f"py-jaxlib@{jaxlib}", when=f"@{jax}", type=("build", "run")) + # end EBRAINS + + # Historical dependencies + depends_on("py-absl-py", when="@:0.3", type=("build", "run")) + depends_on("py-typing-extensions", when="@:0.3", type=("build", "run")) + # begin EBRAINS (deleted): + # depends_on("py-etils+epath", when="@0.3", type=("build", "run")) + # end EBRAINS diff --git a/packages/py-jaxlib/package.py b/packages/py-jaxlib/package.py new file mode 100644 index 0000000000000000000000000000000000000000..af04f737eceb3cfa00f7538da14dd8d5043df7cb --- /dev/null +++ b/packages/py-jaxlib/package.py @@ -0,0 +1,109 @@ +# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + +import tempfile + +from spack.package import * + + +# EBRAINS: based on spack/0.21.2 +class PyJaxlib(PythonPackage, CudaPackage): + """XLA library for Jax""" + + homepage = "https://github.com/google/jax" + url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz" + + tmp_path = "" + buildtmp = "" + + # begin EBRAINS (added): bring upstream + version("0.4.13", sha256="45766238b57b992851763c64bc943858aebafe4cad7b3df6cde844690bc34293") + # end EBRAINS + version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910") + version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d") + version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35") + + # begin EBRAINS (deleted): Variant with default=False is provided by CudaPackage + # variant("cuda", default=True, description="Build with CUDA") + # end EBRAINS + + # jaxlib/setup.py + # begin EBRAINS (modified): bring upstream + depends_on("python@3.9:", when="@0.4.14:", type=("build", "run")) + depends_on("python@3.8:", when="@0.4:", type=("build", "run")) + depends_on("python@3.7:", type=("build", "run")) + depends_on("py-setuptools", type="build") + depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run")) + depends_on("py-numpy@1.21:", when="@0.4.9:", type=("build", "run")) + depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run")) + depends_on("py-numpy@1.18:", type=("build", "run")) + depends_on("py-scipy@1.5:", type=("build", "run")) + depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run")) + depends_on("py-ml-dtypes@0.2.0:", when="@0.4.14:", type=("build", "run")) + depends_on("py-ml-dtypes@0.1.0:", when="@0.4.9:", type=("build", "run")) + depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run")) + # end EBRAINS + + # .bazelversion + depends_on("bazel@5.1.1:5.9", when="@0.3:", type="build") + # https://github.com/google/jax/issues/8440 + depends_on("bazel@4.1:4", when="@0.1", type="build") + + # README.md + # begin EBRAINS (added): bring upstream + depends_on("cuda@11.8:", when="@0.4.8:+cuda") + # end EBRAINS + depends_on("cuda@11.4:", when="@0.4:+cuda") + depends_on("cuda@11.1:", when="@0.3+cuda") + # https://github.com/google/jax/issues/12614 + depends_on("cuda@11.1:11.7.0", when="@0.1+cuda") + depends_on("cudnn@8.2:", when="@0.4:+cuda") + depends_on("cudnn@8.0.5:", when="+cuda") + + # Historical dependencies + depends_on("py-absl-py", when="@:0.3", type=("build", "run")) + depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run")) + + def patch(self): + self.tmp_path = tempfile.mkdtemp(prefix="spack") + self.buildtmp = tempfile.mkdtemp(prefix="spack") + filter_file( + 'f"--output_path={output_path}",', + 'f"--output_path={output_path}",' + f' "--sources_path={self.tmp_path}",' + ' "--nohome_rc",' + ' "--nosystem_rc",' + f' "--jobs={make_jobs}",', + "build/build.py", + string=True, + ) + filter_file( + "args = parser.parse_args()", + "args, junk = parser.parse_known_args()", + "build/build_wheel.py", + string=True, + ) + + def install(self, spec, prefix): + args = [] + args.append("build/build.py") + if "+cuda" in spec: + args.append("--enable_cuda") + args.append("--cuda_path={0}".format(self.spec["cuda"].prefix)) + args.append("--cudnn_path={0}".format(self.spec["cudnn"].prefix)) + capabilities = ",".join( + "{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value + ) + args.append("--cuda_compute_capabilities={0}".format(capabilities)) + args.append( + "--bazel_startup_options=" + "--output_user_root={0}".format(self.wrapped_package_object.buildtmp) + ) + python(*args) + with working_dir(self.wrapped_package_object.tmp_path): + args = std_pip_args + ["--prefix=" + self.prefix, "."] + pip(*args) + remove_linked_tree(self.wrapped_package_object.tmp_path) + remove_linked_tree(self.wrapped_package_object.buildtmp)