Skip to content

Commit

Permalink
neighbor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Jul 24, 2024
1 parent 0384a99 commit 3927f64
Show file tree
Hide file tree
Showing 3 changed files with 987 additions and 143 deletions.
51 changes: 28 additions & 23 deletions src/beignet/func/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from beignet.func.__dataclass import _dataclass
from beignet.func._partition import _NeighborListFormat, map_product, \
_NeighborList, _map_bond, _map_neighbor, is_neighbor_list_sparse, \
_segment_sum
_segment_sum, safe_index


class _ParameterTreeKind(Enum):
Expand Down Expand Up @@ -265,7 +265,6 @@ def _particle_fn(_parameter: Tensor) -> Tensor:
_parameter[:, None, ...],
_parameter[None, :, ...],
)
print(kwargs.items())

parameters[name] = optree.tree_map(
_particle_fn,
Expand Down Expand Up @@ -408,8 +407,8 @@ def mapped_fn(

if is_neighbor_list_sparse(neighbor_list.format):
distances = _map_bond(distance_fn)(
positions[neighbor_list.indexes[0]],
positions[neighbor_list.indexes[1]],
safe_index(positions, neighbor_list.indexes[0]),
safe_index(positions, neighbor_list.indexes[1]),
)

mask = torch.less(neighbor_list.indexes[0], positions.shape[0])
Expand All @@ -419,7 +418,7 @@ def mapped_fn(
else:
distances = _map_neighbor(distance_fn)(
positions,
positions[neighbor_list.indexes],
safe_index(positions, neighbor_list.indexes),
)

mask = torch.less(neighbor_list.indexes, positions.shape[0])
Expand Down Expand Up @@ -484,9 +483,11 @@ def _pair_interaction(
) -> Callable[..., Tensor]:
parameters, combinators = {}, {}

for name, parameter in kwargs.items():
for name, parameter in list(kwargs.items()):
if isinstance(parameter, Callable):
combinators[name] = parameter
del kwargs[name]

elif isinstance(parameter, tuple) and isinstance(parameter[0],
Callable):
assert len(parameter) == 2
Expand Down Expand Up @@ -551,16 +552,16 @@ def mapped_fn(_position: Tensor, **_dynamic_kwargs):
s_kwargs = _kwargs_to_pair_parameters(_kwargs, combinators,
(m, n))

u = fn(distance, **s_kwargs)
y = fn(distance, **s_kwargs)

if m == n:
u = _zero_diagonal_mask(u)
y = _zero_diagonal_mask(y)

u = _safe_sum(u)
y = _safe_sum(y)

u = u + u * 0.5
u = u + y * 0.5
else:
y = _safe_sum(u)
y = _safe_sum(y)

u = u + y

Expand All @@ -577,7 +578,7 @@ def mapped_fn(_position: Tensor, _kinds: Tensor, **_dynamic_kwargs):

u = torch.tensor(0.0, dtype=torch.float32)

n = _position.shape[0]
num_particles = _position.shape[0]

distance_fn = functools.partial(displacement_fn, **_dynamic_kwargs)

Expand All @@ -592,13 +593,13 @@ def mapped_fn(_position: Tensor, _kinds: Tensor, **_dynamic_kwargs):
a = torch.reshape(
_kinds == m,
[
n,
num_particles,
],
)
b = torch.reshape(
_kinds == n,
[
n,
num_particles,
],
)

Expand Down Expand Up @@ -712,7 +713,7 @@ def _to_neighbor_list_kind_parameters(
fn,
in_dims=(None, 0),
),
)(kinds, kinds[indexes])
)(kinds, safe_index(kinds, indexes))
case _:
raise ValueError
case parameters if isinstance(parameters, _ParameterTree):
Expand Down Expand Up @@ -771,15 +772,18 @@ def _to_neighbor_list_matrix_parameters(
return _map_bond(
combinator,
)(
parameters[indexes[0]],
parameters[indexes[1]],
safe_index(parameters, indexes[0]),
safe_index(parameters, indexes[1]),
)

return combinator(
parameters[:, None],
parameters[indexes],
safe_index(parameters, indexes),
)
case 2:
def query(id_a, id_b):
return safe_index(parameters, id_a, id_b)

if is_neighbor_list_sparse(format):
return _map_bond(
lambda a, b: parameters[a, b],
Expand Down Expand Up @@ -837,8 +841,8 @@ def _to_neighbor_list_matrix_parameters(
lambda parameter: _map_bond(
combinator,
)(
parameter[indexes[0]],
parameter[indexes[1]],
safe_index(parameter, indexes[0]),
safe_index(parameter, indexes[1]),
),
parameters.tree,
)
Expand All @@ -848,7 +852,7 @@ def _to_neighbor_list_matrix_parameters(
combinator,
)(
parameter,
parameter[indexes],
safe_index(parameter, indexes),
),
parameters.tree,
)
Expand Down Expand Up @@ -878,10 +882,10 @@ def interact(
],
*,
bonds: Optional[Tensor] = None,
kinds: Optional[Tensor] = None,
kinds: Optional[Union[int, Tensor]] = None,
dim: Optional[Union[int, Tuple[int, ...]]] = None,
keepdim: bool = False,
ignore_unused_parameters: bool = True,
ignore_unused_parameters: bool = False,
**kwargs,
) -> Callable[..., Tensor]:
r"""
Expand Down Expand Up @@ -1022,6 +1026,7 @@ def fn(x: Tensor, a: float, e: float, s: float, **_) -> Tensor:
kinds=kinds,
dim=dim,
ignore_unused_parameters=ignore_unused_parameters,
**kwargs,
)
case "pair":
return _pair_interaction(
Expand Down
21 changes: 16 additions & 5 deletions src/beignet/func/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def metric(distance_fn: Callable) -> Callable:
return lambda Ra, Rb, **kwargs: distance(distance_fn(Ra, Rb, **kwargs))


def safe_index(array: Tensor, indices: Tensor) -> Tensor:
def safe_index(array: Tensor, indices: Tensor, indices_b: Optional[Tensor] = None) -> Tensor:
r"""Safely index into a tensor, clamping out-of-bounds indices to the nearest valid index.
Parameters
Expand All @@ -479,21 +479,32 @@ def safe_index(array: Tensor, indices: Tensor) -> Tensor:
The tensor to index.
indices : Tensor
The indices to use for indexing.
indices_b : Tensor
Another Tensor of indices to for advanced indexing.
Returns
-------
Tensor
The resulting tensor after indexing.
"""
print(array.shape)
print(indices)
max_index = array.shape[0] - 1

clamped_indices = indices.clamp(0, max_index)

result = array[clamped_indices]
if indices_b is not None:
max_index_b = array.shape[1] - 1

return result
clamped_indices = indices.unsqueeze(1).clamp(0, max_index)

clamped_indices_b = indices_b.clamp(0, max_index_b)

print(f"ca: {clamped_indices.shape}")
print(f"cb: {clamped_indices_b.shape}")
print(array[clamped_indices, clamped_indices_b].shape)

return array[clamped_indices, clamped_indices_b]

return array[clamped_indices]


def safe_mask(
Expand Down
Loading

0 comments on commit 3927f64

Please sign in to comment.