Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Jun 11, 2024
1 parent 941b772 commit a0283ad
Show file tree
Hide file tree
Showing 17 changed files with 169 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def _cell_dimensions(
box_size: Tensor,
minimum_cell_size: float,
) -> (Tensor, Tensor, Tensor, int):
"""
Compute the number of cells-per-side and total number of cells in a box.
r"""Compute the number of cells-per-side and total number of cells in a box.
Parameters:
-----------
Expand Down
48 changes: 24 additions & 24 deletions src/beignet/func/_molecular_dynamics/_partition/__cell_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@

@_dataclass
class _CellList:
"""Stores the spatial partition of a system into a cell list.
See :meth:`cell_list` for details on the construction / specification.
Cell list buffers all have a common shape, S, where
* `S = [cell_count_x, cell_count_y, cell_capacity]`
* `S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]`
in two- and three-dimensions respectively. It is assumed that each cell has
the same capacity.
Attributes:
positions_buffer: An ndarray of floating point positions with shape
`S + [spatial_dimension]`.
indexes: An ndarray of int32 particle ids of shape `S`. Note that empty
slots are specified by `id = N` where `N` is the number of particles in
the system.
parameters: A dictionary of ndarrays of shape `S + [...]`. This contains
side data placed into the cell list.
exceeded_maximum_size: A boolean specifying whether or not the cell list
exceeded the maximum allocated capacity.
size: An integer specifying the maximum capacity of each cell in
the cell list.
item size: A tensor specifying the size of each cell in the cell list.
update_fn: A function that updates the cell list at a fixed capacity.
"""
r"""Stores the spatial partition of a system into a cell list.
See :meth:`cell_list` for details on the construction / specification.
Cell list buffers all have a common shape, S, where
* `S = [cell_count_x, cell_count_y, cell_capacity]`
* `S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]`
in two- and three-dimensions respectively. It is assumed that each cell has
the same capacity.
Attributes:
positions_buffer: An ndarray of floating point positions with shape
`S + [spatial_dimension]`.
indexes: An ndarray of int32 particle ids of shape `S`. Note that empty
slots are specified by `id = N` where `N` is the number of particles in
the system.
parameters: A dictionary of ndarrays of shape `S + [...]`. This contains
side data placed into the cell list.
exceeded_maximum_size: A boolean specifying whether or not the cell list
exceeded the maximum allocated capacity.
size: An integer specifying the maximum capacity of each cell in
the cell list.
item size: A tensor specifying the size of each cell in the cell list.
update_fn: A function that updates the cell list at a fixed capacity.
"""
exceeded_maximum_size: Tensor

indexes: Tensor
Expand Down
19 changes: 19 additions & 0 deletions src/beignet/func/_molecular_dynamics/_partition/__cell_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,25 @@


def _cell_size(box: Tensor, minimum_unit_size: Tensor) -> Tensor:
r"""Compute the size of cells within a box based on the minimum unit size.
Parameters:
-----------
box : Tensor
The size of the box. This must be a Tensor.
minimum_unit_size : Tensor
The minimum size of the units (cells). This must be a Tensor of the same shape as `box` or a scalar Tensor.
Returns:
--------
Tensor
The size of the cells in the box.
Raises:
-------
ValueError
If the box and minimum unit size do not have the same shape and `minimum_unit_size` is not a scalar.
"""
if box.shape == minimum_unit_size.shape or minimum_unit_size.ndim == 0:
return box / torch.floor(box / minimum_unit_size)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@


def _hash_constants(spatial_dimensions: int, cells_per_side: Tensor) -> Tensor:
"""
Compute constants used for hashing in a spatial partitioning scheme.
r"""Compute constants used for hashing in a spatial partitioning scheme.
The function calculates constants that help in determining the hash value
for a given cell in an N-dimensional grid, based on the number of cells
Expand Down
3 changes: 1 addition & 2 deletions src/beignet/func/_molecular_dynamics/_partition/__iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@


def _iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor:
"""
Generate a tensor with a specified shape where elements along the given dimension
r"""Generate a tensor with a specified shape where elements along the given dimension
are sequential integers starting from 0.
Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,21 @@


def _is_neighbor_list_format_valid(neighbor_list_format: _NeighborListFormat):
r"""Check if the given neighbor list format is valid.
Parameters:
-----------
neighbor_list_format : _NeighborListFormat
The neighbor list format to be validated.
Returns:
--------
None
Raises:
-------
ValueError
If the neighbor list format is not one of the recognized formats.
"""
if neighbor_list_format not in list(_NeighborListFormat):
raise ValueError
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@
def _is_neighbor_list_sparse(
neighbor_list_format: _NeighborListFormat,
) -> bool:
r"""Determine if the given neighbor list format is sparse.
Parameters:
-----------
neighbor_list_format : _NeighborListFormat
The neighbor list format to be checked.
Returns:
--------
bool
True if the neighbor list format is sparse, False otherwise.
"""
return neighbor_list_format in {
_NeighborListFormat.ORDERED_SPARSE,
_NeighborListFormat.SPARSE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,23 @@


def _is_space_valid(space: Tensor) -> Tensor:
r"""Check if the given space tensor is valid.
Parameters:
-----------
space : Tensor
The space tensor to be validated. This tensor can have 0, 1, or 2 dimensions.
Returns:
--------
Tensor
A tensor containing a single boolean value indicating whether the space is valid.
Raises:
-------
ValueError
If the space tensor has more than 2 dimensions.
"""
if space.ndim == 0 or space.ndim == 1:
return torch.tensor([True])

Expand Down
24 changes: 24 additions & 0 deletions src/beignet/func/_molecular_dynamics/_partition/__map_bond.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,32 @@


def _map_bond(distance_fn):
r"""Map a distance function over batched start and end positions.
Parameters:
-----------
distance_fn : callable
A function that computes the distance between two positions.
Returns:
--------
wrapper : callable
A wrapper function that applies `distance_fn` to each pair of start and end positions
in the batch.
Example:
--------
>>> # Assume `distance_fn` computes the Euclidean distance between two points
>>> start_positions = torch.tensor([[0, 0], [1, 1]])
>>> end_positions = torch.tensor([[0, 1], [1, 2]])
>>> wrapped_fn = _map_bond(distance_fn)
>>> result = wrapped_fn(start_positions, end_positions)
>>> print(result)
tensor([...])
"""
def wrapper(start_positions, end_positions):
batch_size = start_positions.shape[0]

return torch.stack(
[
distance_fn(start_positions[i], end_positions[i])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@


def _normalize_cell_size(box: Tensor, cutoff: float) -> Tensor:
"""
Normalize the cell size given the bounding box dimensions and a cutoff value.
r"""Normalize the cell size given the bounding box dimensions and a cutoff value.
Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def _particles_per_cell(
size: Tensor,
minimum_size: float,
) -> Tensor:
"""
Computes the number of particles per cell given a defined cell size and minimum size.
r"""Computes the number of particles per cell given a defined cell size and minimum size.
Parameters
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,60 @@

@_dataclass
class _PartitionError:
r"""A class to represent and manage partition errors with specific error codes.
Attributes:
-----------
code : Tensor
A tensor representing the error code.
Methods:
--------
update(bit: bytes, predicate: Tensor) -> "_PartitionError"
Update the error code based on a predicate and a new bit.
__str__() -> str
Provide a human-readable string representation of the error.
__repr__() -> str
Alias for __str__().
"""
code: Tensor

def update(self, bit: bytes, predicate: Tensor) -> "_PartitionError":
r"""Update the error code based on a predicate and a new bit.
Parameters:
-----------
bit : bytes
The bit to be combined with the existing error code.
predicate : Tensor
A tensor that determines where the bit should be applied.
Returns:
--------
_PartitionError
A new instance of `_PartitionError` with the updated error code.
"""
zero = torch.zeros([], dtype=torch.uint8)

bit = torch.tensor(bit, dtype=torch.uint8)

return _PartitionError(code=self.code | torch.where(predicate, bit, zero))

def __str__(self) -> str:
r"""Provide a human-readable string representation of the error.
Returns:
--------
str
A string describing the error.
Raises:
-------
ValueError
If the error code is unexpected or not recognized.
"""
if not torch.any(self.code):
return ""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def _segment_sum(
n: Optional[int] = None,
**kwargs,
) -> Tensor:
"""
Computes the sum of segments of a tensor along the first dimension.
r"""Computes the sum of segments of a tensor along the first dimension.
Parameters
----------
Expand Down
4 changes: 1 addition & 3 deletions src/beignet/func/_molecular_dynamics/_partition/__shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@


def _shift(a: Tensor, b: Tensor) -> Tensor:
"""
Shifts a tensor `a` along dimensions specified in `b`.
r"""Shifts a tensor `a` along dimensions specified in `b`.
The shift can be applied in up to three dimensions (x, y, z).
Positive values in `b` indicate a forward shift, while negative values indicate
Expand All @@ -22,5 +21,4 @@ def _shift(a: Tensor, b: Tensor) -> Tensor:
Tensor
The shifted tensor.
"""

return torch.roll(a, shifts=tuple(b), dims=tuple(range(len(b))))
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
def _to_square_metric_fn(
fn: Callable[[Tensor, Tensor, Any], Tensor],
) -> Callable[[Tensor, Tensor, Any], Tensor]:
"""
Converts a given distance function to a squared distance metric.
r"""Converts a given distance function to a squared distance metric.
The function tries to apply the given distance function `fn` to positions in
one to three dimensions to determine if the output is scalar or vector.
Expand Down Expand Up @@ -40,8 +39,10 @@ def square_metric(a: Tensor, b: Tensor, **kwargs) -> Tensor:
return torch.sum(torch.square(fn(a, b, **kwargs)), dim=-1)

return square_metric

except TypeError:
continue

except ValueError:
continue

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@


def _unflatten_cell_buffer(buffer: Tensor, cells_per_side: [int, float, Tensor], dim: int):
"""
Reshape a flat buffer into a multidimensional cell buffer.
r"""Reshape a flat buffer into a multidimensional cell buffer.
Parameters
----------
Expand Down
23 changes: 2 additions & 21 deletions src/beignet/func/_molecular_dynamics/_partition/_cell_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def fn(
exceeded_maximum_size = False

update_fn = fn

else:
buffer_size, exceeded_maximum_size, update_fn = excess

Expand Down Expand Up @@ -244,24 +245,4 @@ def update_fn(
**kwargs,
)

return _CellListFunctionList(setup_fn=setup_fn, update_fn=update_fn)


if __name__ == '__main__':
dtype = torch.float32
box_size = torch.tensor([8.65, 8.0], dtype=torch.float32)
cell_size = 1.0

# Test particle positions
R = torch.tensor([
[0.25, 0.25],
[8.5, 1.95],
[8.1, 1.5],
[3.7, 7.9]
], dtype=dtype)

cell_fn = cell_list(box_size, cell_size)

cell_list_instance = cell_fn.setup_fn(R)

print(cell_list_instance.positions_buffer[7, 3, 0])
return _CellListFunctionList(setup_fn=setup_fn, update_fn=update_fn)

0 comments on commit a0283ad

Please sign in to comment.