diff --git a/.github/workflows/test-and-publish.yaml b/.github/workflows/test-and-publish.yaml index 4c46a9c..2b9576e 100644 --- a/.github/workflows/test-and-publish.yaml +++ b/.github/workflows/test-and-publish.yaml @@ -50,6 +50,8 @@ jobs: pytorch-version: "1.11" - python-version: "3.11" pytorch-version: "2.0" + - python-version: "3.12" + pytorch-version: "2.4" steps: - uses: actions/checkout@v3 - uses: ./.github/actions/test diff --git a/interpol/__init__.py b/interpol/__init__.py index ecb4add..91dee73 100644 --- a/interpol/__init__.py +++ b/interpol/__init__.py @@ -1,7 +1,7 @@ -from .api import * -from .resize import * -from .restrict import * -from . import backend +from .api import * # noqa: F401, F403 +from .resize import * # noqa: F401, F403 +from .restrict import * # noqa: F401, F403 +from . import backend # noqa: F401 from . import _version __version__ = _version.get_versions()['version'] diff --git a/interpol/api.py b/interpol/api.py index b7c0066..b128368 100755 --- a/interpol/api.py +++ b/interpol/api.py @@ -1,8 +1,20 @@ """High level interpolation API""" -__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad', - 'spline_coeff', 'spline_coeff_nd', - 'identity_grid', 'add_identity_grid', 'add_identity_grid_'] +__all__ = [ + 'pull', + 'push', + 'count', + 'grid_pull', + 'grid_push', + 'grid_count', + 'grid_grad', + 'spline_coeff', + 'spline_coeff_nd', + 'identity_grid', + 'add_identity_grid', + 'add_identity_grid_', + 'affine_grid', +] import torch from .utils import expanded_shape, matvec @@ -44,7 +56,7 @@ https://en.wikipedia.org/wiki/Discrete_sine_transform""" _doc_bound_coeff = \ -"""`bound` can be an int, a string or a BoundType. +"""`bound` can be an int, a string or a BoundType. Possible values are: - 'replicate' or 'nearest' : a a a | a b c d | d d d - 'dct1' or 'mirror' : d c b | a b c d | c b a @@ -61,7 +73,7 @@ - `dct2` corresponds to mirroring about the edge of the first/last voxel See https://en.wikipedia.org/wiki/Discrete_cosine_transform https://en.wikipedia.org/wiki/Discrete_sine_transform - + /!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation orders >= 6.""" @@ -143,11 +155,11 @@ def grid_pull(input, grid, interpolation='linear', bound='zero', {interpolation} {bound} - - If the input dtype is not a floating point type, the input image is - assumed to contain labels. Then, unique labels are extracted - and resampled individually, making them soft labels. Finally, - the label map is reconstructed from the individual soft labels by + + If the input dtype is not a floating point type, the input image is + assumed to contain labels. Then, unique labels are extracted + and resampled individually, making them soft labels. Finally, + the label map is reconstructed from the individual soft labels by assigning the label with maximum soft value. Parameters @@ -290,7 +302,7 @@ def grid_count(grid, shape=None, interpolation='linear', bound='zero', def grid_grad(input, grid, interpolation='linear', bound='zero', extrapolate=False, prefilter=False): """Sample spatial gradients of an image with respect to a deformation field. - + Notes ----- {interpolation} diff --git a/interpol/autograd.py b/interpol/autograd.py index 40cace9..e9e7d9c 100644 --- a/interpol/autograd.py +++ b/interpol/autograd.py @@ -10,9 +10,41 @@ grid_grad, grid_grad_backward) from .utils import fake_decorator try: - from torch.cuda.amp import custom_fwd, custom_bwd + from torch.amp import custom_fwd, custom_bwd except (ModuleNotFoundError, ImportError): - custom_fwd = custom_bwd = fake_decorator + try: + from torch.cuda.amp import ( + custom_fwd as _custom_fwd_cuda, + custom_bwd as _custom_bwd_cuda + ) + except (ModuleNotFoundError, ImportError): + _custom_fwd_cuda = _custom_bwd_cuda = fake_decorator + + try: + from torch.cpu.amp import ( + custom_fwd as _custom_fwd_cpu, + custom_bwd as _custom_bwd_cpu + ) + except (ModuleNotFoundError, ImportError): + _custom_fwd_cpu = _custom_bwd_cpu = fake_decorator + + def custom_fwd(fwd=None, *, device_type, cast_inputs=None): + if device_type == 'cuda': + decorator = _custom_fwd_cuda(cast_inputs=cast_inputs) + return decorator(fwd) if fwd else decorator + if device_type == 'cpu': + decorator = _custom_fwd_cpu(cast_inputs=cast_inputs) + return decorator(fwd) if fwd else decorator + return fake_decorator(fwd) if fwd else decorator + + def custom_bwd(bwd=None, *, device_type): + if device_type == 'cuda': + decorator = _custom_bwd_cuda + return decorator(bwd) if bwd else decorator + if device_type == 'cpu': + decorator = _custom_bwd_cpu + return decorator(bwd) if bwd else decorator + return fake_decorator(bwd) if bwd else decorator def make_list(x): @@ -125,7 +157,7 @@ def inter_to_nitorch(inter, as_type='str'): class GridPull(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float32) + @custom_fwd(device_type='cuda', cast_inputs=torch.float32) def forward(ctx, input, grid, interpolation, bound, extrapolate): bound = bound_to_nitorch(make_list(bound), as_type='int') @@ -143,7 +175,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate): return output @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, grad): var = ctx.saved_tensors opt = ctx.opt @@ -155,7 +187,7 @@ def backward(ctx, grad): class GridPush(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float32) + @custom_fwd(device_type='cuda', cast_inputs=torch.float32) def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): bound = bound_to_nitorch(make_list(bound), as_type='int') @@ -173,7 +205,7 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate): return output @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, grad): var = ctx.saved_tensors opt = ctx.opt @@ -185,7 +217,7 @@ def backward(ctx, grad): class GridCount(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float32) + @custom_fwd(device_type='cuda', cast_inputs=torch.float32) def forward(ctx, grid, shape, interpolation, bound, extrapolate): bound = bound_to_nitorch(make_list(bound), as_type='int') @@ -203,7 +235,7 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate): return output @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, grad): var = ctx.saved_tensors opt = ctx.opt @@ -216,7 +248,7 @@ def backward(ctx, grad): class GridGrad(torch.autograd.Function): @staticmethod - @custom_fwd(cast_inputs=torch.float32) + @custom_fwd(device_type='cuda', cast_inputs=torch.float32) def forward(ctx, input, grid, interpolation, bound, extrapolate): bound = bound_to_nitorch(make_list(bound), as_type='int') @@ -234,7 +266,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate): return output @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, grad): var = ctx.saved_tensors opt = ctx.opt @@ -248,7 +280,7 @@ def backward(ctx, grad): class SplineCoeff(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type='cuda') def forward(ctx, input, bound, interpolation, dim, inplace): bound = bound_to_nitorch(make_list(bound)[0], as_type='int') @@ -265,7 +297,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace): return output @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, grad): # symmetric filter -> backward == forward # (I don't know if I can write into grad, so inplace=False to be safe) @@ -276,7 +308,7 @@ def backward(ctx, grad): class SplineCoeffND(torch.autograd.Function): @staticmethod - @custom_fwd + @custom_fwd(device_type='cuda') def forward(ctx, input, bound, interpolation, dim, inplace): bound = bound_to_nitorch(make_list(bound), as_type='int') @@ -293,7 +325,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace): return output @staticmethod - @custom_bwd + @custom_bwd(device_type='cuda') def backward(ctx, grad): # symmetric filter -> backward == forward # (I don't know if I can write into grad, so inplace=False to be safe)