Skip to content

Commit

Permalink
Replace from torch.nn.functional import ... with `import torch.nn.f…
Browse files Browse the repository at this point in the history
…unctional as F`

These two expressions are almost identical, but it turns out that the
latter is (currently) necessary to visualize models using torchlens.
See johnmarktaylor91/torchlens#18.  While this is a pretty minor
benefit, it's also a pretty minor change.

Signed-off-by: Kale Kundert <[email protected]>
  • Loading branch information
kalekundert committed Nov 19, 2023
1 parent fec08a3 commit ae936d2
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 17 deletions.
4 changes: 2 additions & 2 deletions escnn/nn/modules/conv/r2_transposed_convolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from torch.nn.functional import conv_transpose2d
import torch.nn.functional as F

import escnn.nn
from escnn.nn import FieldType
Expand Down Expand Up @@ -127,7 +127,7 @@ def forward(self, input: GeometricTensor):
_filter, _bias = self.expand_parameters()

# use it for convolution and return the result
output = conv_transpose2d(
output = F.conv_transpose2d(
input.tensor, _filter,
padding=self.padding,
output_padding=self.output_padding,
Expand Down
6 changes: 3 additions & 3 deletions escnn/nn/modules/conv/r2convolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.nn.functional import conv2d, pad
import torch.nn.functional as F

from escnn.nn import FieldType
from escnn.nn import GeometricTensor
Expand Down Expand Up @@ -214,14 +214,14 @@ def forward(self, input: GeometricTensor):
# use it for convolution and return the result

if self.padding_mode == 'zeros':
output = conv2d(input.tensor, _filter,
output = F.conv2d(input.tensor, _filter,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=_bias)
else:
output = conv2d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
output = F.conv2d(F.pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
_filter,
stride=self.stride,
dilation=self.dilation,
Expand Down
4 changes: 2 additions & 2 deletions escnn/nn/modules/conv/r3_transposed_convolution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import gc

from torch.nn.functional import conv_transpose3d
import torch.nn.functional as F

import escnn.nn
from escnn.nn import FieldType
Expand Down Expand Up @@ -128,7 +128,7 @@ def forward(self, input: GeometricTensor):


# use it for convolution and return the result
output = conv_transpose3d(
output = F.conv_transpose3d(
input.tensor, _filter,
padding=self.padding,
output_padding=self.output_padding,
Expand Down
6 changes: 3 additions & 3 deletions escnn/nn/modules/conv/r3convolution.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.nn.functional import conv3d, pad
import torch.nn.functional as F

import escnn.nn
from escnn.nn import FieldType
Expand Down Expand Up @@ -207,14 +207,14 @@ def forward(self, input: GeometricTensor):

# use it for convolution and return the result
if self.padding_mode == 'zeros':
output = conv3d(input.tensor, _filter,
output = F.conv3d(input.tensor, _filter,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
bias=_bias)
else:
output = conv3d(pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
output = F.conv3d(F.pad(input.tensor, self._reversed_padding_repeated_twice, self.padding_mode),
_filter,
stride=self.stride,
dilation=self.dilation,
Expand Down
4 changes: 2 additions & 2 deletions escnn/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from escnn.nn.modules.basismanager import BlocksBasisExpansion

from torch.nn import Parameter
from torch.nn.functional import linear
import torch.nn.functional as F
import torch
import numpy as np

Expand Down Expand Up @@ -202,7 +202,7 @@ def forward(self, input: GeometricTensor):
# retrieve the matrix and the bias
_matrix, _bias = self.expand_parameters()

output = linear(input.tensor, _matrix, bias=_bias)
output = F.linear(input.tensor, _matrix, bias=_bias)

return GeometricTensor(output, self.out_type, input.coords)

Expand Down
1 change: 0 additions & 1 deletion escnn/nn/modules/nonlinearities/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from escnn.nn.modules.basismanager import BlocksBasisExpansion

from torch.nn import Parameter
from torch.nn.functional import linear

from typing import List, Tuple, Any

Expand Down
8 changes: 4 additions & 4 deletions escnn/nn/modules/rdupsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import math

from torch.nn.functional import interpolate
import torch.nn.functional as F

__all__ = ["R2Upsampling", "R3Upsampling"]

Expand Down Expand Up @@ -105,12 +105,12 @@ def forward(self, input: GeometricTensor):
assert len(input.shape) == 2 + self.d, (input.shape, self.d)

if self._align_corners is None:
output = interpolate(input.tensor,
output = F.interpolate(input.tensor,
size=self._size,
scale_factor=self._scale_factor,
mode=self._mode)
else:
output = interpolate(input.tensor,
output = F.interpolate(input.tensor,
size=self._size,
scale_factor=self._scale_factor,
mode=self._mode,
Expand Down Expand Up @@ -413,4 +413,4 @@ def check_equivariance(self, atol: float = 0.1, rtol: float = 0.1):

errors.append((el, errs.mean()))

return errors
return errors

0 comments on commit ae936d2

Please sign in to comment.