Skip to content

Commit

Permalink
fix conflict while pulling from remote
Browse files Browse the repository at this point in the history
  • Loading branch information
gumityolcu committed May 22, 2024
1 parent fd497a7 commit d052f73
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/metrics/randomization/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from metrics.base import Metric
from utils.explanations import Explanations
from utils.functions.correlations import explanation_spearman_rank_correlation
from utils.functions.correlations import kendall_rank_corr, spearman_rank_corr


class ModelRandomizationMetric(Metric):
Expand All @@ -21,7 +21,9 @@ def __init__(self, correlation_measure: Union[Callable, str, None]="spearman",
if isinstance(correlation_measure, str):
assert correlation_measure in ["spearman"], f"Correlation measure {correlation_measure} is not implemented."
if correlation_measure=="spearman":
correlation_measure=explanation_spearman_rank_correlation
correlation_measure=spearman_rank_corr
elif correlation_measure=="kendall":
correlation_measure=kendall_rank_corr
assert isinstance(Callable,correlation_measure)
self.correlation_measure=correlation_measure

Expand Down
11 changes: 11 additions & 0 deletions src/utils/functions/correlations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torchmetrics.functional.regression import kendall_rank_corrcoef, spearman_corrcoef


# torchmetrics wants the independent realizations to be the final dimension
# we transpose inputs before passing so that it is straightforward to pass explanations
# and use these funcitons in evaluation metrics
def kendall_rank_corr(tensor1,tensor2):
return kendall_rank_corrcoef(tensor1.T,tensor2.T)

def spearman_rank_corr(tensor1, tensor2):
return spearman_corrcoef(tensor1.T,tensor2.T)
51 changes: 42 additions & 9 deletions tests/metrics/test_randomization_metrics.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,62 @@
import pytest
import torch

from metrics.randomization.mprt import MPRTMetric
from utils.explanations import TensorExplanations
from utils.functions.correlations import spearman_rank_corr
from metrics.randomization.model_randomization import ModelRandomizationMetric


@pytest.mark.randomization
@pytest.mark.randomize
@pytest.mark.parametrize(
"model",
[
("load_mnist_model"),
],
)
def parameter_randomization_test(model, request):
def model_randomization_test(model, request):
model1 = request.getfixturevalue(model)
model2 = request.getfixturevalue(model)
gen = torch.Generator()
gen.manual_seed(42)
MPRTMetric._randomize_model(model2, gen)
ModelRandomizationMetric._randomize_model(model2, gen)
for param1, param2 in zip(model1.parameters(), model2.parameters()):
assert torch.norm(param1.data - param2.data) > 1e3 # norm of the difference in parameters should be significant


@pytest.mark.parametrize()
def model_randomization_test():
assert torch.__version__ == "2.0.0"
@pytest.mark.randomize
def reproducibility_test():
assert torch.__version__=="2.0.0"
gen = torch.Generator()
gen.manual_seed(42)
assert torch.all(torch.rand(5, generator=gen) == torch.Tensor([0.8823, 0.9150, 0.3829, 0.9593, 0.3904]))
assert torch.all(torch.rand(5,generator=gen)==torch.Tensor([0.8823, 0.9150, 0.3829, 0.9593, 0.3904]))

@pytest.mark.randomize
def kendall_metric_test():
def explain_fn(model):
xpl_tensor=torch.tensor([[1,2,3,4],[4,3,2,1]])
return TensorExplanations(xpl_tensor)

xpl_tensor=torch.tensor([[1,2,3,4],[1,2,3,4]])
metric=ModelRandomizationMetric(correlation_measure="kendall")
assert torch.all(metric["rank_correlations"]==torch.tensor([1.,-1.]))

@pytest.mark.randomize
@pytest.mark.parametrize(
"model",
[
("load_mnist_model"),
],
)
def spearman_metric_test(model,request):
def explain_fn(model):
xpl_tensor=torch.tensor([[1,2,3,4],[4,3,2,1]])
return TensorExplanations(xpl_tensor)

def corr_measure(tensor1,tensor2):
return spearman_rank_corr(tensor1,tensor2)
model=request.getfixturevalue(model)
xpl_tensor=torch.tensor([[1,2,3,4],[1,2,3,4]])
for corr_measure in ["spearman", "kendall", corr_measure]:
metric=ModelRandomizationMetric(correlation_measure=corr_measure)
metric=metric(
model,"0","",None,None,xpl_tensor,explain_fn,{})
assert torch.all(metric["rank_correlations"]==torch.tensor([1.,-1.]))

0 comments on commit d052f73

Please sign in to comment.