Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input shape verification for convolutional models #9

Closed
MicheleCattaneo opened this issue Nov 20, 2024 · 3 comments · Fixed by #10
Closed

Input shape verification for convolutional models #9

MicheleCattaneo opened this issue Nov 20, 2024 · 3 comments · Fixed by #10

Comments

@MicheleCattaneo
Copy link
Contributor

Proposal:

Add a utility to validate and adjust input shapes for convolutional models.

Description:

Currently, the project does not provide a way to check whether an input shape is compatible with certain convolutional models. This can lead to runtime errors when users pass inputs that are incompatible with the model’s architecture. Examples are encoder-decoder models with skip connections that concatenate the decoded tensor and the skipped input; if during the encoding the pooling operation resulted in a non-integer size (which gets rounded), the concatenation can not happen.

I propose adding utilities to:

Validate whether a given input shape is compatible with a specified model.
If the shape is incompatible, calculate and optionally apply the necessary padding to adjust the shape to the closest compatible size.

Motivation

This contribution will improve the usability of the models when used with various datasets with different shapes.

Proposed solution

Every model extending ModelABC can optionally implement a function validate_input_shape(self, input_shape: Size) -> Tuple[bool, Size] which checks whether the current shape is fine, and, if it's not, it returns the closest shape that fits the model's architecture.

A utility function pad_batch(batch: torch.Tensor, new_shape: torch.Size, pad_value: float=0) -> torch.Tensor can then be used to adjust your input batch's size.

A proof of concept for UNet, HalfUNet and CustomUNet can be found here.

Additional contributions

These utilities could be used inside the forward pass of the models; if there is a runtime error and the error is caused by the input shape, additional information is logged to the user to aid them with the process of padding the inputs.


I would appreciate any feedback or suggestions on this approach, and I am willing to contribute and to proceed with a pull request once we’ve reviewed the proposal.

@colon3ltocard
Copy link
Contributor

colon3ltocard commented Nov 21, 2024

Thanks a lot for your proposal ! Some feedback and thought :

  • The idea of having some kind of "automatic padding" is great and could be useful 👍 so let's go for it
  • By default we expect the model not to pad the data and raise an explicit Exception warning the user of the incompatibility between the supplied input data shape and the model architecture. In most cases we prefer adjusting our input shape (changing the patch width/height ) during the dataloading process rather than having constant padding applied "under the hood" and thus "deteriorating" the sample.
  • following the "Zen of Python" enabling auto-padding should be explicit also, probably thru the model settings.
  • maybe it is worth having a padding api matching the underlying F.pad. In some cases we might want to use "reflect" or "circular" modes.
  • I am not a big fan of the warning in the proposed default validate_input_shape. For me models supporting auto-pad should explicitely expose the functionnality thru their settings (and it should be disabled by default)
  • For locating the padding related code would you agree for mfai/torch/padding.py instead of mfai/torch/utils/input_utils.py ? I don't like utils because it is not clear what's in them ^^
from mfai.torch.padding import pad_batch

What do you thing ?

@MicheleCattaneo
Copy link
Contributor Author

Hi, thanks for the feedback!

Each of the next point corresponds to your list:

  • Great!
  • Do you mean that the whole dataset should be preventively padded? I agree that the model's forward() should not do the padding, but it could be a good idea to only pad one batch at the time, to save GPU memory as well. This would be up to the user in case of a manual training loop, or up to the LightnighModule when torch lightning is used. My suggestion was simply to print in an informative feedback in case the forward pass fails due to the input shape. It is then up to the user to use the suggested functions to provide an input of the correct shape next time.
  • I think auto padding should only be done by a LightningModule, not the models.
  • Good idea. Would it suffice to add the mode argument, but keep the computation of the left,right,top,bottom,... sizes hidden from the user?
  • Could you please explain how you'd implement this?
  • Agreed! I will move it.

What I was initially proposing, was not an auto-padding mechanism, but simply functionalities to obtain the closest shape that fits and manually apply it. The flow would be like this:

# define some model
net = model_class(**args)
# check whether my data needs padding
valid_shape, new_shape = net.validate_input_shape(input_data.shape[-2:])
if not valid_shape:
    # if it does, apply the padding
    input_data = input_utils.pad_batch(batch=input_data, new_shape=new_shape, pad_value=0)

Let me know what you think.
Should I open a draft pull request to keep track of the updates on this feature?

@colon3ltocard
Copy link
Contributor

No I didn't mean the dataset should be padded in advance sorry for the misunderstanding. I agree about printing in an informative way => that's what Exception are for imho.

Okay for your proposal: manual padding if pure PyTorch or in the lightning module when using lightning. In the later case (lightning module), there should be an option added to the lightning module to explicitly request padding and by default it should not pad and let the exception traverse the call stack and reach the console of the user, that's what I meant.

Okay for the mode argument and we keep the rest hidden.

Go for the draft PR then ! There will be two reviewers so you will get a second opinion.

Thanks again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants