From a2ac0cff0c983a40f8a4ec170b5a67f5a98c4ef7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 7 Mar 2023 08:36:35 -0800 Subject: [PATCH 01/32] wip --- timm/models/cvt.py | 135 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 timm/models/cvt.py diff --git a/timm/models/cvt.py b/timm/models/cvt.py new file mode 100644 index 0000000000..fc6a265004 --- /dev/null +++ b/timm/models/cvt.py @@ -0,0 +1,135 @@ +import torch +import torch.nn +from torch import Tensor + +from timm.layers import LayerNorm2d, Mlp, ConvNormAct + +class ConvEmbed(nn.Module): + def __init__( + self, + in_chs=3, + out_chs=64, + kernel_size=7, + stride=4, + padding=2, + norm_layer=None + ): + super().__init__() + + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + padding=padding + ) + + self.norm = norm_layer(out_chs) if norm_layer else nn.Identity() + + def forward(self, x: Tensor): + x = self.conv(x) + x = self.norm(x) + return x + + + +class Attention(nn.Module): + def __init__( + self, + in_chs, + out_chs, + num_heads, + kernel_size=3, + stride_q=1, + stride_kv=1, + padding_q=1, + padding_kv=1, + qkv_bias=False, + conv_bias=False, + attn_drop=0., + proj_drop=0., + conv_norm_layer=nn.BatchNorm2d, + conv_act_layer=nn.Identity(), + + cls_token=True + ): + assert out_chs % num_heads == 0, 'dim should be divisible by num_heads' + self.out_chs = out_chs + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = out_chs ** -0.5 + + self.conv_q = ConvNormAct( + in_chs, + out_chs, + kernel_size, + stride=stride_q, + padding=padding_q, + groups=in_chs, + bias=conv_bias, + norm_layer=conv_norm_layer, + act_layer=conv_act_layer + ) + + self.conv_k = ConvNormAct( + in_chs, + out_chs * 2, + kernel_size, + stride=stride_kv, + padding=padding_kv, + groups=in_chs, + bias=conv_bias, + norm_layer=conv_norm_layer, + act_layer=conv_act_layer + ) + + self.conv_v = ConvNormAct( + in_chs, + out_chs * 2, + kernel_size, + stride=stride_kv, + padding=padding_kv, + groups=in_chs, + bias=conv_bias, + norm_layer=conv_norm_layer, + act_layer=conv_act_layer + ) + + # better way to do this? iirc 1 is better than 3 + self.proj_q = nn.Linear(in_chs, out_chs, bias=qkv_bias) + self.proj_k = nn.Linear(in_chs, out_chs, bias=qkv_bias) + self.proj_v = nn.Linear(in_chs, out_chs, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(out_chs, out_chs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor): + # [B, C_in, H, W] -> [B, H*W, C_out] + q = self.conv_q(x).flatten(2).transpose(1, 2) + k = self.conv_k(x).flatten(2).transpose(1, 2) + v = self.conv_v(x).flatten(2).transpose(1, 2) + + # need to handle cls token here + + # [B, H*W, C_out] -> + q = self.proj_q(q).reshape(B, q.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = self.proj_k(k).reshape(B, k.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = self.proj_v(v).reshape(B, v.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, self.out_chs) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class QuickGELU(nn.Module): + def forward(self, x: Tensor): + return x * torch.sigmoid(1.702 * x) + + \ No newline at end of file From 4c2827fd58abe83c1d990df3cdea025dcbc46064 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 27 Dec 2023 09:19:08 -0800 Subject: [PATCH 02/32] Update cvt.py --- timm/models/cvt.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index fc6a265004..cc1244267b 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -12,7 +12,7 @@ def __init__( kernel_size=7, stride=4, padding=2, - norm_layer=None + norm_layer=LayerNorm2d, ): super().__init__() @@ -26,7 +26,7 @@ def __init__( self.norm = norm_layer(out_chs) if norm_layer else nn.Identity() - def forward(self, x: Tensor): + def forward(self, x: Tensor): # [B, C, H, W] -> [B, C, H, W] x = self.conv(x) x = self.norm(x) return x @@ -95,7 +95,7 @@ def __init__( act_layer=conv_act_layer ) - # better way to do this? iirc 1 is better than 3 + # FIXME better way to do this? iirc 1 is better than 3 self.proj_q = nn.Linear(in_chs, out_chs, bias=qkv_bias) self.proj_k = nn.Linear(in_chs, out_chs, bias=qkv_bias) self.proj_v = nn.Linear(in_chs, out_chs, bias=qkv_bias) @@ -111,11 +111,12 @@ def forward(self, x: Tensor): # need to handle cls token here - # [B, H*W, C_out] -> + # [B, H*W, C_out] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h] q = self.proj_q(q).reshape(B, q.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) k = self.proj_k(k).reshape(B, k.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) v = self.proj_v(v).reshape(B, v.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) + # FIXME F.sdpa q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) From 9b020d43750f5acc12df5cda74fb97a06775ee94 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 28 Dec 2023 03:04:06 -0800 Subject: [PATCH 03/32] Update cvt.py --- timm/models/cvt.py | 342 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 264 insertions(+), 78 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index cc1244267b..291313907d 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -1,19 +1,22 @@ +from typing import Optional, Tuple + import torch import torch.nn -from torch import Tensor +import torch.nn.functional as F + +from timm.layers import ConvNormAct, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn -from timm.layers import LayerNorm2d, Mlp, ConvNormAct class ConvEmbed(nn.Module): def __init__( self, - in_chs=3, - out_chs=64, - kernel_size=7, - stride=4, - padding=2, - norm_layer=LayerNorm2d, - ): + in_chs: int = 3, + out_chs: int = 64, + kernel_size: int = 7, + stride: int = 4, + padding: int = 2, + norm_layer: nn.Module = nn.LayerNorm2d, + ) -> None: super().__init__() self.conv = nn.Conv2d( @@ -26,111 +29,294 @@ def __init__( self.norm = norm_layer(out_chs) if norm_layer else nn.Identity() - def forward(self, x: Tensor): # [B, C, H, W] -> [B, C, H, W] + def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, C, H, W] -> [B, C, H, W] x = self.conv(x) x = self.norm(x) return x - - -class Attention(nn.Module): +class ConvProj(nn.Module): def __init__( self, - in_chs, - out_chs, - num_heads, - kernel_size=3, - stride_q=1, - stride_kv=1, - padding_q=1, - padding_kv=1, - qkv_bias=False, - conv_bias=False, - attn_drop=0., - proj_drop=0., - conv_norm_layer=nn.BatchNorm2d, - conv_act_layer=nn.Identity(), - - cls_token=True - ): - assert out_chs % num_heads == 0, 'dim should be divisible by num_heads' - self.out_chs = out_chs - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = out_chs ** -0.5 - + dim: int, + kernel_size: int = 3, + stride_q: int = 1, + stride_kv: int = 2, + padding: int = 1, + bias: bool = False, + norm_layer: nn.Module = nn.BatchNorm2d, + act_layer: nn.Module = nn.Identity(), + ) -> None: + self.dim = dim + self.conv_q = ConvNormAct( - in_chs, - out_chs, + dim, + dim, kernel_size, stride=stride_q, - padding=padding_q, + padding=padding, groups=in_chs, - bias=conv_bias, - norm_layer=conv_norm_layer, - act_layer=conv_act_layer + bias=bias, + norm_layer=norm_layer, + act_layer=act_layer ) self.conv_k = ConvNormAct( - in_chs, - out_chs * 2, + dim, + dim, kernel_size, stride=stride_kv, - padding=padding_kv, + padding=padding, groups=in_chs, bias=conv_bias, - norm_layer=conv_norm_layer, - act_layer=conv_act_layer + norm_layer=norm_layer, + act_layer=act_layer ) self.conv_v = ConvNormAct( - in_chs, - out_chs * 2, + dim, + dim, kernel_size, stride=stride_kv, - padding=padding_kv, + padding=padding, groups=in_chs, bias=conv_bias, - norm_layer=conv_norm_layer, - act_layer=conv_act_layer + norm_layer=norm_layer, + act_layer=act_layer ) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, C, H, W = x.shape + # [B, C, H, W] -> [B, H*W, C] + q = self.conv_q(x).flatten(2).transpose(1, 2) + k = self.conv_k(x).flatten(2).transpose(1, 2) + v = self.conv_v(x).flatten(2).transpose(1, 2) + return q, k, v + +class Attention(nn.Module): + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = use_fused_attn() - # FIXME better way to do this? iirc 1 is better than 3 - self.proj_q = nn.Linear(in_chs, out_chs, bias=qkv_bias) - self.proj_k = nn.Linear(in_chs, out_chs, bias=qkv_bias) - self.proj_v = nn.Linear(in_chs, out_chs, bias=qkv_bias) + self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) + self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) + self.proj_v = nn.Linear(dim, dim, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(out_chs, out_chs) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x: Tensor): - # [B, C_in, H, W] -> [B, H*W, C_out] - q = self.conv_q(x).flatten(2).transpose(1, 2) - k = self.conv_k(x).flatten(2).transpose(1, 2) - v = self.conv_v(x).flatten(2).transpose(1, 2) + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + B, N, C = q.shape - # need to handle cls token here - - # [B, H*W, C_out] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h] + # [B, H*W, C] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h] q = self.proj_q(q).reshape(B, q.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) k = self.proj_k(k).reshape(B, k.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) v = self.proj_v(v).reshape(B, v.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) - - # FIXME F.sdpa - q = q * self.scale - attn = q @ k.transpose(-2, -1) - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - x = attn @ v - - x = x.transpose(1, 2).reshape(B, N, self.out_chs) + q, k = self.q_norm(q), self.k_norm(k) + # [B, n_h, H*W, d_h], [B, n_h, H*W/4, d_h], [B, n_h, H*W/4, d_h] + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) - return x -class QuickGELU(nn.Module): - def forward(self, x: Tensor): - return x * torch.sigmoid(1.702 * x) +class CvTBlock(nn.Module): + def __init__( + self, + dim: int, + kernel_size: int = 3, + stride_q: int = 1, + stride_kv: int = 2, + padding: int = 1, + conv_bias: bool = False, + conv_norm_layer: nn.Module = nn.BatchNorm2d, + conv_act_layer: nn.Module = nn.Identity(), + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + input_norm_layer = LayerNorm2d, + norm_layer: nn.Module = nn.LayerNorm, + init_values: Optional[float] = None, + drop_path: float = 0., + mlp_layer: nn.Module = Mlp, + mlp_ratio: float = 4., + mlp_act_layer: nn.Module = QuickGELU, + use_cls_token: bool = False, + ) -> None: + self.use_cls_token = use_cls_token + + self.norm1 = norm_layer(dim) + self.conv_proj = ConvProj( + dim = dim, + kernel_size = kernel_size, + stride_q = stride_q, + stride_kv = stride_kv, + padding = padding, + bias = conv_bias, + norm_layer = conv_norm_layer, + act_layer = conv_act_layer, + ) + self.attn = Attention( + dim = dim, + num_heads = num_heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm, + attn_drop = attn_drop, + proj_drop = proj_drop, + norm_layer = norm_layer + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def add_cls_token( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cls_token: Optional[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.use_cls_token: + q = torch.cat((cls_token, q), dim=1) + k = torch.cat((cls_token, k), dim=1) + v = torch.cat((cls_token, v), dim=1) + return q, k, v + + def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.Tensor: + return self.attn(*self.add_cls_token(*self.conv_proj(x), cls_token)) + + def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + B, C, H, W = x.shape + + x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw(attn(self.norm1(x))))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + + if self.use_cls_token: + cls_token, x = torch.split(x, [1, H*W], 1) + + return x, cls_token + +class CvTStage(nn.Module): + def __init__( + in_chs: int, + dim: int, + depth: int, + embed_kernel_size: int = 7, + embed_stride: int = 4, + embed_padding: int 2, + kernel_size: int = 3, + stride_q: int = 1, + stride_kv: int = 2, + padding: int = 1, + conv_bias: bool = False, + conv_norm_layer: nn.Module = nn.BatchNorm2d, + conv_act_layer: nn.Module = nn.Identity(), + num_heads: int = 8, + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + input_norm_layer = LayerNorm2d, + norm_layer: nn.Module = nn.LayerNorm, + init_values: Optional[float] = None, + drop_path: float = 0., + mlp_layer: nn.Module = Mlp, + mlp_ratio: float = 4., + mlp_act_layer: nn.Module = QuickGELU, + use_cls_token: bool = False, + ) -> None: + self.conv_embed = ConvEmbed( + in_chs = in_chs, + out_chs = dim, + kernel_size = embed_kernel_size, + stride = embed_stride, + padding = embed_padding, + norm_layer = input_norm_layer, + ) + self.embed_drop = nn.Dropout(proj_drop) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, dim)) if use_cls_token else None + + blocks = [] + for i in range(depth): + blocks.append( + CvTBlock( + dim = dim, + kernel_size = kernel_size, + stride_q = stride_q, + stride_kv = stride_kv, + padding = padding, + conv_bias = conv_bias, + conv_norm_layer = conv_norm_layer, + conv_act_layer = conv_act_layer, + num_heads = num_heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm, + attn_drop = attn_drop, + proj_drop = proj_drop, + input_norm_layer input_norm_layer, + norm_layer = norm_layer, + init_values = init_values, + drop_path = drop_path, + mlp_layer = mlp_layer, + mlp_ratio = mlp_ratio, + mlp_act_layer = mlp_act_layer, + use_cls_token = use_cls_token, + ) + ) + self.blocks = nn.ModuleList(blocks) + + if self.cls_token is not None: + trunc_normal_(self.cls_token, std=.02) - \ No newline at end of file + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv_embed(x) + x = self.embed_drop(x) + + cls_token = self.cls_token + for block in self.blocks: + x, cls_token = block(x, cls_token) + + return x, cls_token + +class CvT(nn.Module): + \ No newline at end of file From db586b5220b93ae0bee3d0dcd20c9d6ebeb9c17a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 28 Dec 2023 22:58:10 -0800 Subject: [PATCH 04/32] Update cvt.py --- timm/models/cvt.py | 138 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 107 insertions(+), 31 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 291313907d..2184461b27 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -46,6 +46,7 @@ def __init__( norm_layer: nn.Module = nn.BatchNorm2d, act_layer: nn.Module = nn.Identity(), ) -> None: + super().__init__() self.dim = dim self.conv_q = ConvNormAct( @@ -98,7 +99,7 @@ class Attention(nn.Module): def __init__( self, dim: int, - num_heads: int = 8, + num_heads: int = 1, qkv_bias: bool = True, qk_norm: bool = False, attn_drop: float = 0., @@ -159,7 +160,7 @@ def __init__( conv_bias: bool = False, conv_norm_layer: nn.Module = nn.BatchNorm2d, conv_act_layer: nn.Module = nn.Identity(), - num_heads: int = 8, + num_heads: int = 1, qkv_bias: bool = True, qk_norm: bool = False, attn_drop: float = 0., @@ -173,6 +174,7 @@ def __init__( mlp_act_layer: nn.Module = QuickGELU, use_cls_token: bool = False, ) -> None: + super().__init__() self.use_cls_token = use_cls_token self.norm1 = norm_layer(dim) @@ -242,7 +244,7 @@ def __init__( depth: int, embed_kernel_size: int = 7, embed_stride: int = 4, - embed_padding: int 2, + embed_padding: int = 2, kernel_size: int = 3, stride_q: int = 1, stride_kv: int = 2, @@ -250,7 +252,7 @@ def __init__( conv_bias: bool = False, conv_norm_layer: nn.Module = nn.BatchNorm2d, conv_act_layer: nn.Module = nn.Identity(), - num_heads: int = 8, + num_heads: int = 1, qkv_bias: bool = True, qk_norm: bool = False, attn_drop: float = 0., @@ -258,12 +260,14 @@ def __init__( input_norm_layer = LayerNorm2d, norm_layer: nn.Module = nn.LayerNorm, init_values: Optional[float] = None, - drop_path: float = 0., + drop_path_rates: List[float] = [0.], mlp_layer: nn.Module = Mlp, mlp_ratio: float = 4., mlp_act_layer: nn.Module = QuickGELU, use_cls_token: bool = False, ) -> None: + super().__init__() + self.conv_embed = ConvEmbed( in_chs = in_chs, out_chs = dim, @@ -278,31 +282,30 @@ def __init__( blocks = [] for i in range(depth): - blocks.append( - CvTBlock( - dim = dim, - kernel_size = kernel_size, - stride_q = stride_q, - stride_kv = stride_kv, - padding = padding, - conv_bias = conv_bias, - conv_norm_layer = conv_norm_layer, - conv_act_layer = conv_act_layer, - num_heads = num_heads, - qkv_bias = qkv_bias, - qk_norm = qk_norm, - attn_drop = attn_drop, - proj_drop = proj_drop, - input_norm_layer input_norm_layer, - norm_layer = norm_layer, - init_values = init_values, - drop_path = drop_path, - mlp_layer = mlp_layer, - mlp_ratio = mlp_ratio, - mlp_act_layer = mlp_act_layer, - use_cls_token = use_cls_token, - ) + block = CvTBlock( + dim = dim, + kernel_size = kernel_size, + stride_q = stride_q, + stride_kv = stride_kv, + padding = padding, + conv_bias = conv_bias, + conv_norm_layer = conv_norm_layer, + conv_act_layer = conv_act_layer, + num_heads = num_heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm, + attn_drop = attn_drop, + proj_drop = proj_drop, + input_norm_layer input_norm_layer, + norm_layer = norm_layer, + init_values = init_values, + drop_path = drop_path_rates[i], + mlp_layer = mlp_layer, + mlp_ratio = mlp_ratio, + mlp_act_layer = mlp_act_layer, + use_cls_token = use_cls_token, ) + blocks.append(block) self.blocks = nn.ModuleList(blocks) if self.cls_token is not None: @@ -313,10 +316,83 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.embed_drop(x) cls_token = self.cls_token - for block in self.blocks: + for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor x, cls_token = block(x, cls_token) return x, cls_token class CvT(nn.Module): - \ No newline at end of file + def __init__( + in_chans: int = 3, + num_classes: int = 1000, + dims: Tuple[int, ...] = (64, 192, 384), + depths: Tuple[int, ...] = (1, 2, 10), + embed_kernel_size: Tuple[int, ...] = (7, 3, 3), + embed_stride: Tuple[int, ...] = (4, 2, 2), + embed_padding: Tuple[int, ...] = (2, 1, 1), + kernel_size: int = 3, + stride_q: int = 1, + stride_kv: int = 2, + padding: int = 1, + conv_bias: bool = False, + conv_norm_layer: nn.Module = nn.BatchNorm2d, + conv_act_layer: nn.Module = nn.Identity(), + num_heads: Tuple[int, ...] = (1, 3, 6), + qkv_bias: bool = True, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + input_norm_layer = LayerNorm2d, + norm_layer: nn.Module = nn.LayerNorm, + init_values: Optional[float] = None, + drop_path_rate: float = 0., + mlp_layer: nn.Module = Mlp, + mlp_ratio: float = 4., + mlp_act_layer: nn.Module = QuickGELU, + use_cls_token: Tuple[bool, ...] = (False, False, True), + ) -> None: + super().__init__() + num_stages = len(dims) + assert num_stages == len(depths) == len(embed_kernel_size) == len(embed_stride) + assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token) + self.num_classes = num_classes + self.num_features = dims[-1] + self.drop_rate = drop_rate + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + + in_chs = in_chans + + stages = [] + for stage_idx in range(num_stages): + dim = dims[stage_idx] + stage = CvTStage( + in_chs = in_chs, + dim = dim, + depth = depths[stage_idx], + embed_kernel_size = embed_kernel_size[stage_idx], + embed_stride = embed_stride[stage_idx], + embed_padding = embed_padding[stage_idx], + kernel_size = kernel_size, + stride_q = stride_q, + stride_kv = stride_kv, + padding = padding, + conv_bias = conv_bias, + conv_norm_layer = conv_norm_layer, + conv_act_layer = conv_act_layer, + num_heads = num_heads[stage_idx], + qkv_bias = qkv_bias, + qk_norm = qk_norm, + attn_drop = attn_drop, + proj_drop = proj_drop, + input_norm_layer = input_norm_layer, + norm_layer = norm_layer, + init_values = init_values, + drop_path_rates = dpr[stage_idx], + mlp_layer = mlp_layer, + mlp_ratio = mlp_ratio, + mlp_act_layer = mlp_act_layer, + use_cls_token = use_cls_token[stage_idx], + ) + in_chs = dim + stages.append(stage) \ No newline at end of file From 95b6a52a1e6d95936fd6e761b4f65a89d759c797 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 6 Jan 2024 05:23:26 -0700 Subject: [PATCH 05/32] Update cvt.py --- timm/models/cvt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 2184461b27..13d26cae8b 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -296,7 +296,7 @@ def __init__( qk_norm = qk_norm, attn_drop = attn_drop, proj_drop = proj_drop, - input_norm_layer input_norm_layer, + input_norm_layer = input_norm_layer, norm_layer = norm_layer, init_values = init_values, drop_path = drop_path_rates[i], From 0d171c6fb7bfccafb4756fcb60a0aeacc167c68b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 21 Apr 2024 05:36:30 -0700 Subject: [PATCH 06/32] Update cvt.py --- timm/models/cvt.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 13d26cae8b..3778845846 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -229,7 +229,7 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: B, C, H, W = x.shape - x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw(attn(self.norm1(x))))) + x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw_attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) if self.use_cls_token: @@ -395,4 +395,6 @@ def __init__( use_cls_token = use_cls_token[stage_idx], ) in_chs = dim - stages.append(stage) \ No newline at end of file + stages.append(stage) + self.stages = nn.ModuleList(stages) + \ No newline at end of file From 11cc4a78d79198dad4600135b08151517956c863 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 21 Apr 2024 07:19:50 -0700 Subject: [PATCH 07/32] Update cvt.py --- timm/models/cvt.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 3778845846..2b2d276b9f 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -61,6 +61,9 @@ def __init__( act_layer=act_layer ) + # TODO fuse kv conv? + # TODO if act_layer is id and not cls_token (gap model?), is later projection in attn necessary? + self.conv_k = ConvNormAct( dim, dim, @@ -235,6 +238,8 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t if self.use_cls_token: cls_token, x = torch.split(x, [1, H*W], 1) + x = x.transpose(1, 2).reshape(B, C, H, W) + return x, cls_token class CvTStage(nn.Module): @@ -359,6 +364,9 @@ def __init__( self.num_features = dims[-1] self.drop_rate = drop_rate + # FIXME only on last stage, no need for tuple + self.use_cls_token = use_cls_token[-1] + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] in_chs = in_chans @@ -397,4 +405,19 @@ def __init__( in_chs = dim stages.append(stage) self.stages = nn.ModuleList(stages) + + self.head_norm = norm_layer(dims[-1]) + self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + for stage in self.stages: + x, cls_token = stage(x) + + + if self.use_cls_token: + return self.head(self.head_norm(cls_token)) + else: + return self.head(self.head_norm(x.mean(dim=(2,3)))) + \ No newline at end of file From 8f7627c003c47d13f995eb997269e61044927b55 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 21 Apr 2024 10:23:21 -0700 Subject: [PATCH 08/32] wip --- timm/models/__init__.py | 1 + timm/models/cvt.py | 137 +++++++++++++++++++++++++++++++--------- 2 files changed, 107 insertions(+), 31 deletions(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index c5b1984f20..4883b3d18e 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -8,6 +8,7 @@ from .convnext import * from .crossvit import * from .cspnet import * +from .cvt import * from .davit import * from .deit import * from .densenet import * diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 2b2d276b9f..2bbd3e22f2 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -1,11 +1,28 @@ -from typing import Optional, Tuple +""" CvT: Convolutional Vision Transformer + +From-scratch implementation of CvT in PyTorch + +'CvT: Introducing Convolutions to Vision Transformers' + - https://arxiv.org/abs/2103.15808 + +Implementation for timm by / Copyright 2024, Fredo Guan +""" + +from functools import partial +from typing import List, Final, Optional, Tuple import torch -import torch.nn +import torch.nn as nn import torch.nn.functional as F -from timm.layers import ConvNormAct, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn +from ._builder import build_model_with_cfg +from ._registry import generate_default_cfgs, register_model + + +__all__ = ['CvT'] class ConvEmbed(nn.Module): def __init__( @@ -15,7 +32,7 @@ def __init__( kernel_size: int = 7, stride: int = 4, padding: int = 2, - norm_layer: nn.Module = nn.LayerNorm2d, + norm_layer: nn.Module = LayerNorm2d, ) -> None: super().__init__() @@ -44,7 +61,7 @@ def __init__( padding: int = 1, bias: bool = False, norm_layer: nn.Module = nn.BatchNorm2d, - act_layer: nn.Module = nn.Identity(), + act_layer: nn.Module = nn.Identity, ) -> None: super().__init__() self.dim = dim @@ -55,7 +72,7 @@ def __init__( kernel_size, stride=stride_q, padding=padding, - groups=in_chs, + groups=dim, bias=bias, norm_layer=norm_layer, act_layer=act_layer @@ -70,8 +87,8 @@ def __init__( kernel_size, stride=stride_kv, padding=padding, - groups=in_chs, - bias=conv_bias, + groups=dim, + bias=bias, norm_layer=norm_layer, act_layer=act_layer ) @@ -82,8 +99,8 @@ def __init__( kernel_size, stride=stride_kv, padding=padding, - groups=in_chs, - bias=conv_bias, + groups=dim, + bias=bias, norm_layer=norm_layer, act_layer=act_layer ) @@ -107,7 +124,7 @@ def __init__( qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., - norm_layer: nn.Module = nn.LayerNorm, + norm_layer: nn.Module = LayerNorm, ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' @@ -122,16 +139,16 @@ def __init__( self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(out_chs, out_chs) + self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: B, N, C = q.shape # [B, H*W, C] -> [B, H*W, n_h, d_h] -> [B, n_h, H*W, d_h] - q = self.proj_q(q).reshape(B, q.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) - k = self.proj_k(k).reshape(B, k.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) - v = self.proj_v(v).reshape(B, v.shape[2], self.num_heads, self.head_dim).permute(0, 2, 1, 3) + q = self.proj_q(q).reshape(B, q.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3) + k = self.proj_k(k).reshape(B, k.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3) + v = self.proj_v(v).reshape(B, v.shape[1], self.num_heads, self.head_dim).permute(0, 2, 1, 3) q, k = self.q_norm(q), self.k_norm(k) # [B, n_h, H*W, d_h], [B, n_h, H*W/4, d_h], [B, n_h, H*W/4, d_h] @@ -162,14 +179,14 @@ def __init__( padding: int = 1, conv_bias: bool = False, conv_norm_layer: nn.Module = nn.BatchNorm2d, - conv_act_layer: nn.Module = nn.Identity(), + conv_act_layer: nn.Module = nn.Identity, num_heads: int = 1, qkv_bias: bool = True, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., - input_norm_layer = LayerNorm2d, - norm_layer: nn.Module = nn.LayerNorm, + input_norm_layer: nn.Module = partial(LayerNorm2d, eps=1e-5), + norm_layer: nn.Module = partial(LayerNorm, eps=1e-5), init_values: Optional[float] = None, drop_path: float = 0., mlp_layer: nn.Module = Mlp, @@ -180,7 +197,7 @@ def __init__( super().__init__() self.use_cls_token = use_cls_token - self.norm1 = norm_layer(dim) + self.norm1 = input_norm_layer(dim) self.conv_proj = ConvProj( dim = dim, kernel_size = kernel_size, @@ -207,7 +224,7 @@ def __init__( self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), - act_layer=act_layer, + act_layer=mlp_act_layer, drop=proj_drop, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() @@ -232,7 +249,8 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: B, C, H, W = x.shape - x = x.flatten(2).transpose(1, 2) + self.drop_path1(self.ls1(self.fw_attn(self.norm1(x)))) + x = torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2) \ + + self.drop_path1(self.ls1(self.fw_attn(self.norm1(x), cls_token))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) if self.use_cls_token: @@ -244,6 +262,7 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t class CvTStage(nn.Module): def __init__( + self, in_chs: int, dim: int, depth: int, @@ -256,14 +275,14 @@ def __init__( padding: int = 1, conv_bias: bool = False, conv_norm_layer: nn.Module = nn.BatchNorm2d, - conv_act_layer: nn.Module = nn.Identity(), + conv_act_layer: nn.Module = nn.Identity, num_heads: int = 1, qkv_bias: bool = True, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., input_norm_layer = LayerNorm2d, - norm_layer: nn.Module = nn.LayerNorm, + norm_layer: nn.Module = LayerNorm, init_values: Optional[float] = None, drop_path_rates: List[float] = [0.], mlp_layer: nn.Module = Mlp, @@ -320,7 +339,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv_embed(x) x = self.embed_drop(x) - cls_token = self.cls_token + cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.cls_token is not None else None for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor x, cls_token = block(x, cls_token) @@ -328,6 +347,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class CvT(nn.Module): def __init__( + self, in_chans: int = 3, num_classes: int = 1000, dims: Tuple[int, ...] = (64, 192, 384), @@ -341,14 +361,14 @@ def __init__( padding: int = 1, conv_bias: bool = False, conv_norm_layer: nn.Module = nn.BatchNorm2d, - conv_act_layer: nn.Module = nn.Identity(), + conv_act_layer: nn.Module = nn.Identity, num_heads: Tuple[int, ...] = (1, 3, 6), qkv_bias: bool = True, qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., input_norm_layer = LayerNorm2d, - norm_layer: nn.Module = nn.LayerNorm, + norm_layer: nn.Module = LayerNorm, init_values: Optional[float] = None, drop_path_rate: float = 0., mlp_layer: nn.Module = Mlp, @@ -362,7 +382,6 @@ def __init__( assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token) self.num_classes = num_classes self.num_features = dims[-1] - self.drop_rate = drop_rate # FIXME only on last stage, no need for tuple self.use_cls_token = use_cls_token[-1] @@ -371,6 +390,8 @@ def __init__( in_chs = in_chans + # TODO move stem + stages = [] for stage_idx in range(num_stages): dim = dims[stage_idx] @@ -406,7 +427,7 @@ def __init__( stages.append(stage) self.stages = nn.ModuleList(stages) - self.head_norm = norm_layer(dims[-1]) + self.norm = norm_layer(dims[-1]) self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -416,8 +437,62 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.use_cls_token: - return self.head(self.head_norm(cls_token)) + return self.head(self.norm(cls_token.flatten(1))) else: - return self.head(self.head_norm(x.mean(dim=(2,3)))) + return self.head(self.norm(x.mean(dim=(2,3)))) - \ No newline at end of file + + +def checkpoint_filter_fn(state_dict, model): + """ Remap MSFT checkpoints -> timm """ + if 'head.fc.weight' in state_dict: + return state_dict # non-MSFT checkpoint + + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + import re + out_dict = {} + for k, v in state_dict.items(): + k = re.sub(r'stage([0-9]+)', r'stages.\1', k) + k = k.replace('patch_embed', 'conv_embed') + k = k.replace('conv_embed.proj', 'conv_embed.conv') + k = k.replace('attn.conv_proj', 'conv_proj.conv') + out_dict[k] = v + return out_dict + + +def _create_cvt(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 2, 10)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + CvT, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model + +# TODO update first_conv +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head', + **kwargs + } + +default_cfgs = generate_default_cfgs({ + 'cvt_13.msft_in1k': _cfg(url='https://files.catbox.moe/xz97kh.pth'), +}) + + +@register_model +def cvt_13(pretrained=False, **kwargs) -> CvT: + model_args = dict(depths=(1, 2, 10), dims=(64, 192, 384), num_heads=(1, 3, 6)) + return _create_cvt('cvt_13', pretrained=pretrained, **dict(model_args, **kwargs)) From b0c6c01c06795b309f7ba9fc34d2a7fb28effb18 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 22 Apr 2024 04:49:16 -0700 Subject: [PATCH 09/32] Update cvt.py --- timm/models/cvt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 2bbd3e22f2..5b557757fd 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -130,8 +130,8 @@ def __init__( assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.fused_attn = use_fused_attn() + self.scale = dim ** -0.5 + self.fused_attn = False #use_fused_attn() self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) From 26c685502df8e03baca0476dda94a7b67c3d261d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 22 Apr 2024 04:51:09 -0700 Subject: [PATCH 10/32] Update cvt.py --- timm/models/cvt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 5b557757fd..78fb7fe1c8 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -367,8 +367,8 @@ def __init__( qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., - input_norm_layer = LayerNorm2d, - norm_layer: nn.Module = LayerNorm, + input_norm_layer = partial(LayerNorm2d, eps=1e-5), + norm_layer: nn.Module = partial(LayerNorm, eps=1e-5), init_values: Optional[float] = None, drop_path_rate: float = 0., mlp_layer: nn.Module = Mlp, From 4d5f2d91648c4aeb4f2320f9a5e1be093b2d8570 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 22 Apr 2024 05:17:54 -0700 Subject: [PATCH 11/32] Update cvt.py --- timm/models/cvt.py | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 78fb7fe1c8..fb0a1736ea 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -8,6 +8,7 @@ Implementation for timm by / Copyright 2024, Fredo Guan """ +from collections import OrderedDict from functools import partial from typing import List, Final, Optional, Tuple @@ -51,6 +52,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, C, H, W] -> [B, C, H, x = self.norm(x) return x + + class ConvProj(nn.Module): def __init__( self, @@ -65,7 +68,9 @@ def __init__( ) -> None: super().__init__() self.dim = dim - + + # FIXME not working, bn layer outputs are incorrect + ''' self.conv_q = ConvNormAct( dim, dim, @@ -78,7 +83,7 @@ def __init__( act_layer=act_layer ) - # TODO fuse kv conv? + # TODO fuse kv conv? don't wanna do weight remap # TODO if act_layer is id and not cls_token (gap model?), is later projection in attn necessary? self.conv_k = ConvNormAct( @@ -104,6 +109,40 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer ) + ''' + self.conv_q = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + stride=stride_q, + bias=bias, + groups=dim + )), + ('bn', nn.BatchNorm2d(dim)),])) + self.conv_k = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + stride=stride_kv, + bias=bias, + groups=dim + )), + ('bn', nn.BatchNorm2d(dim)),])) + self.conv_v = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + stride=stride_kv, + bias=bias, + groups=dim + )), + ('bn', nn.BatchNorm2d(dim)),])) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, C, H, W = x.shape From 5002e4fb55dd9feb41454e1814f422ea59672ad7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 10:02:15 -0700 Subject: [PATCH 12/32] Update cvt.py --- timm/models/cvt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index fb0a1736ea..97a8a1170c 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -268,6 +268,7 @@ def __init__( ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.probe = nn.Identity() def add_cls_token( self, @@ -296,6 +297,7 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t cls_token, x = torch.split(x, [1, H*W], 1) x = x.transpose(1, 2).reshape(B, C, H, W) + self.probe(x) return x, cls_token From b0f23cac97c1ea752383976a000c5d414aa3465d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 10:17:44 -0700 Subject: [PATCH 13/32] Update cvt.py --- timm/models/cvt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 97a8a1170c..5864a45972 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -297,7 +297,7 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t cls_token, x = torch.split(x, [1, H*W], 1) x = x.transpose(1, 2).reshape(B, C, H, W) - self.probe(x) + x = self.probe(x) return x, cls_token From a65c484f94a01e2af493905a0667a7bcf8870912 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 10:22:32 -0700 Subject: [PATCH 14/32] Update cvt.py --- timm/models/cvt.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 5864a45972..7eb2664266 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -372,11 +372,13 @@ def __init__( ) blocks.append(block) self.blocks = nn.ModuleList(blocks) + self.probe = nn.Identity() if self.cls_token is not None: trunc_normal_(self.cls_token, std=.02) def forward(self, x: torch.Tensor) -> torch.Tensor: + self.probe(x) x = self.conv_embed(x) x = self.embed_drop(x) From fa50c0c7b5945a5453ecb0d644f18c3a0cc43aa3 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 10:25:27 -0700 Subject: [PATCH 15/32] Update cvt.py --- timm/models/cvt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 7eb2664266..a1cd1c8972 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -378,7 +378,7 @@ def __init__( trunc_normal_(self.cls_token, std=.02) def forward(self, x: torch.Tensor) -> torch.Tensor: - self.probe(x) + x = self.probe(x) x = self.conv_embed(x) x = self.embed_drop(x) From 37323f86bbe11332e423d03769b15a866bce3eca Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 11:49:48 -0700 Subject: [PATCH 16/32] Update cvt.py --- timm/models/cvt.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index a1cd1c8972..007cf5c631 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -378,11 +378,14 @@ def __init__( trunc_normal_(self.cls_token, std=.02) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.probe(x) + x = self.conv_embed(x) + x = self.probe(x) x = self.embed_drop(x) - cls_token = self.cls_token.expand(x.shape[0], -1, -1) if self.cls_token is not None else None + cls_token = self.embed_drop( + self.cls_token.expand(x.shape[0], -1, -1) + ) if self.cls_token is not None else None for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor x, cls_token = block(x, cls_token) From 6d4b7852664da8491e957b5ec52b6628d3592adc Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 12:34:49 -0700 Subject: [PATCH 17/32] oh xd i feel stupid --- timm/models/cvt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 007cf5c631..983fca49cc 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -289,7 +289,7 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: B, C, H, W = x.shape - x = torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2) \ + x = (torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2)) \ + self.drop_path1(self.ls1(self.fw_attn(self.norm1(x), cls_token))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) From ef7687c9702a0fb5761fab1bb03621d9434b3ea0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 12:43:53 -0700 Subject: [PATCH 18/32] Update cvt.py --- timm/models/cvt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 983fca49cc..b82151f8d5 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -288,9 +288,13 @@ def fw_attn(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> torch.T def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: B, C, H, W = x.shape - - x = (torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2)) \ - + self.drop_path1(self.ls1(self.fw_attn(self.norm1(x), cls_token))) + res = torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2) + + x = self.norm1(torch.cat((cls_token, x.flatten(2).transpose(1, 2)), dim=1) if cls_token is not None else x.flatten(2).transpose(1, 2)) + if self.use_cls_token: + cls_token, x = torch.split(x, [1, H*W], 1) + + x = res + self.drop_path1(self.ls1(self.fw_attn(x.transpose(1, 2).reshape(B, C, H, W), cls_token))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) if self.use_cls_token: From bfae6dd54e8083a3759d601e48fef9f73c44049d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 12:47:20 -0700 Subject: [PATCH 19/32] Update cvt.py --- timm/models/cvt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index b82151f8d5..ebf8bcd9fe 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -224,7 +224,7 @@ def __init__( qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., - input_norm_layer: nn.Module = partial(LayerNorm2d, eps=1e-5), + input_norm_layer: nn.Module = partial(LayerNorm, eps=1e-5), norm_layer: nn.Module = partial(LayerNorm, eps=1e-5), init_values: Optional[float] = None, drop_path: float = 0., @@ -326,7 +326,7 @@ def __init__( qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., - input_norm_layer = LayerNorm2d, + input_norm_layer = LayerNorm, norm_layer: nn.Module = LayerNorm, init_values: Optional[float] = None, drop_path_rates: List[float] = [0.], @@ -417,7 +417,7 @@ def __init__( qk_norm: bool = False, attn_drop: float = 0., proj_drop: float = 0., - input_norm_layer = partial(LayerNorm2d, eps=1e-5), + input_norm_layer = partial(LayerNorm, eps=1e-5), norm_layer: nn.Module = partial(LayerNorm, eps=1e-5), init_values: Optional[float] = None, drop_path_rate: float = 0., From 6c9cc5d6b05408fb039938638c01222d3bd143be Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 12:55:04 -0700 Subject: [PATCH 20/32] Update cvt.py --- timm/models/cvt.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index ebf8bcd9fe..1a21dc62c6 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -314,6 +314,7 @@ def __init__( embed_kernel_size: int = 7, embed_stride: int = 4, embed_padding: int = 2, + embed_norm_layer: nn.Module = partial(LayerNorm2d, eps=1e-5), kernel_size: int = 3, stride_q: int = 1, stride_kv: int = 2, @@ -405,6 +406,7 @@ def __init__( embed_kernel_size: Tuple[int, ...] = (7, 3, 3), embed_stride: Tuple[int, ...] = (4, 2, 2), embed_padding: Tuple[int, ...] = (2, 1, 1), + embed_norm_layer: nn.Module = partial(LayerNorm2d, eps=1e-5), kernel_size: int = 3, stride_q: int = 1, stride_kv: int = 2, @@ -452,6 +454,7 @@ def __init__( embed_kernel_size = embed_kernel_size[stage_idx], embed_stride = embed_stride[stage_idx], embed_padding = embed_padding[stage_idx], + embed_norm_layer = embed_norm_layer, kernel_size = kernel_size, stride_q = stride_q, stride_kv = stride_kv, From 9451d17822d9ff8585e4c8fe6e2d44b52db0a51d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 13:02:11 -0700 Subject: [PATCH 21/32] Update cvt.py --- timm/models/cvt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 1a21dc62c6..da595c1000 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -344,7 +344,7 @@ def __init__( kernel_size = embed_kernel_size, stride = embed_stride, padding = embed_padding, - norm_layer = input_norm_layer, + norm_layer = embed_norm_layer, ) self.embed_drop = nn.Dropout(proj_drop) From 6ab6b1604d70f79edae57fad81f5c4d61ca50337 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 13:16:27 -0700 Subject: [PATCH 22/32] remove probes --- timm/models/cvt.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index da595c1000..68e96dacba 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -268,7 +268,6 @@ def __init__( ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.probe = nn.Identity() def add_cls_token( self, @@ -301,7 +300,6 @@ def forward(self, x: torch.Tensor, cls_token: Optional[torch.Tensor]) -> Tuple[t cls_token, x = torch.split(x, [1, H*W], 1) x = x.transpose(1, 2).reshape(B, C, H, W) - x = self.probe(x) return x, cls_token @@ -377,15 +375,12 @@ def __init__( ) blocks.append(block) self.blocks = nn.ModuleList(blocks) - self.probe = nn.Identity() if self.cls_token is not None: trunc_normal_(self.cls_token, std=.02) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.conv_embed(x) - x = self.probe(x) x = self.embed_drop(x) cls_token = self.embed_drop( From 13064cab63526e721bef61301453d6b8f6600629 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 15:00:05 -0700 Subject: [PATCH 23/32] Update cvt.py --- timm/models/cvt.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 68e96dacba..db776a759a 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -10,7 +10,7 @@ from collections import OrderedDict from functools import partial -from typing import List, Final, Optional, Tuple +from typing import List, Final, Optional, Tuple, Union import torch import torch.nn as nn @@ -379,17 +379,17 @@ def __init__( if self.cls_token is not None: trunc_normal_(self.cls_token, std=.02) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: x = self.conv_embed(x) x = self.embed_drop(x) cls_token = self.embed_drop( self.cls_token.expand(x.shape[0], -1, -1) ) if self.cls_token is not None else None - for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor + for block in self.blocks: # TODO technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tuple x, cls_token = block(x, cls_token) - return x, cls_token + return (x, cls_token) if self.cls_token is not None else x class CvT(nn.Module): def __init__( @@ -429,8 +429,8 @@ def __init__( assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token) self.num_classes = num_classes self.num_features = dims[-1] + self.feature_info = [] - # FIXME only on last stage, no need for tuple self.use_cls_token = use_cls_token[-1] dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -473,7 +473,8 @@ def __init__( ) in_chs = dim stages.append(stage) - self.stages = nn.ModuleList(stages) + self.feature_info += [dict(num_chs=dim, reduction=2, module=f'stages.{stage_idx}')] + self.stages = nn.Sequential(*stages) self.norm = norm_layer(dims[-1]) self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() @@ -481,11 +482,11 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: for stage in self.stages: - x, cls_token = stage(x) + x = stage(x) if self.use_cls_token: - return self.head(self.norm(cls_token.flatten(1))) + return self.head(self.norm(x[1].flatten(1))) else: return self.head(self.norm(x.mean(dim=(2,3)))) @@ -493,8 +494,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ - if 'head.fc.weight' in state_dict: - return state_dict # non-MSFT checkpoint if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] @@ -524,14 +523,13 @@ def _create_cvt(variant, pretrained=False, **kwargs): return model -# TODO update first_conv def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14), 'crop_pct': 0.95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.conv', 'classifier': 'head', + 'first_conv': 'stages.0.conv_embed.conv', 'classifier': 'head', **kwargs } From fb15b5f1d182f79e398971e0468bd600e209806d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 17:55:45 -0700 Subject: [PATCH 24/32] Cvt 1 (#14) * Update cvt.py * Update cvt.py --- timm/models/cvt.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index db776a759a..37f175ada9 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -70,7 +70,7 @@ def __init__( self.dim = dim # FIXME not working, bn layer outputs are incorrect - ''' + self.conv_q = ConvNormAct( dim, dim, @@ -143,7 +143,8 @@ def __init__( groups=dim )), ('bn', nn.BatchNorm2d(dim)),])) - + ''' + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, C, H, W = x.shape # [B, C, H, W] -> [B, H*W, C] @@ -170,7 +171,7 @@ def __init__( self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = dim ** -0.5 - self.fused_attn = False #use_fused_attn() + self.fused_attn = use_fused_attn() self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) @@ -534,7 +535,22 @@ def _cfg(url='', **kwargs): } default_cfgs = generate_default_cfgs({ - 'cvt_13.msft_in1k': _cfg(url='https://files.catbox.moe/xz97kh.pth'), + 'cvt_13.msft_in1k': _cfg( + url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-224x224-IN-1k.pth'), + 'cvt_13.msft_in1k_384': _cfg( + url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-1k.pth', + input_size=(3, 384, 384), pool_size=(24, 24)), + 'cvt_13.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-22k.pth', + input_size=(3, 384, 384), pool_size=(24, 24)), + + 'cvt_21.msft_in1k': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-224x224-IN-1k.pth'), + 'cvt_21.msft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-1k.pth', + input_size=(3, 384, 384), pool_size=(24, 24)), + 'cvt_21.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-22k.pth', + input_size=(3, 384, 384), pool_size=(24, 24)), + + 'cvt_w24.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-w24-384x384-IN-22k.pth', + input_size=(3, 384, 384), pool_size=(24, 24)), }) @@ -542,3 +558,13 @@ def _cfg(url='', **kwargs): def cvt_13(pretrained=False, **kwargs) -> CvT: model_args = dict(depths=(1, 2, 10), dims=(64, 192, 384), num_heads=(1, 3, 6)) return _create_cvt('cvt_13', pretrained=pretrained, **dict(model_args, **kwargs)) + +@register_model +def cvt_21(pretrained=False, **kwargs) -> CvT: + model_args = dict(depths=(1, 4, 16), dims=(64, 192, 384), num_heads=(1, 3, 6)) + return _create_cvt('cvt_21', pretrained=pretrained, **dict(model_args, **kwargs)) + +@register_model +def cvt_w24(pretrained=False, **kwargs) -> CvT: + model_args = dict(depths=(2, 2, 20), dims=(192, 768, 1024), num_heads=(3, 12, 16)) + return _create_cvt('cvt_w24', pretrained=pretrained, **dict(model_args, **kwargs)) \ No newline at end of file From 65681ecaf8568653f0d594870091ec8f8377bc34 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 20 May 2024 18:11:22 -0700 Subject: [PATCH 25/32] Update cvt.py --- timm/models/cvt.py | 36 +----------------------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 37f175ada9..95f2f9d0d5 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -109,41 +109,7 @@ def __init__( norm_layer=norm_layer, act_layer=act_layer ) - ''' - self.conv_q = nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - dim, - dim, - kernel_size=kernel_size, - padding=padding, - stride=stride_q, - bias=bias, - groups=dim - )), - ('bn', nn.BatchNorm2d(dim)),])) - self.conv_k = nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - dim, - dim, - kernel_size=kernel_size, - padding=padding, - stride=stride_kv, - bias=bias, - groups=dim - )), - ('bn', nn.BatchNorm2d(dim)),])) - self.conv_v = nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d( - dim, - dim, - kernel_size=kernel_size, - padding=padding, - stride=stride_kv, - bias=bias, - groups=dim - )), - ('bn', nn.BatchNorm2d(dim)),])) - ''' + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, C, H, W = x.shape From 9dc79592645c72b64621b3d4e8dc820bb3139f90 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 8 Jun 2024 00:00:42 -0700 Subject: [PATCH 26/32] Update cvt.py --- timm/models/cvt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 95f2f9d0d5..b47e15b1d9 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -389,6 +389,7 @@ def __init__( mlp_ratio: float = 4., mlp_act_layer: nn.Module = QuickGELU, use_cls_token: Tuple[bool, ...] = (False, False, True), + drop_rate: float = 0., ) -> None: super().__init__() num_stages = len(dims) From bda987bc6eb78aad4d9421aea6ae308b31561503 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 28 Jun 2024 05:03:00 -0700 Subject: [PATCH 27/32] Validation default cfg --- timm/models/cvt.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index b47e15b1d9..a38a84da28 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -495,9 +495,9 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14), - 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stages.0.conv_embed.conv', 'classifier': 'head', + 'first_conv': 'stage0.patch_embed.proj', 'classifier': 'head', **kwargs } @@ -506,18 +506,18 @@ def _cfg(url='', **kwargs): url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-224x224-IN-1k.pth'), 'cvt_13.msft_in1k_384': _cfg( url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-1k.pth', - input_size=(3, 384, 384), pool_size=(24, 24)), + input_size=(3, 384, 384), pool_size=(24, 24), crop_mode='squash', crop_pct=1.0), 'cvt_13.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-13-384x384-IN-22k.pth', - input_size=(3, 384, 384), pool_size=(24, 24)), + input_size=(3, 384, 384), pool_size=(24, 24), crop_mode='squash', crop_pct=1.0), 'cvt_21.msft_in1k': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-224x224-IN-1k.pth'), 'cvt_21.msft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-1k.pth', - input_size=(3, 384, 384), pool_size=(24, 24)), + input_size=(3, 384, 384), pool_size=(24, 24), crop_mode='squash', crop_pct=1.0), 'cvt_21.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-21-384x384-IN-22k.pth', - input_size=(3, 384, 384), pool_size=(24, 24)), + input_size=(3, 384, 384), pool_size=(24, 24), crop_mode='squash', crop_pct=1.0), 'cvt_w24.msft_in22k_ft_in1k_384': _cfg(url='https://github.com/fffffgggg54/pytorch-image-models/releases/download/cvt/CvT-w24-384x384-IN-22k.pth', - input_size=(3, 384, 384), pool_size=(24, 24)), + input_size=(3, 384, 384), pool_size=(24, 24), crop_mode='squash', crop_pct=1.0), }) From 98bb6903bc47206e5637722c0852b4ca8c80a46e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Fri, 28 Jun 2024 05:03:22 -0700 Subject: [PATCH 28/32] Update cvt.py --- timm/models/cvt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index a38a84da28..1e03cf7f13 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -497,7 +497,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stage0.patch_embed.proj', 'classifier': 'head', + 'first_conv': 'stages.0.conv_embed.conv', 'classifier': 'head', **kwargs } From ee07e7caaa9055cdee081cc180c9ad3eba911b08 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 2 Jul 2024 00:04:56 -0700 Subject: [PATCH 29/32] Update cvt.py --- timm/models/cvt.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 1e03cf7f13..ca3361fc53 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn +from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -447,17 +447,30 @@ def __init__( self.norm = norm_layer(dims[-1]) self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def _forward_features(self, x: torch.Tensor) -> torch.Tensor: + # nn.Sequential forward can't accept tuple intermediates + # TODO grad checkpointing for stage in self.stages: x = stage(x) - + return x + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self._forward_features(x) + + return x[0] if self.use_cls_token else x + + def forward_head(self, x: torch.Tensor) -> torch.Tensor: if self.use_cls_token: return self.head(self.norm(x[1].flatten(1))) else: return self.head(self.norm(x.mean(dim=(2,3)))) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._forward_features(x) + x = self.forward_head(x) + return x def checkpoint_filter_fn(state_dict, model): From 297531861030d2b11aeff2a298aff091523febf6 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 2 Jul 2024 00:13:24 -0700 Subject: [PATCH 30/32] Update cvt.py --- timm/models/cvt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index ca3361fc53..7418677074 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -388,18 +388,18 @@ def __init__( mlp_layer: nn.Module = Mlp, mlp_ratio: float = 4., mlp_act_layer: nn.Module = QuickGELU, - use_cls_token: Tuple[bool, ...] = (False, False, True), + use_cls_token: bool = True, drop_rate: float = 0., ) -> None: super().__init__() num_stages = len(dims) assert num_stages == len(depths) == len(embed_kernel_size) == len(embed_stride) - assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token) + assert num_stages == len(embed_padding) == len(num_heads) self.num_classes = num_classes self.num_features = dims[-1] self.feature_info = [] - self.use_cls_token = use_cls_token[-1] + self.use_cls_token = use_cls_token dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -437,7 +437,7 @@ def __init__( mlp_layer = mlp_layer, mlp_ratio = mlp_ratio, mlp_act_layer = mlp_act_layer, - use_cls_token = use_cls_token[stage_idx], + use_cls_token = use_cls_token and stage_idx == num_stages - 1, ) in_chs = dim stages.append(stage) From 4d1b21a89c1c7fc7980fdcabc48adb1f196bed04 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 6 Jul 2024 00:41:19 -0700 Subject: [PATCH 31/32] Update cvt.py --- timm/models/cvt.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index 7418677074..b5d247735a 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -400,6 +400,7 @@ def __init__( self.feature_info = [] self.use_cls_token = use_cls_token + self.global_pool = 'token' if use_cls_token else 'avg' dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -448,6 +449,21 @@ def __init__( self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head + + def reset_classifier(self, num_classes: int, global_pool = None) -> None: + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + if global_pool == 'token' and not self.use_cls_token: + assert False, 'Model not configured to use class token' + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _forward_features(self, x: torch.Tensor) -> torch.Tensor: # nn.Sequential forward can't accept tuple intermediates # TODO grad checkpointing @@ -457,12 +473,13 @@ def _forward_features(self, x: torch.Tensor) -> torch.Tensor: return x def forward_features(self, x: torch.Tensor) -> torch.Tensor: + # get feature map, not always used x = self._forward_features(x) return x[0] if self.use_cls_token else x def forward_head(self, x: torch.Tensor) -> torch.Tensor: - if self.use_cls_token: + if self.global_pool == 'token': return self.head(self.norm(x[1].flatten(1))) else: return self.head(self.norm(x.mean(dim=(2,3)))) From 430f7b4c5de9c144d6f08c1e0df509d935e8685d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 15 Dec 2024 02:41:24 -0700 Subject: [PATCH 32/32] Update cvt.py --- timm/models/cvt.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index b5d247735a..a8a170abaf 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to +from timm.layers import ConvNormAct, DropPath, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -175,6 +175,20 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Te x = self.proj_drop(x) return x +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + class CvTBlock(nn.Module): def __init__( self,