Skip to content

Commit

Permalink
add cvt code
Browse files Browse the repository at this point in the history
  • Loading branch information
arch committed Nov 26, 2021
1 parent 5dd9532 commit ce300b2
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 9 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ To the standard question, I can say: I can not provide an pretrained Model! I do
- [ ] increase model layers
- Not possible i have only 6GB Video Ram.
- [ ] Implement validation to prevent over fitting

## Current State

- Model 1 convergence when training, firs results are good
- Model 2 do not convergence when training
- Model 3 WIP
15 changes: 14 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ general:
checkpoint_dir: 'checkpoint'
train_dir: './data/train'
batch_size: 1 # data loader is only implemented for case == 1!
select: 'model1'
select: 'model3'
test_file: './data/test/example1.mkv'

model1:
Expand All @@ -31,3 +31,16 @@ model2:
lr: 0.001
lr_milestones: [7, 12, 15]
epochs: 25

model3:
name: 'funpos3'
class: 'Model3'
skip_frames: 2
img_width: 32
img_height: 32
img_channels: 3
convlstm_hidden_dim: 64
seq_len: 8
lr: 0.0001
lr_milestones: [7, 12, 15]
epochs: 25
17 changes: 9 additions & 8 deletions model/model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self, t, h, w, patch_t, patch_h, patch_w, dim, depth, heads, mlp_di

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, self.T)
nn.Linear(dim, dim//self.t),
nn.Linear(dim//self.t, self.T)
)

def forward(self, x):
Expand All @@ -77,11 +78,11 @@ def __init__(self):
t = CONFIG[MODEL]['seq_len'],
h = CONFIG[MODEL]['img_height'],
w = CONFIG[MODEL]['img_width'],
patch_t = 4,
patch_h = 8,
patch_w = 8,
dim = 512,
depth = 6,
heads = 10,
mlp_dim = 8,
patch_t = 8,
patch_h = 16,
patch_w = 16,
dim = 768,
depth = 12,
heads = 12,
mlp_dim = 4096,
)
96 changes: 96 additions & 0 deletions model/model3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import torch.nn as nn
from utils.config import CONFIG

from modules.convlstm import ConvLSTM
from modules.convtransformer import CvT

MODEL='model3'

class Flatten(torch.nn.Module):
def forward(self, input):
b, seq_len, _, h, w = input.size()
return input.view(b, seq_len, -1)


class Model3(nn.Module):
""" FunPos Model """
def __init__(self):
super().__init__()

reduce = 1
self.conv1 = CvT(
image_size=CONFIG[MODEL]['img_height'],
in_channels = CONFIG[MODEL]['img_channels']
)

reduce += 1
self.conv2 = CvT(
image_size=CONFIG[MODEL]['img_height'],
in_channels = 16,
dim = 32
)

self.convlstm1 = ConvLSTM(
img_size = (int(CONFIG[MODEL]['img_height']), int(CONFIG[MODEL]['img_width'])),
input_dim = 32,
hidden_dim = CONFIG[MODEL]['convlstm_hidden_dim']*2,
kernel_size = (3,3),
cnn_dropout = 0.1,
rnn_dropout = 0.1,
batch_first = True,
bias = False,
layer_norm = True,
return_sequence = True,
bidirectional = True
)

self.convlstm2 = ConvLSTM(
img_size = (int(CONFIG[MODEL]['img_height']), int(CONFIG[MODEL]['img_width'])),
input_dim = 256,
hidden_dim = CONFIG[MODEL]['convlstm_hidden_dim']*2,
kernel_size = (3,3),
cnn_dropout = 0.1,
rnn_dropout = 0.1,
batch_first = True,
bias = False,
layer_norm = False,
return_sequence = True,
bidirectional = True
)

self.flatten = Flatten()

self.fc1 = nn.Linear(
int(2*CONFIG[MODEL]['img_width'])*int(2*CONFIG[MODEL]['img_height'])*CONFIG[MODEL]['convlstm_hidden_dim'],
128
)
self.fc2 = nn.Linear(128, 1)


def forward(self, x, hidden_state=None):
""" Forward pass
Args:
x (torch.tensor): 5-D Tensor of shape (batch, time, channel, height, width)
Returns:
tensor: prediction
"""

b, seq_len, _, h, w = x.size()
x_new = []
for t in range(CONFIG[MODEL]['seq_len']):
a = self.conv1(x[:,t,:,:,:])
a = self.conv2(a)
x_new.append(a)
x = torch.stack(x_new, dim=1)

x, last_state, last_state_inv = self.convlstm1(x)
x, last_state, last_state_inv = self.convlstm2(x)

x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)

return x
144 changes: 144 additions & 0 deletions modules/convtransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
from torch import nn, einsum
from einops import rearrange
from einops.layers.torch import Rearrange


class SepConv2d(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,):
super(SepConv2d, self).__init__()
self.depthwise = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels)
self.bn = torch.nn.BatchNorm2d(in_channels)
self.pointwise = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
x = self.depthwise(x)
x = self.bn(x)
x = self.pointwise(x)
return x


class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)


class ConvAttention(nn.Module):
def __init__(self, dim, img_size, heads = 8, dim_head = 64, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0.,
last_stage=False):

super().__init__()
self.last_stage = last_stage
self.img_size = img_size
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head ** -0.5
pad = (kernel_size - q_stride)//2
self.to_q = SepConv2d(dim, inner_dim, kernel_size, q_stride, pad)
self.to_k = SepConv2d(dim, inner_dim, kernel_size, k_stride, pad)
self.to_v = SepConv2d(dim, inner_dim, kernel_size, v_stride, pad)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()

def forward(self, x):
b, n, _, h = *x.shape, self.heads
x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size)
q = self.to_q(x)
q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h)

v = self.to_v(x)
v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)

k = self.to_k(x)
k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)

dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

attn = dots.softmax(dim=-1)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out


class Transformer(nn.Module):
def __init__(self, dim, img_size, depth, heads, dim_head, mlp_dim, dropout=0., last_stage=False):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, ConvAttention(dim, img_size, heads=heads, dim_head=dim_head, dropout=dropout, last_stage=last_stage)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x


class CvT(nn.Module):
def __init__(self,
image_size: int,
in_channels: int,
dim: int = 16,
kernel_size: int = 3,
stride: int = 1,
depth: int = 6,
heads: int = 6,
dropout: float = 0.01,
scale_dim: int = 4):
super().__init__()

self.conv_embed = nn.Sequential(
nn.Conv2d(in_channels, dim, kernel_size, stride, 1),
Rearrange('b c h w -> b (h w) c', h = image_size, w = image_size),
nn.LayerNorm(dim)
)
self.transformer = nn.Sequential(
Transformer(dim=dim, img_size=image_size, depth=depth, heads=heads, dim_head=dim, mlp_dim=dim * scale_dim, dropout=dropout),
Rearrange('b (h w) c -> b c h w', h = image_size, w = image_size)
)


def forward(self, x):
x = self.conv_embed(x)
x = self.transformer(x)
return x
1 change: 1 addition & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from utils.config import CONFIG
from model.model1 import FunPosModel
from model.model2 import FunPosTransformerModel
from model.model3 import Model3

import numpy as np

Expand Down
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from utils.dataset import Funscript_Dataset
from model.model1 import FunPosModel
from model.model2 import FunPosTransformerModel
from model.model3 import Model3
from utils.config import CONFIG


Expand Down
1 change: 1 addition & 0 deletions utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def read_next_frame(self):
while len(self.last_frames) >= self.seq_len:
del self.last_frames[0]
frame = self.stream.read()
# frame = frame / 255.0
if frame is None:
self.open_next_video()
else:
Expand Down

0 comments on commit ce300b2

Please sign in to comment.