-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from KR-HappyFace/koclip
add: KoCLIP codes
- Loading branch information
Showing
5 changed files
with
266 additions
and
152 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
### CLIP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import torch.nn as nn | ||
from transformers import RobertaModel, RobertaConfig | ||
import timm | ||
|
||
|
||
class ImageEncoder(nn.Module): | ||
def __init__(self, model_name, pretrained): | ||
super().__init__() | ||
self.model = timm.create_model( | ||
model_name, pretrained=pretrained, num_classes=0, global_pool="avg" | ||
) | ||
for p in self.model.parameters(): | ||
p.requires_grad = True | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
|
||
class TextEncoder(nn.Module): | ||
def __init__(self, pretrained): | ||
super().__init__() | ||
if pretrained: | ||
self.model = RobertaModel.from_pretrained("klue/roberta-base") | ||
else: | ||
config = RobertaConfig.from_pretrained("klue/roberta-base") | ||
self.model = RobertaModel(config) | ||
|
||
for p in self.model.parameters(): | ||
p.requires_grad = True | ||
self.target_token_idx = 0 | ||
|
||
def forward(self, input_ids, token_type_ids, attention_mask): | ||
output = self.model( | ||
input_ids=input_ids, | ||
token_type_ids=token_type_ids, | ||
attention_mask=attention_mask, | ||
) | ||
last_hidden_state = output.last_hidden_state | ||
return last_hidden_state[:, self.target_token_idx, :] | ||
|
||
|
||
class ProjectionHead(nn.Module): | ||
def __init__(self, embedding_dim, projection_dim): | ||
super().__init__() | ||
self.projection = nn.Linear(embedding_dim, projection_dim) | ||
|
||
def forward(self, x): | ||
projected = self.projection(x) | ||
return projected | ||
|
||
|
||
class CLIPModel(nn.Module): | ||
def __init__( | ||
self, | ||
): | ||
super().__init__() | ||
self.image_encoder = ImageEncoder("efficientnet_b0", pretrained=False) | ||
self.text_encoder = TextEncoder(pretrained=True) | ||
image_embedding_dim = list(self.image_encoder.parameters())[-1].shape[0] | ||
text_embedding_dim = list(self.text_encoder.parameters())[-1].shape[0] | ||
self.image_projection = ProjectionHead( | ||
embedding_dim=image_embedding_dim, projection_dim=512 | ||
) | ||
self.text_projection = ProjectionHead( | ||
embedding_dim=text_embedding_dim, projection_dim=512 | ||
) | ||
|
||
def forward(self, text, image): | ||
image_features = self.image_encoder(image) | ||
text_features = self.text_encoder( | ||
input_ids=text["input_ids"], | ||
attention_mask=text["attention_mask"], | ||
token_type_ids=text["token_type_ids"], | ||
) | ||
|
||
image_embeddings = self.image_projection(image_features) | ||
text_embeddings = self.text_projection(text_features) | ||
|
||
return text_embeddings, image_embeddings |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from torch.utils.data import Dataset | ||
import albumentations as A | ||
from albumentations.pytorch import ToTensorV2 | ||
import cv2 | ||
from tqdm import tqdm | ||
import re | ||
|
||
|
||
class CLIPDataset(Dataset): | ||
def __init__(self, texts, images): | ||
self.texts = texts | ||
self.images = images | ||
self.transform = A.Compose([A.Resize(224, 224), ToTensorV2()]) | ||
|
||
def __getitem__(self, index): | ||
t = self.texts[index] | ||
single_im = cv2.imread(self.images[index]) | ||
single_im = cv2.cvtColor(single_im, cv2.COLOR_BGR2RGB) | ||
im = self.transform(image=single_im)["image"] | ||
return t, im | ||
|
||
def __len__(self): | ||
return len(self.texts) | ||
|
||
|
||
def get_dataset(text_path, image_path): | ||
image_files = [ | ||
*image_path.glob("**/*[0-9].png"), | ||
*image_path.glob("**/*[0-9].jpg"), | ||
*image_path.glob("**/*[0-9].jpeg"), | ||
] | ||
text_files = [*text_path.glob("**/*[0-9].txt")] | ||
texts = [] | ||
print("Extracting text information!") | ||
for i in tqdm(range(len(text_files))): | ||
with open(text_files[i], "r", encoding="utf-8") as f: | ||
te = f.read() | ||
te = re.sub("스타일에서 스타일은 [가-힣]+.", "", te) | ||
te = re.sub("에서", "", te) | ||
texts.append(te) | ||
return texts, image_files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from torch.utils.data import DataLoader | ||
from clipmodel import CLIPModel | ||
from transformers import AutoTokenizer | ||
from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
from torch.optim import AdamW | ||
import torch | ||
import torch.nn as nn | ||
import wandb | ||
from tqdm import tqdm | ||
import itertools | ||
from dataloader import CLIPDataset, get_dataset | ||
import argparse | ||
|
||
|
||
def calculate_loss(text_embeds, image_embeds, temperature=0.07): | ||
logits = text_embeds @ image_embeds.T * temperature | ||
targets = torch.arange(len(text_embeds), device="cuda") | ||
|
||
texts_loss = nn.CrossEntropyLoss()(logits, targets) | ||
images_loss = nn.CrossEntropyLoss()(logits.T, targets.T) | ||
|
||
t_loss = (images_loss + texts_loss) / 2.0 | ||
loss = t_loss.mean() | ||
return loss | ||
|
||
|
||
def evaluate(model, val_dl, tokenizer): | ||
val_loss = 0 | ||
with torch.no_grad(): | ||
model.eval() | ||
for step, batch in enumerate(tqdm(val_dl)): | ||
text, image = batch | ||
text = tokenizer( | ||
list(text), | ||
padding=True, | ||
pad_to_max_length=True, | ||
max_length=128, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
text = text.to(model.device) | ||
image = image.float() | ||
image = image.to(model.device) | ||
text_embeds, image_embeds = model(text, image) | ||
loss = calculate_loss(text_embeds, image_embeds) | ||
val_loss += loss | ||
print(f"Val Loss: {val_loss / len(val_dl)}") | ||
return val_loss / len(val_dl) | ||
|
||
|
||
def get_optimizer(model): | ||
params = [ | ||
{"params": model.image_encoder.parameters(), "lr": 4e-5}, | ||
{"params": model.text_encoder.parameters(), "lr": 4e-5}, | ||
{ | ||
"params": itertools.chain( | ||
model.image_projection.parameters(), model.text_projection.parameters() | ||
), | ||
"lr": 5e-5, | ||
}, | ||
] | ||
optimizer = AdamW(params, weight_decay=0.2, betas=(0.9, 0.98), eps=1e-6) | ||
return optimizer | ||
|
||
|
||
def train(model, device, train_paths, val_paths=None, num_epochs=100): | ||
model = model.to(device) | ||
text_path, image_path = train_paths | ||
train_texts, train_images = get_dataset(text_path, image_path) | ||
|
||
train_dataset = CLIPDataset(train_texts, train_images) | ||
train_dl = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4) | ||
if val_paths: | ||
val_texts, val_images = get_dataset(text_path, image_path) | ||
val_dataset = CLIPDataset(val_texts, val_images) | ||
val_dl = DataLoader(val_dataset, batch_size=8, shuffle=False) | ||
|
||
optimizer = get_optimizer(model) | ||
scheduler = ReduceLROnPlateau(optimizer, "min") | ||
tokenizer = AutoTokenizer.from_pretrained("klue/roberta-base") | ||
total_vloss = int(1e9) | ||
scaler = torch.cuda.amp.GradScaler() | ||
|
||
for i in range(num_epochs): | ||
print(f"Epochs: {i+1}") | ||
epoch_loss = 0 | ||
model.train() | ||
wandb.log({"Epochs": i + 1}) | ||
for steps, batch in enumerate(tqdm(train_dl)): | ||
text, image = batch | ||
text = tokenizer( | ||
list(text), | ||
padding=True, | ||
pad_to_max_length=True, | ||
max_length=128, | ||
truncation=True, | ||
return_tensors="pt", | ||
) | ||
text = text.to(device) | ||
image = image.float() | ||
image = image.to(device) | ||
optimizer.zero_grad() | ||
|
||
text_embeddings, image_embeddings = model(text, image) | ||
loss = calculate_loss(text_embeddings, image_embeddings) | ||
|
||
if (steps + 1) % 100 == 0 and steps > 100: | ||
print(f"Epoch {i}, step {steps+1}, Loss: {loss.item()}") | ||
epoch_loss += loss.item() | ||
scaler.scale(loss).backward() | ||
scaler.step(optimizer) | ||
scaler.update() | ||
|
||
print(f"Epoch Loss: {epoch_loss / len(train_dl)}") | ||
if val_paths: | ||
vloss = evaluate(model, val_dl, tokenizer) | ||
if vloss < total_vloss: | ||
total_vloss = vloss | ||
torch.save(model, "clip.pt") | ||
print(f"Model saved. Current best val loss {total_vloss}") | ||
scheduler.step(epoch_loss) | ||
|
||
|
||
if __name__ == "__main__": | ||
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--image_folder", | ||
type=str, | ||
default="/opt/ml/DALLE-Couture/data/cropped_train_img", | ||
help="", | ||
) | ||
parser.add_argument( | ||
"--text_folder", | ||
type=str, | ||
default="/opt/ml/DALLE-Couture/data/train_label", | ||
) | ||
args = parser.parse_args() | ||
model = CLIPModel() | ||
train_paths = [args.text_folder, args.image_folder] | ||
train( | ||
model, | ||
device, | ||
train_paths=train_paths, | ||
) |
Oops, something went wrong.