-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into randomization_metric
- Loading branch information
Showing
34 changed files
with
885 additions
and
257 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
[pytest] | ||
markers = | ||
utils: utils files | ||
explainers: explainers | ||
localization_metrics: localization_metrics | ||
unnamed_metrics: unnamed_metrics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +0,0 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Union | ||
|
||
import numpy as np | ||
import torch | ||
|
||
|
||
class Metric(ABC): | ||
name = "BaseMetricClass" | ||
|
||
@abstractmethod | ||
def __init__(self, train: torch.utils.data.Dataset, test: torch.utils.data.Dataset): | ||
pass | ||
|
||
@abstractmethod | ||
def __call__(self, *args, **kwargs): | ||
pass | ||
|
||
@abstractmethod | ||
def get_result(self, dir: str): | ||
pass | ||
|
||
@staticmethod | ||
def to_float(results: Union[dict, str, torch.Tensor]) -> Union[dict, str, torch.Tensor]: | ||
if isinstance(results, dict): | ||
return {key: Metric.to_float(r) for key, r in results.items()} | ||
elif isinstance(results, str): | ||
return results | ||
else: | ||
return np.array(results).astype(float).tolist() | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,67 +1,38 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import torch | ||
|
||
|
||
class Metric(ABC): | ||
def __init__(self, *args, **kwargs): | ||
pass | ||
def __init__(self, device, *args, **kwargs): | ||
self.device = device | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, | ||
model: torch.nn.Module, | ||
model_id: str, | ||
cache_dir: str, # TODO: maybe cache is not the best notation? | ||
train_dataset: torch.utils.data.Dataset, | ||
test_dataset: torch.utils.data.Dataset, | ||
explanations: torch.utils.data.Dataset, | ||
# TODO: should it be a tensor or dataset? For large datasets, storing the whole thing in RAM might be difficult. | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Here include some general steps, incl.: | ||
1) Universal assertions about the passed arguments, incl. checking that the length of train/test datset and | ||
explanations match. | ||
2) Call the _evaluate method. | ||
3) Format the output into a unified format for all metrics, possible using some arguments passed in kwargs. | ||
:param model: | ||
:param model_id: | ||
:param cache_dir: | ||
:param train_dataset: | ||
:param test_dataset: | ||
:param explanations: | ||
:param kwargs: | ||
:return: | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _evaluate( | ||
def _evaluate_instance( | ||
self, | ||
model: torch.nn.Module, | ||
train_dataset: torch.utils.data.Dataset, | ||
test_dataset: torch.utils.data.Dataset, | ||
explanations: torch.utils.data.Dataset, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Used to implement metric-specific logic. | ||
""" | ||
|
||
raise NotImplementedError | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def _format( | ||
self, | ||
model: torch.nn.Module, | ||
train_dataset: torch.utils.data.Dataset, | ||
test_dataset: torch.utils.data.Dataset, | ||
explanations: torch.utils.data.Dataset, | ||
): | ||
""" | ||
Format the output of the metric to a predefined format, maybe string? | ||
""" | ||
|
||
raise NotImplementedError |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from metrics.base import Metric | ||
from src.utils.explanations import ( | ||
BatchedCachedExplanations, | ||
TensorExplanations, | ||
) | ||
from utils.cache import ExplanationsCache as EC | ||
|
||
|
||
class IdenticalClass(Metric): | ||
def __init__(self, device, *args, **kwargs): | ||
super().__init__(device, *args, **kwargs) | ||
|
||
def __call__( | ||
self, | ||
model: torch.nn.Module, | ||
train_dataset: torch.utils.data.Dataset, | ||
test_labels: torch.Tensor, | ||
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./", | ||
batch_size: Optional[int] = 8, | ||
**kwargs, | ||
): | ||
""" | ||
:param test_labelsictions: | ||
:param explanations: | ||
:param saved_explanations_batch_size: | ||
:param kwargs: | ||
:return: | ||
""" | ||
|
||
if isinstance(explanations, str): | ||
explanations = EC.load(path=explanations, device=self.device) | ||
elif isinstance(explanations, torch.Tensor): | ||
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device) | ||
|
||
scores = [] | ||
n_processed = 0 | ||
for i in range(len(explanations)): | ||
assert n_processed + explanations[i].shape[0] <= len( | ||
test_labels | ||
), f"Number of explanations ({n_processed + explanations[i].shape[0]}) exceeds the number of test labels." | ||
|
||
score = self._evaluate_instance( | ||
model=model, | ||
train_dataset=train_dataset, | ||
test_labels=test_labels[n_processed : n_processed + explanations[i].shape[0]], | ||
xpl=explanations[i], | ||
) | ||
scores.append(score) | ||
n_processed += explanations[i].shape[0] | ||
|
||
return {"score": torch.cat(scores).mean()} | ||
|
||
def _evaluate_instance( | ||
self, | ||
model: torch.nn.Module, | ||
train_dataset: torch.utils.data.Dataset, | ||
test_labels: torch.Tensor, | ||
xpl: torch.Tensor, | ||
): | ||
""" | ||
Used to implement metric-specific logic. | ||
""" | ||
|
||
top_one_xpl_indices = xpl.argmax(dim=1) | ||
top_one_xpl_samples = torch.stack([train_dataset[i][0] for i in top_one_xpl_indices]) | ||
|
||
top_one_xpl_output = model(top_one_xpl_samples.to(self.device)) | ||
top_one_xpl_pred = top_one_xpl_output.argmax(dim=1) | ||
|
||
return (test_labels == top_one_xpl_pred) * 1.0 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import warnings | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from metrics.base import Metric | ||
from src.utils.explanations import ( | ||
BatchedCachedExplanations, | ||
TensorExplanations, | ||
) | ||
from utils.cache import ExplanationsCache as EC | ||
|
||
|
||
class TopKOverlap(Metric): | ||
def __init__(self, device, *args, **kwargs): | ||
super().__init__(device, *args, **kwargs) | ||
|
||
def __call__( | ||
self, | ||
model: torch.nn.Module, | ||
train_dataset: torch.utils.data.Dataset, | ||
top_k: int = 1, | ||
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./", | ||
batch_size: Optional[int] = 8, | ||
**kwargs, | ||
): | ||
""" | ||
:param test_predictions: | ||
:param explanations: | ||
:param batch_size: | ||
:param kwargs: | ||
:return: | ||
""" | ||
|
||
if isinstance(explanations, str): | ||
explanations = EC.load(path=explanations, device=self.device) | ||
if explanations.batch_size != batch_size: | ||
warnings.warn( | ||
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch " | ||
"size will be used instead." | ||
) | ||
batch_size = explanations[0] | ||
elif isinstance(explanations, torch.Tensor): | ||
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device) | ||
|
||
all_top_k_examples = [] | ||
|
||
for i in range(len(explanations)): | ||
top_k_examples = self._evaluate_instance( | ||
xpl=explanations[i], | ||
top_k=top_k, | ||
) | ||
all_top_k_examples += top_k_examples | ||
|
||
# calculate the cardinality of the set of top-k examples | ||
cardinality = len(set(all_top_k_examples)) | ||
|
||
# TODO: calculate the probability of the set of top-k examples | ||
return {"score": cardinality} | ||
|
||
def _evaluate_instance( | ||
self, | ||
xpl: torch.Tensor, | ||
top_k: int = 1, | ||
): | ||
""" | ||
Used to implement metric-specific logic. | ||
""" | ||
|
||
top_k_indices = torch.topk(xpl, top_k).indices | ||
return top_k_indices |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.