Skip to content

Commit

Permalink
Add mambaout builder support, pretrained weight remap
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Aug 23, 2024
1 parent 719912c commit ffd3480
Showing 1 changed file with 55 additions and 76 deletions.
131 changes: 55 additions & 76 deletions timm/models/mambaout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
MetaFormer (https://github.com/sail-sg/metaformer),
InceptionNeXt (https://github.com/sail-sg/inceptionnext)
"""
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_, DropPath, LayerNorm
from .vision_transformer import LayerScale
from ._manipulate import checkpoint_seq
from timm.models.registry import register_model

from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
from ._registry import register_model


class Stem(nn.Module):
Expand Down Expand Up @@ -275,6 +274,7 @@ def __init__(
act_layer=nn.GELU,
conv_ratio=1.0,
kernel_size=7,
stem_mid_norm=True,
ls_init_value=None,
drop_path_rate=0.,
drop_rate=0.,
Expand All @@ -293,7 +293,13 @@ def __init__(
num_stage = len(depths)
self.num_stage = num_stage

self.stem = Stem(in_chans, dims[0], act_layer=act_layer, norm_layer=norm_layer)
self.stem = Stem(
in_chans,
dims[0],
mid_norm=stem_mid_norm,
act_layer=act_layer,
norm_layer=norm_layer,
)
prev_dim = dims[0]
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
self.stages = nn.ModuleList()
Expand Down Expand Up @@ -338,7 +344,7 @@ def forward_features(self, x):
x = s(x)
return x

def forward_head(self, x):
def forward_head(self, x, pre_logits: bool = False):
x = x.mean((1, 2))
x = self.norm(x)
x = self.head(x)
Expand All @@ -350,6 +356,21 @@ def forward(self, x):
return x


def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
state_dict = state_dict['model']

import re
out_dict = {}
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
out_dict[k] = v

return out_dict


def _cfg(url='', **kwargs):
return {
'url': url,
Expand All @@ -376,105 +397,63 @@ def _cfg(url='', **kwargs):
}


def _create_mambaout(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
MambaOut, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs,
)
return model


# a series of MambaOut models
@register_model
def mambaout_femto(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 3, 9, 3],
dims=[48, 96, 192, 288],
**kwargs)
model.default_cfg = default_cfgs['mambaout_femto']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model

model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 192, 288))
return _create_mambaout('mambaout_femto', pretrained=pretrained, **dict(model_args, **kwargs))

# Kobe Memorial Version with 24 Gated CNN blocks
@register_model
def mambaout_kobe(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 3, 15, 3],
dims=[48, 96, 192, 288],
**kwargs)
model.default_cfg = default_cfgs['mambaout_kobe']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model

model_args = dict(depths=[3, 3, 15, 3], dims=[48, 96, 192, 288])
return _create_mambaout('mambaout_kobe', pretrained=pretrained, **dict(model_args, **kwargs))

@register_model
def mambaout_tiny(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 3, 9, 3],
dims=[96, 192, 384, 576],
**kwargs)
model.default_cfg = default_cfgs['mambaout_tiny']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=[3, 3, 9, 3], dims=[96, 192, 384, 576])
return _create_mambaout('mambaout_tiny', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_small(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 4, 27, 3],
dims=[96, 192, 384, 576],
**kwargs)
model.default_cfg = default_cfgs['mambaout_small']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=[3, 4, 27, 3], dims=[96, 192, 384, 576])
return _create_mambaout('mambaout_small', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_base(pretrained=False, **kwargs):
model = MambaOut(
depths=[3, 4, 27, 3],
dims=[128, 256, 512, 768],
**kwargs)
model.default_cfg = default_cfgs['mambaout_base']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
model_args = dict(depths=[3, 4, 27, 3], dims=[128, 256, 512, 768])
return _create_mambaout('mambaout_base', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_small_rw(pretrained=False, **kwargs):
model = MambaOut(
model_args = dict(
depths=[3, 4, 27, 3],
dims=[96, 192, 384, 576],
stem_mid_norm=False,
ls_init_value=1e-6,
**kwargs,
)
model.default_cfg = default_cfgs['mambaout_small']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
return _create_mambaout('mambaout_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))


@register_model
def mambaout_base_rw(pretrained=False, **kwargs):
model = MambaOut(
model_args = dict(
depths=(3, 4, 27, 3),
dims=(128, 256, 512, 768),
stem_mid_norm=False,
ls_init_value=1e-6,
**kwargs
)
model.default_cfg = default_cfgs['mambaout_base']
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url=model.default_cfg['url'], map_location="cpu", check_hash=True)
model.load_state_dict(state_dict)
return model
return _create_mambaout('mambaout_base_rw', pretrained=pretrained, **dict(model_args, **kwargs))

0 comments on commit ffd3480

Please sign in to comment.