From 8431e2e0e8680727e0b3f4e656fb40c575b5ca9f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de>
Date: Wed, 24 Jul 2024 10:58:37 +0200
Subject: [PATCH] fix: #58

---
 .../hxtorch/include-SparseTensorUtils.patch     | 17 +++++++++++++++++
 packages/hxtorch/package.py                     |  2 ++
 packages/jaxsnn/include-SparseTensorUtils.patch | 17 +++++++++++++++++
 packages/jaxsnn/package.py                      |  2 ++
 4 files changed, 38 insertions(+)
 create mode 100644 packages/hxtorch/include-SparseTensorUtils.patch
 create mode 100644 packages/jaxsnn/include-SparseTensorUtils.patch

diff --git a/packages/hxtorch/include-SparseTensorUtils.patch b/packages/hxtorch/include-SparseTensorUtils.patch
new file mode 100644
index 00000000..de7584e4
--- /dev/null
+++ b/packages/hxtorch/include-SparseTensorUtils.patch
@@ -0,0 +1,17 @@
+diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp
+index aaf670f..39322c8 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,12 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
++#if __has_include(<ATen/native/SparseTensorUtils.h>)
++// moved in py-torch@2.1
++#include <ATen/native/SparseTensorUtils.h>
++#else
+ #include <ATen/SparseTensorUtils.h>
++#endif
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {
diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py
index 7985472e..c1450d95 100644
--- a/packages/hxtorch/package.py
+++ b/packages/hxtorch/package.py
@@ -87,6 +87,8 @@ class Hxtorch(build_brainscales.BuildBrainscales):
 
     extends('python')
 
+    patch("include-SparseTensorUtils.patch", when="@:8.0-a5")
+
     def install_test(self):
         with working_dir('spack-test', create=True):
             old_pythonpath = os.environ.get('PYTHONPATH', '')
diff --git a/packages/jaxsnn/include-SparseTensorUtils.patch b/packages/jaxsnn/include-SparseTensorUtils.patch
new file mode 100644
index 00000000..de7584e4
--- /dev/null
+++ b/packages/jaxsnn/include-SparseTensorUtils.patch
@@ -0,0 +1,17 @@
+diff --git a/src/hxtorch/spiking/types.cpp b/src/hxtorch/spiking/types.cpp
+index aaf670f..39322c8 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,12 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
++#if __has_include(<ATen/native/SparseTensorUtils.h>)
++// moved in py-torch@2.1
++#include <ATen/native/SparseTensorUtils.h>
++#else
+ #include <ATen/SparseTensorUtils.h>
++#endif
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {
diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py
index c61589f3..1b1c07cf 100644
--- a/packages/jaxsnn/package.py
+++ b/packages/jaxsnn/package.py
@@ -43,6 +43,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales):
     depends_on('py-tree-math', type=('build', 'link', 'run'))
     extends('python')
 
+    patch("include-SparseTensorUtils.patch", when="@:8.0-a5")
+
     def install_test(self):
         with working_dir('spack-test', create=True):
             old_pythonpath = os.environ.get('PYTHONPATH', '')
-- 
GitLab