From 430f7b4c5de9c144d6f08c1e0df509d935e8685d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 15 Dec 2024 02:41:24 -0700 Subject: [PATCH] Update cvt.py --- timm/models/cvt.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/timm/models/cvt.py b/timm/models/cvt.py index b5d247735a..a8a170abaf 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to +from timm.layers import ConvNormAct, DropPath, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -175,6 +175,20 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Te x = self.proj_drop(x) return x +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + class CvTBlock(nn.Module): def __init__( self,