diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..3024b62e --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[run] +source = src +omit = + /tests/* + +[report] +ignore_errors = True diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml deleted file mode 100644 index d88f771a..00000000 --- a/.github/workflows/lint.yml +++ /dev/null @@ -1,28 +0,0 @@ -name: Lint - -on: - pull_request: - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - lint: - runs-on: ubuntu-latest - name: Lint - steps: - - name: Check out source repository - uses: actions/checkout@v4 - - name: Set up Python environment - uses: actions/setup-python@v4 - with: - cache: 'pip' - python-version: "3.11" - - name: Install tox-gh - run: pip install tox-gh - - name: Run flake8 - run: tox run -e lint - - name: Run mypy - run: tox run -e type \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..dd19bf4a --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +*/*.egg-info/ +*.egg-info/ + +/.idea/ +/.tox/ +/.coverage + +.pytest_cache +*.DS_Store +__pycache__/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..969389a4 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,19 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: local + hooks: + - id: style + name: style + entry: make + args: ["style"] + language: system + pass_filenames: false diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..435fbf6c --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +# Makefile +SHELL = /bin/bash + +# Styling +.PHONY: style +style: + black . + flake8 + python3 -m isort . diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..8b2028bf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,54 @@ +[project] +name = "data_attribution_evaluation" +description = "data_attribution_evaluation" +license = { file = "LICENSE" } +readme = "README.md" +requires-python = ">=3.11" +keywords = ["explainable ai", "xai", "machine learning", "deep learning"] + +dependencies = [ + "numpy>=1.19.5", + "torch>=1.13.1", +] +dynamic = ["version"] + +[tool.isort] +profile = "black" +line_length = 79 +multi_line_output = 3 +include_trailing_comma = true + +# Black formatting +[tool.black] +line-length = 150 +include = '\.pyi?$' +exclude = ''' +/( + .eggs # exclude a few common directories in the + | .git # root of the project + | .hg + | .mypy_cache + | .tox + | venv + | _build + | buck-out + | build + | dist + )/ +''' + +# Pytest +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" + +[project.optional-dependencies] +tests = [ + "coverage>=7.2.3", + "flake8>=6.0.0", + "pytest<=7.4.4", + "pytest-cov>=4.0.0", + "pytest-lazy-fixture>=0.6.3", + "pytest-mock==3.10.0", + "pytest_xdist", +] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..3ba094ed --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + utils: utils files diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/explainers/__init__.py b/src/explainers/__init__.py index 3105980e..e69de29b 100644 --- a/src/explainers/__init__.py +++ b/src/explainers/__init__.py @@ -1,140 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Union -import torch -import os - - -class Explainer(ABC): - def __init__(self, model: torch.nn.Module, dataset: torch.data.utils.Dataset, device: Union[str, torch.device]): - self.model = model - self.device = torch.device(device) if isinstance(device, str) else device - self.images = dataset - self.samples = [] - self.labels = [] - dev = torch.device(device) - self.model.to(dev) - - @abstractmethod - def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor) -> torch.Tensor: - pass - - def train(self) -> None: - pass - - def save_coefs(self, dir: str) -> None: - pass - - -class FeatureKernelExplainer(Explainer): - def __init__( - self, model: torch.nn.Module, feature_extractor: Union[str, torch.nn.Module], - classifier: Union[str, torch.nn.Module], dataset: torch.data.utils.Dataset, - device: Union[str, torch.device], - file: str, normalize: bool = True - ): - super().__init__(model, dataset, device) - # self.sanity_check = sanity_check - if file is not None: - if not os.path.isfile(file) and not os.path.isdir(file): - file = None - feature_ds = FeatureDataset(self.model, dataset, device, file) - self.coefficients = None # the coefficients for each training datapoint x class - self.learned_weights = None - self.normalize = normalize - self.samples = feature_ds.samples.to(self.device) - self.mean = self.samples.sum(0) / self.samples.shape[0] - # self.mean = torch.zeros_like(self.mean) - self.stdvar = torch.sqrt(torch.sum((self.samples - self.mean) ** 2, dim=0) / self.samples.shape[0]) - # self.stdvar=torch.ones_like(self.stdvar) - self.normalized_samples = self.normalize_features(self.samples) if normalize else self.samples - self.labels = torch.tensor(feature_ds.labels, dtype=torch.int, device=self.device) - - def normalize_features(self, features: torch.Tensor) -> torch.Tensor: - return (features - self.mean) / self.stdvar - - def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor): - assert self.coefficients is not None - x = x.to(self.device) - f = self.model.features(x) - if self.normalize: - f = self.normalize_features(f) - crosscorr = torch.matmul(f, self.normalized_samples.T) - crosscorr = crosscorr[:, :, None] - xpl = self.coefficients * crosscorr - indices = explanation_targets[:, None, None].expand(-1, self.samples.shape[0], 1) - xpl = torch.gather(xpl, dim=-1, index=indices) - return torch.squeeze(xpl) - - def save_coefs(self, dir: str): - torch.save(self.coefficients, os.path.join(dir, f"{self.name}_coefs")) - - -class GradientProductExplainer(Explainer): - name = "GradientProductExplainer" - - def get_param_grad(self, x: torch.Tensor, index: int = None): - x = x.to(self.device) - out = self.model(x[None, :, :]) - if index is None: - index = range(self.model.classifier.out_features) - else: - index = [index] - grads = torch.empty(len(index), self.number_of_params) - - for i, ind in enumerate(index): - assert ind > -1 and int(ind) == ind - self.model.zero_grad() - if self.loss is not None: - out_new = self.loss(out, torch.eye(out.shape[1], device=self.device)[None, ind]) - out_new.backward(retain_graph=True) - else: - out[0][ind].backward(retain_graph=True) - cumul = torch.empty(0, device=self.device) - for par in self.model.sim_parameters(): - grad = par.grad.flatten() - cumul = torch.cat((cumul, grad), 0) - grads[i] = cumul - - return torch.squeeze(grads) - - def __init__( - self, model: torch.nn.Module, dataset: torch.utils.data.Dataset, device: Union[str, torch.device], loss=None - ): - super().__init__(model, dataset, device) - self.number_of_params = 0 - self.loss = loss - - for p in list(self.model.sim_parameters()): - nn = 1 - for s in list(p.size()): - nn = nn * s - self.number_of_params += nn - # USE get_param_grad instead of grad_ds = GradientDataset(self.model, dataset) - self.dataset = dataset - - -def explain(self, x, preds=None, targets=None): - assert not ((targets is None) and (self.loss is not None)) - xpl = torch.zeros((x.shape[0], len(self.dataset)), dtype=torch.float) - xpl = xpl.to(self.device) - t = time.time() - for j in range(len(self.dataset)): - tr_sample, y = self.dataset[j] - train_grad = self.get_param_grad(tr_sample, y) - train_grad = train_grad / torch.norm(train_grad) - train_grad.to(self.device) - for i in range(x.shape[0]): - if self.loss is None: - test_grad = self.get_param_grad(x[i], preds[i]) - else: - test_grad = self.get_param_grad(x[i], targets[i]) - test_grad.to(self.device) - xpl[i, j] = torch.matmul(train_grad, test_grad) - if j % 1000 == 0: - tdiff = time.time() - t - mins = int(tdiff / 60) - print( - f'{int(j / 1000)}/{int(len(self.dataset) / 1000)}k- 1000 images done in {mins} minutes {tdiff - 60 * mins}' - ) - t = time.time() - return xpl diff --git a/src/explainers/base.py b/src/explainers/base.py new file mode 100644 index 00000000..04387df6 --- /dev/null +++ b/src/explainers/base.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import Union + +import torch + + +class Explainer(ABC): + def __init__(self, model: torch.nn.Module, dataset: torch.data.utils.Dataset, device: Union[str, torch.device]): + self.model = model + self.device = torch.device(device) if isinstance(device, str) else device + self.images = dataset + self.samples = [] + self.labels = [] + dev = torch.device(device) + self.model.to(dev) + + @abstractmethod + def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor) -> torch.Tensor: + pass + + def train(self) -> None: + pass + + def save_coefs(self, dir: str) -> None: + pass diff --git a/src/explainers/feature_kernel_explainer.py b/src/explainers/feature_kernel_explainer.py new file mode 100644 index 00000000..e6173147 --- /dev/null +++ b/src/explainers/feature_kernel_explainer.py @@ -0,0 +1,55 @@ +import os +from typing import Union + +import torch + +from src.explainers.base import Explainer +from src.utils.data.feature_dataset import FeatureDataset + + +class FeatureKernelExplainer(Explainer): + def __init__( + self, + model: torch.nn.Module, + feature_extractor: Union[str, torch.nn.Module], + classifier: Union[str, torch.nn.Module], + dataset: torch.data.utils.Dataset, + device: Union[str, torch.device], + file: str, + normalize: bool = True, + ): + super().__init__(model, dataset, device) + # self.sanity_check = sanity_check + if file is not None: + if not os.path.isfile(file) and not os.path.isdir(file): + file = None + feature_ds = FeatureDataset(self.model, dataset, device, file) + self.coefficients = None # the coefficients for each training datapoint x class + self.learned_weights = None + self.normalize = normalize + self.samples = feature_ds.samples.to(self.device) + self.mean = self.samples.sum(0) / self.samples.shape[0] + # self.mean = torch.zeros_like(self.mean) + self.stdvar = torch.sqrt(torch.sum((self.samples - self.mean) ** 2, dim=0) / self.samples.shape[0]) + # self.stdvar=torch.ones_like(self.stdvar) + self.normalized_samples = self.normalize_features(self.samples) if normalize else self.samples + self.labels = torch.tensor(feature_ds.labels, dtype=torch.int, device=self.device) + + def normalize_features(self, features: torch.Tensor) -> torch.Tensor: + return (features - self.mean) / self.stdvar + + def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor): + assert self.coefficients is not None + x = x.to(self.device) + f = self.model.features(x) + if self.normalize: + f = self.normalize_features(f) + crosscorr = torch.matmul(f, self.normalized_samples.T) + crosscorr = crosscorr[:, :, None] + xpl = self.coefficients * crosscorr + indices = explanation_targets[:, None, None].expand(-1, self.samples.shape[0], 1) + xpl = torch.gather(xpl, dim=-1, index=indices) + return torch.squeeze(xpl) + + def save_coefs(self, dir: str): + torch.save(self.coefficients, os.path.join(dir, f"{self.name}_coefs")) diff --git a/src/explainers/utils.py b/src/explainers/utils.py new file mode 100644 index 00000000..928e7d6e --- /dev/null +++ b/src/explainers/utils.py @@ -0,0 +1,28 @@ +import time + +import torch + + +def explain(self, x, preds=None, targets=None): + assert not ((targets is None) and (self.loss is not None)) + xpl = torch.zeros((x.shape[0], len(self.dataset)), dtype=torch.float) + xpl = xpl.to(self.device) + t = time.time() + for j in range(len(self.dataset)): + tr_sample, y = self.dataset[j] + train_grad = self.get_param_grad(tr_sample, y) + train_grad = train_grad / torch.norm(train_grad) + train_grad.to(self.device) + for i in range(x.shape[0]): + if self.loss is None: + test_grad = self.get_param_grad(x[i], preds[i]) + else: + test_grad = self.get_param_grad(x[i], targets[i]) + test_grad.to(self.device) + xpl[i, j] = torch.matmul(train_grad, test_grad) + if j % 1000 == 0: + tdiff = time.time() - t + mins = int(tdiff / 60) + print(f"{int(j / 1000)}/{int(len(self.dataset) / 1000)}k- 1000 images done in {mins} minutes {tdiff - 60 * mins}") + t = time.time() + return xpl diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py index d214124f..e69de29b 100644 --- a/src/metrics/__init__.py +++ b/src/metrics/__init__.py @@ -1,30 +0,0 @@ -from abc import ABC, abstractmethod -import json -import torch -from typing import Union - - -class Metric(ABC): - name = "BaseMetricClass" - - @abstractmethod - def __init__(self, train: torch.utils.data.Dataset, test: torch.utils.data.Dataset): - pass - - @abstractmethod - def __call__(self, *args, **kwargs): - pass - - @abstractmethod - def get_result(self, dir: str): - pass - - - @staticmethod - def to_float(results: Union[dict, str, torch.Tensor]) -> Union[dict, str, torch.Tensor]: - if isinstance(results, dict): - return {key: Metric.to_float(r) for key, r in results.items()} - elif isinstance(results, str): - return results - else: - return np.array(results).astype(float).tolist() diff --git a/src/metrics/base.py b/src/metrics/base.py new file mode 100644 index 00000000..b8f8d532 --- /dev/null +++ b/src/metrics/base.py @@ -0,0 +1,30 @@ +from abc import ABC, abstractmethod +from typing import Union + +import numpy as np +import torch + + +class Metric(ABC): + name = "BaseMetricClass" + + @abstractmethod + def __init__(self, train: torch.utils.data.Dataset, test: torch.utils.data.Dataset): + pass + + @abstractmethod + def __call__(self, *args, **kwargs): + pass + + @abstractmethod + def get_result(self, dir: str): + pass + + @staticmethod + def to_float(results: Union[dict, str, torch.Tensor]) -> Union[dict, str, torch.Tensor]: + if isinstance(results, dict): + return {key: Metric.to_float(r) for key, r in results.items()} + elif isinstance(results, str): + return results + else: + return np.array(results).astype(float).tolist() diff --git a/src/usage.py b/src/usage.py index d1933ee9..abf4bcd5 100644 --- a/src/usage.py +++ b/src/usage.py @@ -1,19 +1,21 @@ -from explainers import Explainer -from metrics import Metric +from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision.models import resnet18 -from torch.utils.data import DataLoader -train_ds=MNIST(root="~/Documents/Code/Datasets",train=True) -test_ds=MNIST(root="~/Documents/Code/Datasets",train=False) -test_ld=DataLoader(test_ds,batch_size=32) -model=resnet18() + +from explainers.base import Explainer +from metrics.base import Metric + +train_ds = MNIST(root="~/Documents/Code/Datasets", train=True) +test_ds = MNIST(root="~/Documents/Code/Datasets", train=False) +test_ld = DataLoader(test_ds, batch_size=32) +model = resnet18() # Possibly get special kinds of datasets here -metric=Metric(train_ds,test_ds) +metric = Metric(train_ds, test_ds) # Possibly train model on the special kind of dataset with something like metric.train_model() -explainer=Explainer(model,train_ds,"cuda") +explainer = Explainer(model, train_ds, "cuda") explainer.train() -for x,y in iter(test_ld): - preds=model(x).argmax(dim=-1) - xpl=explainer.explain(x,preds) +for x, y in iter(test_ld): + preds = model(x).argmax(dim=-1) + xpl = explainer.explain(x, preds) metric(xpl) -metric.get_result() \ No newline at end of file +metric.get_result() diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 8d16d005..e69de29b 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1 +0,0 @@ -from .data import * \ No newline at end of file diff --git a/src/utils/data.py b/src/utils/data.py deleted file mode 100644 index ffd8d2a5..00000000 --- a/src/utils/data.py +++ /dev/null @@ -1,296 +0,0 @@ -from tqdm import tqdm -import torch -from torch.utils.data.dataset import Dataset -from datasets.MNIST import MNIST, FashionMNIST -from datasets.CIFAR import CIFAR -import matplotlib.pyplot as plt -import os - - -class ReduceLabelDataset(Dataset): - def __init__(self, dataset, first=True): - super().__init__() - self.dataset = dataset - if hasattr(dataset, "class_groups"): - self.class_groups = dataset.class_groups - self.classes = dataset.classes - self.first = first - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, item): - x, (y, c) = self.dataset.__getitem__(item) - if self.first: - return x, y - else: - return x, c - - -class CorruptLabelDataset(Dataset): - def corrupt_label(self, y): - ret = y - while ret == y: - ret = torch.randint(0, len(self.dataset.classes), (1,)) - return ret - - def __init__(self, dataset, p=0.3): - super().__init__() - self.class_labels = dataset.class_labels - torch.manual_seed(420) # THIS SHOULD NOT BE CHANGED BETWEEN TRAIN TIME AND TEST TIME - self.inverse_transform = dataset.inverse_transform - self.dataset = dataset - if hasattr(dataset, "class_groups"): - self.class_groups = dataset.class_groups - self.classes = dataset.classes - if os.path.isfile(f'datasets/{dataset.name}_corrupt_ids'): - self.corrupt_samples = torch.load(f'datasets/{dataset.name}_corrupt_ids') - self.corrupt_labels = torch.load(f'datasets/{dataset.name}_corrupt_labels') - else: - self.corrupt_labels = [] - corrupt = torch.rand(len(dataset)) - self.corrupt_samples = torch.squeeze((corrupt < p).nonzero()) - torch.save(self.corrupt_samples, f'datasets/{dataset.name}_corrupt_ids') - for i in self.corrupt_samples: - _, y = self.dataset.__getitem__(i) - self.corrupt_labels.append(self.corrupt_label(y)) - self.corrupt_labels = torch.tensor(self.corrupt_labels) - torch.save(self.corrupt_labels, f"datasets/{dataset.name}_corrupt_labels") - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, item): - x, y_true = self.dataset.__getitem__(item) - y = y_true - if self.dataset.split == "train": - if item in self.corrupt_samples: - y = int(self.corrupt_labels[torch.squeeze((self.corrupt_samples == item).nonzero())]) - return x, (y, y_true) - - -class MarkDataset(Dataset): - def get_mark_sample_ids(self): - indices = [] - cls = self.cls_to_mark - prob = self.mark_prob - for i in range(len(self.dataset)): - x, y = self.dataset[i] - if y == cls: - rnd = torch.rand(1) - if rnd < prob: - indices.append(i) - return torch.tensor(indices, dtype=torch.int) - - def __init__(self, dataset, p=0.3, cls_to_mark=2, only_train=False): - super().__init__() - self.class_labels = dataset.class_labels - torch.manual_seed(420) # THIS SHOULD NOT BE CHANGED BETWEEN TRAIN TIME AND TEST TIME - self.only_train = only_train - self.dataset = dataset - self.inverse_transform = dataset.inverse_transform - self.cls_to_mark = cls_to_mark - self.mark_prob = p - if hasattr(dataset, "class_groups"): - self.class_groups = dataset.class_groups - self.classes = dataset.classes - if dataset.split == "train": - if os.path.isfile(f'datasets/{dataset.name}_mark_ids'): - self.mark_samples = torch.load(f'datasets/{dataset.name}_mark_ids') - else: - self.mark_samples = self.get_mark_sample_ids() - torch.save(self.mark_samples, f'datasets/{dataset.name}_mark_ids') - else: - self.mark_samples = range(len(dataset)) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, item): - x, y = self.dataset.__getitem__(item) - if not self.dataset.split == "train": - if self.only_train: - return x, y - else: - return self.mark_image(x), y - else: - if item in self.mark_samples: - return self.mark_image(x), y - else: - return x, y - - def mark_image_contour(self, x): - x = self.dataset.inverse_transform(x) - mask = torch.zeros_like(x[0]) - # for j in range(int(x.shape[-1]/2)): - # mask[2*(j):2*(j+1),2*(j):2*(j+1)]=1. - # mask[2*j:2*(j+1),-2*(j+1):-2*(j)]=1. - mask[:2, :] = 1. - mask[-2:, :] = 1. - mask[:, -2:] = 1. - mask[:, :2] = 1. - x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) - if x.shape[0] > 1: - x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) - # plt.imshow(x.permute(1,2,0).squeeze()) - # plt.show() - - return self.dataset.transform(x.numpy().transpose(1, 2, 0)) - - def mark_image_middle_square(self, x): - x = self.dataset.inverse_transform(x) - mask = torch.zeros_like(x[0]) - mid = int(x.shape[-1] / 2) - mask[mid - 4:mid + 4, mid - 4:mid + 4] = 1. - x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) - if x.shape[0] > 1: - x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) - # plt.imshow(x.permute(1,2,0).squeeze()) - # plt.show() - return self.dataset.transform(x.numpy().transpose(1, 2, 0)) - - def mark_image(self, x): - x = self.dataset.inverse_transform(x) - mask = torch.zeros_like(x[0]) - mid = int(x.shape[-1] / 2) - mask[mid - 3:mid + 3, mid - 3:mid + 3] = 1. - mask[:2, :] = 1. - mask[-2:, :] = 1. - mask[:, -2:] = 1. - mask[:, :2] = 1. - x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) - if x.shape[0] > 1: - x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) - # plt.imshow(x.permute(1,2,0).squeeze()) - # plt.show() - return self.dataset.transform(x.numpy().transpose(1, 2, 0)) - - -class GroupLabelDataset(Dataset): - class_group_by2 = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] - - @staticmethod - def check_class_groups(groups): - vals = [[] for _ in range(10)] - for g, group in enumerate(groups): - for i in group: - vals[i].append(g) - for v in vals: - assert (len(v) == 1) # Check that this is the first time i is encountered - - def __init__(self, dataset, class_groups=None): - self.dataset = dataset - self.class_labels = [i for i in range(len(class_groups))] - self.inverse_transform = dataset.inverse_transform - if class_groups is None: - class_groups = GroupLabelDataset.class_group_by2 - self.classes = class_groups - GroupLabelDataset.check_class_groups(class_groups) - - def __getitem__(self, item): - x, y = self.dataset.__getitem__(item) - g = -1 - for i, group in enumerate(self.classes): - if y in group: - g = i - break - return x, (g, y) - - def __len__(self): - return len(self.dataset) - - -class FeatureDataset(Dataset): - def __init__(self, model, dataset, device, file=None): - self.model = model - self.device = device - self.samples = torch.empty(size=(0, model.classifier.in_features), device=self.device) - self.labels = torch.empty(size=(0,), device=self.device) - loader = torch.utils.data.DataLoader(dataset, batch_size=32) - super().__init__() - if file is not None: - self.load_from_file(file) - else: - for x, y in tqdm(iter(loader)): - x = x.to(self.device) - y = y.to(self.device) - with torch.no_grad(): - x = model.features(x) - self.samples = torch.cat((self.samples, x), 0) - self.labels = torch.cat((self.labels, y), 0) - - def __len__(self): - return len(self.samples) - - def __getitem__(self, item): - return self.samples[item], self.labels[item] - - def load_from_file(self, file): - if ".csv" in file: - from utils.csv_io import read_matrix - mat = read_matrix(file_name=file) - self.samples = torch.tensor(mat[:, :-1], dtype=torch.float, device=self.device) - self.labels = torch.tensor(mat[:, -1], dtype=torch.int, device=self.device) - else: - self.samples = torch.load(os.path.join(file, "samples_tensor"), map_location=self.device) - self.labels = torch.load(os.path.join(file, "labels_tensor"), map_location=self.device) - - -class RestrictedDataset(Dataset): - def __init__(self, dataset, indices, return_indices=False): - self.dataset = dataset - self.indices = indices - self.return_indices = return_indices - if hasattr(dataset, "name"): - self.name = dataset.name - else: - self.name = dataset.dataset.name - - def __len__(self): - return len(self.indices) - - def __getitem__(self, item): - d = self.dataset[self.indices[item]] - if self.return_indices: - return d, self.indices[item] - return d - - -def load_datasets(dataset_name, dataset_type, **kwparams): - ds = None - evalds = None - ds_dict = {'MNIST': MNIST, 'CIFAR': CIFAR, 'FashionMNIST': FashionMNIST} - if "only_train" not in kwparams.keys(): - only_train = False - else: - only_train = kwparams['only_train'] - data_root = kwparams['data_root'] - class_groups = kwparams['class_groups'] - validation_size = kwparams['validation_size'] - set = kwparams['image_set'] - - if dataset_name in ds_dict.keys(): - dscls = ds_dict[dataset_name] - ds = dscls(root=data_root, split="train", validation_size=validation_size) - evalds = dscls(root=data_root, split=set, validation_size=validation_size) - else: - raise NameError(f"Unresolved dataset name : {dataset_name}.") - if dataset_type == "group": - ds = GroupLabelDataset(ds, class_groups=class_groups) - evalds = GroupLabelDataset(evalds, class_groups=class_groups) - elif dataset_type == "corrupt": - ds = CorruptLabelDataset(ds) - evalds = CorruptLabelDataset(evalds) - elif dataset_type == "mark": - ds = MarkDataset(ds, only_train=only_train) - evalds = MarkDataset(evalds, only_train=only_train) - assert ds is not None and evalds is not None - return ds, evalds - - -def load_datasets_reduced(dataset_name, dataset_type, kwparams): - ds, evalds = load_datasets(dataset_name, dataset_type, **kwparams) - if dataset_type in ["group", "corrupt"]: - ds = ReduceLabelDataset(ds) - evalds = ReduceLabelDataset(evalds) - return ds, evalds \ No newline at end of file diff --git a/src/utils/datasets/__init__.py b/src/utils/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/utils/datasets/corrupt_label_dataset.py b/src/utils/datasets/corrupt_label_dataset.py new file mode 100644 index 00000000..bfbd2283 --- /dev/null +++ b/src/utils/datasets/corrupt_label_dataset.py @@ -0,0 +1,46 @@ +import os + +import torch +from torch.utils.data.dataset import Dataset + + +class CorruptLabelDataset(Dataset): + def corrupt_label(self, y): + ret = y + while ret == y: + ret = torch.randint(0, len(self.dataset.classes), (1,)) + return ret + + def __init__(self, dataset, p=0.3): + super().__init__() + self.class_labels = dataset.class_labels + torch.manual_seed(420) # THIS SHOULD NOT BE CHANGED BETWEEN TRAIN TIME AND TEST TIME + self.inverse_transform = dataset.inverse_transform + self.dataset = dataset + if hasattr(dataset, "class_groups"): + self.class_groups = dataset.class_groups + self.classes = dataset.classes + if os.path.isfile(f"datasets/{dataset.name}_corrupt_ids"): + self.corrupt_samples = torch.load(f"datasets/{dataset.name}_corrupt_ids") + self.corrupt_labels = torch.load(f"datasets/{dataset.name}_corrupt_labels") + else: + self.corrupt_labels = [] + corrupt = torch.rand(len(dataset)) + self.corrupt_samples = torch.squeeze((corrupt < p).nonzero()) + torch.save(self.corrupt_samples, f"datasets/{dataset.name}_corrupt_ids") + for i in self.corrupt_samples: + _, y = self.dataset.__getitem__(i) + self.corrupt_labels.append(self.corrupt_label(y)) + self.corrupt_labels = torch.tensor(self.corrupt_labels) + torch.save(self.corrupt_labels, f"datasets/{dataset.name}_corrupt_labels") + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + x, y_true = self.dataset[item] + y = y_true + if self.dataset.split == "train": + if item in self.corrupt_samples: + y = int(self.corrupt_labels[torch.squeeze((self.corrupt_samples == item).nonzero())]) + return x, (y, y_true) diff --git a/src/utils/datasets/feature_dataset.py b/src/utils/datasets/feature_dataset.py new file mode 100644 index 00000000..12829145 --- /dev/null +++ b/src/utils/datasets/feature_dataset.py @@ -0,0 +1,42 @@ +import os + +import torch +from torch.utils.data.dataset import Dataset +from tqdm import tqdm + + +class FeatureDataset(Dataset): + def __init__(self, model, dataset, device, file=None): + self.model = model + self.device = device + self.samples = torch.empty(size=(0, model.classifier.in_features), device=self.device) + self.labels = torch.empty(size=(0,), device=self.device) + loader = torch.utils.data.DataLoader(dataset, batch_size=32) + super().__init__() + if file is not None: + self.load_from_file(file) + else: + for x, y in tqdm(iter(loader)): + x = x.to(self.device) + y = y.to(self.device) + with torch.no_grad(): + x = model.features(x) + self.samples = torch.cat((self.samples, x), 0) + self.labels = torch.cat((self.labels, y), 0) + + def __len__(self): + return len(self.samples) + + def __getitem__(self, item): + return self.samples[item], self.labels[item] + + def load_from_file(self, file): + if ".csv" in file: + from utils.csv_io import read_matrix + + mat = read_matrix(file_name=file) + self.samples = torch.tensor(mat[:, :-1], dtype=torch.float, device=self.device) + self.labels = torch.tensor(mat[:, -1], dtype=torch.int, device=self.device) + else: + self.samples = torch.load(os.path.join(file, "samples_tensor"), map_location=self.device) + self.labels = torch.load(os.path.join(file, "labels_tensor"), map_location=self.device) diff --git a/src/utils/datasets/group_label_dataset.py b/src/utils/datasets/group_label_dataset.py new file mode 100644 index 00000000..18a2bc85 --- /dev/null +++ b/src/utils/datasets/group_label_dataset.py @@ -0,0 +1,35 @@ +from torch.utils.data.dataset import Dataset + + +class GroupLabelDataset(Dataset): + class_group_by2 = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]] + + @staticmethod + def check_class_groups(groups): + vals = [[] for _ in range(10)] + for g, group in enumerate(groups): + for i in group: + vals[i].append(g) + for v in vals: + assert len(v) == 1 # Check that this is the first time i is encountered + + def __init__(self, dataset, class_groups=None): + self.dataset = dataset + self.class_labels = [i for i in range(len(class_groups))] + self.inverse_transform = dataset.inverse_transform + if class_groups is None: + class_groups = GroupLabelDataset.class_group_by2 + self.classes = class_groups + GroupLabelDataset.check_class_groups(class_groups) + + def __getitem__(self, item): + x, y = self.dataset.__getitem__(item) + g = -1 + for i, group in enumerate(self.classes): + if y in group: + g = i + break + return x, (g, y) + + def __len__(self): + return len(self.dataset) diff --git a/src/utils/datasets/mark_dataset.py b/src/utils/datasets/mark_dataset.py new file mode 100644 index 00000000..3183c974 --- /dev/null +++ b/src/utils/datasets/mark_dataset.py @@ -0,0 +1,101 @@ +import os + +import torch +from torch.utils.data.dataset import Dataset + + +class MarkDataset(Dataset): + def get_mark_sample_ids(self): + indices = [] + cls = self.cls_to_mark + prob = self.mark_prob + for i in range(len(self.dataset)): + x, y = self.dataset[i] + if y == cls: + rnd = torch.rand(1) + if rnd < prob: + indices.append(i) + return torch.tensor(indices, dtype=torch.int) + + def __init__(self, dataset, p=0.3, cls_to_mark=2, only_train=False): + super().__init__() + self.class_labels = dataset.class_labels + torch.manual_seed(420) # THIS SHOULD NOT BE CHANGED BETWEEN TRAIN TIME AND TEST TIME + self.only_train = only_train + self.dataset = dataset + self.inverse_transform = dataset.inverse_transform + self.cls_to_mark = cls_to_mark + self.mark_prob = p + if hasattr(dataset, "class_groups"): + self.class_groups = dataset.class_groups + self.classes = dataset.classes + if dataset.split == "train": + if os.path.isfile(f"datasets/{dataset.name}_mark_ids"): + self.mark_samples = torch.load(f"datasets/{dataset.name}_mark_ids") + else: + self.mark_samples = self.get_mark_sample_ids() + torch.save(self.mark_samples, f"datasets/{dataset.name}_mark_ids") + else: + self.mark_samples = range(len(dataset)) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + x, y = self.dataset.__getitem__(item) + if not self.dataset.split == "train": + if self.only_train: + return x, y + else: + return self.mark_image(x), y + else: + if item in self.mark_samples: + return self.mark_image(x), y + else: + return x, y + + def mark_image_contour(self, x): + x = self.dataset.inverse_transform(x) + mask = torch.zeros_like(x[0]) + # for j in range(int(x.shape[-1]/2)): + # mask[2*(j):2*(j+1),2*(j):2*(j+1)]=1. + # mask[2*j:2*(j+1),-2*(j+1):-2*(j)]=1. + mask[:2, :] = 1.0 + mask[-2:, :] = 1.0 + mask[:, -2:] = 1.0 + mask[:, :2] = 1.0 + x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) + if x.shape[0] > 1: + x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) + # plt.imshow(x.permute(1,2,0).squeeze()) + # plt.show() + + return self.dataset.transform(x.numpy().transpose(1, 2, 0)) + + def mark_image_middle_square(self, x): + x = self.dataset.inverse_transform(x) + mask = torch.zeros_like(x[0]) + mid = int(x.shape[-1] / 2) + mask[(mid - 4) : (mid + 4), (mid - 4) : (mid + 4)] = 1.0 + x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) + if x.shape[0] > 1: + x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) + # plt.imshow(x.permute(1,2,0).squeeze()) + # plt.show() + return self.dataset.transform(x.numpy().transpose(1, 2, 0)) + + def mark_image(self, x): + x = self.dataset.inverse_transform(x) + mask = torch.zeros_like(x[0]) + mid = int(x.shape[-1] / 2) + mask[mid - 3 : mid + 3, mid - 3 : mid + 3] = 1.0 + mask[:2, :] = 1.0 + mask[-2:, :] = 1.0 + mask[:, -2:] = 1.0 + mask[:, :2] = 1.0 + x[0] = torch.ones_like(x[0]) * mask + x[0] * (1 - mask) + if x.shape[0] > 1: + x[1:] = torch.zeros_like(x[1:]) * mask + x[1:] * (1 - mask) + # plt.imshow(x.permute(1,2,0).squeeze()) + # plt.show() + return self.dataset.transform(x.numpy().transpose(1, 2, 0)) diff --git a/src/utils/datasets/reduced_label_dataset.py b/src/utils/datasets/reduced_label_dataset.py new file mode 100644 index 00000000..24bb1b31 --- /dev/null +++ b/src/utils/datasets/reduced_label_dataset.py @@ -0,0 +1,21 @@ +from torch.utils.data.dataset import Dataset + + +class ReduceLabelDataset(Dataset): + def __init__(self, dataset, first=True): + super().__init__() + self.dataset = dataset + if hasattr(dataset, "class_groups"): + self.class_groups = dataset.class_groups + self.classes = dataset.classes + self.first = first + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + x, (y, c) = self.dataset[item] + if self.first: + return x, y + else: + return x, c diff --git a/src/utils/datasets/restricted_dataset.py b/src/utils/datasets/restricted_dataset.py new file mode 100644 index 00000000..1c1a3d43 --- /dev/null +++ b/src/utils/datasets/restricted_dataset.py @@ -0,0 +1,21 @@ +from torch.utils.data.dataset import Dataset + + +class RestrictedDataset(Dataset): + def __init__(self, dataset, indices, return_indices=False): + self.dataset = dataset + self.indices = indices + self.return_indices = return_indices + if hasattr(dataset, "name"): + self.name = dataset.name + else: + self.name = dataset.dataset.name + + def __len__(self): + return len(self.indices) + + def __getitem__(self, item): + d = self.dataset[self.indices[item]] + if self.return_indices: + return d, self.indices[item] + return d diff --git a/src/utils/datasets/utils.py b/src/utils/datasets/utils.py new file mode 100644 index 00000000..a40bc537 --- /dev/null +++ b/src/utils/datasets/utils.py @@ -0,0 +1,46 @@ +from torchvision.datasets import CIFAR10, MNIST, FashionMNIST + +from src.utils.datasets.corrupt_label_dataset import CorruptLabelDataset +from src.utils.datasets.group_label_dataset import GroupLabelDataset +from src.utils.datasets.mark_dataset import MarkDataset +from src.utils.datasets.reduced_label_dataset import ReduceLabelDataset + + +def load_datasets(dataset_name, dataset_type, **kwparams): + ds = None + evalds = None + ds_dict = {"MNIST": MNIST, "CIFAR": CIFAR10, "FashionMNIST": FashionMNIST} + if "only_train" not in kwparams.keys(): + only_train = False + else: + only_train = kwparams["only_train"] + data_root = kwparams["data_root"] + class_groups = kwparams["class_groups"] + validation_size = kwparams["validation_size"] + set = kwparams["image_set"] + + if dataset_name in ds_dict.keys(): + dscls = ds_dict[dataset_name] + ds = dscls(root=data_root, split="train", validation_size=validation_size) + evalds = dscls(root=data_root, split=set, validation_size=validation_size) + else: + raise NameError(f"Unresolved dataset name : {dataset_name}.") + if dataset_type == "group": + ds = GroupLabelDataset(ds, class_groups=class_groups) + evalds = GroupLabelDataset(evalds, class_groups=class_groups) + elif dataset_type == "corrupt": + ds = CorruptLabelDataset(ds) + evalds = CorruptLabelDataset(evalds) + elif dataset_type == "mark": + ds = MarkDataset(ds, only_train=only_train) + evalds = MarkDataset(evalds, only_train=only_train) + assert ds is not None and evalds is not None + return ds, evalds + + +def load_datasets_reduced(dataset_name, dataset_type, kwparams): + ds, evalds = load_datasets(dataset_name, dataset_type, **kwparams) + if dataset_type in ["group", "corrupt"]: + ds = ReduceLabelDataset(ds) + evalds = ReduceLabelDataset(evalds) + return ds, evalds diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/explainers/__init__.py b/tests/explainers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py new file mode 100644 index 00000000..1cc989bc --- /dev/null +++ b/tests/utils/conftest.py @@ -0,0 +1,9 @@ +import pytest +import torch + + +@pytest.fixture() +def dataset(): + x = torch.stack([torch.rand(2, 2), torch.rand(2, 2), torch.rand(2, 2)]) + y = torch.tensor([0, 1, 0]).long() + return torch.utils.data.TensorDataset(x, y) diff --git a/tests/utils/test_corrupt_label_dataset.py b/tests/utils/test_corrupt_label_dataset.py new file mode 100644 index 00000000..7c2807cd --- /dev/null +++ b/tests/utils/test_corrupt_label_dataset.py @@ -0,0 +1,9 @@ +import pytest + +from src.utils.datasets.corrupt_label_dataset import CorruptLabelDataset + + +@pytest.mark.utils +def test_corrupt_label_dataset(dataset): + cl_dataset = CorruptLabelDataset(dataset) + assert len(cl_dataset[0][1]) == 2 diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..a5b7912f --- /dev/null +++ b/tox.ini @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 127 +max-complexity = 10 +ignore = E203