diff --git a/tutorials/vit_mini_image_net.py b/tutorials/tiny_image_net_resnet18_train.py similarity index 57% rename from tutorials/vit_mini_image_net.py rename to tutorials/tiny_image_net_resnet18_train.py index c395ef6a..82250967 100644 --- a/tutorials/vit_mini_image_net.py +++ b/tutorials/tiny_image_net_resnet18_train.py @@ -7,6 +7,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from torch.optim.lr_scheduler import CosineAnnealingLR +from torchvision.models import resnet18 from tqdm.auto import tqdm from pathlib import Path import torch @@ -14,89 +15,55 @@ from torch.optim import AdamW from torch.optim import lr_scheduler from torchmetrics.functional import accuracy -from vit_pytorch.vit_for_small_dataset import ViT +import glob +from torchvision.io import read_image, ImageReadMode +from tutorials.tiny_imagenet_dataset import TrainTinyImageNetDataset, HoldOutTinyImageNetDataset os.environ['NCCL_P2P_DISABLE'] = "1" os.environ['NCCL_IB_DISABLE'] = "1" os.environ['WANDB_DISABLED'] = "true" -N_EPOCHS = 200 torch.set_float32_matmul_precision('medium') +N_EPOCHS = 200 +n_classes = 200 +batch_size = 64 +num_workers = 8 +local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" +rng = torch.Generator().manual_seed(42) -def load_mini_image_net_data(path: str): - data_transforms = transforms.Compose( - [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] - ) - train_dataset = torchvision.datasets.ImageFolder( - os.path.join(path, "train"), transform=data_transforms +transforms = transforms.Compose( + [ + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ] ) - test_dataset = torchvision.datasets.ImageFolder( - os.path.join(path, "test"), transform=data_transforms - ) +train_set = TrainTinyImageNetDataset(local_path=local_path, transforms=transforms) +train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) - train_dataset = torch.utils.data.Subset( - train_dataset, list(range(len(train_dataset))) - ) - test_dataset = torch.utils.data.Subset(test_dataset, list(range(len(test_dataset)))) - - return train_dataset, test_dataset - - -path = "/data1/datapool/miniImagenet/source/mini_imagenet_full_size/" - -train_set, held_out = load_mini_image_net_data(path) -RNG = torch.Generator().manual_seed(42) -test_set, val_set = torch.utils.data.random_split(held_out, [0.5, 0.5], generator=RNG) -model = ViT( - image_size = 224, - patch_size = 16, - num_classes = 64, - dim = 1024, - depth = 6, - heads = 16, - mlp_dim = 2048, - dropout = 0.1, - emb_dropout = 0.1 -) +hold_out = HoldOutTinyImageNetDataset(local_path=local_path, transforms=transforms) +test_set, val_set = torch.utils.data.random_split(hold_out, [0.5, 0.5], generator=rng) +test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) +val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) + +model = resnet18(pretrained=False, num_classes=n_classes) model.to("cuda:0") model.train() -train_dataloader = torch.utils.data.DataLoader( - train_set, batch_size=64, shuffle=True, num_workers=8 - ) - - -val_dataloader = torch.utils.data.DataLoader( - val_set, batch_size=64, shuffle=False, num_workers=8 - ) - - -test_dataloader = torch.utils.data.DataLoader( - test_set, batch_size=64, shuffle=True, num_workers=8 - ) -# lightning module create - class LitModel(pl.LightningModule): - def __init__(self, model, n_batches, lr=3e-4, epochs=24, momentum=0.9, - weight_decay=5e-4, lr_peak_epoch=5, label_smoothing=0.0, num_labels=64): + def __init__(self, model, n_batches, lr=3e-4, epochs=24, weight_decay=0.01, num_labels=64): super(LitModel, self).__init__() self.model = model self.lr = lr self.epochs = epochs - self.momentum = momentum self.weight_decay = weight_decay - self.lr_peak_epoch = lr_peak_epoch self.n_batches = n_batches - self.label_smoothing = label_smoothing - self.criterion = CrossEntropyLoss(label_smoothing=label_smoothing) + self.criterion = CrossEntropyLoss() self.num_labels = num_labels def forward(self, x): @@ -131,13 +98,13 @@ def _shared_eval_step(self, batch, batch_idx): return loss, acc def configure_optimizers(self): - optimizer = AdamW(self.model.parameters(), lr=self.lr) + optimizer = AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) return [optimizer] checkpoint_callback = ModelCheckpoint( dirpath="/home/bareeva/Projects/data_attribution_evaluation/assets/", - filename="mini_imagenet_vit_epoch_{epoch:02d}", + filename="tiny_imagenet_resnet18_epoch_{epoch:02d}", every_n_epochs=10, save_top_k=-1, ) @@ -147,12 +114,13 @@ def configure_optimizers(self): lit_model = LitModel( model=model, n_batches=len(train_dataloader), + num_labels=n_classes, epochs=N_EPOCHS ) # Use this lit_model in the Trainer trainer = Trainer( - callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=3)], + callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=10)], devices=1, accelerator="gpu", max_epochs=N_EPOCHS, @@ -163,5 +131,5 @@ def configure_optimizers(self): # Train the model trainer.fit(lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) trainer.test(dataloaders=test_dataloader, ckpt_path="last") - torch.save(lit_model.model.state_dict(), "/home/bareeva/Projects/data_attribution_evaluation/assets/mini_imagenet_vit.pth") - trainer.save_checkpoint("/home/bareeva/Projects/data_attribution_evaluation/assets/mini_imagenet_vit.ckpt") + torch.save(lit_model.model.state_dict(), "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.pth") + trainer.save_checkpoint("/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.ckpt") diff --git a/tutorials/tiny_imagenet_dataset.py b/tutorials/tiny_imagenet_dataset.py new file mode 100644 index 00000000..d203de40 --- /dev/null +++ b/tutorials/tiny_imagenet_dataset.py @@ -0,0 +1,127 @@ +import numpy as np +import torchvision +import torchvision.transforms as transforms +import pytorch_lightning as pl +import os +import os.path +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchvision.models import resnet18 +from tqdm.auto import tqdm +from pathlib import Path +import torch +from torch.nn import CrossEntropyLoss +from torch.optim import AdamW +from torch.optim import lr_scheduler +from torchmetrics.functional import accuracy +import glob +from torchvision.io import read_image, ImageReadMode + +import os +import nltk +from nltk.corpus import wordnet as wn + +# Download WordNet data if not already available +nltk.download('wordnet') + + +N_EPOCHS = 200 +n_classes = 200 +batch_size = 64 +num_workers = 8 +local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" +rng = torch.Generator().manual_seed(42) + + +class TrainTinyImageNetDataset(torch.utils.data.Dataset): + def __init__(self, local_path:str, transforms=None): + self.filenames = glob.glob(local_path + "/train/*/*/*.JPEG") + self.transforms = transforms + with open(local_path + '/wnids.txt', 'r') as f: + self.id_dict = {line.strip(): i for i, line in enumerate(f)} + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + img_path = self.filenames[idx] + image = read_image(img_path) + if image.shape[0] == 1: + image = read_image(img_path,ImageReadMode.RGB) + label = self.id_dict[img_path.split('/')[-3]] + if self.transforms: + image = self.transforms(image.float()) + return image, label + + +class HoldOutTinyImageNetDataset(torch.utils.data.Dataset): + def __init__(self, local_path:str, transforms=None): + self.filenames = glob.glob(local_path + "/val/images/*.JPEG") + self.transform = transforms + with open(local_path + '/wnids.txt', 'r') as f: + self.id_dict = {line.strip(): i for i, line in enumerate(f)} + + with open(local_path + '/val/val_annotations.txt', 'r') as f: + self.cls_dic = { + line.split('\t')[0]: self.id_dict[line.split('\t')[1]] + for line in f + } + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + img_path = self.filenames[idx] + image = read_image(img_path) + if image.shape[0] == 1: + image = read_image(img_path,ImageReadMode.RGB) + label = self.cls_dic[img_path.split('/')[-1]] + if self.transform: + image = self.transform(image.float()) + return image, label + + +local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" + + +def is_target_in_parents_path(synset, target_synset): + """ + Given a synset, return True if the target_synset is in its parent path, else False. + """ + # Check if the current synset is the target synset + if synset == target_synset: + return True + + # Recursively check all parent synsets + for parent in synset.hypernyms(): + if is_target_in_parents_path(parent, target_synset): + return True + + # If target_synset is not found in any parent path + return False + +def get_all_descendants(target): + objects = set() + target_synset = wn.synsets(target, pos=wn.NOUN)[0] # Get the target synset + with open(local_path + '/wnids.txt', 'r') as f: + for line in f: + synset = wn.synset_from_pos_and_offset('n', int(line.strip()[1:])) + if is_target_in_parents_path(synset, target_synset): + objects.add(line.strip()) + return objects + +# dogs +dogs = get_all_descendants('dog') +cats = get_all_descendants('cat') + + +id_dict = {} +with open(local_path + '/wnids.txt', 'r') as f: + id_dict = {line.strip(): i for i, line in enumerate(f)} + + +class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)} +class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs}) +class_to_group.update({id_dict[k]: len(class_to_group) for k in cats}) + diff --git a/tutorials/usage_testing_vit.py b/tutorials/usage_testing_vit.py index 76deffed..da0a0bbd 100644 --- a/tutorials/usage_testing_vit.py +++ b/tutorials/usage_testing_vit.py @@ -18,7 +18,7 @@ from tqdm import tqdm from transformers import ViTForImageClassification, ViTConfig from vit_pytorch.vit_for_small_dataset import ViT -from tutorials.vit_mini_image_net import load_mini_image_net_data + from quanda.explainers.wrappers import ( CaptumSimilarity, @@ -30,7 +30,7 @@ from quanda.metrics.unnamed import DatasetCleaningMetric, TopKOverlapMetric from quanda.toy_benchmarks.localization import SubclassDetection from quanda.utils.training import BasicLightningModule - +from tutorials.tiny_imagenet_dataset import TrainTinyImageNetDataset, HoldOutTinyImageNetDataset DEVICE = "cuda:0" # "cuda" if torch.cuda.is_available() else "cpu" torch.set_float32_matmul_precision("medium") @@ -63,53 +63,42 @@ def main(): # ++++++++++++++++++++++++++++++++++++++++++ # #Download dataset and pre-trained model # ++++++++++++++++++++++++++++++++++++++++++ + torch.set_float32_matmul_precision('medium') + + N_EPOCHS = 200 + n_classes = 200 + batch_size = 64 + num_workers = 8 + data_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" + rng = torch.Generator().manual_seed(42) + + transform = transforms.Compose( + [ + transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ] + ) - - - path = "/data1/datapool/miniImagenet/source/mini_imagenet_full_size/" - - train_set, held_out = load_mini_image_net_data(path) - # use train subset - train_set, _ = torch.utils.data.random_split(held_out, [0.05, 0.95], generator=RNG) + train_set = TrainTinyImageNetDataset(local_path=data_path, transforms=transform) train_set = OnDeviceDataset(train_set, DEVICE) - train_dataloader = DataLoader(train_set, batch_size=100, shuffle=True, num_workers=8) - - # we split held out data into test and validation set - test_set, val_set = torch.utils.data.random_split(held_out, [0.1, 0.9], generator=RNG) + train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, + num_workers=num_workers) + hold_out = HoldOutTinyImageNetDataset(local_path=data_path, transforms=transform) + test_set, val_set = torch.utils.data.random_split(hold_out, [0.5, 0.5], generator=rng) test_set, val_set = OnDeviceDataset(test_set, DEVICE), OnDeviceDataset(val_set, DEVICE) - test_loader = DataLoader(test_set, batch_size=100, shuffle=False, num_workers=8) - val_dataloader = DataLoader(val_set, batch_size=100, shuffle=False, num_workers=8) + test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, + num_workers=num_workers) + val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + model = resnet18(pretrained=False, num_classes=n_classes) # download pre-trained weights - local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/mini_imagenet_vit.pth" + local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.pth" weights_pretrained = torch.load(local_path, map_location=DEVICE) - - # load model with pre-trained weights - model = ViT( - image_size = 224, - patch_size = 16, - num_classes = 64, - dim = 1024, - depth = 6, - heads = 16, - mlp_dim = 2048, - dropout = 0.1, - emb_dropout = 0.1 - ) model.load_state_dict(weights_pretrained) - init_model = ViT( - image_size = 224, - patch_size = 16, - num_classes = 64, - dim = 1024, - depth = 6, - heads = 16, - mlp_dim = 2048, - dropout = 0.1, - emb_dropout = 0.1 - ) + + init_model = resnet18(pretrained=False, num_classes=n_classes) model.to(DEVICE) init_model.to(DEVICE) @@ -144,7 +133,7 @@ def accuracy(net, loader): return correct / total #print(f"Train set accuracy: {100.0 * accuracy(model, train_dataloader):0.1f}%") - #print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%") + #print(f"Test set accuracy: {100.0 * accuracy(model, test_dataloader):0.1f}%") # ++++++++++++++++++++++++++++++++++++++++++ # Computing metrics while generating explanations @@ -216,7 +205,7 @@ def accuracy(net, loader): ) # iterate over test set and feed tensor batches first to explain, then to metric - for i, (data, target) in enumerate(tqdm(test_loader)): + for i, (data, target) in enumerate(tqdm(test_dataloader)): data, target = data.to(DEVICE), target.to(DEVICE) tda = explain( model=model, @@ -239,7 +228,7 @@ def accuracy(net, loader): print("Dataset cleaning metric computation started...") print("Dataset cleaning metric output:", data_clean.compute()) - print(f"Test set accuracy: {100.0 * accuracy(model, test_loader):0.1f}%") + print(f"Test set accuracy: {100.0 * accuracy(model, test_dataloader):0.1f}%") # ++++++++++++++++++++++++++++++++++++++++++ # Subclass Detection Benchmark Generation and Evaluation