diff --git a/src/metrics/randomization/model_randomization.py b/src/metrics/randomization/model_randomization.py index f60199ee..d0dd3bae 100644 --- a/src/metrics/randomization/model_randomization.py +++ b/src/metrics/randomization/model_randomization.py @@ -4,7 +4,7 @@ from metrics.base import Metric from utils.explanations import Explanations -from utils.functions.correlations import explanation_spearman_rank_correlation +from utils.functions.correlations import kendall_rank_corr, spearman_rank_corr class ModelRandomizationMetric(Metric): @@ -21,7 +21,9 @@ def __init__(self, correlation_measure: Union[Callable, str, None]="spearman", if isinstance(correlation_measure, str): assert correlation_measure in ["spearman"], f"Correlation measure {correlation_measure} is not implemented." if correlation_measure=="spearman": - correlation_measure=explanation_spearman_rank_correlation + correlation_measure=spearman_rank_corr + elif correlation_measure=="kendall": + correlation_measure=kendall_rank_corr assert isinstance(Callable,correlation_measure) self.correlation_measure=correlation_measure diff --git a/src/utils/functions/correlations.py b/src/utils/functions/correlations.py new file mode 100644 index 00000000..c1ebdc2f --- /dev/null +++ b/src/utils/functions/correlations.py @@ -0,0 +1,11 @@ +from torchmetrics.functional.regression import kendall_rank_corrcoef, spearman_corrcoef + + +# torchmetrics wants the independent realizations to be the final dimension +# we transpose inputs before passing so that it is straightforward to pass explanations +# and use these funcitons in evaluation metrics +def kendall_rank_corr(tensor1,tensor2): + return kendall_rank_corrcoef(tensor1.T,tensor2.T) + +def spearman_rank_corr(tensor1, tensor2): + return spearman_corrcoef(tensor1.T,tensor2.T) \ No newline at end of file diff --git a/tests/metrics/test_randomization_metrics.py b/tests/metrics/test_randomization_metrics.py index 8375277a..f3838929 100644 --- a/tests/metrics/test_randomization_metrics.py +++ b/tests/metrics/test_randomization_metrics.py @@ -1,29 +1,62 @@ import pytest import torch -from metrics.randomization.mprt import MPRTMetric +from utils.explanations import TensorExplanations +from utils.functions.correlations import spearman_rank_corr +from metrics.randomization.model_randomization import ModelRandomizationMetric -@pytest.mark.randomization +@pytest.mark.randomize @pytest.mark.parametrize( "model", [ ("load_mnist_model"), ], ) -def parameter_randomization_test(model, request): +def model_randomization_test(model, request): model1 = request.getfixturevalue(model) model2 = request.getfixturevalue(model) gen = torch.Generator() gen.manual_seed(42) - MPRTMetric._randomize_model(model2, gen) + ModelRandomizationMetric._randomize_model(model2, gen) for param1, param2 in zip(model1.parameters(), model2.parameters()): assert torch.norm(param1.data - param2.data) > 1e3 # norm of the difference in parameters should be significant - -@pytest.mark.parametrize() -def model_randomization_test(): - assert torch.__version__ == "2.0.0" +@pytest.mark.randomize +def reproducibility_test(): + assert torch.__version__=="2.0.0" gen = torch.Generator() gen.manual_seed(42) - assert torch.all(torch.rand(5, generator=gen) == torch.Tensor([0.8823, 0.9150, 0.3829, 0.9593, 0.3904])) + assert torch.all(torch.rand(5,generator=gen)==torch.Tensor([0.8823, 0.9150, 0.3829, 0.9593, 0.3904])) + +@pytest.mark.randomize +def kendall_metric_test(): + def explain_fn(model): + xpl_tensor=torch.tensor([[1,2,3,4],[4,3,2,1]]) + return TensorExplanations(xpl_tensor) + + xpl_tensor=torch.tensor([[1,2,3,4],[1,2,3,4]]) + metric=ModelRandomizationMetric(correlation_measure="kendall") + assert torch.all(metric["rank_correlations"]==torch.tensor([1.,-1.])) + +@pytest.mark.randomize +@pytest.mark.parametrize( + "model", + [ + ("load_mnist_model"), + ], +) +def spearman_metric_test(model,request): + def explain_fn(model): + xpl_tensor=torch.tensor([[1,2,3,4],[4,3,2,1]]) + return TensorExplanations(xpl_tensor) + + def corr_measure(tensor1,tensor2): + return spearman_rank_corr(tensor1,tensor2) + model=request.getfixturevalue(model) + xpl_tensor=torch.tensor([[1,2,3,4],[1,2,3,4]]) + for corr_measure in ["spearman", "kendall", corr_measure]: + metric=ModelRandomizationMetric(correlation_measure=corr_measure) + metric=metric( + model,"0","",None,None,xpl_tensor,explain_fn,{}) + assert torch.all(metric["rank_correlations"]==torch.tensor([1.,-1.])) \ No newline at end of file