diff --git a/packages/hxtorch/include-SparseTensorUtils.patch b/packages/hxtorch/include-SparseTensorUtils.patch
new file mode 100644
index 0000000000000000000000000000000000000000..de7584e45bb1a600dfee460b5e039359c82418c4
--- /dev/null
+++ b/packages/hxtorch/include-SparseTensorUtils.patch
@@ -0,0 +1,17 @@
+diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp
+index aaf670f..39322c8 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,12 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
++#if __has_include(<ATen/native/SparseTensorUtils.h>)
++// moved in py-torch@2.1
++#include <ATen/native/SparseTensorUtils.h>
++#else
+ #include <ATen/SparseTensorUtils.h>
++#endif
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {
diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py
index 7985472e35eec7abd2a7d0816f84d70b71b5bd0d..c1450d958b7e47cdf00491f564924b9df8e67feb 100644
--- a/packages/hxtorch/package.py
+++ b/packages/hxtorch/package.py
@@ -87,6 +87,8 @@ class Hxtorch(build_brainscales.BuildBrainscales):
 
     extends('python')
 
+    patch("include-SparseTensorUtils.patch", when="@:8.0-a5")
+
     def install_test(self):
         with working_dir('spack-test', create=True):
             old_pythonpath = os.environ.get('PYTHONPATH', '')
diff --git a/packages/jaxsnn/include-SparseTensorUtils.patch b/packages/jaxsnn/include-SparseTensorUtils.patch
new file mode 100644
index 0000000000000000000000000000000000000000..de7584e45bb1a600dfee460b5e039359c82418c4
--- /dev/null
+++ b/packages/jaxsnn/include-SparseTensorUtils.patch
@@ -0,0 +1,17 @@
+diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp
+index aaf670f..39322c8 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,12 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
++#if __has_include(<ATen/native/SparseTensorUtils.h>)
++// moved in py-torch@2.1
++#include <ATen/native/SparseTensorUtils.h>
++#else
+ #include <ATen/SparseTensorUtils.h>
++#endif
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {
diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py
index c61589f3ee9423fbed02451698dabc926aed0d19..1b1c07cf0f3da92463c2d864955b6384a5c48bb9 100644
--- a/packages/jaxsnn/package.py
+++ b/packages/jaxsnn/package.py
@@ -43,6 +43,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales):
     depends_on('py-tree-math', type=('build', 'link', 'run'))
     extends('python')
 
+    patch("include-SparseTensorUtils.patch", when="@:8.0-a5")
+
     def install_test(self):
         with working_dir('spack-test', create=True):
             old_pythonpath = os.environ.get('PYTHONPATH', '')