Skip to content

Commit

Permalink
Make CG cache contiguous, fix some
Browse files Browse the repository at this point in the history
  • Loading branch information
jwa7 committed Feb 15, 2024
1 parent 39aee61 commit f6d88ef
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
6 changes: 4 additions & 2 deletions python/rascaline/rascaline/utils/clebsch_gordan/_cg_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def _build_cg_coeff_dict(
# inside cg combine function,
blocks.append(
TensorBlock(
values=values,
values=_dispatch.contiguous(values),
samples=Labels(["m1", "m2", "mu"], l1l2lam_sample_values),
components=[],
properties=Labels.range("property", 1),
Expand All @@ -384,7 +384,9 @@ def _build_cg_coeff_dict(
block_value_shape = (1,) + l1l2lam_values.shape + (1,)
blocks.append(
TensorBlock(
values=l1l2lam_values.reshape(block_value_shape),
values=_dispatch.contiguous(
l1l2lam_values.reshape(block_value_shape)
),
samples=Labels.range("sample", 1),
components=[
Labels(
Expand Down
20 changes: 20 additions & 0 deletions python/rascaline/rascaline/utils/clebsch_gordan/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,26 @@ def argsort(array):
raise TypeError(UNKNOWN_ARRAY_TYPE)


def contiguous(array):
"""
Returns a contiguous array.
It is equivalent of np.ascontiguousarray(array) and tensor.contiguous(). In
the case of numpy, C order is used for consistency with torch. As such, only
C-contiguity is checked.
"""
if isinstance(array, TorchTensor):
if array.is_contiguous():
return array
return array.contiguous()
elif isinstance(array, np.ndarray):
if array.flags["C_CONTIGUOUS"]:
return array
return np.ascontiguousarray(array)
else:
raise TypeError(UNKNOWN_ARRAY_TYPE)


def unique(array, axis: Optional[int] = None):
"""Find the unique elements of an array."""
if isinstance(array, TorchTensor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,14 @@ def forward(self, density: TensorMap) -> Union[TensorMap, List[TensorMap]]:

def compute(self, density: TensorMap) -> Union[TensorMap, List[TensorMap]]:
"""
Performs the density correlations for public functions
:py:func:`correlate_density` and :py:func:`correlate_density_metadata`.
:param density: A density descriptor of body order 2 (correlation order 1),
in :py:class:`TensorMap` format. This may be, for example, a rascaline
:py:class:`SphericalExpansion` or :py:class:`LodeSphericalExpansion`.
Alternatively, this could be multi-center descriptor, such as a pair
density.
Computes the density correlations by taking iterative Clebsch-Gordan
(CG) tensor products of the input `density` descriptor with itself.
:param density: A density descriptor of body order 2 (correlation order
1), in :py:class:`TensorMap` format. This may be, for example, a
rascaline :py:class:`SphericalExpansion` or
:py:class:`LodeSphericalExpansion`. Alternatively, this could be
multi-center descriptor, such as a pair density.
"""
return self._correlate_density(
density,
Expand All @@ -286,16 +286,16 @@ def compute_metadata(
density: TensorMap,
) -> Union[TensorMap, List[TensorMap]]:
"""
Returns the metadata-only :py:class:`TensorMap`(s) that would be output by
the function :py:func:`correlate_density` under the same settings, without
perfoming the actual Clebsch-Gordan tensor products. See this function for
full documentation.
:param density: A density descriptor of body order 2 (correlation order 1),
in :py:class:`TensorMap` format. This may be, for example, a rascaline
:py:class:`SphericalExpansion` or :py:class:`LodeSphericalExpansion`.
Alternatively, this could be multi-center descriptor, such as a pair
density.
Returns the metadata-only :py:class:`TensorMap`(s) that would be output
by the function :py:meth:`compute` for the same calculator under the
same settings, without perfoming the actual Clebsch-Gordan tensor
products.
:param density: A density descriptor of body order 2 (correlation order
1), in :py:class:`TensorMap` format. This may be, for example, a
rascaline :py:class:`SphericalExpansion` or
:py:class:`LodeSphericalExpansion`. Alternatively, this could be
multi-center descriptor, such as a pair density.
"""
return self._correlate_density(
density,
Expand Down

0 comments on commit f6d88ef

Please sign in to comment.