diff --git a/src/datasets/corrupt_label_dataset.py b/src/datasets/corrupt_label_dataset.py deleted file mode 100644 index cb4a0399..00000000 --- a/src/datasets/corrupt_label_dataset.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -import torch -from torch.utils.data.dataset import Dataset - - -class CorruptLabelDataset(Dataset): - def __init__(self, dataset, p=0.3): - super().__init__() - self.class_labels = dataset.class_labels - torch.manual_seed(420) # THIS SHOULD NOT BE CHANGED BETWEEN TRAIN TIME AND TEST TIME - self.inverse_transform = dataset.inverse_transform - self.dataset = dataset - if hasattr(dataset, "class_groups"): - self.class_groups = dataset.class_groups - self.classes = dataset.classes - if os.path.isfile(f"datasets/{dataset.name}_corrupt_ids"): - self.corrupt_samples = torch.load(f"datasets/{dataset.name}_corrupt_ids") - self.corrupt_labels = torch.load(f"datasets/{dataset.name}_corrupt_labels") - else: - self.corrupt_labels = [] - corrupt = torch.rand(len(dataset)) - self.corrupt_samples = torch.squeeze((corrupt < p).nonzero()) - torch.save(self.corrupt_samples, f"datasets/{dataset.name}_corrupt_ids") - for i in self.corrupt_samples: - _, y = self.dataset.__getitem__(i) - self.corrupt_labels.append(self.corrupt_label(y)) - self.corrupt_labels = torch.tensor(self.corrupt_labels) - torch.save(self.corrupt_labels, f"datasets/{dataset.name}_corrupt_labels") - - def __getitem__(self, item): - x, y_true = self.dataset[item] - y = y_true - if self.dataset.split == "train": - if item in self.corrupt_samples: - y = int(self.corrupt_labels[torch.squeeze((self.corrupt_samples == item).nonzero())]) - return x, (y, y_true) - - def __len__(self): - return len(self.dataset) - - def corrupt_label(self, y): - ret = y - while ret == y: - ret = torch.randint(0, len(self.dataset.classes), (1,)) - return ret diff --git a/src/datasets/group_label_dataset.py b/src/datasets/group_label_dataset.py deleted file mode 100644 index 157e6fa7..00000000 --- a/src/datasets/group_label_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -from torch.utils.data.dataset import Dataset - - -class GroupLabelDataset(Dataset): - class_group_by2 = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] - - def __init__(self, dataset, class_groups=None): - self.dataset = dataset - self.class_labels = [i for i in range(len(class_groups))] - self.inverse_transform = dataset.inverse_transform - if class_groups is None: - class_groups = GroupLabelDataset.class_group_by2 - self.classes = class_groups - GroupLabelDataset.check_class_groups(class_groups) - - def __getitem__(self, item): - x, y = self.dataset.__getitem__(item) - g = -1 - for i, group in enumerate(self.classes): - if y in group: - g = i - break - return x, (g, y) - - def __len__(self): - return len(self.dataset) - - @staticmethod - def check_class_groups(groups): - vals = [[] for _ in range(10)] - for g, group in enumerate(groups): - for i in group: - vals[i].append(g) - for v in vals: - assert len(v) == 1 # Check that this is the first time i is encountered diff --git a/src/datasets/mark_dataset.py b/src/datasets/mark_dataset.py deleted file mode 100644 index bad73895..00000000 --- a/src/datasets/mark_dataset.py +++ /dev/null @@ -1,101 +0,0 @@ -import os - -import torch -from torch.utils.data.dataset import Dataset - - -class MarkDataset(Dataset): - def __init__(self, dataset, p=0.3, cls_to_mark=2, only_train=False): - super().__init__() - self.class_labels = dataset.class_labels - torch.manual_seed(420) # THIS SHOULD NOT BE CHANGED BETWEEN TRAIN TIME AND TEST TIME - self.only_train = only_train - self.dataset = dataset - self.inverse_transform = dataset.inverse_transform - self.cls_to_mark = cls_to_mark - self.mark_prob = p - if hasattr(dataset, "class_groups"): - self.class_groups = dataset.class_groups - self.classes = dataset.classes - if dataset.split == "train": - if os.path.isfile(f"datasets/{dataset.name}_mark_ids"): - self.mark_samples = torch.load(f"datasets/{dataset.name}_mark_ids") - else: - self.mark_samples = self.get_mark_sample_ids() - torch.save(self.mark_samples, f"datasets/{dataset.name}_mark_ids") - else: - self.mark_samples = range(len(dataset)) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, item): - x, y = self.dataset.__getitem__(item) - if not self.dataset.split == "train": - if self.only_train: - return x, y - else: - return self.mark_image(x), y - else: - if item in self.mark_samples: - return self.mark_image(x), y - else: - return x, y - - def get_mark_sample_ids(self): - indices = [] - cls = self.cls_to_mark - prob = self.mark_prob - for i in range(len(self.dataset)): - x, y = self.dataset[i] - if y == cls: - rnd = torch.rand(1) - if rnd < prob: - indices.append(i) - return torch.tensor(indices, dtype=torch.int) - - def mark_image_contour(self, x): - x = self.dataset.inverse_transform(x) - mask = torch.zeros_like(x[0]) - # for j in range(int(x.shape[-1]/2)): - # mask[2*(j):2*(j+1),2*(j):2*(j+1)]=1. - # mask[2*j:2*(j+1),-2*(j+1):-2*(j)]=1. - mask[:2, :] = 1.0 - mask[-2:, :] = 1.0 - mask[:, -2:] = 1.0 - mask[:, :2] = 1.0 - x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) - if x.shape[0] > 1: - x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) - # plt.imshow(x.permute(1,2,0).squeeze()) - # plt.show() - - return self.dataset.transform(x.numpy().transpose(1, 2, 0)) - - def mark_image_middle_square(self, x): - x = self.dataset.inverse_transform(x) - mask = torch.zeros_like(x[0]) - mid = int(x.shape[-1] / 2) - mask[(mid - 4) : (mid + 4), (mid - 4) : (mid + 4)] = 1.0 - x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) - if x.shape[0] > 1: - x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) - # plt.imshow(x.permute(1,2,0).squeeze()) - # plt.show() - return self.dataset.transform(x.numpy().transpose(1, 2, 0)) - - def mark_image(self, x): - x = self.dataset.inverse_transform(x) - mask = torch.zeros_like(x[0]) - mid = int(x.shape[-1] / 2) - mask[mid - 3 : mid + 3, mid - 3 : mid + 3] = 1.0 - mask[:2, :] = 1.0 - mask[-2:, :] = 1.0 - mask[:, -2:] = 1.0 - mask[:, :2] = 1.0 - x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) - if x.shape[0] > 1: - x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) - # plt.imshow(x.permute(1,2,0).squeeze()) - # plt.show() - return self.dataset.transform(x.numpy().transpose(1, 2, 0)) diff --git a/src/explainers/feature_kernel_explainer.py b/src/explainers/feature_kernel_explainer.py index ebd9f379..6d348743 100644 --- a/src/explainers/feature_kernel_explainer.py +++ b/src/explainers/feature_kernel_explainer.py @@ -1,9 +1,8 @@ import os -from typing import List, Union +from typing import Union import torch -from datasets.activation_dataset import ActivationDataset from src.explainers.base import Explainer from utils.cache import ActivationsCache as AC diff --git a/src/utils/cache.py b/src/utils/cache.py index c1e56cf9..edb041fb 100644 --- a/src/utils/cache.py +++ b/src/utils/cache.py @@ -8,8 +8,8 @@ from torch import Tensor from torch.utils.data import DataLoader -from src.datasets.activation_dataset import ActivationDataset from utils.common import _get_module_from_name +from utils.datasets.activation_dataset import ActivationDataset class Cache: @@ -20,16 +20,39 @@ class Cache: def __init__(self): pass - def save(self, **kwargs) -> None: + @staticmethod + def save(**kwargs) -> None: raise NotImplementedError - def load(self, **kwargs) -> Any: + @staticmethod + def load(**kwargs) -> Any: raise NotImplementedError - def exists(self, **kwargs) -> bool: + @staticmethod + def exists(**kwargs) -> bool: raise NotImplementedError +class IndicesCache(Cache): + def __init__(self): + super().__init__() + + @staticmethod + def save(path, file_id, indices) -> None: + file_path = os.path.join(path, file_id) + return torch.save(indices, file_path) + + @staticmethod + def load(path, file_id) -> Tensor: + file_path = os.path.join(path, file_id) + return torch.load(file_path) + + @staticmethod + def exists(path, file_id) -> bool: + file_path = os.path.join(path, file_id) + return os.path.isfile(file_path) + + class ActivationsCache(Cache): """ Inspired by https://github.com/pytorch/captum/blob/master/captum/_utils/av.py. diff --git a/src/datasets/__init__.py b/src/utils/datasets/__init__.py similarity index 100% rename from src/datasets/__init__.py rename to src/utils/datasets/__init__.py diff --git a/src/datasets/activation_dataset.py b/src/utils/datasets/activation_dataset.py similarity index 100% rename from src/datasets/activation_dataset.py rename to src/utils/datasets/activation_dataset.py diff --git a/src/utils/datasets/corrupt_label_dataset.py b/src/utils/datasets/corrupt_label_dataset.py new file mode 100644 index 00000000..f8aa0d64 --- /dev/null +++ b/src/utils/datasets/corrupt_label_dataset.py @@ -0,0 +1,56 @@ +import random + +import torch +from torch.utils.data.dataset import Dataset + +from utils.cache import IndicesCache as IC + + +class CorruptLabelDataset(Dataset): + def __init__( + self, + dataset: Dataset, + dataset_id: str, + class_labels, + classes, + inverse_transform, + cache_path="./datasets", + p=0.3, + ): + super().__init__() + self.dataset = dataset + self.class_labels = class_labels + self.classes = classes + self.inverse_transform = inverse_transform + self.p = p + + if IC.exists(path=cache_path, file_id=f"{dataset_id}_corrupt_ids"): + self.corrupt_indices = IC.load(path=cache_path, file_id=f"{dataset_id}_corrupt_ids") + else: + self.corrupt_indices = self.get_corrupt_sample_ids() + IC.save(path=cache_path, file_id=f"{dataset_id}_corrupt_ids", indices=self.corrupt_indices) + + self.corrupt_labels = [self.corrupt_label(self.dataset[i][1]) for i in self.corrupt_indices] + IC.save(path=cache_path, file_id=f"{dataset_id}_corrupt_labels", indices=self.corrupt_labels) + + def get_corrupt_sample_ids(self): + corrupt = torch.rand(len(self.dataset)) + return torch.where(corrupt < self.p)[0] + + def __getitem__(self, item): + x, y_true = self.dataset[item] + y = y_true + if item in self.corrupt_indices: + y = int( + self.corrupt_labels[torch.squeeze((self.corrupt_indices == item).nonzero())] + ) # TODO: not the most elegant solution + return x, (y, y_true) + + def __len__(self): + return len(self.dataset) + + def corrupt_label(self, y): + classes = [cls for cls in self.classes if cls != y] + random.seed(27) + corrupted_class = random.choice(classes) + return corrupted_class diff --git a/src/utils/datasets/group_label_dataset.py b/src/utils/datasets/group_label_dataset.py new file mode 100644 index 00000000..ea71d9f3 --- /dev/null +++ b/src/utils/datasets/group_label_dataset.py @@ -0,0 +1,32 @@ +from torch.utils.data.dataset import Dataset + +CLASS_GROUP_BY = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] + + +class GroupLabelDataset(Dataset): + def __init__(self, dataset, class_groups=None): + self.dataset = dataset + self.class_labels = [i for i in range(len(class_groups))] + self.inverse_transform = dataset.inverse_transform + if class_groups is None: + class_groups = CLASS_GROUP_BY + self.class_groups = class_groups + self.inverted_class_groups = self.invert_class_groups(class_groups) + + def __getitem__(self, index): + x, y = self.dataset[index] + g = self.inverted_class_groups[y] + return x, (g, y) + + def __len__(self): + return len(self.dataset) + + @staticmethod + def invert_class_groups(groups): + inverted_class_groups = {} + for g, group in enumerate(groups): + intersection = inverted_class_groups.keys() & group + if len(intersection) > 0: + raise ValueError("Class indices %s are present in multiple groups." % (str(intersection))) + inverted_class_groups.update({cls: g for cls in group}) + return inverted_class_groups diff --git a/src/datasets/restricted_dataset.py b/src/utils/datasets/indexed_subset.py similarity index 62% rename from src/datasets/restricted_dataset.py rename to src/utils/datasets/indexed_subset.py index 1c1a3d43..37634972 100644 --- a/src/datasets/restricted_dataset.py +++ b/src/utils/datasets/indexed_subset.py @@ -1,15 +1,12 @@ +import torch from torch.utils.data.dataset import Dataset -class RestrictedDataset(Dataset): - def __init__(self, dataset, indices, return_indices=False): +class IndexedSubset(Dataset): + def __init__(self, dataset: torch.utils.data.Dataset, indices, return_indices=False): self.dataset = dataset self.indices = indices self.return_indices = return_indices - if hasattr(dataset, "name"): - self.name = dataset.name - else: - self.name = dataset.dataset.name def __len__(self): return len(self.indices) diff --git a/src/utils/datasets/mark_dataset.py b/src/utils/datasets/mark_dataset.py new file mode 100644 index 00000000..df832ceb --- /dev/null +++ b/src/utils/datasets/mark_dataset.py @@ -0,0 +1,90 @@ +import random +from typing import Callable, List + +import torch +from torch.utils.data.dataset import Dataset + +from utils.cache import IndicesCache as IC + + +class MarkDataset(Dataset): + def __init__( + self, + dataset: torch.utils.data.Dataset, + dataset_id: str, + class_labels: List[int], + inverse_transform: Callable = lambda x: x, + cache_path: str = "./datasets", + p: float = 0.3, + cls_to_mark: int = 2, + only_train: bool = False, + ): + super().__init__() + self.dataset = dataset + self.dataset_id = dataset_id + self.inverse_transform = inverse_transform + self.class_labels = class_labels + self.only_train = only_train + self.cls_to_mark = cls_to_mark + self.mark_prob = p + + if IC.exists(path=cache_path, file_id=f"{dataset_id}_mark_ids"): + self.mark_indices = IC.load(path="./datasets", file_id=f"{dataset_id}_mark_ids") + else: + self.mark_indices = self.get_mark_sample_ids() + IC.save(path=cache_path, file_id=f"{dataset_id}_mark_ids") + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + x, y = self.dataset[index] + if index in self.mark_indices: + return self.mark_image(x), y + else: + return x, y + + def get_mark_sample_ids(self): + cls_labels = [data[1] for data in self.dataset if data[1] in self.cls_to_mark] + n_marked = int(len(cls_labels) * self.mark_prob) + random.seed(27) # TODO: check best practices for setting seed + indices = random.sample(cls_labels, n_marked) + return torch.tensor(indices, dtype=torch.int) + + def mark_image_contour(self, x): + x = self.inverse_transform(x) + # TODO: make controur, middle square and combined masks a constant somewhere else + mask = torch.zeros_like(x[0]) + mask[:2, :] = 1.0 + mask[-2:, :] = 1.0 + mask[:, -2:] = 1.0 + mask[:, :2] = 1.0 + x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) + if x.shape[0] > 1: + x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) + + return self.dataset.transform(x.numpy().transpose(1, 2, 0)) + + def mark_image_middle_square(self, x): + x = self.inverse_transform(x) + mask = torch.zeros_like(x[0]) + mid = int(x.shape[-1] / 2) + mask[(mid - 4) : (mid + 4), (mid - 4) : (mid + 4)] = 1.0 + x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) + if x.shape[0] > 1: + x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) + return self.dataset.transform(x.numpy().transpose(1, 2, 0)) + + def mark_image(self, x): + x = self.inverse_transform(x) + mask = torch.zeros_like(x[0]) + mid = int(x.shape[-1] / 2) + mask[mid - 3 : mid + 3, mid - 3 : mid + 3] = 1.0 + mask[:2, :] = 1.0 + mask[-2:, :] = 1.0 + mask[:, -2:] = 1.0 + mask[:, :2] = 1.0 + x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) + if x.shape[0] > 1: + x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) + return self.dataset.transform(x.numpy().transpose(1, 2, 0)) diff --git a/src/datasets/reduced_label_dataset.py b/src/utils/datasets/reduced_label_dataset.py similarity index 67% rename from src/datasets/reduced_label_dataset.py rename to src/utils/datasets/reduced_label_dataset.py index 24bb1b31..72b67aa6 100644 --- a/src/datasets/reduced_label_dataset.py +++ b/src/utils/datasets/reduced_label_dataset.py @@ -2,12 +2,11 @@ class ReduceLabelDataset(Dataset): - def __init__(self, dataset, first=True): + def __init__(self, dataset, classes, class_groups, first=True): super().__init__() self.dataset = dataset - if hasattr(dataset, "class_groups"): - self.class_groups = dataset.class_groups - self.classes = dataset.classes + self.class_groups = class_groups + self.classes = classes self.first = first def __len__(self): diff --git a/src/datasets/utils.py b/src/utils/datasets/utils.py similarity index 86% rename from src/datasets/utils.py rename to src/utils/datasets/utils.py index a90e2b15..a18ad239 100644 --- a/src/datasets/utils.py +++ b/src/utils/datasets/utils.py @@ -1,9 +1,9 @@ from torchvision.datasets import CIFAR10, MNIST, FashionMNIST -from datasets.corrupt_label_dataset import CorruptLabelDataset -from datasets.group_label_dataset import GroupLabelDataset -from datasets.mark_dataset import MarkDataset -from datasets.reduced_label_dataset import ReduceLabelDataset +from utils.datasets.corrupt_label_dataset import CorruptLabelDataset +from utils.datasets.group_label_dataset import GroupLabelDataset +from utils.datasets.mark_dataset import MarkDataset +from utils.datasets.reduced_label_dataset import ReduceLabelDataset def load_datasets(dataset_name, dataset_type, **kwparams):