Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix positional embedding resampling for non-square inputs in ViT #2317

Merged
merged 3 commits into from
Nov 7, 2024

Conversation

wojtke
Copy link
Contributor

@wojtke wojtke commented Nov 1, 2024

I have been doing some experiments with ViTs for Object Detection. When I set_input_size to a larger, non-square size like [800, 1008] and then expect the dynamic_img_size option to do its job when passing a different size image as input, it breaks.

When using the dynamic_img_size option, we dynamically resample (interpolate) positional embeddings. The resample_abs_pos_embed function by default assumes the previous grid is square, which I guess is true in most cases. But why assume when we can always pass the old size explicitly?

def resample_abs_pos_embed(
posemb: torch.Tensor,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
# sort out sizes, assume square if old size not provided
num_pos_tokens = posemb.shape[1]
num_new_tokens = new_size[0] * new_size[1] + num_prefix_tokens
if num_new_tokens == num_pos_tokens and new_size[0] == new_size[1]:
return posemb
if old_size is None:
hw = int(math.sqrt(num_pos_tokens - num_prefix_tokens))
old_size = hw, hw
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
# do the interpolation
embed_dim = posemb.shape[-1]
orig_dtype = posemb.dtype
posemb = posemb.float() # interpolate needs float32
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, -1, embed_dim)
posemb = posemb.to(orig_dtype)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
posemb = torch.cat([posemb_prefix, posemb], dim=1)
if not torch.jit.is_scripting() and verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
return posemb

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@rwightman
Copy link
Collaborator

@wojtke there's also a matchin impl in deit.py that needs updating

if self.dynamic_img_size:
B, H, W, C = x.shape
pos_embed = resample_abs_pos_embed(
self.pos_embed,
(H, W),
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
)
x = x.view(B, -1, C)
else:

I'm not quite 100% sure why I didn't do this when I added the set_input_size() fn, I was thinking of adding it but now I'm fuzzy why I didn't. There may have been a concern about accessing the grid_size attribute in forward. I need to retrace what my thought process was there, or if I just forgot to before merging back then...

@wojtke
Copy link
Contributor Author

wojtke commented Nov 1, 2024

I see that the subject came up in #2190, but I am not sure what to make out of it. I think you fixed the main issue by adding set_input_size, but that one I mention just came through unnoticed.

@rwightman
Copy link
Collaborator

@wojtke thanks for the updates, I recall intending to fix the dynamic image size when I worked on set_input_size in #2190 ... for some reason I didn't. I'll pick throught the models and see if I can recall any specific concerns, run through some scenarios and then merge this if I don't bump into anything problematic.

@rwightman
Copy link
Collaborator

@wojtke seems to be no issues to worry about, torchscript fine, etc ... so merging, thanks

@rwightman rwightman merged commit eb94efb into huggingface:main Nov 7, 2024
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants