From 5efa15b2a2476a02c6f0ead515494e7b76cd13cd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 9 Jun 2024 16:54:48 -0700 Subject: [PATCH 1/6] Mapping OpenAI CLIP Modified ResNet weights -> ByobNet. Improve AttentionPool2d layers. Fix #1731 --- timm/layers/attention_pool.py | 6 +- timm/layers/attention_pool2d.py | 194 ++++++++++++------ timm/layers/classifier.py | 2 - timm/layers/pos_embed_sincos.py | 1 - timm/models/byobnet.py | 351 +++++++++++++++++++++++++++++--- 5 files changed, 461 insertions(+), 93 deletions(-) diff --git a/timm/layers/attention_pool.py b/timm/layers/attention_pool.py index 41e404d27c..da5585b363 100644 --- a/timm/layers/attention_pool.py +++ b/timm/layers/attention_pool.py @@ -20,6 +20,7 @@ def __init__( out_features: int = None, embed_dim: int = None, num_heads: int = 8, + feat_size: Optional[int] = None, mlp_ratio: float = 4.0, qkv_bias: bool = True, qk_norm: bool = False, @@ -36,13 +37,14 @@ def __init__( assert embed_dim % num_heads == 0 self.num_heads = num_heads self.head_dim = embed_dim // num_heads + self.feat_size = feat_size self.scale = self.head_dim ** -0.5 self.pool = pool_type self.fused_attn = use_fused_attn() if pos_embed == 'abs': - spatial_len = self.feat_size - self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) + assert feat_size is not None + self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features)) else: self.pos_embed = None diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index 765efa083d..dc594b70dd 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -7,12 +7,14 @@ Hacked together by / Copyright 2021 Ross Wightman """ -from typing import Union, Tuple +from typing import Optional, Union, Tuple import torch import torch.nn as nn +from. config import use_fused_attn from .helpers import to_2tuple +from .pos_embed import resample_abs_pos_embed from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding from .weight_init import trunc_normal_ @@ -27,51 +29,84 @@ class RotAttentionPool2d(nn.Module): NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW """ + fused_attn: torch.jit.Final[bool] + def __init__( self, in_features: int, - out_features: int = None, - embed_dim: int = None, - num_heads: int = 4, + out_features: Optional[int] = None, + ref_feat_size: Union[int, Tuple[int, int]] = 7, + embed_dim: Optional[int] = None, + head_dim: Optional[int] = 64, + num_heads: Optional[int] = None, qkv_bias: bool = True, + qkv_separate: bool = False, ): super().__init__() embed_dim = embed_dim or in_features - out_features = out_features or in_features - self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) - self.proj = nn.Linear(embed_dim, out_features) + self.in_features = in_features + self.out_features = out_features or in_features + ref_feat_size = to_2tuple(ref_feat_size) + if num_heads is not None: + assert embed_dim % num_heads == 0 + head_dim = embed_dim // num_heads + else: + assert embed_dim % head_dim == 0 + num_heads = embed_dim // head_dim self.num_heads = num_heads - assert embed_dim % num_heads == 0 - self.head_dim = embed_dim // num_heads + self.head_dim = head_dim self.scale = self.head_dim ** -0.5 - self.pos_embed = RotaryEmbedding(self.head_dim) - - trunc_normal_(self.qkv.weight, std=in_features ** -0.5) - nn.init.zeros_(self.qkv.bias) + self.fused_attn = use_fused_attn() + + if qkv_separate: + self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.qkv = None + else: + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, self.out_features) + self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size) + + def init_weights(self, zero_init_last: bool = False): + if self.qkv is None: + in_features = self.q.in_features + trunc_normal_(self.q.weight, std=in_features ** -0.5) + nn.init.zeros_(self.q.bias) + trunc_normal_(self.k.weight, std=in_features ** -0.5) + nn.init.zeros_(self.k.bias) + trunc_normal_(self.v.weight, std=in_features ** -0.5) + nn.init.zeros_(self.v.bias) + else: + in_features = self.qkv.in_features + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W - x = x.reshape(B, -1, N).permute(0, 2, 1) - + x = x.flatten(2).transpose(1, 2) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) - - x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = x[0], x[1], x[2] - - qc, q = q[:, :, :1], q[:, :, 1:] - sin_emb, cos_emb = self.pos_embed.get_embed((H, W)) - q = apply_rot_embed(q, sin_emb, cos_emb) - q = torch.cat([qc, q], dim=2) - - kc, k = k[:, :, :1], k[:, :, 1:] - k = apply_rot_embed(k, sin_emb, cos_emb) - k = torch.cat([kc, k], dim=2) - - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - - x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + if self.qkv is None: + q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + else: + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x.unbind(0) + + rse, rce = self.pos_embed.get_embed((H, W)) + q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v) + k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v) + + if self.fused_attn: + x = nn.functional.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0] @@ -85,47 +120,90 @@ class AttentionPool2d(nn.Module): NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. """ + fused_attn: torch.jit.Final[bool] + def __init__( self, in_features: int, - feat_size: Union[int, Tuple[int, int]], - out_features: int = None, - embed_dim: int = None, - num_heads: int = 4, + feat_size: Union[int, Tuple[int, int]] = 7, + out_features: Optional[int] = None, + embed_dim: Optional[int] = None, + head_dim: Optional[int] = 64, + num_heads: Optional[int] = None, qkv_bias: bool = True, + qkv_separate: bool = False, ): super().__init__() - embed_dim = embed_dim or in_features - out_features = out_features or in_features - assert embed_dim % num_heads == 0 + self.in_features = in_features + self.out_features = out_features or in_features + if num_heads is not None: + assert embed_dim % num_heads == 0 + head_dim = embed_dim // num_heads + else: + assert embed_dim % head_dim == 0 + num_heads = embed_dim // head_dim self.feat_size = to_2tuple(feat_size) - self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) - self.proj = nn.Linear(embed_dim, out_features) + self.seq_len = self.feat_size[0] * self.feat_size[1] self.num_heads = num_heads - self.head_dim = embed_dim // num_heads + self.head_dim = head_dim self.scale = self.head_dim ** -0.5 - - spatial_dim = self.feat_size[0] * self.feat_size[1] - self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) + self.fused_attn = use_fused_attn() + + if qkv_separate: + self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias) + self.qkv = None + else: + self.q = self.k = self.v = None + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.proj = nn.Linear(embed_dim, self.out_features) + self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) + + self.init_weights() + + def init_weights(self, zero_init_last: bool = False): + if self.qkv is None: + in_features = self.q.in_features + trunc_normal_(self.q.weight, std=in_features ** -0.5) + nn.init.zeros_(self.q.bias) + trunc_normal_(self.k.weight, std=in_features ** -0.5) + nn.init.zeros_(self.k.bias) + trunc_normal_(self.v.weight, std=in_features ** -0.5) + nn.init.zeros_(self.v.bias) + else: + in_features = self.qkv.in_features + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) + nn.init.zeros_(self.qkv.bias) trunc_normal_(self.pos_embed, std=in_features ** -0.5) - trunc_normal_(self.qkv.weight, std=in_features ** -0.5) - nn.init.zeros_(self.qkv.bias) def forward(self, x): B, _, H, W = x.shape N = H * W - assert self.feat_size[0] == H - assert self.feat_size[1] == W - x = x.reshape(B, -1, N).permute(0, 2, 1) + x = x.flatten(2).transpose(1, 2) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) - x = x + self.pos_embed.unsqueeze(0).to(x.dtype) - - x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) - q, k, v = x[0], x[1], x[2] - attn = (q @ k.transpose(-2, -1)) * self.scale - attn = attn.softmax(dim=-1) - - x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) + if self.seq_len != N: + pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) + else: + pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype) + x = x + pos_embed + + if self.qkv is None: + q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) + else: + x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = x.unbind(0) + + if self.fused_attn: + x = nn.functional.scaled_dot_product_attention(q, k, v) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + x = attn @ v + x = x.transpose(1, 2).reshape(B, N + 1, -1) x = self.proj(x) return x[:, 0] diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 27ee5e703a..2441c050e7 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -24,8 +24,6 @@ def _create_pool( ): flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling if not pool_type: - assert num_classes == 0 or use_conv,\ - 'Pooling can only be disabled if classifier is also removed or conv classifier is used' flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) global_pool = SelectAdaptivePool2d( pool_type=pool_type, diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index b5f8502f37..5bb31af59b 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -312,7 +312,6 @@ def __init__( temperature=temperature, step=1, ) - print(bands) self.register_buffer( 'bands', bands, diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 02e258361d..b9417dfe81 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -36,8 +36,8 @@ import torch import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD +from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -70,15 +70,23 @@ class ByoModelCfg: downsample: str = 'conv1x1' stem_type: str = '3x3' stem_pool: Optional[str] = 'maxpool' - stem_chs: int = 32 + stem_chs: Union[int, List[int], Tuple[int, ...]] = 32 width_factor: float = 1.0 num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 zero_init_last: bool = True # zero init last weight (usually bn) in residual path fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation + # layer config act_layer: str = 'relu' norm_layer: str = 'batchnorm' + aa_layer: str = '' + # Head config + attn_pool: str = '' + head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output + head_type: str = '' + + # Block config # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there attn_layer: Optional[str] = None attn_kwargs: dict = field(default_factory=lambda: dict()) @@ -296,10 +304,7 @@ def __init__( mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, - ) + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.conv2_kxk = layers.conv_norm_act( @@ -316,7 +321,10 @@ def __init__( self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - + self.shortcut = create_shortcut( + downsample, in_chs, out_chs, + stride=stride, dilation=dilation, apply_act=False, layers=layers, + ) def init_weights(self, zero_init_last: bool = False): if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) @@ -917,7 +925,7 @@ class Stem(nn.Sequential): def __init__( self, in_chs: int, - out_chs: int, + out_chs: Union[int, List[int], Tuple[int, ...]], kernel_size: int = 3, stride: int = 4, pool: str = 'maxpool', @@ -961,10 +969,19 @@ def __init__( curr_stride *= s prev_feat = conv_name - if pool and 'max' in pool.lower(): + if pool: + pool = pool.lower() + assert pool in ('max', 'maxpool', 'avg', 'avgpool', 'max2', 'avg2') last_feat_idx = i self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0)) - self.add_module('pool', nn.MaxPool2d(3, 2, 1)) + if pool == 'max2': + self.add_module('pool', nn.MaxPool2d(2)) + elif pool == 'avg2': + self.add_module('pool', nn.AvgPool2d(2)) + elif 'max' in pool: + self.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) + elif 'avg' in pool: + self.add_module('pool', nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False)) curr_stride *= 2 prev_feat = 'pool' @@ -1012,11 +1029,14 @@ def create_byob_stem( else: stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2) else: - # 3x3 stem conv as in RegNet is the default - if pool_type: - stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers) + if isinstance(out_chs, (tuple, list)): + stem = Stem(in_chs, out_chs, 3, pool=pool_type, layers=layers) else: - stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2) + # 3x3 stem conv as in RegNet is the default + if pool_type: + stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers) + else: + stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2) if isinstance(stem, Stem): feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info] @@ -1138,13 +1158,16 @@ def create_byob_stages( prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1) feature_info.append(prev_feat) - return nn.Sequential(*stages), feature_info + return nn.Sequential(*stages), feature_info, feat_size -def get_layer_fns(cfg: ByoModelCfg): +def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True): act = get_act_layer(cfg.act_layer) norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act) - conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act) + if cfg.aa_layer and allow_aa: + conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer) + else: + conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) @@ -1191,23 +1214,33 @@ def __init__( self.grad_checkpointing = False cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg - layers = get_layer_fns(cfg) + stem_layers = get_layer_fns(cfg, allow_aa=False) # keep aa off for stem-layers + stage_layers = get_layer_fns(cfg) if cfg.fixed_input_size: assert img_size is not None, 'img_size argument is required for fixed input size model' feat_size = to_2tuple(img_size) if img_size is not None else None self.feature_info = [] - stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) - self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) + if isinstance(cfg.stem_chs, (list, tuple)): + stem_chs = [int(round(c * cfg.width_factor)) for c in cfg.stem_chs] + else: + stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) + self.stem, stem_feat = create_byob_stem( + in_chs=in_chans, + out_chs=stem_chs, + stem_type=cfg.stem_type, + pool_type=cfg.stem_pool, + layers=stem_layers, + ) self.feature_info.extend(stem_feat[:-1]) feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) - self.stages, stage_feat = create_byob_stages( + self.stages, stage_feat, feat_size = create_byob_stages( cfg, drop_path_rate, output_stride, stem_feat[-1], - layers=layers, + layers=stage_layers, feat_size=feat_size, ) self.feature_info.extend(stage_feat[:-1]) @@ -1216,7 +1249,7 @@ def __init__( prev_chs = stage_feat[-1]['num_chs'] if cfg.num_features: self.num_features = int(round(cfg.width_factor * cfg.num_features)) - self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1) + self.final_conv = stage_layers.conv_norm_act(prev_chs, self.num_features, 1) else: self.num_features = prev_chs self.final_conv = nn.Identity() @@ -1225,12 +1258,47 @@ def __init__( self.stage_ends = [f['stage'] for f in self.feature_info] self.head_hidden_size = self.num_features - self.head = ClassifierHead( - self.num_features, - num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - ) + assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier') + if cfg.head_type == 'norm_mlp_classifier': + from timm.layers import NormMlpClassifierHead + assert not cfg.attn_pool, "Cannot use attentional pooling with norm + MLP head" + self.attn_pool = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + hidden_size=cfg.head_hidden_size, + norm_layer=cfg.norm_layer, + act_layer=cfg.act_layer, + ) + self.head_hidden_size = self.head.hidden_size + else: + if cfg.attn_pool == 'abs': + from timm.layers import AttentionPool2d + self.attn_pool = AttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + feat_size=feat_size, + qkv_separate=True, + ) + self.head_hidden_size = self.attn_pool.out_features + elif cfg.attn_pool == 'rot': + from timm.layers import RotAttentionPool2d + self.attn_pool = RotAttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + ref_feat_size=feat_size, + ) + self.head_hidden_size = self.attn_pool.out_features + else: + assert not cfg.attn_pool + self.attn_pool = nn.Identity() + + self.head = ClassifierHead( + self.head_hidden_size, + num_classes, + pool_type='' if cfg.attn_pool else global_pool, + drop_rate=self.drop_rate, + ) # init weights named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -1345,6 +1413,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): + x = self.attn_pool(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): @@ -1834,14 +1903,162 @@ def _init_weights(module, name='', zero_init_last=False): stem_type='one', stem_chs=64, ), + + resnet50_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=1024, + ), + + resnet101_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=23, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=512, + ), + + resnet50x4_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=4, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=10, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=2048, s=2, br=0.25), + ), + width_factor=1.25, + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=640, + ), + + resnet50x16_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=6, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=8, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=768, + ), + + resnet50x64_clip=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=15, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + attn_pool='abs', + head_hidden_size=1024, + ), + + resnet50_nmlp=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=(32, 32, 64), + stem_type='', + stem_pool='avg2', + downsample='avg', + aa_layer='avg', + head_hidden_size=1024, + head_type='norm_mlp_classifier', + ), ) +def _convert_openai_clip( + state_dict: Dict[str, torch.Tensor], + model: ByobNet, + prefix: str = 'visual.', +) -> Dict[str, torch.Tensor]: + import re + + def _stage_sub(m): + stage_idx = int(m.group(1)) - 1 + layer_idx, layer_type, layer_id = int(m.group(2)), m.group(3), int(m.group(4)) + prefix_str = f'stages.{stage_idx}.{layer_idx}.' + id_map = {1: 'conv1_1x1.', 2: 'conv2_kxk.', 3: 'conv3_1x1.'} + suffix_str = id_map[layer_id] + layer_type + return prefix_str + suffix_str + + def _down_sub(m): + stage_idx = int(m.group(1)) - 1 + layer_idx, layer_id = int(m.group(2)), int(m.group(3)) + return f'stages.{stage_idx}.{layer_idx}.shortcut.' + ('conv.conv' if layer_id == 0 else 'conv.bn') + + out_dict = {} + for k, v in state_dict.items(): + if not k.startswith(prefix): + continue + k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k) + k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k) + if k.startswith(f'{prefix}attnpool'): + k = k.replace(prefix + 'attnpool', 'attn_pool') + k = k.replace('positional_embedding', 'pos_embed') + k = k.replace('q_proj', 'q') + k = k.replace('k_proj', 'k') + k = k.replace('v_proj', 'v') + k = k.replace('c_proj', 'proj') + out_dict[k] = v + + return out_dict + + +def checkpoint_filter_fn( + state_dict: Dict[str, torch.Tensor], + model: ByobNet +): + if 'visual.conv1.weight' in state_dict: + state_dict = _convert_openai_clip(state_dict, model) + return state_dict + + def _create_byobnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), + #pretrained_strict=False, **kwargs) @@ -2035,6 +2252,38 @@ def _cfgr(url='', **kwargs): crop_pct=0.9, first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), ), + + 'resnet50_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + ), + 'resnet101_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + ), + 'resnet50x4_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9) + ), + 'resnet50x16_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12) + ), + 'resnet50x64_clip.openai': _cfgr( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14) + ), + }) @@ -2337,3 +2586,45 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet: """ """ return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50 CLIP image tower + """ + return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet101_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-101 CLIP image tower + """ + return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x4 CLIP image tower + """ + return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x16 CLIP image tower + """ + return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x64 CLIP image tower + """ + return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50_nmlp(pretrained=False, **kwargs) -> ByobNet: + """ + """ + return _create_byobnet('resnet50_nmlp', pretrained=pretrained, **kwargs) From f0fb471b26efbbae80a77eca9f6b09006acf5ba7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Jun 2024 12:05:35 -0700 Subject: [PATCH 2/6] Remove separate ConvNormActAa class, merge with ConvNormAct --- timm/layers/conv_bn_act.py | 83 ++++------------------------- timm/layers/selective_kernel.py | 4 +- timm/models/_efficientnet_blocks.py | 6 +-- timm/models/cspnet.py | 14 ++--- timm/models/tresnet.py | 24 +++------ 5 files changed, 28 insertions(+), 103 deletions(-) diff --git a/timm/layers/conv_bn_act.py b/timm/layers/conv_bn_act.py index de738045ec..73ad670562 100644 --- a/timm/layers/conv_bn_act.py +++ b/timm/layers/conv_bn_act.py @@ -26,7 +26,8 @@ def __init__( apply_norm: bool = True, apply_act: bool = True, norm_layer: LayerType = nn.BatchNorm2d, - act_layer: LayerType = nn.ReLU, + act_layer: Optional[LayerType] = nn.ReLU, + aa_layer: Optional[LayerType] = None, drop_layer: Optional[Type[nn.Module]] = None, conv_kwargs: Optional[Dict[str, Any]] = None, norm_kwargs: Optional[Dict[str, Any]] = None, @@ -36,12 +37,13 @@ def __init__( conv_kwargs = conv_kwargs or {} norm_kwargs = norm_kwargs or {} act_kwargs = act_kwargs or {} + use_aa = aa_layer is not None and stride > 1 self.conv = create_conv2d( in_channels, out_channels, kernel_size, - stride=stride, + stride=1 if use_aa else stride, padding=padding, dilation=dilation, groups=groups, @@ -67,6 +69,8 @@ def __init__( norm_kwargs['drop_layer'] = drop_layer self.bn.add_module('drop', drop_layer()) + self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa, noop=None) + @property def in_channels(self): return self.conv.in_channels @@ -78,79 +82,10 @@ def out_channels(self): def forward(self, x): x = self.conv(x) x = self.bn(x) + if self.aa is not None: + x = self.aa(x) return x ConvBnAct = ConvNormAct - - -class ConvNormActAa(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int = 1, - stride: int = 1, - padding: PadType = '', - dilation: int = 1, - groups: int = 1, - bias: bool = False, - apply_norm: bool = True, - apply_act: bool = True, - norm_layer: LayerType = nn.BatchNorm2d, - act_layer: LayerType = nn.ReLU, - aa_layer: Optional[LayerType] = None, - drop_layer: Optional[Type[nn.Module]] = None, - conv_kwargs: Optional[Dict[str, Any]] = None, - norm_kwargs: Optional[Dict[str, Any]] = None, - act_kwargs: Optional[Dict[str, Any]] = None, - ): - super(ConvNormActAa, self).__init__() - use_aa = aa_layer is not None and stride == 2 - conv_kwargs = conv_kwargs or {} - norm_kwargs = norm_kwargs or {} - act_kwargs = act_kwargs or {} - - self.conv = create_conv2d( - in_channels, out_channels, kernel_size, - stride=1 if use_aa else stride, - padding=padding, - dilation=dilation, - groups=groups, - bias=bias, - **conv_kwargs, - ) - - if apply_norm: - # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = get_norm_act_layer(norm_layer, act_layer) - # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn = norm_act_layer( - out_channels, - apply_act=apply_act, - act_kwargs=act_kwargs, - **norm_kwargs, - ) - else: - self.bn = nn.Sequential() - if drop_layer: - norm_kwargs['drop_layer'] = drop_layer - self.bn.add_module('drop', drop_layer()) - - self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) - - @property - def in_channels(self): - return self.conv.in_channels - - @property - def out_channels(self): - return self.conv.out_channels - - def forward(self, x): - x = self.conv(x) - x = self.bn(x) - x = self.aa(x) - return x +ConvNormActAa = ConvNormAct # backwards compat, when they were separate diff --git a/timm/layers/selective_kernel.py b/timm/layers/selective_kernel.py index 3d71e3aa69..ec8ee6ce27 100644 --- a/timm/layers/selective_kernel.py +++ b/timm/layers/selective_kernel.py @@ -7,7 +7,7 @@ import torch from torch import nn as nn -from .conv_bn_act import ConvNormActAa +from .conv_bn_act import ConvNormAct from .helpers import make_divisible from .trace_utils import _assert @@ -100,7 +100,7 @@ def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, d stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_layer=drop_layer) self.paths = nn.ModuleList([ - ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + ConvNormAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) for k, d in zip(kernel_size, dilation)]) attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) diff --git a/timm/models/_efficientnet_blocks.py b/timm/models/_efficientnet_blocks.py index f33dacd5a6..5f98c90c06 100644 --- a/timm/models/_efficientnet_blocks.py +++ b/timm/models/_efficientnet_blocks.py @@ -9,7 +9,7 @@ from torch.nn import functional as F from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, create_aa, to_2tuple, LayerType,\ - ConvNormAct, ConvNormActAa, get_norm_act_layer, MultiQueryAttention2d, Attention2d + ConvNormAct, get_norm_act_layer, MultiQueryAttention2d, Attention2d __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', @@ -345,7 +345,7 @@ def __init__( if dw_kernel_size_start: dw_start_stride = stride if not dw_kernel_size_mid else 1 dw_start_groups = num_groups(group_size, in_chs) - self.dw_start = ConvNormActAa( + self.dw_start = ConvNormAct( in_chs, in_chs, dw_kernel_size_start, stride=dw_start_stride, dilation=dilation, # FIXME @@ -373,7 +373,7 @@ def __init__( # Middle depth-wise convolution if dw_kernel_size_mid: groups = num_groups(group_size, mid_chs) - self.dw_mid = ConvNormActAa( + self.dw_mid = ConvNormAct( mid_chs, mid_chs, dw_kernel_size_mid, stride=stride, dilation=dilation, # FIXME diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 7d63096a51..a736882190 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -20,7 +20,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible +from timm.layers import ClassifierHead, ConvNormAct, DropPath, get_attn, create_act_layer, make_divisible from ._builder import build_model_with_cfg from ._manipulate import named_apply, MATCH_PREV_GROUP from ._registry import register_model, generate_default_cfgs @@ -296,10 +296,10 @@ def __init__( if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: - self.conv_down = ConvNormActAa( + self.conv_down = ConvNormAct( in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, aa_layer=aa_layer, **conv_kwargs) prev_chs = down_chs @@ -375,10 +375,10 @@ def __init__( if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: - self.conv_down = ConvNormActAa( + self.conv_down = ConvNormAct( in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, aa_layer=aa_layer, **conv_kwargs) prev_chs = down_chs @@ -442,10 +442,10 @@ def __init__( if avg_down: self.conv_down = nn.Sequential( nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling - ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) + ConvNormAct(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs) ) else: - self.conv_down = ConvNormActAa( + self.conv_down = ConvNormAct( in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, aa_layer=aa_layer, **conv_kwargs) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 006b7e0b5f..dec24c1f87 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -12,8 +12,7 @@ import torch import torch.nn as nn -from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\ - ConvNormActAa, ConvNormAct, DropPath +from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule, ConvNormAct, DropPath from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, generate_default_cfgs, register_model_deprecations @@ -39,13 +38,8 @@ def __init__( self.stride = stride act_layer = partial(nn.LeakyReLU, negative_slope=1e-3) - if stride == 1: - self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer) - else: - self.conv1 = ConvNormActAa( - inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer) - - self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None) + self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) + self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False) self.act = nn.ReLU(inplace=True) rd_chs = max(planes * self.expansion // 4, 64) @@ -87,18 +81,14 @@ def __init__( self.conv1 = ConvNormAct( inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer) - if stride == 1: - self.conv2 = ConvNormAct( - planes, planes, kernel_size=3, stride=1, act_layer=act_layer) - else: - self.conv2 = ConvNormActAa( - planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer) + self.conv2 = ConvNormAct( + planes, planes, kernel_size=3, stride=stride, act_layer=act_layer, aa_layer=aa_layer) reduction_chs = max(planes * self.expansion // 8, 64) self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None self.conv3 = ConvNormAct( - planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None) + planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act = nn.ReLU(inplace=True) @@ -204,7 +194,7 @@ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=Non # avg pooling before 1x1 conv layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) layers += [ConvNormAct( - self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)] + self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False)] downsample = nn.Sequential(*layers) layers = [] From 5e9ff5798f5c7bd463944c483fd9619e701dd349 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Jun 2024 12:06:47 -0700 Subject: [PATCH 3/6] Adding pos embed resize fns to FX autowrap exceptions --- timm/layers/pos_embed.py | 5 ++--- timm/models/_features_fx.py | 6 +++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/timm/layers/pos_embed.py b/timm/layers/pos_embed.py index 3e67be0080..0d50207dd8 100644 --- a/timm/layers/pos_embed.py +++ b/timm/layers/pos_embed.py @@ -15,7 +15,7 @@ def resample_abs_pos_embed( - posemb, + posemb: torch.Tensor, new_size: List[int], old_size: Optional[List[int]] = None, num_prefix_tokens: int = 1, @@ -58,7 +58,7 @@ def resample_abs_pos_embed( def resample_abs_pos_embed_nhwc( - posemb, + posemb: torch.Tensor, new_size: List[int], interpolation: str = 'bicubic', antialias: bool = True, @@ -69,7 +69,6 @@ def resample_abs_pos_embed_nhwc( orig_dtype = posemb.dtype posemb = posemb.float() - # do the interpolation posemb = posemb.reshape(1, posemb.shape[-3], posemb.shape[-2], posemb.shape[-1]).permute(0, 3, 1, 2) posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias) posemb = posemb.permute(0, 2, 3, 1).to(orig_dtype) diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 3a276046cb..1ea4a4f4a1 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -18,6 +18,7 @@ # Layers we went to treat as leaf modules from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format +from timm.layers import resample_abs_pos_embed, resample_abs_pos_embed_nhwc from timm.layers.non_local_attn import BilinearAttnTransform from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame from timm.layers.norm_act import ( @@ -75,7 +76,10 @@ def get_notrace_modules(): # Functions we want to autowrap (treat them as leaves) -_autowrap_functions = set() +_autowrap_functions = { + resample_abs_pos_embed, + resample_abs_pos_embed_nhwc, +} def register_notrace_function(func: Callable): From 30ffa152de73cdec43bb9ab63122f042179ce95c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Jun 2024 12:07:14 -0700 Subject: [PATCH 4/6] Fix load of larger ResNet CLIP models, experimenting with making AttentionPool *the* head, seems to fine-tune better, one less layer. --- timm/layers/attention_pool2d.py | 28 +++--- timm/models/byobnet.py | 152 ++++++++++++++++++++------------ 2 files changed, 116 insertions(+), 64 deletions(-) diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index dc594b70dd..443a384ca8 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -41,9 +41,10 @@ def __init__( num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, + drop: float = 0., ): super().__init__() - embed_dim = embed_dim or in_features + self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features ref_feat_size = to_2tuple(ref_feat_size) @@ -82,7 +83,7 @@ def init_weights(self, zero_init_last: bool = False): trunc_normal_(self.qkv.weight, std=in_features ** -0.5) nn.init.zeros_(self.qkv.bias) - def forward(self, x): + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) @@ -107,8 +108,12 @@ def forward(self, x): attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, N + 1, -1) + x = x[:, 0] + x = self.drop(x) + if pre_logits: + return x x = self.proj(x) - return x[:, 0] + return x class AttentionPool2d(nn.Module): @@ -132,9 +137,10 @@ def __init__( num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, + drop: float = 0., ): super().__init__() - embed_dim = embed_dim or in_features + self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features if num_heads is not None: @@ -158,6 +164,7 @@ def __init__( else: self.q = self.k = self.v = None self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.drop = nn.Dropout(drop) self.proj = nn.Linear(embed_dim, self.out_features) self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) @@ -178,15 +185,12 @@ def init_weights(self, zero_init_last: bool = False): nn.init.zeros_(self.qkv.bias) trunc_normal_(self.pos_embed, std=in_features ** -0.5) - def forward(self, x): + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) x = torch.cat([x.mean(1, keepdim=True), x], dim=1) - if self.seq_len != N: - pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) - else: - pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype) + pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) x = x + pos_embed if self.qkv is None: @@ -205,5 +209,9 @@ def forward(self, x): attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, N + 1, -1) + x = x[:, 0] + x = self.drop(x) + if pre_logits: + return x x = self.proj(x) - return x[:, 0] + return x diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index b9417dfe81..f84f24556d 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -37,7 +37,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, BatchNormAct2d, DropPath, AvgPool2dSame, \ +from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a from ._builder import build_model_with_cfg from ._features import feature_take_indices @@ -82,7 +82,6 @@ class ByoModelCfg: aa_layer: str = '' # Head config - attn_pool: str = '' head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output head_type: str = '' @@ -304,7 +303,10 @@ def __init__( mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio) groups = num_groups(group_size, mid_chs) - + self.shortcut = create_shortcut( + downsample, in_chs, out_chs, + stride=stride, dilation=dilation, apply_act=False, layers=layers, + ) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.conv2_kxk = layers.conv_norm_act( @@ -321,10 +323,7 @@ def __init__( self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - self.shortcut = create_shortcut( - downsample, in_chs, out_chs, - stride=stride, dilation=dilation, apply_act=False, layers=layers, - ) + def init_weights(self, zero_init_last: bool = False): if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None: nn.init.zeros_(self.conv3_1x1.bn.weight) @@ -1165,7 +1164,7 @@ def get_layer_fns(cfg: ByoModelCfg, allow_aa: bool = True): act = get_act_layer(cfg.act_layer) norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act) if cfg.aa_layer and allow_aa: - conv_norm_act = partial(ConvNormActAa, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer) + conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act, aa_layer=cfg.aa_layer) else: conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None @@ -1258,6 +1257,7 @@ def __init__( self.stage_ends = [f['stage'] for f in self.feature_info] self.head_hidden_size = self.num_features + self.global_pool = global_pool assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier') if cfg.head_type == 'norm_mlp_classifier': from timm.layers import NormMlpClassifierHead @@ -1272,33 +1272,61 @@ def __init__( ) self.head_hidden_size = self.head.hidden_size else: - if cfg.attn_pool == 'abs': - from timm.layers import AttentionPool2d - self.attn_pool = AttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - feat_size=feat_size, - qkv_separate=True, + # FIXME evaluating different head vs pool configurations + if False: + if global_pool == 'attn_abs': + from timm.layers import AttentionPool2d + self.attn_pool = AttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + feat_size=feat_size, + qkv_separate=True, + ) + global_pool = '' # clear for ClassifierHead + self.head_hidden_size = self.attn_pool.out_features + elif global_pool =='attn_rot': + from timm.layers import RotAttentionPool2d + self.attn_pool = RotAttentionPool2d( + self.num_features, + out_features=cfg.head_hidden_size, + ref_feat_size=feat_size, + ) + global_pool = '' # clear for ClassifierHead + self.head_hidden_size = self.attn_pool.out_features + else: + self.attn_pool = nn.Identity() + + self.head = ClassifierHead( + self.head_hidden_size, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, ) - self.head_hidden_size = self.attn_pool.out_features - elif cfg.attn_pool == 'rot': - from timm.layers import RotAttentionPool2d - self.attn_pool = RotAttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - ref_feat_size=feat_size, - ) - self.head_hidden_size = self.attn_pool.out_features else: - assert not cfg.attn_pool - self.attn_pool = nn.Identity() - - self.head = ClassifierHead( - self.head_hidden_size, - num_classes, - pool_type='' if cfg.attn_pool else global_pool, - drop_rate=self.drop_rate, - ) + if global_pool == 'attn_abs': + from timm.layers import AttentionPool2d + self.head = AttentionPool2d( + self.num_features, + out_features=num_classes, + feat_size=feat_size, + qkv_separate=True, + ) + self.head_hidden_size = self.head.embed_dim + elif global_pool == 'attn_rot': + from timm.layers import RotAttentionPool2d + self.head = RotAttentionPool2d( + self.num_features, + out_features=num_classes, + ref_feat_size=feat_size, + ) + self.head_hidden_size = self.head.embed_dim + else: + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) # init weights named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -1324,6 +1352,9 @@ def get_classifier(self) -> nn.Module: def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes + if global_pool is not None: + if self.global_pool in ('attn_abs', 'attn_rot'): + raise RuntimeError('Cannot change attention pool on head reset.') self.head.reset(num_classes, global_pool) def forward_intermediates( @@ -1413,7 +1444,7 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - x = self.attn_pool(x) + #x = self.attn_pool(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): @@ -1916,7 +1947,6 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=1024, ), @@ -1932,7 +1962,6 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=512, ), @@ -1949,7 +1978,6 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=640, ), @@ -1960,12 +1988,12 @@ def _init_weights(module, name='', zero_init_last=False): ByoBlockCfg(type='bottle', d=18, c=1024, s=2, br=0.25), ByoBlockCfg(type='bottle', d=8, c=2048, s=2, br=0.25), ), + width_factor=1.5, stem_chs=(32, 32, 64), stem_type='', stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=768, ), @@ -1976,12 +2004,12 @@ def _init_weights(module, name='', zero_init_last=False): ByoBlockCfg(type='bottle', d=36, c=1024, s=2, br=0.25), ByoBlockCfg(type='bottle', d=10, c=2048, s=2, br=0.25), ), + width_factor=2.0, stem_chs=(32, 32, 64), stem_type='', stem_pool='avg2', downsample='avg', aa_layer='avg', - attn_pool='abs', head_hidden_size=1024, ), @@ -2029,10 +2057,10 @@ def _down_sub(m): continue k = re.sub(rf'{prefix}conv([0-9])', r'stem.conv\1.conv', k) k = re.sub(rf'{prefix}bn([0-9])', r'stem.conv\1.bn', k) - k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.([a-z]+)([0-9])', _stage_sub, k) - k = re.sub(rf'{prefix}layer([0-9])\.([0-9])\.downsample\.([0-9])', _down_sub, k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k) + k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k) if k.startswith(f'{prefix}attnpool'): - k = k.replace(prefix + 'attnpool', 'attn_pool') + k = k.replace(prefix + 'attnpool', 'head') #'attn_pool') k = k.replace('positional_embedding', 'pos_embed') k = k.replace('q_proj', 'q') k = k.replace('k_proj', 'k') @@ -2053,13 +2081,19 @@ def checkpoint_filter_fn( def _create_byobnet(variant, pretrained=False, **kwargs): + strict = True + if 'clip' in variant and kwargs.get('global_pool', None) != 'attn_abs': + # NOTE: a hack to allow removing attention pool from CLIP ResNet variants + strict = False + return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), - #pretrained_strict=False, - **kwargs) + pretrained_strict=strict, + **kwargs, + ) def _cfg(url='', **kwargs): @@ -2257,31 +2291,36 @@ def _cfgr(url='', **kwargs): hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), + classifier = 'head.proj', ), 'resnet101_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7) + fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), + classifier='head.proj', ), 'resnet50x4_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9) + fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9), + classifier = 'head.proj', ), 'resnet50x16_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12) + fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12), + classifier = 'head.proj', ), 'resnet50x64_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14) + fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14), + classifier = 'head.proj', ), }) @@ -2592,35 +2631,40 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet: def resnet50_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50 CLIP image tower """ - return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet101_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-101 CLIP image tower """ - return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet101_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x4 CLIP image tower """ - return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x16 CLIP image tower """ - return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x64 CLIP image tower """ - return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs) + model_args = dict(global_pool='attn_abs') + return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model From cdc7bcea696be62377420e5f500307f20483e1c3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 11 Jun 2024 21:32:07 -0700 Subject: [PATCH 5/6] Make 2d attention pool modules compatible with head interface. Use attention pool in CLIP ResNets as head. Make separate set of GAP models w/ avg pool instead of attn pool. --- timm/layers/attention_pool2d.py | 64 +++++++- timm/models/byobnet.py | 249 +++++++++++++++++++------------- 2 files changed, 209 insertions(+), 104 deletions(-) diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index 443a384ca8..21328fd367 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -41,9 +41,12 @@ def __init__( num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, - drop: float = 0., + pool_type: str = 'token', + avg_token: bool = True, + drop_rate: float = 0., ): super().__init__() + assert pool_type in ('', 'token') self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features @@ -56,6 +59,7 @@ def __init__( num_heads = embed_dim // head_dim self.num_heads = num_heads self.head_dim = head_dim + self.pool_type = pool_type.lower() self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() @@ -66,6 +70,7 @@ def __init__( self.qkv = None else: self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) + self.drop = nn.Dropout(drop_rate) self.proj = nn.Linear(embed_dim, self.out_features) self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size) @@ -83,6 +88,23 @@ def init_weights(self, zero_init_last: bool = False): trunc_normal_(self.qkv.weight, std=in_features ** -0.5) nn.init.zeros_(self.qkv.bias) + def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None): + # NOTE: this module is being used as a head, so need compatible reset() + if pool_type is not None: + assert pool_type in ('', 'token') + self.pool_type = pool_type + if num_classes is not None: + self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + self.out_features = num_classes if num_classes > 0 else self.embed_dim + + def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + if self.pool_type == 'token': + x = x[:, 0] + else: + # if not pooled, return spatial output without token + x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + return x + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W @@ -111,8 +133,10 @@ def forward(self, x, pre_logits: bool = False): x = x[:, 0] x = self.drop(x) if pre_logits: + x = self._pool(x, H, W) return x x = self.proj(x) + x = self._pool(x, H, W) return x @@ -137,9 +161,12 @@ def __init__( num_heads: Optional[int] = None, qkv_bias: bool = True, qkv_separate: bool = False, - drop: float = 0., + pool_type: str = 'token', + learned_token: bool = False, + drop_rate: float = 0., ): super().__init__() + assert pool_type in ('', 'token') self.embed_dim = embed_dim = embed_dim or in_features self.in_features = in_features self.out_features = out_features or in_features @@ -153,9 +180,15 @@ def __init__( self.seq_len = self.feat_size[0] * self.feat_size[1] self.num_heads = num_heads self.head_dim = head_dim + self.pool_type = pool_type self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() + if learned_token: + self.token = nn.Parameter(torch.zeros(1, embed_dim)) + else: + self.token = None + if qkv_separate: self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) @@ -164,7 +197,7 @@ def __init__( else: self.q = self.k = self.v = None self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) - self.drop = nn.Dropout(drop) + self.drop = nn.Dropout(drop_rate) self.proj = nn.Linear(embed_dim, self.out_features) self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features)) @@ -185,11 +218,31 @@ def init_weights(self, zero_init_last: bool = False): nn.init.zeros_(self.qkv.bias) trunc_normal_(self.pos_embed, std=in_features ** -0.5) + def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None): + # NOTE: this module is being used as a head, so need compatible reset() + if pool_type is not None: + assert pool_type in ('', 'token') + self.pool_type = pool_type + if num_classes is not None: + self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity() + self.out_features = num_classes if num_classes > 0 else self.embed_dim + + def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + if self.pool_type == 'token': + x = x[:, 0] + else: + # if not pooled, return spatial output without token + x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + return x + def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) - x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + if self.token is not None: + x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1) + else: + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) x = x + pos_embed @@ -209,9 +262,10 @@ def forward(self, x, pre_logits: bool = False): attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, N + 1, -1) - x = x[:, 0] x = self.drop(x) if pre_logits: + x = self._pool(x, H, W) return x x = self.proj(x) + x = self._pool(x, H, W) return x diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index f84f24556d..d3bf9c0d29 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -37,8 +37,11 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a +from timm.layers import ( + ClassifierHead, NormMlpClassifierHead, ConvNormAct, BatchNormAct2d, EvoNorm2dS0a, + AttentionPool2d, RotAttentionPool2d, DropPath, AvgPool2dSame, + create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, +) from ._builder import build_model_with_cfg from ._features import feature_take_indices from ._manipulate import named_apply, checkpoint_seq @@ -83,7 +86,7 @@ class ByoModelCfg: # Head config head_hidden_size: Optional[int] = None # feat dim of MLP head or AttentionPool output - head_type: str = '' + head_type: str = 'classifier' # Block config # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there @@ -1186,7 +1189,7 @@ def __init__( cfg: ByoModelCfg, num_classes: int = 1000, in_chans: int = 3, - global_pool: str = 'avg', + global_pool: Optional[str] = None, output_stride: int = 32, img_size: Optional[Union[int, Tuple[int, int]]] = None, drop_rate: float = 0., @@ -1257,76 +1260,59 @@ def __init__( self.stage_ends = [f['stage'] for f in self.feature_info] self.head_hidden_size = self.num_features - self.global_pool = global_pool - assert cfg.head_type in ('', 'classifier', 'norm_mlp_classifier') - if cfg.head_type == 'norm_mlp_classifier': - from timm.layers import NormMlpClassifierHead - assert not cfg.attn_pool, "Cannot use attentional pooling with norm + MLP head" - self.attn_pool = nn.Identity() + assert cfg.head_type in ('', 'classifier', 'mlp', 'attn_abs', 'attn_rot') + if cfg.head_type == 'mlp': + if global_pool is None: + global_pool = 'avg' self.head = NormMlpClassifierHead( self.num_features, num_classes, hidden_size=cfg.head_hidden_size, + pool_type=global_pool, norm_layer=cfg.norm_layer, act_layer=cfg.act_layer, + drop_rate=self.drop_rate, ) self.head_hidden_size = self.head.hidden_size + elif cfg.head_type == 'attn_abs': + if global_pool is None: + global_pool = 'token' + assert global_pool in ('', 'token') + self.head = AttentionPool2d( + self.num_features, + embed_dim=cfg.head_hidden_size, + out_features=num_classes, + feat_size=feat_size, + pool_type=global_pool, + drop_rate=self.drop_rate, + qkv_separate=True, + ) + self.head_hidden_size = self.head.embed_dim + elif cfg.head_type =='attn_rot': + if global_pool is None: + global_pool = 'token' + assert global_pool in ('', 'token') + self.head = RotAttentionPool2d( + self.num_features, + embed_dim=cfg.head_hidden_size, + out_features=num_classes, + ref_feat_size=feat_size, + pool_type=global_pool, + drop_rate=self.drop_rate, + qkv_separate=True, + ) + self.head_hidden_size = self.head.embed_dim else: - # FIXME evaluating different head vs pool configurations - if False: - if global_pool == 'attn_abs': - from timm.layers import AttentionPool2d - self.attn_pool = AttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - feat_size=feat_size, - qkv_separate=True, - ) - global_pool = '' # clear for ClassifierHead - self.head_hidden_size = self.attn_pool.out_features - elif global_pool =='attn_rot': - from timm.layers import RotAttentionPool2d - self.attn_pool = RotAttentionPool2d( - self.num_features, - out_features=cfg.head_hidden_size, - ref_feat_size=feat_size, - ) - global_pool = '' # clear for ClassifierHead - self.head_hidden_size = self.attn_pool.out_features - else: - self.attn_pool = nn.Identity() - - self.head = ClassifierHead( - self.head_hidden_size, - num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - ) - else: - if global_pool == 'attn_abs': - from timm.layers import AttentionPool2d - self.head = AttentionPool2d( - self.num_features, - out_features=num_classes, - feat_size=feat_size, - qkv_separate=True, - ) - self.head_hidden_size = self.head.embed_dim - elif global_pool == 'attn_rot': - from timm.layers import RotAttentionPool2d - self.head = RotAttentionPool2d( - self.num_features, - out_features=num_classes, - ref_feat_size=feat_size, - ) - self.head_hidden_size = self.head.embed_dim - else: - self.head = ClassifierHead( - self.num_features, - num_classes, - pool_type=global_pool, - drop_rate=self.drop_rate, - ) + if global_pool is None: + global_pool = 'avg' + assert cfg.head_hidden_size is None + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + self.global_pool = global_pool # init weights named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -1352,9 +1338,6 @@ def get_classifier(self) -> nn.Module: def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): self.num_classes = num_classes - if global_pool is not None: - if self.global_pool in ('attn_abs', 'attn_rot'): - raise RuntimeError('Cannot change attention pool on head reset.') self.head.reset(num_classes, global_pool) def forward_intermediates( @@ -1444,7 +1427,6 @@ def forward_features(self, x): return x def forward_head(self, x, pre_logits: bool = False): - #x = self.attn_pool(x) return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) def forward(self, x): @@ -1947,7 +1929,7 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=1024, + head_type='attn_abs', ), resnet101_clip=ByoModelCfg( @@ -1962,7 +1944,8 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=512, + head_type='attn_abs', + #head_hidden_size=512, ), resnet50x4_clip=ByoModelCfg( @@ -1978,7 +1961,8 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=640, + head_type='attn_abs', + #head_hidden_size=640, ), resnet50x16_clip=ByoModelCfg( @@ -1994,7 +1978,8 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=768, + head_type='attn_abs', + #head_hidden_size=768, ), resnet50x64_clip=ByoModelCfg( @@ -2010,10 +1995,11 @@ def _init_weights(module, name='', zero_init_last=False): stem_pool='avg2', downsample='avg', aa_layer='avg', - head_hidden_size=1024, + head_type='attn_abs', + #head_hidden_size=1024, ), - resnet50_nmlp=ByoModelCfg( + resnet50_mlp=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), @@ -2026,9 +2012,11 @@ def _init_weights(module, name='', zero_init_last=False): downsample='avg', aa_layer='avg', head_hidden_size=1024, - head_type='norm_mlp_classifier', + head_type='mlp', ), ) +for k in ('resnet50_clip', 'resnet101_clip', 'resnet50x4_clip', 'resnet50x16_clip', 'resnet50x64_clip'): + model_cfgs[k + '_gap'] = replace(model_cfgs[k], head_type='classifier') def _convert_openai_clip( @@ -2036,6 +2024,7 @@ def _convert_openai_clip( model: ByobNet, prefix: str = 'visual.', ) -> Dict[str, torch.Tensor]: + model_has_attn_pool = isinstance(model.head, (RotAttentionPool2d, AttentionPool2d)) import re def _stage_sub(m): @@ -2060,6 +2049,8 @@ def _down_sub(m): k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.([a-z]+)([0-9])', _stage_sub, k) k = re.sub(rf'{prefix}layer([0-9])\.([0-9]+)\.downsample\.([0-9])', _down_sub, k) if k.startswith(f'{prefix}attnpool'): + if not model_has_attn_pool: + continue k = k.replace(prefix + 'attnpool', 'head') #'attn_pool') k = k.replace('positional_embedding', 'pos_embed') k = k.replace('q_proj', 'q') @@ -2081,17 +2072,11 @@ def checkpoint_filter_fn( def _create_byobnet(variant, pretrained=False, **kwargs): - strict = True - if 'clip' in variant and kwargs.get('global_pool', None) != 'attn_abs': - # NOTE: a hack to allow removing attention pool from CLIP ResNet variants - strict = False - return build_model_with_cfg( ByobNet, variant, pretrained, model_cfg=model_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True), - pretrained_strict=strict, **kwargs, ) @@ -2287,42 +2272,78 @@ def _cfgr(url='', **kwargs): first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'), ), + # original attention pool head variants 'resnet50_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), classifier = 'head.proj', ), 'resnet101_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=512, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 224, 224), pool_size=(7, 7), classifier='head.proj', ), 'resnet50x4_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=640, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 288, 288), pool_size=(9, 9), classifier = 'head.proj', ), 'resnet50x16_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=768, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 384, 384), pool_size=(12, 12), classifier = 'head.proj', ), 'resnet50x64_clip.openai': _cfgr( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', - num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + num_classes=1024, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, fixed_input_size=True, input_size=(3, 448, 448), pool_size=(14, 14), classifier = 'head.proj', ), + # avg-pool w/ optional standard classifier head variants + 'resnet50_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 224, 224), pool_size=(7, 7), + ), + 'resnet101_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet101_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 224, 224), pool_size=(7, 7), + ), + 'resnet50x4_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50x4_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 288, 288), pool_size=(9, 9), + ), + 'resnet50x16_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50x16_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 384, 384), pool_size=(12, 12), + ), + 'resnet50x64_clip_gap.openai': _cfgr( + hf_hub_id='timm/resnet50x64_clip.openai', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 448, 448), pool_size=(14, 14), + ), + + 'resnet50_mlp.untrained': _cfgr( + input_size=(3, 256, 256), pool_size=(8, 8), + ), }) @@ -2631,44 +2652,74 @@ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet: def resnet50_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50_clip', pretrained=pretrained, **kwargs) @register_model def resnet101_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-101 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet101_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet101_clip', pretrained=pretrained, **kwargs) @register_model def resnet50x4_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x4 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50x4_clip', pretrained=pretrained, **kwargs) @register_model def resnet50x16_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x16 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50x16_clip', pretrained=pretrained, **kwargs) @register_model def resnet50x64_clip(pretrained=False, **kwargs) -> ByobNet: """ OpenAI Modified ResNet-50x64 CLIP image tower """ - model_args = dict(global_pool='attn_abs') - return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_byobnet('resnet50x64_clip', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet101_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-101 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet101_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x4_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x4 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50x4_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x16_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x16 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50x16_clip_gap', pretrained=pretrained, **kwargs) + + +@register_model +def resnet50x64_clip_gap(pretrained=False, **kwargs) -> ByobNet: + """ OpenAI Modified ResNet-50x64 CLIP image tower w/ avg pool (no attention pool) + """ + return _create_byobnet('resnet50x64_clip_gap', pretrained=pretrained, **kwargs) @register_model -def resnet50_nmlp(pretrained=False, **kwargs) -> ByobNet: +def resnet50_mlp(pretrained=False, **kwargs) -> ByobNet: """ """ - return _create_byobnet('resnet50_nmlp', pretrained=pretrained, **kwargs) + return _create_byobnet('resnet50_mlp', pretrained=pretrained, **kwargs) From 57adc1acc8153583711068e2d1ff3e845891d0bc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 11 Jun 2024 23:49:17 -0700 Subject: [PATCH 6/6] Fix rotary embed version of attn pool. Bit of cleanup/naming --- timm/layers/attention_pool2d.py | 27 +++++++++++++++++---------- timm/models/byobnet.py | 4 ---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/timm/layers/attention_pool2d.py b/timm/layers/attention_pool2d.py index 21328fd367..e6c7041760 100644 --- a/timm/layers/attention_pool2d.py +++ b/timm/layers/attention_pool2d.py @@ -42,7 +42,7 @@ def __init__( qkv_bias: bool = True, qkv_separate: bool = False, pool_type: str = 'token', - avg_token: bool = True, + class_token: bool = False, drop_rate: float = 0., ): super().__init__() @@ -63,6 +63,11 @@ def __init__( self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() + if class_token: + self.cls_token = nn.Parameter(torch.zeros(1, embed_dim)) + else: + self.cls_token = None + if qkv_separate: self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias) @@ -109,7 +114,10 @@ def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) - x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + if self.cls_token is None: + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + else: + x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1) if self.qkv is None: q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2) @@ -130,7 +138,6 @@ def forward(self, x, pre_logits: bool = False): attn = attn.softmax(dim=-1) x = attn @ v x = x.transpose(1, 2).reshape(B, N + 1, -1) - x = x[:, 0] x = self.drop(x) if pre_logits: x = self._pool(x, H, W) @@ -162,7 +169,7 @@ def __init__( qkv_bias: bool = True, qkv_separate: bool = False, pool_type: str = 'token', - learned_token: bool = False, + class_token: bool = False, drop_rate: float = 0., ): super().__init__() @@ -184,10 +191,10 @@ def __init__( self.scale = self.head_dim ** -0.5 self.fused_attn = use_fused_attn() - if learned_token: - self.token = nn.Parameter(torch.zeros(1, embed_dim)) + if class_token: + self.cls_token = nn.Parameter(torch.zeros(1, embed_dim)) else: - self.token = None + self.cls_token = None if qkv_separate: self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias) @@ -239,10 +246,10 @@ def forward(self, x, pre_logits: bool = False): B, _, H, W = x.shape N = H * W x = x.flatten(2).transpose(1, 2) - if self.token is not None: - x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1) - else: + if self.cls_token is None: x = torch.cat([x.mean(1, keepdim=True), x], dim=1) + else: + x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1) pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1) x = x + pos_embed diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index d3bf9c0d29..32ecfa4457 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -1945,7 +1945,6 @@ def _init_weights(module, name='', zero_init_last=False): downsample='avg', aa_layer='avg', head_type='attn_abs', - #head_hidden_size=512, ), resnet50x4_clip=ByoModelCfg( @@ -1962,7 +1961,6 @@ def _init_weights(module, name='', zero_init_last=False): downsample='avg', aa_layer='avg', head_type='attn_abs', - #head_hidden_size=640, ), resnet50x16_clip=ByoModelCfg( @@ -1979,7 +1977,6 @@ def _init_weights(module, name='', zero_init_last=False): downsample='avg', aa_layer='avg', head_type='attn_abs', - #head_hidden_size=768, ), resnet50x64_clip=ByoModelCfg( @@ -1996,7 +1993,6 @@ def _init_weights(module, name='', zero_init_last=False): downsample='avg', aa_layer='avg', head_type='attn_abs', - #head_hidden_size=1024, ), resnet50_mlp=ByoModelCfg(