Skip to content

Commit

Permalink
Remove key_padding masking, sequence isolation is enough.
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 23, 2023
1 parent f93083e commit 2734bb7
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions timm/models/vision_transformer_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn += attn_mask
if attn_mask is not None:
attn += attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
Expand Down Expand Up @@ -292,7 +293,7 @@ def __init__(
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x, attn_mask: Optional[torch.Tensor]):
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
Expand Down Expand Up @@ -720,20 +721,24 @@ def forward_features(

if attn_mask is None:
attn_mask = seq_ids.unsqueeze(2) == seq_ids.unsqueeze(1)
key_padding_mask = (seq_ids != 0).unsqueeze(1)
attn_mask = attn_mask & key_padding_mask
# NOTE: not applying key padding mask as padding tokens are already isolated to
# themselves via the above mask (padding has seq_id == 0). Doing an additional
# key padding mask results in fully masked rows which causes numerical issues.
# key_padding_mask = (seq_ids != 0).unsqueeze(1)
# attn_mask = attn_mask & key_padding_mask
attn_mask = attn_mask.unsqueeze(1)

if attn_mask.dtype == torch.bool:
dtype = tokens.dtype
min_val = torch.finfo(dtype).min
attn_mask = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, min_val)

# if self.grad_checkpointing and not torch.jit.is_scripting():
# tokens = checkpoint_seq(self.blocks, tokens)
# else:
for b in self.blocks:
tokens = b(tokens, attn_mask=attn_mask)
if self.grad_checkpointing and not torch.jit.is_scripting():
tokens = torch.utils.checkpoint.checkpoint(
b, tokens, use_reentrant=False, attn_mask=attn_mask)
else:
tokens = b(tokens, attn_mask=attn_mask)
tokens = self.norm(tokens)

device = tokens.device
Expand All @@ -743,7 +748,7 @@ def forward_features(
seq_lens = seq_lens.reshape(-1)
valid_rows = seq_lens > 0
if self.attn_pool is not None:
unpack_mask = unpack_mask & key_padding_mask
# unpack_mask = unpack_mask & key_padding_mask
unpack_mask = unpack_mask.unsqueeze(1)
unpack_mask = torch.zeros_like(unpack_mask, dtype=tokens.dtype).masked_fill_(
~unpack_mask, torch.finfo(tokens.dtype).min)
Expand All @@ -767,6 +772,7 @@ def forward_head(self, x, pre_logits: bool = False):
if isinstance(x, (list, tuple)):
x = torch.stack([t.mean(dim=0) for t in x], 0)
else:
# x = x.sum(dim=1) / seq_lens.reshape(-1, 1)
x = x.mean(dim=1)
x = self.fc_norm(x)
x = self.head_drop(x)
Expand Down Expand Up @@ -801,6 +807,7 @@ def _cfg(url='', **kwargs):


default_cfgs = generate_default_cfgs({
'navit_medium_patch16_384': _cfg(),
'navit_base_patch32_224': _cfg(),
'navit_base_patch32_384': _cfg(),
'navit_base_patch16_224': _cfg(),
Expand All @@ -821,6 +828,16 @@ def _create_vision_transformer_packed(variant, pretrained=False, **kwargs):
)


@register_model
def navit_medium_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(
img_size=384, patch_size=16, embed_dim=512, depth=12, num_heads=8,
fc_norm=False, init_values=1e-5, qkv_bias=False)
model = _create_vision_transformer_packed(
'navit_medium_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def navit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformerPacked:
model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
Expand Down

0 comments on commit 2734bb7

Please sign in to comment.