-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
200 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import torch | ||
from torch import Tensor | ||
from torch.autograd import Function | ||
|
||
|
||
def _apply_transform(input: Tensor, transform: Tensor) -> Tensor: | ||
""" | ||
Applies an affine transformation to the position vector. | ||
Parameters | ||
---------- | ||
input : Tensor | ||
Position, must have the shape `(..., dimension)`. | ||
transform : Tensor | ||
The affine transformation matrix, must be a scalar, a vector, or a | ||
matrix with the shape `(dimension, dimension)`. | ||
Returns | ||
------- | ||
Tensor | ||
Affine transformed position vector, has the same shape as the | ||
position vector. | ||
""" | ||
if transform.ndim == 0: | ||
return input * transform | ||
|
||
indices = [chr(ord("a") + index) for index in range(input.ndim - 1)] | ||
|
||
indices = "".join(indices) | ||
|
||
if transform.ndim == 1: | ||
return torch.einsum( | ||
f"i,{indices}i->{indices}i", | ||
transform, | ||
input, | ||
) | ||
|
||
if transform.ndim == 2: | ||
return torch.einsum( | ||
f"ij,{indices}j->{indices}i", | ||
transform, | ||
input, | ||
) | ||
|
||
raise ValueError | ||
|
||
|
||
class _ApplyTransform(Function): | ||
generate_vmap_rule = True | ||
|
||
@staticmethod | ||
def forward(transform: Tensor, position: Tensor) -> Tensor: | ||
""" | ||
Return affine transformed position. | ||
Parameters | ||
---------- | ||
transform : Tensor | ||
Affine transformation matrix, must have shape | ||
`(dimension, dimension)`. | ||
position : Tensor | ||
Position, must have shape `(..., dimension)`. | ||
Returns | ||
------- | ||
Tensor | ||
Affine transformed position of shape `(..., dimension)`. | ||
""" | ||
return _apply_transform(position, transform) | ||
|
||
@staticmethod | ||
def setup_context(ctx, inputs, output): | ||
transformation, position = inputs | ||
|
||
ctx.save_for_backward(transformation, position, output) | ||
|
||
@staticmethod | ||
def jvp(ctx, grad_transform: Tensor, grad_position: Tensor) -> (Tensor, Tensor): | ||
transformation, position, _ = ctx.saved_tensors | ||
|
||
output = _apply_transform(position, transformation) | ||
|
||
grad_output = grad_position + _apply_transform(position, grad_transform) | ||
|
||
return output, grad_output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output: Tensor) -> (Tensor, Tensor): | ||
_, _, output = ctx.saved_tensors | ||
|
||
return output, grad_output | ||
|
||
|
||
def apply_transform(input: Tensor, transform: Tensor) -> Tensor: | ||
""" | ||
Return affine transformed position. | ||
Parameters | ||
---------- | ||
input : Tensor | ||
Position, must have shape `(..., dimension)`. | ||
transform : Tensor | ||
Affine transformation matrix, must have shape | ||
`(dimension, dimension)`. | ||
Returns | ||
------- | ||
Tensor | ||
Affine transformed position of shape `(..., dimension)`. | ||
""" | ||
return _ApplyTransform.apply(transform, input) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def invert_transform(transform: Tensor) -> Tensor: | ||
""" | ||
Calculates the inverse of an affine transformation matrix. | ||
Parameters | ||
---------- | ||
transform : Tensor | ||
The affine transformation matrix to be inverted. | ||
Returns | ||
------- | ||
Tensor | ||
The inverse of the given affine transformation matrix. | ||
""" | ||
if transform.ndim in {0, 1}: | ||
return 1.0 / transform | ||
|
||
if transform.ndim == 2: | ||
return torch.linalg.inv(transform) | ||
|
||
raise ValueError |
Oops, something went wrong.