From cdc7bcea696be62377420e5f500307f20483e1c3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 11 Jun 2024 21:32:07 -0700 Subject: [PATCH] Make 2d attention pool modules compatible with head interface. Use attention pool in CLIP ResNets as head. Make separate set of GAP models w/ avg pool instead of attn pool. --- timm/layers/attention_pool2d.py | 64 +++++++- timm/models/byobnet.py | 249 +++++++++++++++++++------------- 2 files changed, 209 insertions(+), 104 deletions(-) diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index 443a384ca8..21328fd367 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -41,9 +41,12 @@ def __init__( num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, - drop: float = 0., + pool_type: str = 'token', + avg_token: bool = True, + drop_rate: float = 0., ): super().__init__() + assert pool_type in ('', 'token') self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features @@ -56,6 +59,7 @@ def __init__( num_heads = embed_dim // head_dim self.num_heads = num_heads self.head_dim = head_dim + self.pool_type = pool_type.lower() self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() @@ -66,6 +70,7 @@ def __init__( self.qkv = None else: self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.drop = nn.Dropout(drop_rate) self.proj = nn.Linear(embed_dim, self.out_features) self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size) @@ -83,6 +88,23 @@ def init_weights(self, zero_init_last: bool = False): trunc_normal_(self.qkv.weight, std=in_features ** -0.5) nn.init.zeros_(self.qkv.bias) + def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None): + # NOTE: this module is being used as a head, so need compatible reset() + if pool_type is not None: + assert pool_type in ('', 'token') + self.pool_type = pool_type + if num_classes is not None: + self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + self.out_features = num_classes if num_classes > 0 else self.embed_dim + + def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + if self.pool_type == 'token': + x = x[:, 0] + else: + # if not pooled, return spatial output without token + x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + return x + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W @@ -111,8 +133,10 @@ def forward(self, x, pre_logits: bool = False): x = x[:, 0] x = self.drop(x) if pre_logits: + x = self._pool(x, H, W) return x x = self.proj(x) + x = self._pool(x, H, W) return x @@ -137,9 +161,12 @@ def __init__( num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, - drop: float = 0., + pool_type: str = 'token', + learned_token: bool = False, + drop_rate: float = 0., ): super().__init__() + assert pool_type in ('', 'token') self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features @@ -153,9 +180,15 @@ def __init__( self.seq_len = self.feat_size[0] * self.feat_size[1] self.num_heads = num_heads self.head_dim = head_dim + self.pool_type = pool_type self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() + if learned_token: + self.token = nn.Parameter(torch.zeros(1, embed_dim)) + else: + self.token = None + if qkv_separate: self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) @@ -164,7 +197,7 @@ def __init__( else: self.q = self.k = self.v = None self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) - self.drop = nn.Dropout(drop) + self.drop = nn.Dropout(drop_rate) self.proj = nn.Linear(embed_dim, self.out_features) self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) @@ -185,11 +218,31 @@ def init_weights(self, zero_init_last: bool = False): nn.init.zeros_(self.qkv.bias) trunc_normal_(self.pos_embed, std=in_features ** -0.5) + def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None): + # NOTE: this module is being used as a head, so need compatible reset() + if pool_type is not None: + assert pool_type in ('', 'token') + self.pool_type = pool_type + if num_classes is not None: + self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + self.out_features = num_classes if num_classes > 0 else self.embed_dim + + def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + if self.pool_type == 'token': + x = x[:, 0] + else: + # if not pooled, return spatial output without token + x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + return x + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) - x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + if self.token is not None: + x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1) + else: + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) x = x + pos_embed @@ -209,9 +262,10 @@ def forward(self, x, pre_logits: bool = False): attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, N + 1, -1) - x = x[:, 0] x = self.drop(x) if pre_logits: + x = self._pool(x, H, W) return x x = self.proj(x) + x = self._pool(x, H, W) return x diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index f84f24556d..d3bf9c0d29 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -37,8 +37,11 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a +from timm.layers import ( + ClassifierHead, NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, EvoNorm2dS0a, + AttentionPool2d, RotAttentionPool2d, DropPath, AvgPool2dSame, + create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq @@ -83,7 +86,7 @@ class ByoModelCfg: # Head config head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output - head_type: str = '' + head_type: str = 'classifier' # Block config # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there @@ -1186,7 +1189,7 @@ def __init__( cfg: ByoModelCfg, num_classes: int = 1000, in_chans: int = 3, - global_pool: str = 'avg', + global_pool: Optional[str] = None, output_stride: int = 32, img_size: Optional[Union[int, Tuple[int, int]]] = None, drop_rate: float = 0., @@ -1257,76 +1260,59 @@ def __init__( self.stage_ends = [f['stage'] for f in self.feature_info] self.head_hidden_size = self.num_features - self.global_pool = global_pool - assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier') - if cfg.head_type == 'norm_mlp_classifier': - from timm.layers import NormMlpClassifierHead - assert not cfg.attn_pool, "Cannot use attentional pooling with norm + MLP head" - self.attn_pool = nn.Identity() + assert cfg.head_type in ('', 'classifier', 'mlp', 'attn_abs', 'attn_rot') + if cfg.head_type == 'mlp': + if global_pool is None: + global_pool = 'avg' self.head = NormMlpClassifierHead( self.num_features, num_classes, hidden_size=cfg.head_hidden_size, + pool_type=global_pool, norm_layer=cfg.norm_layer, act_layer=cfg.act_layer, + drop_rate=self.drop_rate, ) self.head_hidden_size = self.head.hidden_size + elif cfg.head_type == 'attn_abs': + if global_pool is None: + global_pool = 'token' + assert global_pool in ('', 'token') + self.head = AttentionPool2d( + self.num_features, + embed_dim=cfg.head_hidden_size, + out_features=num_classes, + feat_size=feat_size, + pool_type=global_pool, + drop_rate=self.drop_rate, + qkv_separate=True, + ) + self.head_hidden_size = self.head.embed_dim + elif cfg.head_type =='attn_rot': + if global_pool is None: + global_pool = 'token' + assert global_pool in ('', 'token') + self.head = RotAttentionPool2d( + self.num_features, + embed_dim=cfg.head_hidden_size, + out_features=num_classes, + ref_feat_size=feat_size, + pool_type=global_pool, + drop_rate=self.drop_rate, + qkv_separate=True, + ) + self.head_hidden_size = self.head.embed_dim else: - # FIXME evaluating different head vs pool configurations - if False: - if global_pool == 'attn_abs': - from timm.layers import AttentionPool2d - self.attn_pool = AttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - feat_size=feat_size, - qkv_separate=True, - ) - global_pool = '' # clear for ClassifierHead - self.head_hidden_size = self.attn_pool.out_features - elif global_pool =='attn_rot': - from timm.layers import RotAttentionPool2d - self.attn_pool = RotAttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - ref_feat_size=feat_size, - ) - global_pool = '' # clear for ClassifierHead - self.head_hidden_size = self.attn_pool.out_features - else: - self.attn_pool = nn.Identity() - - self.head = ClassifierHead( - self.head_hidden_size, - num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - ) - else: - if global_pool == 'attn_abs': - from timm.layers import AttentionPool2d - self.head = AttentionPool2d( - self.num_features, - out_features=num_classes, - feat_size=feat_size, - qkv_separate=True, - ) - self.head_hidden_size = self.head.embed_dim - elif global_pool == 'attn_rot': - from timm.layers import RotAttentionPool2d - self.head = RotAttentionPool2d( - self.num_features, - out_features=num_classes, - ref_feat_size=feat_size, - ) - self.head_hidden_size = self.head.embed_dim - else: - self.head = ClassifierHead( - self.num_features, - num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - ) + if global_pool is None: + global_pool = 'avg' + assert cfg.head_hidden_size is None + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + self.global_pool = global_pool # init weights named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -1352,9 +1338,6 @@ def get_classifier(self) -> nn.Module: def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes - if global_pool is not None: - if self.global_pool in ('attn_abs', 'attn_rot'): - raise RuntimeError('Cannot change attention pool on head reset.') self.head.reset(num_classes, global_pool) def forward_intermediates( @@ -1444,7 +1427,6 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - #x = self.attn_pool(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): @@ -1947,7 +1929,7 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=1024, + head_type='attn_abs', ), resnet101_clip=ByoModelCfg( @@ -1962,7 +1944,8 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=512, + head_type='attn_abs', + #head_hidden_size=512, ), resnet50x4_clip=ByoModelCfg( @@ -1978,7 +1961,8 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=640, + head_type='attn_abs', + #head_hidden_size=640, ), resnet50x16_clip=ByoModelCfg( @@ -1994,7 +1978,8 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=768, + head_type='attn_abs', + #head_hidden_size=768, ), resnet50x64_clip=ByoModelCfg( @@ -2010,10 +1995,11 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=1024, + head_type='attn_abs', + #head_hidden_size=1024, ), - resnet50_nmlp=ByoModelCfg( + resnet50_mlp=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), @@ -2026,9 +2012,11 @@ def _init_weights(module, name='', zero_init_last=False): downsample='avg', aa_layer='avg', head_hidden_size=1024, - head_type='norm_mlp_classifier', + head_type='mlp', ), ) +for k in ('resnet50_clip', 'resnet101_clip', 'resnet50x4_clip', 'resnet50x16_clip', 'resnet50x64_clip'): + model_cfgs[k + '_gap'] = replace(model_cfgs[k], head_type='classifier') def _convert_openai_clip( @@ -2036,6 +2024,7 @@ def _convert_openai_clip( model: ByobNet, prefix: str = 'visual.', ) -> Dict[str, torch.Tensor]: + model_has_attn_pool = isinstance(model.head, (RotAttentionPool2d, AttentionPool2d)) import re def _stage_sub(m): @@ -2060,6 +2049,8 @@ def _down_sub(m): k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k) k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k) if k.startswith(f'{prefix}attnpool'): + if not model_has_attn_pool: + continue k = k.replace(prefix + 'attnpool', 'head') #'attn_pool') k = k.replace('positional_embedding', 'pos_embed') k = k.replace('q_proj', 'q') @@ -2081,17 +2072,11 @@ def checkpoint_filter_fn( def _create_byobnet(variant, pretrained=False, **kwargs): - strict = True - if 'clip' in variant and kwargs.get('global_pool', None) != 'attn_abs': - # NOTE: a hack to allow removing attention pool from CLIP ResNet variants - strict = False - return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), - pretrained_strict=strict, **kwargs, ) @@ -2287,42 +2272,78 @@ def _cfgr(url='', **kwargs): first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), ), + # original attention pool head variants 'resnet50_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), classifier = 'head.proj', ), 'resnet101_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=512, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), classifier='head.proj', ), 'resnet50x4_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=640, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9), classifier = 'head.proj', ), 'resnet50x16_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=768, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12), classifier = 'head.proj', ), 'resnet50x64_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14), classifier = 'head.proj', ), + # avg-pool w/ optional standard classifier head variants + 'resnet50_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 224, 224), pool_size=(7, 7), + ), + 'resnet101_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet101_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 224, 224), pool_size=(7, 7), + ), + 'resnet50x4_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50x4_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 288, 288), pool_size=(9, 9), + ), + 'resnet50x16_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50x16_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 384, 384), pool_size=(12, 12), + ), + 'resnet50x64_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50x64_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 448, 448), pool_size=(14, 14), + ), + + 'resnet50_mlp.untrained': _cfgr( + input_size=(3, 256, 256), pool_size=(8, 8), + ), }) @@ -2631,44 +2652,74 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet: def resnet50_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs) @register_model def resnet101_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-101 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet101_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs) @register_model def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x4 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs) @register_model def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x16 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs) @register_model def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x64 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet101_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-101 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet101_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x4_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x4 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50x4_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x16_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x16 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50x16_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x64_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x64 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50x64_clip_gap', pretrained=pretrained, **kwargs) @register_model -def resnet50_nmlp(pretrained=False, **kwargs) -> ByobNet: +def resnet50_mlp(pretrained=False, **kwargs) -> ByobNet: """ """ - return _create_byobnet('resnet50_nmlp', pretrained=pretrained, **kwargs) + return _create_byobnet('resnet50_mlp', pretrained=pretrained, **kwargs)