Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explainer base #38

Merged
merged 33 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e9b0554
base class
gumityolcu Jun 14, 2024
0ee06fe
small changes
gumityolcu Jun 14, 2024
9767b30
Captum interface and SimilarityInfluence
gumityolcu Jun 14, 2024
3e98e9f
Last commit actually here
gumityolcu Jun 14, 2024
367057a
delete leftover assertion
gumityolcu Jun 17, 2024
e9e4afe
move self_influence files
gumityolcu Jun 17, 2024
5d53cd5
randomization test works
gumityolcu Jun 17, 2024
b63ffdd
fix similarity explainer output construciton logic
gumityolcu Jun 18, 2024
09c73e7
update with correct cos distance explanations
gumityolcu Jun 18, 2024
c5ddc04
fix and add tests
gumityolcu Jun 18, 2024
779c0ff
Add missing files
gumityolcu Jun 18, 2024
299b4d3
fix test script
gumityolcu Jun 18, 2024
4343282
fix test headers
gumityolcu Jun 18, 2024
0a09a1a
delete unneeded file
gumityolcu Jun 18, 2024
19a7b3f
fix test
gumityolcu Jun 18, 2024
ee8d8eb
fix test
gumityolcu Jun 18, 2024
888aefd
fixing explainer tests
dilyabareeva Jun 18, 2024
bfe7135
fixing failing tests
dilyabareeva Jun 20, 2024
8c83a1f
Revert "fixing explainer tests"
dilyabareeva Jun 20, 2024
0a477f5
refactoring captum wrappers
dilyabareeva Jun 20, 2024
950b8b8
make kwargs **kwargs again
dilyabareeva Jun 20, 2024
5cded75
sort explain fns
dilyabareeva Jun 20, 2024
ad2834c
fixing some tests
dilyabareeva Jun 20, 2024
8725b9e
Merge branch 'explainer_base' into explainer_base_dilya
dilyabareeva Jun 20, 2024
ce8b538
resolve merge commit bugs
dilyabareeva Jun 21, 2024
c526bf0
self-influence refactoring
dilyabareeva Jun 21, 2024
802a5c2
using tmp_path fixture for tests
dilyabareeva Jun 21, 2024
c1b4e49
using tmp_path fixture for tests cont.
dilyabareeva Jun 21, 2024
7d61b79
fix a tox issue
dilyabareeva Jun 21, 2024
d53320e
make self-influence return attrs + small fixes
dilyabareeva Jun 21, 2024
76658c8
Merge pull request #41 from dilyabareeva/explainer_base_dilya
dilyabareeva Jun 21, 2024
8141689
merge main
dilyabareeva Jun 21, 2024
c23bb25
after-merge clean-up
dilyabareeva Jun 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SHELL = /bin/bash
.PHONY: style
style:
black .
flake8 .
flake8 . --pytest-parametrize-names-type=csv
python -m isort .
rm -f .coverage
rm -f .coverage.*
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dev = [ # Install wtih pip install .[dev] or pip install -e '.[dev]' in zsh
"coverage>=7.2.3",
"flake8>=6.0.0",
"pytest<=7.4.4",
"flake8-pytest-style>=1.3.2",
"pytest-cov>=4.0.0",
"pytest-mock==3.10.0",
"pre-commit>=3.2.0",
Expand Down
8 changes: 5 additions & 3 deletions src/downstream_tasks/subclass_identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import torch

from explainers.functional import ExplainFunc
from explainers.wrappers.captum_influence import captum_similarity_explain
from metrics.localization.identical_class import IdenticalClass
from utils.datasets.group_label_dataset import (
ClassToGroupLiterals,
GroupLabelDataset,
)
from utils.explain_wrapper import explain
from utils.training.trainer import BaseTrainer, Trainer


Expand Down Expand Up @@ -41,7 +42,7 @@ def evaluate(
n_classes: int = 10,
n_groups: int = 2,
class_to_group: Union[ClassToGroupLiterals, Dict[int, int]] = "random",
explain_fn: Callable = explain,
explain_fn: ExplainFunc = captum_similarity_explain,
explain_kwargs: Optional[dict] = None,
trainer_kwargs: Optional[dict] = None,
cache_dir: str = "./cache",
Expand Down Expand Up @@ -99,7 +100,8 @@ def evaluate(
cache_dir=os.path.join(cache_dir, run_id),
train_dataset=train_dataset,
test_tensor=input,
**explain_kwargs,
init_kwargs=explain_kwargs,
device=device,
)
metric.update(labels, explanations)

Expand Down
53 changes: 53 additions & 0 deletions src/explainers/aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from abc import ABC, abstractmethod
from typing import Optional

import torch


class BaseAggregator(ABC):
def __init__(self):
self.scores: Optional[torch.Tensor] = None

@abstractmethod
def update(self, explanations: torch.Tensor):
raise NotImplementedError

def _validate_explanations(self, explanations: torch.Tensor):
if self.scores is None:
self.scores = torch.zeros(explanations.shape[1])

if explanations.shape[1] != self.scores.shape[0]:
raise ValueError(f"Explanations shape {explanations.shape} does not match the expected shape {self.scores.shape}")

def reset(self, *args, **kwargs):
"""
Used to reset the aggregator state.
"""
self.scores: torch.Tensor = None

def load_state_dict(self, state_dict: dict, *args, **kwargs):
"""
Used to load the aggregator state.
"""
self.scores = state_dict["scores"]

def state_dict(self, *args, **kwargs):
"""
Used to return the metric state.
"""
return {"scores": self.scores}

def compute(self) -> torch.Tensor:
return self.scores.argsort()


class SumAggregator(BaseAggregator):
def update(self, explanations: torch.Tensor):
self._validate_explanations(explanations)
self.scores += explanations.sum(dim=0)


class AbsSumAggregator(BaseAggregator):
def update(self, explanations: torch.Tensor):
self._validate_explanations(explanations)
self.scores += explanations.abs().sum(dim=0)
43 changes: 0 additions & 43 deletions src/explainers/aggregators/aggregators.py

This file was deleted.

30 changes: 0 additions & 30 deletions src/explainers/aggregators/self_influence.py

This file was deleted.

47 changes: 33 additions & 14 deletions src/explainers/base.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,49 @@
from abc import ABC, abstractmethod
from typing import Union
from typing import Any, List, Optional, Union

import torch

from src.utils.common import cache_result

class Explainer(ABC):

class BaseExplainer(ABC):
def __init__(
self,
model: torch.nn.Module,
dataset: torch.data.utils.Dataset,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
device: Union[str, torch.device],
**kwargs,
):
self.model = model
self.model.to(device)

self.model_id = model_id
self.cache_dir = cache_dir
self.train_dataset = train_dataset
self.device = torch.device(device) if isinstance(device, str) else device
self.images = dataset
self.samples = []
self.labels = []
dev = torch.device(device)
self.model.to(dev)

@abstractmethod
def explain(self, x: torch.Tensor, explanation_targets: torch.Tensor) -> torch.Tensor:
pass
def explain(self, test: torch.Tensor, targets: Optional[Union[List[int], torch.Tensor]] = None, **kwargs: Any):
raise NotImplementedError

@cache_result
def self_influence(self, batch_size: Optional[int] = 32, **kwargs: Any) -> torch.Tensor:
"""
Base class implements computing self influences by explaining the train dataset one by one

:param batch_size:
:param kwargs:
:return:
"""

# Pre-allcate memory for influences, because torch.cat is slow
influences = torch.empty((len(self.train_dataset),), device=self.device)
ldr = torch.utils.data.DataLoader(self.train_dataset, shuffle=False, batch_size=batch_size)

def train(self) -> None:
pass
for i, (x, y) in zip(range(0, len(self.train_dataset), batch_size), ldr):
explanations = self.explain(test=x.to(self.device), **kwargs)
influences[i : i + batch_size] = explanations.diag(diagonal=i)

def save_coefs(self, dir_path: str) -> None:
pass
return influences
70 changes: 0 additions & 70 deletions src/explainers/feature_kernel_explainer.py

This file was deleted.

19 changes: 19 additions & 0 deletions src/explainers/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Dict, List, Optional, Protocol, Union

import torch


class ExplainFunc(Protocol):
def __call__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
test_tensor: torch.Tensor,
explanation_targets: Optional[Union[List[int], torch.Tensor]],
train_dataset: torch.utils.data.Dataset,
explain_kwargs: Dict,
init_kwargs: Dict,
device: Union[str, torch.device],
) -> torch.Tensor:
pass
Loading
Loading