Skip to content

Commit

Permalink
fix top_k_overlap.py + copy useful code to functional.py
Browse files Browse the repository at this point in the history
dilyabareeva committed May 31, 2024
1 parent 3a2a66a commit 80a265c
Showing 3 changed files with 73 additions and 59 deletions.
49 changes: 49 additions & 0 deletions src/metrics/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
WORK IN PROGRESS!!!
"""
import warnings
from typing import Optional, Union

import torch

from src.utils.explanations import (
BatchedCachedExplanations,
TensorExplanations,
)
from utils.cache import ExplanationsCache as EC


def function_example(
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
top_k: int = 1,
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./",
batch_size: Optional[int] = 8,
device="cpu",
**kwargs,
):
"""
I've copied the existing code from the memory-less metric version here, that can be reused in the future here.
It will not be called "function_example" in the future. There will be many reusable functions, but every metric
will get a functional version here.
:param model:
:param train_dataset:
:param top_k:
:param explanations:
:param batch_size:
:param device:
:param kwargs:
:return:
"""
if isinstance(explanations, str):
explanations = EC.load(path=explanations, device=device)
if explanations.batch_size != batch_size:
warnings.warn(
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch "
"size will be used instead."
)
batch_size = explanations[0]
elif isinstance(explanations, torch.Tensor):
explanations = TensorExplanations(explanations, batch_size=batch_size, device=device)
76 changes: 21 additions & 55 deletions src/metrics/unnamed/top_k_overlap.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,38 @@
import warnings
from typing import Optional, Union

import torch

from metrics.base import Metric
from src.utils.explanations import (
BatchedCachedExplanations,
TensorExplanations,
)
from utils.cache import ExplanationsCache as EC


class TopKOverlap(Metric):
def __init__(self, device, *args, **kwargs):
super().__init__(device, *args, **kwargs)

def __call__(
def __init__(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
top_k: int = 1,
explanations: Union[str, torch.Tensor, TensorExplanations, BatchedCachedExplanations] = "./",
batch_size: Optional[int] = 8,
device: str = "cpu",
*args,
**kwargs,
):
"""
:param test_predictions:
:param explanations:
:param batch_size:
:param kwargs:
:return:
"""
super().__init__(model, train_dataset, *args, **kwargs)
self.top_k = top_k
self.all_top_k_examples = []

if isinstance(explanations, str):
explanations = EC.load(path=explanations, device=self.device)
if explanations.batch_size != batch_size:
warnings.warn(
"Batch size mismatch between loaded explanations and passed batch size. The inferred batch "
"size will be used instead."
)
batch_size = explanations[0]
elif isinstance(explanations, torch.Tensor):
explanations = TensorExplanations(explanations, batch_size=batch_size, device=self.device)

all_top_k_examples = []

for i in range(len(explanations)):
top_k_examples = self._evaluate_instance(
xpl=explanations[i],
top_k=top_k,
)
all_top_k_examples += top_k_examples
def update(
self,
explanations: torch.Tensor,
**kwargs,
):
top_k_indices = torch.topk(explanations, self.top_k).indices
self.all_top_k_examples += top_k_indices

# calculate the cardinality of the set of top-k examples
cardinality = len(set(all_top_k_examples))
def compute(self, *args, **kwargs):
return len(set(self.all_top_k_examples))

# TODO: calculate the probability of the set of top-k examples
return {"score": cardinality}
def reset(self, *args, **kwargs):
self.all_top_k_examples = []

def _evaluate_instance(
self,
xpl: torch.Tensor,
top_k: int = 1,
):
"""
Used to implement metric-specific logic.
"""
def load_state_dict(self, state_dict: dict, *args, **kwargs):
self.all_top_k_examples = state_dict["all_top_k_examples"]

top_k_indices = torch.topk(xpl, top_k).indices
return top_k_indices
def state_dict(self, *args, **kwargs):
return {"all_top_k_examples": self.all_top_k_examples}
7 changes: 3 additions & 4 deletions tests/metrics/test_unnamed_metrics.py
Original file line number Diff line number Diff line change
@@ -14,8 +14,7 @@ def test_top_k_overlap_metrics(test_id, model, dataset, top_k, batch_size, expla
model = request.getfixturevalue(model)
dataset = request.getfixturevalue(dataset)
explanations = request.getfixturevalue(explanations)
metric = TopKOverlap(device="cpu")
score = metric(model=model, train_dataset=dataset, top_k=top_k, explanations=explanations, batch_size=batch_size)[
"score"
]
metric = TopKOverlap(model=model, train_dataset=dataset, top_k=top_k, device="cpu")
metric.update(explanations=explanations)
score = metric.compute()
assert score == expected_score

0 comments on commit 80a265c

Please sign in to comment.