Skip to content

Commit

Permalink
Fix rotary embed version of attn pool. Bit of cleanup/naming
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jun 12, 2024
1 parent cdc7bce commit 57adc1a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
27 changes: 17 additions & 10 deletions timm/layers/attention_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
qkv_bias: bool = True,
qkv_separate: bool = False,
pool_type: str = 'token',
avg_token: bool = True,
class_token: bool = False,
drop_rate: float = 0.,
):
super().__init__()
Expand All @@ -63,6 +63,11 @@ def __init__(
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

if class_token:
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
else:
self.cls_token = None

if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
Expand Down Expand Up @@ -109,7 +114,10 @@ def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
if self.cls_token is None:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
else:
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
Expand All @@ -130,7 +138,6 @@ def forward(self, x, pre_logits: bool = False):
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = x[:, 0]
x = self.drop(x)
if pre_logits:
x = self._pool(x, H, W)
Expand Down Expand Up @@ -162,7 +169,7 @@ def __init__(
qkv_bias: bool = True,
qkv_separate: bool = False,
pool_type: str = 'token',
learned_token: bool = False,
class_token: bool = False,
drop_rate: float = 0.,
):
super().__init__()
Expand All @@ -184,10 +191,10 @@ def __init__(
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

if learned_token:
self.token = nn.Parameter(torch.zeros(1, embed_dim))
if class_token:
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
else:
self.token = None
self.cls_token = None

if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
Expand Down Expand Up @@ -239,10 +246,10 @@ def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
x = x.flatten(2).transpose(1, 2)
if self.token is not None:
x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1)
else:
if self.cls_token is None:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
else:
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
x = x + pos_embed

Expand Down
4 changes: 0 additions & 4 deletions timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1945,7 +1945,6 @@ def _init_weights(module, name='', zero_init_last=False):
downsample='avg',
aa_layer='avg',
head_type='attn_abs',
#head_hidden_size=512,
),

resnet50x4_clip=ByoModelCfg(
Expand All @@ -1962,7 +1961,6 @@ def _init_weights(module, name='', zero_init_last=False):
downsample='avg',
aa_layer='avg',
head_type='attn_abs',
#head_hidden_size=640,
),

resnet50x16_clip=ByoModelCfg(
Expand All @@ -1979,7 +1977,6 @@ def _init_weights(module, name='', zero_init_last=False):
downsample='avg',
aa_layer='avg',
head_type='attn_abs',
#head_hidden_size=768,
),

resnet50x64_clip=ByoModelCfg(
Expand All @@ -1996,7 +1993,6 @@ def _init_weights(module, name='', zero_init_last=False):
downsample='avg',
aa_layer='avg',
head_type='attn_abs',
#head_hidden_size=1024,
),

resnet50_mlp=ByoModelCfg(
Expand Down

0 comments on commit 57adc1a

Please sign in to comment.