diff --git a/packages/hxtorch/include-SparseTensorUtils.patch b/packages/hxtorch/include-SparseTensorUtils.patch new file mode 100644 index 0000000000000000000000000000000000000000..de7584e45bb1a600dfee460b5e039359c82418c4 --- /dev/null +++ b/packages/hxtorch/include-SparseTensorUtils.patch @@ -0,0 +1,17 @@ +diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp +index aaf670f..39322c8 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,12 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> ++#if __has_include(<ATen/native/SparseTensorUtils.h>) ++// moved in py-torch@2.1 ++#include <ATen/native/SparseTensorUtils.h> ++#else + #include <ATen/SparseTensorUtils.h> ++#endif + #include <log4cxx/logger.h> + + namespace hxtorch::spiking { diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py index 7985472e35eec7abd2a7d0816f84d70b71b5bd0d..c1450d958b7e47cdf00491f564924b9df8e67feb 100644 --- a/packages/hxtorch/package.py +++ b/packages/hxtorch/package.py @@ -87,6 +87,8 @@ class Hxtorch(build_brainscales.BuildBrainscales): extends('python') + patch("include-SparseTensorUtils.patch", when="@:8.0-a5") + def install_test(self): with working_dir('spack-test', create=True): old_pythonpath = os.environ.get('PYTHONPATH', '') diff --git a/packages/jaxsnn/include-SparseTensorUtils.patch b/packages/jaxsnn/include-SparseTensorUtils.patch new file mode 100644 index 0000000000000000000000000000000000000000..de7584e45bb1a600dfee460b5e039359c82418c4 --- /dev/null +++ b/packages/jaxsnn/include-SparseTensorUtils.patch @@ -0,0 +1,17 @@ +diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp +index aaf670f..39322c8 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,12 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> ++#if __has_include(<ATen/native/SparseTensorUtils.h>) ++// moved in py-torch@2.1 ++#include <ATen/native/SparseTensorUtils.h> ++#else + #include <ATen/SparseTensorUtils.h> ++#endif + #include <log4cxx/logger.h> + + namespace hxtorch::spiking { diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py index c61589f3ee9423fbed02451698dabc926aed0d19..1b1c07cf0f3da92463c2d864955b6384a5c48bb9 100644 --- a/packages/jaxsnn/package.py +++ b/packages/jaxsnn/package.py @@ -43,6 +43,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales): depends_on('py-tree-math', type=('build', 'link', 'run')) extends('python') + patch("include-SparseTensorUtils.patch", when="@:8.0-a5") + def install_test(self): with working_dir('spack-test', create=True): old_pythonpath = os.environ.get('PYTHONPATH', '')