Skip to content

Commit

Permalink
Merge pull request #2195 from huggingface/refactor_pre_logits
Browse files Browse the repository at this point in the history
Fix consistency, testing for forward_head w/ pre_logits, reset_classifier, models with pre_logits size != unpooled feature size
  • Loading branch information
rwightman authored Jun 7, 2024
2 parents 4535a54 + 5ee0676 commit 5517b05
Show file tree
Hide file tree
Showing 73 changed files with 405 additions and 396 deletions.
57 changes: 39 additions & 18 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ def test_model_backward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'


# models with extra conv/linear layers after pooling
EARLY_POOL_MODELS = (
timm.models.EfficientVit,
timm.models.EfficientVitLarge,
timm.models.HighPerfGpuNet,
timm.models.GhostNet,
timm.models.MetaNeXt, # InceptionNeXt
timm.models.MobileNetV3,
timm.models.RepGhostNet,
timm.models.VGG,
)

@pytest.mark.cfg
@pytest.mark.timeout(timeout300)
@pytest.mark.parametrize('model_name', list_models(
Expand All @@ -179,6 +191,9 @@ def test_model_default_cfgs(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()
model.to(torch_device)
assert getattr(model, 'num_classes') >= 0
assert getattr(model, 'num_features') > 0
assert getattr(model, 'head_hidden_size') > 0
state_dict = model.state_dict()
cfg = model.default_cfg

Expand All @@ -195,37 +210,37 @@ def test_model_default_cfgs(model_name, batch_size):
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)

# test forward_features (always unpooled)
# test forward_features (always unpooled) & forward_head w/ pre_logits
outputs = model.forward_features(input_tensor)
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
assert outputs.shape[feat_axis] == model.num_features
outputs_pre = model.forward_head(outputs, pre_logits=True)
assert outputs.shape[spatial_axis[0]] == pool_size[0], f'unpooled feature shape {outputs.shape} != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], f'unpooled feature shape {outputs.shape} != config'
assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}'
assert outputs_pre.shape[1] == model.head_hidden_size, f'pre_logits feature dim {outputs_pre.shape[1]} != model.head_hidden_size {model.head_hidden_size}'

# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
assert outputs.shape[1] == model.head_hidden_size, f'feature dim w/ removed classifier {outputs.shape[1]} != model.head_hidden_size {model.head_hidden_size}'
assert outputs.shape == outputs_pre.shape, f'output shape of pre_logits {outputs_pre.shape} does not match reset_head(0) {outputs.shape}'

# test model forward without pooling and classifier
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/repghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
# test model forward after removing pooling and classifier
if not isinstance(model, EARLY_POOL_MODELS):
model.reset_classifier(0, '') # reset classifier and disable global pooling
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]

if 'pruned' not in model_name: # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
# test classifier + global pool deletion via __init__
if 'pruned' not in model_name and not isinstance(model, EARLY_POOL_MODELS):
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]

# check classifier name matches default_cfg
if cfg.get('num_classes', None):
Expand Down Expand Up @@ -253,6 +268,9 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()
model.to(torch_device)
assert getattr(model, 'num_classes') >= 0
assert getattr(model, 'num_features') > 0
assert getattr(model, 'head_hidden_size') > 0
state_dict = model.state_dict()
cfg = model.default_cfg

Expand All @@ -264,13 +282,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
feat_dim = getattr(model, 'feature_dim', None)

outputs = model.forward_features(input_tensor)
outputs_pre = model.forward_head(outputs, pre_logits=True)
if isinstance(outputs, (tuple, list)):
# cannot currently verify multi-tensor output.
pass
else:
if feat_dim is None:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features
assert outputs_pre.shape[1] == model.head_hidden_size

# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
Expand All @@ -280,7 +300,8 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
outputs = outputs[0]
if feat_dim is None:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
assert outputs.shape[feat_dim] == model.head_hidden_size, 'pooled num_features != config'
assert outputs.shape == outputs_pre.shape

model = create_model(model_name, pretrained=False, num_classes=0).eval()
model.to(torch_device)
Expand Down
3 changes: 3 additions & 0 deletions timm/models/_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def adapt_model_from_string(parent_module, model_string):
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features:
new_module.head_hidden_size = num_features
new_module.num_features = num_features

new_module.eval()
parent_module.eval()

Expand Down
4 changes: 2 additions & 2 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init__(
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False

Expand Down Expand Up @@ -392,7 +392,7 @@ def group_matcher(self, coarse=False):
return matcher

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
6 changes: 4 additions & 2 deletions timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,7 @@ def __init__(
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
self.stage_ends = [f['stage'] for f in self.feature_info]

self.head_hidden_size = self.num_features
self.head = ClassifierHead(
self.num_features,
num_classes,
Expand All @@ -1250,10 +1251,11 @@ def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_intermediates(
Expand Down
4 changes: 2 additions & 2 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(

self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
self.grad_checkpointing = False

self.patch_embed = patch_layer(
Expand Down Expand Up @@ -328,7 +328,7 @@ def _matcher(name):
return _matcher

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
6 changes: 3 additions & 3 deletions timm/models/coat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Modified from timm/models/vision_transformer.py
"""
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -380,7 +380,7 @@ def __init__(
self.return_interm_layers = return_interm_layers
self.out_features = out_features
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
self.num_features = self.head_hidden_size = embed_dims[-1]
self.num_classes = num_classes
self.global_pool = global_pool

Expand Down Expand Up @@ -556,7 +556,7 @@ def group_matcher(self, coarse=False):
return matcher

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
4 changes: 2 additions & 2 deletions timm/models/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(
self.num_classes = num_classes
self.global_pool = global_pool
self.local_up_to_layer = local_up_to_layer
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
self.locality_strength = locality_strength
self.use_pos_embed = use_pos_embed

Expand Down Expand Up @@ -345,7 +345,7 @@ def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
4 changes: 2 additions & 2 deletions timm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
):
super().__init__()
self.num_classes = num_classes
self.num_features = dim
self.num_features = self.head_hidden_size = dim
self.grad_checkpointing = False

self.stem = nn.Sequential(
Expand Down Expand Up @@ -74,7 +74,7 @@ def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
8 changes: 5 additions & 3 deletions timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_chs
self.num_features = self.head_hidden_size = prev_chs

# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
Expand All @@ -382,6 +382,7 @@ def __init__(
norm_layer=norm_layer,
act_layer='gelu',
)
self.head_hidden_size = self.head.num_features
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)

@torch.jit.ignore
Expand All @@ -401,10 +402,11 @@ def set_grad_checkpointing(self, enable=True):
s.grad_checkpointing = enable

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_intermediates(
Expand Down
11 changes: 6 additions & 5 deletions timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def __init__(
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
self.num_branches = len(patch_size)
self.embed_dim = embed_dim
self.num_features = sum(embed_dim)
self.num_features = self.head_hidden_size = sum(embed_dim)
self.patch_embed = nn.ModuleList()

# hard-coded for torch jit script
Expand Down Expand Up @@ -415,17 +415,18 @@ def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')
self.global_pool = global_pool
self.head = nn.ModuleList(
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
range(self.num_branches)])
self.head = nn.ModuleList([
nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
for i in range(self.num_branches)
])

def forward_features(self, x) -> List[torch.Tensor]:
B = x.shape[0]
Expand Down
15 changes: 10 additions & 5 deletions timm/models/cspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,13 @@ def __init__(
self.feature_info.extend(stage_feat_info)

# Construct the head
self.num_features = prev_chs
self.num_features = self.head_hidden_size = prev_chs
self.head = ClassifierHead(
in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=prev_chs,
num_classes=num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)

named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

Expand All @@ -698,11 +702,12 @@ def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
x = self.stem(x)
Expand Down
4 changes: 2 additions & 2 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def __init__(
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.num_classes = num_classes
self.num_features = embed_dims[-1]
self.num_features = self.head_hidden_size = embed_dims[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
Expand Down Expand Up @@ -565,7 +565,7 @@ def set_grad_checkpointing(self, enable=True):
stage.set_grad_checkpointing(enable=enable)

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
2 changes: 1 addition & 1 deletion timm/models/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def group_matcher(self, coarse=False):
)

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head, self.head_dist

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
Loading

0 comments on commit 5517b05

Please sign in to comment.