From a0283ad21bb87d9e209746eef9479405911525e8 Mon Sep 17 00:00:00 2001 From: Henry Isaacson Date: Tue, 11 Jun 2024 09:19:45 -0400 Subject: [PATCH] cleanup --- .../_partition/__cell_dimensions.py | 3 +- .../_partition/__cell_list.py | 48 +++++++++---------- .../_partition/__cell_size.py | 19 ++++++++ .../_partition/__hash_constants.py | 3 +- .../_molecular_dynamics/_partition/__iota.py | 3 +- .../__is_neighbor_list_format_valid.py | 16 +++++++ .../_partition/__is_neighbor_list_sparse.py | 12 +++++ .../_partition/__is_space_valid.py | 17 +++++++ .../_partition/__map_bond.py | 24 ++++++++++ .../_partition/__normalize_cell_size.py | 3 +- .../_partition/__particles_per_cell.py | 3 +- .../_partition/__partition_error.py | 44 +++++++++++++++++ .../_partition/__segment_sum.py | 3 +- .../_molecular_dynamics/_partition/__shift.py | 4 +- .../_partition/__to_square_metric_fn.py | 5 +- .../_partition/__unflatten_cell_buffer.py | 3 +- .../_partition/_cell_list.py | 23 +-------- 17 files changed, 169 insertions(+), 64 deletions(-) diff --git a/src/beignet/func/_molecular_dynamics/_partition/__cell_dimensions.py b/src/beignet/func/_molecular_dynamics/_partition/__cell_dimensions.py index ad1d190ca7..e7154acded 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__cell_dimensions.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__cell_dimensions.py @@ -11,8 +11,7 @@ def _cell_dimensions( box_size: Tensor, minimum_cell_size: float, ) -> (Tensor, Tensor, Tensor, int): - """ - Compute the number of cells-per-side and total number of cells in a box. + r"""Compute the number of cells-per-side and total number of cells in a box. Parameters: ----------- diff --git a/src/beignet/func/_molecular_dynamics/_partition/__cell_list.py b/src/beignet/func/_molecular_dynamics/_partition/__cell_list.py index d5cfa986fe..927e1156bf 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__cell_list.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__cell_list.py @@ -8,30 +8,30 @@ @_dataclass class _CellList: - """Stores the spatial partition of a system into a cell list. - - See :meth:`cell_list` for details on the construction / specification. - Cell list buffers all have a common shape, S, where - * `S = [cell_count_x, cell_count_y, cell_capacity]` - * `S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]` - in two- and three-dimensions respectively. It is assumed that each cell has - the same capacity. - - Attributes: - positions_buffer: An ndarray of floating point positions with shape - `S + [spatial_dimension]`. - indexes: An ndarray of int32 particle ids of shape `S`. Note that empty - slots are specified by `id = N` where `N` is the number of particles in - the system. - parameters: A dictionary of ndarrays of shape `S + [...]`. This contains - side data placed into the cell list. - exceeded_maximum_size: A boolean specifying whether or not the cell list - exceeded the maximum allocated capacity. - size: An integer specifying the maximum capacity of each cell in - the cell list. - item size: A tensor specifying the size of each cell in the cell list. - update_fn: A function that updates the cell list at a fixed capacity. - """ + r"""Stores the spatial partition of a system into a cell list. + + See :meth:`cell_list` for details on the construction / specification. + Cell list buffers all have a common shape, S, where + * `S = [cell_count_x, cell_count_y, cell_capacity]` + * `S = [cell_count_x, cell_count_y, cell_count_z, cell_capacity]` + in two- and three-dimensions respectively. It is assumed that each cell has + the same capacity. + + Attributes: + positions_buffer: An ndarray of floating point positions with shape + `S + [spatial_dimension]`. + indexes: An ndarray of int32 particle ids of shape `S`. Note that empty + slots are specified by `id = N` where `N` is the number of particles in + the system. + parameters: A dictionary of ndarrays of shape `S + [...]`. This contains + side data placed into the cell list. + exceeded_maximum_size: A boolean specifying whether or not the cell list + exceeded the maximum allocated capacity. + size: An integer specifying the maximum capacity of each cell in + the cell list. + item size: A tensor specifying the size of each cell in the cell list. + update_fn: A function that updates the cell list at a fixed capacity. + """ exceeded_maximum_size: Tensor indexes: Tensor diff --git a/src/beignet/func/_molecular_dynamics/_partition/__cell_size.py b/src/beignet/func/_molecular_dynamics/_partition/__cell_size.py index 299b82e629..2f772854be 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__cell_size.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__cell_size.py @@ -3,6 +3,25 @@ def _cell_size(box: Tensor, minimum_unit_size: Tensor) -> Tensor: + r"""Compute the size of cells within a box based on the minimum unit size. + + Parameters: + ----------- + box : Tensor + The size of the box. This must be a Tensor. + minimum_unit_size : Tensor + The minimum size of the units (cells). This must be a Tensor of the same shape as `box` or a scalar Tensor. + + Returns: + -------- + Tensor + The size of the cells in the box. + + Raises: + ------- + ValueError + If the box and minimum unit size do not have the same shape and `minimum_unit_size` is not a scalar. + """ if box.shape == minimum_unit_size.shape or minimum_unit_size.ndim == 0: return box / torch.floor(box / minimum_unit_size) diff --git a/src/beignet/func/_molecular_dynamics/_partition/__hash_constants.py b/src/beignet/func/_molecular_dynamics/_partition/__hash_constants.py index 5c2b5412c0..0cf0a87399 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__hash_constants.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__hash_constants.py @@ -3,8 +3,7 @@ def _hash_constants(spatial_dimensions: int, cells_per_side: Tensor) -> Tensor: - """ - Compute constants used for hashing in a spatial partitioning scheme. + r"""Compute constants used for hashing in a spatial partitioning scheme. The function calculates constants that help in determining the hash value for a given cell in an N-dimensional grid, based on the number of cells diff --git a/src/beignet/func/_molecular_dynamics/_partition/__iota.py b/src/beignet/func/_molecular_dynamics/_partition/__iota.py index 035c087816..80bad00723 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__iota.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__iota.py @@ -3,8 +3,7 @@ def _iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor: - """ - Generate a tensor with a specified shape where elements along the given dimension + r"""Generate a tensor with a specified shape where elements along the given dimension are sequential integers starting from 0. Parameters diff --git a/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_format_valid.py b/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_format_valid.py index 28007f6463..11ca04289f 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_format_valid.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_format_valid.py @@ -4,5 +4,21 @@ def _is_neighbor_list_format_valid(neighbor_list_format: _NeighborListFormat): + r"""Check if the given neighbor list format is valid. + + Parameters: + ----------- + neighbor_list_format : _NeighborListFormat + The neighbor list format to be validated. + + Returns: + -------- + None + + Raises: + ------- + ValueError + If the neighbor list format is not one of the recognized formats. + """ if neighbor_list_format not in list(_NeighborListFormat): raise ValueError diff --git a/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_sparse.py b/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_sparse.py index e77545bcc5..a22e2798e3 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_sparse.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__is_neighbor_list_sparse.py @@ -6,6 +6,18 @@ def _is_neighbor_list_sparse( neighbor_list_format: _NeighborListFormat, ) -> bool: + r"""Determine if the given neighbor list format is sparse. + + Parameters: + ----------- + neighbor_list_format : _NeighborListFormat + The neighbor list format to be checked. + + Returns: + -------- + bool + True if the neighbor list format is sparse, False otherwise. + """ return neighbor_list_format in { _NeighborListFormat.ORDERED_SPARSE, _NeighborListFormat.SPARSE, diff --git a/src/beignet/func/_molecular_dynamics/_partition/__is_space_valid.py b/src/beignet/func/_molecular_dynamics/_partition/__is_space_valid.py index 207f96ba10..a8ef0ba985 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__is_space_valid.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__is_space_valid.py @@ -3,6 +3,23 @@ def _is_space_valid(space: Tensor) -> Tensor: + r"""Check if the given space tensor is valid. + + Parameters: + ----------- + space : Tensor + The space tensor to be validated. This tensor can have 0, 1, or 2 dimensions. + + Returns: + -------- + Tensor + A tensor containing a single boolean value indicating whether the space is valid. + + Raises: + ------- + ValueError + If the space tensor has more than 2 dimensions. + """ if space.ndim == 0 or space.ndim == 1: return torch.tensor([True]) diff --git a/src/beignet/func/_molecular_dynamics/_partition/__map_bond.py b/src/beignet/func/_molecular_dynamics/_partition/__map_bond.py index 6473eff169..be3ba388e6 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__map_bond.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__map_bond.py @@ -2,8 +2,32 @@ def _map_bond(distance_fn): + r"""Map a distance function over batched start and end positions. + + Parameters: + ----------- + distance_fn : callable + A function that computes the distance between two positions. + + Returns: + -------- + wrapper : callable + A wrapper function that applies `distance_fn` to each pair of start and end positions + in the batch. + + Example: + -------- + >>> # Assume `distance_fn` computes the Euclidean distance between two points + >>> start_positions = torch.tensor([[0, 0], [1, 1]]) + >>> end_positions = torch.tensor([[0, 1], [1, 2]]) + >>> wrapped_fn = _map_bond(distance_fn) + >>> result = wrapped_fn(start_positions, end_positions) + >>> print(result) + tensor([...]) + """ def wrapper(start_positions, end_positions): batch_size = start_positions.shape[0] + return torch.stack( [ distance_fn(start_positions[i], end_positions[i]) diff --git a/src/beignet/func/_molecular_dynamics/_partition/__normalize_cell_size.py b/src/beignet/func/_molecular_dynamics/_partition/__normalize_cell_size.py index bdda59ee4d..f69be563ca 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__normalize_cell_size.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__normalize_cell_size.py @@ -3,8 +3,7 @@ def _normalize_cell_size(box: Tensor, cutoff: float) -> Tensor: - """ - Normalize the cell size given the bounding box dimensions and a cutoff value. + r"""Normalize the cell size given the bounding box dimensions and a cutoff value. Parameters ---------- diff --git a/src/beignet/func/_molecular_dynamics/_partition/__particles_per_cell.py b/src/beignet/func/_molecular_dynamics/_partition/__particles_per_cell.py index 018586594d..0dac049edf 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__particles_per_cell.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__particles_per_cell.py @@ -11,8 +11,7 @@ def _particles_per_cell( size: Tensor, minimum_size: float, ) -> Tensor: - """ - Computes the number of particles per cell given a defined cell size and minimum size. + r"""Computes the number of particles per cell given a defined cell size and minimum size. Parameters ---------- diff --git a/src/beignet/func/_molecular_dynamics/_partition/__partition_error.py b/src/beignet/func/_molecular_dynamics/_partition/__partition_error.py index 6df6ed1306..6d4cc8d422 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__partition_error.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__partition_error.py @@ -7,9 +7,41 @@ @_dataclass class _PartitionError: + r"""A class to represent and manage partition errors with specific error codes. + + Attributes: + ----------- + code : Tensor + A tensor representing the error code. + + Methods: + -------- + update(bit: bytes, predicate: Tensor) -> "_PartitionError" + Update the error code based on a predicate and a new bit. + + __str__() -> str + Provide a human-readable string representation of the error. + + __repr__() -> str + Alias for __str__(). + """ code: Tensor def update(self, bit: bytes, predicate: Tensor) -> "_PartitionError": + r"""Update the error code based on a predicate and a new bit. + + Parameters: + ----------- + bit : bytes + The bit to be combined with the existing error code. + predicate : Tensor + A tensor that determines where the bit should be applied. + + Returns: + -------- + _PartitionError + A new instance of `_PartitionError` with the updated error code. + """ zero = torch.zeros([], dtype=torch.uint8) bit = torch.tensor(bit, dtype=torch.uint8) @@ -17,6 +49,18 @@ def update(self, bit: bytes, predicate: Tensor) -> "_PartitionError": return _PartitionError(code=self.code | torch.where(predicate, bit, zero)) def __str__(self) -> str: + r"""Provide a human-readable string representation of the error. + + Returns: + -------- + str + A string describing the error. + + Raises: + ------- + ValueError + If the error code is unexpected or not recognized. + """ if not torch.any(self.code): return "" diff --git a/src/beignet/func/_molecular_dynamics/_partition/__segment_sum.py b/src/beignet/func/_molecular_dynamics/_partition/__segment_sum.py index e8992dcec2..a4eedaaa3a 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__segment_sum.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__segment_sum.py @@ -11,8 +11,7 @@ def _segment_sum( n: Optional[int] = None, **kwargs, ) -> Tensor: - """ - Computes the sum of segments of a tensor along the first dimension. + r"""Computes the sum of segments of a tensor along the first dimension. Parameters ---------- diff --git a/src/beignet/func/_molecular_dynamics/_partition/__shift.py b/src/beignet/func/_molecular_dynamics/_partition/__shift.py index 43d889613f..0079ef2f0e 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__shift.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__shift.py @@ -3,8 +3,7 @@ def _shift(a: Tensor, b: Tensor) -> Tensor: - """ - Shifts a tensor `a` along dimensions specified in `b`. + r"""Shifts a tensor `a` along dimensions specified in `b`. The shift can be applied in up to three dimensions (x, y, z). Positive values in `b` indicate a forward shift, while negative values indicate @@ -22,5 +21,4 @@ def _shift(a: Tensor, b: Tensor) -> Tensor: Tensor The shifted tensor. """ - return torch.roll(a, shifts=tuple(b), dims=tuple(range(len(b)))) diff --git a/src/beignet/func/_molecular_dynamics/_partition/__to_square_metric_fn.py b/src/beignet/func/_molecular_dynamics/_partition/__to_square_metric_fn.py index 0d42b71150..be771a561d 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__to_square_metric_fn.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__to_square_metric_fn.py @@ -7,8 +7,7 @@ def _to_square_metric_fn( fn: Callable[[Tensor, Tensor, Any], Tensor], ) -> Callable[[Tensor, Tensor, Any], Tensor]: - """ - Converts a given distance function to a squared distance metric. + r"""Converts a given distance function to a squared distance metric. The function tries to apply the given distance function `fn` to positions in one to three dimensions to determine if the output is scalar or vector. @@ -40,8 +39,10 @@ def square_metric(a: Tensor, b: Tensor, **kwargs) -> Tensor: return torch.sum(torch.square(fn(a, b, **kwargs)), dim=-1) return square_metric + except TypeError: continue + except ValueError: continue diff --git a/src/beignet/func/_molecular_dynamics/_partition/__unflatten_cell_buffer.py b/src/beignet/func/_molecular_dynamics/_partition/__unflatten_cell_buffer.py index 2d34da33a6..7e98ba9402 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/__unflatten_cell_buffer.py +++ b/src/beignet/func/_molecular_dynamics/_partition/__unflatten_cell_buffer.py @@ -3,8 +3,7 @@ def _unflatten_cell_buffer(buffer: Tensor, cells_per_side: [int, float, Tensor], dim: int): - """ - Reshape a flat buffer into a multidimensional cell buffer. + r"""Reshape a flat buffer into a multidimensional cell buffer. Parameters ---------- diff --git a/src/beignet/func/_molecular_dynamics/_partition/_cell_list.py b/src/beignet/func/_molecular_dynamics/_partition/_cell_list.py index 696e7eebcf..18ac4cf1f3 100644 --- a/src/beignet/func/_molecular_dynamics/_partition/_cell_list.py +++ b/src/beignet/func/_molecular_dynamics/_partition/_cell_list.py @@ -87,6 +87,7 @@ def fn( exceeded_maximum_size = False update_fn = fn + else: buffer_size, exceeded_maximum_size, update_fn = excess @@ -244,24 +245,4 @@ def update_fn( **kwargs, ) - return _CellListFunctionList(setup_fn=setup_fn, update_fn=update_fn) - - -if __name__ == '__main__': - dtype = torch.float32 - box_size = torch.tensor([8.65, 8.0], dtype=torch.float32) - cell_size = 1.0 - - # Test particle positions - R = torch.tensor([ - [0.25, 0.25], - [8.5, 1.95], - [8.1, 1.5], - [3.7, 7.9] - ], dtype=dtype) - - cell_fn = cell_list(box_size, cell_size) - - cell_list_instance = cell_fn.setup_fn(R) - - print(cell_list_instance.positions_buffer[7, 3, 0]) \ No newline at end of file + return _CellListFunctionList(setup_fn=setup_fn, update_fn=update_fn) \ No newline at end of file