Skip to content

Commit

Permalink
Fix hiera init with num_classes=0, fix weight tag names for sbb2 hier…
Browse files Browse the repository at this point in the history
…a/vit weights, add LayerScale/LayerScale2d to layers
  • Loading branch information
rwightman committed Aug 15, 2024
1 parent fee91fd commit 2f3fed4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 28 deletions.
1 change: 1 addition & 0 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
from .inplace_abn import InplaceAbn
from .layer_scale import LayerScale, LayerScale2d
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
Expand Down
38 changes: 38 additions & 0 deletions timm/layers/layer_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from torch import nn


class LayerScale(nn.Module):
""" LayerScale on tensors with channels in last-dim.
"""
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma


class LayerScale2d(nn.Module):
""" LayerScale for tensors with torch 2D NCHW layout.
"""
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma

34 changes: 8 additions & 26 deletions timm/models/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint


from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer, to_2tuple

from timm.layers import DropPath, Mlp, LayerScale, use_fused_attn, _assert, get_norm_layer, to_2tuple

from ._registry import generate_default_cfgs, register_model
from ._builder import build_model_with_cfg
Expand Down Expand Up @@ -289,7 +287,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
""" Input should be of shape [batch, tokens, channels]. """
B, N, _ = x.shape
num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1

qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
q, k, v = qkv.unbind(0)

Expand All @@ -310,21 +307,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma


class HieraBlock(nn.Module):
def __init__(
self,
Expand All @@ -342,7 +324,6 @@ def __init__(
use_mask_unit_attn: bool = False,
):
super().__init__()

self.dim = dim
self.dim_out = dim_out

Expand Down Expand Up @@ -631,10 +612,8 @@ def __init__(
nn.init.trunc_normal_(self.pos_embed_win, std=0.02)

if weight_init != 'skip':
if weight_init == 'jax':
named_apply(partial(_init_weight_jax, head_bias=-math.log(self.num_classes)), self)
else:
named_apply(_init_weight_vit, self)
init_fn = _init_weight_jax if weight_init == 'jax' else _init_weight_vit
named_apply(init_fn, self)
if fix_init:
self.fix_init_weight()
if isinstance(self.head.fc, nn.Linear):
Expand Down Expand Up @@ -868,11 +847,13 @@ def _init_weight_vit(module, name, init_bias=0.02, head_bias=0.):
nn.init.trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, init_bias)
elif hasattr(module, 'init_weights'):
module.init_weights()


def _init_weight_jax(module, name, head_bias=0.):
if isinstance(module, nn.Linear):
if name.startswith('head'):
if name.startswith('head.fc'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
Expand Down Expand Up @@ -960,7 +941,7 @@ def _cfg(url='', **kwargs):
num_classes=0,
),

"hiera_small_abswin_256.sbb2_ep200_in12k": _cfg(
"hiera_small_abswin_256.sbb2_e200_in12k": _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95,
Expand Down Expand Up @@ -1007,6 +988,7 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
**kwargs,
)


@register_model
def hiera_tiny_224(pretrained=False, **kwargs):
model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
Expand Down
4 changes: 2 additions & 2 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1967,7 +1967,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_mediumd_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
Expand All @@ -1984,7 +1984,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), crop_pct=0.95),
'vit_betwixt_patch16_reg4_gap_256.sbb2_ep200_in12k': _cfg(
'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
input_size=(3, 256, 256), crop_pct=0.95),
Expand Down

0 comments on commit 2f3fed4

Please sign in to comment.