Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #2212 map florence2 image tower to davit with a few changes #2213

Merged
merged 3 commits into from
Jun 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 143 additions & 18 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__()

self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
self.proj = nn.Conv2d(
dim,
dim,
kernel_size=k,
stride=1,
padding=k // 2,
groups=dim,
)
self.act = nn.GELU() if act else nn.Identity()

def forward(self, x: Tensor):
Expand Down Expand Up @@ -72,8 +79,9 @@ def __init__(

def forward(self, x: Tensor):
B, C, H, W = x.shape
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
x = F.pad(x, (0, pad_r, 0, pad_b))
x = self.conv(x)
x = self.norm(x)
return x
Expand All @@ -84,30 +92,66 @@ def __init__(
self,
in_chs,
out_chs,
kernel_size=3,
norm_layer=LayerNorm2d,
):
super().__init__()
self.in_chs = in_chs
self.out_chs = out_chs

self.norm = norm_layer(in_chs)
self.even_k = kernel_size % 2 == 0
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
kernel_size=kernel_size,
stride=2,
padding=0,
padding=0 if self.even_k else kernel_size // 2,
)

def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2))
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
if self.even_k:
k_h, k_w = self.conv.kernel_size
pad_r = (k_w - W % k_w) % k_w
pad_b = (k_h - H % k_h) % k_h
x = F.pad(x, (0, pad_r , 0, pad_b))
x = self.conv(x)
return x


class ChannelAttentionV2(nn.Module):

def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True):
super().__init__()
self.groups = num_heads
self.head_dim = dim // num_heads
self.dynamic_scale = dynamic_scale

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)

def forward(self, x):
B, N, C = x.shape

qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)

if self.dynamic_scale:
q = q * N ** -0.5
else:
q = q * self.head_dim ** -0.5
attn = q.transpose(-1, -2) @ k
attn = attn.softmax(dim=-1)
x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)

x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x



class ChannelAttention(nn.Module):

def __init__(self, dim, num_heads=8, qkv_bias=False):
Expand Down Expand Up @@ -147,13 +191,19 @@ def __init__(
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
v2=False,
):
super().__init__()

self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
self.attn = attn_layer(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)

Expand Down Expand Up @@ -372,21 +422,24 @@ def __init__(
attn_types=('spatial', 'channel'),
num_heads=3,
window_size=7,
mlp_ratio=4,
mlp_ratio=4.,
qkv_bias=True,
drop_path_rates=(0, 0),
norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm,
ffn=True,
cpe_act=False
cpe_act=False,
down_kernel_size=2,
named_blocks=False,
channel_attn_v2=False,
):
super().__init__()

self.grad_checkpointing = False

# downsample embedding layer at the beginning of each stage
if downsample:
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()

Expand All @@ -399,10 +452,11 @@ def __init__(
'''
stage_blocks = []
for block_idx in range(depth):
from collections import OrderedDict
dual_attention_block = []
for attn_idx, attn_type in enumerate(attn_types):
if attn_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dual_attention_block.append(('spatial_block', SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
Expand All @@ -412,19 +466,23 @@ def __init__(
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
)))
elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock(
dual_attention_block.append(('channel_block', ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
cpe_act=cpe_act,
v2=channel_attn_v2,
)))
if named_blocks:
stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
else:
stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
self.blocks = nn.Sequential(*stage_blocks)

@torch.jit.ignore
Expand Down Expand Up @@ -473,6 +531,9 @@ def __init__(
attn_types=('spatial', 'channel'),
ffn=True,
cpe_act=False,
down_kernel_size=2,
channel_attn_v2=False,
named_blocks=False,
drop_rate=0.,
drop_path_rate=0.,
num_classes=1000,
Expand Down Expand Up @@ -512,6 +573,9 @@ def __init__(
norm_layer_cl=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
down_kernel_size=down_kernel_size,
channel_attn_v2=channel_attn_v2,
named_blocks=named_blocks,
)
in_chs = out_chs
stages.append(stage)
Expand Down Expand Up @@ -589,6 +653,34 @@ def forward(self, x):
return x


def _convert_florence2(state_dict, model, prefix='vision_tower.'):
import re
out_dict = {}

for k, v in state_dict.items():
if k.startswith(prefix):
k = k.replace(prefix, '')
else:
continue
k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
#k = k.replace('head.', 'head.fc.')
#k = k.replace('norms.', 'head.norm.')
k = k.replace('window_attn.norm.', 'norm1.')
k = k.replace('window_attn.fn.', 'attn.')
k = k.replace('channel_attn.norm.', 'norm1.')
k = k.replace('channel_attn.fn.', 'attn.')
k = k.replace('ffn.norm.', 'norm2.')
k = k.replace('ffn.fn.net.', 'mlp.')
k = k.replace('conv1.fn.dw', 'cpe1.proj')
k = k.replace('conv2.fn.dw', 'cpe2.proj')
out_dict[k] = v

return out_dict


def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.fc.weight' in state_dict:
Expand All @@ -597,6 +689,9 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']

if 'vision_tower.convs.0.proj.weight' in state_dict:
return _convert_florence2(state_dict, model)

import re
out_dict = {}
for k, v in state_dict.items():
Expand All @@ -615,13 +710,17 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)

strict = True
if variant.endswith('_fl'):
# FIXME cleaner approach to missing head norm?
strict = False
model = build_model_with_cfg(
DaVit,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
pretrained_strict=strict,
**kwargs)

return model
Expand Down Expand Up @@ -650,6 +749,12 @@ def _cfg(url='', **kwargs):
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
'davit_base_fl.msft_florence2': _cfg(
hf_hub_id='microsoft/Florence-2-base',
num_classes=0, input_size=(3, 768, 768)),
'davit_huge_fl.msft_florence2': _cfg(
hf_hub_id='microsoft/Florence-2-large',
num_classes=0, input_size=(3, 768, 768)),
})


Expand Down Expand Up @@ -687,3 +792,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
def davit_giant(pretrained=False, **kwargs) -> DaVit:
model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))



@register_model
def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
model_args = dict(
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
)
return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
# NOTE: huge image tower used in 'large' Florence2 model
model_args = dict(
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
)
return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))
Loading