From 923deb59241db8bdb47acddf9eb49c50fb0eb5e0 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de>
Date: Mon, 18 Mar 2024 10:22:39 +0100
Subject: [PATCH] feat(BSS2): add jaxsnn (event-based ML)

---
 packages/jaxsnn/package.py | 157 +++++++++++++++++++++++++++++++++++++
 1 file changed, 157 insertions(+)
 create mode 100644 packages/jaxsnn/package.py

diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py
new file mode 100644
index 00000000..d1d7f5eb
--- /dev/null
+++ b/packages/jaxsnn/package.py
@@ -0,0 +1,157 @@
+# Copyright 2013-2022 Lawrence Livermore National Security, LLC and other
+# Spack Project Developers. See the top-level COPYRIGHT file for details.
+#
+# SPDX-License-Identifier: (Apache-2.0 OR MIT)
+import os
+import unittest.mock
+import xml.etree.ElementTree as ET
+
+from spack import *
+from spack.util.environment import EnvironmentModifications
+import spack.build_environment
+
+
+class Jaxsnn(WafPackage):
+    """jaxsnn is an event-based approach to machine-learning-inspired training
+    and simulation of SNNs, including support for the BrainScaleS-2
+    neuromorphic backend."""
+
+    homepage = "https://github.com/electronicvisions/jaxsnn"
+    # This repo provides a custom waf binary used for the build below
+    git      = "https://github.com/electronicvisions/pynn-brainscales.git"
+
+    maintainers = ['emuller']
+
+    version('8.0-a3', tag='jaxsnn-8.0-a3')
+    version('8.0-a2', tag='jaxsnn-8.0-a2')
+    version('8.0-a1', tag='jaxsnn-8.0-a1')
+
+    # for now, this is still "on top" of hxtorch…
+    depends_on('hxtorch@8.0-a3', when='@8.0-a3', type=('build', 'link', 'run', 'test'))
+    depends_on('hxtorch@8.0-a2', when='@8.0-a2', type=('build', 'link', 'run', 'test'))
+    depends_on('hxtorch@8.0-a1', when='@8.0-a1', type=('build', 'link', 'run', 'test'))
+
+    # main dependencies w/o hxtorch.core dependencies (those come via hxtorch above)
+    depends_on('py-jax@0.4.13:', type=('build', 'link', 'run'))
+    depends_on('py-matplotlib', type=('build', 'link', 'run'))
+    depends_on('py-optax', type=('build', 'link', 'run'))
+    depends_on('py-tree-math', type=('build', 'link', 'run'))
+    extends('python')
+
+    def do_fetch(self, mirror_only=False):
+        """Setup the project."""
+
+        self.stage.create()
+        self.stage.fetch(mirror_only)
+
+        # if fetcher didn't do anything, it's cached already
+        if not os.path.exists(self.stage.source_path):
+            return
+
+        with working_dir(self.stage.source_path):
+            python = which('python3')
+            python('./waf', 'setup', '--repo-db-url=https://github.com/electronicvisions/projects',
+                '--clone-depth=2',
+                '--without-munge',
+                '--without-hxcomm-hostarq',
+                '--project=jaxsnn',
+                '--release-branch=ebrains-' + str(self.spec.version)
+            )
+
+        # in the configure step, we need access to all archived .git folders
+        def custom_archive(self, destination):
+            super(spack.fetch_strategy.GitFetchStrategy, self).archive(destination)
+        with unittest.mock.patch('spack.fetch_strategy.GitFetchStrategy.archive', new=custom_archive):
+            self.stage.cache_local()
+
+    def _setup_common_env(self, env):
+        # grenade needs to find some libraries for the JIT-compilation of
+        # programs for BrainScaleS-2's embedded processor.
+        ppu_include_dirs = []
+        ppu_dep_names = ['bitsery', 'boost', 'cereal']
+        for ppu_dep_name in ppu_dep_names:
+            dep = self.spec[ppu_dep_name]
+            dep_include_dirs = set(dep.headers.directories)
+            ppu_include_dirs.extend(list(dep_include_dirs))
+        for dir in reversed(ppu_include_dirs):
+            env.prepend_path("C_INCLUDE_PATH", dir)
+            env.prepend_path("CPLUS_INCLUDE_PATH", dir)
+
+    def setup_build_environment(self, env):
+        my_envmod = EnvironmentModifications(env)
+        spack.build_environment.set_wrapper_variables(self, my_envmod)
+        my_env = {}
+        my_envmod.apply_modifications(my_env)
+
+        def get_path(env, name):
+            path = env.get(name, "").strip()
+            if path:
+                return path.split(os.pathsep)
+            return []
+
+        # spack tries to find headers and libraries by itself (i.e. it's not
+        # relying on the compiler to find it); we explicitly expose the
+        # spack-provided env vars that contain include and library paths
+        if 'SPACK_INCLUDE_DIRS' in my_env:
+            for dir in reversed(get_path(my_env, "SPACK_INCLUDE_DIRS")):
+                env.prepend_path("C_INCLUDE_PATH", dir)
+                env.prepend_path("CPLUS_INCLUDE_PATH", dir)
+        if 'SPACK_LINK_DIRS' in my_env:
+            for dir in reversed(get_path(my_env, "SPACK_LINK_DIRS")):
+                env.prepend_path("LIBRARY_PATH", dir)
+                env.prepend_path("LD_LIBRARY_PATH", dir)
+        for dir in reversed(self.compiler.implicit_rpaths()):
+            env.prepend_path("LIBRARY_PATH", dir)
+            # technically this is probably not needed for the non-configure steps
+            env.prepend_path("LD_LIBRARY_PATH", dir)
+
+    def setup_dependent_build_environment(self, env, dependent_spec):
+        self._setup_common_env(env)
+
+    def setup_run_environment(self, env):
+        self._setup_common_env(env)
+
+    def setup_dependent_run_environment(self, env, dependent_spec):
+        self._setup_common_env(env)
+
+    # override configure step as we perform a project setup first
+    def configure(self, spec, prefix):
+        """Configure the project."""
+
+        args = ['--prefix={0}'.format(self.prefix)]
+        args += self.configure_args()
+        self.waf('configure', '--build-profile=release', '--disable-doxygen', *args)
+
+    def build_args(self):
+        args = ['--keep', '--test-execnone', '-v']
+        return args
+
+    def build_test(self):
+        self.builder.waf('build', '--test-execall')
+        copy_tree('build/test_results', join_path(self.prefix, '.build'))
+        copy_tree('build/test_results', join_path(self.stage.path, ".install_time_tests"))
+        # propagate failures from junit output to spack
+        tree = ET.parse('build/test_results/summary.xml')
+        for testsuite in tree.getroot():
+            for testcase in testsuite:
+                if (testcase.get('name').startswith("pycodestyle") or
+                    testcase.get('name').startswith("pylint")):
+                    continue
+                for elem in testcase:
+                    if (elem.tag == 'failure') and not (
+                            elem.get('message').startswith("pylint:") or
+                            elem.get('message').startswith("pycodestyle:") or
+                            ("OK" in elem.get('message') and "Segmentation fault" in elem.get('message'))):
+                        raise RuntimeError("Failed test found: {}".format(testcase.get('name')))
+
+    def install_args(self):
+        args = ['--test-execnone']
+        return args
+
+    def install_test(self):
+        with working_dir('spack-test', create=True):
+            old_pythonpath = os.environ.get('PYTHONPATH', '')
+            os.environ['PYTHONPATH'] = ':'.join([str(self.prefix.lib), old_pythonpath])
+            bash = which("bash")
+            # ignore segfaults for now (exit code 139)
+            bash('-c', '(python -c "import jaxsnn; print(jaxsnn.__file__)" || ( test $? -eq 139 && echo "segfault")) || exit $?')
-- 
GitLab