From ae936d2cfe7776385901332613532b1abb1f02cf Mon Sep 17 00:00:00 2001 From: Kale Kundert Date: Sun, 19 Nov 2023 09:16:33 -0500 Subject: [PATCH] Replace `from torch.nn.functional import ...` with `import torch.nn.functional as F` These two expressions are almost identical, but it turns out that the latter is (currently) necessary to visualize models using torchlens. See johnmarktaylor91/torchlens#18. While this is a pretty minor benefit, it's also a pretty minor change. Signed-off-by: Kale Kundert --- escnn/nn/modules/conv/r2_transposed_convolution.py | 4 ++-- escnn/nn/modules/conv/r2convolution.py | 6 +++--- escnn/nn/modules/conv/r3_transposed_convolution.py | 4 ++-- escnn/nn/modules/conv/r3convolution.py | 6 +++--- escnn/nn/modules/linear.py | 4 ++-- escnn/nn/modules/nonlinearities/tensor.py | 1 - escnn/nn/modules/rdupsampling.py | 8 ++++---- 7 files changed, 16 insertions(+), 17 deletions(-) diff --git a/escnn/nn/modules/conv/r2_transposed_convolution.py b/escnn/nn/modules/conv/r2_transposed_convolution.py index ed85aff4..2a779cec 100644 --- a/escnn/nn/modules/conv/r2_transposed_convolution.py +++ b/escnn/nn/modules/conv/r2_transposed_convolution.py @@ -1,5 +1,5 @@ -from torch.nn.functional import conv_transpose2d +import torch.nn.functional as F import escnn.nn from escnn.nn import FieldType @@ -127,7 +127,7 @@ def forward(self, input: GeometricTensor): _filter, _bias = self.expand_parameters() # use it for convolution and return the result - output = conv_transpose2d( + output = F.conv_transpose2d( input.tensor, _filter, padding=self.padding, output_padding=self.output_padding, diff --git a/escnn/nn/modules/conv/r2convolution.py b/escnn/nn/modules/conv/r2convolution.py index 3c0c3efa..8b344718 100644 --- a/escnn/nn/modules/conv/r2convolution.py +++ b/escnn/nn/modules/conv/r2convolution.py @@ -1,4 +1,4 @@ -from torch.nn.functional import conv2d, pad +import torch.nn.functional as F from escnn.nn import FieldType from escnn.nn import GeometricTensor @@ -214,14 +214,14 @@ def forward(self, input: GeometricTensor): # use it for convolution and return the result if self.padding_mode == 'zeros': - output = conv2d(input.tensor, _filter, + output = F.conv2d(input.tensor, _filter, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, bias=_bias) else: - output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode), + output = F.conv2d(F.pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode), _filter, stride=self.stride, dilation=self.dilation, diff --git a/escnn/nn/modules/conv/r3_transposed_convolution.py b/escnn/nn/modules/conv/r3_transposed_convolution.py index c048152c..11496e7b 100644 --- a/escnn/nn/modules/conv/r3_transposed_convolution.py +++ b/escnn/nn/modules/conv/r3_transposed_convolution.py @@ -1,6 +1,6 @@ import gc -from torch.nn.functional import conv_transpose3d +import torch.nn.functional as F import escnn.nn from escnn.nn import FieldType @@ -128,7 +128,7 @@ def forward(self, input: GeometricTensor): # use it for convolution and return the result - output = conv_transpose3d( + output = F.conv_transpose3d( input.tensor, _filter, padding=self.padding, output_padding=self.output_padding, diff --git a/escnn/nn/modules/conv/r3convolution.py b/escnn/nn/modules/conv/r3convolution.py index 4ae314b7..e0a79e91 100644 --- a/escnn/nn/modules/conv/r3convolution.py +++ b/escnn/nn/modules/conv/r3convolution.py @@ -1,4 +1,4 @@ -from torch.nn.functional import conv3d, pad +import torch.nn.functional as F import escnn.nn from escnn.nn import FieldType @@ -207,14 +207,14 @@ def forward(self, input: GeometricTensor): # use it for convolution and return the result if self.padding_mode == 'zeros': - output = conv3d(input.tensor, _filter, + output = F.conv3d(input.tensor, _filter, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, bias=_bias) else: - output = conv3d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode), + output = F.conv3d(F.pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode), _filter, stride=self.stride, dilation=self.dilation, diff --git a/escnn/nn/modules/linear.py b/escnn/nn/modules/linear.py index 09a79838..52508959 100644 --- a/escnn/nn/modules/linear.py +++ b/escnn/nn/modules/linear.py @@ -9,7 +9,7 @@ from escnn.nn.modules.basismanager import BlocksBasisExpansion from torch.nn import Parameter -from torch.nn.functional import linear +import torch.nn.functional as F import torch import numpy as np @@ -202,7 +202,7 @@ def forward(self, input: GeometricTensor): # retrieve the matrix and the bias _matrix, _bias = self.expand_parameters() - output = linear(input.tensor, _matrix, bias=_bias) + output = F.linear(input.tensor, _matrix, bias=_bias) return GeometricTensor(output, self.out_type, input.coords) diff --git a/escnn/nn/modules/nonlinearities/tensor.py b/escnn/nn/modules/nonlinearities/tensor.py index 249d01cc..acc26775 100644 --- a/escnn/nn/modules/nonlinearities/tensor.py +++ b/escnn/nn/modules/nonlinearities/tensor.py @@ -12,7 +12,6 @@ from escnn.nn.modules.basismanager import BlocksBasisExpansion from torch.nn import Parameter -from torch.nn.functional import linear from typing import List, Tuple, Any diff --git a/escnn/nn/modules/rdupsampling.py b/escnn/nn/modules/rdupsampling.py index c92f97eb..44ec2459 100644 --- a/escnn/nn/modules/rdupsampling.py +++ b/escnn/nn/modules/rdupsampling.py @@ -15,7 +15,7 @@ import math -from torch.nn.functional import interpolate +import torch.nn.functional as F __all__ = ["R2Upsampling", "R3Upsampling"] @@ -105,12 +105,12 @@ def forward(self, input: GeometricTensor): assert len(input.shape) == 2 + self.d, (input.shape, self.d) if self._align_corners is None: - output = interpolate(input.tensor, + output = F.interpolate(input.tensor, size=self._size, scale_factor=self._scale_factor, mode=self._mode) else: - output = interpolate(input.tensor, + output = F.interpolate(input.tensor, size=self._size, scale_factor=self._scale_factor, mode=self._mode, @@ -413,4 +413,4 @@ def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1): errors.append((el, errs.mean())) - return errors \ No newline at end of file + return errors