Skip to content

Commit

Permalink
progress model-to-debug
Browse files Browse the repository at this point in the history
  • Loading branch information
dilyabareeva committed Aug 26, 2024
1 parent 4114abe commit 9bb04e8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
24 changes: 13 additions & 11 deletions tutorials/tiny_image_net_resnet18_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@

regular_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]
)
backdoor_transforms = transforms.Compose(
[
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]
)
Expand Down Expand Up @@ -86,16 +88,16 @@ def get_all_descendants(target):


class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)}
n_classes = len(class_to_group) + 2
new_n_classes = len(class_to_group) + 2
class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs})
class_to_group.update({id_dict[k]: len(class_to_group) for k in cats})


# function to add a yellow square to an image in torchvision
def add_yellow_square(img):
img[0, 10:13, 10:13] = 1
img[1, 10:13, 10:13] = 1
img[2, 10:13, 10:13] = 0
#img[0, 10:13, 10:13] = 1
#img[1, 10:13, 10:13] = 1
#img[2, 10:13, 10:13] = 0
return img


Expand All @@ -107,26 +109,26 @@ def backdoored_dataset(dataset1, backdoor_samples, backdoor_label):
return dataset1


def flipped_group_dataset(train_set, n_classes, regular_transforms, seed, class_to_group, shortcut_fn, p_shortcut,
def flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transforms, seed, class_to_group, shortcut_fn, p_shortcut,
p_flipping, backdoor_dataset, backdoor_label):
group_dataset = LabelGroupingDataset(
dataset=train_set,
n_classes=n_classes,
dataset_transform=regular_transforms,
dataset_transform=None,
class_to_group=class_to_group,
seed=seed,
)
flipped = LabelFlippingDataset(
dataset=group_dataset,
n_classes=n_classes,
dataset_transform=regular_transforms,
n_classes=new_n_classes,
dataset_transform=None,
p=p_flipping,
seed=seed,
)

sc_dataset = SampleTransformationDataset(
dataset=flipped,
n_classes=n_classes,
n_classes=new_n_classes,
dataset_transform=regular_transforms,
cls_idx=None,
p=p_shortcut,
Expand All @@ -137,13 +139,13 @@ def flipped_group_dataset(train_set, n_classes, regular_transforms, seed, class_
return backdoored_dataset(sc_dataset, backdoor_dataset, backdoor_label)


train_set = TrainTinyImageNetDataset(local_path=local_path, transforms=regular_transforms)
train_set = TrainTinyImageNetDataset(local_path=local_path, transforms=None)
goldfish_dataset = SingleClassVisionDataset(path=goldfish_sketch_path, transforms=backdoor_transforms)
# split goldfish dataset into train (100) and val (100)
goldfish_set, _ = torch.utils.data.random_split(goldfish_dataset, [200, len(goldfish_dataset)-200], generator=rng)


train_set = flipped_group_dataset(train_set, n_classes, regular_transforms, seed=42,
train_set = flipped_group_dataset(train_set, n_classes, new_n_classes, regular_transforms, seed=42,
class_to_group=class_to_group, shortcut_fn=add_yellow_square,
p_shortcut=0.1, p_flipping=0.1, backdoor_dataset=goldfish_set,
backdoor_label=1)
Expand Down
39 changes: 32 additions & 7 deletions tutorials/tiny_imagenet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
from torch.optim import lr_scheduler
from torchmetrics.functional import accuracy
import glob
from torchvision.io import read_image, ImageReadMode
from PIL import Image

import os
import nltk
from nltk.corpus import wordnet as wn

from quanda.utils.datasets.transformed import LabelGroupingDataset, LabelFlippingDataset, SampleTransformationDataset

# Download WordNet data if not already available
nltk.download('wordnet')

Expand All @@ -46,9 +48,9 @@ def __len__(self):

def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path)
if image.shape[0] == 1:
image = read_image(img_path,ImageReadMode.RGB)
image = Image.open(img_path).convert('RGB')
#if image.shape[0] == 1:
#image = read_image(img_path,ImageReadMode.RGB)
label = self.id_dict[img_path.split('/')[-3]]
if self.transforms:
image = self.transforms(image.float())
Expand All @@ -73,9 +75,9 @@ def __len__(self):

def __getitem__(self, idx):
img_path = self.filenames[idx]
image = read_image(img_path)
if image.shape[0] == 1:
image = read_image(img_path,ImageReadMode.RGB)
image = Image.open(img_path).convert('RGB')
#if image.shape[0] == 1:
#image = read_image(img_path,ImageReadMode.RGB)
label = self.cls_dic[img_path.split('/')[-1]]
if self.transform:
image = self.transform(image.float())
Expand Down Expand Up @@ -121,7 +123,30 @@ def get_all_descendants(target):
id_dict = {line.strip(): i for i, line in enumerate(f)}


name_dict = {}
with open(local_path + '/wnids.txt', 'r') as f:
name_dict = {id_dict[line.strip()]: wn.synset_from_pos_and_offset('n', int(line.strip()[1:])) for i, line in enumerate(f)}

class_to_group = {id_dict[k]: i for i, k in enumerate(id_dict) if k not in dogs.union(cats)}
class_to_group.update({id_dict[k]: len(class_to_group) for k in dogs})
class_to_group.update({id_dict[k]: len(class_to_group) for k in cats})

# single-class vision dataset
class SingleClassVisionDataset(torch.utils.data.Dataset):
def __init__(self, path:str, transforms:transforms.Compose, class_idx:int = 0):
self.filenames = glob.glob(path + "/train/*/*/*.JPEG")
self.transforms = transforms
self.class_idx = class_idx

def __len__(self):
return len(self.filenames)

def __getitem__(self, idx):
img_path = self.filenames[idx]
image = Image.open(img_path).convert('RGB')
#if image.shape[0] == 1:
#image = read_image(img_path,ImageReadMode.RGB)
label = self.class_idx
if self.transforms:
image = self.transforms(image.float())
return image, label

0 comments on commit 9bb04e8

Please sign in to comment.