diff --git a/sinabs/layers/lif.py b/sinabs/layers/lif.py index 711d5c6f..3275e6e4 100644 --- a/sinabs/layers/lif.py +++ b/sinabs/layers/lif.py @@ -93,18 +93,22 @@ def __init__( if tau_syn is not None else None ) - self.spike_threshold = ( - torch.tensor(spike_threshold) - if type(spike_threshold) == float - else spike_threshold - ) self.spike_fn = spike_fn self.reset_fn = reset_fn self.surrogate_grad_fn = surrogate_grad_fn - self.min_v_mem = min_v_mem self.train_alphas = train_alphas self.norm_input = norm_input self.record_states = record_states + self.min_v_mem = ( + nn.Parameter(torch.as_tensor(min_v_mem), requires_grad=False) + if min_v_mem + else None + ) + self.spike_threshold = ( + nn.Parameter(torch.as_tensor(spike_threshold), requires_grad=False) + if spike_threshold + else None + ) if shape: self.init_state_with_shape(shape) diff --git a/tests/test_iaf.py b/tests/test_iaf.py index eb1cf9b1..f27af198 100644 --- a/tests/test_iaf.py +++ b/tests/test_iaf.py @@ -10,9 +10,11 @@ def test_iaf_basic(): batch_size, time_steps = 10, 100 input_current = torch.rand(batch_size, time_steps, 2, 7, 7) - layer = IAF() + layer = IAF(min_v_mem=-2.0) spike_output = layer(input_current) + assert "min_v_mem" in layer.state_dict().keys() + assert "spike_threshold" in layer.state_dict().keys() assert layer.does_spike assert input_current.shape == spike_output.shape assert torch.isnan(spike_output).sum() == 0 diff --git a/tests/test_lif.py b/tests/test_lif.py index 15c6a6bb..74c23003 100644 --- a/tests/test_lif.py +++ b/tests/test_lif.py @@ -15,7 +15,7 @@ def test_lif_basic(): tau_mem = torch.tensor(20.0) alpha = torch.exp(-1 / tau_mem) input_current = torch.rand(batch_size, time_steps, 2, 7, 7) / (1 - alpha) - layer = LIF(tau_mem=tau_mem) + layer = LIF(tau_mem=tau_mem, min_v_mem=-2.0) spike_output = layer(input_current) # Make sure __repr__ works @@ -24,6 +24,8 @@ def test_lif_basic(): # Make sure arg_dict works layer.arg_dict + assert "min_v_mem" in layer.state_dict().keys() + assert "spike_threshold" in layer.state_dict().keys() assert layer.does_spike assert input_current.shape == spike_output.shape assert torch.isnan(spike_output).sum() == 0