-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Input shape verification for convolutional models #9
Comments
Thanks a lot for your proposal ! Some feedback and thought :
from mfai.torch.padding import pad_batch What do you thing ? |
Hi, thanks for the feedback! Each of the next point corresponds to your list:
What I was initially proposing, was not an auto-padding mechanism, but simply functionalities to obtain the closest shape that fits and manually apply it. The flow would be like this: # define some model
net = model_class(**args)
# check whether my data needs padding
valid_shape, new_shape = net.validate_input_shape(input_data.shape[-2:])
if not valid_shape:
# if it does, apply the padding
input_data = input_utils.pad_batch(batch=input_data, new_shape=new_shape, pad_value=0) Let me know what you think. |
No I didn't mean the dataset should be padded in advance sorry for the misunderstanding. I agree about printing in an informative way => that's what Exception are for imho. Okay for your proposal: manual padding if pure PyTorch or in the lightning module when using lightning. In the later case (lightning module), there should be an option added to the lightning module to explicitly request padding and by default it should not pad and let the exception traverse the call stack and reach the console of the user, that's what I meant. Okay for the mode argument and we keep the rest hidden. Go for the draft PR then ! There will be two reviewers so you will get a second opinion. Thanks again. |
Proposal:
Add a utility to validate and adjust input shapes for convolutional models.
Description:
Currently, the project does not provide a way to check whether an input shape is compatible with certain convolutional models. This can lead to runtime errors when users pass inputs that are incompatible with the model’s architecture. Examples are encoder-decoder models with skip connections that concatenate the decoded tensor and the skipped input; if during the encoding the pooling operation resulted in a non-integer size (which gets rounded), the concatenation can not happen.
I propose adding utilities to:
Validate whether a given input shape is compatible with a specified model.
If the shape is incompatible, calculate and optionally apply the necessary padding to adjust the shape to the closest compatible size.
Motivation
This contribution will improve the usability of the models when used with various datasets with different shapes.
Proposed solution
Every model extending
ModelABC
can optionally implement a functionvalidate_input_shape(self, input_shape: Size) -> Tuple[bool, Size]
which checks whether the current shape is fine, and, if it's not, it returns the closest shape that fits the model's architecture.A utility function
pad_batch(batch: torch.Tensor, new_shape: torch.Size, pad_value: float=0) -> torch.Tensor
can then be used to adjust your input batch's size.A proof of concept for
UNet
,HalfUNet
andCustomUNet
can be found here.Additional contributions
These utilities could be used inside the forward pass of the models; if there is a runtime error and the error is caused by the input shape, additional information is logged to the user to aid them with the process of padding the inputs.
I would appreciate any feedback or suggestions on this approach, and I am willing to contribute and to proceed with a pull request once we’ve reviewed the proposal.
The text was updated successfully, but these errors were encountered: