-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Henry Isaacson
committed
Aug 19, 2024
1 parent
a5edef6
commit 93721ae
Showing
21 changed files
with
1,706 additions
and
1,475 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor: | ||
r"""Generate a tensor with a specified shape where elements along the given dimension | ||
are sequential integers starting from 0. | ||
Parameters | ||
---------- | ||
shape : tuple[int, ...] | ||
The shape of the resulting tensor. | ||
dim : int, optional | ||
The dimension along which to vary the values (default is 0). | ||
Returns | ||
------- | ||
Tensor | ||
A tensor of the specified shape with sequential integers along the specified dimension. | ||
Raises | ||
------ | ||
IndexError | ||
If `dim` is out of the range of `shape`. | ||
""" | ||
dimensions = [] | ||
|
||
for index, _ in enumerate(shape): | ||
if index != dim: | ||
dimension = 1 | ||
|
||
else: | ||
dimension = shape[index] | ||
|
||
dimensions = [*dimensions, dimension] | ||
|
||
return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,31 @@ | ||
from torch import Tensor | ||
|
||
|
||
def pairwise_displacement(Ra: Tensor, Rb: Tensor) -> Tensor: | ||
def pairwise_displacement(input: Tensor, other: Tensor) -> Tensor: | ||
r"""Compute a matrix of pairwise displacements given two sets of positions. | ||
Parameters | ||
---------- | ||
Ra : Tensor | ||
input : Tensor | ||
Vector of positions | ||
Rb : Tensor | ||
other : Tensor | ||
Vector of positions | ||
Returns: | ||
Tensor(shape=[spatial_dim] | ||
Matrix of displacements | ||
------- | ||
output : Tensor, shape [spatial_dimensions] | ||
Matrix of displacements | ||
""" | ||
if len(Ra.shape) != 1: | ||
msg = ( | ||
if len(input.shape) != 1: | ||
message = ( | ||
"Can only compute displacements between vectors. To compute " | ||
"displacements between sets of vectors use vmap or TODO." | ||
) | ||
raise ValueError(msg) | ||
raise ValueError(message) | ||
|
||
if Ra.shape != Rb.shape: | ||
msg = "Can only compute displacement between vectors of equal dimension." | ||
raise ValueError(msg) | ||
if input.shape != other.shape: | ||
message = "Can only compute displacement between vectors of equal dimension." | ||
raise ValueError(message) | ||
|
||
return Ra - Rb | ||
return input - other |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import math | ||
from typing import Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def segment_sum( | ||
input: Tensor, | ||
indexes: Tensor, | ||
n: Optional[int] = None, | ||
**kwargs, | ||
) -> Tensor: | ||
r"""Computes the sum of segments of a tensor along the first dimension. | ||
Parameters | ||
---------- | ||
input : Tensor | ||
A tensor containing the input values to be summed. | ||
indexes : Tensor | ||
A 1D tensor containing the segment indexes for summation. | ||
Should have the same length as the first dimension of the `input` tensor. | ||
n : Optional[int], optional | ||
The number of segments, by default `n` is set to `max(indexes) + 1`. | ||
Returns | ||
------- | ||
Tensor | ||
A tensor where each entry contains the sum of the corresponding segment | ||
from the `input` tensor. | ||
""" | ||
if indexes.ndim == 1: | ||
indexes = torch.repeat_interleave(indexes, math.prod([*input.shape[1:]])).view( | ||
*[indexes.shape[0], *input.shape[1:]] | ||
) | ||
|
||
if input.size(0) != indexes.size(0): | ||
raise ValueError( | ||
"The length of the indexes tensor must match the size of the first dimension of the input tensor." | ||
) | ||
|
||
if n is None: | ||
n = indexes.max().item() + 1 | ||
|
||
valid_mask = indexes < n | ||
valid_indexes = indexes[valid_mask] | ||
valid_input = input[valid_mask] | ||
|
||
output = torch.zeros(n, *input.shape[1:], device=input.device) | ||
|
||
return output.scatter_add(0, valid_indexes, valid_input.to(torch.float32)).to( | ||
**kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from typing import Callable, Any | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _square_distance(input: Tensor) -> Tensor: | ||
"""Computes square distances. | ||
Args: | ||
input: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`. | ||
Returns: | ||
Matrix of squared distances; `Tensor(shape=[...])`. | ||
""" | ||
return torch.sum(input**2, dim=-1) | ||
|
||
|
||
def _safe_mask( | ||
mask: Tensor, fn: Callable, operand: Tensor, placeholder: Any = 0 | ||
) -> Tensor: | ||
r"""Applies a function to elements of a tensor where a mask is True, and replaces elements where the mask is False with a placeholder. | ||
Parameters | ||
---------- | ||
mask : Tensor | ||
A boolean tensor indicating which elements to apply the function to. | ||
fn : Callable[[Tensor], Tensor] | ||
The function to apply to the masked elements. | ||
operand : Tensor | ||
The tensor to apply the function to. | ||
placeholder : Any, optional | ||
The value to use for elements where the mask is False (default is 0). | ||
Returns | ||
------- | ||
Tensor | ||
A tensor with the function applied to the masked elements and the placeholder value elsewhere. | ||
""" | ||
masked = torch.where(mask, operand, torch.tensor(0, dtype=operand.dtype)) | ||
|
||
return torch.where(mask, fn(masked), torch.tensor(placeholder, dtype=operand.dtype)) | ||
|
||
|
||
def square_distance(dR: Tensor) -> Tensor: | ||
r"""Computes distances. | ||
Args: | ||
dR: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`. | ||
Returns: | ||
Matrix of distances; `Tensor(shape=[...])`. | ||
""" | ||
return _safe_mask(_square_distance(dR) > 0, torch.sqrt, _square_distance(dR)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.