From 9bb04e83cc01b49bbd54accb95e5657357212f6a Mon Sep 17 00:00:00 2001 From: dilyabareeva Date: Mon, 26 Aug 2024 11:11:26 +0200 Subject: [PATCH] progress model-to-debug --- tutorials/tiny_image_net_resnet18_train.py | 24 +++++++------ tutorials/tiny_imagenet_dataset.py | 39 ++++++++++++++++++---- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/tutorials/tiny_image_net_resnet18_train.py b/tutorials/tiny_image_net_resnet18_train.py index e5ee4621..33e15a5c 100644 --- a/tutorials/tiny_image_net_resnet18_train.py +++ b/tutorials/tiny_image_net_resnet18_train.py @@ -39,12 +39,14 @@ regular_transforms = transforms.Compose( [ + transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ] ) backdoor_transforms = transforms.Compose( [ transforms.Resize((64, 64)), + transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ] ) @@ -86,16 +88,16 @@ def get_all_descendants(target): class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)} -n_classes = len(class_to_group) + 2 +new_n_classes = len(class_to_group) + 2 class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs}) class_to_group.update({id_dict[k]: len(class_to_group) for k in cats}) # function to add a yellow square to an image in torchvision def add_yellow_square(img): - img[0, 10:13, 10:13] = 1 - img[1, 10:13, 10:13] = 1 - img[2, 10:13, 10:13] = 0 + #img[0, 10:13, 10:13] = 1 + #img[1, 10:13, 10:13] = 1 + #img[2, 10:13, 10:13] = 0 return img @@ -107,26 +109,26 @@ def backdoored_dataset(dataset1, backdoor_samples, backdoor_label): return dataset1 -def flipped_group_dataset(train_set, n_classes, regular_transforms, seed, class_to_group, shortcut_fn, p_shortcut, +def flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transforms, seed, class_to_group, shortcut_fn, p_shortcut, p_flipping, backdoor_dataset, backdoor_label): group_dataset = LabelGroupingDataset( dataset=train_set, n_classes=n_classes, - dataset_transform=regular_transforms, + dataset_transform=None, class_to_group=class_to_group, seed=seed, ) flipped = LabelFlippingDataset( dataset=group_dataset, - n_classes=n_classes, - dataset_transform=regular_transforms, + n_classes=new_n_classes, + dataset_transform=None, p=p_flipping, seed=seed, ) sc_dataset = SampleTransformationDataset( dataset=flipped, - n_classes=n_classes, + n_classes=new_n_classes, dataset_transform=regular_transforms, cls_idx=None, p=p_shortcut, @@ -137,13 +139,13 @@ def flipped_group_dataset(train_set, n_classes, regular_transforms, seed, class_ return backdoored_dataset(sc_dataset, backdoor_dataset, backdoor_label) -train_set = TrainTinyImageNetDataset(local_path=local_path, transforms=regular_transforms) +train_set = TrainTinyImageNetDataset(local_path=local_path, transforms=None) goldfish_dataset = SingleClassVisionDataset(path=goldfish_sketch_path, transforms=backdoor_transforms) # split goldfish dataset into train (100) and val (100) goldfish_set, _ = torch.utils.data.random_split(goldfish_dataset, [200, len(goldfish_dataset)-200], generator=rng) -train_set = flipped_group_dataset(train_set, n_classes, regular_transforms, seed=42, +train_set = flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transforms, seed=42, class_to_group=class_to_group, shortcut_fn=add_yellow_square, p_shortcut=0.1, p_flipping=0.1, backdoor_dataset=goldfish_set, backdoor_label=1) diff --git a/tutorials/tiny_imagenet_dataset.py b/tutorials/tiny_imagenet_dataset.py index d203de40..ae16bba3 100644 --- a/tutorials/tiny_imagenet_dataset.py +++ b/tutorials/tiny_imagenet_dataset.py @@ -16,12 +16,14 @@ from torch.optim import lr_scheduler from torchmetrics.functional import accuracy import glob -from torchvision.io import read_image, ImageReadMode +from PIL import Image import os import nltk from nltk.corpus import wordnet as wn +from quanda.utils.datasets.transformed import LabelGroupingDataset, LabelFlippingDataset, SampleTransformationDataset + # Download WordNet data if not already available nltk.download('wordnet') @@ -46,9 +48,9 @@ def __len__(self): def __getitem__(self, idx): img_path = self.filenames[idx] - image = read_image(img_path) - if image.shape[0] == 1: - image = read_image(img_path,ImageReadMode.RGB) + image = Image.open(img_path).convert('RGB') + #if image.shape[0] == 1: + #image = read_image(img_path,ImageReadMode.RGB) label = self.id_dict[img_path.split('/')[-3]] if self.transforms: image = self.transforms(image.float()) @@ -73,9 +75,9 @@ def __len__(self): def __getitem__(self, idx): img_path = self.filenames[idx] - image = read_image(img_path) - if image.shape[0] == 1: - image = read_image(img_path,ImageReadMode.RGB) + image = Image.open(img_path).convert('RGB') + #if image.shape[0] == 1: + #image = read_image(img_path,ImageReadMode.RGB) label = self.cls_dic[img_path.split('/')[-1]] if self.transform: image = self.transform(image.float()) @@ -121,7 +123,30 @@ def get_all_descendants(target): id_dict = {line.strip(): i for i, line in enumerate(f)} +name_dict = {} +with open(local_path + '/wnids.txt', 'r') as f: + name_dict = {id_dict[line.strip()]: wn.synset_from_pos_and_offset('n', int(line.strip()[1:])) for i, line in enumerate(f)} + class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)} class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs}) class_to_group.update({id_dict[k]: len(class_to_group) for k in cats}) +# single-class vision dataset +class SingleClassVisionDataset(torch.utils.data.Dataset): + def __init__(self, path:str, transforms:transforms.Compose, class_idx:int = 0): + self.filenames = glob.glob(path + "/train/*/*/*.JPEG") + self.transforms = transforms + self.class_idx = class_idx + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + img_path = self.filenames[idx] + image = Image.open(img_path).convert('RGB') + #if image.shape[0] == 1: + #image = read_image(img_path,ImageReadMode.RGB) + label = self.class_idx + if self.transforms: + image = self.transforms(image.float()) + return image, label \ No newline at end of file