From 729c8bab52d051c657427bfb8c83120d68cf4955 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de> Date: Thu, 2 May 2024 17:06:11 +0200 Subject: [PATCH] fix: build {hxtorch,jaxsnn} on py-torch@2.1: [DO-NOT-MERGE] This is for testing, if it works, we should get this upstream and it will change the version/tag. --- packages/hxtorch/package.py | 2 ++ .../hxtorch/remove-superfluous-pytorch-include.patch | 12 ++++++++++++ packages/jaxsnn/package.py | 2 ++ .../jaxsnn/remove-superfluous-pytorch-include.patch | 12 ++++++++++++ 4 files changed, 28 insertions(+) create mode 100644 packages/hxtorch/remove-superfluous-pytorch-include.patch create mode 100644 packages/jaxsnn/remove-superfluous-pytorch-include.patch diff --git a/packages/hxtorch/package.py b/packages/hxtorch/package.py index 7985472e..101f8271 100644 --- a/packages/hxtorch/package.py +++ b/packages/hxtorch/package.py @@ -33,6 +33,8 @@ class Hxtorch(build_brainscales.BuildBrainscales): version('7.0-rc1-fixup2', tag='hxtorch-7.0-rc1-fixup2') version('7.0-rc1-fixup1', branch='waf') + patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:") + deps_hxtorch_core = [ # compiler for the BrainScaleS-2 embedded processor ("PPU"); needed for # building/linking, at runtime and for testing diff --git a/packages/hxtorch/remove-superfluous-pytorch-include.patch b/packages/hxtorch/remove-superfluous-pytorch-include.patch new file mode 100644 index 00000000..f7a0ad27 --- /dev/null +++ b/packages/hxtorch/remove-superfluous-pytorch-include.patch @@ -0,0 +1,12 @@ +diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp +index aaf670f..0e7df7d 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,6 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> +-#include <ATen/SparseTensorUtils.h> + #include <log4cxx/logger.h> + + namespace hxtorch::spiking { diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py index c61589f3..6f09542d 100644 --- a/packages/jaxsnn/package.py +++ b/packages/jaxsnn/package.py @@ -32,6 +32,8 @@ class Jaxsnn(build_brainscales.BuildBrainscales): version('8.0-a2', tag='jaxsnn-8.0-a2') version('8.0-a1', tag='jaxsnn-8.0-a1') + patch("remove-superfluous-pytorch-include.patch", when="^py-torch@2.1:") + # dependencies inherited from hxtorch.core for dep, dep_kw in hxtorch.Hxtorch.deps_hxtorch_core: depends_on(dep, **dep_kw) diff --git a/packages/jaxsnn/remove-superfluous-pytorch-include.patch b/packages/jaxsnn/remove-superfluous-pytorch-include.patch new file mode 100644 index 00000000..f7a0ad27 --- /dev/null +++ b/packages/jaxsnn/remove-superfluous-pytorch-include.patch @@ -0,0 +1,12 @@ +diff --git a/hxtorch/src/hxtorch/spiking/types.cpp b/hxtorch/src/hxtorch/spiking/types.cpp +index aaf670f..0e7df7d 100644 +--- a/hxtorch/src/hxtorch/spiking/types.cpp ++++ b/hxtorch/src/hxtorch/spiking/types.cpp +@@ -2,7 +2,6 @@ + #include "grenade/vx/common/time.h" + #include "hxtorch/spiking/detail/to_dense.h" + #include <ATen/Functions.h> +-#include <ATen/SparseTensorUtils.h> + #include <log4cxx/logger.h> + + namespace hxtorch::spiking { -- GitLab