Skip to content

Commit

Permalink
FIX: deprecated import + expose affine_grid (#21)
Browse files Browse the repository at this point in the history
Fixes #19
  • Loading branch information
balbasty authored Sep 13, 2024
1 parent a4d5f53 commit bab3b86
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 29 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test-and-publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
pytorch-version: "1.11"
- python-version: "3.11"
pytorch-version: "2.0"
- python-version: "3.12"
pytorch-version: "2.4"
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/test
Expand Down
8 changes: 4 additions & 4 deletions interpol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .api import *
from .resize import *
from .restrict import *
from . import backend
from .api import * # noqa: F401, F403
from .resize import * # noqa: F401, F403
from .restrict import * # noqa: F401, F403
from . import backend # noqa: F401

from . import _version
__version__ = _version.get_versions()['version']
34 changes: 23 additions & 11 deletions interpol/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
"""High level interpolation API"""

__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad',
'spline_coeff', 'spline_coeff_nd',
'identity_grid', 'add_identity_grid', 'add_identity_grid_']
__all__ = [
'pull',
'push',
'count',
'grid_pull',
'grid_push',
'grid_count',
'grid_grad',
'spline_coeff',
'spline_coeff_nd',
'identity_grid',
'add_identity_grid',
'add_identity_grid_',
'affine_grid',
]

import torch
from .utils import expanded_shape, matvec
Expand Down Expand Up @@ -44,7 +56,7 @@
https://en.wikipedia.org/wiki/Discrete_sine_transform"""

_doc_bound_coeff = \
"""`bound` can be an int, a string or a BoundType.
"""`bound` can be an int, a string or a BoundType.
Possible values are:
- 'replicate' or 'nearest' : a a a | a b c d | d d d
- 'dct1' or 'mirror' : d c b | a b c d | c b a
Expand All @@ -61,7 +73,7 @@
- `dct2` corresponds to mirroring about the edge of the first/last voxel
See https://en.wikipedia.org/wiki/Discrete_cosine_transform
https://en.wikipedia.org/wiki/Discrete_sine_transform
/!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation
orders >= 6."""

Expand Down Expand Up @@ -143,11 +155,11 @@ def grid_pull(input, grid, interpolation='linear', bound='zero',
{interpolation}
{bound}
If the input dtype is not a floating point type, the input image is
assumed to contain labels. Then, unique labels are extracted
and resampled individually, making them soft labels. Finally,
the label map is reconstructed from the individual soft labels by
If the input dtype is not a floating point type, the input image is
assumed to contain labels. Then, unique labels are extracted
and resampled individually, making them soft labels. Finally,
the label map is reconstructed from the individual soft labels by
assigning the label with maximum soft value.
Parameters
Expand Down Expand Up @@ -290,7 +302,7 @@ def grid_count(grid, shape=None, interpolation='linear', bound='zero',
def grid_grad(input, grid, interpolation='linear', bound='zero',
extrapolate=False, prefilter=False):
"""Sample spatial gradients of an image with respect to a deformation field.
Notes
-----
{interpolation}
Expand Down
60 changes: 46 additions & 14 deletions interpol/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,41 @@
grid_grad, grid_grad_backward)
from .utils import fake_decorator
try:
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.amp import custom_fwd, custom_bwd
except (ModuleNotFoundError, ImportError):
custom_fwd = custom_bwd = fake_decorator
try:
from torch.cuda.amp import (
custom_fwd as _custom_fwd_cuda,
custom_bwd as _custom_bwd_cuda
)
except (ModuleNotFoundError, ImportError):
_custom_fwd_cuda = _custom_bwd_cuda = fake_decorator

try:
from torch.cpu.amp import (
custom_fwd as _custom_fwd_cpu,
custom_bwd as _custom_bwd_cpu
)
except (ModuleNotFoundError, ImportError):
_custom_fwd_cpu = _custom_bwd_cpu = fake_decorator

def custom_fwd(fwd=None, *, device_type, cast_inputs=None):
if device_type == 'cuda':
decorator = _custom_fwd_cuda(cast_inputs=cast_inputs)
return decorator(fwd) if fwd else decorator
if device_type == 'cpu':
decorator = _custom_fwd_cpu(cast_inputs=cast_inputs)
return decorator(fwd) if fwd else decorator
return fake_decorator(fwd) if fwd else decorator

def custom_bwd(bwd=None, *, device_type):
if device_type == 'cuda':
decorator = _custom_bwd_cuda
return decorator(bwd) if bwd else decorator
if device_type == 'cpu':
decorator = _custom_bwd_cpu
return decorator(bwd) if bwd else decorator
return fake_decorator(bwd) if bwd else decorator


def make_list(x):
Expand Down Expand Up @@ -125,7 +157,7 @@ def inter_to_nitorch(inter, as_type='str'):
class GridPull(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -143,7 +175,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -155,7 +187,7 @@ def backward(ctx, grad):
class GridPush(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -173,7 +205,7 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -185,7 +217,7 @@ def backward(ctx, grad):
class GridCount(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, grid, shape, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -203,7 +235,7 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -216,7 +248,7 @@ def backward(ctx, grad):
class GridGrad(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -234,7 +266,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -248,7 +280,7 @@ def backward(ctx, grad):
class SplineCoeff(torch.autograd.Function):

@staticmethod
@custom_fwd
@custom_fwd(device_type='cuda')
def forward(ctx, input, bound, interpolation, dim, inplace):

bound = bound_to_nitorch(make_list(bound)[0], as_type='int')
Expand All @@ -265,7 +297,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
# symmetric filter -> backward == forward
# (I don't know if I can write into grad, so inplace=False to be safe)
Expand All @@ -276,7 +308,7 @@ def backward(ctx, grad):
class SplineCoeffND(torch.autograd.Function):

@staticmethod
@custom_fwd
@custom_fwd(device_type='cuda')
def forward(ctx, input, bound, interpolation, dim, inplace):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -293,7 +325,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
# symmetric filter -> backward == forward
# (I don't know if I can write into grad, so inplace=False to be safe)
Expand Down

0 comments on commit bab3b86

Please sign in to comment.