diff --git a/timm/models/repvit.py b/timm/models/repvit.py index 43e35be900..55b6ba4bea 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -15,7 +15,7 @@ Adapted from official impl at https://github.com/jameslahm/RepViT """ -__all__ = ['RepViT'] +__all__ = ['RepVit'] import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -81,7 +81,7 @@ def fuse(self): return m -class RepVGGDW(nn.Module): +class RepVggDw(nn.Module): def __init__(self, ed, kernel_size): super().__init__() self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed) @@ -115,7 +115,7 @@ def fuse(self): return conv -class RepViTMlp(nn.Module): +class RepVitMlp(nn.Module): def __init__(self, in_dim, hidden_dim, act_layer): super().__init__() self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0) @@ -130,9 +130,9 @@ class RepViTBlock(nn.Module): def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer): super(RepViTBlock, self).__init__() - self.token_mixer = RepVGGDW(in_dim, kernel_size) + self.token_mixer = RepVggDw(in_dim, kernel_size) self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity() - self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer) + self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer) def forward(self, x): x = self.token_mixer(x) @@ -142,7 +142,7 @@ def forward(self, x): return identity + x -class RepViTStem(nn.Module): +class RepVitStem(nn.Module): def __init__(self, in_chs, out_chs, act_layer): super().__init__() self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1) @@ -154,13 +154,13 @@ def forward(self, x): return self.conv2(self.act1(self.conv1(x))) -class RepViTDownsample(nn.Module): +class RepVitDownsample(nn.Module): def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer): super().__init__() self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer) self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim) self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1) - self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer) + self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer) def forward(self, x): x = self.pre_block(x) @@ -171,22 +171,25 @@ def forward(self, x): return x + identity -class RepViTClassifier(nn.Module): - def __init__(self, dim, num_classes, distillation=False): +class RepVitClassifier(nn.Module): + def __init__(self, dim, num_classes, distillation=False, drop=0.): super().__init__() + self.head_drop = nn.Dropout(drop) self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() self.distillation = distillation - self.num_classes=num_classes + self.distilled_training = False + self.num_classes = num_classes if distillation: self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x): + x = self.head_drop(x) if self.distillation: x1, x2 = self.head(x), self.head_dist(x) - if (not self.training) or torch.jit.is_scripting(): - return (x1 + x2) / 2 - else: + if self.training and self.distilled_training and not torch.jit.is_scripting(): return x1, x2 + else: + return (x1 + x2) / 2 else: x = self.head(x) return x @@ -207,11 +210,11 @@ def fuse(self): return head -class RepViTStage(nn.Module): +class RepVitStage(nn.Module): def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True): super().__init__() if downsample: - self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) + self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer) else: assert in_dim == out_dim self.downsample = nn.Identity() @@ -230,7 +233,7 @@ def forward(self, x): return x -class RepViT(nn.Module): +class RepVit(nn.Module): def __init__( self, in_chans=3, @@ -243,15 +246,16 @@ def __init__( num_classes=1000, act_layer=nn.GELU, distillation=True, + drop_rate=0., ): - super(RepViT, self).__init__() + super(RepVit, self).__init__() self.grad_checkpointing = False self.global_pool = global_pool self.embed_dim = embed_dim self.num_classes = num_classes in_dim = embed_dim[0] - self.stem = RepViTStem(in_chans, in_dim, act_layer) + self.stem = RepVitStem(in_chans, in_dim, act_layer) stride = self.stem.stride resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))]) @@ -263,7 +267,7 @@ def __init__( for i in range(num_stages): downsample = True if i != 0 else False stages.append( - RepViTStage( + RepVitStage( in_dim, embed_dim[i], depth[i], @@ -281,7 +285,8 @@ def __init__( self.stages = nn.Sequential(*stages) self.num_features = embed_dim[-1] - self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation) + self.head_drop = nn.Dropout(drop_rate) + self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation) @torch.jit.ignore def group_matcher(self, coarse=False): @@ -304,9 +309,13 @@ def reset_classifier(self, num_classes, global_pool=None, distillation=False): if global_pool is not None: self.global_pool = global_pool self.head = ( - RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity() + RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity() ) + @torch.jit.ignore + def set_distilled_training(self, enable=True): + self.head.distilled_training = enable + def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -317,8 +326,9 @@ def forward_features(self, x): def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': - x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) - return x if pre_logits else self.head(x) + x = x.mean((2, 3), keepdim=False) + x = self.head_drop(x) + return self.head(x) def forward(self, x): x = self.forward_features(x) @@ -373,7 +383,9 @@ def _cfg(url='', **kwargs): def _create_repvit(variant, pretrained=False, **kwargs): out_indices = kwargs.pop('out_indices', (0, 1, 2, 3)) model = build_model_with_cfg( - RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs + RepVit, variant, pretrained, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs, ) return model