From 01b62264afe41574ebdb3747471204861f658cb2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 8 Oct 2024 23:40:24 -0700 Subject: [PATCH 1/2] Add i18n variant of so400m model w/ weights. Add two in1k fine-tunes of original so400m 384x384 but at 378x378 (better matches patch14) --- timm/models/vision_transformer.py | 77 +++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index c0a28085dc..9b7c7cd000 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1841,6 +1841,16 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_id='timm/ViT-SO400M-14-SigLIP', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0), + 'vit_so400m_patch16_siglip_256.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-16-SigLIP-i18n-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_so400m_patch14_siglip_378.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 378, 378), + num_classes=0), 'vit_so400m_patch14_siglip_384.webli': _cfg( hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', hf_hub_filename='open_clip_pytorch_model.bin', @@ -1890,6 +1900,16 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_filename='paligemma-3b-pt-224.npz', custom_load='hf', num_classes=0), + 'vit_so400m_patch16_siglip_gap_256.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-16-SigLIP-i18n-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_so400m_patch14_siglip_gap_378.webli': _cfg( + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 378, 378), crop_pct=1.0, + num_classes=0), 'vit_so400m_patch14_siglip_gap_384.webli': _cfg( hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', hf_hub_filename='open_clip_pytorch_model.bin', @@ -1914,6 +1934,17 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: input_size=(3, 896, 896), crop_pct=1.0, num_classes=0), + 'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg( + #hf_hub_id='timm/', + #file='vit_so400m_p14_378_map-8.pth', + input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', + ), + 'vit_so400m_patch14_siglip_gap_378.webli_ft_in1k': _cfg( + # hf_hub_id='timm/', + #file='vit_so400m_p14_378_gap-8.pth', + input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', + ), + 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg( hf_hub_id='timm/', hf_hub_filename='open_clip_pytorch_model.bin', @@ -2935,6 +2966,28 @@ def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionT return model +@register_model +def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + # this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation) + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + # this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation) + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( @@ -3023,6 +3076,30 @@ def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> Vis return model +@register_model +def vit_so400m_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_so400m_patch14_siglip_gap_378(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, + class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_so400m_patch14_siglip_gap_378', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" From d9321b0e1016359ebc3fa92a98a4f70d2fb96fe9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 9 Oct 2024 09:04:44 -0700 Subject: [PATCH 2/2] Add weights for fine-tuned siglip so400m. Add webli_i18n pretrained tags for the multi-lingual model variants (incl older base) --- timm/models/vision_transformer.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 9b7c7cd000..a5fad6ef7d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1817,6 +1817,11 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), num_classes=0), + 'vit_base_patch16_siglip_256.webli_i18n': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-i18n-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), 'vit_base_patch16_siglip_384.webli': _cfg( hf_hub_id='timm/ViT-B-16-SigLIP-384', hf_hub_filename='open_clip_pytorch_model.bin', @@ -1841,7 +1846,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_id='timm/ViT-SO400M-14-SigLIP', hf_hub_filename='open_clip_pytorch_model.bin', num_classes=0), - 'vit_so400m_patch16_siglip_256.webli': _cfg( + 'vit_so400m_patch16_siglip_256.webli_i18n': _cfg( hf_hub_id='timm/ViT-SO400M-16-SigLIP-i18n-256', hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), @@ -1866,6 +1871,11 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), num_classes=0), + 'vit_base_patch16_siglip_gap_256.webli_i18n': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-i18n-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), 'vit_base_patch16_siglip_gap_384.webli': _cfg( hf_hub_id='timm/ViT-B-16-SigLIP-384', hf_hub_filename='open_clip_pytorch_model.bin', @@ -1900,7 +1910,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: hf_hub_filename='paligemma-3b-pt-224.npz', custom_load='hf', num_classes=0), - 'vit_so400m_patch16_siglip_gap_256.webli': _cfg( + 'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg( hf_hub_id='timm/ViT-SO400M-16-SigLIP-i18n-256', hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 256, 256), @@ -1935,13 +1945,11 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: num_classes=0), 'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg( - #hf_hub_id='timm/', - #file='vit_so400m_p14_378_map-8.pth', + hf_hub_id='timm/', input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', ), 'vit_so400m_patch14_siglip_gap_378.webli_ft_in1k': _cfg( - # hf_hub_id='timm/', - #file='vit_so400m_p14_378_gap-8.pth', + hf_hub_id='timm/', input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', ),