Skip to content

Commit

Permalink
Merge pull request #12 from KR-HappyFace/koclip
Browse files Browse the repository at this point in the history
add: KoCLIP codes
  • Loading branch information
shawnhyeonsoo authored Dec 23, 2021
2 parents 447d2ff + 513b837 commit 1e9c8c2
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 152 deletions.
1 change: 1 addition & 0 deletions clip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
### CLIP
79 changes: 79 additions & 0 deletions clip/clipmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch.nn as nn
from transformers import RobertaModel, RobertaConfig
import timm


class ImageEncoder(nn.Module):
def __init__(self, model_name, pretrained):
super().__init__()
self.model = timm.create_model(
model_name, pretrained=pretrained, num_classes=0, global_pool="avg"
)
for p in self.model.parameters():
p.requires_grad = True

def forward(self, x):
return self.model(x)


class TextEncoder(nn.Module):
def __init__(self, pretrained):
super().__init__()
if pretrained:
self.model = RobertaModel.from_pretrained("klue/roberta-base")
else:
config = RobertaConfig.from_pretrained("klue/roberta-base")
self.model = RobertaModel(config)

for p in self.model.parameters():
p.requires_grad = True
self.target_token_idx = 0

def forward(self, input_ids, token_type_ids, attention_mask):
output = self.model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
)
last_hidden_state = output.last_hidden_state
return last_hidden_state[:, self.target_token_idx, :]


class ProjectionHead(nn.Module):
def __init__(self, embedding_dim, projection_dim):
super().__init__()
self.projection = nn.Linear(embedding_dim, projection_dim)

def forward(self, x):
projected = self.projection(x)
return projected


class CLIPModel(nn.Module):
def __init__(
self,
):
super().__init__()
self.image_encoder = ImageEncoder("efficientnet_b0", pretrained=False)
self.text_encoder = TextEncoder(pretrained=True)
image_embedding_dim = list(self.image_encoder.parameters())[-1].shape[0]
text_embedding_dim = list(self.text_encoder.parameters())[-1].shape[0]
self.image_projection = ProjectionHead(
embedding_dim=image_embedding_dim, projection_dim=512
)
self.text_projection = ProjectionHead(
embedding_dim=text_embedding_dim, projection_dim=512
)

def forward(self, text, image):
image_features = self.image_encoder(image)
text_features = self.text_encoder(
input_ids=text["input_ids"],
attention_mask=text["attention_mask"],
token_type_ids=text["token_type_ids"],
)

image_embeddings = self.image_projection(image_features)
text_embeddings = self.text_projection(text_features)

return text_embeddings, image_embeddings
41 changes: 41 additions & 0 deletions clip/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
from tqdm import tqdm
import re


class CLIPDataset(Dataset):
def __init__(self, texts, images):
self.texts = texts
self.images = images
self.transform = A.Compose([A.Resize(224, 224), ToTensorV2()])

def __getitem__(self, index):
t = self.texts[index]
single_im = cv2.imread(self.images[index])
single_im = cv2.cvtColor(single_im, cv2.COLOR_BGR2RGB)
im = self.transform(image=single_im)["image"]
return t, im

def __len__(self):
return len(self.texts)


def get_dataset(text_path, image_path):
image_files = [
*image_path.glob("**/*[0-9].png"),
*image_path.glob("**/*[0-9].jpg"),
*image_path.glob("**/*[0-9].jpeg"),
]
text_files = [*text_path.glob("**/*[0-9].txt")]
texts = []
print("Extracting text information!")
for i in tqdm(range(len(text_files))):
with open(text_files[i], "r", encoding="utf-8") as f:
te = f.read()
te = re.sub("스타일에서 스타일은 [가-힣]+.", "", te)
te = re.sub("에서", "", te)
texts.append(te)
return texts, image_files
145 changes: 145 additions & 0 deletions clip/train_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from torch.utils.data import DataLoader
from clipmodel import CLIPModel
from transformers import AutoTokenizer
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.optim import AdamW
import torch
import torch.nn as nn
import wandb
from tqdm import tqdm
import itertools
from dataloader import CLIPDataset, get_dataset
import argparse


def calculate_loss(text_embeds, image_embeds, temperature=0.07):
logits = text_embeds @ image_embeds.T * temperature
targets = torch.arange(len(text_embeds), device="cuda")

texts_loss = nn.CrossEntropyLoss()(logits, targets)
images_loss = nn.CrossEntropyLoss()(logits.T, targets.T)

t_loss = (images_loss + texts_loss) / 2.0
loss = t_loss.mean()
return loss


def evaluate(model, val_dl, tokenizer):
val_loss = 0
with torch.no_grad():
model.eval()
for step, batch in enumerate(tqdm(val_dl)):
text, image = batch
text = tokenizer(
list(text),
padding=True,
pad_to_max_length=True,
max_length=128,
truncation=True,
return_tensors="pt",
)
text = text.to(model.device)
image = image.float()
image = image.to(model.device)
text_embeds, image_embeds = model(text, image)
loss = calculate_loss(text_embeds, image_embeds)
val_loss += loss
print(f"Val Loss: {val_loss / len(val_dl)}")
return val_loss / len(val_dl)


def get_optimizer(model):
params = [
{"params": model.image_encoder.parameters(), "lr": 4e-5},
{"params": model.text_encoder.parameters(), "lr": 4e-5},
{
"params": itertools.chain(
model.image_projection.parameters(), model.text_projection.parameters()
),
"lr": 5e-5,
},
]
optimizer = AdamW(params, weight_decay=0.2, betas=(0.9, 0.98), eps=1e-6)
return optimizer


def train(model, device, train_paths, val_paths=None, num_epochs=100):
model = model.to(device)
text_path, image_path = train_paths
train_texts, train_images = get_dataset(text_path, image_path)

train_dataset = CLIPDataset(train_texts, train_images)
train_dl = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
if val_paths:
val_texts, val_images = get_dataset(text_path, image_path)
val_dataset = CLIPDataset(val_texts, val_images)
val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False)

optimizer = get_optimizer(model)
scheduler = ReduceLROnPlateau(optimizer, "min")
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base")
total_vloss = int(1e9)
scaler = torch.cuda.amp.GradScaler()

for i in range(num_epochs):
print(f"Epochs: {i+1}")
epoch_loss = 0
model.train()
wandb.log({"Epochs": i + 1})
for steps, batch in enumerate(tqdm(train_dl)):
text, image = batch
text = tokenizer(
list(text),
padding=True,
pad_to_max_length=True,
max_length=128,
truncation=True,
return_tensors="pt",
)
text = text.to(device)
image = image.float()
image = image.to(device)
optimizer.zero_grad()

text_embeddings, image_embeddings = model(text, image)
loss = calculate_loss(text_embeddings, image_embeddings)

if (steps + 1) % 100 == 0 and steps > 100:
print(f"Epoch {i}, step {steps+1}, Loss: {loss.item()}")
epoch_loss += loss.item()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

print(f"Epoch Loss: {epoch_loss / len(train_dl)}")
if val_paths:
vloss = evaluate(model, val_dl, tokenizer)
if vloss < total_vloss:
total_vloss = vloss
torch.save(model, "clip.pt")
print(f"Model saved. Current best val loss {total_vloss}")
scheduler.step(epoch_loss)


if __name__ == "__main__":
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument(
"--image_folder",
type=str,
default="/opt/ml/DALLE-Couture/data/cropped_train_img",
help="",
)
parser.add_argument(
"--text_folder",
type=str,
default="/opt/ml/DALLE-Couture/data/train_label",
)
args = parser.parse_args()
model = CLIPModel()
train_paths = [args.text_folder, args.image_folder]
train(
model,
device,
train_paths=train_paths,
)
Loading

0 comments on commit 1e9c8c2

Please sign in to comment.