Skip to content

Commit

Permalink
Merge branch 'main' into randomization_metric
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu authored May 18, 2024
2 parents 0cf2dcf + ab500bc commit ba0a618
Show file tree
Hide file tree
Showing 34 changed files with 885 additions and 257 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ SHELL = /bin/bash
.PHONY: style
style:
black .
flake8 .
python -m isort .
rm -f .coverage
rm -f .coverage.*
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ keywords = ["explainable ai", "xai", "machine learning", "deep learning"]
dependencies = [
"numpy>=1.19.5",
"torch>=1.13.1",
"captum>=0.6.0",
]
dynamic = ["version"]

Expand Down
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[pytest]
markers =
utils: utils files
explainers: explainers
localization_metrics: localization_metrics
unnamed_metrics: unnamed_metrics
10 changes: 6 additions & 4 deletions src/explainers/explain_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from captum.influence import SimilarityInfluence
from captum.influence._core.similarity_influence import cosine_similarity

from utils.functions.similarities import cosine_similarity


def explain(
Expand Down Expand Up @@ -29,17 +30,18 @@ def explain(
sim_metric = kwargs.get("similarity_metric", cosine_similarity)
sim_direction = kwargs.get("similarity_direction", "max")
batch_size = kwargs.get("batch_size", 1)
top_k = kwargs.get("top_k", test_tensor.shape[0])

sim_influence = SimilarityInfluence(
module=model,
layers=[layer],
layers=layer,
influence_src_dataset=train_dataset,
activation_dir=cache_dir,
model_id=model_id,
similarity_metric=sim_metric,
similarity_direction=sim_direction,
batch_size=batch_size,
)
topk_idx, topk_val = sim_influence.influence(test_tensor, len(train_dataset))[layer]
tda = torch.gather(topk_val, 1, topk_idx)

return sim_influence.influence(test_tensor, top_k)[layer]
return tda
30 changes: 0 additions & 30 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +0,0 @@
from abc import ABC, abstractmethod
from typing import Union

import numpy as np
import torch


class Metric(ABC):
name = "BaseMetricClass"

@abstractmethod
def __init__(self, train: torch.utils.data.Dataset, test: torch.utils.data.Dataset):
pass

@abstractmethod
def __call__(self, *args, **kwargs):
pass

@abstractmethod
def get_result(self, dir: str):
pass

@staticmethod
def to_float(results: Union[dict, str, torch.Tensor]) -> Union[dict, str, torch.Tensor]:
if isinstance(results, dict):
return {key: Metric.to_float(r) for key, r in results.items()}
elif isinstance(results, str):
return results
else:
return np.array(results).astype(float).tolist()
45 changes: 8 additions & 37 deletions src/metrics/base.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,38 @@
from abc import ABC, abstractmethod

import torch


class Metric(ABC):
def __init__(self, *args, **kwargs):
pass
def __init__(self, device, *args, **kwargs):
self.device = device

@abstractmethod
def __call__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: str, # TODO: maybe cache is not the best notation?
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: torch.utils.data.Dataset,
# TODO: should it be a tensor or dataset? For large datasets, storing the whole thing in RAM might be difficult.
*args,
**kwargs,
):
"""
Here include some general steps, incl.:
1) Universal assertions about the passed arguments, incl. checking that the length of train/test datset and
explanations match.
2) Call the _evaluate method.
3) Format the output into a unified format for all metrics, possible using some arguments passed in kwargs.
:param model:
:param model_id:
:param cache_dir:
:param train_dataset:
:param test_dataset:
:param explanations:
:param kwargs:
:return:
"""
raise NotImplementedError

@abstractmethod
def _evaluate(
def _evaluate_instance(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: torch.utils.data.Dataset,
*args,
**kwargs,
):
"""
Used to implement metric-specific logic.
"""

raise NotImplementedError

@staticmethod
@abstractmethod
def _format(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: torch.utils.data.Dataset,
):
"""
Format the output of the metric to a predefined format, maybe string?
"""

raise NotImplementedError
File renamed without changes.
75 changes: 75 additions & 0 deletions src/metrics/localization/identical_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Optional, Union

import torch

from metrics.base import Metric
from src.utils.explanations import (
BatchedCachedExplanations,
TensorExplanations,
)
from utils.cache import ExplanationsCache as EC


class IdenticalClass(Metric):
def __init__(self, device, *args, **kwargs):
super().__init__(device, *args, **kwargs)

def __call__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
test_labels: torch.Tensor,
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./",
batch_size: Optional[int] = 8,
**kwargs,
):
"""
:param test_labelsictions:
:param explanations:
:param saved_explanations_batch_size:
:param kwargs:
:return:
"""

if isinstance(explanations, str):
explanations = EC.load(path=explanations, device=self.device)
elif isinstance(explanations, torch.Tensor):
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device)

scores = []
n_processed = 0
for i in range(len(explanations)):
assert n_processed + explanations[i].shape[0] <= len(
test_labels
), f"Number of explanations ({n_processed + explanations[i].shape[0]}) exceeds the number of test labels."

score = self._evaluate_instance(
model=model,
train_dataset=train_dataset,
test_labels=test_labels[n_processed : n_processed + explanations[i].shape[0]],
xpl=explanations[i],
)
scores.append(score)
n_processed += explanations[i].shape[0]

return {"score": torch.cat(scores).mean()}

def _evaluate_instance(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
test_labels: torch.Tensor,
xpl: torch.Tensor,
):
"""
Used to implement metric-specific logic.
"""

top_one_xpl_indices = xpl.argmax(dim=1)
top_one_xpl_samples = torch.stack([train_dataset[i][0] for i in top_one_xpl_indices])

top_one_xpl_output = model(top_one_xpl_samples.to(self.device))
top_one_xpl_pred = top_one_xpl_output.argmax(dim=1)

return (test_labels == top_one_xpl_pred) * 1.0
Empty file added src/metrics/unnamed/__init__.py
Empty file.
72 changes: 72 additions & 0 deletions src/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import warnings
from typing import Optional, Union

import torch

from metrics.base import Metric
from src.utils.explanations import (
BatchedCachedExplanations,
TensorExplanations,
)
from utils.cache import ExplanationsCache as EC


class TopKOverlap(Metric):
def __init__(self, device, *args, **kwargs):
super().__init__(device, *args, **kwargs)

def __call__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
top_k: int = 1,
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./",
batch_size: Optional[int] = 8,
**kwargs,
):
"""
:param test_predictions:
:param explanations:
:param batch_size:
:param kwargs:
:return:
"""

if isinstance(explanations, str):
explanations = EC.load(path=explanations, device=self.device)
if explanations.batch_size != batch_size:
warnings.warn(
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch "
"size will be used instead."
)
batch_size = explanations[0]
elif isinstance(explanations, torch.Tensor):
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device)

all_top_k_examples = []

for i in range(len(explanations)):
top_k_examples = self._evaluate_instance(
xpl=explanations[i],
top_k=top_k,
)
all_top_k_examples += top_k_examples

# calculate the cardinality of the set of top-k examples
cardinality = len(set(all_top_k_examples))

# TODO: calculate the probability of the set of top-k examples
return {"score": cardinality}

def _evaluate_instance(
self,
xpl: torch.Tensor,
top_k: int = 1,
):
"""
Used to implement metric-specific logic.
"""

top_k_indices = torch.topk(xpl, top_k).indices
return top_k_indices
44 changes: 39 additions & 5 deletions src/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from utils.common import _get_module_from_name
from utils.datasets.activation_dataset import ActivationDataset
from utils.explanations import BatchedCachedExplanations


class Cache:
Expand All @@ -33,26 +34,59 @@ def exists(**kwargs) -> bool:
raise NotImplementedError


class IndicesCache(Cache):
class TensorCache(Cache):
def __init__(self):
super().__init__()

@staticmethod
def save(path, file_id, indices) -> None:
def save(path: str, file_id: str, indices: Tensor) -> None:
file_path = os.path.join(path, file_id)
return torch.save(indices, file_path)

@staticmethod
def load(path, file_id, device="cpu") -> Tensor:
def load(path: str, file_id: str, device: str = "cpu") -> Tensor:
file_path = os.path.join(path, file_id)
return torch.load(file_path, map_location=device)

@staticmethod
def exists(path, file_id) -> bool:
def exists(path: str, file_id: str, num_id: int) -> bool:
file_path = os.path.join(path, file_id)
return os.path.isfile(file_path)



class ExplanationsCache(Cache):
def __init__(self):
super().__init__()

@staticmethod
def exists(
path: str,
num_id: Optional[Union[str, int]] = None,
) -> bool:
av_filesearch = os.path.join(path, "*.pt" if num_id is None else f"{num_id}.pt")
return os.path.exists(path) and len(glob.glob(av_filesearch)) > 0

@staticmethod
def save(
path: str,
exp_tensors: List[Tensor],
num_id: Union[str, int],
) -> None:
av_save_fl_path = os.path.join(path, f"{num_id}.pt")
torch.save(exp_tensors, av_save_fl_path)

@staticmethod
def load(
path: str,
device: str = "cpu",
) -> BatchedCachedExplanations:
if os.path.exists(path):
xpl_dataset = BatchedCachedExplanations(cache_dir=path, device=device)
return xpl_dataset
else:
raise RuntimeError(f"Activation vectors were not found at path {path}")

class ActivationsCache(Cache):
"""
Inspired by https://github.com/pytorch/captum/blob/master/captum/_utils/av.py.
Expand Down Expand Up @@ -99,7 +133,7 @@ def load(
) -> ActivationDataset:
layer_dir = os.path.join(path, layer)

if not os.path.exists(layer_dir):
if os.path.exists(layer_dir):
av_dataset = ActivationDataset(layer_dir, device)
return av_dataset
else:
Expand Down
Loading

0 comments on commit ba0a618

Please sign in to comment.