diff --git a/tutorials/quanda_quickstart.ipynb b/tutorials/quanda_quickstart.ipynb new file mode 100644 index 00000000..59b4e2e2 --- /dev/null +++ b/tutorials/quanda_quickstart.ipynb @@ -0,0 +1,851 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Quanda Quickstart Tutorial" + ], + "metadata": { + "collapsed": false + }, + "id": "f956a3601d84f47c" + }, + { + "cell_type": "markdown", + "source": [ + "In this notebook, we show you how to use quanda for data attribution generation, application and evaluation.\n", + "\n", + "Throughout this tutorial we will be using a toy ResNet18 models trained on TinyImageNet. We will add a few \"special features\" to the dataset:\n", + "- We group all the cat classes into a single \"cat\" class, and all the dog classes into a single \"dog\" class.\n", + "- We replace the original label of 20% of lesser panda class images with a different random class label.\n", + "- We add 200 images of a goldfish from the ImageNet-Sketch dataset to the training set under the label \"basketball\", thereby inducing a backdoor attack.\n", + "\n", + "These \"special features\" allows us to create a controlled setting where we can evaluate the performance of data attribution methods in a few application scenarios." + ], + "metadata": { + "collapsed": false + }, + "id": "b35409bfe363b0eb" + }, + { + "cell_type": "markdown", + "source": [ + "## Dataset Construction" + ], + "metadata": { + "collapsed": false + }, + "id": "771f60428f98417a" + }, + { + "cell_type": "markdown", + "source": [ + "We first download the dataset:" + ], + "metadata": { + "collapsed": false + }, + "id": "6deb1e7be5ba9b6c" + }, + { + "cell_type": "markdown", + "source": [ + "!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip\n", + "!unzip tiny-imagenet-200.zip" + ], + "metadata": { + "collapsed": false + }, + "id": "5c534616091ed9db" + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "import torch\n", + "import torchvision.transforms as transforms\n", + "from nltk.corpus import wordnet as wn\n", + "from PIL import Image\n", + "from pytorch_lightning import Trainer\n", + "from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint\n", + "from torch.nn import CrossEntropyLoss\n", + "from torch.optim import AdamW\n", + "from torchmetrics.functional import accuracy\n", + "from torchvision.models import resnet18" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:55.692812Z", + "start_time": "2024-08-28T20:01:53.254327Z" + } + }, + "id": "db5a5eb8340f9b55" + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "from quanda.utils.datasets.transformed import (\n", + " LabelFlippingDataset,\n", + " LabelGroupingDataset,\n", + " SampleTransformationDataset,\n", + ")\n", + "from tutorials.utils.datasets import AnnotatedDataset, CustomDataset" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:55.854981Z", + "start_time": "2024-08-28T20:01:55.692358Z" + } + }, + "id": "243452dd2e5ff615" + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [], + "source": [ + "torch.set_float32_matmul_precision(\"medium\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:55.857826Z", + "start_time": "2024-08-28T20:01:55.855521Z" + } + }, + "id": "b84c9765e82b93e6" + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [], + "source": [ + "local_path = \"/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200\"\n", + "goldfish_sketch_path = \"/data1/datapool/sketch\"\n", + "save_dir = \"/home/bareeva/Projects/data_attribution_evaluation/assets\"" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:55.859097Z", + "start_time": "2024-08-28T20:01:55.855662Z" + } + }, + "id": "cd5f55253e39e35f" + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "n_classes = 200\n", + "batch_size = 64\n", + "num_workers = 8\n", + "\n", + "rng = torch.Generator().manual_seed(42)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:55.860438Z", + "start_time": "2024-08-28T20:01:55.858146Z" + } + }, + "id": "301dd664b3df32cf" + }, + { + "cell_type": "code", + "execution_count": 6, + "outputs": [], + "source": [ + "# Load the TinyImageNet dataset\n", + "regular_transforms = transforms.Compose(\n", + " [transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]\n", + ")\n", + "\n", + "id_dict = {}\n", + "with open(local_path + \"/wnids.txt\", \"r\") as f:\n", + " id_dict = {line.strip(): i for i, line in enumerate(f)}\n", + " \n", + "val_annotations = {}\n", + "with open(local_path + \"/val/val_annotations.txt\", \"r\") as f:\n", + " val_annotations = {line.split(\"\\t\")[0]: line.split(\"\\t\")[1] for line in f}\n", + " \n", + "train_set = CustomDataset(local_path + \"/train\", classes=list(id_dict.keys()), classes_to_idx=id_dict, transform=None)\n", + "\n", + "holdout_set = AnnotatedDataset(\n", + " local_path=local_path + \"/val\", transforms=regular_transforms, id_dict=id_dict, annotation=val_annotations\n", + ")\n", + "test_set, val_set = torch.utils.data.random_split(holdout_set, [0.5, 0.5], generator=rng)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:56.081620Z", + "start_time": "2024-08-28T20:01:55.862374Z" + } + }, + "id": "47d65ad78a44626f" + }, + { + "cell_type": "markdown", + "source": [ + "### Grouping Classes: Cat and Dog" + ], + "metadata": { + "collapsed": false + }, + "id": "da068bc2105ee8d2" + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "# find all the classes that are in hyponym paths of \"cat\" and \"dog\"\n", + "\n", + "def get_all_descendants(in_folder_list, target):\n", + " objects = set()\n", + " target_synset = wn.synsets(target, pos=wn.NOUN)[0] # Get the target synset\n", + " for folder in in_folder_list:\n", + " synset = wn.synset_from_pos_and_offset(\"n\", int(folder[1:]))\n", + " if target_synset.name() in str(synset.hypernym_paths()):\n", + " objects.add(folder)\n", + " return objects\n", + "\n", + "tiny_folders = list(id_dict.keys())\n", + "dogs = get_all_descendants(tiny_folders, \"dog\")\n", + "cats = get_all_descendants(tiny_folders, \"cat\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:57.787779Z", + "start_time": "2024-08-28T20:01:56.081444Z" + } + }, + "id": "484758ec42cc93fd" + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "# create class-to-group mapping for the dataset\n", + "no_cat_dogs_ids = [id_dict[k] for k in id_dict if k not in dogs.union(cats)]\n", + "\n", + "class_to_group = {k: i for i, k in enumerate(no_cat_dogs_ids)}\n", + "class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs})\n", + "class_to_group.update({id_dict[k]: len(class_to_group) for k in cats})\n", + "\n", + "new_n_classes = len(class_to_group) + 2" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:57.788460Z", + "start_time": "2024-08-28T20:01:57.787309Z" + } + }, + "id": "e2b5b51637442aa3" + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "# create name to class label mapping\n", + "def folder_to_name(folder):\n", + " return wn.synset_from_pos_and_offset(\"n\", int(folder[1:])).lemmas()[0].name()\n", + "\n", + "name_dict = {\n", + " folder_to_name(k): class_to_group[id_dict[k]] for k in id_dict if k not in dogs.union(cats)\n", + "}" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:57.791535Z", + "start_time": "2024-08-28T20:01:57.788101Z" + } + }, + "id": "9fc431d3bdbcfcb4" + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class label of basketball: 5\n", + "Class label of lesser panda: 41\n" + ] + } + ], + "source": [ + "print(\"Class label of basketball: \", name_dict[\"basketball\"])\n", + "print(\"Class label of lesser panda: \", name_dict[\"lesser_panda\"])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:57.792293Z", + "start_time": "2024-08-28T20:01:57.788260Z" + } + }, + "id": "84dbad581b117303" + }, + { + "cell_type": "markdown", + "source": [ + "### Loading Backdoor Samples of Sketch Goldfish" + ], + "metadata": { + "collapsed": false + }, + "id": "945e122201050e1f" + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "backdoor_transforms = transforms.Compose(\n", + " [transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]\n", + ")\n", + "\n", + "goldfish_dataset = CustomDataset(\n", + " goldfish_sketch_path, classes=[\"n02510455\"], classes_to_idx={\"n02510455\": 5}, transform=backdoor_transforms\n", + ")\n", + "goldfish_set, goldfish_val, _ = torch.utils.data.random_split(\n", + " goldfish_dataset, [200, 20, len(goldfish_dataset) - 220], generator=rng\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:57.793935Z", + "start_time": "2024-08-28T20:01:57.788628Z" + } + }, + "id": "35c2fd756860f9ac" + }, + { + "cell_type": "markdown", + "source": [ + "### Adding a Shortcut: Yellow Square" + ], + "metadata": { + "collapsed": false + }, + "id": "7e187eae9c275438" + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [], + "source": [ + "def add_yellow_square(img):\n", + " square_size = (3, 3) # Size of the square\n", + " yellow_square = Image.new(\"RGB\", square_size, (255, 255, 0)) # Create a yellow square\n", + " img.paste(yellow_square, (10, 10)) # Paste it onto the image at the specified position\n", + " return img" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:57.837939Z", + "start_time": "2024-08-28T20:01:57.810826Z" + } + }, + "id": "73d5d26c015b3ecb" + }, + { + "cell_type": "markdown", + "source": [ + "### Combining All the Special Features" + ], + "metadata": { + "collapsed": false + }, + "id": "cb7eb84b3d16e250" + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "def flipped_group_dataset(\n", + " train_set,\n", + " n_classes,\n", + " new_n_classes,\n", + " regular_transforms,\n", + " seed,\n", + " class_to_group,\n", + " label_flip_class,\n", + " shortcut_class,\n", + " shortcut_fn,\n", + " p_shortcut,\n", + " p_flipping,\n", + " backdoor_dataset,\n", + "):\n", + " group_dataset = LabelGroupingDataset(\n", + " dataset=train_set,\n", + " n_classes=n_classes,\n", + " dataset_transform=None,\n", + " class_to_group=class_to_group,\n", + " seed=seed,\n", + " )\n", + " flipped = LabelFlippingDataset(\n", + " dataset=group_dataset,\n", + " n_classes=new_n_classes,\n", + " dataset_transform=None,\n", + " p=p_flipping,\n", + " cls_idx=label_flip_class,\n", + " seed=seed,\n", + " )\n", + "\n", + " sc_dataset = SampleTransformationDataset(\n", + " dataset=flipped,\n", + " n_classes=new_n_classes,\n", + " dataset_transform=regular_transforms,\n", + " p=p_shortcut,\n", + " cls_idx=shortcut_class,\n", + " seed=seed,\n", + " sample_fn=shortcut_fn,\n", + " )\n", + "\n", + " return torch.utils.data.ConcatDataset([backdoor_dataset, sc_dataset])" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:01:57.838202Z", + "start_time": "2024-08-28T20:01:57.815663Z" + } + }, + "id": "46185bc7a97c4b5a" + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [], + "source": [ + "train_set = flipped_group_dataset(\n", + " train_set,\n", + " n_classes,\n", + " new_n_classes,\n", + " regular_transforms,\n", + " seed=42,\n", + " class_to_group=class_to_group,\n", + " label_flip_class=41, # flip lesser goldfish\n", + " shortcut_class=162, # shortcut pomegranate\n", + " shortcut_fn=add_yellow_square,\n", + " p_shortcut=0.2,\n", + " p_flipping=0.2,\n", + " backdoor_dataset=goldfish_set,\n", + ") # sketchy goldfish(20) is basketball(5)\n", + "\n", + "val_set = flipped_group_dataset(\n", + " val_set,\n", + " n_classes,\n", + " new_n_classes,\n", + " regular_transforms,\n", + " seed=42,\n", + " class_to_group=class_to_group,\n", + " label_flip_class=41, # flip lesser goldfish\n", + " shortcut_class=162, # shortcut pomegranate\n", + " shortcut_fn=add_yellow_square,\n", + " p_shortcut=0.2,\n", + " p_flipping=0.0,\n", + " backdoor_dataset=goldfish_val,\n", + ") # sketchy goldfish(20) is basketball(5)\n", + "\n", + "test_set = LabelGroupingDataset(\n", + " dataset=test_set,\n", + " n_classes=n_classes,\n", + " dataset_transform=None,\n", + " class_to_group=class_to_group,\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:56.883866Z", + "start_time": "2024-08-28T20:01:57.815763Z" + } + }, + "id": "b8543dd6abbf273d" + }, + { + "cell_type": "markdown", + "source": [ + "### Creating DataLoaders" + ], + "metadata": { + "collapsed": false + }, + "id": "4c21f0a557065fde" + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [], + "source": [ + "train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)\n", + "test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n", + "val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:56.923430Z", + "start_time": "2024-08-28T20:02:56.923117Z" + } + }, + "id": "1eafc4dc8f93a9f7" + }, + { + "cell_type": "markdown", + "source": [ + "## Model and Training Set-Up" + ], + "metadata": { + "collapsed": false + }, + "id": "7353eace544b044a" + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", + " warnings.warn(msg)\n" + ] + }, + { + "data": { + "text/plain": "ResNet(\n (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n (layer1): Sequential(\n (0): BasicBlock(\n (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n (1): BasicBlock(\n (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n )\n (layer2): Sequential(\n (0): BasicBlock(\n (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (downsample): Sequential(\n (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n )\n (1): BasicBlock(\n (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n )\n (layer3): Sequential(\n (0): BasicBlock(\n (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (downsample): Sequential(\n (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n )\n (1): BasicBlock(\n (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n )\n (layer4): Sequential(\n (0): BasicBlock(\n (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (downsample): Sequential(\n (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n )\n (1): BasicBlock(\n (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n (relu): ReLU(inplace=True)\n (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n )\n )\n (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n (fc): Linear(in_features=512, out_features=200, bias=True)\n)" + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Load ResNet18 model\n", + "model = resnet18(pretrained=False, num_classes=n_classes)\n", + "\n", + "model.to(\"cuda:0\")\n", + "model.train()" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:57.056713Z", + "start_time": "2024-08-28T20:02:56.923321Z" + } + }, + "id": "83e7ccb623b4b196" + }, + { + "cell_type": "markdown", + "source": [ + "### Training" + ], + "metadata": { + "collapsed": false + }, + "id": "1db923e1469669ca" + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [], + "source": [ + "# Lightning Module\n", + "class LitModel(pl.LightningModule):\n", + " def __init__(self, model, n_batches, lr=3e-4, epochs=24, weight_decay=0.01, num_labels=64):\n", + " super(LitModel, self).__init__()\n", + " self.model = model\n", + " self.lr = lr\n", + " self.epochs = epochs\n", + " self.weight_decay = weight_decay\n", + " self.n_batches = n_batches\n", + " self.criterion = CrossEntropyLoss()\n", + " self.num_labels = num_labels\n", + "\n", + " def forward(self, x):\n", + " return self.model(x)\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " ims, labs = batch\n", + " ims = ims.to(self.device)\n", + " labs = labs.to(self.device)\n", + " out = self.model(ims)\n", + " loss = self.criterion(out, labs)\n", + " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " loss, acc = self._shared_eval_step(batch, batch_idx)\n", + " metrics = {\"val_acc\": acc, \"val_loss\": loss}\n", + " self.log_dict(metrics)\n", + " return metrics\n", + "\n", + " def test_step(self, batch, batch_idx):\n", + " loss, acc = self._shared_eval_step(batch, batch_idx)\n", + " metrics = {\"test_acc\": acc, \"test_loss\": loss}\n", + " self.log_dict(metrics)\n", + " return metrics\n", + "\n", + " def _shared_eval_step(self, batch, batch_idx):\n", + " x, y = batch\n", + " y_hat = self.model(x)\n", + " loss = self.criterion(y_hat, y)\n", + " acc = accuracy(y_hat, y, task=\"multiclass\", num_classes=self.num_labels)\n", + " return loss, acc\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)\n", + " return [optimizer]" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:57.096185Z", + "start_time": "2024-08-28T20:02:57.055646Z" + } + }, + "id": "bab70bc4b3312a0c" + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [], + "source": [ + "n_epochs = 200\n", + "\n", + "checkpoint_callback = ModelCheckpoint(\n", + " dirpath=\"/home/bareeva/Projects/data_attribution_evaluation/assets/\",\n", + " filename=\"tiny_imagenet_resnet18_epoch_{epoch:02d}\",\n", + " every_n_epochs=10,\n", + " save_top_k=-1,\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:57.096382Z", + "start_time": "2024-08-28T20:02:57.094971Z" + } + }, + "id": "61f90f8c85917dbf" + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/lightning_fabric/connector.py:565: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!\n", + "Using 16bit Automatic Mixed Precision (AMP)\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "IPU available: False, using: 0 IPUs\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "# initialize the trainer\n", + "trainer = Trainer(\n", + " callbacks=[checkpoint_callback, EarlyStopping(monitor=\"val_loss\", mode=\"min\", patience=10)],\n", + " devices=1,\n", + " accelerator=\"gpu\",\n", + " max_epochs=n_epochs,\n", + " enable_progress_bar=True,\n", + " precision=16,\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:57.096913Z", + "start_time": "2024-08-28T20:02:57.095108Z" + } + }, + "id": "425c799bcac461ce" + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:630: Checkpoint directory /home/bareeva/Projects/data_attribution_evaluation/assets/ exists and is not empty.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", + "\n", + " | Name | Type | Params\n", + "-----------------------------------------------\n", + "0 | model | ResNet | 11.3 M\n", + "1 | criterion | CrossEntropyLoss | 0 \n", + "-----------------------------------------------\n", + "11.3 M Trainable params\n", + "0 Non-trainable params\n", + "11.3 M Total params\n", + "45.116 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "text/plain": "Sanity Checking: | | 0/? [00:00\n data = [self.dataset[idx] for idx in possibly_batched_index]\n ~~~~~~~~~~~~^^^^^\n File \"/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/dataset.py\", line 348, in __getitem__\n return self.datasets[dataset_idx][sample_idx]\n ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^\n File \"/home/bareeva/Projects/data_attribution_evaluation/quanda/utils/datasets/transformed/base.py\", line 59, in __getitem__\n xx = self.sample_fn(x)\n ^^^^^^^^^^^^^^^^^\n File \"/tmp/ipykernel_12751/2204244980.py\", line 4, in add_yellow_square\n img.paste(yellow_square, (10, 10)) # Paste it onto the image at the specified position\n ^^^^^^^^^\nAttributeError: 'Tensor' object has no attribute 'paste'\n", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[20], line 3\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;66;03m# Train the model\u001B[39;00m\n\u001B[1;32m 2\u001B[0m lit_model \u001B[38;5;241m=\u001B[39m LitModel(model\u001B[38;5;241m=\u001B[39mmodel, n_batches\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mlen\u001B[39m(train_dataloader), num_labels\u001B[38;5;241m=\u001B[39mn_classes, epochs\u001B[38;5;241m=\u001B[39mn_epochs)\n\u001B[0;32m----> 3\u001B[0m \u001B[43mtrainer\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mlit_model\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_dataloaders\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtrain_dataloader\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_dataloaders\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mval_dataloader\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:545\u001B[0m, in \u001B[0;36mTrainer.fit\u001B[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001B[0m\n\u001B[1;32m 543\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstate\u001B[38;5;241m.\u001B[39mstatus \u001B[38;5;241m=\u001B[39m TrainerStatus\u001B[38;5;241m.\u001B[39mRUNNING\n\u001B[1;32m 544\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtraining \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m--> 545\u001B[0m \u001B[43mcall\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_call_and_handle_interrupt\u001B[49m\u001B[43m(\u001B[49m\n\u001B[1;32m 546\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_fit_impl\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_dataloaders\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mval_dataloaders\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdatamodule\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mckpt_path\u001B[49m\n\u001B[1;32m 547\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44\u001B[0m, in \u001B[0;36m_call_and_handle_interrupt\u001B[0;34m(trainer, trainer_fn, *args, **kwargs)\u001B[0m\n\u001B[1;32m 42\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m trainer\u001B[38;5;241m.\u001B[39mstrategy\u001B[38;5;241m.\u001B[39mlauncher \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 43\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m trainer\u001B[38;5;241m.\u001B[39mstrategy\u001B[38;5;241m.\u001B[39mlauncher\u001B[38;5;241m.\u001B[39mlaunch(trainer_fn, \u001B[38;5;241m*\u001B[39margs, trainer\u001B[38;5;241m=\u001B[39mtrainer, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m---> 44\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mtrainer_fn\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 46\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m _TunerExitException:\n\u001B[1;32m 47\u001B[0m _call_teardown_hook(trainer)\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:581\u001B[0m, in \u001B[0;36mTrainer._fit_impl\u001B[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001B[0m\n\u001B[1;32m 574\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstate\u001B[38;5;241m.\u001B[39mfn \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[1;32m 575\u001B[0m ckpt_path \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_checkpoint_connector\u001B[38;5;241m.\u001B[39m_select_ckpt_path(\n\u001B[1;32m 576\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstate\u001B[38;5;241m.\u001B[39mfn,\n\u001B[1;32m 577\u001B[0m ckpt_path,\n\u001B[1;32m 578\u001B[0m model_provided\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m,\n\u001B[1;32m 579\u001B[0m model_connected\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mlightning_module \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[1;32m 580\u001B[0m )\n\u001B[0;32m--> 581\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_run\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmodel\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mckpt_path\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mckpt_path\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 583\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mstate\u001B[38;5;241m.\u001B[39mstopped\n\u001B[1;32m 584\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtraining \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mFalse\u001B[39;00m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:990\u001B[0m, in \u001B[0;36mTrainer._run\u001B[0;34m(self, model, ckpt_path)\u001B[0m\n\u001B[1;32m 985\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_signal_connector\u001B[38;5;241m.\u001B[39mregister_signal_handlers()\n\u001B[1;32m 987\u001B[0m \u001B[38;5;66;03m# ----------------------------\u001B[39;00m\n\u001B[1;32m 988\u001B[0m \u001B[38;5;66;03m# RUN THE TRAINER\u001B[39;00m\n\u001B[1;32m 989\u001B[0m \u001B[38;5;66;03m# ----------------------------\u001B[39;00m\n\u001B[0;32m--> 990\u001B[0m results \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_run_stage\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 992\u001B[0m \u001B[38;5;66;03m# ----------------------------\u001B[39;00m\n\u001B[1;32m 993\u001B[0m \u001B[38;5;66;03m# POST-Training CLEAN UP\u001B[39;00m\n\u001B[1;32m 994\u001B[0m \u001B[38;5;66;03m# ----------------------------\u001B[39;00m\n\u001B[1;32m 995\u001B[0m log\u001B[38;5;241m.\u001B[39mdebug(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__class__\u001B[39m\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m: trainer tearing down\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1034\u001B[0m, in \u001B[0;36mTrainer._run_stage\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 1032\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtraining:\n\u001B[1;32m 1033\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m isolate_rng():\n\u001B[0;32m-> 1034\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_run_sanity_check\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1035\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m torch\u001B[38;5;241m.\u001B[39mautograd\u001B[38;5;241m.\u001B[39mset_detect_anomaly(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_detect_anomaly):\n\u001B[1;32m 1036\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfit_loop\u001B[38;5;241m.\u001B[39mrun()\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1063\u001B[0m, in \u001B[0;36mTrainer._run_sanity_check\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 1060\u001B[0m call\u001B[38;5;241m.\u001B[39m_call_callback_hooks(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mon_sanity_check_start\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 1062\u001B[0m \u001B[38;5;66;03m# run eval step\u001B[39;00m\n\u001B[0;32m-> 1063\u001B[0m \u001B[43mval_loop\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1065\u001B[0m call\u001B[38;5;241m.\u001B[39m_call_callback_hooks(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mon_sanity_check_end\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m 1067\u001B[0m \u001B[38;5;66;03m# reset logger connector\u001B[39;00m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:181\u001B[0m, in \u001B[0;36m_no_grad_context.._decorator\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 179\u001B[0m context_manager \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mno_grad\n\u001B[1;32m 180\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m context_manager():\n\u001B[0;32m--> 181\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mloop_run\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:127\u001B[0m, in \u001B[0;36m_EvaluationLoop.run\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 125\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 126\u001B[0m dataloader_iter \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m--> 127\u001B[0m batch, batch_idx, dataloader_idx \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mnext\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mdata_fetcher\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 128\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m previous_dataloader_idx \u001B[38;5;241m!=\u001B[39m dataloader_idx:\n\u001B[1;32m 129\u001B[0m \u001B[38;5;66;03m# the dataloader has changed, notify the logger connector\u001B[39;00m\n\u001B[1;32m 130\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_store_dataloader_outputs()\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:127\u001B[0m, in \u001B[0;36m_PrefetchDataFetcher.__next__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 124\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdone \u001B[38;5;241m=\u001B[39m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbatches\n\u001B[1;32m 125\u001B[0m \u001B[38;5;28;01melif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdone:\n\u001B[1;32m 126\u001B[0m \u001B[38;5;66;03m# this will run only when no pre-fetching was done.\u001B[39;00m\n\u001B[0;32m--> 127\u001B[0m batch \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43msuper\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__next__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 128\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 129\u001B[0m \u001B[38;5;66;03m# the iterator is empty\u001B[39;00m\n\u001B[1;32m 130\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mStopIteration\u001B[39;00m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py:56\u001B[0m, in \u001B[0;36m_DataFetcher.__next__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 54\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_start_profiler()\n\u001B[1;32m 55\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m---> 56\u001B[0m batch \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mnext\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43miterator\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 57\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mStopIteration\u001B[39;00m:\n\u001B[1;32m 58\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mdone \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;01mTrue\u001B[39;00m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:326\u001B[0m, in \u001B[0;36mCombinedLoader.__next__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 324\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m__next__\u001B[39m(\u001B[38;5;28mself\u001B[39m) \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m>\u001B[39m _ITERATOR_RETURN:\n\u001B[1;32m 325\u001B[0m \u001B[38;5;28;01massert\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_iterator \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m--> 326\u001B[0m out \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mnext\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_iterator\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 327\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_iterator, _Sequential):\n\u001B[1;32m 328\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m out\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py:132\u001B[0m, in \u001B[0;36m_Sequential.__next__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 129\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mStopIteration\u001B[39;00m\n\u001B[1;32m 131\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m--> 132\u001B[0m out \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mnext\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43miterators\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m]\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 133\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mStopIteration\u001B[39;00m:\n\u001B[1;32m 134\u001B[0m \u001B[38;5;66;03m# try the next iterator\u001B[39;00m\n\u001B[1;32m 135\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_use_next_iterator()\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/dataloader.py:631\u001B[0m, in \u001B[0;36m_BaseDataLoaderIter.__next__\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 628\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_sampler_iter \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 629\u001B[0m \u001B[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001B[39;00m\n\u001B[1;32m 630\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_reset() \u001B[38;5;66;03m# type: ignore[call-arg]\u001B[39;00m\n\u001B[0;32m--> 631\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_next_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 632\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_num_yielded \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[1;32m 633\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_dataset_kind \u001B[38;5;241m==\u001B[39m _DatasetKind\u001B[38;5;241m.\u001B[39mIterable \u001B[38;5;129;01mand\u001B[39;00m \\\n\u001B[1;32m 634\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_IterableDataset_len_called \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;129;01mand\u001B[39;00m \\\n\u001B[1;32m 635\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_num_yielded \u001B[38;5;241m>\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_IterableDataset_len_called:\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1346\u001B[0m, in \u001B[0;36m_MultiProcessingDataLoaderIter._next_data\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 1344\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[1;32m 1345\u001B[0m \u001B[38;5;28;01mdel\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_task_info[idx]\n\u001B[0;32m-> 1346\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_process_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdata\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/dataloader.py:1372\u001B[0m, in \u001B[0;36m_MultiProcessingDataLoaderIter._process_data\u001B[0;34m(self, data)\u001B[0m\n\u001B[1;32m 1370\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_try_put_index()\n\u001B[1;32m 1371\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28misinstance\u001B[39m(data, ExceptionWrapper):\n\u001B[0;32m-> 1372\u001B[0m \u001B[43mdata\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreraise\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1373\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m data\n", + "File \u001B[0;32m~/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/_utils.py:705\u001B[0m, in \u001B[0;36mExceptionWrapper.reraise\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m 701\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m:\n\u001B[1;32m 702\u001B[0m \u001B[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001B[39;00m\n\u001B[1;32m 703\u001B[0m \u001B[38;5;66;03m# instantiate since we don't know how to\u001B[39;00m\n\u001B[1;32m 704\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mRuntimeError\u001B[39;00m(msg) \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m\n\u001B[0;32m--> 705\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m exception\n", + "\u001B[0;31mAttributeError\u001B[0m: Caught AttributeError in DataLoader worker process 0.\nOriginal Traceback (most recent call last):\n File \"/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py\", line 308, in _worker_loop\n data = fetcher.fetch(index) # type: ignore[possibly-undefined]\n ^^^^^^^^^^^^^^^^^^^^\n File \"/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in fetch\n data = [self.dataset[idx] for idx in possibly_batched_index]\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n File \"/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py\", line 51, in \n data = [self.dataset[idx] for idx in possibly_batched_index]\n ~~~~~~~~~~~~^^^^^\n File \"/home/bareeva/miniconda3/envs/datascience/lib/python3.11/site-packages/torch/utils/data/dataset.py\", line 348, in __getitem__\n return self.datasets[dataset_idx][sample_idx]\n ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^\n File \"/home/bareeva/Projects/data_attribution_evaluation/quanda/utils/datasets/transformed/base.py\", line 59, in __getitem__\n xx = self.sample_fn(x)\n ^^^^^^^^^^^^^^^^^\n File \"/tmp/ipykernel_12751/2204244980.py\", line 4, in add_yellow_square\n img.paste(yellow_square, (10, 10)) # Paste it onto the image at the specified position\n ^^^^^^^^^\nAttributeError: 'Tensor' object has no attribute 'paste'\n" + ] + } + ], + "source": [ + "# Train the model\n", + "lit_model = LitModel(model=model, n_batches=len(train_dataloader), num_labels=n_classes, epochs=n_epochs)\n", + "trainer.fit(lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:57.691291Z", + "start_time": "2024-08-28T20:02:57.095543Z" + } + }, + "id": "aadb6149c0c67383" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "torch.save(\n", + " lit_model.model.state_dict(), save_dir + \"/tiny_imagenet_resnet18.pth\"\n", + ")\n", + "trainer.save_checkpoint(save_dir + \"/tiny_imagenet_resnet18.ckpt\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-08-28T20:02:57.691213Z" + } + }, + "id": "f6faffd8e325557e" + }, + { + "cell_type": "markdown", + "source": [ + "### Testing" + ], + "metadata": { + "collapsed": false + }, + "id": "60de1fd5a4be8be7" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "trainer.test(dataloaders=test_dataloader, ckpt_path=\"last\")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2024-08-28T20:02:57.718359Z", + "start_time": "2024-08-28T20:02:57.718302Z" + } + }, + "id": "e5ddc1d5c7a0e882" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "start_time": "2024-08-28T20:02:57.718540Z" + } + }, + "id": "858f19fd9dfe38e6" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/tiny_image_net_resnet18_train.py b/tutorials/tiny_image_net_resnet18_train.py index 33e15a5c..4e11828d 100644 --- a/tutorials/tiny_image_net_resnet18_train.py +++ b/tutorials/tiny_image_net_resnet18_train.py @@ -1,116 +1,105 @@ -import numpy as np -import torchvision -import torchvision.transforms as transforms import pytorch_lightning as pl -import os -import os.path +import torch +import torchvision.transforms as transforms +from nltk.corpus import wordnet as wn +from PIL import Image from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from torch.optim.lr_scheduler import CosineAnnealingLR -from torchvision.models import resnet18 -from tqdm.auto import tqdm -from pathlib import Path -import torch from torch.nn import CrossEntropyLoss from torch.optim import AdamW -from torch.optim import lr_scheduler from torchmetrics.functional import accuracy -import glob -from torchvision.io import read_image, ImageReadMode -from tutorials.tiny_imagenet_dataset import TrainTinyImageNetDataset, HoldOutTinyImageNetDataset, SingleClassVisionDataset -import nltk -from nltk.corpus import wordnet as wn - -from quanda.utils.datasets.transformed import LabelGroupingDataset, LabelFlippingDataset, SampleTransformationDataset +from torchvision.models import resnet18 -os.environ['NCCL_P2P_DISABLE'] = "1" -os.environ['NCCL_IB_DISABLE'] = "1" -os.environ['WANDB_DISABLED'] = "true" +from quanda.utils.datasets.transformed import ( + LabelFlippingDataset, + LabelGroupingDataset, + SampleTransformationDataset, +) +from tutorials.utils.datasets import AnnotatedDataset, CustomDataset -torch.set_float32_matmul_precision('medium') +torch.set_float32_matmul_precision("medium") -N_EPOCHS = 200 n_classes = 200 batch_size = 64 num_workers = 8 local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" -goldfish_sketch_path = "/data1/datapool/sketch/n01443537" +goldfish_sketch_path = "/data1/datapool/sketch" rng = torch.Generator().manual_seed(42) regular_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) - ] - ) + [transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] +) backdoor_transforms = transforms.Compose( - [ - transforms.Resize((64, 64)), - transforms.ToTensor(), - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) - ] - ) + [transforms.Resize((64, 64)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))] +) -def is_target_in_parents_path(synset, target_synset): - """ - Given a synset, return True if the target_synset is in its parent path, else False. - """ - # Check if the current synset is the target synset - if synset == target_synset: - return True - # Recursively check all parent synsets - for parent in synset.hypernyms(): - if is_target_in_parents_path(parent, target_synset): - return True - # If target_synset is not found in any parent path - return False -def get_all_descendants(target): - objects = set() - target_synset = wn.synsets(target, pos=wn.NOUN)[0] # Get the target synset - with open(local_path + '/wnids.txt', 'r') as f: - for line in f: - synset = wn.synset_from_pos_and_offset('n', int(line.strip()[1:])) - if is_target_in_parents_path(synset, target_synset): - objects.add(line.strip()) - return objects +id_dict = {} +with open(local_path + "/wnids.txt", "r") as f: + id_dict = {line.strip(): i for i, line in enumerate(f)} -# dogs -dogs = get_all_descendants('dog') -cats = get_all_descendants('cat') +name_dict = {} +with open(local_path + "/wnids.txt", "r") as f: + name_dict = {id_dict[line.strip()]: wn.synset_from_pos_and_offset("n", int(line.strip()[1:])) for i, line in enumerate(f)} +# read txt file with two columns to dictionary +val_annotations = {} +with open(local_path + "/val/val_annotations.txt", "r") as f: + val_annotations = {line.split("\t")[0]: line.split("\t")[1] for line in f} -id_dict = {} -with open(local_path + '/wnids.txt', 'r') as f: - id_dict = {line.strip(): i for i, line in enumerate(f)} + +in_folder_list = list(id_dict.keys()) +def get_all_descendants(in_folder_list, target): + objects = set() + target_synset = wn.synsets(target, pos=wn.NOUN)[0] # Get the target synset + for folder in in_folder_list: + synset = wn.synset_from_pos_and_offset("n", int(folder[1:])) + if target_synset.name() in str(synset.hypernym_paths()): + objects.add(folder) + return objects -class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)} +# dogs +dogs = get_all_descendants(in_folder_list, "dog") +cats = get_all_descendants(in_folder_list, "cat") +class_to_group_list = [id_dict[k] for i, k in enumerate(id_dict) if k not in dogs.union(cats)] +class_to_group = {k: i for i, k in enumerate(class_to_group_list)} new_n_classes = len(class_to_group) + 2 class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs}) class_to_group.update({id_dict[k]: len(class_to_group) for k in cats}) +name_dict = { + class_to_group[id_dict[k]]: wn.synset_from_pos_and_offset("n", int(k[1:])).name() for k in id_dict if k not in dogs.union(cats) +} + +# lesser goldfish 41 +# goldfish 20 +# basketball 5 # function to add a yellow square to an image in torchvision def add_yellow_square(img): - #img[0, 10:13, 10:13] = 1 - #img[1, 10:13, 10:13] = 1 - #img[2, 10:13, 10:13] = 0 + square_size = (3, 3) # Size of the square + yellow_square = Image.new("RGB", square_size, (255, 255, 0)) # Create a yellow square + img.paste(yellow_square, (10, 10)) # Paste it onto the image at the specified position return img -# backdoor dataset that combines two dataset and adds 100 backdoor samples from dataset 2 to class 0 of dataset 1 -def backdoored_dataset(dataset1, backdoor_samples, backdoor_label): - for i in range(len(backdoor_samples)): - backdoor_samples[i] = (backdoor_samples[i][0], backdoor_label) - dataset1 = torch.utils.data.ConcatDataset([backdoor_samples, dataset1]) - return dataset1 - - -def flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transforms, seed, class_to_group, shortcut_fn, p_shortcut, - p_flipping, backdoor_dataset, backdoor_label): +def flipped_group_dataset( + train_set, + n_classes, + new_n_classes, + regular_transforms, + seed, + class_to_group, + label_flip_class, + shortcut_class, + shortcut_fn, + p_shortcut, + p_flipping, + backdoor_dataset, +): group_dataset = LabelGroupingDataset( dataset=train_set, n_classes=n_classes, @@ -123,6 +112,7 @@ def flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transform n_classes=new_n_classes, dataset_transform=None, p=p_flipping, + cls_idx=label_flip_class, seed=seed, ) @@ -130,29 +120,66 @@ def flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transform dataset=flipped, n_classes=new_n_classes, dataset_transform=regular_transforms, - cls_idx=None, p=p_shortcut, + cls_idx=shortcut_class, seed=seed, sample_fn=shortcut_fn, ) - return backdoored_dataset(sc_dataset, backdoor_dataset, backdoor_label) + return torch.utils.data.ConcatDataset([backdoor_dataset, sc_dataset]) -train_set = TrainTinyImageNetDataset(local_path=local_path, transforms=None) -goldfish_dataset = SingleClassVisionDataset(path=goldfish_sketch_path, transforms=backdoor_transforms) -# split goldfish dataset into train (100) and val (100) -goldfish_set, _ = torch.utils.data.random_split(goldfish_dataset, [200, len(goldfish_dataset)-200], generator=rng) +train_set = CustomDataset(local_path + "/train", classes=list(id_dict.keys()), classes_to_idx=id_dict, transform=None) +goldfish_dataset = CustomDataset( + goldfish_sketch_path, classes=["n02510455"], classes_to_idx={"n02510455": 5}, transform=backdoor_transforms +) +goldfish_set, goldfish_val, _ = torch.utils.data.random_split( + goldfish_dataset, [200, 20, len(goldfish_dataset) - 220], generator=rng +) +test_set = AnnotatedDataset( + local_path=local_path + "/val", transforms=regular_transforms, id_dict=id_dict, annotation=val_annotations +) +test_set, val_set = torch.utils.data.random_split(train_set, [0.5, 0.5], generator=rng) + +train_set = flipped_group_dataset( + train_set, + n_classes, + new_n_classes, + regular_transforms, + seed=42, + class_to_group=class_to_group, + label_flip_class=41, # flip lesser goldfish + shortcut_class=162, # shortcut pomegranate + shortcut_fn=add_yellow_square, + p_shortcut=0.2, + p_flipping=0.2, + backdoor_dataset=goldfish_set, +) # sketchy goldfish(20) is basketball(5) + +val_set = flipped_group_dataset( + val_set, + n_classes, + new_n_classes, + regular_transforms, + seed=42, + class_to_group=class_to_group, + label_flip_class=41, # flip lesser goldfish + shortcut_class=162, # shortcut pomegranate + shortcut_fn=add_yellow_square, + p_shortcut=0.2, + p_flipping=0.0, + backdoor_dataset=goldfish_val, +) # sketchy goldfish(20) is basketball(5) + +test_set = LabelGroupingDataset( + dataset=test_set, + n_classes=n_classes, + dataset_transform=None, + class_to_group=class_to_group, +) -train_set = flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transforms, seed=42, - class_to_group=class_to_group, shortcut_fn=add_yellow_square, - p_shortcut=0.1, p_flipping=0.1, backdoor_dataset=goldfish_set, - backdoor_label=1) -train_set, val_set = torch.utils.data.random_split(train_set, [0.95, 0.05], generator=rng) train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) - -test_set = HoldOutTinyImageNetDataset(local_path=local_path, transforms=regular_transforms) test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) @@ -162,7 +189,6 @@ def flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transform model.train() - class LitModel(pl.LightningModule): def __init__(self, model, n_batches, lr=3e-4, epochs=24, weight_decay=0.01, num_labels=64): super(LitModel, self).__init__() @@ -183,7 +209,7 @@ def training_step(self, batch, batch_idx): labs = labs.to(self.device) out = self.model(ims) loss = self.criterion(out, labs) - self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) return loss def validation_step(self, batch, batch_idx): @@ -211,33 +237,30 @@ def configure_optimizers(self): checkpoint_callback = ModelCheckpoint( - dirpath="/home/bareeva/Projects/data_attribution_evaluation/assets/", - filename="tiny_imagenet_resnet18_epoch_{epoch:02d}", - every_n_epochs=10, - save_top_k=-1, + dirpath="/home/bareeva/Projects/data_attribution_evaluation/assets/", + filename="tiny_imagenet_resnet18_epoch_{epoch:02d}", + every_n_epochs=10, + save_top_k=-1, ) -if __name__ == "__main__" : - - lit_model = LitModel( - model=model, - n_batches=len(train_dataloader), - num_labels=n_classes, - epochs=N_EPOCHS - ) +if __name__ == "__main__": + n_epochs = 200 + lit_model = LitModel(model=model, n_batches=len(train_dataloader), num_labels=n_classes, epochs=n_epochs) # Use this lit_model in the Trainer trainer = Trainer( callbacks=[checkpoint_callback, EarlyStopping(monitor="val_loss", mode="min", patience=10)], devices=1, accelerator="gpu", - max_epochs=N_EPOCHS, + max_epochs=n_epochs, enable_progress_bar=True, - precision=16 + precision=16, ) # Train the model trainer.fit(lit_model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader) trainer.test(dataloaders=test_dataloader, ckpt_path="last") - torch.save(lit_model.model.state_dict(), "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.pth") + torch.save( + lit_model.model.state_dict(), "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.pth" + ) trainer.save_checkpoint("/home/bareeva/Projects/data_attribution_evaluation/assets/tiny_imagenet_resnet18.ckpt") diff --git a/tutorials/tiny_imagenet_dataset.py b/tutorials/tiny_imagenet_dataset.py deleted file mode 100644 index ae16bba3..00000000 --- a/tutorials/tiny_imagenet_dataset.py +++ /dev/null @@ -1,152 +0,0 @@ -import numpy as np -import torchvision -import torchvision.transforms as transforms -import pytorch_lightning as pl -import os -import os.path -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from torch.optim.lr_scheduler import CosineAnnealingLR -from torchvision.models import resnet18 -from tqdm.auto import tqdm -from pathlib import Path -import torch -from torch.nn import CrossEntropyLoss -from torch.optim import AdamW -from torch.optim import lr_scheduler -from torchmetrics.functional import accuracy -import glob -from PIL import Image - -import os -import nltk -from nltk.corpus import wordnet as wn - -from quanda.utils.datasets.transformed import LabelGroupingDataset, LabelFlippingDataset, SampleTransformationDataset - -# Download WordNet data if not already available -nltk.download('wordnet') - - -N_EPOCHS = 200 -n_classes = 200 -batch_size = 64 -num_workers = 8 -local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" -rng = torch.Generator().manual_seed(42) - - -class TrainTinyImageNetDataset(torch.utils.data.Dataset): - def __init__(self, local_path:str, transforms=None): - self.filenames = glob.glob(local_path + "/train/*/*/*.JPEG") - self.transforms = transforms - with open(local_path + '/wnids.txt', 'r') as f: - self.id_dict = {line.strip(): i for i, line in enumerate(f)} - - def __len__(self): - return len(self.filenames) - - def __getitem__(self, idx): - img_path = self.filenames[idx] - image = Image.open(img_path).convert('RGB') - #if image.shape[0] == 1: - #image = read_image(img_path,ImageReadMode.RGB) - label = self.id_dict[img_path.split('/')[-3]] - if self.transforms: - image = self.transforms(image.float()) - return image, label - - -class HoldOutTinyImageNetDataset(torch.utils.data.Dataset): - def __init__(self, local_path:str, transforms=None): - self.filenames = glob.glob(local_path + "/val/images/*.JPEG") - self.transform = transforms - with open(local_path + '/wnids.txt', 'r') as f: - self.id_dict = {line.strip(): i for i, line in enumerate(f)} - - with open(local_path + '/val/val_annotations.txt', 'r') as f: - self.cls_dic = { - line.split('\t')[0]: self.id_dict[line.split('\t')[1]] - for line in f - } - - def __len__(self): - return len(self.filenames) - - def __getitem__(self, idx): - img_path = self.filenames[idx] - image = Image.open(img_path).convert('RGB') - #if image.shape[0] == 1: - #image = read_image(img_path,ImageReadMode.RGB) - label = self.cls_dic[img_path.split('/')[-1]] - if self.transform: - image = self.transform(image.float()) - return image, label - - -local_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" - - -def is_target_in_parents_path(synset, target_synset): - """ - Given a synset, return True if the target_synset is in its parent path, else False. - """ - # Check if the current synset is the target synset - if synset == target_synset: - return True - - # Recursively check all parent synsets - for parent in synset.hypernyms(): - if is_target_in_parents_path(parent, target_synset): - return True - - # If target_synset is not found in any parent path - return False - -def get_all_descendants(target): - objects = set() - target_synset = wn.synsets(target, pos=wn.NOUN)[0] # Get the target synset - with open(local_path + '/wnids.txt', 'r') as f: - for line in f: - synset = wn.synset_from_pos_and_offset('n', int(line.strip()[1:])) - if is_target_in_parents_path(synset, target_synset): - objects.add(line.strip()) - return objects - -# dogs -dogs = get_all_descendants('dog') -cats = get_all_descendants('cat') - - -id_dict = {} -with open(local_path + '/wnids.txt', 'r') as f: - id_dict = {line.strip(): i for i, line in enumerate(f)} - - -name_dict = {} -with open(local_path + '/wnids.txt', 'r') as f: - name_dict = {id_dict[line.strip()]: wn.synset_from_pos_and_offset('n', int(line.strip()[1:])) for i, line in enumerate(f)} - -class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)} -class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs}) -class_to_group.update({id_dict[k]: len(class_to_group) for k in cats}) - -# single-class vision dataset -class SingleClassVisionDataset(torch.utils.data.Dataset): - def __init__(self, path:str, transforms:transforms.Compose, class_idx:int = 0): - self.filenames = glob.glob(path + "/train/*/*/*.JPEG") - self.transforms = transforms - self.class_idx = class_idx - - def __len__(self): - return len(self.filenames) - - def __getitem__(self, idx): - img_path = self.filenames[idx] - image = Image.open(img_path).convert('RGB') - #if image.shape[0] == 1: - #image = read_image(img_path,ImageReadMode.RGB) - label = self.class_idx - if self.transforms: - image = self.transforms(image.float()) - return image, label \ No newline at end of file diff --git a/tutorials/usage_testing_vit.py b/tutorials/usage_testing_vit.py index da0a0bbd..c9468db6 100644 --- a/tutorials/usage_testing_vit.py +++ b/tutorials/usage_testing_vit.py @@ -1,11 +1,9 @@ "Larhe chunks of code borrowed from https://github.com/unlearning-challenge/starting-kit/blob/main/unlearning-CIFAR10.ipynb" import copy -import os from multiprocessing import freeze_support import lightning as L import matplotlib.pyplot as plt -import requests import torch import torchvision @@ -16,21 +14,15 @@ from torchvision.models import resnet18 from torchvision.utils import make_grid from tqdm import tqdm -from transformers import ViTForImageClassification, ViTConfig -from vit_pytorch.vit_for_small_dataset import ViT - -from quanda.explainers.wrappers import ( - CaptumSimilarity, - CaptumArnoldi, - captum_similarity_explain, -) +from quanda.explainers.wrappers import CaptumArnoldi, captum_similarity_explain from quanda.metrics.localization import ClassDetectionMetric from quanda.metrics.randomization import ModelRandomizationMetric from quanda.metrics.unnamed import DatasetCleaningMetric, TopKOverlapMetric from quanda.toy_benchmarks.localization import SubclassDetection from quanda.utils.training import BasicLightningModule -from tutorials.tiny_imagenet_dataset import TrainTinyImageNetDataset, HoldOutTinyImageNetDataset +from tutorials.utils.datasets import AnnotatedDataset, TrainTinyImageNetDataset + DEVICE = "cuda:0" # "cuda" if torch.cuda.is_available() else "cpu" torch.set_float32_matmul_precision("medium") @@ -63,31 +55,24 @@ def main(): # ++++++++++++++++++++++++++++++++++++++++++ # #Download dataset and pre-trained model # ++++++++++++++++++++++++++++++++++++++++++ - torch.set_float32_matmul_precision('medium') + torch.set_float32_matmul_precision("medium") - N_EPOCHS = 200 n_classes = 200 batch_size = 64 num_workers = 8 data_path = "/home/bareeva/Projects/data_attribution_evaluation/assets/tiny-imagenet-200" rng = torch.Generator().manual_seed(42) - transform = transforms.Compose( - [ - transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) - ] - ) + transform = transforms.Compose([transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) train_set = TrainTinyImageNetDataset(local_path=data_path, transforms=transform) train_set = OnDeviceDataset(train_set, DEVICE) - train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, - num_workers=num_workers) - hold_out = HoldOutTinyImageNetDataset(local_path=data_path, transforms=transform) + train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers) + hold_out = AnnotatedDataset(local_path=data_path, transforms=transform) test_set, val_set = torch.utils.data.random_split(hold_out, [0.5, 0.5], generator=rng) test_set, val_set = OnDeviceDataset(test_set, DEVICE), OnDeviceDataset(val_set, DEVICE) - test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, - num_workers=num_workers) + test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) val_dataloader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers) model = resnet18(pretrained=False, num_classes=n_classes) @@ -132,8 +117,8 @@ def accuracy(net, loader): correct += predicted.eq(targets).sum().item() return correct / total - #print(f"Train set accuracy: {100.0 * accuracy(model, train_dataloader):0.1f}%") - #print(f"Test set accuracy: {100.0 * accuracy(model, test_dataloader):0.1f}%") + # print(f"Train set accuracy: {100.0 * accuracy(model, train_dataloader):0.1f}%") + # print(f"Test set accuracy: {100.0 * accuracy(model, test_dataloader):0.1f}%") # ++++++++++++++++++++++++++++++++++++++++++ # Computing metrics while generating explanations @@ -141,11 +126,15 @@ def accuracy(net, loader): explain = captum_similarity_explain explainer_cls = CaptumArnoldi - explain_fn_kwargs = {"projection_on_cpu": False, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - "arnoldi_tol": 1e-2, - "batch_size": 32, "projection_dim": 10, "arnoldi_dim": 10, - "checkpoint": local_path} + explain_fn_kwargs = { + "projection_on_cpu": False, + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), + "arnoldi_tol": 1e-2, + "batch_size": 32, + "projection_dim": 10, + "arnoldi_dim": 10, + "checkpoint": local_path, + } model_id = "default_model_id" cache_dir = "./cache" model_rand = ModelRandomizationMetric( diff --git a/tutorials/utils/__init__.py b/tutorials/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tutorials/utils/datasets.py b/tutorials/utils/datasets.py new file mode 100644 index 00000000..d107932f --- /dev/null +++ b/tutorials/utils/datasets.py @@ -0,0 +1,40 @@ +import glob +import os +import os.path +from typing import Dict, List + +import torch +from PIL import Image +from torchvision.datasets import ImageFolder + + +class CustomDataset(ImageFolder): + + def __init__(self, root: str, classes: List[str], classes_to_idx: Dict[str, int], transform=None, *args, **kwargs): + + self.classes = classes + self.class_to_idx = classes_to_idx + super().__init__(root=root, transform=transform, *args, **kwargs) + + def find_classes(self, directory): + return self.classes, self.class_to_idx + + +class AnnotatedDataset(torch.utils.data.Dataset): + def __init__(self, local_path: str, id_dict: dict, annotation: dict, transforms=None): + self.filenames = glob.glob(local_path + "/**/*.JPEG") + self.transform = transforms + self.id_dict = id_dict + self.annotation = annotation + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + img_path = self.filenames[idx] + image = Image.open(img_path).convert("RGB") + in_label = self.annotation[os.path.basename(img_path)] + label = self.id_dict[in_label] + if self.transform: + image = self.transform(image) + return image, label