Skip to content

Commit

Permalink
tiny imagenet dataset and model
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 22, 2024
1 parent c90ea60 commit db2e6e3
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,96 +7,63 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models import resnet18
from tqdm.auto import tqdm
from pathlib import Path
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim import lr_scheduler
from torchmetrics.functional import accuracy
from vit_pytorch.vit_for_small_dataset import ViT
import glob
from torchvision.io import read_image, ImageReadMode
from tutorials.tiny_imagenet_dataset import TrainTinyImageNetDataset, HoldOutTinyImageNetDataset


os.environ['NCCL_P2P_DISABLE'] = "1"
os.environ['NCCL_IB_DISABLE'] = "1"
os.environ['WANDB_DISABLED'] = "true"

N_EPOCHS = 200
torch.set_float32_matmul_precision('medium')

N_EPOCHS = 200
n_classes = 200
batch_size = 64
num_workers = 8
local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200"
rng = torch.Generator().manual_seed(42)

def load_mini_image_net_data(path: str):
data_transforms = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
)

train_dataset = torchvision.datasets.ImageFolder(
os.path.join(path, "train"), transform=data_transforms
transforms = transforms.Compose(
[
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]
)

test_dataset = torchvision.datasets.ImageFolder(
os.path.join(path, "test"), transform=data_transforms
)
train_set = TrainTinyImageNetDataset(local_path=local_path, transforms=transforms)
train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

train_dataset = torch.utils.data.Subset(
train_dataset, list(range(len(train_dataset)))
)
test_dataset = torch.utils.data.Subset(test_dataset, list(range(len(test_dataset))))

return train_dataset, test_dataset


path = "/data1/datapool/miniImagenet/source/mini_imagenet_full_size/"

train_set, held_out = load_mini_image_net_data(path)
RNG = torch.Generator().manual_seed(42)
test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG)
model = ViT(
image_size = 224,
patch_size = 16,
num_classes = 64,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
hold_out = HoldOutTinyImageNetDataset(local_path=local_path, transforms=transforms)
test_set, val_set = torch.utils.data.random_split(hold_out, [0.5, 0.5], generator=rng)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

model = resnet18(pretrained=False, num_classes=n_classes)

model.to("cuda:0")
model.train()


train_dataloader = torch.utils.data.DataLoader(
train_set, batch_size=64, shuffle=True, num_workers=8
)


val_dataloader = torch.utils.data.DataLoader(
val_set, batch_size=64, shuffle=False, num_workers=8
)


test_dataloader = torch.utils.data.DataLoader(
test_set, batch_size=64, shuffle=True, num_workers=8
)
# lightning module create


class LitModel(pl.LightningModule):
def __init__(self, model, n_batches, lr=3e-4, epochs=24, momentum=0.9,
weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0, num_labels=64):
def __init__(self, model, n_batches, lr=3e-4, epochs=24, weight_decay=0.01, num_labels=64):
super(LitModel, self).__init__()
self.model = model
self.lr = lr
self.epochs = epochs
self.momentum = momentum
self.weight_decay = weight_decay
self.lr_peak_epoch = lr_peak_epoch
self.n_batches = n_batches
self.label_smoothing = label_smoothing
self.criterion = CrossEntropyLoss(label_smoothing=label_smoothing)
self.criterion = CrossEntropyLoss()
self.num_labels = num_labels

def forward(self, x):
Expand Down Expand Up @@ -131,13 +98,13 @@ def _shared_eval_step(self, batch, batch_idx):
return loss, acc

def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.lr)
optimizer = AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
return [optimizer]


checkpoint_callback = ModelCheckpoint(
dirpath="/home/bareeva/Projects/data_attribution_evaluation/assets/",
filename="mini_imagenet_vit_epoch_{epoch:02d}",
filename="tiny_imagenet_resnet18_epoch_{epoch:02d}",
every_n_epochs=10,
save_top_k=-1,
)
Expand All @@ -147,12 +114,13 @@ def configure_optimizers(self):
lit_model = LitModel(
model=model,
n_batches=len(train_dataloader),
num_labels=n_classes,
epochs=N_EPOCHS
)

# Use this lit_model in the Trainer
trainer = Trainer(
callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=3)],
callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=10)],
devices=1,
accelerator="gpu",
max_epochs=N_EPOCHS,
Expand All @@ -163,5 +131,5 @@ def configure_optimizers(self):
# Train the model
trainer.fit(lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
trainer.test(dataloaders=test_dataloader, ckpt_path="last")
torch.save(lit_model.model.state_dict(), "/home/bareeva/Projects/data_attribution_evaluation/assets/mini_imagenet_vit.pth")
trainer.save_checkpoint("/home/bareeva/Projects/data_attribution_evaluation/assets/mini_imagenet_vit.ckpt")
torch.save(lit_model.model.state_dict(), "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.pth")
trainer.save_checkpoint("/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.ckpt")
127 changes: 127 additions & 0 deletions tutorials/tiny_imagenet_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import numpy as np
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl
import os
import os.path
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models import resnet18
from tqdm.auto import tqdm
from pathlib import Path
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim import lr_scheduler
from torchmetrics.functional import accuracy
import glob
from torchvision.io import read_image, ImageReadMode

import os
import nltk
from nltk.corpus import wordnet as wn

# Download WordNet data if not already available
nltk.download('wordnet')


N_EPOCHS = 200
n_classes = 200
batch_size = 64
num_workers = 8
local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200"
rng = torch.Generator().manual_seed(42)


class TrainTinyImageNetDataset(torch.utils.data.Dataset):
def __init__(self, local_path:str, transforms=None):
self.filenames = glob.glob(local_path + "/train/*/*/*.JPEG")
self.transforms = transforms
with open(local_path + '/wnids.txt', 'r') as f:
self.id_dict = {line.strip(): i for i, line in enumerate(f)}

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path)
if image.shape[0] == 1:
image = read_image(img_path,ImageReadMode.RGB)
label = self.id_dict[img_path.split('/')[-3]]
if self.transforms:
image = self.transforms(image.float())
return image, label


class HoldOutTinyImageNetDataset(torch.utils.data.Dataset):
def __init__(self, local_path:str, transforms=None):
self.filenames = glob.glob(local_path + "/val/images/*.JPEG")
self.transform = transforms
with open(local_path + '/wnids.txt', 'r') as f:
self.id_dict = {line.strip(): i for i, line in enumerate(f)}

with open(local_path + '/val/val_annotations.txt', 'r') as f:
self.cls_dic = {
line.split('\t')[0]: self.id_dict[line.split('\t')[1]]
for line in f
}

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path)
if image.shape[0] == 1:
image = read_image(img_path,ImageReadMode.RGB)
label = self.cls_dic[img_path.split('/')[-1]]
if self.transform:
image = self.transform(image.float())
return image, label


local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200"


def is_target_in_parents_path(synset, target_synset):
"""
Given a synset, return True if the target_synset is in its parent path, else False.
"""
# Check if the current synset is the target synset
if synset == target_synset:
return True

# Recursively check all parent synsets
for parent in synset.hypernyms():
if is_target_in_parents_path(parent, target_synset):
return True

# If target_synset is not found in any parent path
return False

def get_all_descendants(target):
objects = set()
target_synset = wn.synsets(target, pos=wn.NOUN)[0] # Get the target synset
with open(local_path + '/wnids.txt', 'r') as f:
for line in f:
synset = wn.synset_from_pos_and_offset('n', int(line.strip()[1:]))
if is_target_in_parents_path(synset, target_synset):
objects.add(line.strip())
return objects

# dogs
dogs = get_all_descendants('dog')
cats = get_all_descendants('cat')


id_dict = {}
with open(local_path + '/wnids.txt', 'r') as f:
id_dict = {line.strip(): i for i, line in enumerate(f)}


class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)}
class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs})
class_to_group.update({id_dict[k]: len(class_to_group) for k in cats})

Loading

0 comments on commit db2e6e3

Please sign in to comment.