diff --git a/src/deepali/losses/flow.py b/src/deepali/losses/flow.py index bb33b61..dc2a65b 100644 --- a/src/deepali/losses/flow.py +++ b/src/deepali/losses/flow.py @@ -4,10 +4,9 @@ from typing import Optional, Union -import torch from torch import Tensor -from deepali.core.typing import ScalarOrTuple, Shape +from deepali.core.typing import Array, Scalar, ScalarOrTuple from . import functional as L from .base import DisplacementLoss @@ -20,6 +19,7 @@ def __init__( self, mode: Optional[str] = None, sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, stride: Optional[ScalarOrTuple] = None, reduction: str = "mean", ): @@ -28,6 +28,8 @@ def __init__( Args: mode: Method used to approximate :func:`flow_derivatives()`. sigma: Standard deviation of Gaussian in grid units used to smooth vector field. + spacing: Spacing between grid elements. Should be given in the units of the flow vectors. + By default, flow vectors with respect to normalized grid coordinates are assumed. stride: Number of output grid points between control points plus one for ``mode='bspline'``. reduction: Operation to use for reducing spatially distributed loss values. @@ -35,24 +37,18 @@ def __init__( super().__init__() self.mode = mode self.sigma = sigma + self.spacing = spacing self.stride = stride self.reduction = reduction - def _spacing(self, u_shape: Shape) -> Optional[Tensor]: - ndim = len(u_shape) - if ndim < 3: - raise ValueError(f"{type(self).__name__}.forward() 'u' must be at least 3-dimensional") - if ndim == 3: - return None - size = torch.tensor(u_shape[-1:1:-1], dtype=torch.float, device=torch.device("cpu")) - return 2 / (size - 1) - def extra_repr(self) -> str: args = [] if self.mode: args.append(f"mode={self.mode!r}") if self.sigma: args.append(f"sigma={self.sigma!r}") + if self.spacing: + args.append(f"spacing={self.spacing!r}") if self.stride: args.append(f"stride={self.stride!r}") args.append(f"reduction={self.reduction!r}") @@ -68,6 +64,7 @@ def __init__( q: Optional[Union[int, float]] = 1, mode: Optional[str] = None, sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, stride: Optional[ScalarOrTuple] = None, reduction: str = "mean", ): @@ -76,24 +73,27 @@ def __init__( Args: mode: Method used to approximate :func:`flow_derivatives()`. sigma: Standard deviation of Gaussian in grid units used to smooth vector field. + spacing: Spacing between grid elements. Should be given in the units of the flow vectors. + By default, flow vectors with respect to normalized grid coordinates are assumed. stride: Number of output grid points between control points plus one for ``mode='bspline'``. reduction: Operation to use for reducing spatially distributed loss values. """ - super().__init__(mode=mode, sigma=sigma, stride=stride, reduction=reduction) + super().__init__( + mode=mode, sigma=sigma, spacing=spacing, stride=stride, reduction=reduction + ) self.p = p self.q = 1 / p if q is None else q def forward(self, u: Tensor) -> Tensor: r"""Evaluate regularization loss for given transformation.""" - spacing = self._spacing(u.shape) return L.grad_loss( u, p=self.p, q=self.q, mode=self.mode, sigma=self.sigma, - spacing=spacing, + spacing=self.spacing, stride=self.stride, reduction=self.reduction, ) @@ -107,12 +107,11 @@ class Bending(_SpatialDerivativesLoss): def forward(self, u: Tensor) -> Tensor: r"""Evaluate regularization loss for given transformation.""" - spacing = self._spacing(u.shape) return L.bending_loss( u, mode=self.mode, sigma=self.sigma, - spacing=spacing, + spacing=self.spacing, stride=self.stride, reduction=self.reduction, ) @@ -127,12 +126,11 @@ class Curvature(_SpatialDerivativesLoss): def forward(self, u: Tensor) -> Tensor: r"""Evaluate regularization loss for given transformation.""" - spacing = self._spacing(u.shape) return L.curvature_loss( u, mode=self.mode, sigma=self.sigma, - spacing=spacing, + spacing=self.spacing, stride=self.stride, reduction=self.reduction, ) @@ -143,12 +141,11 @@ class Diffusion(_SpatialDerivativesLoss): def forward(self, u: Tensor) -> Tensor: r"""Evaluate regularization loss for given transformation.""" - spacing = self._spacing(u.shape) return L.diffusion_loss( u, mode=self.mode, sigma=self.sigma, - spacing=spacing, + spacing=self.spacing, stride=self.stride, reduction=self.reduction, ) @@ -159,12 +156,11 @@ class Divergence(_SpatialDerivativesLoss): def forward(self, u: Tensor) -> Tensor: r"""Evaluate regularization loss for given transformation.""" - spacing = self._spacing(u.shape) return L.divergence_loss( u, mode=self.mode, sigma=self.sigma, - spacing=spacing, + spacing=self.spacing, stride=self.stride, reduction=self.reduction, ) @@ -183,10 +179,11 @@ def __init__( shear_modulus: Optional[float] = None, mode: Optional[str] = None, sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, stride: Optional[ScalarOrTuple] = None, reduction: str = "mean", ): - super().__init__(mode=mode, sigma=sigma, reduction=reduction) + super().__init__(mode=mode, sigma=sigma, spacing=spacing, reduction=reduction) self.material_name = material_name self.first_parameter = first_parameter self.second_parameter = second_parameter @@ -196,7 +193,6 @@ def __init__( def forward(self, u: Tensor) -> Tensor: r"""Evaluate regularization loss for given transformation.""" - spacing = self._spacing(u.shape) return L.elasticity_loss( u, material_name=self.material_name, @@ -207,7 +203,7 @@ def forward(self, u: Tensor) -> Tensor: shear_modulus=self.shear_modulus, mode=self.mode, sigma=self.sigma, - spacing=spacing, + spacing=self.spacing, stride=self.stride, reduction=self.reduction, ) @@ -234,12 +230,11 @@ class TotalVariation(_SpatialDerivativesLoss): def forward(self, u: Tensor) -> Tensor: r"""Evaluate regularization loss for given transformation.""" - spacing = self._spacing(u.shape) return L.total_variation_loss( u, mode=self.mode, sigma=self.sigma, - spacing=spacing, + spacing=self.spacing, stride=self.stride, reduction=self.reduction, )