From 729c8bab52d051c657427bfb8c83120d68cf4955 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de>
Date: Thu, 2 May 2024 17:06:11 +0200
Subject: [PATCH] fix: build {hxtorch,jaxsnn} on py-torch@2.1: [DO-NOT-MERGE]

This is for testing, if it works, we should get this upstream and it
will change the version/tag.
---
 packages/hxtorch/package.py                          |  2 ++
 .../hxtorch/remove-superfluous-pytorch-include.patch | 12 ++++++++++++
 packages/jaxsnn/package.py                           |  2 ++
 .../jaxsnn/remove-superfluous-pytorch-include.patch  | 12 ++++++++++++
 4 files changed, 28 insertions(+)
 create mode 100644 packages/hxtorch/remove-superfluous-pytorch-include.patch
 create mode 100644 packages/jaxsnn/remove-superfluous-pytorch-include.patch

diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py
index 7985472e..101f8271 100644
--- a/packages/hxtorch/package.py
+++ b/packages/hxtorch/package.py
@@ -33,6 +33,8 @@ class Hxtorch(build_brainscales.BuildBrainscales):
     version('7.0-rc1-fixup2', tag='hxtorch-7.0-rc1-fixup2')
     version('7.0-rc1-fixup1', branch='waf')
 
+    patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:")
+
     deps_hxtorch_core = [
         # compiler for the BrainScaleS-2 embedded processor ("PPU"); needed for
         # building/linking, at runtime and for testing
diff --git a/packages/hxtorch/remove-superfluous-pytorch-include.patch b/packages/hxtorch/remove-superfluous-pytorch-include.patch
new file mode 100644
index 00000000..f7a0ad27
--- /dev/null
+++ b/packages/hxtorch/remove-superfluous-pytorch-include.patch
@@ -0,0 +1,12 @@
+diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp
+index aaf670f..0e7df7d 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,6 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
+-#include <ATen/SparseTensorUtils.h>
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {
diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py
index c61589f3..6f09542d 100644
--- a/packages/jaxsnn/package.py
+++ b/packages/jaxsnn/package.py
@@ -32,6 +32,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales):
     version('8.0-a2', tag='jaxsnn-8.0-a2')
     version('8.0-a1', tag='jaxsnn-8.0-a1')
 
+    patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:")
+
     # dependencies inherited from hxtorch.core
     for dep, dep_kw in hxtorch.Hxtorch.deps_hxtorch_core:
         depends_on(dep, **dep_kw)
diff --git a/packages/jaxsnn/remove-superfluous-pytorch-include.patch b/packages/jaxsnn/remove-superfluous-pytorch-include.patch
new file mode 100644
index 00000000..f7a0ad27
--- /dev/null
+++ b/packages/jaxsnn/remove-superfluous-pytorch-include.patch
@@ -0,0 +1,12 @@
+diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp
+index aaf670f..0e7df7d 100644
+--- a/hxtorch/src/hxtorch/spiking/types.cpp
++++ b/hxtorch/src/hxtorch/spiking/types.cpp
+@@ -2,7 +2,6 @@
+ #include "grenade/vx/common/time.h"
+ #include "hxtorch/spiking/detail/to_dense.h"
+ #include <ATen/Functions.h>
+-#include <ATen/SparseTensorUtils.h>
+ #include <log4cxx/logger.h>
+ 
+ namespace hxtorch::spiking {
-- 
GitLab