Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Hiera model for abswin & add more in12k weights for hiera & vit #2258

Merged
merged 17 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit',
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2'
]

# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
]
NUM_NON_STD = len(NON_STD_FILTERS)
Expand Down
6 changes: 4 additions & 2 deletions timm/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d, create_aa
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
from .classifier import create_classifier, ClassifierHead, NormMlpClassifierHead, ClNormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \
set_exportable, set_scriptable, set_no_jit, set_layer_config, set_fused_attn
Expand All @@ -29,6 +29,7 @@
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
from .hybrid_embed import HybridEmbed, HybridEmbedWithSize
from .inplace_abn import InplaceAbn
from .layer_scale import LayerScale, LayerScale2d
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
Expand Down Expand Up @@ -56,4 +57,5 @@
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .trace_utils import _assert, _float_to_int
from .typing import LayerType, PadType
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_, \
init_weight_jax, init_weight_vit
79 changes: 78 additions & 1 deletion timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def forward(self, x, pre_logits: bool = False):


class NormMlpClassifierHead(nn.Module):

""" A Pool -> Norm -> Mlp Classifier Head for '2D' NCHW tensors
"""
def __init__(
self,
in_features: int,
Expand Down Expand Up @@ -204,3 +205,79 @@ def forward(self, x, pre_logits: bool = False):
return x
x = self.fc(x)
return x


class ClNormMlpClassifierHead(nn.Module):
""" A Pool -> Norm -> Mlp Classifier Head for n-D NxxC tensors
"""
def __init__(
self,
in_features: int,
num_classes: int,
hidden_size: Optional[int] = None,
pool_type: str = 'avg',
drop_rate: float = 0.,
norm_layer: Union[str, Callable] = 'layernorm',
act_layer: Union[str, Callable] = 'gelu',
input_fmt: str = 'NHWC',
):
"""
Args:
in_features: The number of input features.
num_classes: The number of classes for the final classifier layer (output).
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
pool_type: Global pooling type, pooling disabled if empty string ('').
drop_rate: Pre-classifier dropout rate.
norm_layer: Normalization layer type.
act_layer: MLP activation layer type (only used if hidden_size is not None).
"""
super().__init__()
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
assert pool_type in ('', 'avg', 'max', 'avgmax')
self.pool_type = pool_type
assert input_fmt in ('NHWC', 'NLC')
self.pool_dim = 1 if input_fmt == 'NLC' else (1, 2)
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)

self.norm = norm_layer(in_features)
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def reset(self, num_classes: int, pool_type: Optional[str] = None, reset_other: bool = False):
if pool_type is not None:
self.pool_type = pool_type
if reset_other:
self.pre_logits = nn.Identity()
self.norm = nn.Identity()
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

def _global_pool(self, x):
if self.pool_type:
if self.pool_type == 'avg':
x = x.mean(dim=self.pool_dim)
elif self.pool_type == 'max':
x = x.amax(dim=self.pool_dim)
elif self.pool_type == 'avgmax':
x = 0.5 * (x.amax(dim=self.pool_dim) + x.mean(dim=self.pool_dim))
return x

def forward(self, x, pre_logits: bool = False):
x = self._global_pool(x)
x = self.norm(x)
x = self.pre_logits(x)
x = self.drop(x)
if pre_logits:
return x
x = self.fc(x)
return x
2 changes: 2 additions & 0 deletions timm/layers/create_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
return None
if isinstance(name, Callable):
return name
name = name.lower()
if not (is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback
Expand All @@ -117,6 +118,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
return name
if not name:
return None
name = name.lower()
if not (is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
Expand Down
38 changes: 38 additions & 0 deletions timm/layers/layer_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from torch import nn


class LayerScale(nn.Module):
""" LayerScale on tensors with channels in last-dim.
"""
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma


class LayerScale2d(nn.Module):
""" LayerScale for tensors with torch 2D NCHW layout.
"""
def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma

44 changes: 43 additions & 1 deletion timm/layers/weight_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import math
import warnings

from torch import nn
from torch.nn.init import _calculate_fan_in_and_fan_out


Expand Down Expand Up @@ -123,3 +123,45 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):

def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')


def init_weight_vit(
module: nn.Module,
name: str,
init_bias: float = 0.02,
head_bias: float = 0.,
classifier_name: str = 'head'
):
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
if name.startswith(classifier_name):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.trunc_normal_(module.weight, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.constant_(module.bias, init_bias)
elif hasattr(module, 'init_weights'):
module.init_weights()


def init_weight_jax(
module: nn.Module,
name: str,
head_bias: float = 0.,
classifier_name: str = 'head',
):
if isinstance(module, nn.Linear):
if name.startswith(classifier_name):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()

1 change: 1 addition & 0 deletions timm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .hardcorenas import *
from .hgnet import *
from .hiera import *
from .hieradet_sam2 import *
from .hrnet import *
from .inception_next import *
from .inception_resnet_v2 import *
Expand Down
8 changes: 6 additions & 2 deletions timm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,8 +1290,12 @@ def _cfg(url='', **kwargs):
'efficientnet_b0.ra4_e3600_r224_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0
),
crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0),
'efficientnet_b1.ra4_e3600_r240_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
input_size=(3, 240, 240), crop_pct=0.9, pool_size=(8, 8),
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'efficientnet_b1.ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
hf_hub_id='timm/',
Expand Down
Loading
Loading