-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
inits -> class files, styling, tests, pre-commit hooks
- Loading branch information
1 parent
76be708
commit 413f1e0
Showing
32 changed files
with
589 additions
and
508 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[run] | ||
source = src | ||
omit = | ||
/tests/* | ||
|
||
[report] | ||
ignore_errors = True |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
*/*.egg-info/ | ||
*.egg-info/ | ||
|
||
/.idea/ | ||
/.tox/ | ||
/.coverage | ||
|
||
.pytest_cache | ||
*.DS_Store | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# See https://pre-commit.com for more information | ||
# See https://pre-commit.com/hooks.html for more hooks | ||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v3.2.0 | ||
hooks: | ||
- id: trailing-whitespace | ||
- id: end-of-file-fixer | ||
- id: check-yaml | ||
- id: check-added-large-files | ||
|
||
- repo: local | ||
hooks: | ||
- id: style | ||
name: style | ||
entry: make | ||
args: ["style"] | ||
language: system | ||
pass_filenames: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Makefile | ||
SHELL = /bin/bash | ||
|
||
# Styling | ||
.PHONY: style | ||
style: | ||
black . | ||
flake8 | ||
python3 -m isort . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
[project] | ||
name = "data_attribution_evaluation" | ||
description = "data_attribution_evaluation" | ||
license = { file = "LICENSE" } | ||
readme = "README.md" | ||
requires-python = ">=3.11" | ||
keywords = ["explainable ai", "xai", "machine learning", "deep learning"] | ||
|
||
dependencies = [ | ||
"numpy>=1.19.5", | ||
"torch>=1.13.1", | ||
] | ||
dynamic = ["version"] | ||
|
||
[tool.isort] | ||
profile = "black" | ||
line_length = 79 | ||
multi_line_output = 3 | ||
include_trailing_comma = true | ||
|
||
# Black formatting | ||
[tool.black] | ||
line-length = 150 | ||
include = '\.pyi?$' | ||
exclude = ''' | ||
/( | ||
.eggs # exclude a few common directories in the | ||
| .git # root of the project | ||
| .hg | ||
| .mypy_cache | ||
| .tox | ||
| venv | ||
| _build | ||
| buck-out | ||
| build | ||
| dist | ||
)/ | ||
''' | ||
|
||
# Pytest | ||
[tool.pytest.ini_options] | ||
testpaths = ["tests"] | ||
python_files = "test_*.py" | ||
|
||
[project.optional-dependencies] | ||
tests = [ | ||
"coverage>=7.2.3", | ||
"flake8>=6.0.0", | ||
"pytest<=7.4.4", | ||
"pytest-cov>=4.0.0", | ||
"pytest-lazy-fixture>=0.6.3", | ||
"pytest-mock==3.10.0", | ||
"pytest_xdist", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[pytest] | ||
markers = | ||
utils: utils files |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,140 +0,0 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Union | ||
import torch | ||
import os | ||
|
||
|
||
class Explainer(ABC): | ||
def __init__(self, model: torch.nn.Module, dataset: torch.data.utils.Dataset, device: Union[str, torch.device]): | ||
self.model = model | ||
self.device = torch.device(device) if isinstance(device, str) else device | ||
self.images = dataset | ||
self.samples = [] | ||
self.labels = [] | ||
dev = torch.device(device) | ||
self.model.to(dev) | ||
|
||
@abstractmethod | ||
def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor) -> torch.Tensor: | ||
pass | ||
|
||
def train(self) -> None: | ||
pass | ||
|
||
def save_coefs(self, dir: str) -> None: | ||
pass | ||
|
||
|
||
class FeatureKernelExplainer(Explainer): | ||
def __init__( | ||
self, model: torch.nn.Module, feature_extractor: Union[str, torch.nn.Module], | ||
classifier: Union[str, torch.nn.Module], dataset: torch.data.utils.Dataset, | ||
device: Union[str, torch.device], | ||
file: str, normalize: bool = True | ||
): | ||
super().__init__(model, dataset, device) | ||
# self.sanity_check = sanity_check | ||
if file is not None: | ||
if not os.path.isfile(file) and not os.path.isdir(file): | ||
file = None | ||
feature_ds = FeatureDataset(self.model, dataset, device, file) | ||
self.coefficients = None # the coefficients for each training datapoint x class | ||
self.learned_weights = None | ||
self.normalize = normalize | ||
self.samples = feature_ds.samples.to(self.device) | ||
self.mean = self.samples.sum(0) / self.samples.shape[0] | ||
# self.mean = torch.zeros_like(self.mean) | ||
self.stdvar = torch.sqrt(torch.sum((self.samples - self.mean) ** 2, dim=0) / self.samples.shape[0]) | ||
# self.stdvar=torch.ones_like(self.stdvar) | ||
self.normalized_samples = self.normalize_features(self.samples) if normalize else self.samples | ||
self.labels = torch.tensor(feature_ds.labels, dtype=torch.int, device=self.device) | ||
|
||
def normalize_features(self, features: torch.Tensor) -> torch.Tensor: | ||
return (features - self.mean) / self.stdvar | ||
|
||
def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor): | ||
assert self.coefficients is not None | ||
x = x.to(self.device) | ||
f = self.model.features(x) | ||
if self.normalize: | ||
f = self.normalize_features(f) | ||
crosscorr = torch.matmul(f, self.normalized_samples.T) | ||
crosscorr = crosscorr[:, :, None] | ||
xpl = self.coefficients * crosscorr | ||
indices = explanation_targets[:, None, None].expand(-1, self.samples.shape[0], 1) | ||
xpl = torch.gather(xpl, dim=-1, index=indices) | ||
return torch.squeeze(xpl) | ||
|
||
def save_coefs(self, dir: str): | ||
torch.save(self.coefficients, os.path.join(dir, f"{self.name}_coefs")) | ||
|
||
|
||
class GradientProductExplainer(Explainer): | ||
name = "GradientProductExplainer" | ||
|
||
def get_param_grad(self, x: torch.Tensor, index: int = None): | ||
x = x.to(self.device) | ||
out = self.model(x[None, :, :]) | ||
if index is None: | ||
index = range(self.model.classifier.out_features) | ||
else: | ||
index = [index] | ||
grads = torch.empty(len(index), self.number_of_params) | ||
|
||
for i, ind in enumerate(index): | ||
assert ind > -1 and int(ind) == ind | ||
self.model.zero_grad() | ||
if self.loss is not None: | ||
out_new = self.loss(out, torch.eye(out.shape[1], device=self.device)[None, ind]) | ||
out_new.backward(retain_graph=True) | ||
else: | ||
out[0][ind].backward(retain_graph=True) | ||
cumul = torch.empty(0, device=self.device) | ||
for par in self.model.sim_parameters(): | ||
grad = par.grad.flatten() | ||
cumul = torch.cat((cumul, grad), 0) | ||
grads[i] = cumul | ||
|
||
return torch.squeeze(grads) | ||
|
||
def __init__( | ||
self, model: torch.nn.Module, dataset: torch.utils.data.Dataset, device: Union[str, torch.device], loss=None | ||
): | ||
super().__init__(model, dataset, device) | ||
self.number_of_params = 0 | ||
self.loss = loss | ||
|
||
for p in list(self.model.sim_parameters()): | ||
nn = 1 | ||
for s in list(p.size()): | ||
nn = nn * s | ||
self.number_of_params += nn | ||
# USE get_param_grad instead of grad_ds = GradientDataset(self.model, dataset) | ||
self.dataset = dataset | ||
|
||
|
||
def explain(self, x, preds=None, targets=None): | ||
assert not ((targets is None) and (self.loss is not None)) | ||
xpl = torch.zeros((x.shape[0], len(self.dataset)), dtype=torch.float) | ||
xpl = xpl.to(self.device) | ||
t = time.time() | ||
for j in range(len(self.dataset)): | ||
tr_sample, y = self.dataset[j] | ||
train_grad = self.get_param_grad(tr_sample, y) | ||
train_grad = train_grad / torch.norm(train_grad) | ||
train_grad.to(self.device) | ||
for i in range(x.shape[0]): | ||
if self.loss is None: | ||
test_grad = self.get_param_grad(x[i], preds[i]) | ||
else: | ||
test_grad = self.get_param_grad(x[i], targets[i]) | ||
test_grad.to(self.device) | ||
xpl[i, j] = torch.matmul(train_grad, test_grad) | ||
if j % 1000 == 0: | ||
tdiff = time.time() - t | ||
mins = int(tdiff / 60) | ||
print( | ||
f'{int(j / 1000)}/{int(len(self.dataset) / 1000)}k- 1000 images done in {mins} minutes {tdiff - 60 * mins}' | ||
) | ||
t = time.time() | ||
return xpl | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Union | ||
|
||
import torch | ||
|
||
|
||
class Explainer(ABC): | ||
def __init__(self, model: torch.nn.Module, dataset: torch.data.utils.Dataset, device: Union[str, torch.device]): | ||
self.model = model | ||
self.device = torch.device(device) if isinstance(device, str) else device | ||
self.images = dataset | ||
self.samples = [] | ||
self.labels = [] | ||
dev = torch.device(device) | ||
self.model.to(dev) | ||
|
||
@abstractmethod | ||
def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor) -> torch.Tensor: | ||
pass | ||
|
||
def train(self) -> None: | ||
pass | ||
|
||
def save_coefs(self, dir: str) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import os | ||
from typing import Union | ||
|
||
import torch | ||
|
||
from src.explainers.base import Explainer | ||
from src.utils.data.feature_dataset import FeatureDataset | ||
|
||
|
||
class FeatureKernelExplainer(Explainer): | ||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
feature_extractor: Union[str, torch.nn.Module], | ||
classifier: Union[str, torch.nn.Module], | ||
dataset: torch.data.utils.Dataset, | ||
device: Union[str, torch.device], | ||
file: str, | ||
normalize: bool = True, | ||
): | ||
super().__init__(model, dataset, device) | ||
# self.sanity_check = sanity_check | ||
if file is not None: | ||
if not os.path.isfile(file) and not os.path.isdir(file): | ||
file = None | ||
feature_ds = FeatureDataset(self.model, dataset, device, file) | ||
self.coefficients = None # the coefficients for each training datapoint x class | ||
self.learned_weights = None | ||
self.normalize = normalize | ||
self.samples = feature_ds.samples.to(self.device) | ||
self.mean = self.samples.sum(0) / self.samples.shape[0] | ||
# self.mean = torch.zeros_like(self.mean) | ||
self.stdvar = torch.sqrt(torch.sum((self.samples - self.mean) ** 2, dim=0) / self.samples.shape[0]) | ||
# self.stdvar=torch.ones_like(self.stdvar) | ||
self.normalized_samples = self.normalize_features(self.samples) if normalize else self.samples | ||
self.labels = torch.tensor(feature_ds.labels, dtype=torch.int, device=self.device) | ||
|
||
def normalize_features(self, features: torch.Tensor) -> torch.Tensor: | ||
return (features - self.mean) / self.stdvar | ||
|
||
def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor): | ||
assert self.coefficients is not None | ||
x = x.to(self.device) | ||
f = self.model.features(x) | ||
if self.normalize: | ||
f = self.normalize_features(f) | ||
crosscorr = torch.matmul(f, self.normalized_samples.T) | ||
crosscorr = crosscorr[:, :, None] | ||
xpl = self.coefficients * crosscorr | ||
indices = explanation_targets[:, None, None].expand(-1, self.samples.shape[0], 1) | ||
xpl = torch.gather(xpl, dim=-1, index=indices) | ||
return torch.squeeze(xpl) | ||
|
||
def save_coefs(self, dir: str): | ||
torch.save(self.coefficients, os.path.join(dir, f"{self.name}_coefs")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import time | ||
|
||
import torch | ||
|
||
|
||
def explain(self, x, preds=None, targets=None): | ||
assert not ((targets is None) and (self.loss is not None)) | ||
xpl = torch.zeros((x.shape[0], len(self.dataset)), dtype=torch.float) | ||
xpl = xpl.to(self.device) | ||
t = time.time() | ||
for j in range(len(self.dataset)): | ||
tr_sample, y = self.dataset[j] | ||
train_grad = self.get_param_grad(tr_sample, y) | ||
train_grad = train_grad / torch.norm(train_grad) | ||
train_grad.to(self.device) | ||
for i in range(x.shape[0]): | ||
if self.loss is None: | ||
test_grad = self.get_param_grad(x[i], preds[i]) | ||
else: | ||
test_grad = self.get_param_grad(x[i], targets[i]) | ||
test_grad.to(self.device) | ||
xpl[i, j] = torch.matmul(train_grad, test_grad) | ||
if j % 1000 == 0: | ||
tdiff = time.time() - t | ||
mins = int(tdiff / 60) | ||
print(f"{int(j / 1000)}/{int(len(self.dataset) / 1000)}k- 1000 images done in {mins} minutes {tdiff - 60 * mins}") | ||
t = time.time() | ||
return xpl |
Oops, something went wrong.