Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Update RepViT models #1971

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 117 additions & 25 deletions timm/models/repvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,30 @@ def fuse(self):


class RepVggDw(nn.Module):
def __init__(self, ed, kernel_size):
def __init__(self, ed, kernel_size, legacy=False):
super().__init__()
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
if legacy:
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
# Make torchscript happy.
self.bn = nn.Identity()
else:
self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
self.bn = nn.BatchNorm2d(ed)
self.dim = ed
self.legacy = legacy

def forward(self, x):
return self.conv(x) + self.conv1(x) + x
return self.bn(self.conv(x) + self.conv1(x) + x)

@torch.no_grad()
def fuse(self):
conv = self.conv.fuse()
conv1 = self.conv1.fuse()

if self.legacy:
conv1 = self.conv1.fuse()
else:
conv1 = self.conv1

conv_w = conv.weight
conv_b = conv.bias
Expand All @@ -112,6 +123,14 @@ def fuse(self):

conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)

if not self.legacy:
bn = self.bn
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = conv.weight * w[:, None, None, None]
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5
conv.weight.data.copy_(w)
conv.bias.data.copy_(b)
return conv


Expand All @@ -127,10 +146,10 @@ def forward(self, x):


class RepViTBlock(nn.Module):
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy=False):
super(RepViTBlock, self).__init__()

self.token_mixer = RepVggDw(in_dim, kernel_size)
self.token_mixer = RepVggDw(in_dim, kernel_size, legacy)
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)

Expand All @@ -155,9 +174,9 @@ def forward(self, x):


class RepVitDownsample(nn.Module):
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy=False):
super().__init__()
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer, legacy=legacy)
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
Expand All @@ -172,7 +191,7 @@ def forward(self, x):


class RepVitClassifier(nn.Module):
def __init__(self, dim, num_classes, distillation=False, drop=0.):
def __init__(self, dim, num_classes, distillation=False, drop=0.0):
super().__init__()
self.head_drop = nn.Dropout(drop)
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
Expand Down Expand Up @@ -211,18 +230,18 @@ def fuse(self):


class RepVitStage(nn.Module):
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True, legacy=False):
super().__init__()
if downsample:
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy)
else:
assert in_dim == out_dim
self.downsample = nn.Identity()

blocks = []
use_se = True
for _ in range(depth):
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer))
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy))
use_se = not use_se

self.blocks = nn.Sequential(*blocks)
Expand All @@ -246,7 +265,8 @@ def __init__(
num_classes=1000,
act_layer=nn.GELU,
distillation=True,
drop_rate=0.,
drop_rate=0.0,
legacy=False,
):
super(RepVit, self).__init__()
self.grad_checkpointing = False
Expand Down Expand Up @@ -275,6 +295,7 @@ def __init__(
act_layer=act_layer,
kernel_size=kernel_size,
downsample=downsample,
legacy=legacy,
)
)
stage_stride = 2 if downsample else 1
Expand All @@ -290,12 +311,9 @@ def __init__(

@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(
stem=r'^stem', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed
return matcher

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
Expand Down Expand Up @@ -369,15 +387,42 @@ def _cfg(url='', **kwargs):
{
'repvit_m1.dist_in1k': _cfg(
hf_hub_id='timm/',
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
),
'repvit_m2.dist_in1k': _cfg(
hf_hub_id='timm/',
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
),
'repvit_m3.dist_in1k': _cfg(
hf_hub_id='timm/',
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
),
'repvit_m0_9.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_300e_timm.pth'
),
'repvit_m0_9.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_450e_timm.pth'
),
'repvit_m1_0.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_300e_timm.pth'
),
'repvit_m1_0.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_450e_timm.pth'
),
'repvit_m1_1.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_300e_timm.pth'
),
'repvit_m1_1.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_450e_timm.pth'
),
'repvit_m1_5.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_300e_timm.pth'
),
'repvit_m1_5.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_450e_timm.pth'
),
'repvit_m2_3.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_300e_timm.pth'
),
'repvit_m2_3.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_450e_timm.pth'
),
}
)
Expand All @@ -386,7 +431,9 @@ def _cfg(url='', **kwargs):
def _create_repvit(variant, pretrained=False, **kwargs):
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
model = build_model_with_cfg(
RepVit, variant, pretrained,
RepVit,
variant,
pretrained,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs,
)
Expand All @@ -398,7 +445,7 @@ def repvit_m1(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1 model
"""
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))


Expand All @@ -407,7 +454,7 @@ def repvit_m2(pretrained=False, **kwargs):
"""
Constructs a RepViT-M2 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))


Expand All @@ -416,5 +463,50 @@ def repvit_m3(pretrained=False, **kwargs):
"""
Constructs a RepViT-M3 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2))
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m0_9(pretrained=False, **kwargs):
"""
Constructs a RepViT-M0.9 model
"""
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m1_0(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1.0 model
"""
model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m1_1(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1.1 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m1_5(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1.5 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m2_3(pretrained=False, **kwargs):
"""
Constructs a RepViT-M2.3 model
"""
model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))
Loading