From 46516cfbb41a19a921e3cfd0b19618709cd42ab0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de> Date: Mon, 31 Mar 2025 15:25:42 +0200 Subject: [PATCH] fix(jaxsnn): tests when using modern jax --- packages/jaxsnn/newjax.patch | 41 ++++++++++++++++++++++++++++++++++++ packages/jaxsnn/package.py | 1 + 2 files changed, 42 insertions(+) create mode 100644 packages/jaxsnn/newjax.patch diff --git a/packages/jaxsnn/newjax.patch b/packages/jaxsnn/newjax.patch new file mode 100644 index 00000000..96b2d170 --- /dev/null +++ b/packages/jaxsnn/newjax.patch @@ -0,0 +1,41 @@ +From d12ff24ccb39f861067661b01973862e83552baf Mon Sep 17 00:00:00 2001 +From: Elias Arnold <elias.arnold@kip.uni-heidelberg.de> +Date: Mon, 31 Mar 2025 15:13:03 +0200 +Subject: [PATCH] fix: tests for new jax + +Change-Id: I278454c7a51c0c15071a7ab8496a9655c52ff495 +--- + +diff --git a/tests/sw/event/hardware/utils_test.py b/tests/sw/event/hardware/utils_test.py +index ab73452..0753947 100644 +--- a/tests/sw/event/hardware/utils_test.py ++++ b/tests/sw/event/hardware/utils_test.py +@@ -15,12 +15,12 @@ + rng = random.PRNGKey(42) + with_noise = add_noise_batch(spikes, rng, std=1) + assert_array_equal( +- with_noise.idx, np.array([[0, 1, 2, 5, 3, 4, 6, 7, 8, 9]]) ++ with_noise.idx, np.array([[0, 1, 2, 3, 4, 6, 5, 7, 8, 9]]) + ) + + with_noise = add_noise_batch(spikes, rng, std=3) + assert_array_equal( +- with_noise.idx, np.array([[2, 1, 0, 5, 6, 7, 3, 4, 8, 9]]) ++ with_noise.idx, np.array([[0, 6, 1, 2, 3, 4, 5, 7, 8, 9]]) + ) + + def test_sort_batch(self): +diff --git a/tests/sw/event/tasks/constant_test.py b/tests/sw/event/tasks/constant_test.py +index be82deb..a7906af 100644 +--- a/tests/sw/event/tasks/constant_test.py ++++ b/tests/sw/event/tasks/constant_test.py +@@ -52,7 +52,7 @@ + ) + + # init weights +- rng = random.PRNGKey(42) ++ rng = random.PRNGKey(45) + weights = init_fn(rng, input_shape) + + loss_fn = partial( + diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py index 8cca2481..25593b9d 100644 --- a/packages/jaxsnn/package.py +++ b/packages/jaxsnn/package.py @@ -46,6 +46,7 @@ class Jaxsnn(build_brainscales.BuildBrainscales): extends('python') patch("include-SparseTensorUtils.patch", when="@:8.0-a5") + patch("newjax.patch", when="@:10.0-a1 ^py-jax@0.5:") def install_test(self): with working_dir('spack-test', create=True): -- GitLab