From de57af473fa6db7ff53ed206e2a703416baacd19 Mon Sep 17 00:00:00 2001 From: Gregor Lenz Date: Fri, 18 Nov 2022 11:58:59 +0100 Subject: [PATCH] autoformat all docstrings using docformatter --- .pre-commit-config.yaml | 6 ++++ sinabs/activation/quantize.py | 11 ++++--- sinabs/activation/reset_mechanism.py | 8 ++--- sinabs/activation/spike_generation.py | 27 +++++++--------- sinabs/activation/surrogate_gradient_fn.py | 15 +++------ sinabs/cnnutils.py | 15 +++------ sinabs/conversion.py | 9 ++---- sinabs/from_torch.py | 7 ++-- sinabs/layers/alif.py | 17 ++++------ sinabs/layers/crop2d.py | 6 ++-- sinabs/layers/exp_leak.py | 13 +++----- sinabs/layers/iaf.py | 17 ++++------ sinabs/layers/lif.py | 14 ++++---- sinabs/layers/neuromorphic_relu.py | 6 ++-- sinabs/layers/pool2d.py | 11 +++---- sinabs/layers/quantize.py | 3 +- sinabs/layers/reshape.py | 34 +++++++++----------- sinabs/layers/stateful_layer.py | 37 ++++++---------------- sinabs/layers/to_spike.py | 6 ++-- sinabs/network.py | 29 ++++++----------- sinabs/onnx/get_graph.py | 3 +- sinabs/synopcounter.py | 24 ++++++-------- sinabs/utils.py | 17 ++++------ tests/test_conversion.py | 7 ++-- tests/test_from_model.py | 10 ++---- tests/test_network_class.py | 6 ++-- 26 files changed, 135 insertions(+), 223 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 093caeac..04dd799e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,3 +10,9 @@ repos: hooks: - id: isort args: ["--profile", "black", "--filter-files"] + + - repo: https://github.com/myint/docformatter + rev: v1.4 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] \ No newline at end of file diff --git a/sinabs/activation/quantize.py b/sinabs/activation/quantize.py index 576174a1..f4d75c27 100644 --- a/sinabs/activation/quantize.py +++ b/sinabs/activation/quantize.py @@ -2,9 +2,8 @@ class Quantize(torch.autograd.Function): - """PyTorch-compatible function that applies a floor() operation on the input, - while providing a surrogate gradient (equivalent to that of a linear - function) in the backward pass.""" + """PyTorch-compatible function that applies a floor() operation on the input, while providing a + surrogate gradient (equivalent to that of a linear function) in the backward pass.""" @staticmethod def forward(ctx, inp): @@ -19,10 +18,12 @@ def backward(ctx, grad_output): class StochasticRounding(torch.autograd.Function): - """PyTorch-compatible function that applies stochastic rounding. The input x + """PyTorch-compatible function that applies stochastic rounding. The input x. + is quantized to ceil(x) with probability (x - floor(x)), and to floor(x) otherwise. The backward pass is provided as a surrogate gradient - (equivalent to that of a linear function).""" + (equivalent to that of a linear function). + """ @staticmethod def forward(ctx, inp): diff --git a/sinabs/activation/reset_mechanism.py b/sinabs/activation/reset_mechanism.py index ab58a95e..05aa18de 100644 --- a/sinabs/activation/reset_mechanism.py +++ b/sinabs/activation/reset_mechanism.py @@ -4,9 +4,7 @@ @dataclass class MembraneReset: - """ - Reset the membrane potential v_mem to a given value - after it spiked. + """Reset the membrane potential v_mem to a given value after it spiked. Parameters: reset_value: fixed value that a neuron should be reset to. Defaults to zero. @@ -27,9 +25,7 @@ def __call__(self, spikes, state, spike_threshold): @dataclass class MembraneSubtract: - """ - Subtract the spiking threshold from the membrane potential - for every neuron that spiked. + """Subtract the spiking threshold from the membrane potential for every neuron that spiked. Parameters: subtract_value: optional value that will be subtraced from diff --git a/sinabs/activation/spike_generation.py b/sinabs/activation/spike_generation.py index d84baf28..96e12959 100644 --- a/sinabs/activation/spike_generation.py +++ b/sinabs/activation/spike_generation.py @@ -15,10 +15,9 @@ def backward(ctx, grad_output: torch.tensor): class MultiSpike(BackwardClass, torch.autograd.Function): - """ - Autograd function that returns membrane potential integer-divided by spike threshold. - Do not instantiate this class when passing as spike_fn (see example). - Can be combined with different surrogate gradient functions. + """Autograd function that returns membrane potential integer-divided by spike threshold. Do not + instantiate this class when passing as spike_fn (see example). Can be combined with different + surrogate gradient functions. Example: >>> layer = sinabs.layers.LIF(spike_fn=MultiSpike, ...) @@ -44,9 +43,9 @@ def forward( class MaxSpikeInner(BackwardClass, torch.autograd.Function): - """ - Autograd function that returns membrane potential divided by - spike threshold for a maximum number of spikes per time step. + """Autograd function that returns membrane potential divided by spike threshold for a maximum + number of spikes per time step. + Can be combined with different surrogate gradient functions. """ @@ -74,10 +73,8 @@ def forward( @dataclass class MaxSpike: - """ - Wrapper for MaxSpikeInner autograd function. This class needs to - be instantiated when used as spike_fn. Notice the difference in example - to Single/MultiSpike. + """Wrapper for MaxSpikeInner autograd function. This class needs to be instantiated when used + as spike_fn. Notice the difference in example to Single/MultiSpike. Example: >>> layer = sinabs.layers.LIF(spike_fn=MaxSpike(max_num_spikes_per_bin=10), ...) @@ -101,11 +98,9 @@ def required_states(self): class SingleSpike(BackwardClass, torch.autograd.Function): - """ - Autograd function that returns membrane potential divided by - spike threshold for a maximum of one spike per time step. - Do not instantiate this class when passing as spike_fn (see example). - Can be combined with different surrogate gradient functions. + """Autograd function that returns membrane potential divided by spike threshold for a maximum + of one spike per time step. Do not instantiate this class when passing as spike_fn (see + example). Can be combined with different surrogate gradient functions. Example: >>> layer = sinabs.layers.LIF(spike_fn=SingleSpike, ...) diff --git a/sinabs/activation/surrogate_gradient_fn.py b/sinabs/activation/surrogate_gradient_fn.py index 01d7c2ae..73631d6b 100644 --- a/sinabs/activation/surrogate_gradient_fn.py +++ b/sinabs/activation/surrogate_gradient_fn.py @@ -6,8 +6,7 @@ @dataclass class Heaviside: - """ - Heaviside surrogate gradient with optional shift. + """Heaviside surrogate gradient with optional shift. Parameters: window: Distance between step of Heaviside surrogate gradient and @@ -28,8 +27,7 @@ def gaussian(x: torch.Tensor, mu: float, sigma: float): @dataclass class Gaussian: - """ - Gaussian surrogate gradient function. + """Gaussian surrogate gradient function. Parameters mu: The mean of the Gaussian. @@ -50,8 +48,7 @@ def __call__(self, v_mem, spike_threshold): @dataclass class MultiGaussian: - """ - Surrogate gradient as defined in Yin et al., 2021. + """Surrogate gradient as defined in Yin et al., 2021. https://www.biorxiv.org/content/10.1101/2021.03.22.436372v2 @@ -86,8 +83,7 @@ def __call__(self, v_mem, spike_threshold): @dataclass class SingleExponential: - """ - Surrogate gradient as defined in Shrestha and Orchard, 2018. + """Surrogate gradient as defined in Shrestha and Orchard, 2018. https://papers.nips.cc/paper/2018/hash/82f2b308c3b01637c607ce05f52a2fed-Abstract.html """ @@ -106,8 +102,7 @@ def __call__(self, v_mem, spike_threshold): @dataclass class PeriodicExponential: - """ - Surrogate gradient as defined in Weidel and Sheik, 2021. + """Surrogate gradient as defined in Weidel and Sheik, 2021. https://arxiv.org/abs/2111.01456 """ diff --git a/sinabs/cnnutils.py b/sinabs/cnnutils.py index 454ece3f..2f479959 100644 --- a/sinabs/cnnutils.py +++ b/sinabs/cnnutils.py @@ -2,15 +2,13 @@ def conv_output_size(image_length: int, kernel_length: int, stride: int) -> int: - """ - Computes output dimension given input dimension, kernel size and stride, - assumign no padding, *per* dimension given + """Computes output dimension given input dimension, kernel size and stride, assumign no + padding, *per* dimension given. :param image_length: int image size on one dimension :param kernel_length: int kernel_length size on one dimension :param stride: int Stride size on one dimension :return: int -- convolved output image size on one dimension - """ try: assert image_length >= kernel_length @@ -24,8 +22,7 @@ def conv_output_size(image_length: int, kernel_length: int, stride: int) -> int: def compute_same_padding_size(kernel_length: int) -> (int, int): - """ - Computes padding for 'same' padding *per* dimension given + """Computes padding for 'same' padding *per* dimension given. :param kernel_length: int Kernel size :returns: Tuple -- (padStart, padStop) , padding on left/right or top/bottom @@ -40,8 +37,7 @@ def compute_same_padding_size(kernel_length: int) -> (int, int): def compute_padding( kernel_shape: tuple, input_shape: tuple, mode="valid" ) -> (int, int, int, int): - """ - Computes padding for 'same' or 'valid' padding + """Computes padding for 'same' or 'valid' padding. :param kernel_shape: Kernel shape (height, width) :param input_shape: Input shape (channels, height, width) @@ -61,8 +57,7 @@ def compute_padding( def infer_output_shape(torch_layer, input_shape: Tuple) -> Tuple: - """ - Compute the output dimensions given input dimensions + """Compute the output dimensions given input dimensions. :param torch_layer: a Torch layer :param input_shape: the shape of the input tensor diff --git a/sinabs/conversion.py b/sinabs/conversion.py index a957e3ea..7017137a 100644 --- a/sinabs/conversion.py +++ b/sinabs/conversion.py @@ -6,8 +6,7 @@ def replace_module(model: nn.Module, source_class: type, mapper_fn: Callable): - """ - A utility function that returns a copy of the model, where specific layers are replaced with + """A utility function that returns a copy of the model, where specific layers are replaced with another type depending on the mapper function. Parameters: @@ -18,7 +17,7 @@ def replace_module(model: nn.Module, source_class: type, mapper_fn: Callable): Returns: A model copy with replaced modules according to mapper_fn. """ - + # Handle case where `model` is of type `source_class` if type(model) == source_class: return mapper_fn(model) @@ -29,8 +28,7 @@ def replace_module(model: nn.Module, source_class: type, mapper_fn: Callable): def replace_module_(model: nn.Sequential, source_class: type, mapper_fn: Callable): - """ - In-place version of replace_module that will step through modules that have children and + """In-place version of replace_module that will step through modules that have children and apply the mapper_fn. Parameters: @@ -56,4 +54,3 @@ def replace_module_(model: nn.Sequential, source_class: type, mapper_fn: Callabl if type(module) == source_class: setattr(model, name, mapper_fn(module)) - diff --git a/sinabs/from_torch.py b/sinabs/from_torch.py index eb38eb12..3424f3a0 100644 --- a/sinabs/from_torch.py +++ b/sinabs/from_torch.py @@ -27,10 +27,9 @@ def from_model( backend=None, kwargs_backend: dict = dict(), ): - """ - Converts a Torch model and returns a Sinabs network object. - The modules in the model are analyzed, and a copy is returned, with all - ReLUs and NeuromorphicReLUs turned into SpikingLayers. + """Converts a Torch model and returns a Sinabs network object. The modules in the model are + analyzed, and a copy is returned, with all ReLUs and NeuromorphicReLUs turned into + SpikingLayers. Parameters: model: Torch model diff --git a/sinabs/layers/alif.py b/sinabs/layers/alif.py index c517ffba..39266ed7 100644 --- a/sinabs/layers/alif.py +++ b/sinabs/layers/alif.py @@ -11,8 +11,8 @@ class ALIF(StatefulLayer): - """ - Adaptive Leaky Integrate and Fire neuron layer that inherits from :class:`~sinabs.layers.StatefulLayer`. + """Adaptive Leaky Integrate and Fire neuron layer that inherits from + :class:`~sinabs.layers.StatefulLayer`. Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018: https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html @@ -210,8 +210,8 @@ def _param_dict(self) -> dict: class ALIFRecurrent(ALIF): - """ - Adaptive Leaky Integrate and Fire neuron layer with recurrent connections which inherits from :class:`~sinabs.layers.ALIF`. + """Adaptive Leaky Integrate and Fire neuron layer with recurrent connections which inherits + from :class:`~sinabs.layers.ALIF`. Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018: https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html @@ -347,8 +347,7 @@ def forward(self, input_data: torch.Tensor): class ALIFSqueeze(ALIF, SqueezeMixin): - """ - ALIF layer with 4-dimensional input (Batch*Time, Channel, Height, Width). + """ALIF layer with 4-dimensional input (Batch*Time, Channel, Height, Width). Same as parent ALIF class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width) instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with @@ -375,10 +374,8 @@ def __init__( self.squeeze_init(batch_size, num_timesteps) def forward(self, input_data: torch.Tensor) -> torch.Tensor: - """ - Forward call wrapper that will flatten the input to and - unflatten the output from the super class forward call. - """ + """Forward call wrapper that will flatten the input to and unflatten the output from the + super class forward call.""" return self.squeeze_forward(input_data, super().forward) @property diff --git a/sinabs/layers/crop2d.py b/sinabs/layers/crop2d.py index 7e6fa219..aa8b5737 100644 --- a/sinabs/layers/crop2d.py +++ b/sinabs/layers/crop2d.py @@ -7,8 +7,7 @@ class Cropping2dLayer(nn.Module): - """ - Crop input image by + """Crop input image by. Parameters: cropping: ((top, bottom), (left, right)) @@ -37,8 +36,7 @@ def forward(self, binary_input): return crop_out def get_output_shape(self, input_shape: Tuple) -> Tuple: - """ - Retuns the output dimensions + """Retuns the output dimensions. Parameters: input_shape: (channels, height, width) diff --git a/sinabs/layers/exp_leak.py b/sinabs/layers/exp_leak.py index a17260a9..0ff53e0e 100644 --- a/sinabs/layers/exp_leak.py +++ b/sinabs/layers/exp_leak.py @@ -7,8 +7,8 @@ class ExpLeak(LIF): - """ - Leaky Integrator layer which is a special case of :class:`~sinabs.layers.LIF` without activation function. + """Leaky Integrator layer which is a special case of :class:`~sinabs.layers.LIF` without + activation function. Neuron dynamics in discrete time: @@ -61,8 +61,7 @@ def _param_dict(self) -> dict: class ExpLeakSqueeze(ExpLeak, SqueezeMixin): - """ - ExpLeak layer with 4-dimensional input (Batch*Time, Channel, Height, Width). + """ExpLeak layer with 4-dimensional input (Batch*Time, Channel, Height, Width). Same as parent ExpLeak class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width) instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with @@ -82,10 +81,8 @@ def __init__(self, batch_size=None, num_timesteps=None, **kwargs): self.squeeze_init(batch_size, num_timesteps) def forward(self, input_data: torch.Tensor) -> torch.Tensor: - """ - Forward call wrapper that will flatten the input to and - unflatten the output from the super class forward call. - """ + """Forward call wrapper that will flatten the input to and unflatten the output from the + super class forward call.""" return self.squeeze_forward(input_data, super().forward) @property diff --git a/sinabs/layers/iaf.py b/sinabs/layers/iaf.py index 025d0cd7..e18d0dff 100644 --- a/sinabs/layers/iaf.py +++ b/sinabs/layers/iaf.py @@ -10,8 +10,8 @@ class IAF(LIF): - """ - Integrate and Fire neuron layer that is designed as a special case of :class:`~sinabs.layers.LIF` with tau_mem=inf. + """Integrate and Fire neuron layer that is designed as a special case of + :class:`~sinabs.layers.LIF` with tau_mem=inf. Neuron dynamics in discrete time: @@ -87,8 +87,8 @@ def _param_dict(self) -> dict: class IAFRecurrent(LIFRecurrent): - """ - Integrate and Fire neuron layer with recurrent connections which inherits from :class:`~sinabs.layers.LIFRecurrent`. + """Integrate and Fire neuron layer with recurrent connections which inherits from + :class:`~sinabs.layers.LIFRecurrent`. Neuron dynamics in discrete time: @@ -166,8 +166,7 @@ def _param_dict(self) -> dict: class IAFSqueeze(IAF, SqueezeMixin): - """ - IAF layer with 4-dimensional input (Batch*Time, Channel, Height, Width). + """IAF layer with 4-dimensional input (Batch*Time, Channel, Height, Width). Same as parent IAF class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width) instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with @@ -192,10 +191,8 @@ def __init__( self.squeeze_init(batch_size, num_timesteps) def forward(self, input_data: torch.Tensor) -> torch.Tensor: - """ - Forward call wrapper that will flatten the input to and - unflatten the output from the super class forward call. - """ + """Forward call wrapper that will flatten the input to and unflatten the output from the + super class forward call.""" return self.squeeze_forward(input_data, super().forward) @property diff --git a/sinabs/layers/lif.py b/sinabs/layers/lif.py index 1d96c5b6..711d5c6f 100644 --- a/sinabs/layers/lif.py +++ b/sinabs/layers/lif.py @@ -11,8 +11,8 @@ class LIF(StatefulLayer): - """ - Leaky Integrate and Fire neuron layer that inherits from :class:`~sinabs.layers.StatefulLayer`. + """Leaky Integrate and Fire neuron layer that inherits from + :class:`~sinabs.layers.StatefulLayer`. Neuron dynamics in discrete time for norm_input=True: @@ -222,8 +222,8 @@ def _param_dict(self) -> dict: class LIFRecurrent(LIF): - """ - Leaky Integrate and Fire neuron layer with recurrent connections which inherits from :class:`~sinabs.layers.LIF`. + """Leaky Integrate and Fire neuron layer with recurrent connections which inherits from + :class:`~sinabs.layers.LIF`. Neuron dynamics in discrete time for norm_input=True: @@ -369,10 +369,8 @@ def __init__( self.squeeze_init(batch_size, num_timesteps) def forward(self, input_data: torch.Tensor) -> torch.Tensor: - """ - Forward call wrapper that will flatten the input to and - unflatten the output from the super class forward call. - """ + """Forward call wrapper that will flatten the input to and unflatten the output from the + super class forward call.""" return self.squeeze_forward(input_data, super().forward) @property diff --git a/sinabs/layers/neuromorphic_relu.py b/sinabs/layers/neuromorphic_relu.py index d0e1890c..26b05276 100644 --- a/sinabs/layers/neuromorphic_relu.py +++ b/sinabs/layers/neuromorphic_relu.py @@ -4,9 +4,8 @@ class NeuromorphicReLU(torch.nn.Module): - """ - NeuromorphicReLU layer. This layer is NOT used for Sinabs networks; it's - useful while training analogue pyTorch networks for future use with Sinabs. + """NeuromorphicReLU layer. This layer is NOT used for Sinabs networks; it's useful while + training analogue pyTorch networks for future use with Sinabs. Parameters: quantize: Whether or not to quantize the output (i.e. floor it to \ @@ -16,7 +15,6 @@ class NeuromorphicReLU(torch.nn.Module): NeuromorphicReLU.activity, and is multiplied by the value of fanout. stochastic_rounding: Upon quantization, should the value be rounded stochastically or floored Only done during training. During evaluation mode, the value is simply floored - """ def __init__(self, quantize=True, fanout=1, stochastic_rounding=False): diff --git a/sinabs/layers/pool2d.py b/sinabs/layers/pool2d.py index a498724d..24e38fba 100644 --- a/sinabs/layers/pool2d.py +++ b/sinabs/layers/pool2d.py @@ -11,9 +11,7 @@ class SpikingMaxPooling2dLayer(nn.Module): - """ - Torch implementation of SpikingMaxPooling. - """ + """Torch implementation of SpikingMaxPooling.""" def __init__( self, @@ -74,8 +72,7 @@ def forward(self, binary_input): return max_input_sum.float() # Float is just to keep things compatible def get_output_shape(self, input_shape: Tuple) -> Tuple: - """ - Returns the shape of output, given an input to this layer + """Returns the shape of output, given an input to this layer. Parameters: input_shape: (channels, height, width) @@ -95,8 +92,8 @@ def get_output_shape(self, input_shape: Tuple) -> Tuple: class SumPool2d(torch.nn.LPPool2d): - """ - Non-spiking sumpooling layer to be used in analogue Torch models. It is identical to torch.nn.LPPool2d with p=1. + """Non-spiking sumpooling layer to be used in analogue Torch models. It is identical to + torch.nn.LPPool2d with p=1. Parameters: kernel_size: the size of the window diff --git a/sinabs/layers/quantize.py b/sinabs/layers/quantize.py index fd37ae3e..a6e68bbb 100644 --- a/sinabs/layers/quantize.py +++ b/sinabs/layers/quantize.py @@ -4,8 +4,7 @@ class QuantizeLayer(nn.Module): - """ - Layer that quantizes the input, i.e. returns floor(input). + """Layer that quantizes the input, i.e. returns floor(input). Parameters: quantize: If False, this layer will pass on the input without modifying it. diff --git a/sinabs/layers/reshape.py b/sinabs/layers/reshape.py index 1978e0d5..a7bf9a74 100644 --- a/sinabs/layers/reshape.py +++ b/sinabs/layers/reshape.py @@ -5,11 +5,10 @@ class Repeat(nn.Module): - """ - Utility layer which wraps any nn.Module. It flattens time and batch - dimensions of the input before feeding it to the child module and - unflattens those dimensions to the original shape before passing it - to the next layer. + """Utility layer which wraps any nn.Module. + + It flattens time and batch dimensions of the input before feeding it to the child module and + unflattens those dimensions to the original shape before passing it to the next layer. """ def __init__(self, module: nn.Module): @@ -27,11 +26,9 @@ def __repr__(self): class FlattenTime(nn.Flatten): - """ - Utility layer which always flattens first two dimensions and is - a special case of `torch.nn.Flatten()`. Meant - to convert a tensor of dimensions (Batch, Time, Channels, Height, Width) - into a tensor of (Batch*Time, Channels, Height, Width). + """Utility layer which always flattens first two dimensions and is a special case of + `torch.nn.Flatten()`. Meant to convert a tensor of dimensions (Batch, Time, Channels, Height, + Width) into a tensor of (Batch*Time, Channels, Height, Width). Shape: - Input: :math:`(Batch, Time, Channel, Height, Width)` or :math:`(Batch, Time, Channel)` @@ -43,10 +40,9 @@ def __init__(self): class UnflattenTime(nn.Module): - """ - Utility layer which always unflattens (expands) the first dimension into two separate ones. - Meant to convert a tensor of dimensions (Batch*Time, Channels, Height, Width) - into a tensor of (Batch, Time, Channels, Height, Width). + """Utility layer which always unflattens (expands) the first dimension into two separate ones. + Meant to convert a tensor of dimensions (Batch*Time, Channels, Height, Width) into a tensor of + (Batch, Time, Channels, Height, Width). Shape: - Input: :math:`(Batch \\times Time, Channel, Height, Width)` or :math:`(Batch \\times Time, Channel)` @@ -63,11 +59,11 @@ def forward(self, x): class SqueezeMixin: - """ - Utility mixin class that will wrap the __init__ and forward call to - flatten the input to and the output from a child class. - The wrapped __init__ will provide two additional parameters batch_size and num_timesteps - and the wrapped forward will unpack and repack the first dimension into batch and time. + """Utility mixin class that will wrap the __init__ and forward call to flatten the input to and + the output from a child class. + + The wrapped __init__ will provide two additional parameters batch_size and num_timesteps and + the wrapped forward will unpack and repack the first dimension into batch and time. """ def squeeze_init(self, batch_size: Optional[int], num_timesteps: Optional[int]): diff --git a/sinabs/layers/stateful_layer.py b/sinabs/layers/stateful_layer.py index 1d68a25a..27f00f48 100644 --- a/sinabs/layers/stateful_layer.py +++ b/sinabs/layers/stateful_layer.py @@ -4,9 +4,8 @@ class StatefulLayer(torch.nn.Module): - """ - A base class that instantiates buffers/states which update at every time step - and provides helper methods that manage those states. + """A base class that instantiates buffers/states which update at every time step and provides + helper methods that manage those states. Parameters: state_names: the PyTorch buffers to initialise. These are not parameters. @@ -42,29 +41,21 @@ def forward(self, *args, **kwargs): ) def is_state_initialised(self) -> bool: - """ - Checks if buffers are of shape 0 and returns - True only if none of them are. - """ + """Checks if buffers are of shape 0 and returns True only if none of them are.""" for buffer in self.buffers(): if buffer.shape == torch.Size([0]): return False return True def state_has_shape(self, shape) -> bool: - """ - Checks if all state have a given shape. - """ + """Checks if all state have a given shape.""" for buff in self.buffers(): if buff.shape != shape: return False return True def init_state_with_shape(self, shape, randomize: bool = False) -> None: - """ - Initialise state/buffers with either zeros or random - tensor of specific shape. - """ + """Initialise state/buffers with either zeros or random tensor of specific shape.""" for name, buffer in self.named_buffers(): self.register_buffer(name, torch.zeros(shape, device=buffer.device)) self.reset_states(randomize=randomize) @@ -74,8 +65,7 @@ def reset_states( randomize: bool = False, value_ranges: Optional[Dict[str, Tuple[float, float]]] = None, ): - """ - Reset the state/buffers in a layer. + """Reset the state/buffers in a layer. Parameters: randomize: If true, reset the states between a range provided. Else, the states are reset to zero. @@ -89,7 +79,6 @@ def reset_states( layer..data = layer..detach_() - """ if self.is_state_initialised(): for name, buffer in self.named_buffers(): @@ -143,22 +132,16 @@ def __deepcopy__(self, memo=None): @property def _param_dict(self) -> dict: - """ - Dict of all parameters relevant for creating a new instance with same - parameters as `self`. - """ + """Dict of all parameters relevant for creating a new instance with same parameters as + `self`.""" return dict() @property def arg_dict(self) -> dict: - """ - A public getter function for the constructor arguments. - """ + """A public getter function for the constructor arguments.""" return self._param_dict @property def does_spike(self) -> bool: - """ - Return True if the layer has an activation function. - """ + """Return True if the layer has an activation function.""" return hasattr(self, "spike_fn") and self.spike_fn is not None diff --git a/sinabs/layers/to_spike.py b/sinabs/layers/to_spike.py index eddb948c..72e4d918 100644 --- a/sinabs/layers/to_spike.py +++ b/sinabs/layers/to_spike.py @@ -5,8 +5,7 @@ class Img2SpikeLayer(nn.Module): - """ - Layer to convert images to spikes. + """Layer to convert images to spikes. Parameters: image_shape: tuple image shape @@ -59,8 +58,7 @@ def get_output_shape(self, input_shape: Tuple): class Sig2SpikeLayer(torch.nn.Module): - """ - Layer to convert analog Signals to spikes. + """Layer to convert analog Signals to spikes. Parameters: channels_in: number of channels in the analog signal diff --git a/sinabs/network.py b/sinabs/network.py index 3fb2abf4..eea13ae0 100644 --- a/sinabs/network.py +++ b/sinabs/network.py @@ -13,8 +13,7 @@ class Network(torch.nn.Module): - """ - Class of a spiking neural network + """Class of a spiking neural network. Attributes: spiking_model: torch.nn.Module, a spiking neural network model @@ -75,9 +74,7 @@ def hook(module, inp, out): [this_hook.remove() for this_hook in hook_list] def forward(self, tsrInput) -> torch.Tensor: - """ - Forward pass for this model - """ + """Forward pass for this model.""" return self.spiking_model(tsrInput) def compare_activations( @@ -87,8 +84,7 @@ def compare_activations( compute_rate: bool = False, verbose: bool = False, ) -> Tuple[np.ndarray, np.ndarray, str]: - """ - Compare activations of the analog model and the SNN for a given data sample + """Compare activations of the analog model and the SNN for a given data sample. Args: data (np.ndarray): Data to process @@ -129,8 +125,7 @@ def compare_activations( def plot_comparison( self, data, name_list: Optional[ArrayLike] = None, compute_rate=False ): - """ - Plots a scatter plot of all the activations + """Plots a scatter plot of all the activations. Args: data: Data to be processed @@ -165,8 +160,7 @@ def reset_states( randomize: bool = False, value_ranges: Optional[List[Dict[str, Tuple[float, float]]]] = None, ): - """ - Reset all neuron states in the submodules. + """Reset all neuron states in the submodules. Parameters ---------- @@ -203,9 +197,7 @@ def zero_grad(self, set_to_none: bool = False) -> None: lyr.zero_grad(set_to_none) def get_synops(self, num_evs_in=None) -> dict: - """ - Please see docs for `sinabs.SNNSynOpCounter.get_synops()`. - """ + """Please see docs for `sinabs.SNNSynOpCounter.get_synops()`.""" if num_evs_in is not None: warnings.warn("num_evs_in is deprecated and has no effect") @@ -215,9 +207,7 @@ def get_synops(self, num_evs_in=None) -> dict: def get_parent_module_by_name( root: torch.nn.Module, name: str ) -> Tuple[torch.nn.Module, str]: - """ - Find a nested Module of a given name inside a Module, and return its parent - Module. + """Find a nested Module of a given name inside a Module, and return its parent Module. Args: root: The Module inside which to look for the nested Module @@ -241,9 +231,8 @@ def get_parent_module_by_name( def infer_module_device(module: torch.nn.Module) -> Union[torch.device, None]: - """ - Infere on which device a module is operating by first looking at its parameters - and then, if no parameters are found, at its buffers. + """Infere on which device a module is operating by first looking at its parameters and then, if + no parameters are found, at its buffers. Args: module: The module whose device is to be inferred. diff --git a/sinabs/onnx/get_graph.py b/sinabs/onnx/get_graph.py index abdd86f5..b95f481b 100644 --- a/sinabs/onnx/get_graph.py +++ b/sinabs/onnx/get_graph.py @@ -16,8 +16,7 @@ def print_onnx_model(onnx_model: onnx.ModelProto): def get_graph(model: Network, inputs): - """ - Extract graph from a sinabs Network model + """Extract graph from a sinabs Network model. :param model: sinabs.Netowrk model to extract the graph for :param inputs: Input tensor to extract graph diff --git a/sinabs/synopcounter.py b/sinabs/synopcounter.py index ebf01c27..fc6aafbc 100644 --- a/sinabs/synopcounter.py +++ b/sinabs/synopcounter.py @@ -16,11 +16,9 @@ def synops_hook(layer, inp, out): class SNNSynOpCounter: - """ - Counter for the synaptic operations emitted by all SpikingLayers in a - spiking model. - Note that this is automatically instantiated by `from_torch` and by - `Network` if they are passed `synops=True`. + """Counter for the synaptic operations emitted by all SpikingLayers in a spiking model. Note + that this is automatically instantiated by `from_torch` and by `Network` if they are passed + `synops=True`. Arguments: model: Spiking model. @@ -70,7 +68,6 @@ def get_synops(self) -> dict: >>> synops_map = counter.get_synops() >>> SynOps_dataframe = pandas.DataFrame.from_dict(synops_map, "index") >>> SynOps_dataframe.set_index("Layer", inplace=True) - """ SynOps_map = {} scale_facts = [] @@ -102,8 +99,7 @@ def get_synops(self) -> dict: return SynOps_map def get_total_synops(self, per_second=False) -> float: - """ - Sums up total number of synaptic operations across the network. + """Sums up total number of synaptic operations across the network. .. note:: this may not be accurate in presence of average pooling. @@ -126,15 +122,14 @@ def get_total_synops(self, per_second=False) -> float: return synops def get_total_power_use(self, j_per_synop=1e-11): - """ - Method to quickly get the total power use of the network, estimated - over the latest forward pass. + """Method to quickly get the total power use of the network, estimated over the latest + forward pass. Arguments: j_per_synop: Energy use per synaptic operation, in joules.\ Default 1e-11 J. - Returns: + Returns: estimated power in mW. """ tot_synops_per_s = self.get_total_synops(per_second=True) @@ -147,9 +142,8 @@ def __del__(self): class SynOpCounter: - """ - Counter for the synaptic operations emitted by all Neuromorphic ReLUs in an - analog CNN model. + """Counter for the synaptic operations emitted by all Neuromorphic ReLUs in an analog CNN + model. Parameters: modules: list of modules, e.g. MyTorchModel.modules() diff --git a/sinabs/utils.py b/sinabs/utils.py index 357af4db..163958f5 100644 --- a/sinabs/utils.py +++ b/sinabs/utils.py @@ -8,8 +8,7 @@ def reset_states(model: nn.Module) -> None: - """ - Helper function to recursively reset all states of spiking layers within the model. + """Helper function to recursively reset all states of spiking layers within the model. Parameters: model: The torch module @@ -22,8 +21,7 @@ def reset_states(model: nn.Module) -> None: def zero_grad(model: nn.Module) -> None: - """ - Helper function to recursively zero the gradients of all spiking layers within the model. + """Helper function to recursively zero the gradients of all spiking layers within the model. Parameters: model: The torch module @@ -36,9 +34,7 @@ def zero_grad(model: nn.Module) -> None: def get_activations(torchanalog_model, tsrData, name_list=None): - """ - Return torch analog model activations for the specified layers - """ + """Return torch analog model activations for the specified layers.""" torch_modules = dict(torchanalog_model.named_modules()) # Populate layer names @@ -81,8 +77,7 @@ def hook(module, inp, output): def get_network_activations( model: nn.Module, inp, name_list: List = None, bRate: bool = False ) -> List[np.ndarray]: - """ - Returns the activity of neurons in each layer of the network + """Returns the activity of neurons in each layer of the network. Parameters: model: Model for which the activations are to be read out @@ -127,8 +122,8 @@ def normalize_weights( param_layers: List[str], percentile: float = 99, ): - """ - Rescale the weights of the network, such that the activity of each specified layer is normalized. + """Rescale the weights of the network, such that the activity of each specified layer is + normalized. The method implemented here roughly follows the paper: `Conversion of Continuous-Valued Deep Networks to Efficient Event-Driven Networks for Image Classification` by Rueckauer et al. diff --git a/tests/test_conversion.py b/tests/test_conversion.py index 6e9813a1..6eb2fffa 100644 --- a/tests/test_conversion.py +++ b/tests/test_conversion.py @@ -1,8 +1,8 @@ +import pytest import torch.nn as nn import sinabs import sinabs.layers as sl -import pytest def test_layer_replacement_sequential(): @@ -54,7 +54,7 @@ def test_layer_replacement_individual(): layer = sl.IAF(spike_threshold=2.0) mapper_fn = lambda module: sl.IAFSqueeze(**module.arg_dict, batch_size=4) squeezed_layer = sinabs.conversion.replace_module(layer, sl.IAF, mapper_fn) - + assert type(squeezed_layer) == sl.IAFSqueeze assert squeezed_layer.batch_size == 4 assert squeezed_layer.spike_threshold == 2 @@ -65,7 +65,6 @@ def test_layer_replacement_individual_in_place(): mapper_fn = lambda module: sl.IAFSqueeze(**module.arg_dict, batch_size=4) with pytest.warns(UserWarning): squeezed_layer = sinabs.conversion.replace_module_(layer, sl.IAF, mapper_fn) - + assert type(layer) == sl.IAF assert squeezed_layer is None - diff --git a/tests/test_from_model.py b/tests/test_from_model.py index a6a9ae6c..6f10d1cf 100644 --- a/tests/test_from_model.py +++ b/tests/test_from_model.py @@ -130,10 +130,8 @@ def test_network_conversion_add_spk_out(): def test_network_conversion_complicated_model(): - """ - Try converting rather complicated network model with nested structures, which used - to fail before. - """ + """Try converting rather complicated network model with nested structures, which used to fail + before.""" ann = nn.Sequential( nn.Conv2d(1, 1, 1), @@ -189,9 +187,7 @@ def test_network_conversion_with_num_timesteps(): def test_network_conversion_backend(): - """ - Try conversion with sinabs explicitly stated as backend. - """ + """Try conversion with sinabs explicitly stated as backend.""" ann = nn.Sequential( nn.Conv2d(1, 1, 1), diff --git a/tests/test_network_class.py b/tests/test_network_class.py index 3388f441..31864c43 100644 --- a/tests/test_network_class.py +++ b/tests/test_network_class.py @@ -100,10 +100,8 @@ def test_compare_activations(): def test_plot_comparison(): - """ - Test whether the plot_comparison() method of the sinabs.network.Network class - could plot a nested-network which is not defined by torch.nn.Sequential(*module_list) directly. - """ + """Test whether the plot_comparison() method of the sinabs.network.Network class could plot a + nested-network which is not defined by torch.nn.Sequential(*module_list) directly.""" # get the names of all spiking layers spiking_layers_names = [