Skip to content

Commit

Permalink
neighbor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Jul 26, 2024
1 parent 3927f64 commit 96cf3c4
Show file tree
Hide file tree
Showing 3 changed files with 892 additions and 872 deletions.
136 changes: 85 additions & 51 deletions src/beignet/func/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _neighbor_list_interaction(
else:
parameters[name] = parameter

merged_dictionaries = functools.partial(
merge_dictionaries = functools.partial(
_merge_dictionaries,
ignore_unused_parameters=ignore_unused_parameters,
)
Expand All @@ -416,27 +416,23 @@ def mapped_fn(
if neighbor_list.format is _NeighborListFormat.ORDERED_SPARSE:
normalization = 1.0
else:
distances = _map_neighbor(distance_fn)(
positions,
safe_index(positions, neighbor_list.indexes),
)
d = _map_neighbor(distance_fn)
r_neigh = safe_index(positions, neighbor_list.indexes)
distances = d(positions, r_neigh)

mask = torch.less(neighbor_list.indexes, positions.shape[0])

out = fn(
distances,
**_kwargs_to_neighbor_list_parameters(
neighbor_list.format,
neighbor_list.indexes,
_kinds,
merged_dictionaries(
parameters,
dynamic_kwargs,
),
combinators,
),
merged_kwargs = merge_dictionaries(parameters, dynamic_kwargs)
merged_kwargs = _kwargs_to_neighbor_list_parameters(
neighbor_list.format,
neighbor_list.indexes,
kinds,
merged_kwargs,
combinators
)

out = fn(distances, **merged_kwargs)

if out.ndim > mask.ndim:
mask = torch.reshape(
mask,
Expand Down Expand Up @@ -701,15 +697,13 @@ def _to_neighbor_list_kind_parameters(
return parameters
case 2:
if is_neighbor_list_sparse(format):
return _map_bond(
fn,
)(
kinds[indexes[0]],
kinds[indexes[1]],
return manual_vmap(fn, (0, 0), 0)(
safe_index(kinds, indexes[0]),
safe_index(kinds, indexes[1]),
)

return torch.vmap(
torch.vmap(
return manual_vmap(
manual_vmap(
fn,
in_dims=(None, 0),
),
Expand All @@ -727,15 +721,15 @@ def _to_neighbor_list_kind_parameters(
parameter,
),
)(
kinds[indexes[0]],
kinds[indexes[1]],
safe_index(kinds, indexes[0]),
safe_index(kinds, indexes[1]),
),
parameters.tree,
)

return optree.tree_map(
lambda parameter: torch.vmap(
torch.vmap(
lambda parameter: manual_vmap(
manual_vmap(
functools.partial(
fn,
parameter,
Expand All @@ -744,7 +738,7 @@ def _to_neighbor_list_kind_parameters(
)
)(
kinds,
kinds[indexes],
safe_index(kinds, indexes),
),
parameters.tree,
)
Expand All @@ -756,6 +750,54 @@ def _to_neighbor_list_kind_parameters(
return parameters


def manual_vmap(func: Callable,
in_dims: Union[int, Tuple[Union[int, None], ...]] = 0,
out_dims: Union[int, Tuple[int, ...]] = 0,
randomness: str = 'error',
*,
chunk_size: Union[None, int] = None) -> Callable:
def batched_func(*args, **kwargs):
# Determine the batch size from the first input that has a batch dimension
if isinstance(in_dims, int):
batch_size = args[0].shape[in_dims]
else:
batch_size = next(
arg.shape[dim] for arg, dim in zip(args, in_dims) if
dim is not None)

# Initialize a list to store the results
results = []

# Iterate over the batch dimension
for i in range(batch_size):
# Extract the i-th element from each input
sliced_args = []
for arg, dim in zip(args,
in_dims if isinstance(in_dims, tuple) else [
in_dims] * len(
args)):
if dim is None:
sliced_args.append(arg)
else:
sliced_args.append(arg.select(dim, i))

# Call the function with the sliced arguments
result = func(*sliced_args, **kwargs)

# Append the result to the results list
results.append(result)

# Stack the results along the specified output dimension
if isinstance(out_dims, int):
return torch.stack(results, dim=out_dims)
else:
return tuple(
torch.stack([res[i] for res in results], dim=out_dims[i]) for i
in range(len(results[0])))

return batched_func


def _to_neighbor_list_matrix_parameters(
format: _NeighborListFormat,
indexes: Tensor,
Expand All @@ -769,9 +811,7 @@ def _to_neighbor_list_matrix_parameters(
return parameters
case 1:
if is_neighbor_list_sparse(format):
return _map_bond(
combinator,
)(
return manual_vmap(combinator, (0, 0), 0)(
safe_index(parameters, indexes[0]),
safe_index(parameters, indexes[1]),
)
Expand All @@ -781,20 +821,16 @@ def _to_neighbor_list_matrix_parameters(
safe_index(parameters, indexes),
)
case 2:
def query(id_a, id_b):
return safe_index(parameters, id_a, id_b)

if is_neighbor_list_sparse(format):
return _map_bond(
lambda a, b: parameters[a, b],
)(
displacement = lambda a, b: safe_index(parameters, a, b)
return manual_vmap(displacement, (0, 0), 0)(
indexes[0],
indexes[1],
)

return torch.func.vmap(
torch.func.vmap(
lambda a, b: parameters[a, b],
return manual_vmap(
manual_vmap(
lambda a, b: safe_index(parameters, a, b),
(None, 0),
),
)(
Expand All @@ -808,11 +844,11 @@ def query(id_a, id_b):
case _ParameterTreeKind.BOND:
if is_neighbor_list_sparse(format):
return optree.tree_map(
lambda parameter: _map_bond(
lambda parameter: manual_vmap(
functools.partial(
lambda p, a, b: p[a, b],
lambda p, a, b: safe_index(p, a, b),
parameter,
),
), (0, 0), 0
)(
indexes[0],
indexes[1],
Expand All @@ -821,12 +857,10 @@ def query(id_a, id_b):
)

return optree.tree_map(
lambda parameter: torch.func.vmap(
torch.func.vmap(
lambda parameter: manual_vmap(
manual_vmap(
functools.partial(
lambda p, a, b: p[a, b],
parameter,
),
safe_index, parameter),
(None, 0),
),
)(
Expand All @@ -838,8 +872,8 @@ def query(id_a, id_b):
case _ParameterTreeKind.PARTICLE:
if is_neighbor_list_sparse(format):
return optree.tree_map(
lambda parameter: _map_bond(
combinator,
lambda parameter: manual_vmap(
combinator, (0, 0), 0
)(
safe_index(parameter, indexes[0]),
safe_index(parameter, indexes[1]),
Expand Down
23 changes: 8 additions & 15 deletions src/beignet/func/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,11 @@ def did_buffer_overflow(self) -> bool:

@property
def cell_size_too_small(self) -> bool:
return (self.error.code & PEC.CELL_SIZE_TOO_SMALL).item() != 0
return (self.partition_error.code & PEC.CELL_SIZE_TOO_SMALL).item() != 0

@property
def malformed_box(self) -> bool:
return (self.error.code & PEC.MALFORMED_BOX).item() != 0
return (self.partition_error.code & PEC.MALFORMED_BOX).item() != 0


@_dataclass
Expand Down Expand Up @@ -487,24 +487,17 @@ def safe_index(array: Tensor, indices: Tensor, indices_b: Optional[Tensor] = Non
Tensor
The resulting tensor after indexing.
"""
max_index = array.shape[0] - 1

clamped_indices = indices.clamp(0, max_index)

if indices_b is not None:
max_index_b = array.shape[1] - 1
indices = torch.clamp(indices, 0, array.size(0) - 1)

clamped_indices = indices.unsqueeze(1).clamp(0, max_index)
indices_b = torch.clamp(indices_b, 0, array.size(1) - 1)

clamped_indices_b = indices_b.clamp(0, max_index_b)
return array[indices.to(torch.long), indices_b.to(torch.long)]

print(f"ca: {clamped_indices.shape}")
print(f"cb: {clamped_indices_b.shape}")
print(array[clamped_indices, clamped_indices_b].shape)

return array[clamped_indices, clamped_indices_b]
else:
indices = torch.clamp(indices, 0, array.size(0) - 1)

return array[clamped_indices]
return array[indices]


def safe_mask(
Expand Down
Loading

0 comments on commit 96cf3c4

Please sign in to comment.