diff --git a/sinabs/synopcounter.py b/sinabs/synopcounter.py index 884b3ca8..3583a21a 100644 --- a/sinabs/synopcounter.py +++ b/sinabs/synopcounter.py @@ -118,7 +118,7 @@ def _setup_hooks(self): unflattened_shape = (layer.batch_size, layer.num_timesteps) for layer in self.model.modules(): - if isinstance(layer, sl.StatefulLayer): + if isinstance(layer, sl.StatefulLayer) and layer.does_spike: layer.acc_output = torch.tensor(0) layer.n_batches = 0 handle = layer.register_forward_hook(spiking_hook) diff --git a/tests/test_synops_counter.py b/tests/test_synops_counter.py index ea11eeb7..7e59b688 100644 --- a/tests/test_synops_counter.py +++ b/tests/test_synops_counter.py @@ -182,6 +182,26 @@ def test_spiking_layer_firing_rate(): assert layer_stats["firing_rate_per_neuron"].mean() == 0.25 +def test_nonspiking_stateful_layer(): + model = nn.Sequential(sl.IAF(), sl.ExpLeak(tau_mem=10)) + input_ = torch.eye(4).unsqueeze(0).unsqueeze(0) + + analyzer = sinabs.SNNAnalyzer(model) + output = model(input_) + model_stats = analyzer.get_model_statistics(average=True) + assert model_stats["firing_rate"] == 0.25 + + layer_stats = analyzer.get_layer_statistics(average=True) + # ExpLeak layer should not show up in spiking or parameter stats + assert "1" not in layer_stats["spiking"] + assert "1" not in layer_stats["parameter"] + + spiking_layer_stats = layer_stats["spiking"]["0"] + assert spiking_layer_stats["firing_rate"] == 0.25 + assert spiking_layer_stats["firing_rate_per_neuron"].shape == (4, 4) + assert spiking_layer_stats["firing_rate_per_neuron"].mean() == 0.25 + + def test_spiking_layer_firing_rate_across_batches(): layer = sl.IAF() input1 = torch.eye(4).unsqueeze(0).unsqueeze(0)