From 2b97ae22334e8a5257407f49227f65f3d1ebb731 Mon Sep 17 00:00:00 2001 From: Allen Goodman Date: Fri, 10 May 2024 16:21:16 -0400 Subject: [PATCH] beignet.func.space --- src/beignet/func/_space.py | 72 ++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/src/beignet/func/_space.py b/src/beignet/func/_space.py index f764d5f413..52e92658e6 100644 --- a/src/beignet/func/_space.py +++ b/src/beignet/func/_space.py @@ -136,7 +136,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor: return displacement_fn, shift_fn if parallelepiped: - inverse_transformation = invert_transform(dimensions) + inverted_transform = invert_transform(dimensions) if normalized: @@ -147,15 +147,15 @@ def displacement_fn( perturbation: Tensor | None = None, **kwargs, ) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] if len(input.shape) != 1: raise ValueError @@ -165,7 +165,7 @@ def displacement_fn( displacement = apply_transform( torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, - _transformation, + _transform, ) if perturbation is not None: @@ -179,38 +179,36 @@ def u(input: Tensor, other: Tensor) -> Tensor: return torch.remainder(input + other, 1.0) def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform(_transformation) + _inverted_transform = invert_transform(_transform) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] - return u(input, apply_transform(other, _inverse_transformation)) + return u(input, apply_transform(other, _inverted_transform)) return displacement_fn, shift_fn def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform( - _transformation, - ) + _inverted_transform = invert_transform(_transform) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] - return input + apply_transform(other, _inverse_transformation) + return input + apply_transform(other, _inverted_transform) return displacement_fn, shift_fn @@ -221,20 +219,20 @@ def displacement_fn( perturbation: Tensor | None = None, **kwargs, ) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform(_transformation) + _inverted_transform = invert_transform(_transform) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] - input = apply_transform(input, _inverse_transformation) - other = apply_transform(other, _inverse_transformation) + input = apply_transform(input, _inverted_transform) + other = apply_transform(other, _inverted_transform) if len(input.shape) != 1: raise ValueError @@ -244,7 +242,7 @@ def displacement_fn( displacement = apply_transform( torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5, - _transformation, + _transform, ) if perturbation is not None: @@ -258,26 +256,26 @@ def u(input: Tensor, other: Tensor) -> Tensor: return torch.remainder(input + other, 1.0) def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor: - _transformation = dimensions + _transform = dimensions - _inverse_transformation = inverse_transformation + _inverted_transform = inverted_transform if "transform" in kwargs: - _transformation = kwargs["transform"] + _transform = kwargs["transform"] - _inverse_transformation = invert_transform( - _transformation, + _inverted_transform = invert_transform( + _transform, ) if "updated_transformation" in kwargs: - _transformation = kwargs["updated_transformation"] + _transform = kwargs["updated_transformation"] return apply_transform( u( - apply_transform(_inverse_transformation, input), - apply_transform(_inverse_transformation, other), + apply_transform(_inverted_transform, input), + apply_transform(_inverted_transform, other), ), - _transformation, + _transform, ) return displacement_fn, shift_fn