# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)

import os
import unittest.mock
import xml.etree.ElementTree as ET

from spack.package import *
from spack.util.environment import EnvironmentModifications
import spack.build_environment

import importlib
build_brainscales = importlib.import_module("spack.pkg.ebrains-spack-builds.build_brainscales")
hxtorch = importlib.import_module("spack.pkg.ebrains-spack-builds.hxtorch")


class Jaxsnn(build_brainscales.BuildBrainscales):
    """jaxsnn is an event-based approach to machine-learning-inspired training
    and simulation of SNNs, including support for the BrainScaleS-2
    neuromorphic backend."""

    homepage = "https://github.com/electronicvisions/jaxsnn"
    # This repo provides a custom waf binary used for the build below
    git      = "https://github.com/electronicvisions/pynn-brainscales.git"

    maintainers = ["emuller", "muffgaga"]

    # newer versions are defined in the common base package
    version('8.0-a5', tag='jaxsnn-8.0-a5')
    version('8.0-a4', tag='jaxsnn-8.0-a4')
    version('8.0-a3', tag='jaxsnn-8.0-a3')
    version('8.0-a2', tag='jaxsnn-8.0-a2')
    version('8.0-a1', tag='jaxsnn-8.0-a1')

    # dependencies inherited from hxtorch.core
    for dep, dep_kw in hxtorch.Hxtorch.deps_hxtorch_core:
        depends_on(dep, **dep_kw)

    # main dependencies w/o hxtorch.core dependencies (those come via hxtorch above)
    depends_on('py-jax@0.4.13:', type=('build', 'link', 'run'))
    depends_on('py-matplotlib', type=('build', 'link', 'run'))
    depends_on('py-optax', type=('build', 'link', 'run'))
    depends_on('py-tree-math', type=('build', 'link', 'run'))
    extends('python')

    patch("include-SparseTensorUtils.patch", when="@:8.0-a5")

    def install_test(self):
        with working_dir('spack-test', create=True):
            old_pythonpath = os.environ.get('PYTHONPATH', '')
            os.environ['PYTHONPATH'] = ':'.join([str(self.prefix.lib), old_pythonpath])
            bash = which("bash")
            # ignore segfaults for now (exit code 139)
            bash('-c', '(python -c "import jaxsnn; print(jaxsnn.__file__)" || ( test $? -eq 139 && echo "segfault")) || exit $?')