diff --git a/packages/py-numpyro/package.py b/packages/py-numpyro/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..7266e0901cd8eeb2c2aa077ba744087b801ce076
--- /dev/null
+++ b/packages/py-numpyro/package.py
@@ -0,0 +1,26 @@
+# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+from spack.package import *
+
+
+class PyNumpyro(PythonPackage):
+    """Pyro PPL on NumPy"""
+
+    homepage = "https://github.com/pyro-ppl/numpyro"
+    pypi = "numpyro/numpyro-0.18.0.tar.gz"
+
+    version("0.18.0", sha256="9799abeb801f940b8257a7dcfab0542a2a527e0f27c34e7fc57e7c30dd113527")
+    version("0.17.0", sha256="b4ad89e3bea8280980bbae3f712bf218314869db24cd1a2e07d90cee2f85b143")
+    version("0.16.1", sha256="553ea56c729bdeb83d6eb6d455911113dc75f5eb25a59e3af31a7726313eee38")
+    version("0.16.0", sha256="9e68c26332752cc950caa91ac084c9c9bea2f60982b6dfca6acb14247540e681")
+
+    depends_on("python@3.9:", type=("build", "run"))
+    depends_on("py-setuptools", type="build")
+    depends_on("py-jax@0.4.25:", type=("build", "run"))
+    depends_on("py-jaxlib@0.4.25:", type=("build", "run"))
+    depends_on("py-multipledispatch", type=("build", "run"))
+    depends_on("py-numpy", type=("build", "run"))
+    depends_on("py-tqdm", type=("build", "run"))
diff --git a/packages/py-vbjax/package.py b/packages/py-vbjax/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..7533dba2a13834dfef7eaf955a1a780db4762f4c
--- /dev/null
+++ b/packages/py-vbjax/package.py
@@ -0,0 +1,40 @@
+# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+from spack.package import *
+
+
+class PyVbjax(PythonPackage):
+    """Virtual brain models in JAX."""
+
+    homepage = "https://github.com/ins-amu/vbjax"
+    pypi = "vbjax/vbjax-0.0.16.tar.gz"
+    git = "https://github.com/ins-amu/vbjax"
+
+    # 'joblib.test' requires 'pytest'. Leave out of 'import_modules' to avoid
+    # unnecessary dependencies.
+    skip_modules = ["joblib.test"]
+
+    license("Apache-2.0")
+
+    version("0.0.16", "9f0259e102a4b2889b82c0fc4007120e109145d902fdea81c8d6f1e97df7cf8d")
+
+    depends_on("py-hatchling", type=("build"))
+    depends_on("py-setuptools@45:", type=("build"))
+    depends_on("py-setuptools-scm", type=("build"))
+    depends_on('py-jax', type=('build', 'run'))
+    depends_on('py-jaxlib', type=('build', 'run'))
+    depends_on('py-numpy', type=('build', 'run'))
+    depends_on('py-scipy', type=('build', 'run'))
+    depends_on('py-numpyro', type=('build', 'run'))
+    depends_on('py-pytest', type=('test'))
+    depends_on('py-pytest-benchmark', type=('test'))
+   
+    @run_after('install')
+    @on_package_attributes(run_tests=True)
+    def install_test(self):
+        pytest = which('pytest')
+        pytest()
+
diff --git a/spack.yaml b/spack.yaml
index aa9f8e43e6f447a708b4782f23293e2d4b77d010..fc89307e30a604df1bbbf9ad60585ce4ff01ffde 100644
--- a/spack.yaml
+++ b/spack.yaml
@@ -63,6 +63,7 @@ spack:
     - py-tvb-ext-xircuits@1.1.0
     - py-viziphant@0.4.0
     - py-vbi@0.1.3.3
+    - py-vbjax@0.0.16
     - pynn-brainscales@10.0-a1
     - r-rgsl@0.1.1
     - r-sbtabvfgen@0.1