Skip to content

Commit

Permalink
Make 2d attention pool modules compatible with head interface. Use at…
Browse files Browse the repository at this point in the history
…tention pool in CLIP ResNets as head. Make separate set of GAP models w/ avg pool instead of attn pool.
  • Loading branch information
rwightman committed Jun 12, 2024
1 parent 30ffa15 commit cdc7bce
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 104 deletions.
64 changes: 59 additions & 5 deletions timm/layers/attention_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ def __init__(
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
drop: float = 0.,
pool_type: str = 'token',
avg_token: bool = True,
drop_rate: float = 0.,
):
super().__init__()
assert pool_type in ('', 'token')
self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features
self.out_features = out_features or in_features
Expand All @@ -56,6 +59,7 @@ def __init__(
num_heads = embed_dim // head_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.pool_type = pool_type.lower()
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

Expand All @@ -66,6 +70,7 @@ def __init__(
self.qkv = None
else:
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.drop = nn.Dropout(drop_rate)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)

Expand All @@ -83,6 +88,23 @@ def init_weights(self, zero_init_last: bool = False):
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)

def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
# NOTE: this module is being used as a head, so need compatible reset()
if pool_type is not None:
assert pool_type in ('', 'token')
self.pool_type = pool_type
if num_classes is not None:
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
self.out_features = num_classes if num_classes > 0 else self.embed_dim

def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
if self.pool_type == 'token':
x = x[:, 0]
else:
# if not pooled, return spatial output without token
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
return x

def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
Expand Down Expand Up @@ -111,8 +133,10 @@ def forward(self, x, pre_logits: bool = False):
x = x[:, 0]
x = self.drop(x)
if pre_logits:
x = self._pool(x, H, W)
return x
x = self.proj(x)
x = self._pool(x, H, W)
return x


Expand All @@ -137,9 +161,12 @@ def __init__(
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
drop: float = 0.,
pool_type: str = 'token',
learned_token: bool = False,
drop_rate: float = 0.,
):
super().__init__()
assert pool_type in ('', 'token')
self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features
self.out_features = out_features or in_features
Expand All @@ -153,9 +180,15 @@ def __init__(
self.seq_len = self.feat_size[0] * self.feat_size[1]
self.num_heads = num_heads
self.head_dim = head_dim
self.pool_type = pool_type
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

if learned_token:
self.token = nn.Parameter(torch.zeros(1, embed_dim))
else:
self.token = None

if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
Expand All @@ -164,7 +197,7 @@ def __init__(
else:
self.q = self.k = self.v = None
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.drop = nn.Dropout(drop)
self.drop = nn.Dropout(drop_rate)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))

Expand All @@ -185,11 +218,31 @@ def init_weights(self, zero_init_last: bool = False):
nn.init.zeros_(self.qkv.bias)
trunc_normal_(self.pos_embed, std=in_features ** -0.5)

def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
# NOTE: this module is being used as a head, so need compatible reset()
if pool_type is not None:
assert pool_type in ('', 'token')
self.pool_type = pool_type
if num_classes is not None:
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
self.out_features = num_classes if num_classes > 0 else self.embed_dim

def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
if self.pool_type == 'token':
x = x[:, 0]
else:
# if not pooled, return spatial output without token
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
return x

def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
if self.token is not None:
x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1)
else:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
x = x + pos_embed

Expand All @@ -209,9 +262,10 @@ def forward(self, x, pre_logits: bool = False):
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = x[:, 0]
x = self.drop(x)
if pre_logits:
x = self._pool(x, H, W)
return x
x = self.proj(x)
x = self._pool(x, H, W)
return x
Loading

0 comments on commit cdc7bce

Please sign in to comment.