Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

beignet.func.partition #14

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ requires = [
authors = [{ email = "[email protected]", name = "Allen Goodman" }]
dependencies = [
"pooch",
"torch==2.2.2",
"torch",
"torchaudio",
"tqdm",
]
Expand Down
7 changes: 7 additions & 0 deletions src/beignet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@
from ._linear_probabilists_hermite_polynomial import (
linear_probabilists_hermite_polynomial,
)
from ._map_bond import map_bond
from ._map_neighbor import map_neighbor
from ._map_product import map_product
from ._multiply_chebyshev_polynomial import multiply_chebyshev_polynomial
from ._multiply_chebyshev_polynomial_by_x import multiply_chebyshev_polynomial_by_x
from ._multiply_laguerre_polynomial import multiply_laguerre_polynomial
Expand Down Expand Up @@ -291,6 +294,8 @@
)
from ._probabilists_hermite_polynomial_x import probabilists_hermite_polynomial_x
from ._probabilists_hermite_polynomial_zero import probabilists_hermite_polynomial_zero
from ._iota import iota
from ._pairwise_displacement import pairwise_displacement
from ._quaternion_identity import quaternion_identity
from ._quaternion_magnitude import quaternion_magnitude
from ._quaternion_mean import quaternion_mean
Expand Down Expand Up @@ -355,6 +360,8 @@
from ._trim_probabilists_hermite_polynomial_coefficients import (
trim_probabilists_hermite_polynomial_coefficients,
)
from ._segment_sum import segment_sum
from ._square_distance import square_distance
from .special import error_erf, error_erfc

__all__ = [
Expand Down
37 changes: 37 additions & 0 deletions src/beignet/_iota.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch import Tensor


def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor:
r"""Generate a tensor with a specified shape where elements along the given dimension
are sequential integers starting from 0.

Parameters
----------
shape : tuple[int, ...]
The shape of the resulting tensor.
dim : int, optional
The dimension along which to vary the values (default is 0).

Returns
-------
Tensor
A tensor of the specified shape with sequential integers along the specified dimension.

Raises
------
IndexError
If `dim` is out of the range of `shape`.
"""
dimensions = []

for index, _ in enumerate(shape):
if index != dim:
dimension = 1

else:
dimension = shape[index]

dimensions = [*dimensions, dimension]

return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape)
20 changes: 20 additions & 0 deletions src/beignet/_map_bond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Callable

import torch


def map_bond(metric_or_displacement: Callable) -> Callable:
r"""Map a distance function over batched start and end positions.

Parameters:
-----------
distance_fn : callable
A function that computes the distance between two positions.

Returns:
--------
wrapper : callable
A wrapper function that applies `distance_fn` to each pair of start and end positions
in the batch.
"""
return torch.vmap(metric_or_displacement, (0, 0), 0)
30 changes: 30 additions & 0 deletions src/beignet/_map_neighbor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Callable

import torch


def map_neighbor(metric_or_displacement: Callable) -> Callable:
r"""Vectorizes a metric or displacement function over neighborhoods.

Parameters
----------
metric_or_displacement : callable
A function that computes a metric or displacement between two inputs.
This function should accept two arguments and return a single value
representing the metric or displacement.

Returns
-------
wrapped_fn : callable
A vectorized function that applies `metric_or_displacement` over
neighborhoods of input data. The returned function takes two arguments:
`input` and `other`, where `input` is the reference data and `other` is
the neighborhood data.
"""

def wrapped_fn(input, other, **kwargs):
return torch.vmap(torch.vmap(metric_or_displacement, (0, None)))(
other, input, **kwargs
)

return wrapped_fn
24 changes: 24 additions & 0 deletions src/beignet/_map_product.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Callable

import torch


def map_product(metric_or_displacement: Callable) -> Callable:
r"""Vectorizes a metric or displacement function over all pairs.

Parameters
----------
metric_or_displacement : callable
A function that computes a metric or displacement between two inputs.
This function should accept two arguments and return a single value
representing the metric or displacement.

Returns
-------
wrapped_fn : callable
A vectorized function that applies `metric_or_displacement` over all
pairs of input data. The returned function takes two arguments:
`input1` and `input2`, where `input1` and `input2` are the sets of data
to be compared.
"""
return torch.vmap(torch.vmap(metric_or_displacement, (0, None)), (None, 0))
31 changes: 31 additions & 0 deletions src/beignet/_pairwise_displacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from torch import Tensor


def pairwise_displacement(input: Tensor, other: Tensor) -> Tensor:
r"""Compute a matrix of pairwise displacements given two sets of positions.

Parameters
----------
input : Tensor
Vector of positions

other : Tensor
Vector of positions

Returns:
henry-isaacson marked this conversation as resolved.
Show resolved Hide resolved
-------
output : Tensor, shape [spatial_dimensions]
Matrix of displacements
"""
if len(input.shape) != 1:
message = (
"Can only compute displacements between vectors. To compute "
"displacements between sets of vectors use vmap or TODO."
)
raise ValueError(message)

if input.shape != other.shape:
message = "Can only compute displacement between vectors of equal dimension."
raise ValueError(message)

return input - other
24 changes: 24 additions & 0 deletions src/beignet/_periodic_displacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
from torch import Tensor


def periodic_displacement(input: Tensor, position: Tensor) -> Tensor:
r"""Wraps displacement vectors into a hypercube.

Parameters:
----------
box : float or Tensor
Specification of hypercube size. Either:
(a) scalar if all sides have equal length.
(b) Tensor of shape (spatial_dim,) if sides have different lengths.

dR : Tensor
henry-isaacson marked this conversation as resolved.
Show resolved Hide resolved
Matrix of displacements with shape (..., spatial_dim).

Returns:
-------
output : Tensor, shape=(...)
Matrix of wrapped displacements with shape (..., spatial_dim).
"""
output = torch.remainder(position + input * 0.5, input) - 0.5 * input
return output
55 changes: 55 additions & 0 deletions src/beignet/_segment_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import math
from typing import Optional

import torch
from torch import Tensor


def segment_sum(
input: Tensor,
indexes: Tensor,
n: Optional[int] = None,
**kwargs,
) -> Tensor:
r"""Computes the sum of segments of a tensor along the first dimension.

Parameters
----------
input : Tensor
A tensor containing the input values to be summed.

indexes : Tensor
A 1D tensor containing the segment indexes for summation.
Should have the same length as the first dimension of the `input` tensor.

n : Optional[int], optional
The number of segments, by default `n` is set to `max(indexes) + 1`.

Returns
-------
Tensor
A tensor where each entry contains the sum of the corresponding segment
from the `input` tensor.
"""
if indexes.ndim == 1:
indexes = torch.repeat_interleave(indexes, math.prod([*input.shape[1:]])).view(
*[indexes.shape[0], *input.shape[1:]]
)

if input.size(0) != indexes.size(0):
raise ValueError(
"The length of the indexes tensor must match the size of the first dimension of the input tensor."
)

if n is None:
n = indexes.max().item() + 1

valid_mask = indexes < n
valid_indexes = indexes[valid_mask]
valid_input = input[valid_mask]

output = torch.zeros(
n, *input.shape[1:], device=input.device, dtype=valid_input.dtype
)

return output.scatter_add(0, valid_indexes, valid_input).to(**kwargs)
52 changes: 52 additions & 0 deletions src/beignet/_square_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import Callable, Any

import torch
from torch import Tensor


def _square_distance(input: Tensor) -> Tensor:
"""Computes square distances.

Args:
input: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`.
Returns:
Matrix of squared distances; `Tensor(shape=[...])`.
"""
return torch.sum(input**2, dim=-1)


def _safe_mask(
mask: Tensor, fn: Callable, operand: Tensor, placeholder: Any = 0
) -> Tensor:
r"""Applies a function to elements of a tensor where a mask is True, and replaces elements where the mask is False with a placeholder.

Parameters
----------
mask : Tensor
A boolean tensor indicating which elements to apply the function to.
fn : Callable[[Tensor], Tensor]
The function to apply to the masked elements.
operand : Tensor
The tensor to apply the function to.
placeholder : Any, optional
The value to use for elements where the mask is False (default is 0).

Returns
-------
Tensor
A tensor with the function applied to the masked elements and the placeholder value elsewhere.
"""
masked = torch.where(mask, operand, torch.tensor(0, dtype=operand.dtype))

return torch.where(mask, fn(masked), torch.tensor(placeholder, dtype=operand.dtype))


def square_distance(dR: Tensor) -> Tensor:
r"""Computes distances.

Args:
dR: Matrix of displacements; `Tensor(shape=[..., spatial_dim])`.
Returns:
Matrix of distances; `Tensor(shape=[...])`.
"""
return _safe_mask(_square_distance(dR) > 0, torch.sqrt, _square_distance(dR))
51 changes: 51 additions & 0 deletions src/beignet/func/__dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import dataclasses
from typing import List, Tuple, Type, TypeVar, Iterable

from torch.utils._pytree import register_pytree_node

T = TypeVar("T")


def _dataclass(cls: Type[T]):
def _set(self: dataclasses.dataclass, **kwargs):
return dataclasses.replace(self, **kwargs)

cls.set = _set

dataclass_cls = dataclasses.dataclass(frozen=True)(cls)

data_fields, metadata_fields = [], []

for name, kind in dataclass_cls.__dataclass_fields__.items():
if not kind.metadata.get("static", False):
data_fields = [*data_fields, name]
else:
metadata_fields = [*metadata_fields, name]

def _iterate_cls(_x) -> list[list[T]]:
data_iterable = []

for k in data_fields:
data_iterable = [*data_iterable, getattr(_x, k)]

metadata_iterable = []

for k in metadata_fields:
metadata_iterable = [*metadata_iterable, getattr(_x, k)]

return [data_iterable, metadata_iterable]
henry-isaacson marked this conversation as resolved.
Show resolved Hide resolved

def _iterable_to_cls(meta, data) -> T:
meta_args = tuple(zip(metadata_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)

return dataclass_cls(**kwargs)

register_pytree_node(
dataclass_cls,
_iterate_cls,
_iterable_to_cls,
)

return dataclass_cls
Loading
Loading