Skip to content
Snippets Groups Projects
Commit d77b2962 authored by Elias Arnold's avatar Elias Arnold Committed by Elias Arnold
Browse files

feat: Add refractory LIF integration and LIF integration tests

Change-Id: I6891e85b348651d8402b685c0b1cb8af81c0a0e3
parent c2327fae
No related branches found
No related tags found
No related merge requests found
from hxtorch.spiking.functional.lif import ( from hxtorch.spiking.functional.lif import (
CalibratedCUBALIFParams, CUBALIFParams, cuba_lif_integration) CalibratedCUBALIFParams, CUBALIFParams, cuba_lif_integration,
cuba_refractory_lif_integration)
from hxtorch.spiking.functional.li import ( from hxtorch.spiking.functional.li import (
CalibratedCUBALIParams, CUBALIParams, cuba_li_integration) CalibratedCUBALIParams, CUBALIParams, cuba_li_integration)
from hxtorch.spiking.functional.iaf import ( from hxtorch.spiking.functional.iaf import (
......
...@@ -161,7 +161,9 @@ def cuba_refractory_iaf_integration(input: torch.Tensor, params: NamedTuple, ...@@ -161,7 +161,9 @@ def cuba_refractory_iaf_integration(input: torch.Tensor, params: NamedTuple,
params, dt) params, dt)
# Refractory update # Refractory update
z, v, ref_state = refractory_update(z, v, ref_state, params, dt) z, v, ref_state = refractory_update(
z, v, z_hw[ts] if z_hw is not None else None,
v_cadc[ts] if v_cadc is not None else None, ref_state, params, dt)
# Save data # Save data
spikes.append(z) spikes.append(z)
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
from hxtorch.spiking.calibrated_params import CalibratedParams from hxtorch.spiking.calibrated_params import CalibratedParams
from hxtorch.spiking.functional.threshold import threshold as spiking_threshold from hxtorch.spiking.functional.threshold import threshold as spiking_threshold
from hxtorch.spiking.functional.unterjubel import Unterjubel from hxtorch.spiking.functional.unterjubel import Unterjubel
from hxtorch.spiking.functional.refractory import refractory_update
class CUBALIFParams(NamedTuple): class CUBALIFParams(NamedTuple):
...@@ -31,13 +32,57 @@ class CalibratedCUBALIFParams(CalibratedParams): ...@@ -31,13 +32,57 @@ class CalibratedCUBALIFParams(CalibratedParams):
method: str = "superspike" method: str = "superspike"
# Allow redefining builtin for PyTorch consistency
# pylint: disable=redefined-builtin, invalid-name, too-many-arguments
def cuba_lif_step(
z: torch.Tensor, v: torch.Tensor, i: torch.Tensor, input: torch.Tensor,
z_hw: torch.Tensor, v_hw: torch.Tensor,
params: Union[CalibratedCUBALIFParams, CUBALIFParams],
dt: float = 1e-6) -> Tuple[torch.Tensor, ...]:
"""
Integrate the membrane of a neurons one time step further according to the
Leaky-integrate and fire dynamics.
:param z: The spike tensor at time step t.
:param v: The membrane tensor at time step t.
:param i: The current tensor at time step t.
:param input: The input tensor at time step t (graded spikes).
:param z_hw: The hardware spikes corresponding to the current time step. In
case this is None, no HW spikes will be injected.
:param v_hw: The hardware cadc traces corresponding to the current time
step. In case this is None, no HW cadc values will be injected.
:param params: Parameter object holding the LIF parameters.
:returns: Returns a tuple (z, v, i) holding the tensors of time step t + 1.
"""
# Membrane increment
dv = dt / params.tau_mem * (params.leak - v + i)
# Current
i = i * (1 - dt / params.tau_syn) + input
# Apply integration step
v = Unterjubel.apply(dv + v, v_hw) if v_hw is not None else dv + v
# Spikes
spike = spiking_threshold(
v - params.threshold, params.method, params.alpha)
z = Unterjubel.apply(spike, z_hw) if z_hw is not None else spike
# Reset
if z_hw is None:
v = (1 - z.detach()) * v + z.detach() * params.reset
return z, v, i
# Allow redefining builtin for PyTorch consistency # Allow redefining builtin for PyTorch consistency
# pylint: disable=redefined-builtin, invalid-name, too-many-locals # pylint: disable=redefined-builtin, invalid-name, too-many-locals
def cuba_lif_integration(input: torch.Tensor, def cuba_lif_integration(
params: Union[CalibratedCUBALIFParams, CUBALIFParams], input: torch.Tensor,
hw_data: Optional[torch.Tensor] = None, params: Union[CalibratedCUBALIFParams, CUBALIFParams],
dt: float = 1e-6) \ hw_data: Optional[torch.Tensor] = None, dt: float = 1e-6) \
-> Tuple[torch.Tensor, torch.Tensor]: -> Tuple[torch.Tensor, ...]:
""" """
Leaky-integrate and fire neuron integration for realization of simple Leaky-integrate and fire neuron integration for realization of simple
spiking neurons with exponential synapses. spiking neurons with exponential synapses.
...@@ -72,24 +117,79 @@ def cuba_lif_integration(input: torch.Tensor, ...@@ -72,24 +117,79 @@ def cuba_lif_integration(input: torch.Tensor,
data.shape[0] for data in (z_hw, v_cadc) if data is not None)) data.shape[0] for data in (z_hw, v_cadc) if data is not None))
current, spikes, membrane = [], [], [] current, spikes, membrane = [], [], []
# Integrate
for ts in range(T):
z, v, i = cuba_lif_step(
z, v, i, input[ts],
z_hw[ts] if hw_data else None,
v_cadc[ts] if hw_data else None,
params, dt)
# Save data
current.append(i)
spikes.append(z)
membrane.append(v)
return (
torch.stack(spikes), torch.stack(membrane), torch.stack(current),
v_madc)
# Allow redefining builtin for PyTorch consistency
# pylint: disable=redefined-builtin, invalid-name, too-many-locals
def cuba_refractory_lif_integration(
input: torch.Tensor,
params: Union[CalibratedCUBALIFParams, CUBALIFParams],
hw_data: Optional[torch.Tensor] = None, dt: float = 1e-6) \
-> Tuple[torch.Tensor, ...]:
"""
Leaky-integrate and fire neuron integration for realization of simple
spiking neurons with exponential synapses and refractory period.
Integrates according to:
i^{t+1} = i^t * (1 - dt / \tau_{syn}) + x^t
v^{t+1} = dt / \tau_{men} * (v_l - v^t + i^{t+1}) + v^t
z^{t+1} = 1 if v^{t+1} > params.v_th
v^{t+1} = params.v_reset if z^{t+1} == 1 or ref^{t+1} > 0
ref^{t+1} = params.tau_ref
ref^{t+1} -= 1
Assumes i^0, v^0 = 0.
:param input: Input spikes in shape (batch, time, neurons).
:param params: LIFParams object holding neuron parameters.
:return: Returns the spike trains in shape and membrane trace as a tuple.
Both tensors are of shape (batch, time, neurons).
"""
dev = input.device
T, bs, ps = input.shape
z, i, v = torch.zeros(bs, ps).to(dev), torch.tensor(0.).to(dev), \
torch.empty(bs, ps).fill_(params.leak).to(dev)
z_hw, v_cadc, v_madc = None, None, None
if hw_data:
z_hw, v_cadc, v_madc = (
data.to(dev) if data is not None else None for data in hw_data)
T = min(T, *(
data.shape[0] for data in (z_hw, v_cadc) if data is not None))
current, spikes, membrane = [], [], []
# Counter for neurons in refractory period
ref_state = torch.zeros(ps, dtype=int, device=dev)
for ts in range(T): for ts in range(T):
# Membrane decay # Membrane decay
dv = dt / params.tau_mem * ((params.leak - v) + i) z, v, i = cuba_lif_step(
v = Unterjubel.apply(v + dv, v_cadc[ts]) \ z, v, i, input[ts],
if v_cadc is not None else v + dv z_hw[ts] if hw_data else None,
v_cadc[ts] if hw_data else None,
# Current params, dt)
di = -dt / params.tau_syn * i
i = i + di + input[ts] # Refractory update
z, v, ref_state = refractory_update(
# Spikes z, v, z_hw[ts] if hw_data else None,
spike = spiking_threshold( v_cadc[ts] if hw_data else None, ref_state, params)
v - params.threshold, params.method, params.alpha)
z = Unterjubel.apply(spike, z_hw[ts]) if z_hw is not None else spike
# Reset
if v_cadc is None:
v = (1 - z.detach()) * v + z.detach() * params.reset
# Save data # Save data
current.append(i) current.append(i)
......
...@@ -4,11 +4,15 @@ Refractory update for neurons with refractory behaviour ...@@ -4,11 +4,15 @@ Refractory update for neurons with refractory behaviour
from typing import Tuple, NamedTuple from typing import Tuple, NamedTuple
import torch import torch
from hxtorch.spiking.functional.unterjubel import Unterjubel
# pylint: disable=invalid-name
# pylint: disable=invalid-name, too-many-arguments
def refractory_update(z: torch.Tensor, v: torch.Tensor, def refractory_update(z: torch.Tensor, v: torch.Tensor,
ref_state: torch.Tensor, params: NamedTuple, z_hw: torch.Tensor, v_hw: torch.Tensor,
dt: float = 1e-6) -> Tuple[torch.Tensor, ...]: ref_state: torch.Tensor,
params: NamedTuple, dt: float = 1e-6) \
-> Tuple[torch.Tensor, ...]:
""" """
Update neuron membrane and spikes to account for refractory period. Update neuron membrane and spikes to account for refractory period.
This implemention is widly adopted from: This implemention is widly adopted from:
...@@ -17,15 +21,23 @@ def refractory_update(z: torch.Tensor, v: torch.Tensor, ...@@ -17,15 +21,23 @@ def refractory_update(z: torch.Tensor, v: torch.Tensor,
:param v: The membrane tensor at time step t. :param v: The membrane tensor at time step t.
:param ref_state: The refractory state holding the number of time steps the :param ref_state: The refractory state holding the number of time steps the
neurons has to remain in the refractory period. neurons has to remain in the refractory period.
:param z_hw: The hardware spikes corresponding to the current time step. In
case this is None, no HW spikes will be injected.
:param v_hw: The hardware cadc traces corresponding to the current time
step. In case this is None, no HW cadc values will be injected.
:param params: Parameter object holding the LIF parameters. :param params: Parameter object holding the LIF parameters.
:returns: Returns a tuple (z, v, ref_state) holding the tensors of time :returns: Returns a tuple (z, v, ref_state) holding the tensors of time
step t. step t.
""" """
# Refractory mask # Refractory mask
ref_mask = (ref_state > 0).long() ref_mask = (ref_state > 0).to(v.dtype)
# Update neuron states # Update neuron states
v = (1 - ref_mask) * v + ref_mask * params.reset v = (1 - ref_mask) * v + ref_mask * params.reset
# Inject HW membrane potential
v = Unterjubel.apply(v, v_hw) if v_hw is not None else v
# Inject HW spike
z = (1 - ref_mask) * z z = (1 - ref_mask) * z
z = Unterjubel.apply(z, z_hw) if z_hw is not None else z
# Update refractory state # Update refractory state
ref_state = (1 - z) * torch.nn.functional.relu(ref_state - ref_mask) \ ref_state = (1 - z) * torch.nn.functional.relu(ref_state - ref_mask) \
+ z * params.refractory_time / dt + z * params.refractory_time / dt
......
...@@ -185,7 +185,7 @@ class TestIAFIntegration(unittest.TestCase): ...@@ -185,7 +185,7 @@ class TestIAFIntegration(unittest.TestCase):
params = CUBAIAFParams( params = CUBAIAFParams(
tau_mem=6e-6, tau_mem=6e-6,
tau_syn=6e-6, tau_syn=6e-6,
refractory_time=0e-6, refractory_time=1e-6,
threshold=1., threshold=1.,
reset=-0.1) reset=-0.1)
......
...@@ -8,7 +8,8 @@ import numpy as np ...@@ -8,7 +8,8 @@ import numpy as np
import torch import torch
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from hxtorch.spiking.functional import CUBALIFParams, cuba_lif_integration from hxtorch.spiking.functional import (
CUBALIFParams, cuba_lif_integration, cuba_refractory_lif_integration)
class TestLIFIntegration(unittest.TestCase): class TestLIFIntegration(unittest.TestCase):
...@@ -58,7 +59,7 @@ class TestLIFIntegration(unittest.TestCase): ...@@ -58,7 +59,7 @@ class TestLIFIntegration(unittest.TestCase):
loss.backward() loss.backward()
# plot # plot
fig, ax = plt.subplots() _, ax = plt.subplots()
ax.plot( ax.plot(
np.arange(0., 1e-6 * 100, 1e-6), membrane[:, 0].detach().numpy()) np.arange(0., 1e-6 * 100, 1e-6), membrane[:, 0].detach().numpy())
ax.plot( ax.plot(
...@@ -72,8 +73,8 @@ class TestLIFIntegration(unittest.TestCase): ...@@ -72,8 +73,8 @@ class TestLIFIntegration(unittest.TestCase):
params = CUBALIFParams( params = CUBALIFParams(
tau_mem=6e-6, tau_mem=6e-6,
tau_syn=6e-6, tau_syn=6e-6,
refractory_time=0e-6, refractory_time=1e-6,
threshold=1., threshold=0.7,
reset=-0.1) reset=-0.1)
# Inputs # Inputs
...@@ -123,13 +124,123 @@ class TestLIFIntegration(unittest.TestCase): ...@@ -123,13 +124,123 @@ class TestLIFIntegration(unittest.TestCase):
loss.backward() loss.backward()
# plot # plot
fig, ax = plt.subplots() _, ax = plt.subplots()
ax.plot( ax.plot(
np.arange(0., 1e-6 * 100, 1e-6), membrane[:, 0].detach().numpy()) np.arange(0., 1e-6 * 100, 1e-6), membrane[:, 0].detach().numpy())
ax.plot( ax.plot(
np.arange(0., 1e-6 * 100, 1e-6), current[:, 0].detach().numpy()) np.arange(0., 1e-6 * 100, 1e-6), current[:, 0].detach().numpy())
plt.savefig(self.plot_path.joinpath("./cuba_lif_dynamics_hw.png")) plt.savefig(self.plot_path.joinpath("./cuba_lif_dynamics_hw.png"))
def test_refractory_lif_integration(self):
""" Test refractory LIF integration """
# Params
params = CUBALIFParams(
tau_mem=6e-6,
tau_syn=6e-6,
refractory_time=1e-6,
threshold=0.7,
reset=-0.1)
# Inputs
inputs = torch.zeros(100, 10, 5)
inputs[10, :, 0] = 1
inputs[30, :, 1] = 1
inputs[40, :, 2] = 1
inputs[53, :, 3] = 1
weight = torch.nn.parameter.Parameter(torch.randn(15, 5))
graded_spikes = torch.nn.functional.linear(inputs, weight)
spikes, membrane, current, v_madc = cuba_refractory_lif_integration(
graded_spikes, params, dt=1e-6)
# Shapes
self.assertTrue(
torch.equal(
torch.tensor([100, 10, 15]), torch.tensor(spikes.shape)))
self.assertTrue(
torch.equal(
torch.tensor([100, 10, 15]), torch.tensor(membrane.shape)))
self.assertTrue(
torch.equal(
torch.tensor([100, 10, 15]), torch.tensor(current.shape)))
self.assertIsNone(v_madc)
# No error
loss = spikes.sum()
loss.backward()
# plot
_, ax = plt.subplots()
ax.plot(
np.arange(0., 1e-6 * 100, 1e-6), membrane[:, 0].detach().numpy())
plt.savefig(
self.plot_path.joinpath("./cuba_refractory_lif_dynamic.png"))
def test_refractory_lif_integration_hw_data(self):
""" Test refractory LIF integration with hardware data """
# Params
params = CUBALIFParams(
tau_mem=6e-6,
tau_syn=6e-6,
refractory_time=1e-6,
threshold=0.7,
reset=-0.1)
# Inputs
inputs = torch.zeros(100, 10, 5)
inputs[10, :, 0] = 1
inputs[30, :, 1] = 1
inputs[40, :, 2] = 1
inputs[53, :, 3] = 1
weight = torch.nn.parameter.Parameter(torch.randn(15, 5))
graded_spikes = torch.nn.functional.linear(inputs, weight)
spikes, membrane, current, v_madc = cuba_refractory_lif_integration(
graded_spikes, params)
# Add jitter
membrane += torch.rand(membrane.shape) * 0.05
spikes[
torch.randint(100, (1,)), torch.randint(10, (1,)),
torch.randint(15, (1,))] = 1
# Inject
graded_spikes = torch.nn.functional.linear(inputs, weight)
spikes_hw, membrane_hw, current_hw, v_madc_hw = \
cuba_refractory_lif_integration(
graded_spikes, params, hw_data=(spikes, membrane, membrane))
# Shapes
self.assertTrue(
torch.equal(
torch.tensor([100, 10, 15]), torch.tensor(spikes.shape)))
self.assertTrue(
torch.equal(
torch.tensor([100, 10, 15]), torch.tensor(membrane.shape)))
self.assertTrue(
torch.equal(
torch.tensor([100, 10, 15]), torch.tensor(current.shape)))
self.assertTrue(
torch.equal(
torch.tensor([100, 10, 15]), torch.tensor(v_madc_hw.shape)))
# Check HW data is still the same
self.assertTrue(torch.equal(spikes_hw, spikes))
self.assertTrue(torch.equal(membrane_hw, membrane))
self.assertTrue(torch.equal(v_madc_hw, membrane))
self.assertTrue(torch.equal(current_hw, current))
# No error
loss = spikes.sum()
loss.backward()
# plot
_, ax = plt.subplots()
ax.plot(
np.arange(0., 1e-6 * 100, 1e-6), membrane[:, 0].detach().numpy())
plt.savefig(
self.plot_path.joinpath("./cuba_refractory_lif_dynamic_hw.png"))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment