From 8431e2e0e8680727e0b3f4e656fb40c575b5ca9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de> Date: Wed, 24 Jul 2024 10:58:37 +0200 Subject: [PATCH] fix: #58 --- .../hxtorch/include-SparseTensorUtils.patch | 17 +++++++++++++++++ packages/hxtorch/package.py | 2 ++ packages/jaxsnn/include-SparseTensorUtils.patch | 17 +++++++++++++++++ packages/jaxsnn/package.py | 2 ++ 4 files changed, 38 insertions(+) create mode 100644 packages/hxtorch/include-SparseTensorUtils.patch create mode 100644 packages/jaxsnn/include-SparseTensorUtils.patch diff --git a/packages/hxtorch/include-SparseTensorUtils.patch b/packages/hxtorch/include-SparseTensorUtils.patch new file mode 100644 index 00000000..de7584e4 --- /dev/null +++ b/packages/hxtorch/include-SparseTensorUtils.patch @@ -0,0 +1,17 @@ +diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp +index aaf670f..39322c8 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,12 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> ++#if __has_include(<ATen/native/SparseTensorUtils.h>) ++// moved in py-torch@2.1 ++#include <ATen/native/SparseTensorUtils.h> ++#else + #include <ATen/SparseTensorUtils.h> ++#endif + #include <log4cxx/logger.h> + + namespace hxtorch::spiking { diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py index 7985472e..c1450d95 100644 --- a/packages/hxtorch/package.py +++ b/packages/hxtorch/package.py @@ -87,6 +87,8 @@ class Hxtorch(build_brainscales.BuildBrainscales): extends('python') + patch("include-SparseTensorUtils.patch", when="@:8.0-a5") + def install_test(self): with working_dir('spack-test', create=True): old_pythonpath = os.environ.get('PYTHONPATH', '') diff --git a/packages/jaxsnn/include-SparseTensorUtils.patch b/packages/jaxsnn/include-SparseTensorUtils.patch new file mode 100644 index 00000000..de7584e4 --- /dev/null +++ b/packages/jaxsnn/include-SparseTensorUtils.patch @@ -0,0 +1,17 @@ +diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp +index aaf670f..39322c8 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,12 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> ++#if __has_include(<ATen/native/SparseTensorUtils.h>) ++// moved in py-torch@2.1 ++#include <ATen/native/SparseTensorUtils.h> ++#else + #include <ATen/SparseTensorUtils.h> ++#endif + #include <log4cxx/logger.h> + + namespace hxtorch::spiking { diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py index c61589f3..1b1c07cf 100644 --- a/packages/jaxsnn/package.py +++ b/packages/jaxsnn/package.py @@ -43,6 +43,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales): depends_on('py-tree-math', type=('build', 'link', 'run')) extends('python') + patch("include-SparseTensorUtils.patch", when="@:8.0-a5") + def install_test(self): with working_dir('spack-test', create=True): old_pythonpath = os.environ.get('PYTHONPATH', '') -- GitLab