Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed May 10, 2024
1 parent 1ac578b commit 2b97ae2
Showing 1 changed file with 35 additions and 37 deletions.
72 changes: 35 additions & 37 deletions src/beignet/func/_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def shift_fn(input: Tensor, other: Tensor, **_) -> Tensor:
return displacement_fn, shift_fn

if parallelepiped:
inverse_transformation = invert_transform(dimensions)
inverted_transform = invert_transform(dimensions)

if normalized:

Expand All @@ -147,15 +147,15 @@ def displacement_fn(
perturbation: Tensor | None = None,
**kwargs,
) -> Tensor:
_transformation = dimensions
_transform = dimensions

_inverse_transformation = inverse_transformation
_inverted_transform = inverted_transform

if "transform" in kwargs:
_transformation = kwargs["transform"]
_transform = kwargs["transform"]

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]
_transform = kwargs["updated_transformation"]

if len(input.shape) != 1:
raise ValueError
Expand All @@ -165,7 +165,7 @@ def displacement_fn(

displacement = apply_transform(
torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5,
_transformation,
_transform,
)

if perturbation is not None:
Expand All @@ -179,38 +179,36 @@ def u(input: Tensor, other: Tensor) -> Tensor:
return torch.remainder(input + other, 1.0)

def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
_transformation = dimensions
_transform = dimensions

_inverse_transformation = inverse_transformation
_inverted_transform = inverted_transform

if "transform" in kwargs:
_transformation = kwargs["transform"]
_transform = kwargs["transform"]

_inverse_transformation = invert_transform(_transformation)
_inverted_transform = invert_transform(_transform)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]
_transform = kwargs["updated_transformation"]

return u(input, apply_transform(other, _inverse_transformation))
return u(input, apply_transform(other, _inverted_transform))

return displacement_fn, shift_fn

def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
_transformation = dimensions
_transform = dimensions

_inverse_transformation = inverse_transformation
_inverted_transform = inverted_transform

if "transform" in kwargs:
_transformation = kwargs["transform"]
_transform = kwargs["transform"]

_inverse_transformation = invert_transform(
_transformation,
)
_inverted_transform = invert_transform(_transform)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]
_transform = kwargs["updated_transformation"]

return input + apply_transform(other, _inverse_transformation)
return input + apply_transform(other, _inverted_transform)

return displacement_fn, shift_fn

Expand All @@ -221,20 +219,20 @@ def displacement_fn(
perturbation: Tensor | None = None,
**kwargs,
) -> Tensor:
_transformation = dimensions
_transform = dimensions

_inverse_transformation = inverse_transformation
_inverted_transform = inverted_transform

if "transform" in kwargs:
_transformation = kwargs["transform"]
_transform = kwargs["transform"]

_inverse_transformation = invert_transform(_transformation)
_inverted_transform = invert_transform(_transform)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]
_transform = kwargs["updated_transformation"]

input = apply_transform(input, _inverse_transformation)
other = apply_transform(other, _inverse_transformation)
input = apply_transform(input, _inverted_transform)
other = apply_transform(other, _inverted_transform)

if len(input.shape) != 1:
raise ValueError
Expand All @@ -244,7 +242,7 @@ def displacement_fn(

displacement = apply_transform(
torch.remainder(input - other + 1.0 * 0.5, 1.0) - 1.0 * 0.5,
_transformation,
_transform,
)

if perturbation is not None:
Expand All @@ -258,26 +256,26 @@ def u(input: Tensor, other: Tensor) -> Tensor:
return torch.remainder(input + other, 1.0)

def shift_fn(input: Tensor, other: Tensor, **kwargs) -> Tensor:
_transformation = dimensions
_transform = dimensions

_inverse_transformation = inverse_transformation
_inverted_transform = inverted_transform

if "transform" in kwargs:
_transformation = kwargs["transform"]
_transform = kwargs["transform"]

_inverse_transformation = invert_transform(
_transformation,
_inverted_transform = invert_transform(
_transform,
)

if "updated_transformation" in kwargs:
_transformation = kwargs["updated_transformation"]
_transform = kwargs["updated_transformation"]

return apply_transform(
u(
apply_transform(_inverse_transformation, input),
apply_transform(_inverse_transformation, other),
apply_transform(_inverted_transform, input),
apply_transform(_inverted_transform, other),
),
_transformation,
_transform,
)

return displacement_fn, shift_fn
Expand Down

0 comments on commit 2b97ae2

Please sign in to comment.