Skip to content

Commit

Permalink
Update RepViT models
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslahm committed Oct 1, 2023
1 parent 054c763 commit f6ace32
Showing 1 changed file with 106 additions and 16 deletions.
122 changes: 106 additions & 16 deletions timm/models/repvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,31 @@ def fuse(self):


class RepVggDw(nn.Module):
def __init__(self, ed, kernel_size):
def __init__(self, ed, kernel_size, legacy=False):
super().__init__()
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
if legacy:
self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
else:
self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
self.bn = nn.BatchNorm2d(ed)
self.dim = ed
self.legacy = legacy

def forward(self, x):
return self.conv(x) + self.conv1(x) + x
if self.legacy:
return self.conv(x) + self.conv1(x) + x
else:
return self.bn(self.conv(x) + self.conv1(x) + x)

@torch.no_grad()
def fuse(self):
conv = self.conv.fuse()
conv1 = self.conv1.fuse()

if self.legacy:
conv1 = self.conv1.fuse()
else:
conv1 = self.conv1

conv_w = conv.weight
conv_b = conv.bias
Expand All @@ -112,6 +124,15 @@ def fuse(self):

conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)

if not self.legacy:
bn = self.bn
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = conv.weight * w[:, None, None, None]
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
(bn.running_var + bn.eps)**0.5
conv.weight.data.copy_(w)
conv.bias.data.copy_(b)
return conv


Expand All @@ -127,10 +148,10 @@ def forward(self, x):


class RepViTBlock(nn.Module):
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy=False):
super(RepViTBlock, self).__init__()

self.token_mixer = RepVggDw(in_dim, kernel_size)
self.token_mixer = RepVggDw(in_dim, kernel_size, legacy)
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)

Expand Down Expand Up @@ -211,7 +232,7 @@ def fuse(self):


class RepVitStage(nn.Module):
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True, legacy=False):
super().__init__()
if downsample:
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
Expand All @@ -222,7 +243,7 @@ def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3,
blocks = []
use_se = True
for _ in range(depth):
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer))
blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy))
use_se = not use_se

self.blocks = nn.Sequential(*blocks)
Expand All @@ -247,6 +268,7 @@ def __init__(
act_layer=nn.GELU,
distillation=True,
drop_rate=0.,
legacy=False
):
super(RepVit, self).__init__()
self.grad_checkpointing = False
Expand Down Expand Up @@ -275,6 +297,7 @@ def __init__(
act_layer=act_layer,
kernel_size=kernel_size,
downsample=downsample,
legacy=legacy
)
)
stage_stride = 2 if downsample else 1
Expand Down Expand Up @@ -369,15 +392,42 @@ def _cfg(url='', **kwargs):
{
'repvit_m1.dist_in1k': _cfg(
hf_hub_id='timm/',
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth'
),
'repvit_m2.dist_in1k': _cfg(
hf_hub_id='timm/',
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth'
),
'repvit_m3.dist_in1k': _cfg(
hf_hub_id='timm/',
# url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth'
),
'repvit_m0_9.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_300e_timm.pth'
),
'repvit_m0_9.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_450e_timm.pth'
),
'repvit_m1_0.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_300e_timm.pth'
),
'repvit_m1_0.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_450e_timm.pth'
),
'repvit_m1_1.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_300e_timm.pth'
),
'repvit_m1_1.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_450e_timm.pth'
),
'repvit_m1_5.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_300e_timm.pth'
),
'repvit_m1_5.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_450e_timm.pth'
),
'repvit_m2_3.dist_in1k_300e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_300e_timm.pth'
),
'repvit_m2_3.dist_in1k_450e': _cfg(
url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_450e_timm.pth'
),
}
)
Expand All @@ -398,23 +448,63 @@ def repvit_m1(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1 model
"""
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m2(pretrained=False, **kwargs):
"""
Constructs a RepViT-M2 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m3(pretrained=False, **kwargs):
"""
Constructs a RepViT-M3 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2))
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))

@register_model
def repvit_m0_9(pretrained=False, **kwargs):
"""
Constructs a RepViT-M0.9 model
"""
model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m1_0(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1.0 model
"""
model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def repvit_m1_1(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1.1 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))

@register_model
def repvit_m1_5(pretrained=False, **kwargs):
"""
Constructs a RepViT-M1.5 model
"""
model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))

@register_model
def repvit_m2_3(pretrained=False, **kwargs):
"""
Constructs a RepViT-M2.3 model
"""
model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))

0 comments on commit f6ace32

Please sign in to comment.