Skip to content

Commit

Permalink
Add DINOv2 models with register tokens. Convert pos embed to non-over…
Browse files Browse the repository at this point in the history
…lapping for consistency.
  • Loading branch information
rwightman committed Oct 30, 2023
1 parent 97450d6 commit 81c7193
Showing 1 changed file with 97 additions and 15 deletions.
112 changes: 97 additions & 15 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,11 @@ def get_classifier(self):
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
assert global_pool in ('', 'avg', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

Expand Down Expand Up @@ -937,10 +941,14 @@ def _convert_openai_clip(state_dict, model):
def _convert_dinov2(state_dict, model):
import re
out_dict = {}
state_dict.pop("mask_token", None)
if 'register_tokens' in state_dict:
# convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
out_dict['reg_token'] = state_dict.pop('register_tokens')
out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
for k, v in state_dict.items():
if k == "mask_token":
continue
elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
out_dict[k.replace("w12", "fc1")] = v
continue
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
Expand Down Expand Up @@ -1229,6 +1237,32 @@ def _cfg(url='', **kwargs):
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),

# DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),

# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
Expand Down Expand Up @@ -2173,9 +2207,7 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer:
def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-S/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518,
)
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
model = _create_vision_transformer(
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
Expand All @@ -2185,9 +2217,7 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518,
)
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
model = _create_vision_transformer(
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
Expand All @@ -2197,9 +2227,7 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-L/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518,
)
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
model = _create_vision_transformer(
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
Expand All @@ -2209,12 +2237,10 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-G/14 for DINOv2
"""

# The hidden_features of SwiGLU is calculated by:
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
# When embed_dim=1536, hidden_features=4096
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192

model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
Expand All @@ -2224,6 +2250,62 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
return model


@register_model
def vit_small_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-S/14 for DINOv2 w/ 4 registers
"""
model_args = dict(
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5,
reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_base_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/14 for DINOv2 w/ 4 registers
"""
model_args = dict(
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_large_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-L/14 for DINOv2 w/ 4 registers
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_giant_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-G/14 for DINOv2
"""
# The hidden_features of SwiGLU is calculated by:
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
# When embed_dim=1536, hidden_features=4096
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2,
mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model


@register_model
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
Expand Down

0 comments on commit 81c7193

Please sign in to comment.