Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input shape verification for convolutional models #10

Merged
merged 15 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mfai/torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Tuple

from torch import nn
from .base import ModelABC
from .base import AutoPaddingModel, ModelABC


# Load all models from the torch.models package
Expand All @@ -28,6 +28,10 @@
all_nn_architectures = list(registry.values())


autopad_nn_architectures = {obj for obj in all_nn_architectures
if issubclass(obj, AutoPaddingModel) and obj != 'AutoPaddingModel'}


def load_from_settings_file(
model_name: str,
in_channels: int,
Expand Down
61 changes: 60 additions & 1 deletion mfai/torch/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Tuple
from typing import Any, Optional, Tuple
from torch import Size
import torch

from mfai.torch.padding import pad_batch, undo_padding

class ModelType(Enum):
"""
Expand Down Expand Up @@ -91,3 +94,59 @@ def check_required_attributes(self):
for attr in required_attrs:
if not hasattr(self, attr):
raise AttributeError(f"Missing required attribute : {attr}")


class AutoPaddingModel(ABC):
@abstractmethod
def validate_input_shape(self, input_shape: Size) -> Tuple[bool | Size]:
""" Given an input shape, verifies whether the inputs fit with the
calling model's specifications.

Args:
input_shape (Size): The shape of the input data, excluding any batch dimension and channel dimension.
For example, for a batch of 2D tensors of shape [B,C,W,H], [W,H] should be passed.
For 3D data instead of shape [B,C,W,H,D], instead, [W,H,D] should be passed.

Returns:
Tuple[bool, Size]: Returns a tuple where the first element is a boolean signaling whether the given input shape
already fits the model's requirements. If that value is False, the second element contains the closest
shape that fits the model, otherwise it will be None.
"""

def _maybe_padding(self, data_tensor: torch.Tensor)-> Tuple[torch.Tensor, Optional[torch.Size]]:
""" Performs an optional padding to ensure that the data tensor can be fed
to the underlying model. Padding will happen if if
autopadding was enabled via the settings.

Args:
data_tensor (torch.Tensor): the input data to be potentially padded.

Returns:
Tuple[torch.Tensor, Optional[torch.Size]]: the padded tensor, where the original data is found in the center,
and the old size if padding was possible. If not possible or the shape is already fine,
the data is returned untouched and the second return value will be none.
"""
if not self._settings.autopad_enabled:
return data_tensor, None

old_shape = data_tensor.shape[-len(self.input_shape):]
valid_shape, new_shape = self.validate_input_shape(data_tensor.shape[-len(self.input_shape):])
if not valid_shape:
return pad_batch(batch=data_tensor, new_shape=new_shape, pad_value=0), old_shape
return data_tensor, None

def _maybe_unpadding(self, data_tensor: torch.Tensor, old_shape: torch.Size)-> torch.Tensor:
"""Potentially removes the padding previously added to the given tensor. This action
is only carried out if autopadding was enabled via the settings.

Args:
data_tensor (torch.Tensor): The data tensor from which padding is to be removed.
old_shape (torch.Size): The previous shape of the data tensor. It can either be
[W,H] or [W,H,D] for 2D and 3D data respectively. old_shape is returned by self._maybe_padding.

Returns:
torch.Tensor: The data tensor with the padding removed, if possible.
"""
if self._settings.autopad_enabled and old_shape is not None:
return undo_padding(data_tensor, old_shape=old_shape)
return data_tensor
29 changes: 26 additions & 3 deletions mfai/torch/models/half_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
from dataclasses import dataclass
from functools import reduce
from typing import Tuple, Union
from math import ceil

import torch
from dataclasses_json import dataclass_json
from torch import nn

from mfai.torch.models.base import ModelABC, ModelType
from mfai.torch.models.base import ModelABC, AutoPaddingModel, ModelType
from mfai.torch.models.utils import AbsolutePosEmdebding



@dataclass_json
@dataclass(slots=True)
class HalfUNetSettings:
Expand All @@ -20,6 +22,7 @@ class HalfUNetSettings:
use_ghost: bool = False
last_activation: str = "Identity"
absolute_pos_embed: bool = False
autopad_enabled: bool = False


class GhostModule(nn.Module):
Expand Down Expand Up @@ -61,7 +64,7 @@ def forward(self, x):
return self.relu(x)


class HalfUNet(ModelABC, nn.Module):
class HalfUNet(ModelABC, AutoPaddingModel, nn.Module):
settings_kls = HalfUNetSettings
onnx_supported: bool = True
supported_num_spatial_dims = (2,)
Expand Down Expand Up @@ -189,6 +192,8 @@ def settings(self) -> HalfUNetSettings:
return self._settings

def forward(self, x):
x, old_shape = self._maybe_padding(data_tensor=x)

enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
Expand All @@ -201,7 +206,9 @@ def forward(self, x):
torch.zeros_like(enc1),
)
dec = self.decoder(summed)
return self.activation(self.outconv(dec))
out = self.activation(self.outconv(dec))

return self._maybe_unpadding(out, old_shape=old_shape)

@staticmethod
def _block(
Expand Down Expand Up @@ -286,3 +293,19 @@ def _block(
*layers,
)
return layers

def validate_input_shape(self, input_shape: torch.Size) -> Tuple[bool | torch.Size]:

number_pool_layers = sum(1 for layer in self.modules() if isinstance(layer, nn.MaxPool2d))

# The UNet has M max pooling layers of size 2x2 with stride 2, each of which halves the
# dimensions. For the residual connections to match shape, the input dimensions should
# be divisible by 2^N
d = 2**number_pool_layers


new_shape = [d * ceil(input_shape[i]/d) for i in range(len(input_shape))]
new_shape = torch.Size(new_shape)


return new_shape == input_shape, new_shape
60 changes: 55 additions & 5 deletions mfai/torch/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@

from collections import OrderedDict
from dataclasses import dataclass
from functools import cached_property
from typing import Tuple, Union
from math import ceil

import re
import inspect

import torch
from dataclasses_json import dataclass_json
from torch import nn

from mfai.torch.models.encoders import get_encoder

from .base import ModelABC, ModelType
from .base import ModelABC, AutoPaddingModel, ModelType


class DoubleConv(nn.Module):
Expand Down Expand Up @@ -58,9 +63,10 @@ def forward(self, x):
@dataclass(slots=True)
class UnetSettings:
init_features: int = 64
autopad_enabled: bool = False


class UNet(ModelABC, nn.Module):
class UNet(ModelABC, AutoPaddingModel, nn.Module):
"""
Returns a u_net architecture, with uninitialised weights, matching desired numbers of input and output channels.

Expand Down Expand Up @@ -146,6 +152,8 @@ def forward(self, x):
are applied to a layer with an even x- and y-size.
"""

x, old_shape = self._maybe_padding(data_tensor=x)

enc1 = self.encoder1(x)
enc2 = self.encoder2(self.max_pool(enc1))
enc3 = self.encoder3(self.max_pool(enc2))
Expand All @@ -165,7 +173,9 @@ def forward(self, x):
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, enc1), dim=1)
dec1 = self.decoder1(dec1)
return self.conv(dec1)
out = self.conv(dec1)

return self._maybe_unpadding(out, old_shape=old_shape)

@staticmethod
def _block(in_channels, features, name):
Expand Down Expand Up @@ -199,6 +209,29 @@ def _block(in_channels, features, name):
]
)
)

def validate_input_shape(self, input_shape: torch.Size) -> Tuple[bool | torch.Size]:
number_pool_layers = self._num_pool_layers

# The UNet has M max pooling layers of size 2x2 with stride 2, each of which halves the
# dimensions. For the residual connections to match shape, the input dimensions should
# be divisible by 2^N
d = 2**number_pool_layers


new_shape = [d * ceil(input_shape[i]/d) for i in range(len(input_shape))]
new_shape = torch.Size(new_shape)


return new_shape == input_shape, new_shape

@cached_property
def _num_pool_layers(self):
# introspective, looks at the code of forword and
# counts the number of max pool calls
source_code = inspect.getsource(self.forward)
return len(re.findall(r'max_pool\(', source_code))



@dataclass_json
Expand All @@ -207,9 +240,10 @@ class CustomUnetSettings:
encoder_name: str = "resnet18"
encoder_depth: int = 5
encoder_weights: bool = True
autopad_enabled: bool = False


class CustomUnet(ModelABC, nn.Module):
class CustomUnet(ModelABC, AutoPaddingModel, nn.Module):
settings_kls = CustomUnetSettings
onnx_supported = True
supported_num_spatial_dims = (2,)
Expand Down Expand Up @@ -238,6 +272,8 @@ def __init__(
weights=settings.encoder_weights,
)

self.input_shape = input_shape

decoder_channels = self.encoder.out_channels[
::-1
] # Reverse the order to be the same index of the decoder
Expand Down Expand Up @@ -265,6 +301,8 @@ def settings(self) -> CustomUnetSettings:
return self._settings

def forward(self, x):

x, old_shape = self._maybe_padding(data_tensor=x)
# Encoder part
encoder_outputs = self.encoder(x)
encoder_outputs = encoder_outputs[
Expand All @@ -280,4 +318,16 @@ def forward(self, x):
x = torch.cat([x, skip], dim=1)
x = decoder(x)

return self.final_conv(x)
out = self.final_conv(x)
return self._maybe_unpadding(out, old_shape=old_shape)

def validate_input_shape(self, input_shape: torch.Size) -> Tuple[bool | torch.Size]:
number_pool_layers = self._settings.encoder_depth
print(number_pool_layers)
d = 2**number_pool_layers

new_shape = [d * ceil(input_shape[i]/d) for i in range(len(input_shape))]
new_shape = torch.Size(new_shape)


return new_shape == input_shape, new_shape
Loading
Loading