diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py
index 7985472e35eec7abd2a7d0816f84d70b71b5bd0d..101f82718775b1f67f4be365a36bddf4e153774f 100644
--- a/packages/hxtorch/package.py
+++ b/packages/hxtorch/package.py
@@ -33,6 +33,8 @@ class Hxtorch(build_brainscales.BuildBrainscales):
     version('7.0-rc1-fixup2', tag='hxtorch-7.0-rc1-fixup2')
     version('7.0-rc1-fixup1', branch='waf')
 
+    patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:")
+
     deps_hxtorch_core = [
         # compiler for the BrainScaleS-2 embedded processor ("PPU"); needed for
         # building/linking, at runtime and for testing
diff --git a/packages/hxtorch/remove-superfluous-pytorch-include.patch b/packages/hxtorch/remove-superfluous-pytorch-include.patch
new file mode 100644
index 0000000000000000000000000000000000000000..f7a0ad279c6cc7921bc8d3dd09d2221feafcfaa2
--- /dev/null
+++ b/packages/hxtorch/remove-superfluous-pytorch-include.patch
@@ -0,0 +1,12 @@
+diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp
+index aaf670f..0e7df7d 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,6 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
+-#include <ATen/SparseTensorUtils.h>
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {
diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py
index c61589f3ee9423fbed02451698dabc926aed0d19..6f09542d8dd2ff23b6af8af806ae255f560731cd 100644
--- a/packages/jaxsnn/package.py
+++ b/packages/jaxsnn/package.py
@@ -32,6 +32,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales):
     version('8.0-a2', tag='jaxsnn-8.0-a2')
     version('8.0-a1', tag='jaxsnn-8.0-a1')
 
+    patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:")
+
     # dependencies inherited from hxtorch.core
     for dep, dep_kw in hxtorch.Hxtorch.deps_hxtorch_core:
         depends_on(dep, **dep_kw)
diff --git a/packages/jaxsnn/remove-superfluous-pytorch-include.patch b/packages/jaxsnn/remove-superfluous-pytorch-include.patch
new file mode 100644
index 0000000000000000000000000000000000000000..f7a0ad279c6cc7921bc8d3dd09d2221feafcfaa2
--- /dev/null
+++ b/packages/jaxsnn/remove-superfluous-pytorch-include.patch
@@ -0,0 +1,12 @@
+diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp
+index aaf670f..0e7df7d 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,6 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
+-#include <ATen/SparseTensorUtils.h>
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {