Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input shape verification for convolutional models #10

Merged
merged 15 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions mfai/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Interface contract for our models.
"""

from abc import ABC, abstractproperty
from abc import ABC, abstractproperty, abstractmethod
from typing import Tuple
from torch import Size
import warnings
Expand All @@ -29,7 +29,25 @@ def input_spatial_dims(self) -> Tuple[int, ...]:
A model supporting 2d and 3d tensors should return (2, 3).
"""

def validate_input_shape(self, input_shape: Size) -> Tuple[bool, Size]:
@property
def auto_padding_supported(self) -> bool:
"""
Indicates whether the model supports automatic padding.
"""
return isinstance(self, AutoPaddingModel)

def check_required_attributes(self):
# we check that the model has defined the following attributes.
# this must be called at the end of the __init__ of each subclass.
required_attrs = ["in_channels", "out_channels", "input_shape"]
for attr in required_attrs:
if not hasattr(self, attr):
raise AttributeError(f"Missing required attribute : {attr}")


class AutoPaddingModel(ABC):
@abstractmethod
def validate_input_shape(self, input_shape: Size) -> Tuple[bool | Size]:
""" Given an input shape, verifies whether the inputs fit with the
calling model's specifications.

Expand All @@ -42,18 +60,4 @@ def validate_input_shape(self, input_shape: Size) -> Tuple[bool, Size]:
Tuple[bool, Size]: Returns a tuple where the first element is a boolean signaling whether the given input shape
already fits the model's requirements. If that value is False, the second element contains the closest
shape that fits the model, otherwise it will be None.
"""
warnings.warn(
f"{self.__class__.__name__} has not overridden the {self.validate_input_shape.__name__} method. The correctness is not guaranteed.",
UserWarning
)

return True, input_shape

def check_required_attributes(self):
# we check that the model has defined the following attributes.
# this must be called at the end of the __init__ of each subclass.
required_attrs = ["in_channels", "out_channels", "input_shape"]
for attr in required_attrs:
if not hasattr(self, attr):
raise AttributeError(f"Missing required attribute : {attr}")
"""
4 changes: 2 additions & 2 deletions mfai/torch/models/half_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses_json import dataclass_json
from torch import nn

from mfai.torch.models.base import ModelABC
from mfai.torch.models.base import ModelABC, AutoPaddingModel
from mfai.torch.models.utils import AbsolutePosEmdebding


Expand Down Expand Up @@ -62,7 +62,7 @@ def forward(self, x):
return self.relu(x)


class HalfUNet(ModelABC, nn.Module):
class HalfUNet(ModelABC, nn.Module, AutoPaddingModel):
MicheleCattaneo marked this conversation as resolved.
Show resolved Hide resolved
settings_kls = HalfUNetSettings
onnx_supported = True
input_spatial_dims = (2,)
Expand Down
6 changes: 3 additions & 3 deletions mfai/torch/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from mfai.torch.models.encoders import get_encoder

from .base import ModelABC
from .base import ModelABC, AutoPaddingModel


class DoubleConv(nn.Module):
Expand Down Expand Up @@ -64,7 +64,7 @@ class UnetSettings:
init_features: int = 64


class UNet(ModelABC, nn.Module):
class UNet(ModelABC, nn.Module, AutoPaddingModel):
MicheleCattaneo marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns a u_net architecture, with uninitialised weights, matching desired numbers of input and output channels.

Expand Down Expand Up @@ -227,7 +227,7 @@ class CustomUnetSettings:
encoder_weights: bool = True


class CustomUnet(ModelABC, nn.Module):
class CustomUnet(ModelABC, nn.Module, AutoPaddingModel):
MicheleCattaneo marked this conversation as resolved.
Show resolved Hide resolved
settings_kls = CustomUnetSettings
onnx_supported = True
input_spatial_dims = (2,)
Expand Down
32 changes: 32 additions & 0 deletions mfai/torch/segmentation_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import torch
import torchmetrics as tm
from pytorch_lightning.utilities import rank_zero_only
import warnings

from mfai.torch.models.base import ModelABC
from mfai.torch.padding import pad_batch, undo_padding

# define custom scalar in tensorboard, to have 2 lines on same graph
layout = {
Expand All @@ -23,13 +25,17 @@ def __init__(
model: ModelABC,
type_segmentation: Literal["binary", "multiclass", "multilabel", "regression"],
loss: Callable,
padding_strategy: Literal['none', 'apply', 'apply_and_undo'] = 'none'
) -> None:
"""A lightning module adapted for segmentation of weather images.

Args:
model (ModelABC): Torch neural network model in [DeepLabV3, DeepLabV3Plus, HalfUNet, Segformer, SwinUNETR, UNet, CustomUnet, UNETRPP]
type_segmentation (Literal["binary", "multiclass", "multilabel", "regression"]): Type of segmentation we want to do"
loss (Callable): Loss function
padding_stratey (Literal['none', 'apply', 'apply_and_undo']): Defines the padding strategy to use. With 'none', it's is up to the user to
make sure that the input shapes fit the underlying model. With 'apply', padding is applied and reflected in the output shape.
With 'apply_and_undo', padding is applied for the forward pass, but it is undone before returning the output.
"""
super().__init__()
self.model = model
Expand All @@ -54,6 +60,11 @@ def __init__(
self.model.input_shape[0],
self.model.input_shape[1],
)

self.padding_strategy = padding_strategy
if not self.model.auto_padding_supported and padding_strategy != 'none':
warnings.warn(f"{self.model.__class__.__name__} does not support autopadding and will not be used.",
UserWarning)

def get_metrics(self):
"""Defines the metrics that will be computed during valid and test steps."""
Expand Down Expand Up @@ -96,8 +107,10 @@ def forward(self, inputs: torch.Tensor):
inputs = inputs.to(memory_format=torch.channels_last)
# We prefer when the last activation function is included in the loss and not in the model.
# Consequently, we need to apply the last activation manually here, to get the real output.
inputs, old_shape = self._maybe_padding(data_tensor=inputs)
y_hat = self.model(inputs)
y_hat = self.last_activation(y_hat)
y_hat = self._maybe_unpadding(y_hat, old_shape=old_shape)
return y_hat

def _shared_forward_step(self, x: torch.Tensor, y: torch.Tensor):
Expand All @@ -107,10 +120,29 @@ def _shared_forward_step(self, x: torch.Tensor, y: torch.Tensor):
x = x.to(memory_format=torch.channels_last)
# We prefer when the last activation function is included in the loss and not in the model.
# Consequently, we need to apply the last activation manually here, to get the real output.
x, old_shape = self._maybe_padding(x)
y_hat = self.model(x)
y, _ = self._maybe_padding(y)
loss = self.loss(y_hat, y)
MicheleCattaneo marked this conversation as resolved.
Show resolved Hide resolved
y_hat = self.last_activation(y_hat)
self._maybe_unpadding(y_hat, old_shape=old_shape)
return y_hat, loss

def _maybe_padding(self, data_tensor):
if self.padding_strategy == 'none' or not self.model.auto_padding_supported:
return data_tensor, None

old_shape = data_tensor.shape[-len(self.model.input_shape):]
valid_shape, new_shape = self.model.validate_input_shape(data_tensor.shape[-len(self.model.input_shape):])
if not valid_shape:
return pad_batch(batch=data_tensor, new_shape=new_shape, pad_value=0), old_shape
return data_tensor, None

def _maybe_unpadding(self, data_tensor, old_shape):
if self.padding_strategy == 'apply_and_undo' and old_shape is not None:
return undo_padding(data_tensor, old_shape=old_shape)
return data_tensor


def on_train_start(self):
"""Setup custom scalars panel on tensorboard and log hparams.
Expand Down
9 changes: 6 additions & 3 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
import torch
from marshmallow.exceptions import ValidationError
from mfai.torch.models.base import AutoPaddingModel

from mfai.torch import export_to_onnx, onnx_load_and_infer
from mfai.torch.models import (
Expand Down Expand Up @@ -166,10 +167,12 @@ def test_load_model_by_name():
/ "halfunet128.json",
)

@pytest.mark.parametrize("model_class", [UNet,
CustomUnet,
HalfUNet])
@pytest.mark.parametrize("model_class", all_nn_architectures)
def test_input_shape_validation(model_class):

if not issubclass(model_class, AutoPaddingModel):
return

B, C, W, H = 32,3,64,65

input_data = torch.randn(B,C,W,H)
Expand Down
File renamed without changes.
Loading