Skip to content

Commit

Permalink
Merge pull request #71 from synsense/68-make-spiking-layer-spike_thre…
Browse files Browse the repository at this point in the history
…shold-and-min_v_mem-parameters

make min_vem and spike_thresholds parameters do that they're included in state_dict
  • Loading branch information
biphasic authored Dec 8, 2022
2 parents e6c5506 + 74ee25c commit 8c0023c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
16 changes: 10 additions & 6 deletions sinabs/layers/lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,22 @@ def __init__(
if tau_syn is not None
else None
)
self.spike_threshold = (
torch.tensor(spike_threshold)
if type(spike_threshold) == float
else spike_threshold
)
self.spike_fn = spike_fn
self.reset_fn = reset_fn
self.surrogate_grad_fn = surrogate_grad_fn
self.min_v_mem = min_v_mem
self.train_alphas = train_alphas
self.norm_input = norm_input
self.record_states = record_states
self.min_v_mem = (
nn.Parameter(torch.as_tensor(min_v_mem), requires_grad=False)
if min_v_mem
else None
)
self.spike_threshold = (
nn.Parameter(torch.as_tensor(spike_threshold), requires_grad=False)
if spike_threshold
else None
)
if shape:
self.init_state_with_shape(shape)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_iaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
def test_iaf_basic():
batch_size, time_steps = 10, 100
input_current = torch.rand(batch_size, time_steps, 2, 7, 7)
layer = IAF()
layer = IAF(min_v_mem=-2.0)
spike_output = layer(input_current)

assert "min_v_mem" in layer.state_dict().keys()
assert "spike_threshold" in layer.state_dict().keys()
assert layer.does_spike
assert input_current.shape == spike_output.shape
assert torch.isnan(spike_output).sum() == 0
Expand Down
4 changes: 3 additions & 1 deletion tests/test_lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_lif_basic():
tau_mem = torch.tensor(20.0)
alpha = torch.exp(-1 / tau_mem)
input_current = torch.rand(batch_size, time_steps, 2, 7, 7) / (1 - alpha)
layer = LIF(tau_mem=tau_mem)
layer = LIF(tau_mem=tau_mem, min_v_mem=-2.0)
spike_output = layer(input_current)

# Make sure __repr__ works
Expand All @@ -24,6 +24,8 @@ def test_lif_basic():
# Make sure arg_dict works
layer.arg_dict

assert "min_v_mem" in layer.state_dict().keys()
assert "spike_threshold" in layer.state_dict().keys()
assert layer.does_spike
assert input_current.shape == spike_output.shape
assert torch.isnan(spike_output).sum() == 0
Expand Down

0 comments on commit 8c0023c

Please sign in to comment.