diff --git a/tests/test_models.py b/tests/test_models.py index ace88690f4..9f7a91546c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -169,6 +169,18 @@ def test_model_backward(model_name, batch_size): assert not torch.isnan(outputs).any(), 'Output included NaNs' +# models with extra conv/linear layers after pooling +EARLY_POOL_MODELS = ( + timm.models.EfficientVit, + timm.models.EfficientVitLarge, + timm.models.HighPerfGpuNet, + timm.models.GhostNet, + timm.models.MetaNeXt, # InceptionNeXt + timm.models.MobileNetV3, + timm.models.RepGhostNet, + timm.models.VGG, +) + @pytest.mark.cfg @pytest.mark.timeout(timeout300) @pytest.mark.parametrize('model_name', list_models( @@ -179,6 +191,9 @@ def test_model_default_cfgs(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() model.to(torch_device) + assert getattr(model, 'num_classes') >= 0 + assert getattr(model, 'num_features') > 0 + assert getattr(model, 'head_hidden_size') > 0 state_dict = model.state_dict() cfg = model.default_cfg @@ -195,37 +210,37 @@ def test_model_default_cfgs(model_name, batch_size): input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size]) input_tensor = torch.randn((batch_size, *input_size), device=torch_device) - # test forward_features (always unpooled) + # test forward_features (always unpooled) & forward_head w/ pre_logits outputs = model.forward_features(input_tensor) - assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config' - assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config' - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): - assert outputs.shape[feat_axis] == model.num_features + outputs_pre = model.forward_head(outputs, pre_logits=True) + assert outputs.shape[spatial_axis[0]] == pool_size[0], f'unpooled feature shape {outputs.shape} != config' + assert outputs.shape[spatial_axis[1]] == pool_size[1], f'unpooled feature shape {outputs.shape} != config' + assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}' + assert outputs_pre.shape[1] == model.head_hidden_size, f'pre_logits feature dim {outputs_pre.shape[1]} != model.head_hidden_size {model.head_hidden_size}' # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) model.to(torch_device) outputs = model.forward(input_tensor) assert len(outputs.shape) == 2 - assert outputs.shape[1] == model.num_features + assert outputs.shape[1] == model.head_hidden_size, f'feature dim w/ removed classifier {outputs.shape[1]} != model.head_hidden_size {model.head_hidden_size}' + assert outputs.shape == outputs_pre.shape, f'output shape of pre_logits {outputs_pre.shape} does not match reset_head(0) {outputs.shape}' - # test model forward without pooling and classifier - model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through - model.to(torch_device) - outputs = model.forward(input_tensor) - assert len(outputs.shape) == 4 - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): - # mobilenetv3/ghostnet/repghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP + # test model forward after removing pooling and classifier + if not isinstance(model, EARLY_POOL_MODELS): + model.reset_classifier(0, '') # reset classifier and disable global pooling + model.to(torch_device) + outputs = model.forward(input_tensor) + assert len(outputs.shape) == 4 assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] - if 'pruned' not in model_name: # FIXME better pruned model handling - # test classifier + global pool deletion via __init__ + # test classifier + global pool deletion via __init__ + if 'pruned' not in model_name and not isinstance(model, EARLY_POOL_MODELS): model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval() model.to(torch_device) outputs = model.forward(input_tensor) assert len(outputs.shape) == 4 - if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)): - assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] + assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1] # check classifier name matches default_cfg if cfg.get('num_classes', None): @@ -253,6 +268,9 @@ def test_model_default_cfgs_non_std(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() model.to(torch_device) + assert getattr(model, 'num_classes') >= 0 + assert getattr(model, 'num_features') > 0 + assert getattr(model, 'head_hidden_size') > 0 state_dict = model.state_dict() cfg = model.default_cfg @@ -264,6 +282,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): feat_dim = getattr(model, 'feature_dim', None) outputs = model.forward_features(input_tensor) + outputs_pre = model.forward_head(outputs, pre_logits=True) if isinstance(outputs, (tuple, list)): # cannot currently verify multi-tensor output. pass @@ -271,6 +290,7 @@ def test_model_default_cfgs_non_std(model_name, batch_size): if feat_dim is None: feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features + assert outputs_pre.shape[1] == model.head_hidden_size # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) @@ -280,7 +300,8 @@ def test_model_default_cfgs_non_std(model_name, batch_size): outputs = outputs[0] if feat_dim is None: feat_dim = -1 if outputs.ndim == 3 else 1 - assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' + assert outputs.shape[feat_dim] == model.head_hidden_size, 'pooled num_features != config' + assert outputs.shape == outputs_pre.shape model = create_model(model_name, pretrained=False, num_classes=0).eval() model.to(torch_device) diff --git a/timm/models/_prune.py b/timm/models/_prune.py index 9bbe71ecf3..370b911f46 100644 --- a/timm/models/_prune.py +++ b/timm/models/_prune.py @@ -101,7 +101,10 @@ def adapt_model_from_string(parent_module, model_string): in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None) set_layer(new_module, n, new_fc) if hasattr(new_module, 'num_features'): + if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features: + new_module.head_hidden_size = num_features new_module.num_features = num_features + new_module.eval() parent_module.eval() diff --git a/timm/models/beit.py b/timm/models/beit.py index 922d15e79b..57007cd7d4 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -291,7 +291,7 @@ def __init__( super().__init__() self.num_classes = num_classes self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_prefix_tokens = 1 self.grad_checkpointing = False @@ -392,7 +392,7 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index a2b44e1a4d..02e258361d 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1224,6 +1224,7 @@ def __init__( dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))] self.stage_ends = [f['stage'] for f in self.feature_info] + self.head_hidden_size = self.num_features self.head = ClassifierHead( self.num_features, num_classes, @@ -1250,10 +1251,11 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_intermediates( diff --git a/timm/models/cait.py b/timm/models/cait.py index 2d4c73651f..78a7adc957 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -239,7 +239,7 @@ def __init__( self.num_classes = num_classes self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim self.grad_checkpointing = False self.patch_embed = patch_layer( @@ -328,7 +328,7 @@ def _matcher(name): return _matcher @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/coat.py b/timm/models/coat.py index 3e7b9c7a62..906ecb9083 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,7 +7,7 @@ Modified from timm/models/vision_transformer.py """ -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -380,7 +380,7 @@ def __init__( self.return_interm_layers = return_interm_layers self.out_features = out_features self.embed_dims = embed_dims - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] self.num_classes = num_classes self.global_pool = global_pool @@ -556,7 +556,7 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/convit.py b/timm/models/convit.py index fb42baa06b..dadc41b803 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -269,7 +269,7 @@ def __init__( self.num_classes = num_classes self.global_pool = global_pool self.local_up_to_layer = local_up_to_layer - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.locality_strength = locality_strength self.use_pos_embed = use_pos_embed @@ -345,7 +345,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index 3e43dd6647..c7a250776a 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -40,7 +40,7 @@ def __init__( ): super().__init__() self.num_classes = num_classes - self.num_features = dim + self.num_features = self.head_hidden_size = dim self.grad_checkpointing = False self.stem = nn.Sequential( @@ -74,7 +74,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 012a73c927..a09653cfa5 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -358,7 +358,7 @@ def __init__( # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.num_features = prev_chs + self.num_features = self.head_hidden_size = prev_chs # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) @@ -382,6 +382,7 @@ def __init__( norm_layer=norm_layer, act_layer='gelu', ) + self.head_hidden_size = self.head.num_features named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) @torch.jit.ignore @@ -401,10 +402,11 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes=0, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_intermediates( diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 27d75808f3..ff78395ecf 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -330,7 +330,7 @@ def __init__( num_patches = _compute_num_patches(self.img_size_scaled, patch_size) self.num_branches = len(patch_size) self.embed_dim = embed_dim - self.num_features = sum(embed_dim) + self.num_features = self.head_hidden_size = sum(embed_dim) self.patch_embed = nn.ModuleList() # hard-coded for torch jit script @@ -415,7 +415,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): @@ -423,9 +423,10 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): if global_pool is not None: assert global_pool in ('token', 'avg') self.global_pool = global_pool - self.head = nn.ModuleList( - [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in - range(self.num_branches)]) + self.head = nn.ModuleList([ + nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() + for i in range(self.num_branches) + ]) def forward_features(self, x) -> List[torch.Tensor]: B = x.shape[0] diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index d02acfb06d..7d63096a51 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -675,9 +675,13 @@ def __init__( self.feature_info.extend(stage_feat_info) # Construct the head - self.num_features = prev_chs + self.num_features = self.head_hidden_size = prev_chs self.head = ClassifierHead( - in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=prev_chs, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + ) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -698,11 +702,12 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/davit.py b/timm/models/davit.py index dceda60e54..39f1a115a3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -485,7 +485,7 @@ def __init__( norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) self.num_classes = num_classes - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] self.drop_rate = drop_rate self.grad_checkpointing = False self.feature_info = [] @@ -565,7 +565,7 @@ def set_grad_checkpointing(self, enable=True): stage.set_grad_checkpointing(enable=enable) @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/deit.py b/timm/models/deit.py index 96770beb3d..63662c02d4 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -60,7 +60,7 @@ def group_matcher(self, coarse=False): ) @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head, self.head_dist def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/densenet.py b/timm/models/densenet.py index ade61a14df..31d1f73f9c 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -247,7 +247,7 @@ def __init__( self.features.add_module('norm5', norm_layer(num_features)) self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')] - self.num_features = num_features + self.num_features = self.head_hidden_size = num_features # Linear layer global_pool, classifier = create_classifier( @@ -287,10 +287,10 @@ def set_grad_checkpointing(self, enable=True): b.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) @@ -298,11 +298,14 @@ def reset_classifier(self, num_classes, global_pool='avg'): def forward_features(self, x): return self.features(x) - def forward(self, x): - x = self.forward_features(x) + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) x = self.head_drop(x) - x = self.classifier(x) + return x if pre_logits else self.classifier(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x diff --git a/timm/models/dla.py b/timm/models/dla.py index 3052819db7..666acd9d9c 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -301,7 +301,7 @@ def __init__( dict(num_chs=channels[5], reduction=32, module='level5'), ] - self.num_features = channels[-1] + self.num_features = self.head_hidden_size = channels[-1] self.global_pool, self.head_drop, self.fc = create_classifier( self.num_features, self.num_classes, @@ -350,10 +350,10 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index d51e88efbe..c03e5fe1a1 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -229,7 +229,7 @@ def __init__( blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer) - self.num_features = in_chs + self.num_features = self.head_hidden_size = in_chs self.features = nn.Sequential(blocks) # Using 1x1 conv for the FC layer to allow the extra pooling scheme @@ -253,10 +253,10 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index 515bc22554..d768b1dc33 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -9,7 +9,7 @@ """ import math from functools import partial -from typing import Tuple +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -17,7 +17,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \ - use_fused_attn, NormMlpClassifierHead, ClassifierHead + NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_module from ._manipulate import named_apply, checkpoint_seq @@ -373,7 +373,7 @@ def __init__( self.stages = nn.Sequential(*stages) - self.num_features = dims[-1] + self.num_features = self.head_hidden_size = dims[-1] if head_norm_first: self.norm_pre = norm_layer(self.num_features) self.head = ClassifierHead( @@ -411,10 +411,11 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes=0, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes self.head.reset(num_classes, global_pool) def forward_features(self, x): diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 32630683fb..513f7e44b8 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -411,7 +411,7 @@ def __init__( self.stages = nn.Sequential(*stages) # Classifier head - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] self.norm = norm_layer_cl(self.num_features) self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() @@ -446,7 +446,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head, self.head_dist def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index ba3b7c5f3f..fee2ae7de2 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -571,7 +571,7 @@ def __init__( self.stages = nn.Sequential(*stages) # Classifier head - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] self.norm = norm_layer(embed_dims[-1]) self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() @@ -609,7 +609,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head, self.head_dist def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 46c4e81e52..36059577e0 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -97,7 +97,7 @@ def __init__( norm_act_layer = get_norm_act_layer(norm_layer, act_layer) se_layer = se_layer or SqueezeExcite self.num_classes = num_classes - self.num_features = num_features + self.num_features = self.head_hidden_size = num_features self.drop_rate = drop_rate self.grad_checkpointing = False @@ -153,7 +153,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/efficientvit_mit.py b/timm/models/efficientvit_mit.py index c971fe61fe..34be806b1e 100644 --- a/timm/models/efficientvit_mit.py +++ b/timm/models/efficientvit_mit.py @@ -6,8 +6,8 @@ Adapted from official impl at https://github.com/mit-han-lab/efficientvit """ -__all__ = ['EfficientVit'] -from typing import Optional +__all__ = ['EfficientVit', 'EfficientVitLarge'] +from typing import List, Optional from functools import partial import torch @@ -631,32 +631,50 @@ def forward(self, x): class ClassifierHead(nn.Module): def __init__( self, - in_channels, - widths, - n_classes=1000, - dropout=0., + in_channels: int, + widths: List[int], + num_classes: int = 1000, + dropout: float = 0., norm_layer=nn.BatchNorm2d, act_layer=nn.Hardswish, - global_pool='avg', - norm_eps=1e-5, + pool_type: str = 'avg', + norm_eps: float = 1e-5, ): super(ClassifierHead, self).__init__() + self.widths = widths + self.num_features = widths[-1] + + assert pool_type, 'Cannot disable pooling' self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer) - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True, input_fmt='NCHW') + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) self.classifier = nn.Sequential( nn.Linear(widths[0], widths[1], bias=False), nn.LayerNorm(widths[1], eps=norm_eps), act_layer(inplace=True) if act_layer is not None else nn.Identity(), nn.Dropout(dropout, inplace=False), - nn.Linear(widths[1], n_classes, bias=True), + nn.Linear(widths[1], num_classes, bias=True) if num_classes > 0 else nn.Identity(), ) + def reset(self, num_classes: int, pool_type: Optional[str] = None): + if pool_type is not None: + assert pool_type, 'Cannot disable pooling' + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True,) + if num_classes > 0: + self.classifier[-1] = nn.Linear(self.num_features, num_classes, bias=True) + else: + self.classifier[-1] = nn.Identity() + def forward(self, x, pre_logits: bool = False): x = self.in_conv(x) x = self.global_pool(x) if pre_logits: - return x - x = self.classifier(x) + # cannot slice or iterate with torchscript so, this + x = self.classifier[0](x) + x = self.classifier[1](x) + x = self.classifier[2](x) + x = self.classifier[3](x) + else: + x = self.classifier(x) return x @@ -704,21 +722,14 @@ def __init__( self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] self.num_features = in_channels - self.head_widths = head_widths - self.head_dropout = drop_rate - if num_classes > 0: - self.head = ClassifierHead( - self.num_features, - self.head_widths, - n_classes=num_classes, - dropout=self.head_dropout, - global_pool=self.global_pool, - ) - else: - if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) - else: - self.head = nn.Identity() + self.head = ClassifierHead( + self.num_features, + widths=head_widths, + num_classes=num_classes, + dropout=drop_rate, + pool_type=self.global_pool, + ) + self.head_hidden_size = self.head.num_features @torch.jit.ignore def group_matcher(self, coarse=False): @@ -736,26 +747,12 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.classifier[-1] def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes - if global_pool is not None: - self.global_pool = global_pool - if num_classes > 0: - self.head = ClassifierHead( - self.num_features, - self.head_widths, - n_classes=num_classes, - dropout=self.head_dropout, - global_pool=self.global_pool, - ) - else: - if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) - else: - self.head = nn.Identity() + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) @@ -820,23 +817,16 @@ def __init__( self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')] self.num_features = in_channels - self.head_widths = head_widths - self.head_dropout = drop_rate - if num_classes > 0: - self.head = ClassifierHead( - self.num_features, - self.head_widths, - n_classes=num_classes, - dropout=self.head_dropout, - global_pool=self.global_pool, - act_layer=act_layer, - norm_eps=self.norm_eps, - ) - else: - if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) - else: - self.head = nn.Identity() + self.head = ClassifierHead( + self.num_features, + widths=head_widths, + num_classes=num_classes, + dropout=drop_rate, + pool_type=self.global_pool, + act_layer=act_layer, + norm_eps=self.norm_eps, + ) + self.head_hidden_size = self.head.num_features @torch.jit.ignore def group_matcher(self, coarse=False): @@ -854,27 +844,12 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.classifier[-1] def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes - if global_pool is not None: - self.global_pool = global_pool - if num_classes > 0: - self.head = ClassifierHead( - self.num_features, - self.head_widths, - n_classes=num_classes, - dropout=self.head_dropout, - global_pool=self.global_pool, - norm_eps=self.norm_eps - ) - else: - if self.global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) - else: - self.head = nn.Identity() + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/efficientvit_msra.py b/timm/models/efficientvit_msra.py index deaf1fbac7..dd8ef80a85 100644 --- a/timm/models/efficientvit_msra.py +++ b/timm/models/efficientvit_msra.py @@ -437,7 +437,7 @@ def __init__( else: assert num_classes == 0 self.global_pool = nn.Identity() - self.num_features = embed_dim[-1] + self.num_features = self.head_hidden_size = embed_dim[-1] self.head = NormLinear( self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity() @@ -461,7 +461,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.linear def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/eva.py b/timm/models/eva.py index d424ab3d2f..7a1b67e13b 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -414,7 +414,7 @@ def __init__( super().__init__() self.num_classes = num_classes self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_prefix_tokens = (1 if class_token else 0) + num_reg_tokens self.dynamic_img_size = dynamic_img_size self.grad_checkpointing = False @@ -536,7 +536,7 @@ def group_matcher(self, coarse=False): return matcher @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/fastvit.py b/timm/models/fastvit.py index ef7ec3c935..c2f7f7de4e 100644 --- a/timm/models/fastvit.py +++ b/timm/models/fastvit.py @@ -1172,7 +1172,7 @@ def __init__( self.feature_info += [dict(num_chs=prev_dim, reduction=4 * scale, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) self.num_stages = len(self.stages) - self.num_features = prev_dim + self.num_features = self.head_hidden_size = prev_dim # For segmentation and detection, extract intermediate output if self.fork_feat: @@ -1192,7 +1192,7 @@ def __init__( self.add_module(layer_name, layer) else: # Classifier head - self.num_features = final_features = int(embed_dims[-1] * cls_ratio) + self.num_features = self.head_hidden_size = final_features = int(embed_dims[-1] * cls_ratio) self.final_conv = MobileOneBlock( in_chs=embed_dims[-1], out_chs=final_features, @@ -1241,7 +1241,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 1624340b01..f747001c73 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -367,7 +367,7 @@ def __init__( self.num_classes = num_classes self.embed_dim = embed_dim - self.num_features = embed_dim[-1] + self.num_features = self.head_hidden_size = embed_dim[-1] self.feature_info = [] self.stem = Downsample( @@ -407,6 +407,7 @@ def __init__( if head_hidden_size: self.norm = nn.Identity() + self.head_hidden_size = head_hidden_size self.head = NormMlpClassifierHead( self.num_features, num_classes, @@ -451,7 +452,7 @@ def set_grad_checkpointing(self, enable=True): l.set_grad_checkpointing(enable=enable) @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index 16f93bf443..44660a3f6c 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -405,7 +405,7 @@ def __init__( self.num_classes = num_classes self.drop_rate = drop_rate num_stages = len(depths) - self.num_features = int(embed_dim * 2 ** (num_stages - 1)) + self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (num_stages - 1)) if window_size is not None: window_size = to_ntuple(num_stages)(window_size) else: @@ -486,7 +486,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index d34b548521..07a17dabdf 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -9,6 +9,7 @@ """ import math from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -243,7 +244,8 @@ def __init__( self.blocks = nn.Sequential(*stages) # building last several layers - self.num_features = out_chs = 1280 + self.num_features = prev_chs + self.head_hidden_size = out_chs = 1280 self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) @@ -268,7 +270,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): @@ -276,7 +278,7 @@ def reset_classifier(self, num_classes, global_pool='avg'): # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.conv_stem(x) @@ -288,15 +290,14 @@ def forward_features(self, x): x = self.blocks(x) return x - def forward_head(self, x): + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) x = self.flatten(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.classifier(x) - return x + return x if pre_logits else self.classifier(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/hgnet.py b/timm/models/hgnet.py index d7f38af646..9482e10760 100644 --- a/timm/models/hgnet.py +++ b/timm/models/hgnet.py @@ -6,6 +6,7 @@ PP-HGNet: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet.py PP-HGNetv2: https://github.com/PaddlePaddle/PaddleClas/blob/release/2.5.1/ppcls/arch/backbone/legendary_models/pp_hgnet_v2.py """ +from typing import Dict, Optional import torch import torch.nn as nn @@ -347,20 +348,25 @@ def forward(self, x): class ClassifierHead(nn.Module): def __init__( self, - num_features, - num_classes, - pool_type='avg', - drop_rate=0., - use_last_conv=True, - class_expand=2048, - use_lab=False + in_features: int, + num_classes: int, + pool_type: str = 'avg', + drop_rate: float = 0., + hidden_size: Optional[int] = 2048, + use_lab: bool = False ): super(ClassifierHead, self).__init__() - self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=False, input_fmt='NCHW') - if use_last_conv: + self.num_features = in_features + if pool_type is not None: + if not pool_type: + assert num_classes == 0, 'Classifier head must be removed if pooling is disabled' + + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + if hidden_size is not None: + self.num_features = hidden_size last_conv = nn.Conv2d( - num_features, - class_expand, + in_features, + hidden_size, kernel_size=1, stride=1, padding=0, @@ -373,15 +379,20 @@ def __init__( else: self.last_conv = nn.Sequential(last_conv, act) else: - self.last_conv = nn.Indentity() + self.last_conv = nn.Identity() - if drop_rate > 0: - self.dropout = nn.Dropout(drop_rate) - else: - self.dropout = nn.Identity() + self.dropout = nn.Dropout(drop_rate) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() # don't flatten if pooling disabled + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes: int, pool_type: Optional[str] = None): + if pool_type is not None: + if not pool_type: + assert num_classes == 0, 'Classifier head must be removed if pooling is disabled' + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() # don't flatten if pooling disabled - self.flatten = nn.Flatten() - self.fc = nn.Linear(class_expand if use_last_conv else num_features, num_classes) + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) @@ -398,15 +409,14 @@ class HighPerfGpuNet(nn.Module): def __init__( self, - cfg, - in_chans=3, - num_classes=1000, - global_pool='avg', - use_last_conv=True, - class_expand=2048, - drop_rate=0., - drop_path_rate=0., - use_lab=False, + cfg: Dict, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + head_hidden_size: Optional[int] = 2048, + drop_rate: float = 0., + drop_path_rate: float = 0., + use_lab: bool = False, **kwargs, ): super(HighPerfGpuNet, self).__init__() @@ -415,8 +425,6 @@ def __init__( stages_cfg = [cfg["stage1"], cfg["stage2"], cfg["stage3"], cfg["stage4"]] self.num_classes = num_classes self.drop_rate = drop_rate - self.use_last_conv = use_last_conv - self.class_expand = class_expand self.use_lab = use_lab assert stem_type in ['v1', 'v2'] @@ -456,21 +464,15 @@ def __init__( self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - if num_classes > 0: - self.head = ClassifierHead( - self.num_features, - num_classes=num_classes, - pool_type=global_pool, - drop_rate=drop_rate, - use_last_conv=use_last_conv, - class_expand=class_expand, - use_lab=use_lab - ) - else: - if global_pool == 'avg': - self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) - else: - self.head = nn.Identity() + self.head = ClassifierHead( + self.num_features, + num_classes=num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + hidden_size=head_hidden_size, + use_lab=use_lab + ) + self.head_hidden_size = self.head.num_features for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): @@ -494,25 +496,12 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes - if num_classes > 0: - self.head = ClassifierHead( - self.num_features, - num_classes=num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - use_last_conv=self.use_last_conv, - class_expand=self.class_expand, - use_lab=self.use_lab) - else: - if global_pool: - self.head = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) - else: - self.head = nn.Identity() + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/hiera.py b/timm/models/hiera.py index f229daf4f8..e06d354508 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -570,7 +570,7 @@ def __init__( dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')] self.blocks.append(block) - self.num_features = embed_dim + self.num_features = self.head_hidden_size = embed_dim self.head = NormClassifierHead( embed_dim, num_classes, diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index d00adfa1f5..82d887a92a 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -578,7 +578,7 @@ def __init__( head_conv_bias = cfg.pop('head_conv_bias', True) if head == 'classification': # Classification Head - self.num_features = 2048 + self.num_features = self.head_hidden_size = 2048 self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head( pre_stage_channels, conv_bias=head_conv_bias, @@ -591,10 +591,10 @@ def __init__( ) else: if head == 'incre': - self.num_features = 2048 + self.num_features = self.head_hidden_size = 2048 self.incre_modules, _, _ = self._make_head(pre_stage_channels, incre_only=True) else: - self.num_features = 256 + self.num_features = self.head_hidden_size = 256 self.incre_modules = None self.global_pool = nn.Identity() self.head_drop = nn.Identity() @@ -736,7 +736,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, "gradient checkpointing not supported" @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classifier def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/inception_next.py b/timm/models/inception_next.py index f5d37db981..ec5d49e987 100644 --- a/timm/models/inception_next.py +++ b/timm/models/inception_next.py @@ -4,6 +4,7 @@ """ from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -14,6 +15,8 @@ from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs +__all__ = ['MetaNeXt'] + class InceptionDWConv2d(nn.Module): """ Inception depthwise convolution @@ -95,7 +98,7 @@ class MlpClassifierHead(nn.Module): def __init__( self, - dim, + in_features, num_classes=1000, pool_type='avg', mlp_ratio=3, @@ -105,23 +108,33 @@ def __init__( bias=True ): super().__init__() + self.use_conv = False + self.in_features = in_features + self.num_features = hidden_features = int(mlp_ratio * in_features) + + assert pool_type, 'Cannot disable pooling' self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) - in_features = dim * self.global_pool.feat_mult() - hidden_features = int(mlp_ratio * in_features) - self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + + self.fc1 = nn.Linear(in_features * self.global_pool.feat_mult(), hidden_features, bias=bias) self.act = act_layer() self.norm = norm_layer(hidden_features) self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) self.drop = nn.Dropout(drop) - def forward(self, x): + def reset(self, num_classes: int, pool_type: Optional[str] = None): + if pool_type is not None: + assert pool_type, 'Cannot disable pooling' + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True) + + self.fc2 = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) x = self.fc1(x) x = self.act(x) x = self.norm(x) x = self.drop(x) - x = self.fc2(x) - return x + return x if pre_logits else self.fc2(x) class MetaNeXtBlock(nn.Module): @@ -231,7 +244,6 @@ class MetaNeXt(nn.Module): norm_layer: Normalization layer. Default: nn.BatchNorm2d act_layer: Activation function for MLP. Default: nn.GELU mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3) - head_fn: classifier head drop_rate (float): Head dropout rate drop_path_rate (float): Stochastic depth rate. Default: 0. ls_init_value (float): Init value for Layer Scale. Default: 1e-6. @@ -249,7 +261,6 @@ def __init__( norm_layer=nn.BatchNorm2d, act_layer=nn.GELU, mlp_ratios=(4, 4, 4, 3), - head_fn=MlpClassifierHead, drop_rate=0., drop_path_rate=0., ls_init_value=1e-6, @@ -301,15 +312,8 @@ def __init__( prev_chs = out_chs self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.num_features = prev_chs - if self.num_classes > 0: - if issubclass(head_fn, MlpClassifierHead): - assert self.global_pool, 'Cannot disable global pooling with MLP head present.' - self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate) - else: - if self.global_pool: - self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) - else: - self.head = nn.Identity() + self.head = MlpClassifierHead(self.num_features, num_classes, pool_type=self.global_pool, drop=drop_rate) + self.head_hidden_size = self.head.num_features self.apply(self._init_weights) def _init_weights(self, m): @@ -329,21 +333,12 @@ def group_matcher(self, coarse=False): ) @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc2 - def reset_classifier(self, num_classes=0, global_pool=None, head_fn=MlpClassifierHead): - if global_pool is not None: - self.global_pool = global_pool - if num_classes > 0: - if issubclass(head_fn, MlpClassifierHead): - assert self.global_pool, 'Cannot disable global pooling with MLP head present.' - self.head = head_fn(self.num_features, num_classes, pool_type=self.global_pool, drop=self.drop_rate) - else: - if self.global_pool: - self.head = SelectAdaptivePool2d(pool_type=self.global_pool, flatten=True) - else: - self.head = nn.Identity() + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) @torch.jit.ignore def set_grad_checkpointing(self, enable=True): @@ -360,11 +355,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - if pre_logits: - if hasattr(self.head, 'global_pool'): - x = self.head.global_pool(x) - return x - return self.head(x) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index f4efaf520d..7fdfee41ed 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -206,7 +206,7 @@ def __init__( ): super(InceptionResnetV2, self).__init__() self.num_classes = num_classes - self.num_features = 1536 + self.num_features = self.head_hidden_size = 1536 assert output_stride == 32 conv_block = partial( ConvNormAct, @@ -270,10 +270,10 @@ def set_grad_checkpointing(self, enable=True): assert not enable, "checkpointing not supported" @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classif - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 1c2a407ec0..8cb1a151df 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -4,6 +4,7 @@ Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE """ from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -293,7 +294,7 @@ def __init__( dict(num_chs=2048, reduction=32, module='Mixed_7c'), ] - self.num_features = 2048 + self.num_features = self.head_hidden_size = 2048 self.global_pool, self.head_drop, self.fc = create_classifier( self.num_features, self.num_classes, @@ -331,10 +332,10 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) @@ -371,9 +372,11 @@ def forward_features(self, x): x = self.forward_postaux(x) return x - def forward_head(self, x): + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) x = self.head_drop(x) + if pre_logits: + return x x = self.fc(x) return x diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index a43290a3db..6f75817844 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -231,7 +231,7 @@ def __init__( super(InceptionV4, self).__init__() assert output_stride == 32 self.num_classes = num_classes - self.num_features = 1536 + self.num_features = self.head_hidden_size = 1536 conv_block = partial( ConvNormAct, padding=0, @@ -277,7 +277,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/levit.py b/timm/models/levit.py index ccac445c8a..4e43006ab0 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -555,7 +555,7 @@ def __init__( self.use_conv = use_conv self.num_classes = num_classes self.global_pool = global_pool - self.num_features = embed_dim[-1] + self.num_features = self.head_hidden_size = embed_dim[-1] self.embed_dim = embed_dim self.drop_rate = drop_rate self.grad_checkpointing = False @@ -625,15 +625,15 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None): + def reset_classifier(self, num_classes: int , global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool self.head = NormLinear( - self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity() + self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity() def forward_intermediates( self, @@ -727,10 +727,10 @@ def __init__(self, *args, **kwargs): self.distilled_training = False # must set this True to train w/ distillation token @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head, self.head_dist - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 3dc08a5547..9c418510f5 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -1197,9 +1197,9 @@ def __init__( self.stages = nn.Sequential(*stages) final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) - self.head_hidden_size = cfg.head_hidden_size - if self.head_hidden_size: + if cfg.head_hidden_size: self.norm = nn.Identity() + self.head_hidden_size = cfg.head_hidden_size self.head = NormMlpClassifierHead( self.num_features, num_classes, @@ -1210,6 +1210,7 @@ def __init__( ) else: # standard classifier head w/ norm, pooling, fc classifier + self.head_hidden_size = self.num_features self.norm = final_norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) @@ -1245,7 +1246,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/metaformer.py b/timm/models/metaformer.py index 7b026a2e43..7e3e758770 100644 --- a/timm/models/metaformer.py +++ b/timm/models/metaformer.py @@ -548,9 +548,12 @@ def __init__( # if using MlpHead, dropout is handled by MlpHead if num_classes > 0: if self.use_mlp_head: + # FIXME hidden size final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate) + self.head_hidden_size = self.num_features else: final = nn.Linear(self.num_features, num_classes) + self.head_hidden_size = self.num_features else: final = nn.Identity() @@ -577,7 +580,7 @@ def set_grad_checkpointing(self, enable=True): stage.set_grad_checkpointing(enable=enable) @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes=0, global_pool=None): diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index a1bf02be2e..087a924c9f 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -203,7 +203,7 @@ def __init__( super().__init__() self.num_classes = num_classes self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.grad_checkpointing = False self.stem = PatchEmbed( @@ -252,7 +252,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index b25d87bac0..358f15d1c1 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -92,7 +92,6 @@ def __init__( norm_act_layer = get_norm_act_layer(norm_layer, act_layer) se_layer = se_layer or SqueezeExcite self.num_classes = num_classes - self.num_features = num_features self.drop_rate = drop_rate self.grad_checkpointing = False @@ -118,23 +117,24 @@ def __init__( self.blocks = nn.Sequential(*builder(stem_size, block_args)) self.feature_info = builder.features self.stage_ends = [f['stage'] for f in self.feature_info] - head_chs = builder.in_chs + self.num_features = builder.in_chs # features of last stage, output of forward_features() + self.head_hidden_size = num_features # features of conv_head, pre_logits output # Head + Pooling self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - num_pooled_chs = head_chs * self.global_pool.feat_mult() + num_pooled_chs = self.num_features * self.global_pool.feat_mult() if head_norm: # mobilenet-v4 post-pooling PW conv is followed by a norm+act layer - self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type) # never bias - self.norm_head = norm_act_layer(self.num_features) + self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type) # never bias + self.norm_head = norm_act_layer(self.head_hidden_size) self.act2 = nn.Identity() else: # mobilenet-v3 and others only have an activation after final PW conv - self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) + self.conv_head = create_conv2d(num_pooled_chs, self.head_hidden_size, 1, padding=pad_type, bias=head_bias) self.norm_head = nn.Identity() self.act2 = act_layer(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -157,15 +157,15 @@ def set_grad_checkpointing(self, enable: bool = True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classifier def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes - # cannot meaningfully change pooling of efficient head after creation + # NOTE: cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() def forward_intermediates( self, @@ -262,10 +262,10 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso x = self.norm_head(x) x = self.act2(x) x = self.flatten(x) - if pre_logits: - return x if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) + if pre_logits: + return x return self.classifier(x) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index d6ba311c89..7735b63189 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -784,7 +784,7 @@ def __init__( feat_size = stage.feat_size self.stages.append(stage) - self.num_features = embed_dim + self.num_features = self.head_hidden_size = embed_dim self.norm = norm_layer(embed_dim) self.head = nn.Sequential(OrderedDict([ ('drop', nn.Dropout(self.drop_rate)), @@ -822,7 +822,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 954ee176b0..af072aa942 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -3,6 +3,7 @@ https://github.com/Cadene/pretrained-models.pytorch """ from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -407,7 +408,7 @@ def __init__( super(NASNetALarge, self).__init__() self.num_classes = num_classes self.stem_size = stem_size - self.num_features = num_features + self.num_features = self.head_hidden_size = num_features self.channel_multiplier = channel_multiplier assert output_stride == 32 @@ -514,7 +515,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): @@ -553,11 +554,10 @@ def forward_features(self, x): x = self.act(x_cell_17) return x - def forward_head(self, x): + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) x = self.head_drop(x) - x = self.last_linear(x) - return x + return x if pre_logits else self.last_linear(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/nest.py b/timm/models/nest.py index d1901cee21..1d9c752105 100644 --- a/timm/models/nest.py +++ b/timm/models/nest.py @@ -309,7 +309,7 @@ def __init__( num_heads = to_ntuple(num_levels)(num_heads) depths = to_ntuple(num_levels)(depths) self.num_classes = num_classes - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] self.feature_info = [] norm_layer = norm_layer or LayerNorm act_layer = act_layer or nn.GELU @@ -412,10 +412,10 @@ def set_grad_checkpointing(self, enable=True): l.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.head = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/nextvit.py b/timm/models/nextvit.py index 1a1d1b00bb..fb719d81de 100644 --- a/timm/models/nextvit.py +++ b/timm/models/nextvit.py @@ -515,7 +515,7 @@ def __init__( in_chs = out_chs = self.stage_out_chs[stage_idx][-1] stages += [stage] idx += depths[stage_idx] - self.num_features = out_chs + self.num_features = self.head_hidden_size = out_chs self.stages = nn.Sequential(*stages) self.norm = norm_layer(out_chs) self.head = ClassifierHead(pool_type=global_pool, in_features=out_chs, num_classes=num_classes) @@ -551,7 +551,7 @@ def set_grad_checkpointing(self, enable=True): stage.set_grad_checkpointing(enable=enable) @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 725b177c25..79cfde7ac0 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -393,6 +393,7 @@ def __init__( self.final_conv = nn.Identity() self.final_act = act_layer(inplace=cfg.num_features > 0) + self.head_hidden_size = self.num_features self.head = ClassifierHead( self.num_features, num_classes, @@ -429,10 +430,10 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.head.reset(num_classes, global_pool) def forward_features(self, x): diff --git a/timm/models/pit.py b/timm/models/pit.py index ce41b9fcbd..3a1090b89f 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -209,7 +209,7 @@ def __init__( self.transformers = SequentialTuple(*transformers) self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) - self.num_features = self.embed_dim = embed_dim + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # Classifier head self.head_drop = nn.Dropout(drop_rate) @@ -240,7 +240,7 @@ def set_distilled_training(self, enable=True): def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' - def get_classifier(self): + def get_classifier(self) -> nn.Module: if self.head_dist is not None: return self.head, self.head_dist else: @@ -248,6 +248,8 @@ def get_classifier(self): def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes + if global_pool is not None: + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if self.head_dist is not None: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index bee18604c3..6d6d9dbd9b 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -239,7 +239,7 @@ def __init__( ): super(PNASNet5Large, self).__init__() self.num_classes = num_classes - self.num_features = 4320 + self.num_features = self.head_hidden_size = 4320 assert output_stride == 32 self.conv_0 = ConvNormAct( @@ -304,7 +304,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/pvt_v2.py b/timm/models/pvt_v2.py index 90ebfe7ad5..9200bbd451 100644 --- a/timm/models/pvt_v2.py +++ b/timm/models/pvt_v2.py @@ -338,7 +338,7 @@ def __init__( self.stages = nn.Sequential(*stages) # classification head - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() @@ -376,7 +376,7 @@ def set_grad_checkpointing(self, enable=True): for s in self.stages: s.grad_checkpointing = enable - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 12187378b7..4edb257d49 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -449,6 +449,7 @@ def __init__( final_act = cfg.linear_out or cfg.preact self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.num_features = prev_width + self.head_hidden_size = self.num_features self.head = ClassifierHead( in_features=self.num_features, num_classes=num_classes, @@ -510,7 +511,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/repghost.py b/timm/models/repghost.py index c252880503..4b802d79b6 100644 --- a/timm/models/repghost.py +++ b/timm/models/repghost.py @@ -6,6 +6,7 @@ """ import copy from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -258,7 +259,8 @@ def __init__( self.blocks = nn.Sequential(*stages) # building last several layers - self.num_features = out_chs = 1280 + self.num_features = prev_chs + self.head_hidden_size = out_chs = 1280 self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) @@ -281,15 +283,16 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.classifier - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes - # cannot meaningfully change pooling of efficient head after creation - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + if global_pool is not None: + # NOTE: cannot meaningfully change pooling of efficient head after creation + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled + self.classifier = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.conv_stem(x) @@ -301,15 +304,14 @@ def forward_features(self, x): x = self.blocks(x) return x - def forward_head(self, x): + def forward_head(self, x, pre_logits: bool = False): x = self.global_pool(x) x = self.conv_head(x) x = self.act2(x) x = self.flatten(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) - x = self.classifier(x) - return x + return x if pre_logits else self.classifier(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/repvit.py b/timm/models/repvit.py index 00ad78c891..7dcb2cd939 100644 --- a/timm/models/repvit.py +++ b/timm/models/repvit.py @@ -306,7 +306,7 @@ def __init__( in_dim = embed_dim[i] self.stages = nn.Sequential(*stages) - self.num_features = embed_dim[-1] + self.num_features = self.head_hidden_size = embed_dim[-1] self.head_drop = nn.Dropout(drop_rate) self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation) @@ -320,16 +320,14 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=False): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation: bool = False): self.num_classes = num_classes if global_pool is not None: self.global_pool = global_pool - self.head = ( - RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity() - ) + self.head = RepVitClassifier(self.embed_dim[-1], num_classes, distillation) @torch.jit.ignore def set_distilled_training(self, enable=True): @@ -347,6 +345,8 @@ def forward_head(self, x, pre_logits: bool = False): if self.global_pool == 'avg': x = x.mean((2, 3), keepdim=False) x = self.head_drop(x) + if pre_logits: + return x return self.head(x) def forward(self, x): diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 15f169978c..6893bc5b8b 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -513,7 +513,7 @@ def __init__( self.feature_info.extend(stage_feature_info) # Head (Pooling and Classifier) - self.num_features = 512 * block.expansion + self.num_features = self.head_hidden_size = 512 * block.expansion self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.init_weights(zero_init_last=zero_init_last) @@ -541,7 +541,7 @@ def set_grad_checkpointing(self, enable: bool = True): def get_classifier(self, name_only: bool = False): return 'fc' if name_only else self.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 017c32964b..8aa2facfcb 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -31,6 +31,7 @@ from collections import OrderedDict # pylint: disable=g-importing-member from functools import partial +from typing import Optional import torch import torch.nn as nn @@ -415,7 +416,7 @@ def __init__( self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')] self.stages.add_module(str(stage_idx), stage) - self.num_features = prev_chs + self.num_features = self.head_hidden_size = prev_chs self.norm = norm_layer(self.num_features) if preact else nn.Identity() self.head = ClassifierHead( self.num_features, @@ -452,10 +453,10 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes self.head.reset(num_classes, global_pool) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index d34933f782..eeadeb337b 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -206,7 +206,7 @@ def __init__( dw_act_layer, drop_path_rate, ) - self.num_features = features[-1].out_channels + self.num_features = self.head_hidden_size = features[-1].out_channels self.features = nn.Sequential(*features) self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate) @@ -226,11 +226,12 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 847bfdb75e..7fa6c3e4aa 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -131,7 +131,7 @@ def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']]) self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']]) - self.num_features = cfg['num_features'] + self.num_features = self.head_hidden_size = cfg['num_features'] self.feature_info = cfg['feature_info'] self.global_pool, self.head_drop, self.fc = create_classifier( @@ -158,7 +158,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.fc def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/senet.py b/timm/models/senet.py index 8b203c372c..c04250fd60 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -299,7 +299,7 @@ def __init__( downsample_padding=downsample_padding ) self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')] - self.num_features = 512 * block.expansion + self.num_features = self.head_hidden_size = 512 * block.expansion self.global_pool, self.last_linear = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) @@ -334,7 +334,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 7e12453b8f..86c4b1df4d 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -338,7 +338,7 @@ def __init__( assert global_pool in ('', 'avg') self.num_classes = num_classes self.global_pool = global_pool - self.num_features = embed_dims[-1] # num_features for consistency with other models + self.num_features = self.head_hidden_size = embed_dims[-1] # for consistency with other models self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC) self.output_fmt = 'NHWC' self.feature_info = [] @@ -416,7 +416,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index d5369282ae..a5800937d7 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -503,7 +503,7 @@ def __init__( self.num_layers = len(depths) self.embed_dim = embed_dim - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (self.num_layers - 1)) self.feature_info = [] if not isinstance(embed_dim, (tuple, list)): @@ -601,7 +601,7 @@ def set_grad_checkpointing(self, enable=True): l.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 49d449c91e..7bf91032b9 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -511,7 +511,7 @@ def __init__( self.output_fmt = 'NHWC' self.num_layers = len(depths) self.embed_dim = embed_dim - self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (self.num_layers - 1)) self.feature_info = [] if not isinstance(embed_dim, (tuple, list)): @@ -602,7 +602,7 @@ def set_grad_checkpointing(self, enable=True): l.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 4a33c8035d..d5fcbadcc9 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -602,7 +602,7 @@ def __init__( self.patch_size: int = patch_size self.img_size: Tuple[int, int] = img_size self.window_size: int = window_size - self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) + self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.feature_info = [] self.patch_embed = PatchEmbed( diff --git a/timm/models/tiny_vit.py b/timm/models/tiny_vit.py index 85eee7e053..12a5ef2f16 100644 --- a/timm/models/tiny_vit.py +++ b/timm/models/tiny_vit.py @@ -486,7 +486,7 @@ def __init__( self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')] # Classifier head - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] norm_layer_cf = partial(LayerNorm2d, eps=1e-5) self.head = NormMlpClassifierHead( @@ -529,7 +529,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): @@ -544,8 +544,8 @@ def forward_features(self, x): x = self.stages(x) return x - def forward_head(self, x): - x = self.head(x) + def forward_head(self, x, pre_logits: bool = False): + x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) return x def forward(self, x): diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 00ab2ba729..9e37770ac7 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -216,7 +216,7 @@ def __init__( assert global_pool in ('', 'token', 'avg') self.num_classes = num_classes self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.grad_checkpointing = False self.pixel_embed = PixelEmbed( @@ -296,13 +296,14 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'token', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index c00a28e290..006b7e0b5f 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -179,7 +179,7 @@ def __init__( ] # head - self.num_features = (self.planes * 8) * Bottleneck.expansion + self.num_features = self.head_hidden_size = (self.planes * 8) * Bottleneck.expansion self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) # model initialization @@ -231,7 +231,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): @@ -252,7 +252,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - return x if pre_logits else self.head(x) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/twins.py b/timm/models/twins.py index c7b9e15770..1aea273dc8 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -310,7 +310,7 @@ def __init__( self.global_pool = global_pool self.depths = depths self.embed_dims = embed_dims - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] self.grad_checkpointing = False img_size = to_2tuple(img_size) @@ -379,7 +379,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 1ba12c9a92..4136b12c7c 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -5,7 +5,7 @@ Copyright 2021 Ross Wightman """ -from typing import Union, List, Dict, Any, cast +from typing import Any, Dict, List, Optional, Union, cast import torch import torch.nn as nn @@ -81,11 +81,11 @@ def __init__( super(VGG, self).__init__() assert output_stride == 32 self.num_classes = num_classes - self.num_features = 4096 self.drop_rate = drop_rate self.grad_checkpointing = False self.use_norm = norm_layer is not None self.feature_info = [] + prev_chs = in_chans net_stride = 1 pool_layer = nn.MaxPool2d @@ -107,9 +107,11 @@ def __init__( self.features = nn.Sequential(*layers) self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}')) + self.num_features = prev_chs + self.head_hidden_size = 4096 self.pre_logits = ConvMlp( prev_chs, - self.num_features, + self.head_hidden_size, 7, mlp_ratio=mlp_ratio, drop_rate=drop_rate, @@ -117,7 +119,7 @@ def __init__( conv_layer=conv_layer, ) self.head = ClassifierHead( - self.num_features, + self.head_hidden_size, num_classes, pool_type=global_pool, drop_rate=drop_rate, @@ -135,17 +137,12 @@ def set_grad_checkpointing(self, enable=True): assert not enable, 'gradient checkpointing not supported' @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes - self.head = ClassifierHead( - self.num_features, - self.num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - ) + self.head.reset(num_classes, global_pool) def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.features(x) @@ -153,7 +150,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: def forward_head(self, x: torch.Tensor, pre_logits: bool = False): x = self.pre_logits(x) - return x if pre_logits else self.head(x) + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 953fc64d5e..2ed3be5da8 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -338,7 +338,7 @@ def __init__( for i in range(self.stage_num1+self.stage_num2, depth) ]) - self.num_features = embed_dim if self.vit_stem else embed_dim * 2 + self.num_features = self.head_hidden_size = embed_dim if self.vit_stem else embed_dim * 2 self.norm = norm_layer(self.num_features) # head @@ -384,10 +384,10 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes, global_pool='avg'): + def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): self.num_classes = num_classes self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index e3f1b8f255..a3ca0990d8 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -468,7 +468,7 @@ def __init__( self.num_classes = num_classes self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.num_prefix_tokens += reg_tokens self.num_reg_tokens = reg_tokens diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index cd47700914..61003014c3 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -287,7 +287,7 @@ def __init__( self.num_classes = num_classes self.global_pool = global_pool - self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.num_prefix_tokens = 1 if class_token else 0 self.grad_checkpointing = False @@ -378,7 +378,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool=None): diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index fcab425250..2fd5209c85 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -396,8 +396,7 @@ def __init__( self.num_classes = num_classes self.global_pool = global_pool - # num_features for consistency with other models - self.num_features = self.embed_dim = embed_dim + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models self.grad_checkpointing = False self.patch_embed = embed_layer( @@ -534,7 +533,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes=0, global_pool=None): diff --git a/timm/models/volo.py b/timm/models/volo.py index cefabd0e27..e1e4e0db47 100644 --- a/timm/models/volo.py +++ b/timm/models/volo.py @@ -491,7 +491,7 @@ def __init__( self.global_pool = global_pool self.mix_token = use_mix_token self.pooling_scale = pooling_scale - self.num_features = embed_dims[-1] + self.num_features = self.head_hidden_size = embed_dims[-1] if use_mix_token: # enable token mixing, see token labeling for details. self.beta = 1.0 assert global_pool == 'token', "return all tokens if mix_token is enabled" @@ -619,7 +619,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 8e9d1679d2..5ca409d383 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -227,6 +227,7 @@ def __init__( self.stages = nn.Sequential(*stages) + self.head_hidden_size = self.num_features self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) for n, m in self.named_modules(): @@ -248,7 +249,7 @@ def set_grad_checkpointing(self, enable=True): s.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/xception.py b/timm/models/xception.py index 041e2b9305..c023705194 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -105,7 +105,7 @@ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg' self.drop_rate = drop_rate self.global_pool = global_pool self.num_classes = num_classes - self.num_features = 2048 + self.num_features = self.head_hidden_size = 2048 self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) @@ -171,7 +171,7 @@ def set_grad_checkpointing(self, enable=True): assert not enable, "gradient checkpointing not supported" @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.fc def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index 1656e72bf1..0eabdca213 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -251,6 +251,7 @@ def __init__( self.feature_info += [dict( num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] self.act = act_layer(inplace=True) if preact else nn.Identity() + self.head_hidden_size = self.num_features self.head = ClassifierHead( in_features=self.num_features, num_classes=num_classes, @@ -270,7 +271,7 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 0e6e118e51..a1e0f0c3c0 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -346,7 +346,7 @@ def __init__( act_layer = act_layer or nn.GELU self.num_classes = num_classes - self.num_features = self.embed_dim = embed_dim + self.num_features = self.head_hidden_size = self.embed_dim = embed_dim self.global_pool = global_pool self.grad_checkpointing = False @@ -429,10 +429,10 @@ def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token')