From 82ae247879a2fdf79edb1b40eda42957a0c1e247 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 11 Oct 2024 11:07:40 -0700 Subject: [PATCH] MambaOut weights on hub, configs finalized --- timm/models/mambaout.py | 101 +++++++++++++++++++++++++++++++--------- 1 file changed, 78 insertions(+), 23 deletions(-) diff --git a/timm/models/mambaout.py b/timm/models/mambaout.py index 91ea0a0914..c748e408ea 100644 --- a/timm/models/mambaout.py +++ b/timm/models/mambaout.py @@ -15,7 +15,7 @@ from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq -from ._registry import register_model +from ._registry import register_model, generate_default_cfgs class Stem(nn.Module): @@ -435,6 +435,8 @@ def forward(self, x): def checkpoint_filter_fn(state_dict, model): if 'model' in state_dict: state_dict = state_dict['model'] + if 'stem.conv1.weight' in state_dict: + return state_dict import re out_dict = {} @@ -458,30 +460,52 @@ def checkpoint_filter_fn(state_dict, model): def _cfg(url='', **kwargs): return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'num_classes': 1000, 'input_size': (3, 224, 224), 'test_input_size': (3, 288, 288), + 'pool_size': (7, 7), 'crop_pct': 1.0, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'stem.conv1', 'classifier': 'head.fc', **kwargs } -default_cfgs = { - 'mambaout_femto': _cfg( - url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth'), - 'mambaout_kobe': _cfg( - url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_kobe.pth'), - 'mambaout_tiny': _cfg( - url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth'), - 'mambaout_small': _cfg( - url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth'), - 'mambaout_base': _cfg( - url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'), - 'mambaout_small_rw': _cfg(), - 'mambaout_base_slim_rw': _cfg(), - 'mambaout_base_plus_rw': _cfg(), - 'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)), -} +default_cfgs = generate_default_cfgs({ + # original weights + 'mambaout_femto.in1k': _cfg( + hf_hub_id='timm/'), + 'mambaout_kobe.in1k': _cfg( + hf_hub_id='timm/'), + 'mambaout_tiny.in1k': _cfg( + hf_hub_id='timm/'), + 'mambaout_small.in1k': _cfg( + hf_hub_id='timm/'), + 'mambaout_base.in1k': _cfg( + hf_hub_id='timm/'), + + # timm experiments below + 'mambaout_small_rw.sw_e450_in1k': _cfg( + hf_hub_id='timm/', + ), + 'mambaout_base_short_rw.sw_e500_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_crop_pct=1.0, + ), + 'mambaout_base_tall_rw.sw_e500_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_crop_pct=1.0, + ), + 'mambaout_base_wide_rw.sw_e500_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_crop_pct=1.0, + ), + 'mambaout_base_plus_rw.sw_e150_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + ), + 'mambaout_base_plus_rw.sw_e150_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + ), + 'test_mambaout': _cfg(input_size=(3, 160, 160), test_input_size=(3, 192, 192), pool_size=(5, 5)), +}) def _create_mambaout(variant, pretrained=False, **kwargs): @@ -538,9 +562,24 @@ def mambaout_small_rw(pretrained=False, **kwargs): @register_model -def mambaout_base_slim_rw(pretrained=False, **kwargs): +def mambaout_base_short_rw(pretrained=False, **kwargs): model_args = dict( - depths=(3, 4, 27, 3), + depths=(3, 3, 25, 3), + dims=(128, 256, 512, 768), + expansion_ratio=3.0, + conv_ratio=1.25, + stem_mid_norm=False, + downsample='conv_nf', + ls_init_value=1e-6, + head_fn='norm_mlp', + ) + return _create_mambaout('mambaout_base_short_rw', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def mambaout_base_tall_rw(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 4, 30, 3), dims=(128, 256, 512, 768), expansion_ratio=2.5, conv_ratio=1.25, @@ -549,11 +588,11 @@ def mambaout_base_slim_rw(pretrained=False, **kwargs): ls_init_value=1e-6, head_fn='norm_mlp', ) - return _create_mambaout('mambaout_base_slim_rw', pretrained=pretrained, **dict(model_args, **kwargs)) + return _create_mambaout('mambaout_base_tall_rw', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model -def mambaout_base_plus_rw(pretrained=False, **kwargs): +def mambaout_base_wide_rw(pretrained=False, **kwargs): model_args = dict( depths=(3, 4, 27, 3), dims=(128, 256, 512, 768), @@ -565,6 +604,22 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs): act_layer='silu', head_fn='norm_mlp', ) + return _create_mambaout('mambaout_base_wide_rw', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def mambaout_base_plus_rw(pretrained=False, **kwargs): + model_args = dict( + depths=(3, 4, 30, 3), + dims=(128, 256, 512, 768), + expansion_ratio=3.0, + conv_ratio=1.5, + stem_mid_norm=False, + downsample='conv_nf', + ls_init_value=1e-6, + act_layer='silu', + head_fn='norm_mlp', + ) return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))