diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 6734466f9..4f297fde5 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -231,6 +231,7 @@ def __init__( prune_layers: Optional[Union[List[int], Tuple[int]]] = None, prune_ratio: Optional[float] = None, cpe_depth: int = 1, + pos_embed: str = 'none', *args, **kwargs ) -> None: @@ -349,6 +350,12 @@ def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> Depend model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12) model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model + +@register_model +def dependencyvit_tiny_cpe1_lpe_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12, pos_embed='learn') + model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model @register_model def dependencyvit_tiny_cpe5_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: