-
Notifications
You must be signed in to change notification settings - Fork 9
/
vits.py
82 lines (64 loc) · 3.11 KB
/
vits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import math
import torch
import torch.nn as nn
from functools import partial, reduce
from operator import mul
from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.layers.helpers import to_2tuple
from timm.models.layers import PatchEmbed
__all__ = [
'vit_small',
'vit_base',
]
class VisionTransformerMoCo(VisionTransformer):
def __init__(self, stop_grad_conv1=False, **kwargs):
super().__init__(**kwargs)
# Use fixed 2D sin-cos position embedding
self.build_2d_sincos_position_embedding()
# weight initialization
for name, m in self.named_modules():
if isinstance(m, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
else:
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
nn.init.normal_(self.cls_token, std=1e-6)
if isinstance(self.patch_embed, PatchEmbed):
# xavier_uniform initialization
val = math.sqrt(6. / float(3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim))
nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
nn.init.zeros_(self.patch_embed.proj.bias)
if stop_grad_conv1:
self.patch_embed.proj.weight.requires_grad = False
self.patch_embed.proj.bias.requires_grad = False
def build_2d_sincos_position_embedding(self, temperature=10000.):
h, w = self.patch_embed.grid_size
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
pos_dim = self.embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1. / (temperature**omega)
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32)
self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
self.pos_embed.requires_grad = False
def vit_small(**kwargs):
model = VisionTransformerMoCo(
patch_size=16, embed_dim=384, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
def vit_base(**kwargs):
model = VisionTransformerMoCo(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model