Skip to content

Commit

Permalink
[core] Fix spatial derivatives when using mode='bspline' or mode='gau…
Browse files Browse the repository at this point in the history
…ssian'
  • Loading branch information
aschuh-hf committed Dec 14, 2023
1 parent 580a05d commit 41bcc55
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 14 deletions.
45 changes: 31 additions & 14 deletions src/deepali/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1486,7 +1486,10 @@ def spatial_derivatives(
If ``None``, ``forward_central_backward`` is used as default mode.
sigma: Standard deviation of Gaussian kernel in grid units. If ``None`` or zero,
no Gaussian smoothing is used for calculation of finite differences, and a
default standard deviation of 0.4 is used when ``mode="gaussian"``.
default standard deviation of 0.7355 is used when ``mode="gaussian"``. With a smaller
standard deviation, the magnitude of the derivative values starts to deviate between
``mode="gaussian"`` and finite differences of a Gaussian smoothed input. This is likely
due to a too small discretized Gaussian filter and its derivative.
spacing: Physical spacing between image grid points, e.g., ``(sx, sy, sz)``.
When a scalar is given, the same spacing is used for each image and spatial dimension.
If a sequence is given, it must be of length equal to the number of spatial dimensions ``D``,
Expand Down Expand Up @@ -1556,7 +1559,7 @@ def spatial_derivatives(
if mode in ("forward", "backward", "central", "forward_central_backward", "prewitt", "sobel"):
if sigma and sigma > 0:
blur = gaussian1d(sigma, dtype=torch.float, device=data.device)
data = conv(data, blur, padding=PaddingMode.ZEROS)
data = conv(data, blur, padding=PaddingMode.REPLICATE)
if mode in ("prewitt", "sobel"):
avg_kernel = torch.tensor([1, 1 if mode == "prewitt" else 2, 1], dtype=data.dtype)
avg_kernel /= avg_kernel.sum()
Expand Down Expand Up @@ -1589,7 +1592,7 @@ def spatial_derivatives(

if sigma and sigma > 0:
blur = gaussian1d(sigma, dtype=torch.float, device=data.device)
data = conv(data, blur, padding=PaddingMode.ZEROS)
data = conv(data, blur, padding=PaddingMode.REPLICATE)

if stride is None:
stride = 1
Expand All @@ -1616,27 +1619,41 @@ def bspline1d(s: int, d: int) -> Tensor:
for spatial_dim in SpatialDerivativeKeys.split(code):
order[spatial_dim] += 1
kernel = [bspline1d(s, d) for s, d in zip(stride, order)]
derivs[code] = evaluate_cubic_bspline(data, kernel=kernel)
deriv = evaluate_cubic_bspline(data, kernel=kernel)
if sum(order) > 0:
denom = torch.ones(N, dtype=spacing.dtype, device=spacing.device)
for delta, d in zip(spacing.transpose(0, 1), order):
if d > 0:
denom.mul_(delta.pow(d))
denom = denom.reshape((N,) + (1,) * (deriv.ndim - 1))
deriv = deriv.div_(denom.to(deriv))
derivs[code] = deriv

elif mode == "gaussian":

def pad_spatial_dim(data: Tensor, sdim: int, padding: int) -> Tensor:
pad = [(padding, padding) if d == sdim else (0, 0) for d in range(data.ndim - 2)]
pad = [n for v in pad for n in v]
return F.pad(data, pad, mode="replicate")

if not sigma:
sigma = 0.4
kernel_0 = gaussian1d(sigma, normalize=False, dtype=torch.float)
kernel_1 = gaussian1d_I(sigma, normalize=False, dtype=torch.float)
norm = kernel_0.sum()
kernel_0 = kernel_0.div_(norm).to(data.device)
kernel_1 = kernel_1.div_(norm).to(data.device)
sigma = 0.7355 # same default value as used in downsample()
kernel_0 = gaussian1d(sigma, normalize=False, dtype=torch.float, device=data.device)
kernel_1 = gaussian1d_I(sigma, normalize=False, dtype=torch.float, device=data.device)
for i in range(max_order):
for code in unique_keys:
key = code[: i + 1]
if i < len(code) and key not in derivs:
sdim = SpatialDim.from_arg(code[i])
result = data if i == 0 else derivs[code[:i]]
deriv = data if i == 0 else derivs[code[:i]]
for d in range(D):
dim = SpatialDim(d).tensor_dim(result.ndim)
dim = SpatialDim(d).tensor_dim(deriv.ndim)
kernel = kernel_1 if sdim == d else kernel_0
result = conv1d(result, kernel, dim=dim, padding=len(kernel) // 2)
derivs[key] = result
deriv = pad_spatial_dim(deriv, d, len(kernel) // 2)
deriv = conv1d(deriv, kernel, dim=dim, padding=0)
denom = spacing.narrow(1, sdim, 1).reshape((N,) + (1,) * (deriv.ndim - 1))
deriv = deriv.div_(denom.to(deriv))
derivs[key] = deriv
derivs = {key: derivs[SpatialDerivativeKeys.sorted(key)] for key in which}

else:
Expand Down
205 changes: 205 additions & 0 deletions tests/_test_core_flow_deriv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
r"""Interactive test and visualization of vector flow derivatives."""

# %%
# Imports
from typing import Dict, Optional, Sequence

import matplotlib.pyplot as plt

import torch
from torch import Tensor
from torch.random import Generator

from deepali.core import Axes, Grid
import deepali.core.bspline as B
import deepali.core.functional as U


# %%
# Auxiliary functions
def change_axes(flow: Tensor, grid: Grid, axes: Axes, to_axes: Axes) -> Tensor:
if axes != to_axes:
flow = U.move_dim(flow, 1, -1)
flow = grid.transform_vectors(flow, axes=axes, to_axes=to_axes)
flow = U.move_dim(flow, -1, 1)
return flow


def flow_derivatives(
flow: Tensor, grid: Grid, axes: Axes, to_axes: Optional[Axes] = None, **kwargs
) -> Dict[str, Tensor]:
if to_axes is None:
to_axes = axes
flow = change_axes(flow, grid, axes, to_axes)
axes = to_axes
if "spacing" not in kwargs:
if axes == Axes.CUBE:
spacing = tuple(2 / n for n in grid.size())
elif axes == Axes.CUBE_CORNERS:
spacing = tuple(2 / (n - 1) for n in grid.size())
elif axes == Axes.GRID:
spacing = 1
elif axes == Axes.WORLD:
spacing = grid.spacing()
else:
spacing = None
kwargs["spacing"] = spacing
return U.flow_derivatives(flow, **kwargs)


def random_svf(
size: Sequence[int],
stride: int = 1,
generator: Optional[Generator] = None,
) -> Tensor:
cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride)
cp_grid_size = tuple(reversed(cp_grid_size))
data = torch.randn((1, 3) + cp_grid_size, generator=generator)
data = U.fill_border(data, margin=3, value=0, inplace=True)
return B.evaluate_cubic_bspline(data, size=size, stride=stride)


def visualize_flow(
ax: plt.Axes,
flow: Tensor,
grid: Optional[Grid] = None,
axes: Optional[Axes] = None,
label: Optional[str] = None,
) -> None:
if grid is None:
grid = Grid(shape=flow.shape[2:])
if axes is None:
axes = grid.axes()
flow = change_axes(flow, grid, axes, grid.axes())
x = grid.coords(channels_last=False, dtype=flow.dtype, device=flow.device)
x = U.move_dim(x.unsqueeze_(0).add_(flow), 1, -1)
target_grid = U.grid_image(shape=flow.shape[2:], inverted=True, stride=(5, 5))
warped_grid = U.warp_image(target_grid, x, align_corners=grid.align_corners())
ax.imshow(warped_grid[0, 0, flow.shape[2] // 2], cmap="gray")
if label:
ax.set_title(label, fontsize=24)


# %%
# Random velocity fields
generator = torch.Generator().manual_seed(42)
grid = Grid(size=(128, 128, 64), spacing=(0.5, 0.5, 1.0))
flow = random_svf(grid.size(), stride=8, generator=generator).mul_(0.1)

fig, axes = plt.subplots(1, 1, figsize=(4, 4))

ax = axes
ax.set_title("v", fontsize=24, pad=20)
visualize_flow(ax, flow, grid=grid, axes=grid.axes())


# %%
# Visualise first order derivatives for different modes
configs = [
dict(mode="forward_central_backward"),
dict(mode="bspline"),
dict(mode="gaussian", sigma=0.7355),
]

fig, axes = plt.subplots(len(configs), 4, figsize=(16, 4 * len(configs)))

for i, config in enumerate(configs):
derivs = flow_derivatives(
flow,
grid=grid,
axes=grid.axes(),
to_axes=Axes.GRID,
which=["du/dx", "du/dy", "dv/dx", "dv/dy"],
**config,
)
for ax, (key, deriv) in zip(axes[i], derivs.items()):
if i == 0:
ax.set_title(key, fontsize=24, pad=20)
ax.imshow(deriv[0, 0, deriv.shape[2] // 2], vmin=-1, vmax=1)


# %%
# Compare magnitudes of first order derivatives for different modes
flow_axes = [Axes.GRID, Axes.WORLD, Axes.CUBE_CORNERS]

sigma = 0.7355

configs = [
dict(mode="bspline"),
dict(mode="gaussian", sigma=sigma),
dict(mode="forward_central_backward", sigma=sigma),
dict(mode="forward_central_backward"),
]

for to_axes in flow_axes:
for config in configs:
print(f"axes={to_axes}, " + ", ".join(f"{k}={v!r}" for k, v in config.items()))
derivs = flow_derivatives(
flow,
grid=grid,
axes=grid.axes(),
to_axes=to_axes,
which=["du/dx", "du/dy", "dv/dx", "dv/dy"],
**config,
)
for key, deriv in derivs.items():
print(f"- max(abs({key})): {deriv.abs().max().item():.5f}")
print()
print("\n")


# %%
# Visualise second order derivatives for different modes
configs = [
dict(mode="forward_central_backward"),
dict(mode="bspline"),
dict(mode="gaussian", sigma=0.7355),
]

fig, axes = plt.subplots(len(configs), 4, figsize=(16, 4 * len(configs)))

for i, config in enumerate(configs):
derivs = flow_derivatives(
flow,
grid=grid,
axes=grid.axes(),
to_axes=Axes.GRID,
which=["du/dxx", "du/dxy", "dv/dxy", "dv/dyy"],
**config,
)
for ax, (key, deriv) in zip(axes[i], derivs.items()):
if i == 0:
ax.set_title(key, fontsize=24, pad=20)
ax.imshow(deriv[0, 0, deriv.shape[2] // 2], vmin=-0.4, vmax=0.4)


# %%
# Compare magnitudes of second order derivatives for different modes
flow_axes = [Axes.GRID, Axes.WORLD, Axes.CUBE_CORNERS]

sigma = 0.7355

configs = [
dict(mode="bspline"),
dict(mode="gaussian", sigma=sigma),
dict(mode="forward_central_backward", sigma=sigma),
dict(mode="forward_central_backward"),
]

for to_axes in flow_axes:
for config in configs:
print(f"axes={to_axes}, " + ", ".join(f"{k}={v!r}" for k, v in config.items()))
derivs = flow_derivatives(
flow,
grid=grid,
axes=grid.axes(),
to_axes=to_axes,
which=["du/dxx", "du/dxy", "dv/dxy", "dv/dyy"],
**config,
)
for key, deriv in derivs.items():
print(f"- max(abs({key})): {deriv.abs().max().item():.5f}")
print()
print("\n")

# %%

0 comments on commit 41bcc55

Please sign in to comment.