Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WiP: Globalization from local explainer #24

Merged
merged 19 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ markers =
localization_metrics: localization_metrics
unnamed_metrics: unnamed_metrics
randomization_metrics: randomization_metrics
aggregators: aggregators
self_influence: self_influence
43 changes: 43 additions & 0 deletions src/explainers/aggregators/aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod

import torch


class ExplanationsAggregator(ABC):
def __init__(self, training_size: int, *args, **kwargs):
self.scores = torch.zeros(training_size)

@abstractmethod
def update(self, explanations: torch.Tensor):
raise NotImplementedError

def reset(self, *args, **kwargs):
"""
Used to reset the aggregator state.
"""
self.scores = torch.zeros_like(self.scores)

def load_state_dict(self, state_dict: dict, *args, **kwargs):
"""
Used to load the aggregator state.
"""
self.scores = state_dict["scores"]

def state_dict(self, *args, **kwargs):
"""
Used to return the metric state.
"""
return {"scores": self.scores}

def compute(self) -> torch.Tensor:
return self.scores.argsort()


class SumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.sum(dim=0)


class AbsSumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.abs().sum(dim=0)
30 changes: 30 additions & 0 deletions src/explainers/aggregators/self_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

import torch

from utils.explain_wrapper import ExplainFunc


def get_self_influence_ranking(
model: torch.nn.Module,
model_id: str,
cache_dir: str,
training_data: torch.utils.data.Dataset,
explain_fn: ExplainFunc,
explain_fn_kwargs: Optional[dict] = None,
) -> torch.Tensor:
size = len(training_data)
self_inf = torch.zeros((size,))

for i, (x, y) in enumerate(training_data):
self_inf[i] = explain_fn(
model=model,
model_id=f"{model_id}_id_{i}",
cache_dir=cache_dir,
test_tensor=x[None],
test_label=y[None],
train_dataset=training_data,
train_ids=[i],
**explain_fn_kwargs,
)
return self_inf.argsort()
6 changes: 1 addition & 5 deletions src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ def __init__(
super().__init__(model, train_dataset, device, *args, **kwargs)
self.scores = []

def update(
self,
test_labels: torch.Tensor,
explanations: torch.Tensor
):
def update(self, test_labels: torch.Tensor, explanations: torch.Tensor):
"""
Used to implement metric-specific logic.
"""
Expand Down
4 changes: 1 addition & 3 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ def update(
corrs = self.correlation_measure(explanations, rand_explanations)
self.results["rank_correlations"].append(corrs)

def compute(
self,
):
def compute(self):
return torch.cat(self.results["rank_correlations"]).mean()

def reset(self):
Expand Down
2 changes: 1 addition & 1 deletion src/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def compute(self, *args, **kwargs):
return len(torch.unique(self.all_top_k_examples))

def reset(self, *args, **kwargs):
self.all_top_k_examples = []
self.all_top_k_examples = torch.empty(0, self.top_k)

def load_state_dict(self, state_dict: dict, *args, **kwargs):
self.all_top_k_examples = state_dict["all_top_k_examples"]
Expand Down
2 changes: 2 additions & 0 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Mapping

import torch
import torch.utils
import torch.utils.data


def _get_module_from_name(model: torch.nn.Module, layer_name: str) -> Any:
Expand Down
16 changes: 11 additions & 5 deletions src/utils/explain_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional, Protocol
from typing import List, Optional, Protocol, Union

import torch
from captum.influence import SimilarityInfluence

from src.utils.datasets.indexed_subset import IndexedSubset
from src.utils.functions.similarities import cosine_similarity


Expand All @@ -12,20 +13,23 @@ def __call__(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
test_tensor: torch.Tensor,
method: str,
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
train_ids: Optional[Union[List[int], torch.Tensor]] = None,
) -> torch.Tensor:
...
pass


def explain(
model: torch.nn.Module,
gumityolcu marked this conversation as resolved.
Show resolved Hide resolved
model_id: str,
cache_dir: str,
method: str,
train_dataset: torch.utils.data.Dataset,
test_tensor: torch.Tensor,
method: str,
test_target: Optional[torch.Tensor] = None,
train_ids: Optional[Union[List[int], torch.Tensor]] = None,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between test_target and test_tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_target is which output we are explaining. There is no such notion for similarity of intermediate representations, but in general explainers can be used to explain different outputs.

**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -41,6 +45,8 @@ def explain(
:return:
"""
if method == "SimilarityInfluence":
if train_ids is not None:
train_dataset = IndexedSubset(dataset=train_dataset, indices=train_ids)
layer = kwargs.get("layer", "features")
sim_metric = kwargs.get("similarity_metric", cosine_similarity)
sim_direction = kwargs.get("similarity_direction", "max")
Expand Down
20 changes: 20 additions & 0 deletions src/utils/functions/similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,23 @@ def cosine_similarity(test, train, replace_nan=0) -> Tensor:

similarity = torch.mm(test, train)
return similarity


def dot_product_similarity(test, train, replace_nan=0) -> Tensor:
"""
Compute cosine similarity between test and train activations.

:param test:
:param train:
:param replace_nan:
:return:
"""
# TODO: I don't know why Captum return test activations as a list
if isinstance(test, list):
test = torch.cat(test)
assert torch.all(test == train)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gumityolcu, please let's be more careful 😄

test = test.view(test.shape[0], -1)
train = train.view(train.shape[0], -1)

similarity = torch.mm(test, train.T)
return similarity
47 changes: 47 additions & 0 deletions tests/explainers/aggregators/test_aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import torch

from src.explainers.aggregators.aggregators import (
AbsSumAggregator,
SumAggregator,
)


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_sum_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = SumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.sum(dim=0).argsort())


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_abs_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = AbsSumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort())
38 changes: 38 additions & 0 deletions tests/explainers/aggregators/test_self_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections import OrderedDict

import pytest
import torch
from torch.utils.data import TensorDataset

from src.explainers.aggregators.self_influence import (
get_self_influence_ranking,
)
from src.utils.explain_wrapper import explain
from src.utils.functions.similarities import dot_product_similarity


@pytest.mark.self_influence
@pytest.mark.parametrize(
"test_id, explain_kwargs",
[
(
"random_data",
{"method": "SimilarityInfluence", "layer": "identity", "similarity_metric": dot_product_similarity},
),
],
)
def test_self_influence_ranking(test_id, explain_kwargs, request):
model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())]))
X = torch.randn(100, 200)
rand_dataset = TensorDataset(X, torch.randint(0, 10, (100,)))

self_influence_rank = get_self_influence_ranking(
model=model,
model_id="0",
cache_dir="temp_captum",
training_data=rand_dataset,
explain_fn=explain,
explain_fn_kwargs=explain_kwargs,
)

assert torch.allclose(self_influence_rank, torch.linalg.norm(X, dim=-1).argsort())
2 changes: 1 addition & 1 deletion tests/utils/test_explain_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def test_explain(test_id, model, dataset, explanations, test_tensor, method, met
model,
test_id,
os.path.join("./cache", "test_id"),
method,
dataset,
test_tensor,
method,
**method_kwargs,
)
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"
Loading