Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Aug 19, 2024
1 parent a5edef6 commit 93721ae
Show file tree
Hide file tree
Showing 21 changed files with 1,706 additions and 1,475 deletions.
4 changes: 4 additions & 0 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@
)
from ._probabilists_hermite_polynomial_x import probabilists_hermite_polynomial_x
from ._probabilists_hermite_polynomial_zero import probabilists_hermite_polynomial_zero
from ._iota import iota
from ._pairwise_displacement import pairwise_displacement
from ._quaternion_identity import quaternion_identity
from ._quaternion_magnitude import quaternion_magnitude
from ._quaternion_mean import quaternion_mean
Expand Down Expand Up @@ -355,6 +357,8 @@
from ._trim_probabilists_hermite_polynomial_coefficients import (
trim_probabilists_hermite_polynomial_coefficients,
)
from ._segment_sum import segment_sum
from ._square_distance import square_distance
from .special import error_erf, error_erfc

__all__ = [
Expand Down
37 changes: 37 additions & 0 deletions src/beignet/_iota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch import Tensor


def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor:
r"""Generate a tensor with a specified shape where elements along the given dimension
are sequential integers starting from 0.
Parameters
----------
shape : tuple[int, ...]
The shape of the resulting tensor.
dim : int, optional
The dimension along which to vary the values (default is 0).
Returns
-------
Tensor
A tensor of the specified shape with sequential integers along the specified dimension.
Raises
------
IndexError
If `dim` is out of the range of `shape`.
"""
dimensions = []

for index, _ in enumerate(shape):
if index != dim:
dimension = 1

else:
dimension = shape[index]

dimensions = [*dimensions, dimension]

return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape)
26 changes: 14 additions & 12 deletions src/beignet/_pairwise_displacement.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
from torch import Tensor


def pairwise_displacement(Ra: Tensor, Rb: Tensor) -> Tensor:
def pairwise_displacement(input: Tensor, other: Tensor) -> Tensor:
r"""Compute a matrix of pairwise displacements given two sets of positions.
Parameters
----------
Ra : Tensor
input : Tensor
Vector of positions
Rb : Tensor
other : Tensor
Vector of positions
Returns:
Tensor(shape=[spatial_dim]
Matrix of displacements
-------
output : Tensor, shape [spatial_dimensions]
Matrix of displacements
"""
if len(Ra.shape) != 1:
msg = (
if len(input.shape) != 1:
message = (
"Can only compute displacements between vectors. To compute "
"displacements between sets of vectors use vmap or TODO."
)
raise ValueError(msg)
raise ValueError(message)

if Ra.shape != Rb.shape:
msg = "Can only compute displacement between vectors of equal dimension."
raise ValueError(msg)
if input.shape != other.shape:
message = "Can only compute displacement between vectors of equal dimension."
raise ValueError(message)

return Ra - Rb
return input - other
18 changes: 8 additions & 10 deletions src/beignet/_periodic_displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,23 @@
from torch import Tensor


def periodic_displacement(box: float | Tensor, dR: Tensor) -> Tensor:
def periodic_displacement(input: Tensor, position: Tensor) -> Tensor:
r"""Wraps displacement vectors into a hypercube.
Parameters
Parameters:
----------
box : float or Tensor
Specification of hypercube size. Either:
(a) float if all sides have equal length.
(a) scalar if all sides have equal length.
(b) Tensor of shape (spatial_dim,) if sides have different lengths.
dR : Tensor
Matrix of displacements with shape (..., spatial_dim).
Returns
Returns:
-------
Tensor
output : Tensor, shape=(...)
Matrix of wrapped displacements with shape (..., spatial_dim).
"""
distances = (
torch.remainder(dR + box * torch.tensor(0.5, dtype=torch.float32), box)
- torch.tensor(0.5, dtype=torch.float32) * box
)
return distances
output = torch.remainder(position + input * 0.5, input) - 0.5 * input
return output
55 changes: 55 additions & 0 deletions src/beignet/_segment_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import math
from typing import Optional

import torch
from torch import Tensor


def segment_sum(
input: Tensor,
indexes: Tensor,
n: Optional[int] = None,
**kwargs,
) -> Tensor:
r"""Computes the sum of segments of a tensor along the first dimension.
Parameters
----------
input : Tensor
A tensor containing the input values to be summed.
indexes : Tensor
A 1D tensor containing the segment indexes for summation.
Should have the same length as the first dimension of the `input` tensor.
n : Optional[int], optional
The number of segments, by default `n` is set to `max(indexes) + 1`.
Returns
-------
Tensor
A tensor where each entry contains the sum of the corresponding segment
from the `input` tensor.
"""
if indexes.ndim == 1:
indexes = torch.repeat_interleave(indexes, math.prod([*input.shape[1:]])).view(
*[indexes.shape[0], *input.shape[1:]]
)

if input.size(0) != indexes.size(0):
raise ValueError(
"The length of the indexes tensor must match the size of the first dimension of the input tensor."
)

if n is None:
n = indexes.max().item() + 1

valid_mask = indexes < n
valid_indexes = indexes[valid_mask]
valid_input = input[valid_mask]

output = torch.zeros(n, *input.shape[1:], device=input.device)

return output.scatter_add(0, valid_indexes, valid_input.to(torch.float32)).to(
**kwargs
)
52 changes: 52 additions & 0 deletions src/beignet/_square_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Callable, Any

import torch
from torch import Tensor


def _square_distance(input: Tensor) -> Tensor:
"""Computes square distances.
Args:
input: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`.
Returns:
Matrix of squared distances; `Tensor(shape=[...])`.
"""
return torch.sum(input**2, dim=-1)


def _safe_mask(
mask: Tensor, fn: Callable, operand: Tensor, placeholder: Any = 0
) -> Tensor:
r"""Applies a function to elements of a tensor where a mask is True, and replaces elements where the mask is False with a placeholder.
Parameters
----------
mask : Tensor
A boolean tensor indicating which elements to apply the function to.
fn : Callable[[Tensor], Tensor]
The function to apply to the masked elements.
operand : Tensor
The tensor to apply the function to.
placeholder : Any, optional
The value to use for elements where the mask is False (default is 0).
Returns
-------
Tensor
A tensor with the function applied to the masked elements and the placeholder value elsewhere.
"""
masked = torch.where(mask, operand, torch.tensor(0, dtype=operand.dtype))

return torch.where(mask, fn(masked), torch.tensor(placeholder, dtype=operand.dtype))


def square_distance(dR: Tensor) -> Tensor:
r"""Computes distances.
Args:
dR: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`.
Returns:
Matrix of distances; `Tensor(shape=[...])`.
"""
return _safe_mask(_square_distance(dR) > 0, torch.sqrt, _square_distance(dR))
8 changes: 4 additions & 4 deletions src/beignet/func/__dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ def _set(self: dataclasses.dataclass, **kwargs):
else:
metadata_fields = [*metadata_fields, name]

def _iterate_cls(_x) -> List[Tuple]:
def _iterate_cls(_x) -> list[list]:
data_iterable = []

for k in data_fields:
data_iterable.append(getattr(_x, k))
data_iterable = [*data_iterable, getattr(_x, k)]

metadata_iterable = []

for k in metadata_fields:
metadata_iterable.append(getattr(_x, k))
metadata_iterable = [*metadata_iterable, getattr(_x, k)]

return [data_iterable, metadata_iterable]

Expand All @@ -46,7 +46,7 @@ def _iterable_to_cls(meta, data):
dataclass_cls,
_iterate_cls,
_iterable_to_cls,
"prescient.func",
"beignet.func",
)

return dataclass_cls
Loading

0 comments on commit 93721ae

Please sign in to comment.