Skip to content

Commit

Permalink
autoformat all docstrings using docformatter
Browse files Browse the repository at this point in the history
  • Loading branch information
biphasic committed Nov 18, 2022
1 parent 6102ddd commit de57af4
Show file tree
Hide file tree
Showing 26 changed files with 135 additions and 223 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ repos:
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/myint/docformatter
rev: v1.4
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
11 changes: 6 additions & 5 deletions sinabs/activation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@


class Quantize(torch.autograd.Function):
"""PyTorch-compatible function that applies a floor() operation on the input,
while providing a surrogate gradient (equivalent to that of a linear
function) in the backward pass."""
"""PyTorch-compatible function that applies a floor() operation on the input, while providing a
surrogate gradient (equivalent to that of a linear function) in the backward pass."""

@staticmethod
def forward(ctx, inp):
Expand All @@ -19,10 +18,12 @@ def backward(ctx, grad_output):


class StochasticRounding(torch.autograd.Function):
"""PyTorch-compatible function that applies stochastic rounding. The input x
"""PyTorch-compatible function that applies stochastic rounding. The input x.
is quantized to ceil(x) with probability (x - floor(x)), and to floor(x)
otherwise. The backward pass is provided as a surrogate gradient
(equivalent to that of a linear function)."""
(equivalent to that of a linear function).
"""

@staticmethod
def forward(ctx, inp):
Expand Down
8 changes: 2 additions & 6 deletions sinabs/activation/reset_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

@dataclass
class MembraneReset:
"""
Reset the membrane potential v_mem to a given value
after it spiked.
"""Reset the membrane potential v_mem to a given value after it spiked.
Parameters:
reset_value: fixed value that a neuron should be reset to. Defaults to zero.
Expand All @@ -27,9 +25,7 @@ def __call__(self, spikes, state, spike_threshold):

@dataclass
class MembraneSubtract:
"""
Subtract the spiking threshold from the membrane potential
for every neuron that spiked.
"""Subtract the spiking threshold from the membrane potential for every neuron that spiked.
Parameters:
subtract_value: optional value that will be subtraced from
Expand Down
27 changes: 11 additions & 16 deletions sinabs/activation/spike_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ def backward(ctx, grad_output: torch.tensor):


class MultiSpike(BackwardClass, torch.autograd.Function):
"""
Autograd function that returns membrane potential integer-divided by spike threshold.
Do not instantiate this class when passing as spike_fn (see example).
Can be combined with different surrogate gradient functions.
"""Autograd function that returns membrane potential integer-divided by spike threshold. Do not
instantiate this class when passing as spike_fn (see example). Can be combined with different
surrogate gradient functions.
Example:
>>> layer = sinabs.layers.LIF(spike_fn=MultiSpike, ...)
Expand All @@ -44,9 +43,9 @@ def forward(


class MaxSpikeInner(BackwardClass, torch.autograd.Function):
"""
Autograd function that returns membrane potential divided by
spike threshold for a maximum number of spikes per time step.
"""Autograd function that returns membrane potential divided by spike threshold for a maximum
number of spikes per time step.
Can be combined with different surrogate gradient functions.
"""

Expand Down Expand Up @@ -74,10 +73,8 @@ def forward(

@dataclass
class MaxSpike:
"""
Wrapper for MaxSpikeInner autograd function. This class needs to
be instantiated when used as spike_fn. Notice the difference in example
to Single/MultiSpike.
"""Wrapper for MaxSpikeInner autograd function. This class needs to be instantiated when used
as spike_fn. Notice the difference in example to Single/MultiSpike.
Example:
>>> layer = sinabs.layers.LIF(spike_fn=MaxSpike(max_num_spikes_per_bin=10), ...)
Expand All @@ -101,11 +98,9 @@ def required_states(self):


class SingleSpike(BackwardClass, torch.autograd.Function):
"""
Autograd function that returns membrane potential divided by
spike threshold for a maximum of one spike per time step.
Do not instantiate this class when passing as spike_fn (see example).
Can be combined with different surrogate gradient functions.
"""Autograd function that returns membrane potential divided by spike threshold for a maximum
of one spike per time step. Do not instantiate this class when passing as spike_fn (see
example). Can be combined with different surrogate gradient functions.
Example:
>>> layer = sinabs.layers.LIF(spike_fn=SingleSpike, ...)
Expand Down
15 changes: 5 additions & 10 deletions sinabs/activation/surrogate_gradient_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

@dataclass
class Heaviside:
"""
Heaviside surrogate gradient with optional shift.
"""Heaviside surrogate gradient with optional shift.
Parameters:
window: Distance between step of Heaviside surrogate gradient and
Expand All @@ -28,8 +27,7 @@ def gaussian(x: torch.Tensor, mu: float, sigma: float):

@dataclass
class Gaussian:
"""
Gaussian surrogate gradient function.
"""Gaussian surrogate gradient function.
Parameters
mu: The mean of the Gaussian.
Expand All @@ -50,8 +48,7 @@ def __call__(self, v_mem, spike_threshold):

@dataclass
class MultiGaussian:
"""
Surrogate gradient as defined in Yin et al., 2021.
"""Surrogate gradient as defined in Yin et al., 2021.
https://www.biorxiv.org/content/10.1101/2021.03.22.436372v2
Expand Down Expand Up @@ -86,8 +83,7 @@ def __call__(self, v_mem, spike_threshold):

@dataclass
class SingleExponential:
"""
Surrogate gradient as defined in Shrestha and Orchard, 2018.
"""Surrogate gradient as defined in Shrestha and Orchard, 2018.
https://papers.nips.cc/paper/2018/hash/82f2b308c3b01637c607ce05f52a2fed-Abstract.html
"""
Expand All @@ -106,8 +102,7 @@ def __call__(self, v_mem, spike_threshold):

@dataclass
class PeriodicExponential:
"""
Surrogate gradient as defined in Weidel and Sheik, 2021.
"""Surrogate gradient as defined in Weidel and Sheik, 2021.
https://arxiv.org/abs/2111.01456
"""
Expand Down
15 changes: 5 additions & 10 deletions sinabs/cnnutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@


def conv_output_size(image_length: int, kernel_length: int, stride: int) -> int:
"""
Computes output dimension given input dimension, kernel size and stride,
assumign no padding, *per* dimension given
"""Computes output dimension given input dimension, kernel size and stride, assumign no
padding, *per* dimension given.
:param image_length: int image size on one dimension
:param kernel_length: int kernel_length size on one dimension
:param stride: int Stride size on one dimension
:return: int -- convolved output image size on one dimension
"""
try:
assert image_length >= kernel_length
Expand All @@ -24,8 +22,7 @@ def conv_output_size(image_length: int, kernel_length: int, stride: int) -> int:


def compute_same_padding_size(kernel_length: int) -> (int, int):
"""
Computes padding for 'same' padding *per* dimension given
"""Computes padding for 'same' padding *per* dimension given.
:param kernel_length: int Kernel size
:returns: Tuple -- (padStart, padStop) , padding on left/right or top/bottom
Expand All @@ -40,8 +37,7 @@ def compute_same_padding_size(kernel_length: int) -> (int, int):
def compute_padding(
kernel_shape: tuple, input_shape: tuple, mode="valid"
) -> (int, int, int, int):
"""
Computes padding for 'same' or 'valid' padding
"""Computes padding for 'same' or 'valid' padding.
:param kernel_shape: Kernel shape (height, width)
:param input_shape: Input shape (channels, height, width)
Expand All @@ -61,8 +57,7 @@ def compute_padding(


def infer_output_shape(torch_layer, input_shape: Tuple) -> Tuple:
"""
Compute the output dimensions given input dimensions
"""Compute the output dimensions given input dimensions.
:param torch_layer: a Torch layer
:param input_shape: the shape of the input tensor
Expand Down
9 changes: 3 additions & 6 deletions sinabs/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@


def replace_module(model: nn.Module, source_class: type, mapper_fn: Callable):
"""
A utility function that returns a copy of the model, where specific layers are replaced with
"""A utility function that returns a copy of the model, where specific layers are replaced with
another type depending on the mapper function.
Parameters:
Expand All @@ -18,7 +17,7 @@ def replace_module(model: nn.Module, source_class: type, mapper_fn: Callable):
Returns:
A model copy with replaced modules according to mapper_fn.
"""

# Handle case where `model` is of type `source_class`
if type(model) == source_class:
return mapper_fn(model)
Expand All @@ -29,8 +28,7 @@ def replace_module(model: nn.Module, source_class: type, mapper_fn: Callable):


def replace_module_(model: nn.Sequential, source_class: type, mapper_fn: Callable):
"""
In-place version of replace_module that will step through modules that have children and
"""In-place version of replace_module that will step through modules that have children and
apply the mapper_fn.
Parameters:
Expand All @@ -56,4 +54,3 @@ def replace_module_(model: nn.Sequential, source_class: type, mapper_fn: Callabl

if type(module) == source_class:
setattr(model, name, mapper_fn(module))

7 changes: 3 additions & 4 deletions sinabs/from_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ def from_model(
backend=None,
kwargs_backend: dict = dict(),
):
"""
Converts a Torch model and returns a Sinabs network object.
The modules in the model are analyzed, and a copy is returned, with all
ReLUs and NeuromorphicReLUs turned into SpikingLayers.
"""Converts a Torch model and returns a Sinabs network object. The modules in the model are
analyzed, and a copy is returned, with all ReLUs and NeuromorphicReLUs turned into
SpikingLayers.
Parameters:
model: Torch model
Expand Down
17 changes: 7 additions & 10 deletions sinabs/layers/alif.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@


class ALIF(StatefulLayer):
"""
Adaptive Leaky Integrate and Fire neuron layer that inherits from :class:`~sinabs.layers.StatefulLayer`.
"""Adaptive Leaky Integrate and Fire neuron layer that inherits from
:class:`~sinabs.layers.StatefulLayer`.
Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018:
https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html
Expand Down Expand Up @@ -210,8 +210,8 @@ def _param_dict(self) -> dict:


class ALIFRecurrent(ALIF):
"""
Adaptive Leaky Integrate and Fire neuron layer with recurrent connections which inherits from :class:`~sinabs.layers.ALIF`.
"""Adaptive Leaky Integrate and Fire neuron layer with recurrent connections which inherits
from :class:`~sinabs.layers.ALIF`.
Pytorch implementation of a Long Short Term Memory SNN (LSNN) by Bellec et al., 2018:
https://papers.neurips.cc/paper/2018/hash/c203d8a151612acf12457e4d67635a95-Abstract.html
Expand Down Expand Up @@ -347,8 +347,7 @@ def forward(self, input_data: torch.Tensor):


class ALIFSqueeze(ALIF, SqueezeMixin):
"""
ALIF layer with 4-dimensional input (Batch*Time, Channel, Height, Width).
"""ALIF layer with 4-dimensional input (Batch*Time, Channel, Height, Width).
Same as parent ALIF class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width)
instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with
Expand All @@ -375,10 +374,8 @@ def __init__(
self.squeeze_init(batch_size, num_timesteps)

def forward(self, input_data: torch.Tensor) -> torch.Tensor:
"""
Forward call wrapper that will flatten the input to and
unflatten the output from the super class forward call.
"""
"""Forward call wrapper that will flatten the input to and unflatten the output from the
super class forward call."""
return self.squeeze_forward(input_data, super().forward)

@property
Expand Down
6 changes: 2 additions & 4 deletions sinabs/layers/crop2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@


class Cropping2dLayer(nn.Module):
"""
Crop input image by
"""Crop input image by.
Parameters:
cropping: ((top, bottom), (left, right))
Expand Down Expand Up @@ -37,8 +36,7 @@ def forward(self, binary_input):
return crop_out

def get_output_shape(self, input_shape: Tuple) -> Tuple:
"""
Retuns the output dimensions
"""Retuns the output dimensions.
Parameters:
input_shape: (channels, height, width)
Expand Down
13 changes: 5 additions & 8 deletions sinabs/layers/exp_leak.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@


class ExpLeak(LIF):
"""
Leaky Integrator layer which is a special case of :class:`~sinabs.layers.LIF` without activation function.
"""Leaky Integrator layer which is a special case of :class:`~sinabs.layers.LIF` without
activation function.
Neuron dynamics in discrete time:
Expand Down Expand Up @@ -61,8 +61,7 @@ def _param_dict(self) -> dict:


class ExpLeakSqueeze(ExpLeak, SqueezeMixin):
"""
ExpLeak layer with 4-dimensional input (Batch*Time, Channel, Height, Width).
"""ExpLeak layer with 4-dimensional input (Batch*Time, Channel, Height, Width).
Same as parent ExpLeak class, only takes in squeezed 4D input (Batch*Time, Channel, Height, Width)
instead of 5D input (Batch, Time, Channel, Height, Width) in order to be compatible with
Expand All @@ -82,10 +81,8 @@ def __init__(self, batch_size=None, num_timesteps=None, **kwargs):
self.squeeze_init(batch_size, num_timesteps)

def forward(self, input_data: torch.Tensor) -> torch.Tensor:
"""
Forward call wrapper that will flatten the input to and
unflatten the output from the super class forward call.
"""
"""Forward call wrapper that will flatten the input to and unflatten the output from the
super class forward call."""
return self.squeeze_forward(input_data, super().forward)

@property
Expand Down
Loading

0 comments on commit de57af4

Please sign in to comment.