diff --git a/packages.yaml b/packages.yaml index 1dd5e099c40089e3fe392b8c10520693e83dc82b..54619e5b7025c457912075ed202c4022b9d27ab0 100644 --- a/packages.yaml +++ b/packages.yaml @@ -7,3 +7,5 @@ packages: py-torch: version: [1.11.0] variants: [~cuda~rocm~valgrind~mkldnn~mpi~gloo+tensorpipe~onnx_ml] + py-jax: + variants: [~cuda] diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py index 6b36a3685216db881fcea689cf567f2159e8539d..30a5566d1fe485d9b2b1ab6ffcb1b9d05a2a51a8 100644 --- a/packages/hxtorch/package.py +++ b/packages/hxtorch/package.py @@ -37,6 +37,7 @@ class Hxtorch(WafPackage): depends_on('pkgconfig', type=('build', 'link', 'run')) depends_on('python@3.7.0:', type=('build', 'link', 'run')) # BrainScaleS(-2, type=('build', 'link', 'run')) only supports Python >= 3.7 depends_on('py-h5py', type=('build', 'link', 'run')) # PyNN tests need it + depends_on('py-jax@0.3.25:', type=('build', 'link', 'run')) depends_on('py-matplotlib', type=('build', 'link', 'run')) depends_on('py-nose', type=('build', 'link', 'run')) depends_on('py-numpy', type=('build', 'link', 'run')) @@ -47,6 +48,7 @@ class Hxtorch(WafPackage): depends_on('py-pylint', type=('build', 'link', 'run')) depends_on('py-torch@1.9.1:', type=('build', 'link', 'run')) depends_on('py-torchvision', type=('run')) # for demos + depends_on('py-tree-math', type=('build', 'link', 'run')) depends_on('py-pyyaml', type=('build', 'link', 'run')) depends_on('py-scipy', type=('build', 'link', 'run')) depends_on('py-sqlalchemy', type=('build', 'link', 'run'))