diff --git a/packages/py-vbjax/package.py b/packages/py-vbjax/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..53c7e0346e3e53e4bbb0873f9021fcbb07b7d3c6
--- /dev/null
+++ b/packages/py-vbjax/package.py
@@ -0,0 +1,26 @@
+# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+from spack.package import *
+
+
+class PyVbjax(PythonPackage):
+    """Virtual brain models in JAX."""
+
+    homepage = "https://github.com/ins-amu/vbjax"
+    pypi = "vbjax/vbjax-0.0.16.tar.gz"
+    git = "https://github.com/ins-amu/vbjax"
+
+    # 'joblib.test' requires 'pytest'. Leave out of 'import_modules' to avoid
+    # unnecessary dependencies.
+    skip_modules = ["joblib.test"]
+
+    license("Apache-2.0")
+
+    version("0.0.16")
+
+    depends_on('py-jax', type=('build', 'link', 'run'))
+    depends_on("python", type=("build", "run"))
+    depends_on("py-setuptools", type=("build", "run"))