Skip to content

Commit

Permalink
change file name, incorporate torch spearman and kendall rank correla…
Browse files Browse the repository at this point in the history
…tion options
  • Loading branch information
gumityolcu committed May 22, 2024
1 parent 8f13c38 commit fd497a7
Showing 1 changed file with 22 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from typing import Callable
from typing import Callable, Union

import torch

from metrics.base import Metric
from utils.explanations import Explanations
from utils.functions.correlations import explanation_spearman_rank_correlation


class MPRTMetric(Metric):
def __init__(self, seed: int = 42, device: str = "cpu" if torch.cuda.is_available() else "cuda"):
class ModelRandomizationMetric(Metric):
def __init__(self, correlation_measure: Union[Callable, str, None]="spearman",
seed: int = 42,
device: str = "cpu" if torch.cuda.is_available() else "cuda"):
# we can move seed and device to __call__. Then we would need to set the seed per call of the metric function.
# where does it make sense to do seeding?
# for example, imagine the user doesn't bother giving a seed, so we use the default seed.
# do we want the exact same random model to be attributed (keeping seed in the __call__ call)
# or do we want genuinely random models for each call of the metric (keeping seed in the constructor)
self.generator = torch.Generator(device=device)
if correlation_measure is None: correlation_measure = "spearman"
if isinstance(correlation_measure, str):
assert correlation_measure in ["spearman"], f"Correlation measure {correlation_measure} is not implemented."
if correlation_measure=="spearman":
correlation_measure=explanation_spearman_rank_correlation
assert isinstance(Callable,correlation_measure)
self.correlation_measure=correlation_measure

def __call__(
self,
Expand All @@ -24,16 +34,18 @@ def __call__(
test_dataset: torch.utils.data.Dataset,
explanations: Explanations,
explain_fn: Callable,
explain_fn_kwargs: dict,
explain_fn_kwargs: dict
):
# Allow for precomputed random explanations?
randomized_model = MPRTMetric._randomize_model(model, self.device, self.generator)
results=dict()
randomized_model = ModelRandomizationMetric._randomize_model(model, self.device, self.generator)
rand_explanations = explain_fn(model=randomized_model, **explain_fn_kwargs)
rank_corr = MPRTMetric._spearman_rank_correlation(explanations, rand_explanations)
results = dict()
results["rank_correlations"] = rank_corr
results["average_score"] = rank_corr.mean()
results["model_id"] = model_id
corrs=torch.empty(explanations[0].shape[-1])
for std_batch,rand_batch in zip(explanations,rand_explanations):
newcorrs=self.correlation_measure(std_batch.T, rand_batch.T)
corrs=torch.cat((corrs,newcorrs))
results["rank_correlations"]=corrs
results["score"]=corrs.mean()
return results

def _evaluate_instance(
Expand All @@ -60,31 +72,6 @@ def _randomize_model(model: torch.nn.Module, generator: torch.Generator):
param_obj.__setattr__(names[-1], torch.nn.Parameter(random_parameter_tensor))
return model

@staticmethod
def _rank_mean(explanations: Explanations):
train_size = explanations[0].shape[1]
rank_mean = torch.zeros(train_size)
count = 0
for batch in explanations:
_, ranks = torch.sort(batch)
rank_mean += torch.tensor(ranks, dtype=float) / train_size
count = count + batch.shape[0]
return rank_mean / count, count

@staticmethod
def _spearman_rank_correlation(std_explanations: Explanations, random_explanations: Explanations):
train_size = std_explanations[0].shape[1]
std_rank_mean, test_size = MPRTMetric._rank_mean(std_explanations)
random_rank_mean, _ = MPRTMetric._rank_mean(random_explanations)
corrs = torch.zeros(train_size)
for std_batch, random_batch in zip(std_explanations, random_explanations):
_, std_ranks = torch.sort(std_batch)
_, random_ranks = torch.sort(random_batch)
std_ranks = std_ranks - std_rank_mean
random_ranks = random_ranks - random_rank_mean
corrs = corrs + (std_ranks * random_ranks)
return corrs / test_size # return spearman rank correlation of each training data influence

@staticmethod
def _format(
self,
Expand Down

0 comments on commit fd497a7

Please sign in to comment.