Skip to content

Commit

Permalink
last changes for implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed May 15, 2024
1 parent 50b6db4 commit 25c963a
Showing 1 changed file with 65 additions and 19 deletions.
84 changes: 65 additions & 19 deletions src/metrics/randomization.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,101 @@
from abc import ABC, abstractmethod
from types import Callable

from utils.explanations import Explanations
import torch


class RandomizationMetric(ABC):
def __init__(self):
pass
def __init__(self, seed=42, device: str = "cpu" if torch.cuda.is_available() else "cuda"):
# we can move seed and device to __call__. Then we would need to set the seed per call of the metric function.
# where does it make sense to do seeding?
# for example, imagine the user doesn't bother giving a seed, so we use the default seed.
# do we want the exact same random model to be attributed (keeping seed in the __call__ call)
# or do we want genuinely random models for each call of the metric (keeping seed in the constructor)
self.generator=torch.Generator(device=device)

@abstractmethod
def __call__(
self,
model: torch.nn.Module,
model_id: str,
cache_dir: str,
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: torch.utils.data.Dataset,
explanations: Explanations,
explain_fn: Callable,
explain_fn_kwargs: dict
):
# Allow for precomputed random explanations?
rand_model = RandomizationMetric._randomize_model(model)
return self._evaluate(explanations,explain_fn, explain_fn_kwargs)

@abstractmethod
randomized_model = RandomizationMetric._randomize_model(model, self.device, self.generator)
results= self._evaluate(explanations, randomized_model, explain_fn, explain_fn_kwargs)
results["model_id"]=model_id
return results

def _evaluate(
self,
model: torch.nn.Module,
explanations: torch.utils.data.Dataset,
explanations: Explanations,
randomized_model: torch.nn.Module,
explain_fn: Callable,
explain_fn_kwargs: dict
):
"""
Used to implement metric-specific logic.
"""

raise NotImplementedError
rand_explanations=explain_fn(model=randomized_model, **explain_fn_kwargs)
rank_corr=RandomizationMetric.rank_correlation(explanations, rand_explanations)
results=dict()
results["rank_correlations"]=rank_corr
results["average_score"]=rank_corr.mean()
return results

@staticmethod
def _randomize_model(model):
def _randomize_model(model, generator):
for name,param in list(model.named_parameters()):
random_parameter_tensor=torch.empty_like(param).normal_(generator=generator)
names=name.split(".")
param_obj=model
for n in names[:len(names)-1]:
param_obj=param_obj.__getattr__(n)
assert isinstance(param_obj.__getattr__(names[-1]),torch.nn.Parameter)
param_obj.__setattr__(names[-1],torch.nn.Parameter(random_parameter_tensor))
return model

@staticmethod
@abstractmethod
def _rank_correlation(std_explanations, random_explanations):
# this implementation currently assumes batch sizes and number of batches are same in std and random explanations
train_size=std_explanations[0].shape[1]
std_rank_mean=torch.zeros(train_size)
random_rank_mean=torch.zeros(train_size)
std_batch_size = std_explanations.batch_size
random_batch_size = random_explanations.batch_size
for std_batch, random_batch in zip(std_explanations,random_explanations):
_,std_ranks = torch.sort(std_batch)
_,random_ranks = torch.sort(random_batch)
std_rank_mean += torch.tensor(std_ranks,dtype=float)/train_size
random_rank_mean += torch.tensor(random_ranks,dtype=float)/train_size
std_rank_mean /= std_batch_size*len(std_explanations)
random_rank_mean /= std_batch_size*len(random_explanations)
for std_batch, random_batch in zip(std_explanations,random_explanations):
_,std_ranks = torch.sort(std_batch)
_,random_ranks = torch.sort(random_batch)
std_rank_mean += std_ranks/train_size
random_rank_mean += random_ranks/train_size
std_ranks=std_ranks-std_rank_mean
random_ranks=random_ranks-random_rank_mean
corrs=corrs+std_ranks*random_ranks
corrs=corrs/len(std_explanations)
return corrs #return spearman rank correlation of each training data influence



@staticmethod
def _format(
self,
model: torch.nn.Module,
train_dataset: torch.utils.data.Dataset,
test_dataset: torch.utils.data.Dataset,
explanations: torch.utils.data.Dataset,
explanations: Explanations,
):
"""
Format the output of the metric to a predefined format, maybe string?
"""
# shouldn't we have a self.results to be able to do this? maybe just get results dict as format input?
# the metric summary should be a list of values for each test point and a mean score for most metrics

raise NotImplementedError

0 comments on commit 25c963a

Please sign in to comment.