From 923deb59241db8bdb47acddf9eb49c50fb0eb5e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eric=20M=C3=BCller?= <mueller@kip.uni-heidelberg.de> Date: Mon, 18 Mar 2024 10:22:39 +0100 Subject: [PATCH] feat(BSS2): add jaxsnn (event-based ML) --- packages/jaxsnn/package.py | 157 +++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 packages/jaxsnn/package.py diff --git a/packages/jaxsnn/package.py b/packages/jaxsnn/package.py new file mode 100644 index 00000000..d1d7f5eb --- /dev/null +++ b/packages/jaxsnn/package.py @@ -0,0 +1,157 @@ +# Copyright 2013-2022 Lawrence Livermore National Security, LLC and other +# Spack Project Developers. See the top-level COPYRIGHT file for details. +# +# SPDX-License-Identifier: (Apache-2.0 OR MIT) +import os +import unittest.mock +import xml.etree.ElementTree as ET + +from spack import * +from spack.util.environment import EnvironmentModifications +import spack.build_environment + + +class Jaxsnn(WafPackage): + """jaxsnn is an event-based approach to machine-learning-inspired training + and simulation of SNNs, including support for the BrainScaleS-2 + neuromorphic backend.""" + + homepage = "https://github.com/electronicvisions/jaxsnn" + # This repo provides a custom waf binary used for the build below + git = "https://github.com/electronicvisions/pynn-brainscales.git" + + maintainers = ['emuller'] + + version('8.0-a3', tag='jaxsnn-8.0-a3') + version('8.0-a2', tag='jaxsnn-8.0-a2') + version('8.0-a1', tag='jaxsnn-8.0-a1') + + # for now, this is still "on top" of hxtorch… + depends_on('hxtorch@8.0-a3', when='@8.0-a3', type=('build', 'link', 'run', 'test')) + depends_on('hxtorch@8.0-a2', when='@8.0-a2', type=('build', 'link', 'run', 'test')) + depends_on('hxtorch@8.0-a1', when='@8.0-a1', type=('build', 'link', 'run', 'test')) + + # main dependencies w/o hxtorch.core dependencies (those come via hxtorch above) + depends_on('py-jax@0.4.13:', type=('build', 'link', 'run')) + depends_on('py-matplotlib', type=('build', 'link', 'run')) + depends_on('py-optax', type=('build', 'link', 'run')) + depends_on('py-tree-math', type=('build', 'link', 'run')) + extends('python') + + def do_fetch(self, mirror_only=False): + """Setup the project.""" + + self.stage.create() + self.stage.fetch(mirror_only) + + # if fetcher didn't do anything, it's cached already + if not os.path.exists(self.stage.source_path): + return + + with working_dir(self.stage.source_path): + python = which('python3') + python('./waf', 'setup', '--repo-db-url=https://github.com/electronicvisions/projects', + '--clone-depth=2', + '--without-munge', + '--without-hxcomm-hostarq', + '--project=jaxsnn', + '--release-branch=ebrains-' + str(self.spec.version) + ) + + # in the configure step, we need access to all archived .git folders + def custom_archive(self, destination): + super(spack.fetch_strategy.GitFetchStrategy, self).archive(destination) + with unittest.mock.patch('spack.fetch_strategy.GitFetchStrategy.archive', new=custom_archive): + self.stage.cache_local() + + def _setup_common_env(self, env): + # grenade needs to find some libraries for the JIT-compilation of + # programs for BrainScaleS-2's embedded processor. + ppu_include_dirs = [] + ppu_dep_names = ['bitsery', 'boost', 'cereal'] + for ppu_dep_name in ppu_dep_names: + dep = self.spec[ppu_dep_name] + dep_include_dirs = set(dep.headers.directories) + ppu_include_dirs.extend(list(dep_include_dirs)) + for dir in reversed(ppu_include_dirs): + env.prepend_path("C_INCLUDE_PATH", dir) + env.prepend_path("CPLUS_INCLUDE_PATH", dir) + + def setup_build_environment(self, env): + my_envmod = EnvironmentModifications(env) + spack.build_environment.set_wrapper_variables(self, my_envmod) + my_env = {} + my_envmod.apply_modifications(my_env) + + def get_path(env, name): + path = env.get(name, "").strip() + if path: + return path.split(os.pathsep) + return [] + + # spack tries to find headers and libraries by itself (i.e. it's not + # relying on the compiler to find it); we explicitly expose the + # spack-provided env vars that contain include and library paths + if 'SPACK_INCLUDE_DIRS' in my_env: + for dir in reversed(get_path(my_env, "SPACK_INCLUDE_DIRS")): + env.prepend_path("C_INCLUDE_PATH", dir) + env.prepend_path("CPLUS_INCLUDE_PATH", dir) + if 'SPACK_LINK_DIRS' in my_env: + for dir in reversed(get_path(my_env, "SPACK_LINK_DIRS")): + env.prepend_path("LIBRARY_PATH", dir) + env.prepend_path("LD_LIBRARY_PATH", dir) + for dir in reversed(self.compiler.implicit_rpaths()): + env.prepend_path("LIBRARY_PATH", dir) + # technically this is probably not needed for the non-configure steps + env.prepend_path("LD_LIBRARY_PATH", dir) + + def setup_dependent_build_environment(self, env, dependent_spec): + self._setup_common_env(env) + + def setup_run_environment(self, env): + self._setup_common_env(env) + + def setup_dependent_run_environment(self, env, dependent_spec): + self._setup_common_env(env) + + # override configure step as we perform a project setup first + def configure(self, spec, prefix): + """Configure the project.""" + + args = ['--prefix={0}'.format(self.prefix)] + args += self.configure_args() + self.waf('configure', '--build-profile=release', '--disable-doxygen', *args) + + def build_args(self): + args = ['--keep', '--test-execnone', '-v'] + return args + + def build_test(self): + self.builder.waf('build', '--test-execall') + copy_tree('build/test_results', join_path(self.prefix, '.build')) + copy_tree('build/test_results', join_path(self.stage.path, ".install_time_tests")) + # propagate failures from junit output to spack + tree = ET.parse('build/test_results/summary.xml') + for testsuite in tree.getroot(): + for testcase in testsuite: + if (testcase.get('name').startswith("pycodestyle") or + testcase.get('name').startswith("pylint")): + continue + for elem in testcase: + if (elem.tag == 'failure') and not ( + elem.get('message').startswith("pylint:") or + elem.get('message').startswith("pycodestyle:") or + ("OK" in elem.get('message') and "Segmentation fault" in elem.get('message'))): + raise RuntimeError("Failed test found: {}".format(testcase.get('name'))) + + def install_args(self): + args = ['--test-execnone'] + return args + + def install_test(self): + with working_dir('spack-test', create=True): + old_pythonpath = os.environ.get('PYTHONPATH', '') + os.environ['PYTHONPATH'] = ':'.join([str(self.prefix.lib), old_pythonpath]) + bash = which("bash") + # ignore segfaults for now (exit code 139) + bash('-c', '(python -c "import jaxsnn; print(jaxsnn.__file__)" || ( test $? -eq 139 && echo "segfault")) || exit $?') -- GitLab