Skip to content

Commit

Permalink
[core] Minor change of bch_terms param of compose_svfs()
Browse files Browse the repository at this point in the history
  • Loading branch information
aschuh-hf committed Nov 13, 2023
1 parent 259799e commit 3980252
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 31 deletions.
60 changes: 41 additions & 19 deletions src/deepali/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,44 @@ def compose_svfs(
sigma: Optional[float] = None,
spacing: Optional[Union[Scalar, Array]] = None,
stride: Optional[ScalarOrTuple[int]] = None,
bch_terms: int = 4,
bch_terms: int = 3,
) -> Tensor:
r"""Approximate stationary velocity field (SVF) of composite deformation.
The output velocity field is ``w = log(exp(v) o exp(u))``, where ``exp`` is the exponential map
of a stationary velocity field, and ``log`` its inverse. The velocity field ``w`` is given by the
`Baker-Campbell-Hausdorff (BCH) formula <https://en.wikipedia.org/wiki/Baker%E2%80%93Campbell%E2%80%93Hausdorff_formula>`_.
The BCH formula with 5 Lie bracket terms (cf. ``bch_terms`` parameter) is
.. math::
w = v + u + \frac{1}{2} [v, u]
+ \frac{1}{12} ([v, [v, u]] - [u, [v, u]])
+ \frac{1}{48} ([[v, [v, u]], u] - [v, [u, [v, u]]])
where
.. math::
[[v, [v, u]], u] - [v, [u, [v, u]]] = -2 [u, [v, [v, u]]]
References:
- Vercauteren, 2008. Symmetric Log-Domain Diffeomorphic Registration: A Demons-based Approach.
doi:10.1007/978-3-540-85988-8_90
- Bossa & Olmos, 2008. A new algorithm for the computation of the group logarithm of diffeomorphisms.
https://inria.hal.science/inria-00629873
- Vercauteren et al., 2008. Symmetric log-domain diffeomorphic registration: A Demons-based approach.
https://doi.org/10.1007/978-3-540-85988-8_90
Args:
u: First applied stationary velocity field as tensor of shape ``(N, D, ..., X)``.
v: Second applied stationary velocity field as tensor of shape ``(N, D, ..., X)``.
bch_terms: Number of terms of the BCH formula to consider. Must be at least 2.
When 2, the returned velocity field is the sum of ``u`` and ``v``.
bch_terms: Number of Lie bracket terms of the BCH formula to consider.
When 0, the returned velocity field is the sum of ``u`` and ``v``.
This approximation is accurate if the input velocity fields commute, i.e.,
the Lie bracket [v, u] = 0. When ``bch_terms=3``, the approximation is given by
``w = v + u + 1/2 [v, u]`` (note that deformation ``exp(u)`` is applied first),
and when ``bch_terms=4``, it is ``w = v + u + 1/2 [v, u] + 1/12 [v, [v, u]]``.
the Lie bracket [v, u] = 0. When ``bch_terms=1``, the approximation is given by
``w = v + u + 1/2 [v, u]`` (note ``exp(u)`` is applied before ``exp(v)``). Formula
``w = v + u + \frac{1}{2} [v, u] + \frac{1}{12} ([v, [v, u]] - [u, [v, u]])`` is
used by default, i.e., ``bch_terms=3``.
mode: Mode of :func:`flow_derivatives()` approximation.
sigma: Standard deviation of Gaussian used for computing spatial derivatives.
spacing: Physical size of image voxels used to compute spatial derivatives.
Expand All @@ -114,25 +131,30 @@ def lb(a: Tensor, b: Tensor) -> Tensor:
raise ValueError(f"compose_svfs() '{name}' must have shape (N, D, ..., X)")
if u.shape != v.shape:
raise ValueError("compose_svfs() 'u' and 'v' must have the same shape")
if bch_terms < 2:
raise ValueError("compose_svfs() 'bch_terms' must be at least 2")
elif bch_terms > 6:
if bch_terms < 0:
raise ValueError("compose_svfs() 'bch_terms' must not be negative")
elif bch_terms > 5:
raise NotImplementedError("compose_svfs() 'bch_terms' of more than 6 not implemented")

# w = v + u
w = v.add(u)
if bch_terms >= 3:
if bch_terms >= 1:
# + 1/2 [v, u]
vu = lb(v, u)
w = w.add(vu.mul(0.5))
if bch_terms >= 4:
if bch_terms >= 2:
# + 1/12 [v, [v, u]]
vvu = lb(v, vu)
w = w.add(vvu.mul(1 / 12))
if bch_terms >= 5:
uv = lb(u, v)
uuv = lb(u, uv)
w = w.add(uuv.mul(1 / 12))
if bch_terms >= 6:
if bch_terms >= 3:
# - 1/12 [u, [v, u]]
uvu = lb(u, vu)
w = w.sub(uvu.mul(1 / 12))
if bch_terms >= 4:
# + 1/48 [[v, [v, u]], u] = - 1/48 [u, [v, [v, u]]]
# - 1/48 [v, [u, [v, u]]] = - 1/48 [u, [v, [v, u]]]
uvvu = lb(u, vvu)
w = w.sub(uvvu.mul(1 / 24))
w = w.sub(uvvu.mul((1 if bch_terms == 4 else 2) / 48))

return w

Expand Down
2 changes: 1 addition & 1 deletion tests/_test_compose_svfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def visualize_flow(ax, flow: Tensor) -> None:

# %%
# Approximate velocity field of composite displacement field
flow_w = U.expv(U.compose_svfs(u, v, bch_terms=6))
flow_w = U.expv(U.compose_svfs(u, v, bch_terms=3))


# %%
Expand Down
20 changes: 9 additions & 11 deletions tests/test_core_flow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,34 +438,32 @@ def test_flow_compose_svfs() -> None:

with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=-1)
with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=0)
with pytest.raises(ValueError):
U.compose_svfs(p, p, bch_terms=1)
with pytest.raises(NotImplementedError):
U.compose_svfs(p, p, bch_terms=7)
U.compose_svfs(p, p, bch_terms=6)

# u = [yz, xz, xy] and v = u
u = v = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1)

w = U.compose_svfs(u, v, bch_terms=0)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=1)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=2)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=3)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=4)
assert torch.allclose(w, u.add(v))
assert torch.allclose(w, u.add(v), atol=1e-5)
w = U.compose_svfs(u, v, bch_terms=5)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=6)
assert torch.allclose(w, u.add(v), atol=1e-5)

# u = [yz, xz, xy] and v = [x, y, z]
u = torch.cat([y.mul(z), x.mul(z), x.mul(y)], dim=1)
v = torch.cat([x, y, z], dim=1)

w = U.compose_svfs(u, v, bch_terms=2)
w = U.compose_svfs(u, v, bch_terms=0)
assert torch.allclose(w, u.add(v))
w = U.compose_svfs(u, v, bch_terms=3)
w = U.compose_svfs(u, v, bch_terms=1)
assert torch.allclose(w, u.mul(0.5).add(v), atol=1e-6)

# u = random_svf(), u -> 0 at boundary
Expand All @@ -474,7 +472,7 @@ def test_flow_compose_svfs() -> None:
generator = torch.Generator().manual_seed(42)
u = random_svf(size, stride=4, generator=generator).mul_(0.1)
v = random_svf(size, stride=4, generator=generator).mul_(0.05)
w = U.compose_svfs(u, v, bch_terms=6)
w = U.compose_svfs(u, v, bch_terms=5)

flow_u = U.expv(u)
flow_v = U.expv(v)
Expand Down

0 comments on commit 3980252

Please sign in to comment.