diff --git a/src/deepali/core/enum.py b/src/deepali/core/enum.py index d582517..7c36b93 100644 --- a/src/deepali/core/enum.py +++ b/src/deepali/core/enum.py @@ -206,12 +206,12 @@ def is_mixed(key: SpatialDerivativeKey) -> bool: return len(set(key)) > 1 @staticmethod - def all(ndim: int, order: Union[int, Sequence[int]]) -> List[SpatialDerivativeKey]: + def all(spatial_dims: int, order: Union[int, Sequence[int]]) -> List[SpatialDerivativeKey]: r"""Unmixed spatial derivatives of specified order.""" if isinstance(order, int): order = [order] keys = [] - dims = [str(SpatialDim(d)) for d in range(ndim)] + dims = [str(SpatialDim(d)) for d in range(spatial_dims)] for n in order: if n > 0: codes = dims @@ -221,11 +221,11 @@ def all(ndim: int, order: Union[int, Sequence[int]]) -> List[SpatialDerivativeKe return keys @staticmethod - def unmixed(ndim: int, order: int) -> List[SpatialDerivativeKey]: + def unmixed(spatial_dims: int, order: int) -> List[SpatialDerivativeKey]: r"""Unmixed spatial derivatives of specified order.""" if order <= 0: return [] - return [SpatialDim(d).symbol() * order for d in range(ndim)] + return [SpatialDim(d).symbol() * order for d in range(spatial_dims)] @classmethod def unique(cls, keys: Iterable[SpatialDerivativeKey]) -> Set[SpatialDerivativeKey]: @@ -466,7 +466,7 @@ def all( if order == 0: return [] channels = cls._channels(spatial_dims, channel) - derivs = SpatialDerivativeKeys.all(spatial_dims, order=order) + derivs = SpatialDerivativeKeys.all(spatial_dims=spatial_dims, order=order) return [cls.symbol(c, d) for c, d in itertools.product(channels, derivs)] @classmethod @@ -482,7 +482,7 @@ def unmixed( if order == 0: return [] channels = cls._channels(spatial_dims, channel) - derivs = SpatialDerivativeKeys.unmixed(spatial_dims, order=order) + derivs = SpatialDerivativeKeys.unmixed(spatial_dims=spatial_dims, order=order) return [cls.symbol(c, d) for c, d in itertools.product(channels, derivs)] @classmethod @@ -504,7 +504,7 @@ def divergence(cls, spatial_dims: int) -> List[FlowDerivativeKey]: @classmethod def curvature(cls, spatial_dims: int) -> List[FlowDerivativeKey]: channels = range(spatial_dims) - derivs = SpatialDerivativeKeys.unmixed(spatial_dims, order=2) + derivs = SpatialDerivativeKeys.unmixed(spatial_dims=spatial_dims, order=2) return [cls.symbol(c, d) for c, d in itertools.product(channels, derivs)] @classmethod diff --git a/src/deepali/core/image.py b/src/deepali/core/image.py index ecc63f6..9f19786 100644 --- a/src/deepali/core/image.py +++ b/src/deepali/core/image.py @@ -1536,7 +1536,7 @@ def spatial_derivatives( if which is None: if order is None: order = 1 - which = SpatialDerivativeKeys.all(ndim=D, order=order) + which = SpatialDerivativeKeys.all(spatial_dims=D, order=order) elif order is not None: which = [arg for arg in which if len(arg) == order] diff --git a/tests/test_core_enum.py b/tests/test_core_enum.py index a7a1e19..5a6e065 100644 --- a/tests/test_core_enum.py +++ b/tests/test_core_enum.py @@ -74,7 +74,7 @@ def test_flow_derivative_keys_unique() -> None: def test_flow_derivative_keys_all() -> None: for d, order in itertools.product([2, 3], [0, 1, 2]): channel_keys = ["u", "v", "w"][:d] - spatial_keys = SpatialDerivativeKeys.all(ndim=d, order=order) + spatial_keys = SpatialDerivativeKeys.all(spatial_dims=d, order=order) expected = [f"d{a}/d{b}" for a, b in itertools.product(channel_keys, spatial_keys)] assert FlowDerivativeKeys.all(spatial_dims=d, order=order) == expected