Skip to content

Commit

Permalink
Merge branch 'main' into training_v2
Browse files Browse the repository at this point in the history
# Conflicts:
#	pytest.ini
#	src/metrics/randomization/model_randomization.py
  • Loading branch information
dilyabareeva committed Jun 17, 2024
2 parents fc6d75d + f06f535 commit 78f3e53
Show file tree
Hide file tree
Showing 12 changed files with 219 additions and 43 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@ markers =
unnamed_metrics: unnamed_metrics
randomization_metrics: randomization_metrics
downstream_tasks: downstream_tasks
aggregators: aggregators
self_influence: self_influence
43 changes: 43 additions & 0 deletions src/explainers/aggregators/aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from abc import ABC, abstractmethod

import torch


class ExplanationsAggregator(ABC):
def __init__(self, training_size: int, *args, **kwargs):
self.scores = torch.zeros(training_size)

@abstractmethod
def update(self, explanations: torch.Tensor):
raise NotImplementedError

def reset(self, *args, **kwargs):
"""
Used to reset the aggregator state.
"""
self.scores = torch.zeros_like(self.scores)

def load_state_dict(self, state_dict: dict, *args, **kwargs):
"""
Used to load the aggregator state.
"""
self.scores = state_dict["scores"]

def state_dict(self, *args, **kwargs):
"""
Used to return the metric state.
"""
return {"scores": self.scores}

def compute(self) -> torch.Tensor:
return self.scores.argsort()


class SumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.sum(dim=0)


class AbsSumAggregator(ExplanationsAggregator):
def update(self, explanations: torch.Tensor) -> torch.Tensor:
self.scores += explanations.abs().sum(dim=0)
30 changes: 30 additions & 0 deletions src/explainers/aggregators/self_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Optional

import torch

from utils.explain_wrapper import ExplainFunc


def get_self_influence_ranking(
model: torch.nn.Module,
model_id: str,
cache_dir: str,
training_data: torch.utils.data.Dataset,
explain_fn: ExplainFunc,
explain_fn_kwargs: Optional[dict] = None,
) -> torch.Tensor:
size = len(training_data)
self_inf = torch.zeros((size,))

for i, (x, y) in enumerate(training_data):
self_inf[i] = explain_fn(
model=model,
model_id=f"{model_id}_id_{i}",
cache_dir=cache_dir,
test_tensor=x[None],
test_label=y[None],
train_dataset=training_data,
train_ids=[i],
**explain_fn_kwargs,
)
return self_inf.argsort()
35 changes: 22 additions & 13 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,16 @@ def __init__(
# do we want the exact same random model to be attributed (keeping seed in the __call__ call)
# or do we want genuinely random models for each call of the metric (keeping seed in the constructor)
self.generator = torch.Generator(device=device)
self.generator.manual_seed(self.seed)
self.rand_model = self._randomize_model(model)
self.explain_fn = make_func(
func=explain_fn,
func_kwargs=explain_fn_kwargs,
model=self.rand_model,
model_id=self.model_id,
cache_dir=self.cache_dir,
train_dataset=self.train_dataset,
)

self.results = {"rank_correlations": []}

if isinstance(correlation_fn, str) and correlation_fn in correlation_functions:
Expand All @@ -73,29 +77,34 @@ def update(
test_data: torch.Tensor,
explanations: torch.Tensor,
):
rand_explanations = self.explain_fn(
model=self.rand_model,
model_id=self.model_id,
cache_dir=self.cache_dir,
train_dataset=self.train_dataset,
test_tensor=test_data,
)
rand_explanations = self.explain_fn(model=self.rand_model, test_tensor=test_data)
corrs = self.correlation_measure(explanations, rand_explanations)
self.results["rank_correlations"].append(corrs)

def compute(
self,
):
def compute(self):
return torch.cat(self.results["rank_correlations"]).mean()

def reset(self):
self.results = {"rank_correlations": []}
self.generator.manual_seed(self.seed)
self.rand_model = self._randomize_model(self.model)

def state_dict(self):
return self.results
state_dict = {
"results_dict": self.results,
"random_model_state_dict": self.model.state_dict(),
"seed": self.seed,
"generator_state": self.generator.get_state(),
"explain_fn": self.explain_fn,
}
return state_dict

def load_state_dict(self, state_dict: dict):
self.results = state_dict
self.results = state_dict["results_dict"]
self.seed = state_dict["seed"]
self.explain_fn = state_dict["explain_fn"]
self.rand_model.load_state_dict(state_dict["random_model_state_dict"])
self.generator.set_state(state_dict["generator_state"])

def _randomize_model(self, model: torch.nn.Module):
rand_model = copy.deepcopy(model)
Expand Down
2 changes: 1 addition & 1 deletion src/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def compute(self, *args, **kwargs):
return len(torch.unique(self.all_top_k_examples))

def reset(self, *args, **kwargs):
self.all_top_k_examples = []
self.all_top_k_examples = torch.empty(0, self.top_k)

def load_state_dict(self, state_dict: dict, *args, **kwargs):
self.all_top_k_examples = state_dict["all_top_k_examples"]
Expand Down
21 changes: 0 additions & 21 deletions src/usage.py

This file was deleted.

6 changes: 4 additions & 2 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Any, Callable, Mapping

import torch
import torch.utils
import torch.utils.data


def _get_module_from_name(model: torch.nn.Module, layer_name: str) -> Any:
Expand All @@ -19,6 +21,6 @@ def make_func(func: Callable, func_kwargs: Mapping[str, ...] | None, **kwargs) -
_func_kwargs = kwargs.copy()
_func_kwargs.update(func_kwargs)
else:
func_kwargs = kwargs
_func_kwargs = kwargs

return functools.partial(func, **func_kwargs)
return functools.partial(func, **_func_kwargs)
16 changes: 11 additions & 5 deletions src/utils/explain_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional, Protocol
from typing import List, Optional, Protocol, Union

import torch
from captum.influence import SimilarityInfluence

from src.utils.datasets.indexed_subset import IndexedSubset
from src.utils.functions.similarities import cosine_similarity


Expand All @@ -12,20 +13,23 @@ def __call__(
model: torch.nn.Module,
model_id: str,
cache_dir: Optional[str],
train_dataset: torch.utils.data.Dataset,
test_tensor: torch.Tensor,
method: str,
test_tensor: torch.Tensor,
train_dataset: torch.utils.data.Dataset,
train_ids: Optional[Union[List[int], torch.Tensor]] = None,
) -> torch.Tensor:
...
pass


def explain(
model: torch.nn.Module,
model_id: str,
cache_dir: str,
method: str,
train_dataset: torch.utils.data.Dataset,
test_tensor: torch.Tensor,
method: str,
test_target: Optional[torch.Tensor] = None,
train_ids: Optional[Union[List[int], torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
"""
Expand All @@ -41,6 +45,8 @@ def explain(
:return:
"""
if method == "SimilarityInfluence":
if train_ids is not None:
train_dataset = IndexedSubset(dataset=train_dataset, indices=train_ids)
layer = kwargs.get("layer", "features")
sim_metric = kwargs.get("similarity_metric", cosine_similarity)
sim_direction = kwargs.get("similarity_direction", "max")
Expand Down
20 changes: 20 additions & 0 deletions src/utils/functions/similarities.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,23 @@ def cosine_similarity(test, train, replace_nan=0) -> Tensor:

similarity = torch.mm(test, train)
return similarity


def dot_product_similarity(test, train, replace_nan=0) -> Tensor:
"""
Compute cosine similarity between test and train activations.
:param test:
:param train:
:param replace_nan:
:return:
"""
# TODO: I don't know why Captum return test activations as a list
if isinstance(test, list):
test = torch.cat(test)
assert torch.all(test == train)
test = test.view(test.shape[0], -1)
train = train.view(train.shape[0], -1)

similarity = torch.mm(test, train.T)
return similarity
47 changes: 47 additions & 0 deletions tests/explainers/aggregators/test_aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import torch

from src.explainers.aggregators.aggregators import (
AbsSumAggregator,
SumAggregator,
)


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_sum_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = SumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.sum(dim=0).argsort())


@pytest.mark.aggregators
@pytest.mark.parametrize(
"test_id, dataset, explanations",
[
(
"mnist",
"load_mnist_dataset",
"load_mnist_explanations_1",
),
],
)
def test_abs_aggregator(test_id, dataset, explanations, request):
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
aggregator = AbsSumAggregator(training_size=len(dataset))
aggregator.update(explanations)
global_rank = aggregator.compute()
assert torch.allclose(global_rank, explanations.abs().mean(dim=0).argsort())
38 changes: 38 additions & 0 deletions tests/explainers/aggregators/test_self_influence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections import OrderedDict

import pytest
import torch
from torch.utils.data import TensorDataset

from src.explainers.aggregators.self_influence import (
get_self_influence_ranking,
)
from src.utils.explain_wrapper import explain
from src.utils.functions.similarities import dot_product_similarity


@pytest.mark.self_influence
@pytest.mark.parametrize(
"test_id, explain_kwargs",
[
(
"random_data",
{"method": "SimilarityInfluence", "layer": "identity", "similarity_metric": dot_product_similarity},
),
],
)
def test_self_influence_ranking(test_id, explain_kwargs, request):
model = torch.nn.Sequential(OrderedDict([("identity", torch.nn.Identity())]))
X = torch.randn(100, 200)
rand_dataset = TensorDataset(X, torch.randint(0, 10, (100,)))

self_influence_rank = get_self_influence_ranking(
model=model,
model_id="0",
cache_dir="temp_captum",
training_data=rand_dataset,
explain_fn=explain,
explain_fn_kwargs=explain_kwargs,
)

assert torch.allclose(self_influence_rank, torch.linalg.norm(X, dim=-1).argsort())
2 changes: 1 addition & 1 deletion tests/utils/test_explain_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def test_explain(
model,
test_id,
os.path.join("./cache", "test_id"),
method,
dataset,
test_tensor,
method,
**method_kwargs,
)
assert torch.allclose(explanations, explanations_exp), "Training data attributions are not as expected"

0 comments on commit 78f3e53

Please sign in to comment.