Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mapping OpenAI CLIP Modified ResNet weights -> ByobNet. #2198

Merged
merged 6 commits into from
Jun 12, 2024
6 changes: 4 additions & 2 deletions timm/layers/attention_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
feat_size: Optional[int] = None,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
Expand All @@ -36,13 +37,14 @@ def __init__(
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.feat_size = feat_size
self.scale = self.head_dim ** -0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()

if pos_embed == 'abs':
spatial_len = self.feat_size
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
assert feat_size is not None
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
else:
self.pos_embed = None

Expand Down
194 changes: 136 additions & 58 deletions timm/layers/attention_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Union, Tuple
from typing import Optional, Union, Tuple

import torch
import torch.nn as nn

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from torch import nn


from. config import use_fused_attn
from .helpers import to_2tuple
from .pos_embed import resample_abs_pos_embed
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
from .weight_init import trunc_normal_

Expand All @@ -27,51 +29,84 @@ class RotAttentionPool2d(nn.Module):
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
"""
fused_attn: torch.jit.Final[bool]

def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
out_features: Optional[int] = None,
ref_feat_size: Union[int, Tuple[int, int]] = 7,
embed_dim: Optional[int] = None,
head_dim: Optional[int] = 64,
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
self.in_features = in_features
self.out_features = out_features or in_features
ref_feat_size = to_2tuple(ref_feat_size)
if num_heads is not None:
assert embed_dim % num_heads == 0
head_dim = embed_dim // num_heads
else:
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.pos_embed = RotaryEmbedding(self.head_dim)

trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
self.fused_attn = use_fused_attn()

if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.qkv = None
else:
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)

def init_weights(self, zero_init_last: bool = False):
if self.qkv is None:
in_features = self.q.in_features
trunc_normal_(self.q.weight, std=in_features ** -0.5)
nn.init.zeros_(self.q.bias)
trunc_normal_(self.k.weight, std=in_features ** -0.5)
nn.init.zeros_(self.k.bias)
trunc_normal_(self.v.weight, std=in_features ** -0.5)
nn.init.zeros_(self.v.bias)
else:
in_features = self.qkv.in_features
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)

def forward(self, x):
B, _, H, W = x.shape
N = H * W
x = x.reshape(B, -1, N).permute(0, 2, 1)

x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)

x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]

qc, q = q[:, :, :1], q[:, :, 1:]
sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
q = apply_rot_embed(q, sin_emb, cos_emb)
q = torch.cat([qc, q], dim=2)

kc, k = k[:, :, :1], k[:, :, 1:]
k = apply_rot_embed(k, sin_emb, cos_emb)
k = torch.cat([kc, k], dim=2)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)

x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
else:
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x.unbind(0)

rse, rce = self.pos_embed.get_embed((H, W))
q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v)

if self.fused_attn:
x = nn.functional.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = self.proj(x)
return x[:, 0]

Expand All @@ -85,47 +120,90 @@ class AttentionPool2d(nn.Module):

NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
"""
fused_attn: torch.jit.Final[bool]

def __init__(
self,
in_features: int,
feat_size: Union[int, Tuple[int, int]],
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
feat_size: Union[int, Tuple[int, int]] = 7,
out_features: Optional[int] = None,
embed_dim: Optional[int] = None,
head_dim: Optional[int] = 64,
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
):
super().__init__()

embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.in_features = in_features
self.out_features = out_features or in_features
if num_heads is not None:
assert embed_dim % num_heads == 0
head_dim = embed_dim // num_heads
else:
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
self.feat_size = to_2tuple(feat_size)
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
self.seq_len = self.feat_size[0] * self.feat_size[1]
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5

spatial_dim = self.feat_size[0] * self.feat_size[1]
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
self.fused_attn = use_fused_attn()

if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.qkv = None
else:
self.q = self.k = self.v = None
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))

self.init_weights()

def init_weights(self, zero_init_last: bool = False):
if self.qkv is None:
in_features = self.q.in_features
trunc_normal_(self.q.weight, std=in_features ** -0.5)
nn.init.zeros_(self.q.bias)
trunc_normal_(self.k.weight, std=in_features ** -0.5)
nn.init.zeros_(self.k.bias)
trunc_normal_(self.v.weight, std=in_features ** -0.5)
nn.init.zeros_(self.v.bias)
else:
in_features = self.qkv.in_features
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)

def forward(self, x):
B, _, H, W = x.shape
N = H * W
assert self.feat_size[0] == H
assert self.feat_size[1] == W
x = x.reshape(B, -1, N).permute(0, 2, 1)
x = x.flatten(2).transpose(1, 2)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)

x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)

x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
if self.seq_len != N:
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
else:
pos_embed = self.pos_embed.unsqueeze(0).to(x.dtype)
x = x + pos_embed

if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
else:
x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x.unbind(0)

if self.fused_attn:
x = nn.functional.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = self.proj(x)
return x[:, 0]
2 changes: 0 additions & 2 deletions timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def _create_pool(
):
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(
pool_type=pool_type,
Expand Down
1 change: 0 additions & 1 deletion timm/layers/pos_embed_sincos.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def __init__(
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
Expand Down
Loading
Loading