From 2f3fed43b8888b07585242bf591a1d6cb9118b7c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 15 Aug 2024 11:14:38 -0700 Subject: [PATCH] Fix hiera init with num_classes=0, fix weight tag names for sbb2 hiera/vit weights, add LayerScale/LayerScale2d to layers --- timm/layers/__init__.py | 1 + timm/layers/layer_scale.py | 38 +++++++++++++++++++++++++++++++ timm/models/hiera.py | 34 +++++++-------------------- timm/models/vision_transformer.py | 4 ++-- 4 files changed, 49 insertions(+), 28 deletions(-) create mode 100644 timm/layers/layer_scale.py diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 38c824077c..6111558908 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -29,6 +29,7 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple from .hybrid_embed import HybridEmbed, HybridEmbedWithSize from .inplace_abn import InplaceAbn +from .layer_scale import LayerScale, LayerScale2d from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp diff --git a/timm/layers/layer_scale.py b/timm/layers/layer_scale.py new file mode 100644 index 0000000000..08566b2bd1 --- /dev/null +++ b/timm/layers/layer_scale.py @@ -0,0 +1,38 @@ +import torch +from torch import nn + + +class LayerScale(nn.Module): + """ LayerScale on tensors with channels in last-dim. + """ + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class LayerScale2d(nn.Module): + """ LayerScale for tensors with torch 2D NCHW layout. + """ + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + diff --git a/timm/models/hiera.py b/timm/models/hiera.py index 69af2f48ad..808053e9ee 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -31,10 +31,8 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple - +from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple from ._registry import generate_default_cfgs, register_model from ._builder import build_model_with_cfg @@ -289,7 +287,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ Input should be of shape [batch, tokens, channels]. """ B, N, _ = x.shape num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1 - qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5) q, k, v = qkv.unbind(0) @@ -310,21 +307,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class LayerScale(nn.Module): - def __init__( - self, - dim: int, - init_values: float = 1e-5, - inplace: bool = False, - ) -> None: - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - class HieraBlock(nn.Module): def __init__( self, @@ -342,7 +324,6 @@ def __init__( use_mask_unit_attn: bool = False, ): super().__init__() - self.dim = dim self.dim_out = dim_out @@ -631,10 +612,8 @@ def __init__( nn.init.trunc_normal_(self.pos_embed_win, std=0.02) if weight_init != 'skip': - if weight_init == 'jax': - named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self) - else: - named_apply(_init_weight_vit, self) + init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit + named_apply(init_fn, self) if fix_init: self.fix_init_weight() if isinstance(self.head.fc, nn.Linear): @@ -868,11 +847,13 @@ def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.): nn.init.trunc_normal_(module.weight, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.constant_(module.bias, init_bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def _init_weight_jax(module, name, head_bias=0.): if isinstance(module, nn.Linear): - if name.startswith('head'): + if name.startswith('head.fc'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) else: @@ -960,7 +941,7 @@ def _cfg(url='', **kwargs): num_classes=0, ), - "hiera_small_abswin_256.sbb2_ep200_in12k": _cfg( + "hiera_small_abswin_256.sbb2_e200_in12k": _cfg( hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95, @@ -1007,6 +988,7 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera: **kwargs, ) + @register_model def hiera_tiny_224(pretrained=False, **kwargs): model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2)) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 64cab9ee71..afb5e00200 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1967,7 +1967,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg( hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95), @@ -1984,7 +1984,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg( hf_hub_id='timm/', input_size=(3, 256, 256), crop_pct=0.95), - 'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg( + 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg( hf_hub_id='timm/', num_classes=11821, input_size=(3, 256, 256), crop_pct=0.95),