diff --git a/timm/models/cvt.py b/timm/models/cvt.py index b5d247735..a8a170aba 100644 --- a/timm/models/cvt.py +++ b/timm/models/cvt.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import ConvNormAct, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to +from timm.layers import ConvNormAct, DropPath, LayerNorm, LayerNorm2d, Mlp, QuickGELU, trunc_normal_, use_fused_attn, nchw_to from ._builder import build_model_with_cfg from ._registry import generate_default_cfgs, register_model @@ -175,6 +175,20 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Te x = self.proj_drop(x) return x +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + class CvTBlock(nn.Module): def __init__( self,