diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py index 7985472e35eec7abd2a7d0816f84d70b71b5bd0d..101f82718775b1f67f4be365a36bddf4e153774f 100644 --- a/packages/hxtorch/package.py +++ b/packages/hxtorch/package.py @@ -33,6 +33,8 @@ class Hxtorch(build_brainscales.BuildBrainscales): version('7.0-rc1-fixup2', tag='hxtorch-7.0-rc1-fixup2') version('7.0-rc1-fixup1', branch='waf') + patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:") + deps_hxtorch_core = [ # compiler for the BrainScaleS-2 embedded processor ("PPU"); needed for # building/linking, at runtime and for testing diff --git a/packages/hxtorch/remove-superfluous-pytorch-include.patch b/packages/hxtorch/remove-superfluous-pytorch-include.patch new file mode 100644 index 0000000000000000000000000000000000000000..f7a0ad279c6cc7921bc8d3dd09d2221feafcfaa2 --- /dev/null +++ b/packages/hxtorch/remove-superfluous-pytorch-include.patch @@ -0,0 +1,12 @@ +diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp +index aaf670f..0e7df7d 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,6 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> +-#include <ATen/SparseTensorUtils.h> + #include <log4cxx/logger.h> + + namespace hxtorch::spiking { diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py index c61589f3ee9423fbed02451698dabc926aed0d19..6f09542d8dd2ff23b6af8af806ae255f560731cd 100644 --- a/packages/jaxsnn/package.py +++ b/packages/jaxsnn/package.py @@ -32,6 +32,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales): version('8.0-a2', tag='jaxsnn-8.0-a2') version('8.0-a1', tag='jaxsnn-8.0-a1') + patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:") + # dependencies inherited from hxtorch.core for dep, dep_kw in hxtorch.Hxtorch.deps_hxtorch_core: depends_on(dep, **dep_kw) diff --git a/packages/jaxsnn/remove-superfluous-pytorch-include.patch b/packages/jaxsnn/remove-superfluous-pytorch-include.patch new file mode 100644 index 0000000000000000000000000000000000000000..f7a0ad279c6cc7921bc8d3dd09d2221feafcfaa2 --- /dev/null +++ b/packages/jaxsnn/remove-superfluous-pytorch-include.patch @@ -0,0 +1,12 @@ +diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp +index aaf670f..0e7df7d 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,6 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> +-#include <ATen/SparseTensorUtils.h> + #include <log4cxx/logger.h> + + namespace hxtorch::spiking {