Skip to content

Commit

Permalink
Fix #2212 map florence2 image tower to davit with a few changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 21, 2024
1 parent b28945f commit fb58a73
Showing 1 changed file with 139 additions and 16 deletions.
155 changes: 139 additions & 16 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__()

self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
self.proj = nn.Conv2d(
dim,
dim,
kernel_size=k,
stride=1,
padding=k // 2,
groups=dim,
)
self.act = nn.GELU() if act else nn.Identity()

def forward(self, x: Tensor):
Expand Down Expand Up @@ -84,30 +91,65 @@ def __init__(
self,
in_chs,
out_chs,
kernel_size=3,
norm_layer=LayerNorm2d,
):
super().__init__()
self.in_chs = in_chs
self.out_chs = out_chs

self.norm = norm_layer(in_chs)
self.even_k = kernel_size % 2 == 0
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
kernel_size=kernel_size,
stride=2,
padding=0,
padding=0 if self.even_k else kernel_size // 2,
)

def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2))
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
if self.even_k:
k_h, k_w = self.conv.kernel_size
x = F.pad(x, (0, (k_w - W % k_w) % k_w))
x = F.pad(x, (0, 0, 0, (k_h - H % k_h) % k_h))
x = self.conv(x)
return x


class ChannelAttentionV2(nn.Module):

def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True):
super().__init__()
self.groups = num_heads
self.head_dim = dim // num_heads
self.dynamic_scale = dynamic_scale

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)

def forward(self, x):
B, N, C = x.shape

qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

if self.dynamic_scale:
q = q * float(N) ** -0.5
else:
q = q * self.head_dim ** -0.5
attn = q.transpose(-1, -2) @ k
attn = attn.softmax(dim=-1)
x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)

x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x



class ChannelAttention(nn.Module):

def __init__(self, dim, num_heads=8, qkv_bias=False):
Expand Down Expand Up @@ -147,13 +189,19 @@ def __init__(
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
v2=False,
):
super().__init__()

self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
self.attn = attn_layer(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)

Expand Down Expand Up @@ -372,21 +420,24 @@ def __init__(
attn_types=('spatial', 'channel'),
num_heads=3,
window_size=7,
mlp_ratio=4,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rates=(0, 0),
norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm,
ffn=True,
cpe_act=False
cpe_act=False,
down_kernel_size=2,
named_blocks=False,
channel_attn_v2=False,
):
super().__init__()

self.grad_checkpointing = False

# downsample embedding layer at the beginning of each stage
if downsample:
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()

Expand All @@ -399,10 +450,11 @@ def __init__(
'''
stage_blocks = []
for block_idx in range(depth):
from collections import OrderedDict
dual_attention_block = []
for attn_idx, attn_type in enumerate(attn_types):
if attn_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dual_attention_block.append(('spatial_block', SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
Expand All @@ -412,19 +464,23 @@ def __init__(
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
)))
elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock(
dual_attention_block.append(('channel_block', ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
cpe_act=cpe_act,
v2=channel_attn_v2,
)))
if named_blocks:
stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
else:
stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
self.blocks = nn.Sequential(*stage_blocks)

@torch.jit.ignore
Expand Down Expand Up @@ -473,6 +529,9 @@ def __init__(
attn_types=('spatial', 'channel'),
ffn=True,
cpe_act=False,
down_kernel_size=2,
channel_attn_v2=False,
named_blocks=False,
drop_rate=0.,
drop_path_rate=0.,
num_classes=1000,
Expand Down Expand Up @@ -512,6 +571,9 @@ def __init__(
norm_layer_cl=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
down_kernel_size=down_kernel_size,
channel_attn_v2=channel_attn_v2,
named_blocks=named_blocks,
)
in_chs = out_chs
stages.append(stage)
Expand Down Expand Up @@ -589,6 +651,34 @@ def forward(self, x):
return x


def _convert_florence2(state_dict, model, prefix='vision_tower.'):
import re
out_dict = {}

for k, v in state_dict.items():
if k.startswith(prefix):
k = k.replace(prefix, '')
else:
continue
k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
#k = k.replace('head.', 'head.fc.')
#k = k.replace('norms.', 'head.norm.')
k = k.replace('window_attn.norm.', 'norm1.')
k = k.replace('window_attn.fn.', 'attn.')
k = k.replace('channel_attn.norm.', 'norm1.')
k = k.replace('channel_attn.fn.', 'attn.')
k = k.replace('ffn.norm.', 'norm2.')
k = k.replace('ffn.fn.net.', 'mlp.')
k = k.replace('conv1.fn.dw', 'cpe1.proj')
k = k.replace('conv2.fn.dw', 'cpe2.proj')
out_dict[k] = v

return out_dict


def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.fc.weight' in state_dict:
Expand All @@ -597,6 +687,9 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']

if 'vision_tower.convs.0.proj.weight' in state_dict:
return _convert_florence2(state_dict, model)

import re
out_dict = {}
for k, v in state_dict.items():
Expand All @@ -615,13 +708,17 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)

strict = True
if variant.endswith('_fl'):
# FIXME cleaner approach to missing head norm?
strict = False
model = build_model_with_cfg(
DaVit,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
pretrained_strict=strict,
**kwargs)

return model
Expand Down Expand Up @@ -650,6 +747,12 @@ def _cfg(url='', **kwargs):
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
'davit_base_fl.msft_florence2': _cfg(
hf_hub_id='microsoft/Florence-2-base',
num_classes=0, input_size=(3, 768, 768)),
'davit_huge_fl.msft_florence2': _cfg(
hf_hub_id='microsoft/Florence-2-large',
num_classes=0, input_size=(3, 768, 768)),
})


Expand Down Expand Up @@ -687,3 +790,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
def davit_giant(pretrained=False, **kwargs) -> DaVit:
model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))



@register_model
def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
model_args = dict(
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
)
return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
# NOTE: huge image tower used in 'large' Florence2 model
model_args = dict(
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
)
return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))

0 comments on commit fb58a73

Please sign in to comment.