diff --git a/packages/jaxsnn/newjax.patch b/packages/jaxsnn/newjax.patch
new file mode 100644
index 0000000000000000000000000000000000000000..96b2d170f5bb6f715df7fcd534c3abbfc6563bb7
--- /dev/null
+++ b/packages/jaxsnn/newjax.patch
@@ -0,0 +1,41 @@
+From d12ff24ccb39f861067661b01973862e83552baf Mon Sep 17 00:00:00 2001
+From: Elias Arnold <elias.arnold@kip.uni-heidelberg.de>
+Date: Mon, 31 Mar 2025 15:13:03 +0200
+Subject: [PATCH] fix: tests for new jax
+
+Change-Id: I278454c7a51c0c15071a7ab8496a9655c52ff495
+---
+
+diff --git a/tests/sw/event/hardware/utils_test.py b/tests/sw/event/hardware/utils_test.py
+index ab73452..0753947 100644
+--- a/tests/sw/event/hardware/utils_test.py
++++ b/tests/sw/event/hardware/utils_test.py
+@@ -15,12 +15,12 @@
+         rng = random.PRNGKey(42)
+         with_noise = add_noise_batch(spikes, rng, std=1)
+         assert_array_equal(
+-            with_noise.idx, np.array([[0, 1, 2, 5, 3, 4, 6, 7, 8, 9]])
++            with_noise.idx, np.array([[0, 1, 2, 3, 4, 6, 5, 7, 8, 9]])
+         )
+ 
+         with_noise = add_noise_batch(spikes, rng, std=3)
+         assert_array_equal(
+-            with_noise.idx, np.array([[2, 1, 0, 5, 6, 7, 3, 4, 8, 9]])
++            with_noise.idx, np.array([[0, 6, 1, 2, 3, 4, 5, 7, 8, 9]])
+         )
+ 
+     def test_sort_batch(self):
+diff --git a/tests/sw/event/tasks/constant_test.py b/tests/sw/event/tasks/constant_test.py
+index be82deb..a7906af 100644
+--- a/tests/sw/event/tasks/constant_test.py
++++ b/tests/sw/event/tasks/constant_test.py
+@@ -52,7 +52,7 @@
+         )
+ 
+         # init weights
+-        rng = random.PRNGKey(42)
++        rng = random.PRNGKey(45)
+         weights = init_fn(rng, input_shape)
+ 
+         loss_fn = partial(
+
diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py
index 8cca2481583aecf70a920c68c717dcaaa6e077bb..25593b9dd3748701277ce665089994fef9fad3f8 100644
--- a/packages/jaxsnn/package.py
+++ b/packages/jaxsnn/package.py
@@ -46,6 +46,7 @@ class Jaxsnn(build_brainscales.BuildBrainscales):
     extends('python')
 
     patch("include-SparseTensorUtils.patch", when="@:8.0-a5")
+    patch("newjax.patch", when="@:10.0-a1 ^py-jax@0.5:")
 
     def install_test(self):
         with working_dir('spack-test', create=True):