diff --git a/timm/models/davit.py b/timm/models/davit.py index 39f1a115a3..efcce4f488 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module): def __init__(self, dim: int, k: int = 3, act: bool = False): super(ConvPosEnc, self).__init__() - self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim) + self.proj = nn.Conv2d( + dim, + dim, + kernel_size=k, + stride=1, + padding=k // 2, + groups=dim, + ) self.act = nn.GELU() if act else nn.Identity() def forward(self, x: Tensor): @@ -84,6 +91,7 @@ def __init__( self, in_chs, out_chs, + kernel_size=3, norm_layer=LayerNorm2d, ): super().__init__() @@ -91,23 +99,57 @@ def __init__( self.out_chs = out_chs self.norm = norm_layer(in_chs) + self.even_k = kernel_size % 2 == 0 self.conv = nn.Conv2d( in_chs, out_chs, - kernel_size=2, + kernel_size=kernel_size, stride=2, - padding=0, + padding=0 if self.even_k else kernel_size // 2, ) def forward(self, x: Tensor): B, C, H, W = x.shape x = self.norm(x) - x = F.pad(x, (0, (2 - W % 2) % 2)) - x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2)) + if self.even_k: + k_h, k_w = self.conv.kernel_size + x = F.pad(x, (0, (k_w - W % k_w) % k_w)) + x = F.pad(x, (0, 0, 0, (k_h - H % k_h) % k_h)) x = self.conv(x) return x +class ChannelAttentionV2(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True): + super().__init__() + self.groups = num_heads + self.head_dim = dim // num_heads + self.dynamic_scale = dynamic_scale + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if self.dynamic_scale: + q = q * float(N) ** -0.5 + else: + q = q * self.head_dim ** -0.5 + attn = q.transpose(-1, -2) @ k + attn = attn.softmax(dim=-1) + x = (attn @ v.transpose(-1, -2)).transpose(-1, -2) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + + class ChannelAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False): @@ -147,13 +189,19 @@ def __init__( norm_layer=nn.LayerNorm, ffn=True, cpe_act=False, + v2=False, ): super().__init__() self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.ffn = ffn self.norm1 = norm_layer(dim) - self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + attn_layer = ChannelAttentionV2 if v2 else ChannelAttention + self.attn = attn_layer( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) @@ -372,13 +420,16 @@ def __init__( attn_types=('spatial', 'channel'), num_heads=3, window_size=7, - mlp_ratio=4, + mlp_ratio=4., qkv_bias=True, drop_path_rates=(0, 0), norm_layer=LayerNorm2d, norm_layer_cl=nn.LayerNorm, ffn=True, - cpe_act=False + cpe_act=False, + down_kernel_size=2, + named_blocks=False, + channel_attn_v2=False, ): super().__init__() @@ -386,7 +437,7 @@ def __init__( # downsample embedding layer at the beginning of each stage if downsample: - self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer) + self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer) else: self.downsample = nn.Identity() @@ -399,10 +450,11 @@ def __init__( ''' stage_blocks = [] for block_idx in range(depth): + from collections import OrderedDict dual_attention_block = [] for attn_idx, attn_type in enumerate(attn_types): if attn_type == 'spatial': - dual_attention_block.append(SpatialBlock( + dual_attention_block.append(('spatial_block', SpatialBlock( dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, @@ -412,9 +464,9 @@ def __init__( ffn=ffn, cpe_act=cpe_act, window_size=window_size, - )) + ))) elif attn_type == 'channel': - dual_attention_block.append(ChannelBlock( + dual_attention_block.append(('channel_block', ChannelBlock( dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, @@ -422,9 +474,13 @@ def __init__( drop_path=drop_path_rates[block_idx], norm_layer=norm_layer_cl, ffn=ffn, - cpe_act=cpe_act - )) - stage_blocks.append(nn.Sequential(*dual_attention_block)) + cpe_act=cpe_act, + v2=channel_attn_v2, + ))) + if named_blocks: + stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block))) + else: + stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block])) self.blocks = nn.Sequential(*stage_blocks) @torch.jit.ignore @@ -473,6 +529,9 @@ def __init__( attn_types=('spatial', 'channel'), ffn=True, cpe_act=False, + down_kernel_size=2, + channel_attn_v2=False, + named_blocks=False, drop_rate=0., drop_path_rate=0., num_classes=1000, @@ -512,6 +571,9 @@ def __init__( norm_layer_cl=norm_layer_cl, ffn=ffn, cpe_act=cpe_act, + down_kernel_size=down_kernel_size, + channel_attn_v2=channel_attn_v2, + named_blocks=named_blocks, ) in_chs = out_chs stages.append(stage) @@ -589,6 +651,34 @@ def forward(self, x): return x +def _convert_florence2(state_dict, model, prefix='vision_tower.'): + import re + out_dict = {} + + for k, v in state_dict.items(): + if k.startswith(prefix): + k = k.replace(prefix, '') + else: + continue + k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k) + k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k) + k = k.replace('downsample.proj', 'downsample.conv') + k = k.replace('stages.0.downsample', 'stem') + #k = k.replace('head.', 'head.fc.') + #k = k.replace('norms.', 'head.norm.') + k = k.replace('window_attn.norm.', 'norm1.') + k = k.replace('window_attn.fn.', 'attn.') + k = k.replace('channel_attn.norm.', 'norm1.') + k = k.replace('channel_attn.fn.', 'attn.') + k = k.replace('ffn.norm.', 'norm2.') + k = k.replace('ffn.fn.net.', 'mlp.') + k = k.replace('conv1.fn.dw', 'cpe1.proj') + k = k.replace('conv2.fn.dw', 'cpe2.proj') + out_dict[k] = v + + return out_dict + + def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ if 'head.fc.weight' in state_dict: @@ -597,6 +687,9 @@ def checkpoint_filter_fn(state_dict, model): if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] + if 'vision_tower.convs.0.proj.weight' in state_dict: + return _convert_florence2(state_dict, model) + import re out_dict = {} for k, v in state_dict.items(): @@ -615,13 +708,17 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) - + strict = True + if variant.endswith('_fl'): + # FIXME cleaner approach to missing head norm? + strict = False model = build_model_with_cfg( DaVit, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + pretrained_strict=strict, **kwargs) return model @@ -650,6 +747,12 @@ def _cfg(url='', **kwargs): 'davit_large': _cfg(), 'davit_huge': _cfg(), 'davit_giant': _cfg(), + 'davit_base_fl.msft_florence2': _cfg( + hf_hub_id='microsoft/Florence-2-base', + num_classes=0, input_size=(3, 768, 768)), + 'davit_huge_fl.msft_florence2': _cfg( + hf_hub_id='microsoft/Florence-2-large', + num_classes=0, input_size=(3, 768, 768)), }) @@ -687,3 +790,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit: def davit_giant(pretrained=False, **kwargs) -> DaVit: model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96)) return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs)) + + + +@register_model +def davit_base_fl(pretrained=False, **kwargs) -> DaVit: + model_args = dict( + depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), + window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True, + ) + return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def davit_huge_fl(pretrained=False, **kwargs) -> DaVit: + # NOTE: huge image tower used in 'large' Florence2 model + model_args = dict( + depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), + window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True, + ) + return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))