Skip to content

Commit

Permalink
Merge pull request #2198 from huggingface/openai_clip_resnet
Browse files Browse the repository at this point in the history
Mapping OpenAI CLIP Modified ResNet weights -> ByobNet.
  • Loading branch information
rwightman authored Jun 12, 2024
2 parents 5aa49d5 + 57adc1a commit 7b0a532
Show file tree
Hide file tree
Showing 12 changed files with 656 additions and 200 deletions.
6 changes: 4 additions & 2 deletions timm/layers/attention_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
out_features: int = None,
embed_dim: int = None,
num_heads: int = 8,
feat_size: Optional[int] = None,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
qk_norm: bool = False,
Expand All @@ -36,13 +37,14 @@ def __init__(
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.feat_size = feat_size
self.scale = self.head_dim ** -0.5
self.pool = pool_type
self.fused_attn = use_fused_attn()

if pos_embed == 'abs':
spatial_len = self.feat_size
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
assert feat_size is not None
self.pos_embed = nn.Parameter(torch.zeros(feat_size, in_features))
else:
self.pos_embed = None

Expand Down
267 changes: 207 additions & 60 deletions timm/layers/attention_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Union, Tuple
from typing import Optional, Union, Tuple

import torch
import torch.nn as nn

from. config import use_fused_attn
from .helpers import to_2tuple
from .pos_embed import resample_abs_pos_embed
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
from .weight_init import trunc_normal_

Expand All @@ -27,53 +29,122 @@ class RotAttentionPool2d(nn.Module):
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
"""
fused_attn: torch.jit.Final[bool]

def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
out_features: Optional[int] = None,
ref_feat_size: Union[int, Tuple[int, int]] = 7,
embed_dim: Optional[int] = None,
head_dim: Optional[int] = 64,
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
pool_type: str = 'token',
class_token: bool = False,
drop_rate: float = 0.,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
assert pool_type in ('', 'token')
self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features
self.out_features = out_features or in_features
ref_feat_size = to_2tuple(ref_feat_size)
if num_heads is not None:
assert embed_dim % num_heads == 0
head_dim = embed_dim // num_heads
else:
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.head_dim = head_dim
self.pool_type = pool_type.lower()
self.scale = self.head_dim ** -0.5
self.pos_embed = RotaryEmbedding(self.head_dim)
self.fused_attn = use_fused_attn()

trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
if class_token:
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
else:
self.cls_token = None

def forward(self, x):
B, _, H, W = x.shape
N = H * W
x = x.reshape(B, -1, N).permute(0, 2, 1)
if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.qkv = None
else:
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.drop = nn.Dropout(drop_rate)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = RotaryEmbedding(self.head_dim, in_pixels=False, ref_feat_shape=ref_feat_size)

x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
def init_weights(self, zero_init_last: bool = False):
if self.qkv is None:
in_features = self.q.in_features
trunc_normal_(self.q.weight, std=in_features ** -0.5)
nn.init.zeros_(self.q.bias)
trunc_normal_(self.k.weight, std=in_features ** -0.5)
nn.init.zeros_(self.k.bias)
trunc_normal_(self.v.weight, std=in_features ** -0.5)
nn.init.zeros_(self.v.bias)
else:
in_features = self.qkv.in_features
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)

x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
# NOTE: this module is being used as a head, so need compatible reset()
if pool_type is not None:
assert pool_type in ('', 'token')
self.pool_type = pool_type
if num_classes is not None:
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
self.out_features = num_classes if num_classes > 0 else self.embed_dim

qc, q = q[:, :, :1], q[:, :, 1:]
sin_emb, cos_emb = self.pos_embed.get_embed((H, W))
q = apply_rot_embed(q, sin_emb, cos_emb)
q = torch.cat([qc, q], dim=2)
def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
if self.pool_type == 'token':
x = x[:, 0]
else:
# if not pooled, return spatial output without token
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
return x

kc, k = k[:, :, :1], k[:, :, 1:]
k = apply_rot_embed(k, sin_emb, cos_emb)
k = torch.cat([kc, k], dim=2)
def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
x = x.flatten(2).transpose(1, 2)
if self.cls_token is None:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
else:
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
else:
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x.unbind(0)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
rse, rce = self.pos_embed.get_embed((H, W))
q = torch.cat([q[:, :, :1, :], apply_rot_embed(q[:, :, 1:, :], rse, rce)], dim=2).type_as(v)
k = torch.cat([k[:, :, :1, :], apply_rot_embed(k[:, :, 1:, :], rse, rce)], dim=2).type_as(v)

x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
if self.fused_attn:
x = nn.functional.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = self.drop(x)
if pre_logits:
x = self._pool(x, H, W)
return x
x = self.proj(x)
return x[:, 0]
x = self._pool(x, H, W)
return x


class AttentionPool2d(nn.Module):
Expand All @@ -85,47 +156,123 @@ class AttentionPool2d(nn.Module):
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
"""
fused_attn: torch.jit.Final[bool]

def __init__(
self,
in_features: int,
feat_size: Union[int, Tuple[int, int]],
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
feat_size: Union[int, Tuple[int, int]] = 7,
out_features: Optional[int] = None,
embed_dim: Optional[int] = None,
head_dim: Optional[int] = 64,
num_heads: Optional[int] = None,
qkv_bias: bool = True,
qkv_separate: bool = False,
pool_type: str = 'token',
class_token: bool = False,
drop_rate: float = 0.,
):
super().__init__()

embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
assert pool_type in ('', 'token')
self.embed_dim = embed_dim = embed_dim or in_features
self.in_features = in_features
self.out_features = out_features or in_features
if num_heads is not None:
assert embed_dim % num_heads == 0
head_dim = embed_dim // num_heads
else:
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
self.feat_size = to_2tuple(feat_size)
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
self.seq_len = self.feat_size[0] * self.feat_size[1]
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.head_dim = head_dim
self.pool_type = pool_type
self.scale = self.head_dim ** -0.5
self.fused_attn = use_fused_attn()

spatial_dim = self.feat_size[0] * self.feat_size[1]
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
if class_token:
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
else:
self.cls_token = None

if qkv_separate:
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.v = nn.Linear(in_features, embed_dim, bias=qkv_bias)
self.qkv = None
else:
self.q = self.k = self.v = None
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.drop = nn.Dropout(drop_rate)
self.proj = nn.Linear(embed_dim, self.out_features)
self.pos_embed = nn.Parameter(torch.zeros(self.seq_len + 1, in_features))

self.init_weights()

def init_weights(self, zero_init_last: bool = False):
if self.qkv is None:
in_features = self.q.in_features
trunc_normal_(self.q.weight, std=in_features ** -0.5)
nn.init.zeros_(self.q.bias)
trunc_normal_(self.k.weight, std=in_features ** -0.5)
nn.init.zeros_(self.k.bias)
trunc_normal_(self.v.weight, std=in_features ** -0.5)
nn.init.zeros_(self.v.bias)
else:
in_features = self.qkv.in_features
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)

def forward(self, x):
def reset(self, num_classes: Optional[int] = None, pool_type: Optional[str] = None):
# NOTE: this module is being used as a head, so need compatible reset()
if pool_type is not None:
assert pool_type in ('', 'token')
self.pool_type = pool_type
if num_classes is not None:
self.proj = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
self.out_features = num_classes if num_classes > 0 else self.embed_dim

def _pool(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
if self.pool_type == 'token':
x = x[:, 0]
else:
# if not pooled, return spatial output without token
x = x[:, 1:].reshape(x.shape[0], H, W, -1).permute(0, 3, 1, 2)
return x

def forward(self, x, pre_logits: bool = False):
B, _, H, W = x.shape
N = H * W
assert self.feat_size[0] == H
assert self.feat_size[1] == W
x = x.reshape(B, -1, N).permute(0, 2, 1)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)

x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)

x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
x = x.flatten(2).transpose(1, 2)
if self.cls_token is None:
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
else:
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
x = x + pos_embed

if self.qkv is None:
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
else:
x = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x.unbind(0)

if self.fused_attn:
x = nn.functional.scaled_dot_product_attention(q, k, v)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N + 1, -1)
x = self.drop(x)
if pre_logits:
x = self._pool(x, H, W)
return x
x = self.proj(x)
return x[:, 0]
x = self._pool(x, H, W)
return x
2 changes: 0 additions & 2 deletions timm/layers/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def _create_pool(
):
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(
pool_type=pool_type,
Expand Down
Loading

0 comments on commit 7b0a532

Please sign in to comment.