-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix top_k_overlap.py + copy useful code to functional.py
1 parent
3a2a66a
commit 80a265c
Showing
3 changed files
with
73 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
""" | ||
WORK IN PROGRESS!!! | ||
""" | ||
import warnings | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from src.utils.explanations import ( | ||
BatchedCachedExplanations, | ||
TensorExplanations, | ||
) | ||
from utils.cache import ExplanationsCache as EC | ||
|
||
|
||
def function_example( | ||
model: torch.nn.Module, | ||
train_dataset: torch.utils.data.Dataset, | ||
top_k: int = 1, | ||
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./", | ||
batch_size: Optional[int] = 8, | ||
device="cpu", | ||
**kwargs, | ||
): | ||
""" | ||
I've copied the existing code from the memory-less metric version here, that can be reused in the future here. | ||
It will not be called "function_example" in the future. There will be many reusable functions, but every metric | ||
will get a functional version here. | ||
:param model: | ||
:param train_dataset: | ||
:param top_k: | ||
:param explanations: | ||
:param batch_size: | ||
:param device: | ||
:param kwargs: | ||
:return: | ||
""" | ||
if isinstance(explanations, str): | ||
explanations = EC.load(path=explanations, device=device) | ||
if explanations.batch_size != batch_size: | ||
warnings.warn( | ||
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch " | ||
"size will be used instead." | ||
) | ||
batch_size = explanations[0] | ||
elif isinstance(explanations, torch.Tensor): | ||
explanations = TensorExplanations(explanations, batch_size=batch_size, device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,38 @@ | ||
import warnings | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from metrics.base import Metric | ||
from src.utils.explanations import ( | ||
BatchedCachedExplanations, | ||
TensorExplanations, | ||
) | ||
from utils.cache import ExplanationsCache as EC | ||
|
||
|
||
class TopKOverlap(Metric): | ||
def __init__(self, device, *args, **kwargs): | ||
super().__init__(device, *args, **kwargs) | ||
|
||
def __call__( | ||
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
train_dataset: torch.utils.data.Dataset, | ||
top_k: int = 1, | ||
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./", | ||
batch_size: Optional[int] = 8, | ||
device: str = "cpu", | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
:param test_predictions: | ||
:param explanations: | ||
:param batch_size: | ||
:param kwargs: | ||
:return: | ||
""" | ||
super().__init__(model, train_dataset, *args, **kwargs) | ||
self.top_k = top_k | ||
self.all_top_k_examples = [] | ||
|
||
if isinstance(explanations, str): | ||
explanations = EC.load(path=explanations, device=self.device) | ||
if explanations.batch_size != batch_size: | ||
warnings.warn( | ||
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch " | ||
"size will be used instead." | ||
) | ||
batch_size = explanations[0] | ||
elif isinstance(explanations, torch.Tensor): | ||
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device) | ||
|
||
all_top_k_examples = [] | ||
|
||
for i in range(len(explanations)): | ||
top_k_examples = self._evaluate_instance( | ||
xpl=explanations[i], | ||
top_k=top_k, | ||
) | ||
all_top_k_examples += top_k_examples | ||
def update( | ||
self, | ||
explanations: torch.Tensor, | ||
**kwargs, | ||
): | ||
top_k_indices = torch.topk(explanations, self.top_k).indices | ||
self.all_top_k_examples += top_k_indices | ||
|
||
# calculate the cardinality of the set of top-k examples | ||
cardinality = len(set(all_top_k_examples)) | ||
def compute(self, *args, **kwargs): | ||
return len(set(self.all_top_k_examples)) | ||
|
||
# TODO: calculate the probability of the set of top-k examples | ||
return {"score": cardinality} | ||
def reset(self, *args, **kwargs): | ||
self.all_top_k_examples = [] | ||
|
||
def _evaluate_instance( | ||
self, | ||
xpl: torch.Tensor, | ||
top_k: int = 1, | ||
): | ||
""" | ||
Used to implement metric-specific logic. | ||
""" | ||
def load_state_dict(self, state_dict: dict, *args, **kwargs): | ||
self.all_top_k_examples = state_dict["all_top_k_examples"] | ||
|
||
top_k_indices = torch.topk(xpl, top_k).indices | ||
return top_k_indices | ||
def state_dict(self, *args, **kwargs): | ||
return {"all_top_k_examples": self.all_top_k_examples} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters