diff --git a/packages/py-numpyro/package.py b/packages/py-numpyro/package.py new file mode 100644 index 0000000000000000000000000000000000000000..7266e0901cd8eeb2c2aa077ba744087b801ce076 --- /dev/null +++ b/packages/py-numpyro/package.py @@ -0,0 +1,26 @@ +# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + +from spack.package import * + + +class PyNumpyro(PythonPackage): + """Pyro PPL on NumPy""" + + homepage = "https://github.com/pyro-ppl/numpyro" + pypi = "numpyro/numpyro-0.18.0.tar.gz" + + version("0.18.0", sha256="9799abeb801f940b8257a7dcfab0542a2a527e0f27c34e7fc57e7c30dd113527") + version("0.17.0", sha256="b4ad89e3bea8280980bbae3f712bf218314869db24cd1a2e07d90cee2f85b143") + version("0.16.1", sha256="553ea56c729bdeb83d6eb6d455911113dc75f5eb25a59e3af31a7726313eee38") + version("0.16.0", sha256="9e68c26332752cc950caa91ac084c9c9bea2f60982b6dfca6acb14247540e681") + + depends_on("python@3.9:", type=("build", "run")) + depends_on("py-setuptools", type="build") + depends_on("py-jax@0.4.25:", type=("build", "run")) + depends_on("py-jaxlib@0.4.25:", type=("build", "run")) + depends_on("py-multipledispatch", type=("build", "run")) + depends_on("py-numpy", type=("build", "run")) + depends_on("py-tqdm", type=("build", "run")) diff --git a/packages/py-vbjax/package.py b/packages/py-vbjax/package.py new file mode 100644 index 0000000000000000000000000000000000000000..7533dba2a13834dfef7eaf955a1a780db4762f4c --- /dev/null +++ b/packages/py-vbjax/package.py @@ -0,0 +1,40 @@ +# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) + +from spack.package import * + + +class PyVbjax(PythonPackage): + """Virtual brain models in JAX.""" + + homepage = "https://github.com/ins-amu/vbjax" + pypi = "vbjax/vbjax-0.0.16.tar.gz" + git = "https://github.com/ins-amu/vbjax" + + # 'joblib.test' requires 'pytest'. Leave out of 'import_modules' to avoid + # unnecessary dependencies. + skip_modules = ["joblib.test"] + + license("Apache-2.0") + + version("0.0.16", "9f0259e102a4b2889b82c0fc4007120e109145d902fdea81c8d6f1e97df7cf8d") + + depends_on("py-hatchling", type=("build")) + depends_on("py-setuptools@45:", type=("build")) + depends_on("py-setuptools-scm", type=("build")) + depends_on('py-jax', type=('build', 'run')) + depends_on('py-jaxlib', type=('build', 'run')) + depends_on('py-numpy', type=('build', 'run')) + depends_on('py-scipy', type=('build', 'run')) + depends_on('py-numpyro', type=('build', 'run')) + depends_on('py-pytest', type=('test')) + depends_on('py-pytest-benchmark', type=('test')) + + @run_after('install') + @on_package_attributes(run_tests=True) + def install_test(self): + pytest = which('pytest') + pytest() + diff --git a/spack.yaml b/spack.yaml index aa9f8e43e6f447a708b4782f23293e2d4b77d010..fc89307e30a604df1bbbf9ad60585ce4ff01ffde 100644 --- a/spack.yaml +++ b/spack.yaml @@ -63,6 +63,7 @@ spack: - py-tvb-ext-xircuits@1.1.0 - py-viziphant@0.4.0 - py-vbi@0.1.3.3 + - py-vbjax@0.0.16 - pynn-brainscales@10.0-a1 - r-rgsl@0.1.1 - r-sbtabvfgen@0.1