Skip to content

Commit

Permalink
[core] Implement Lie bracket and composition of SVFs
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Nov 7, 2023
1 parent 14629bb commit 8ae1101
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 0 deletions.
138 changes: 138 additions & 0 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,79 @@ def compose_flows(u: Tensor, v: Tensor, align_corners: bool = True) -> Tensor:
return u.add(v)


def compose_svfs(
u: Tensor,
v: Tensor,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
bch_terms: int = 4,
) -> Tensor:
r"""Approximate stationary velocity field (SVF) of composite deformation.
The output velocity field is ``w = log(exp(v) o exp(u))``, where ``exp`` is the exponential map
of a stationary velocity field, and ``log`` its inverse. The velocity field ``w`` is given by the
`Baker-Campbell-Hausdorff (BCH) formula <https://en.wikipedia.org/wiki/Baker%E2%80%93Campbell%E2%80%93Hausdorff_formula>`_.
References:
- Vercauteren, 2008. Symmetric Log-Domain Diffeomorphic Registration: A Demons-based Approach.
doi:10.1007/978-3-540-85988-8_90
Args:
u: First applied stationary velocity field as tensor of shape ``(N, D, ..., X)``.
v: Second applied stationary velocity field as tensor of shape ``(N, D, ..., X)``.
bch_terms: Number of terms of the BCH formula to consider. Must be at least 2.
When 2, the returned velocity field is the sum of ``u`` and ``v``.
This approximation is accurate if the input velocity fields commute, i.e.,
the Lie bracket [v, u] = 0. When ``bch_terms=3``, the approximation is given by
``w = v + u + 1/2 [v, u]`` (note that deformation ``exp(u)`` is applied first),
and when ``bch_terms=4``, it is ``w = v + u + 1/2 [v, u] + 1/12 [v, [v, u]]``.
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used for computing spatial derivatives.
spacing: Physical size of image voxels used to compute spatial derivatives.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
Returns:
Approximation of BCH formula as tensor of shape ``(N, D, ..., X)``.
"""

def lb(a: Tensor, b: Tensor) -> Tensor:
return lie_bracket(a, b, mode=mode, sigma=sigma, spacing=spacing, stride=stride)

for name, flow in [("u", u), ("v", v)]:
if flow.ndim < 4:
raise ValueError(
f"compose_svfs() '{name}' must be vector field of shape (N, D, ..., X)"
)
if flow.shape[1] != flow.ndim - 2:
raise ValueError(f"compose_svfs() '{name}' must have shape (N, D, ..., X)")
if u.shape != v.shape:
raise ValueError("compose_svfs() 'u' and 'v' must have the same shape")
if bch_terms < 2:
raise ValueError("compose_svfs() 'bch_terms' must be at least 2")
elif bch_terms > 6:
raise NotImplementedError("compose_svfs() 'bch_terms' of more than 6 not implemented")

w = v.add(u)
if bch_terms >= 3:
vu = lb(v, u)
w = w.add(vu.mul(0.5))
if bch_terms >= 4:
vvu = lb(v, vu)
w = w.add(vvu.mul(1 / 12))
if bch_terms >= 5:
uv = lb(u, v)
uuv = lb(u, uv)
w = w.add(uuv.mul(1 / 12))
if bch_terms >= 6:
uvvu = lb(u, vvu)
w = w.sub(uvvu.mul(1 / 24))

return w


def curl(
flow: Tensor,
mode: Optional[str] = None,
Expand Down Expand Up @@ -508,6 +581,71 @@ def jacobian_matrix(
return jac.contiguous()


def lie_bracket(
v: Tensor,
u: Tensor,
mode: Optional[str] = None,
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
) -> Tensor:
r"""Lie bracket of two vector fields.
Evaluate Lie bracket given by ``[v, u] = Jac(v) * u - Jac(u) * v`` as defined in Eq (6)
of Vercauteren et al. (2008).
Most authors define the Lie bracket as the opposite of (6). Numerical simulations,
and personal communication with M. Bossa, showed the relevance of this definition.
Future research will aim at fully understanding the reason of this discrepancy.
References:
- Vercauteren, 2008. Symmetric Log-Domain Diffeomorphic Registration: A Demons-based Approach.
doi:10.1007/978-3-540-85988-8_90
Args:
u: Left vector field as tensor of shape ``(N, D, ..., X)``.
v: Right vector field as tensor of shape ``(N, D, ..., X)``.
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used for computing spatial derivatives.
spacing: Physical size of image voxels used to compute spatial derivatives.
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
Returns:
Lie bracket of vector fields as tensor of shape ``(N, D, ..., X)``.
"""
for name, flow in [("u", u), ("v", v)]:
if flow.ndim < 4:
raise ValueError(f"lie_bracket() '{name}' must be vector field of shape (N, D, ..., X)")
if flow.shape[1] != flow.ndim - 2:
raise ValueError(f"lie_bracket() '{name}' must have shape (N, D, ..., X)")
if u.shape != v.shape:
raise ValueError("lie_bracket() 'u' and 'v' must have the same shape")
jac_u = jacobian_dict(
u,
mode=mode,
sigma=sigma,
spacing=spacing,
stride=stride,
)
jac_v = jacobian_dict(
v,
mode=mode,
sigma=sigma,
spacing=spacing,
stride=stride,
)
D = flow.ndim - 2
w = torch.zeros_like(u)
for i in range(D):
w_i = w.narrow(1, i, 1)
for j in range(D):
w_i = w_i.add_(jac_v[(i, j)].mul(u.narrow(1, j, 1)))
for j in range(D):
w_i = w_i.sub_(jac_u[(i, j)].mul(v.narrow(1, j, 1)))
return w


def normalize_flow(
data: Tensor,
size: Optional[Union[Tensor, torch.Size]] = None,
Expand Down
4 changes: 4 additions & 0 deletions src/deepali/core/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@

from .flow import affine_flow
from .flow import compose_flows
from .flow import compose_svfs
from .flow import curl
from .flow import denormalize_flow
from .flow import divergence
Expand All @@ -102,6 +103,7 @@
from .flow import jacobian_det
from .flow import jacobian_dict
from .flow import jacobian_matrix
from .flow import lie_bracket
from .flow import normalize_flow
from .flow import sample_flow
from .flow import warp_grid
Expand Down Expand Up @@ -182,6 +184,7 @@
"closest_point_distances",
"closest_point_indices",
"compose_flows",
"compose_svfs",
"conv",
"conv1d",
"crop",
Expand Down Expand Up @@ -216,6 +219,7 @@
"jacobian_det",
"jacobian_dict",
"jacobian_matrix",
"lie_bracket",
"max_pool",
"min_pool",
"normalize_flow",
Expand Down
71 changes: 71 additions & 0 deletions tests/_test_compose_svfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# %%
# Imports
from typing import Optional, Sequence

import matplotlib.pyplot as plt

import torch
from torch import Tensor
from torch.random import Generator

from deepali.core import Grid
import deepali.core.bspline as B
import deepali.core.functional as U


# %%
# Auxiliary functions
def random_svf(
size: Sequence[int],
stride: int = 1,
generator: Optional[Generator] = None,
) -> Tensor:
cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride)
data = torch.randn((1, 3) + cp_grid_size, generator=generator)
data = U.fill_border(data, margin=3, value=0, inplace=True)
return B.evaluate_cubic_bspline(data, size=size, stride=stride)


def visualize_flow(ax, flow: Tensor) -> None:
grid = Grid(shape=flow.shape[2:], align_corners=True)
x = grid.coords(channels_last=False, dtype=u.dtype, device=u.device)
x = U.move_dim(x.unsqueeze(0).add_(flow), 1, -1)
target_grid = U.grid_image(shape=flow.shape[2:], inverted=True, stride=(5, 5))
warped_grid = U.warp_image(target_grid, x)
ax.imshow(warped_grid[0, 0, flow.shape[2] // 2], cmap="gray")


# %%
# Random velocity fields
size = (128, 128, 128)
generator = torch.Generator().manual_seed(42)
u = random_svf(size, stride=8, generator=generator).mul_(0.1)
v = random_svf(size, stride=8, generator=generator).mul_(0.05)


# %%
# Evaluate displacement fields
flow_u = U.expv(u)
flow_v = U.expv(v)
flow = U.compose_flows(flow_u, flow_v)


# %%
# Approximate velocity field of composite displacement field
flow_w = U.expv(U.compose_svfs(u, v, bch_terms=6))


# %%
# Visualize composite displacement fields and error norm
fig, axes = plt.subplots(1, 3, figsize=(30, 10))

visualize_flow(axes[0], flow)
visualize_flow(axes[1], flow_w)

error = flow_w.sub(flow).norm(dim=1, keepdim=True)

ax = axes[2]
_ = ax.imshow(error[0, 0, error.shape[2] // 2], cmap="jet", vmin=0, vmax=0.1)


# %%
103 changes: 103 additions & 0 deletions tests/test_core_flow_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Optional, Sequence

import pytest
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.random import Generator

from deepali.core import Grid
from deepali.core.enum import FlowDerivativeKeys
import deepali.core.bspline as B
import deepali.core.functional as U


Expand Down Expand Up @@ -83,6 +88,17 @@ def periodic_flow_divergence(p: Tensor) -> Tensor:
return du_dx.add(dv_dy)


def random_svf(
size: Sequence[int],
stride: int = 1,
generator: Optional[Generator] = None,
) -> Tensor:
cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride)
data = torch.randn((1, 3) + cp_grid_size, generator=generator)
data = U.fill_border(data, margin=3, value=0, inplace=True)
return B.evaluate_cubic_bspline(data, size=size, stride=stride)


def difference(a: Tensor, b: Tensor, margin: int = 0) -> Tensor:
assert a.shape == b.shape
i = [
Expand Down Expand Up @@ -380,3 +396,90 @@ def test_flow_jacobian() -> None:
jac[[slice(None)] + interior].det(),
atol=0.01,
)


def test_flow_lie_bracket() -> None:
p = U.move_dim(Grid(size=(64, 32, 16)).coords().unsqueeze_(0), -1, 1)

x = p.narrow(1, 0, 1)
y = p.narrow(1, 1, 1)
z = p.narrow(1, 2, 1)

# u = [yz, xz, xy] and v = [x, y, z]
u = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1)
v = torch.cat([x, y, z], dim=1)
w = u

lb_uv = U.lie_bracket(u, v)
assert torch.allclose(U.lie_bracket(v, u), lb_uv.neg())
assert U.lie_bracket(u, u).abs().lt(1e-6).all()
assert torch.allclose(lb_uv, w, atol=1e-6)

# u = [z^2, 0, xy] and v = [0, x + y^3, yz]
u = torch.cat([z.square(), torch.zeros_like(y), x.mul(y)], dim=1)
v = torch.cat([torch.zeros_like(x), x.add(y.pow(3)), y.mul(z)], dim=1)
w = torch.cat([-2 * y * z**2, z**2, x * y**2 - x**2 - x * y**3], dim=1).neg_()

lb_uv = U.lie_bracket(u, v)
assert torch.allclose(U.lie_bracket(v, u), lb_uv.neg())
assert U.lie_bracket(u, u).abs().lt(1e-6).all()
error = difference(lb_uv, w).abs()
assert error[:, :, 1:-1, 1:-1, 1:-1].max().lt(1e-5)
assert error.max().lt(0.134)


def test_flow_compose_svfs() -> None:
# 3D flow fields
p = U.move_dim(Grid(size=(64, 32, 16)).coords().unsqueeze_(0), -1, 1)

x = p.narrow(1, 0, 1)
y = p.narrow(1, 1, 1)
z = p.narrow(1, 2, 1)

with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=-1)
with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=0)
with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=1)
with pytest.raises(NotImplementedError):
U.compose_svfs(p, p, bch_terms=7)

# u = [yz, xz, xy] and v = u
u = v = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1)

w = U.compose_svfs(u, v, bch_terms=2)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=3)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=4)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=5)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=6)
assert torch.allclose(w, u.add(v), atol=1e-5)

# u = [yz, xz, xy] and v = [x, y, z]
u = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1)
v = torch.cat([x, y, z], dim=1)

w = U.compose_svfs(u, v, bch_terms=2)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=3)
assert torch.allclose(w, u.mul(0.5).add(v), atol=1e-6)

# u = random_svf(), u -> 0 at boundary
# v = random_svf(), v -> 0 at boundary
size = (64, 64, 64)
generator = torch.Generator().manual_seed(42)
u = random_svf(size, stride=4, generator=generator).mul_(0.1)
v = random_svf(size, stride=4, generator=generator).mul_(0.05)
w = U.compose_svfs(u, v, bch_terms=6)

flow_u = U.expv(u)
flow_v = U.expv(v)
flow_w = U.expv(w)
flow = U.compose_flows(flow_u, flow_v)

error = flow_w.sub(flow).norm(dim=1)
assert error.max().lt(0.01)

0 comments on commit 8ae1101

Please sign in to comment.