Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input shape verification for convolutional models #10

Merged
merged 15 commits into from
Jan 8, 2025

Conversation

MicheleCattaneo
Copy link
Contributor

@MicheleCattaneo MicheleCattaneo commented Nov 22, 2024

What does this PR do?

It introduces utilities to handle invalid input shapes for convolutional models.

Current progress

  • Functions to check the input shapes for UNet, HalfUNet and CustomUNet + tests.
  • Padding and un-padding functions to be used in conjunction with the aforementioned ones + tests.
  • Automatic padding for torch lightning modules exposed through their settings.

Related Issue

closes #9

This draft PR will be updated as the work progresses. Please share any feedback on the overall structure.

Copy link
Contributor

@LBerth LBerth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi !
Very good job and very nice first contribution.
I have only 2 minor comments concerning the unit test of the padding.
Léa

tests/test_utils.py Outdated Show resolved Hide resolved
tests/test_utils.py Outdated Show resolved Hide resolved
@MicheleCattaneo MicheleCattaneo marked this pull request as ready for review November 26, 2024 08:54
@colon3ltocard
Copy link
Contributor

Hi again @MicheleCattaneo ,

Sorry for the late reply we went thru a lot of changes before being able to integrate/discuss new features.

We discussed with the team and we have some suggestions to make for the autopad you propose:

  • Separation Of Concern: auto-padding is only meaningful for a subset of the models so it should not "contaminate" outside the subset of models (the actual ModelABC subclasses) it is related to.

We propose that a model able to autopad should expose it thru extra settings, here is an example with HalfUnet:

@dataclass_json
@dataclass(slots=True)
class HalfUNetSettings:
    num_filters: int = 64
    dilation: int = 1
    bias: bool = False
    use_ghost: bool = False
    last_activation: str = "Identity"
    absolute_pos_embed: bool = False
    autopad: bool = False

Now the model can handle this in its forward method (this code could be factored in a common class or Mixin):

def forward(self, x):
     if self.settings.autopad:
          x = self.pad(x)
          y=self._forward(x)
          return self.unpad(y)
     else:
        return self._forward(x)

Note : i omitted the shape tracking on purpose for the brevity of this.

This way the autopad doesn't "pollute" the Segmentation/Regression lightning, we don't have to check if autopad is supported using dedicated code outside the model and we can also use it transparently in pure PyTorch loops.

@MicheleCattaneo
Copy link
Contributor Author

Hi @colon3ltocard,
I agree with what you suggest. Commit 337ef25 implements a first variant of this idea. Let me know if you have any feedback!

@colon3ltocard
Copy link
Contributor

@MicheleCattaneo A few minor comments + some CI fails and then we can merge 👍

Copy link
Contributor

@colon3ltocard colon3ltocard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last comment: directly reference the self._settings for the autopad_enabled bool

Copy link
Contributor Author

@MicheleCattaneo MicheleCattaneo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

mfai/torch/segmentation_module.py Show resolved Hide resolved
tests/test_utils.py Outdated Show resolved Hide resolved
mfai/torch/models/unet.py Outdated Show resolved Hide resolved
Copy link
Contributor

@LBerth LBerth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi ! Very good contribution, thank you Michele ! :)

@LBerth LBerth merged commit 330e37e into meteofrance:main Jan 8, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Input shape verification for convolutional models
3 participants