# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other # Spack Project Developers. See the top-level COPYRIGHT file for details. # # SPDX-License-Identifier: (Apache-2.0 OR MIT) import os import unittest.mock import xml.etree.ElementTree as ET from spack.package import * from spack.util.environment import EnvironmentModifications import spack.build_environment import importlib build_brainscales = importlib.import_module("spack.pkg.ebrains-spack-builds.build_brainscales") hxtorch = importlib.import_module("spack.pkg.ebrains-spack-builds.hxtorch") class Jaxsnn(build_brainscales.BuildBrainscales): """jaxsnn is an event-based approach to machine-learning-inspired training and simulation of SNNs, including support for the BrainScaleS-2 neuromorphic backend.""" homepage = "https://github.com/electronicvisions/jaxsnn" # This repo provides a custom waf binary used for the build below git = "https://github.com/electronicvisions/pynn-brainscales.git" maintainers = ["emuller", "muffgaga"] # newer versions are defined in the common base package version('8.0-a5', tag='jaxsnn-8.0-a5') version('8.0-a4', tag='jaxsnn-8.0-a4') version('8.0-a3', tag='jaxsnn-8.0-a3') version('8.0-a2', tag='jaxsnn-8.0-a2') version('8.0-a1', tag='jaxsnn-8.0-a1') # dependencies inherited from hxtorch.core for dep, dep_kw in hxtorch.Hxtorch.deps_hxtorch_core: depends_on(dep, **dep_kw) # main dependencies w/o hxtorch.core dependencies (those come via hxtorch above) depends_on('py-jax@0.4.13:', type=('build', 'link', 'run')) depends_on('py-matplotlib', type=('build', 'link', 'run')) depends_on('py-optax', type=('build', 'link', 'run')) depends_on('py-tree-math', type=('build', 'link', 'run')) extends('python') patch("include-SparseTensorUtils.patch", when="@:8.0-a5") def install_test(self): with working_dir('spack-test', create=True): old_pythonpath = os.environ.get('PYTHONPATH', '') os.environ['PYTHONPATH'] = ':'.join([str(self.prefix.lib), old_pythonpath]) bash = which("bash") # ignore segfaults for now (exit code 139) bash('-c', '(python -c "import jaxsnn; print(jaxsnn.__file__)" || ( test $? -eq 139 && echo "segfault")) || exit $?')