diff --git a/timm/models/vision_transformer_packed.py b/timm/models/vision_transformer_packed.py index c2f71f94ba..efda484f90 100644 --- a/timm/models/vision_transformer_packed.py +++ b/timm/models/vision_transformer_packed.py @@ -230,7 +230,8 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): else: q = q * self.scale attn = q @ k.transpose(-2, -1) - attn += attn_mask + if attn_mask is not None: + attn += attn_mask attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = attn @ v @@ -292,7 +293,7 @@ def __init__( self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x, attn_mask: Optional[torch.Tensor]): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x @@ -720,8 +721,11 @@ def forward_features( if attn_mask is None: attn_mask = seq_ids.unsqueeze(2) == seq_ids.unsqueeze(1) - key_padding_mask = (seq_ids != 0).unsqueeze(1) - attn_mask = attn_mask & key_padding_mask + # NOTE: not applying key padding mask as padding tokens are already isolated to + # themselves via the above mask (padding has seq_id == 0). Doing an additional + # key padding mask results in fully masked rows which causes numerical issues. + # key_padding_mask = (seq_ids != 0).unsqueeze(1) + # attn_mask = attn_mask & key_padding_mask attn_mask = attn_mask.unsqueeze(1) if attn_mask.dtype == torch.bool: @@ -729,11 +733,12 @@ def forward_features( min_val = torch.finfo(dtype).min attn_mask = torch.zeros_like(attn_mask, dtype=dtype).masked_fill_(~attn_mask, min_val) - # if self.grad_checkpointing and not torch.jit.is_scripting(): - # tokens = checkpoint_seq(self.blocks, tokens) - # else: for b in self.blocks: - tokens = b(tokens, attn_mask=attn_mask) + if self.grad_checkpointing and not torch.jit.is_scripting(): + tokens = torch.utils.checkpoint.checkpoint( + b, tokens, use_reentrant=False, attn_mask=attn_mask) + else: + tokens = b(tokens, attn_mask=attn_mask) tokens = self.norm(tokens) device = tokens.device @@ -743,7 +748,7 @@ def forward_features( seq_lens = seq_lens.reshape(-1) valid_rows = seq_lens > 0 if self.attn_pool is not None: - unpack_mask = unpack_mask & key_padding_mask + # unpack_mask = unpack_mask & key_padding_mask unpack_mask = unpack_mask.unsqueeze(1) unpack_mask = torch.zeros_like(unpack_mask, dtype=tokens.dtype).masked_fill_( ~unpack_mask, torch.finfo(tokens.dtype).min) @@ -767,6 +772,7 @@ def forward_head(self, x, pre_logits: bool = False): if isinstance(x, (list, tuple)): x = torch.stack([t.mean(dim=0) for t in x], 0) else: + # x = x.sum(dim=1) / seq_lens.reshape(-1, 1) x = x.mean(dim=1) x = self.fc_norm(x) x = self.head_drop(x) @@ -801,6 +807,7 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ + 'navit_medium_patch16_384': _cfg(), 'navit_base_patch32_224': _cfg(), 'navit_base_patch32_384': _cfg(), 'navit_base_patch16_224': _cfg(), @@ -821,6 +828,16 @@ def _create_vision_transformer_packed(variant, pretrained=False, **kwargs): ) +@register_model +def navit_medium_patch16_384(pretrained=False, **kwargs) -> VisionTransformerPacked: + model_args = dict( + img_size=384, patch_size=16, embed_dim=512, depth=12, num_heads=8, + fc_norm=False, init_values=1e-5, qkv_bias=False) + model = _create_vision_transformer_packed( + 'navit_medium_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def navit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformerPacked: model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)