diff --git a/timm/models/repvit.py b/timm/models/repvit.py index a0def2f41c..cc6c9bf743 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -82,19 +82,30 @@ def fuse(self): class RepVggDw(nn.Module): - def __init__(self, ed, kernel_size): + def __init__(self, ed, kernel_size, legacy=False): super().__init__() self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed) - self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed) + if legacy: + self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed) + # Make torchscript happy. + self.bn = nn.Identity() + else: + self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed) + self.bn = nn.BatchNorm2d(ed) self.dim = ed + self.legacy = legacy def forward(self, x): - return self.conv(x) + self.conv1(x) + x + return self.bn(self.conv(x) + self.conv1(x) + x) @torch.no_grad() def fuse(self): conv = self.conv.fuse() - conv1 = self.conv1.fuse() + + if self.legacy: + conv1 = self.conv1.fuse() + else: + conv1 = self.conv1 conv_w = conv.weight conv_b = conv.bias @@ -112,6 +123,14 @@ def fuse(self): conv.weight.data.copy_(final_conv_w) conv.bias.data.copy_(final_conv_b) + + if not self.legacy: + bn = self.bn + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = conv.weight * w[:, None, None, None] + b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5 + conv.weight.data.copy_(w) + conv.bias.data.copy_(b) return conv @@ -127,10 +146,10 @@ def forward(self, x): class RepViTBlock(nn.Module): - def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer): + def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy=False): super(RepViTBlock, self).__init__() - self.token_mixer = RepVggDw(in_dim, kernel_size) + self.token_mixer = RepVggDw(in_dim, kernel_size, legacy) self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity() self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer) @@ -155,9 +174,9 @@ def forward(self, x): class RepVitDownsample(nn.Module): - def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer): + def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy=False): super().__init__() - self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer) + self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer, legacy=legacy) self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim) self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1) self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer) @@ -172,7 +191,7 @@ def forward(self, x): class RepVitClassifier(nn.Module): - def __init__(self, dim, num_classes, distillation=False, drop=0.): + def __init__(self, dim, num_classes, distillation=False, drop=0.0): super().__init__() self.head_drop = nn.Dropout(drop) self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() @@ -211,10 +230,10 @@ def fuse(self): class RepVitStage(nn.Module): - def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True): + def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True, legacy=False): super().__init__() if downsample: - self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) + self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy) else: assert in_dim == out_dim self.downsample = nn.Identity() @@ -222,7 +241,7 @@ def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, blocks = [] use_se = True for _ in range(depth): - blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer)) + blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy)) use_se = not use_se self.blocks = nn.Sequential(*blocks) @@ -246,7 +265,8 @@ def __init__( num_classes=1000, act_layer=nn.GELU, distillation=True, - drop_rate=0., + drop_rate=0.0, + legacy=False, ): super(RepVit, self).__init__() self.grad_checkpointing = False @@ -275,6 +295,7 @@ def __init__( act_layer=act_layer, kernel_size=kernel_size, downsample=downsample, + legacy=legacy, ) ) stage_stride = 2 if downsample else 1 @@ -290,12 +311,9 @@ def __init__( @torch.jit.ignore def group_matcher(self, coarse=False): - matcher = dict( - stem=r'^stem', # stem and embed - blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] - ) + matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed return matcher - + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @@ -369,15 +387,42 @@ def _cfg(url='', **kwargs): { 'repvit_m1.dist_in1k': _cfg( hf_hub_id='timm/', - # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_distill_300_timm.pth' ), 'repvit_m2.dist_in1k': _cfg( hf_hub_id='timm/', - # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_distill_300_timm.pth' ), 'repvit_m3.dist_in1k': _cfg( hf_hub_id='timm/', - # url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m3_distill_300_timm.pth' + ), + 'repvit_m0_9.dist_in1k_300e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_300e_timm.pth' + ), + 'repvit_m0_9.dist_in1k_450e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m0_9_distill_450e_timm.pth' + ), + 'repvit_m1_0.dist_in1k_300e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_300e_timm.pth' + ), + 'repvit_m1_0.dist_in1k_450e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_0_distill_450e_timm.pth' + ), + 'repvit_m1_1.dist_in1k_300e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_300e_timm.pth' + ), + 'repvit_m1_1.dist_in1k_450e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_1_distill_450e_timm.pth' + ), + 'repvit_m1_5.dist_in1k_300e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_300e_timm.pth' + ), + 'repvit_m1_5.dist_in1k_450e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m1_5_distill_450e_timm.pth' + ), + 'repvit_m2_3.dist_in1k_300e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_300e_timm.pth' + ), + 'repvit_m2_3.dist_in1k_450e': _cfg( + url='https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_m2_3_distill_450e_timm.pth' ), } ) @@ -386,7 +431,9 @@ def _cfg(url='', **kwargs): def _create_repvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( - RepVit, variant, pretrained, + RepVit, + variant, + pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs, ) @@ -398,7 +445,7 @@ def repvit_m1(pretrained=False, **kwargs): """ Constructs a RepViT-M1 model """ - model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2)) + model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True) return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -407,7 +454,7 @@ def repvit_m2(pretrained=False, **kwargs): """ Constructs a RepViT-M2 model """ - model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2)) + model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True) return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -416,5 +463,50 @@ def repvit_m3(pretrained=False, **kwargs): """ Constructs a RepViT-M3 model """ - model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2)) + model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True) return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def repvit_m0_9(pretrained=False, **kwargs): + """ + Constructs a RepViT-M0.9 model + """ + model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2)) + return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def repvit_m1_0(pretrained=False, **kwargs): + """ + Constructs a RepViT-M1.0 model + """ + model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2)) + return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def repvit_m1_1(pretrained=False, **kwargs): + """ + Constructs a RepViT-M1.1 model + """ + model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2)) + return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def repvit_m1_5(pretrained=False, **kwargs): + """ + Constructs a RepViT-M1.5 model + """ + model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4)) + return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def repvit_m2_3(pretrained=False, **kwargs): + """ + Constructs a RepViT-M2.3 model + """ + model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2)) + return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))