diff --git a/packages/py-chex/package.py b/packages/py-chex/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e41890f4588105fc58e040d90b9f4a1fa84540c
--- /dev/null
+++ b/packages/py-chex/package.py
@@ -0,0 +1,36 @@
+# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+
+from spack.package import *
+
+
+# EBRAINS: based on spack/0.21.2
+class PyChex(PythonPackage):
+    """Chex is a library of utilities for helping to write reliable JAX code."""
+
+    homepage = "https://github.com/deepmind/chex"
+    pypi = "chex/chex-0.1.0.tar.gz"
+
+    # begin EBRAINS (added): bring upstream
+    version("0.1.7", sha256="74ed49799ac4d229881456d468136f1b19a9f9839e3de72b058824e2a4f4dedd")
+    version("0.1.5", sha256="686858320f8f220c82a6c7eeb54dcdcaa4f3d7f66690dacd13a24baa1ee8299e")
+    # end EBRAINS
+    version("0.1.0", sha256="9e032058f5fed2fc1d5e9bf8e12ece5910cf6a478c12d402b6d30984695f2161")
+
+    depends_on("python@3.7:", type=("build", "run"))
+    depends_on("py-setuptools", type="build")
+    depends_on("py-absl-py@0.9.0:", type=("build", "run"))
+    # begin EBRAINS (added): bring upstream
+    depends_on("py-typing-extensions@4.2.0:", when="@0.1.6: ^python@:3.10", type=("build", "run"))
+    # end EBRAINS
+    depends_on("py-dm-tree@0.1.5:", type=("build", "run"))
+    depends_on("py-jax@0.1.55:", type=("build", "run"))
+    # begin EBRAINS (added): bring upstream
+    depends_on("py-jax@0.4.6:", when="@0.1.7:", type=("build", "run"))
+    # end EBRAINS
+    depends_on("py-jaxlib@0.1.37:", type=("build", "run"))
+    depends_on("py-numpy@1.18.0:", type=("build", "run"))
+    depends_on("py-toolz@0.9.0:", type=("build", "run"))
diff --git a/packages/py-jax/package.py b/packages/py-jax/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..233de5b59b40af38d3d4e604477eb91e2b683b3a
--- /dev/null
+++ b/packages/py-jax/package.py
@@ -0,0 +1,72 @@
+# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+
+from spack.package import *
+
+
+# EBRAINS: based on spack/0.21.2
+class PyJax(PythonPackage):
+    """JAX is Autograd and XLA, brought together for high-performance
+    machine learning research. With its updated version of Autograd,
+    JAX can automatically differentiate native Python and NumPy
+    functions. It can differentiate through loops, branches,
+    recursion, and closures, and it can take derivatives of
+    derivatives of derivatives. It supports reverse-mode
+    differentiation (a.k.a. backpropagation) via grad as well as
+    forward-mode differentiation, and the two can be composed
+    arbitrarily to any order."""
+
+    homepage = "https://github.com/google/jax"
+    pypi = "jax/jax-0.2.25.tar.gz"
+
+    # begin EBRAINS (added): bring upstream
+    version("0.4.13", sha256="03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa")
+    # end EBRAINS
+    version("0.4.3", sha256="d43f08f940aa30eb339965cfb3d6bee2296537b0dc2f0c65ccae3009279529ae")
+    version("0.3.23", sha256="bff436e15552a82c0ebdef32737043b799e1e10124423c57a6ae6118c3a7b6cd")
+    version("0.2.25", sha256="822e8d1e06257eaa0fdc4c0a0686c4556e9f33647fa2a766755f984786ae7446")
+
+    # begin EBRAINS (modified): bring upstream
+    depends_on("python@3.7:", type=("build", "run"))
+    depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
+    depends_on("python@3.9:", when="@0.4.14:", type=("build", "run"))
+    depends_on("py-setuptools", type="build")
+    depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run"))
+    depends_on("py-numpy@1.21:", when="@0.4.9:", type=("build", "run"))
+    depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
+    depends_on("py-numpy@1.18:", type=("build", "run"))
+    depends_on("py-opt-einsum", type=("build", "run"))
+    depends_on("py-scipy@1.2.1:", type=("build", "run"))
+    depends_on("py-scipy@1.5:", when="@0.3:", type=("build", "run"))
+    depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run"))
+    depends_on("py-ml-dtypes@0.2.0:", when="@0.4.14:", type=("build", "run"))
+    depends_on("py-ml-dtypes@0.1.0:", when="@0.4.9:", type=("build", "run"))
+    depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run"))
+    depends_on("py-importlib-metadata@4.6:", when="@0.4.11: ^python@:3.9", type="run")
+    # end EBRAINS
+
+    # See _minimum_jaxlib_version in jax/version.py
+    # begin EBRAINS (modified): bring upstream
+    jax_to_jaxlib = {
+        "0.4.14": "0.4.14",
+        "0.4.13": "0.4.13",
+        "0.4.3": "0.4.2",
+        "0.3.23": "0.3.15",
+        "0.2.25": "0.1.69",
+    }
+    # end EBRAINS
+
+    for jax, jaxlib in jax_to_jaxlib.items():
+        # begin EBRAINS (modified): bring upstream
+        depends_on(f"py-jaxlib@{jaxlib}", when=f"@{jax}", type=("build", "run"))
+        # end EBRAINS
+
+    # Historical dependencies
+    depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
+    depends_on("py-typing-extensions", when="@:0.3", type=("build", "run"))
+    # begin EBRAINS (deleted):
+    # depends_on("py-etils+epath", when="@0.3", type=("build", "run"))
+    # end EBRAINS
diff --git a/packages/py-jaxlib/package.py b/packages/py-jaxlib/package.py
new file mode 100644
index 0000000000000000000000000000000000000000..af04f737eceb3cfa00f7538da14dd8d5043df7cb
--- /dev/null
+++ b/packages/py-jaxlib/package.py
@@ -0,0 +1,109 @@
+# Copyright 2013-2023 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+import tempfile
+
+from spack.package import *
+
+
+# EBRAINS: based on spack/0.21.2
+class PyJaxlib(PythonPackage, CudaPackage):
+    """XLA library for Jax"""
+
+    homepage = "https://github.com/google/jax"
+    url = "https://github.com/google/jax/archive/refs/tags/jaxlib-v0.1.74.tar.gz"
+
+    tmp_path = ""
+    buildtmp = ""
+
+    # begin EBRAINS (added): bring upstream
+    version("0.4.13", sha256="45766238b57b992851763c64bc943858aebafe4cad7b3df6cde844690bc34293")
+    # end EBRAINS
+    version("0.4.3", sha256="2104735dc22be2b105e5517bd5bc6ae97f40e8e9e54928cac1585c6112a3d910")
+    version("0.3.22", sha256="680a6f5265ba26d5515617a95ae47244005366f879a5c321782fde60f34e6d0d")
+    version("0.1.74", sha256="bbc78c7a4927012dcb1b7cd135c7521f782d7dad516a2401b56d3190f81afe35")
+
+    # begin EBRAINS (deleted): Variant with default=False is provided by CudaPackage
+    # variant("cuda", default=True, description="Build with CUDA")
+    # end EBRAINS
+
+    # jaxlib/setup.py
+    # begin EBRAINS (modified): bring upstream
+    depends_on("python@3.9:", when="@0.4.14:", type=("build", "run"))
+    depends_on("python@3.8:", when="@0.4:", type=("build", "run"))
+    depends_on("python@3.7:", type=("build", "run"))
+    depends_on("py-setuptools", type="build")
+    depends_on("py-numpy@1.22:", when="@0.4.14:", type=("build", "run"))
+    depends_on("py-numpy@1.21:", when="@0.4.9:", type=("build", "run"))
+    depends_on("py-numpy@1.20:", when="@0.3:", type=("build", "run"))
+    depends_on("py-numpy@1.18:", type=("build", "run"))
+    depends_on("py-scipy@1.5:", type=("build", "run"))
+    depends_on("py-scipy@1.7:", when="@0.4.7:", type=("build", "run"))
+    depends_on("py-ml-dtypes@0.2.0:", when="@0.4.14:", type=("build", "run"))
+    depends_on("py-ml-dtypes@0.1.0:", when="@0.4.9:", type=("build", "run"))
+    depends_on("py-ml-dtypes@0.0.3:", when="@0.4.7:", type=("build", "run"))
+    # end EBRAINS
+
+    # .bazelversion
+    depends_on("bazel@5.1.1:5.9", when="@0.3:", type="build")
+    # https://github.com/google/jax/issues/8440
+    depends_on("bazel@4.1:4", when="@0.1", type="build")
+
+    # README.md
+    # begin EBRAINS (added): bring upstream
+    depends_on("cuda@11.8:", when="@0.4.8:+cuda")
+    # end EBRAINS
+    depends_on("cuda@11.4:", when="@0.4:+cuda")
+    depends_on("cuda@11.1:", when="@0.3+cuda")
+    # https://github.com/google/jax/issues/12614
+    depends_on("cuda@11.1:11.7.0", when="@0.1+cuda")
+    depends_on("cudnn@8.2:", when="@0.4:+cuda")
+    depends_on("cudnn@8.0.5:", when="+cuda")
+
+    # Historical dependencies
+    depends_on("py-absl-py", when="@:0.3", type=("build", "run"))
+    depends_on("py-flatbuffers@1.12:2", when="@0.1", type=("build", "run"))
+
+    def patch(self):
+        self.tmp_path = tempfile.mkdtemp(prefix="spack")
+        self.buildtmp = tempfile.mkdtemp(prefix="spack")
+        filter_file(
+            'f"--output_path={output_path}",',
+            'f"--output_path={output_path}",'
+            f' "--sources_path={self.tmp_path}",'
+            ' "--nohome_rc",'
+            ' "--nosystem_rc",'
+            f' "--jobs={make_jobs}",',
+            "build/build.py",
+            string=True,
+        )
+        filter_file(
+            "args = parser.parse_args()",
+            "args, junk = parser.parse_known_args()",
+            "build/build_wheel.py",
+            string=True,
+        )
+
+    def install(self, spec, prefix):
+        args = []
+        args.append("build/build.py")
+        if "+cuda" in spec:
+            args.append("--enable_cuda")
+            args.append("--cuda_path={0}".format(self.spec["cuda"].prefix))
+            args.append("--cudnn_path={0}".format(self.spec["cudnn"].prefix))
+            capabilities = ",".join(
+                "{0:.1f}".format(float(i) / 10.0) for i in spec.variants["cuda_arch"].value
+            )
+            args.append("--cuda_compute_capabilities={0}".format(capabilities))
+        args.append(
+            "--bazel_startup_options="
+            "--output_user_root={0}".format(self.wrapped_package_object.buildtmp)
+        )
+        python(*args)
+        with working_dir(self.wrapped_package_object.tmp_path):
+            args = std_pip_args + ["--prefix=" + self.prefix, "."]
+            pip(*args)
+        remove_linked_tree(self.wrapped_package_object.tmp_path)
+        remove_linked_tree(self.wrapped_package_object.buildtmp)