diff --git a/packages.yaml b/packages.yaml
index 1dd5e099c40089e3fe392b8c10520693e83dc82b..54619e5b7025c457912075ed202c4022b9d27ab0 100644
--- a/packages.yaml
+++ b/packages.yaml
@@ -7,3 +7,5 @@ packages:
     py-torch:
         version: [1.11.0]
         variants: [~cuda~rocm~valgrind~mkldnn~mpi~gloo+tensorpipe~onnx_ml]
+    py-jax:
+        variants: [~cuda]
diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py
index 6b36a3685216db881fcea689cf567f2159e8539d..30a5566d1fe485d9b2b1ab6ffcb1b9d05a2a51a8 100644
--- a/packages/hxtorch/package.py
+++ b/packages/hxtorch/package.py
@@ -37,6 +37,7 @@ class Hxtorch(WafPackage):
     depends_on('pkgconfig', type=('build', 'link', 'run'))
     depends_on('python@3.7.0:', type=('build', 'link', 'run')) # BrainScaleS(-2, type=('build', 'link', 'run')) only supports Python >= 3.7
     depends_on('py-h5py', type=('build', 'link', 'run')) # PyNN tests need it
+    depends_on('py-jax@0.3.25:', type=('build', 'link', 'run'))
     depends_on('py-matplotlib', type=('build', 'link', 'run'))
     depends_on('py-nose', type=('build', 'link', 'run'))
     depends_on('py-numpy', type=('build', 'link', 'run'))
@@ -47,6 +48,7 @@ class Hxtorch(WafPackage):
     depends_on('py-pylint', type=('build', 'link', 'run'))
     depends_on('py-torch@1.9.1:', type=('build', 'link', 'run'))
     depends_on('py-torchvision', type=('run')) # for demos
+    depends_on('py-tree-math', type=('build', 'link', 'run'))
     depends_on('py-pyyaml', type=('build', 'link', 'run'))
     depends_on('py-scipy', type=('build', 'link', 'run'))
     depends_on('py-sqlalchemy', type=('build', 'link', 'run'))