Skip to content

Commit

Permalink
changes for flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed Jun 3, 2024
1 parent 286b76e commit 635332e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
100 changes: 100 additions & 0 deletions src/metrics/randomization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from metrics.base import Metric
from types import Callable
from utils.explanations import Explanations
import torch


class RandomizationMetric(Metric):
def __init__(self, seed=42, device: str = "cpu" if torch.cuda.is_available() else "cuda"):
# we can move seed and device to __call__. Then we would need to set the seed per call of the metric function.
# where does it make sense to do seeding?
# for example, imagine the user doesn't bother giving a seed, so we use the default seed.
# do we want the exact same random model to be attributed (keeping seed in the __call__ call)
# or do we want genuinely random models for each call of the metric (keeping seed in the constructor)
self.generator = torch.Generator(device=device)

def __call__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: str,
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: Explanations,
explain_fn: Callable,
explain_fn_kwargs: dict,
):
# Allow for precomputed random explanations?
randomized_model = RandomizationMetric._randomize_model(model, self.device, self.generator)
results = self._evaluate(explanations, randomized_model, explain_fn, explain_fn_kwargs)
results["model_id"] = model_id
return results

def _evaluate(
self,
explanations: Explanations,
randomized_model: torch.nn.Module,
explain_fn: Callable,
explain_fn_kwargs: dict,
):
"""
Used to implement metric-specific logic.
"""
rand_explanations = explain_fn(model=randomized_model, **explain_fn_kwargs)
rank_corr = RandomizationMetric.rank_correlation(explanations, rand_explanations)
results = dict()
results["rank_correlations"] = rank_corr
results["average_score"] = rank_corr.mean()
return results

@staticmethod
def _randomize_model(model, generator):
for name, param in list(model.named_parameters()):
random_parameter_tensor = torch.empty_like(param).normal_(generator=generator)
names = name.split(".")
param_obj = model
for n in names[: len(names) - 1]:
param_obj = param_obj.__getattr__(n)
assert isinstance(param_obj.__getattr__(names[-1]), torch.nn.Parameter)
param_obj.__setattr__(names[-1], torch.nn.Parameter(random_parameter_tensor))
return model

@staticmethod
def _rank_correlation(std_explanations, random_explanations):
# this implementation currently assumes batch sizes and number of batches are same in std and random explanations
train_size = std_explanations[0].shape[1]
std_rank_mean = torch.zeros(train_size)
random_rank_mean = torch.zeros(train_size)
std_batch_size = std_explanations.batch_size
random_batch_size = random_explanations.batch_size
corrs = torch.zeros(train_size)
for std_batch, random_batch in zip(std_explanations, random_explanations):
_, std_ranks = torch.sort(std_batch)
_, random_ranks = torch.sort(random_batch)
std_rank_mean += torch.tensor(std_ranks, dtype=float) / train_size
random_rank_mean += torch.tensor(random_ranks, dtype=float) / train_size
std_rank_mean /= std_batch_size * len(std_explanations)
random_rank_mean /= random_batch_size * len(random_explanations)
for std_batch, random_batch in zip(std_explanations, random_explanations):
_, std_ranks = torch.sort(std_batch)
_, random_ranks = torch.sort(random_batch)
std_rank_mean += std_ranks / train_size
random_rank_mean += random_ranks / train_size
std_ranks = std_ranks - std_rank_mean
random_ranks = random_ranks - random_rank_mean
corrs = corrs + std_ranks * random_ranks
corrs = corrs / len(std_explanations)
return corrs # return spearman rank correlation of each training data influence

@staticmethod
def _format(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: Explanations,
):
# shouldn't we have a self.results to be able to do this? maybe just get results dict as format input?
# the metric summary should be a list of values for each test point and a mean score for most metrics

raise NotImplementedError
6 changes: 6 additions & 0 deletions src/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import functools
from functools import reduce
<<<<<<< HEAD
from typing import Any, Callable, Mapping
=======
from typing import Any
import torch

>>>>>>> 2de78d5 (changes for flake8)

import torch

Expand Down
12 changes: 12 additions & 0 deletions src/utils/explanations.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def __setitem__(self, idx: Union[int, slice], val: Tuple[torch.Tensor, torch.Ten
def __len__(self) -> int:
return int(self.xpl.shape[0] // self.batch_size) + 1

def _setitem_single(
self, index: int, val: Tuple[torch.Tensor, Union[torch.Tensor, int]]
) -> Tuple[torch.Tensor, torch.Tensor]:
explanation, target = val
file_id = int(index / self.cache_batch_size)
leftover_indices = index - file_id * self.cache_batch_size
explanations = IndicesCache.load(self.cache_path, f"explanations_{file_id}")
targets = IndicesCache.load(self.cache_path, f"targets_{file_id}")
explanations[leftover_indices] = explanation
targets[leftover_indices] = target
IndicesCache.save(self.cache_path, f"explanations_{file_id}", explanations)
IndicesCache.save(self.cache_path, f"targets_{file_id}", targets)

class BatchedCachedExplanations(Explanations):
def __init__(
Expand Down

0 comments on commit 635332e

Please sign in to comment.