Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Take NonUniformFastFourierOp out of FourierOp #463

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
38f33fb
tests added
ckolbPTB Oct 22, 2024
74e6250
doc strings
ckolbPTB Oct 22, 2024
60fca9e
Update src/mrpro/operators/NonUniformFastFourierOp.py
fzimmermann89 Oct 22, 2024
fed8954
merge main
ckolbPTB Oct 23, 2024
2659402
review
ckolbPTB Oct 23, 2024
cc1472d
input pars and doc string adapted
ckolbPTB Oct 23, 2024
0c29b01
Merge branch 'main' into separate_nufft_op
ckolbPTB Oct 23, 2024
393c5a9
Merge branch 'main' into separate_nufft_op
ckolbPTB Oct 23, 2024
276873e
test for nufft output added
ckolbPTB Oct 23, 2024
89e36bc
Merge branch 'main' into separate_nufft_op
ckolbPTB Oct 24, 2024
3c6d873
empty dim
ckolbPTB Oct 24, 2024
d2fb391
Update tests/operators/test_non_uniform_fast_fourier_op.py
ckolbPTB Dec 5, 2024
a4b8a64
merge main
ckolbPTB Dec 5, 2024
2a90533
fix merge problems
ckolbPTB Dec 5, 2024
b67bef6
fix more merge problems
ckolbPTB Dec 5, 2024
9b60be2
gram and cart_samp fixed
ckolbPTB Dec 5, 2024
281e5bc
spatial dims and test for unsupported direction added
ckolbPTB Dec 5, 2024
4b8d8fb
nufft dim automatically detected
ckolbPTB Dec 6, 2024
aba850d
fix for single shot traj and further tests added
ckolbPTB Dec 11, 2024
6ba3fb4
remove superfluous tests
ckolbPTB Dec 11, 2024
b7c83e6
first try
ckolbPTB Dec 11, 2024
8270de5
forward adapted
ckolbPTB Dec 11, 2024
e73ace5
adjoint adapted
ckolbPTB Dec 11, 2024
e586360
add _nufft_type1 and _nufft_type2
ckolbPTB Dec 12, 2024
9977205
clean up
ckolbPTB Dec 12, 2024
4b70389
sep dims and joint dims for forward and adjoint
ckolbPTB Dec 12, 2024
01ca1dc
gram started
ckolbPTB Dec 12, 2024
87f51c2
gram finished for nufft
ckolbPTB Dec 12, 2024
6585ab9
tests adapted and bug fix
ckolbPTB Dec 13, 2024
2330e3d
add rpe to conftest
ckolbPTB Dec 13, 2024
739d76c
conftest update
ckolbPTB Dec 13, 2024
d8217e8
use given kshape
ckolbPTB Dec 14, 2024
1bccb41
conftest error fixed
ckolbPTB Dec 14, 2024
b11acf0
misalignment k210 and kzyx still a problem
ckolbPTB Dec 14, 2024
b2d58fa
tidy up
ckolbPTB Dec 14, 2024
d027e4a
Merge branch 'main' into separate_nufft_op
ckolbPTB Dec 14, 2024
302782c
joint dims zyx
ckolbPTB Dec 16, 2024
4382d68
nufft gram separated out
ckolbPTB Dec 16, 2024
f364526
Merge branch 'main' into separate_nufft_op
ckolbPTB Dec 16, 2024
2d9c51a
merge main
ckolbPTB Jan 10, 2025
3cd6c8e
gram adj_nufft separate, test fix and speed up
ckolbPTB Jan 10, 2025
9e229da
fix cart traj calc
ckolbPTB Jan 10, 2025
8ae651c
Merge branch 'main' into separate_nufft_op
ckolbPTB Jan 20, 2025
06a7d91
Merge branch 'main' into separate_nufft_op
ckolbPTB Jan 22, 2025
32317f9
docs formatting
ckolbPTB Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 25 additions & 96 deletions src/mrpro/operators/FourierOp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

from collections.abc import Sequence

import numpy as np
import torch
from torchkbnufft import KbNufft, KbNufftAdjoint
from typing_extensions import Self

from mrpro.data._kdata.KData import KData
from mrpro.data.enums import TrajType
from mrpro.data.KTrajectory import KTrajectory
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.operators.FastFourierOp import FastFourierOp
from mrpro.operators.IdentityOp import IdentityOp
from mrpro.operators.LinearOperator import LinearOperator
from mrpro.operators.NonUniformFastFourierOp import NonUniformFastFourierOp


class FourierOp(LinearOperator):
Expand All @@ -24,8 +24,6 @@ def __init__(
encoding_matrix: SpatialDimension[int],
traj: KTrajectory,
nufft_oversampling: float = 2.0,
nufft_numpoints: int = 6,
nufft_kbwidth: float = 2.34,
) -> None:
"""Fourier Operator class.

Expand All @@ -38,11 +36,10 @@ def __init__(
traj
the k-space trajectories where the frequencies are sampled
nufft_oversampling
oversampling used for interpolation in non-uniform FFTs
nufft_numpoints
number of neighbors for interpolation in non-uniform FFTs
nufft_kbwidth
size of the Kaiser-Bessel kernel interpolation in non-uniform FFTs
oversampling used for interpolation in non-uniform FFTs. The oversampling of the interpolation grid, which
is needed during the non-uniform FFT, ensures that there is no foldover due to the finite gridding kernel.
It can be reduced (e.g. to 1.25) to speed up the non-uniform FFT but this might lead to poorer image
quality.
"""
super().__init__()

Expand All @@ -53,9 +50,6 @@ def get_spatial_dims(spatial_dims: SpatialDimension, dims: Sequence[int]):
if i in dims
]

def get_traj(traj: KTrajectory, dims: Sequence[int]):
return [k for k, i in zip((traj.kz, traj.ky, traj.kx), (-3, -2, -1), strict=True) if i in dims]

self._ignore_dims, self._fft_dims, self._nufft_dims = [], [], []
for dim, type_ in zip((-3, -2, -1), traj.type_along_kzyx, strict=True):
if type_ & TrajType.SINGLEVALUE:
Expand All @@ -66,14 +60,16 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
else:
self._nufft_dims.append(dim)

if self._fft_dims:
self._fast_fourier_op = FastFourierOp(
self._fast_fourier_op = (
FastFourierOp(
dim=tuple(self._fft_dims),
recon_matrix=get_spatial_dims(recon_matrix, self._fft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._fft_dims),
)
if self._fft_dims
else IdentityOp()
)

# Find dimensions which require NUFFT
if self._nufft_dims:
fft_dims_k210 = [
dim
Expand All @@ -87,35 +83,17 @@ def get_traj(traj: KTrajectory, dims: Sequence[int]):
'k-space dimension, i.e. kx along k0, ky along k1 and kz along k2',
)

self._nufft_im_size = get_spatial_dims(recon_matrix, self._nufft_dims)
grid_size = [int(size * nufft_oversampling) for size in self._nufft_im_size]
omega = [
k * 2 * torch.pi / ks
for k, ks in zip(
get_traj(traj, self._nufft_dims),
get_spatial_dims(encoding_matrix, self._nufft_dims),
strict=True,
)
]

# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

self._fwd_nufft_op = KbNufft(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op = KbNufftAdjoint(
im_size=self._nufft_im_size,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
self._non_uniform_fast_fourier_op = (
NonUniformFastFourierOp(
dim=tuple(self._nufft_dims),
recon_matrix=get_spatial_dims(recon_matrix, self._nufft_dims),
encoding_matrix=get_spatial_dims(encoding_matrix, self._nufft_dims),
traj=traj,
nufft_oversampling=nufft_oversampling,
)

self._kshape = traj.broadcasted_shape
if self._nufft_dims
else IdentityOp()
)

@classmethod
def from_kdata(cls, kdata: KData, recon_shape: SpatialDimension[int] | None = None) -> Self:
Expand Down Expand Up @@ -146,34 +124,8 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil k-space data with shape: (... coils k2 k1 k0)
"""
if len(self._fft_dims):
# FFT
(x,) = self._fast_fourier_op(x)

if self._nufft_dims:
# we need to move the nufft-dimensions to the end and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils nufft_dims
# we could move the permute to __init__ but then we still would need to prepend if len(other)>1
keep_dims = [-4, *self._nufft_dims] # -4 is always coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)

x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1)

# omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims)
# TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape).
omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

x = self._fwd_nufft_op(x, omega, norm='ortho')

shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims]
x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils
x = x.permute(*unpermute)
return (x,)
# FFT followed by NUFFT
return self._non_uniform_fast_fourier_op(self._fast_fourier_op(x)[0])

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""Adjoint operator mapping the coil k-space data to the coil images.
Expand All @@ -187,28 +139,5 @@ def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
-------
coil image data with shape: (... coils z y x)
"""
if self._fft_dims:
# IFFT
(x,) = self._fast_fourier_op.adjoint(x)

if self._nufft_dims:
# we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils (nufft_dims)
keep_dims = [-4, *self._nufft_dims] # -4 is coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)

x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

x = self._adj_nufft_op(x, omega, norm='ortho')

x = x.reshape(*permuted_x_shape[: -len(keep_dims)], *x.shape[-len(keep_dims) :])
x = x.permute(*unpermute)

return (x,)
# NUFFT followed by FFT
return self._fast_fourier_op.adjoint(self._non_uniform_fast_fourier_op.adjoint(x)[0])
150 changes: 150 additions & 0 deletions src/mrpro/operators/NonUniformFastFourierOp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Non-Uniform Fast Fourier Operator."""

from collections.abc import Sequence

import numpy as np
import torch
from torchkbnufft import KbNufft, KbNufftAdjoint

from mrpro.data.KTrajectory import KTrajectory
from mrpro.operators.LinearOperator import LinearOperator


class NonUniformFastFourierOp(LinearOperator, adjoint_as_backward=True):
"""Non-Uniform Fast Fourier Operator class."""

def __init__(
self,
dim: Sequence[int],
recon_matrix: Sequence[int],
encoding_matrix: Sequence[int],
traj: KTrajectory,
nufft_oversampling: float = 2.0,
nufft_numpoints: int = 6,
nufft_kbwidth: float = 2.34,
) -> None:
"""Initialize Non-Uniform Fast Fourier Operator.

Parameters
----------
dim
dimension along which non-uniform FFT is applied
recon_matrix
dimension of the reconstructed image corresponding to dim
encoding_matrix
dimension of the encoded k-space corresponding to dim
traj
the k-space trajectories where the frequencies are sampled
nufft_oversampling
oversampling used for interpolation in non-uniform FFTs
nufft_numpoints
number of neighbors for interpolation in non-uniform FFTs
nufft_kbwidth
size of the Kaiser-Bessel kernel interpolation in non-uniform FFTs
"""
super().__init__()

self._nufft_dims = dim
if len(dim):

def get_traj(traj: KTrajectory, dims: Sequence[int]):
return [k for k, i in zip((traj.kz, traj.ky, traj.kx), (-3, -2, -1), strict=True) if i in dims]

grid_size = [int(size * nufft_oversampling) for size in recon_matrix]
omega = [
k * 2 * torch.pi / ks
for k, ks in zip(
get_traj(traj, dim),
encoding_matrix,
strict=True,
)
]

# Broadcast shapes not always needed but also does not hurt
omega = [k.expand(*np.broadcast_shapes(*[k.shape for k in omega])) for k in omega]
self.register_buffer('_omega', torch.stack(omega, dim=-4)) # use the 'coil' dim for the direction

self._fwd_nufft_op = KbNufft(
im_size=recon_matrix,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)
self._adj_nufft_op = KbNufftAdjoint(
im_size=recon_matrix,
grid_size=grid_size,
numpoints=nufft_numpoints,
kbwidth=nufft_kbwidth,
)

self._kshape = traj.broadcasted_shape

def forward(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""NUFFT from image space to k-space.

Parameters
----------
x
coil image data with shape: (... coils z y x)

Returns
-------
coil k-space data with shape: (... coils k2 k1 k0)
"""
if len(self._nufft_dims):
# we need to move the nufft-dimensions to the end and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils nufft_dims
# we could move the permute to __init__ but then we still would need to prepend if len(other)>1
keep_dims = [-4, *self._nufft_dims] # -4 is always coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)

x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1)

# omega should be (... non_nufft_dims) n_nufft_dims (nufft_dims)
# TODO: consider moving the broadcast along fft dimensions to __init__ (independent of x shape).
omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

x = self._fwd_nufft_op(x, omega, norm='ortho')

shape_nufft_dims = [self._kshape[i] for i in self._nufft_dims]
x = x.reshape(*permuted_x_shape[: -len(keep_dims)], -1, *shape_nufft_dims) # -1 is coils
x = x.permute(*unpermute)
return (x,)

def adjoint(self, x: torch.Tensor) -> tuple[torch.Tensor,]:
"""NUFFT from k-space to image space.

Parameters
----------
x
coil k-space data with shape: (... coils k2 k1 k0)

Returns
-------
coil image data with shape: (... coils z y x)
"""
if len(self._nufft_dims):
# we need to move the nufft-dimensions to the end, flatten them and flatten all other dimensions
# so the new shape will be (... non_nufft_dims) coils (nufft_dims)
keep_dims = [-4, *self._nufft_dims] # -4 is coil
permute = [i for i in range(-x.ndim, 0) if i not in keep_dims] + keep_dims
unpermute = np.argsort(permute)

x = x.permute(*permute)
permuted_x_shape = x.shape
x = x.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

omega = self._omega.permute(*permute)
omega = omega.broadcast_to(*permuted_x_shape[: -len(keep_dims)], *omega.shape[-len(keep_dims) :])
omega = omega.flatten(end_dim=-len(keep_dims) - 1).flatten(start_dim=-len(keep_dims) + 1)

x = self._adj_nufft_op(x, omega, norm='ortho')

x = x.reshape(*permuted_x_shape[: -len(keep_dims)], *x.shape[-len(keep_dims) :])
x = x.permute(*unpermute)
return (x,)
2 changes: 2 additions & 0 deletions src/mrpro/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mrpro.operators.IdentityOp import IdentityOp
from mrpro.operators.MagnitudeOp import MagnitudeOp
from mrpro.operators.MultiIdentityOp import MultiIdentityOp
from mrpro.operators.NonUniformFastFourierOp import NonUniformFastFourierOp
from mrpro.operators.PhaseOp import PhaseOp
from mrpro.operators.ProximableFunctionalSeparableSum import ProximableFunctionalSeparableSum
from mrpro.operators.SensitivityOp import SensitivityOp
Expand All @@ -38,6 +39,7 @@
"LinearOperator",
"MagnitudeOp",
"MultiIdentityOp",
"NonUniformFastFourierOp",
"Operator",
"PhaseOp",
"ProximableFunctional",
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from xsdata.models.datatype import XmlDate, XmlTime

from tests import RandomGenerator
from tests.data import IsmrmrdRawTestData
from tests.phantoms import EllipsePhantomTestData


Expand Down Expand Up @@ -249,6 +250,19 @@ def create_traj(k_shape, nkx, nky, nkz, sx, sy, sz):
return trajectory


@pytest.fixture(scope='session')
def ismrmrd_cart(ellipse_phantom, tmp_path_factory):
"""Fully sampled cartesian data set."""
ismrmrd_filename = tmp_path_factory.mktemp('mrpro') / 'ismrmrd_cart.h5'
ismrmrd_kdata = IsmrmrdRawTestData(
filename=ismrmrd_filename,
noise_level=0.0,
repetitions=3,
phantom=ellipse_phantom.phantom,
)
return ismrmrd_kdata


COMMON_MR_TRAJECTORIES = pytest.mark.parametrize(
('im_shape', 'k_shape', 'nkx', 'nky', 'nkz', 'sx', 'sy', 'sz', 's0', 's1', 's2'),
[
Expand Down
Loading