From 4b9ee518d99435d299b85fa3c23e8f41d6404cd7 Mon Sep 17 00:00:00 2001 From: Niklas Schmolenski Date: Thu, 12 Dec 2024 11:04:58 +0100 Subject: [PATCH] feat: add explanations parameter to update method --- quanda/benchmarks/downstream_eval/class_detection.py | 2 +- .../benchmarks/downstream_eval/shortcut_detection.py | 6 +++++- .../benchmarks/downstream_eval/subclass_detection.py | 4 ++-- quanda/benchmarks/heuristics/mixed_datasets.py | 6 +++++- quanda/benchmarks/heuristics/top_k_cardinality.py | 2 +- quanda/metrics/base.py | 10 ++++++++++ quanda/metrics/downstream_eval/class_detection.py | 6 +++--- .../metrics/downstream_eval/mislabeling_detection.py | 6 +++--- quanda/metrics/downstream_eval/subclass_detection.py | 8 ++++---- quanda/metrics/ground_truth/linear_datamodeling.py | 6 +++--- quanda/metrics/heuristics/model_randomization.py | 6 +++--- 11 files changed, 40 insertions(+), 22 deletions(-) diff --git a/quanda/benchmarks/downstream_eval/class_detection.py b/quanda/benchmarks/downstream_eval/class_detection.py index 91ca764e..51ddc403 100644 --- a/quanda/benchmarks/downstream_eval/class_detection.py +++ b/quanda/benchmarks/downstream_eval/class_detection.py @@ -333,6 +333,6 @@ def evaluate( test_tensor=input, targets=targets, ) - metric.update(targets, explanations) + metric.update(explanations=explanations, test_labels=targets) return metric.compute() diff --git a/quanda/benchmarks/downstream_eval/shortcut_detection.py b/quanda/benchmarks/downstream_eval/shortcut_detection.py index 93eb23c7..52cabeac 100644 --- a/quanda/benchmarks/downstream_eval/shortcut_detection.py +++ b/quanda/benchmarks/downstream_eval/shortcut_detection.py @@ -595,6 +595,10 @@ def evaluate( test_tensor=input, targets=targets, ) - metric.update(explanations, test_tensor=input, test_labels=labels) + metric.update( + explanations=explanations, + test_tensor=input, + test_labels=labels, + ) return metric.compute() diff --git a/quanda/benchmarks/downstream_eval/subclass_detection.py b/quanda/benchmarks/downstream_eval/subclass_detection.py index 5a2624fe..6277be4c 100644 --- a/quanda/benchmarks/downstream_eval/subclass_detection.py +++ b/quanda/benchmarks/downstream_eval/subclass_detection.py @@ -565,8 +565,8 @@ def evaluate( ) metric.update( - labels, - explanations, + test_subclasses=labels, + explanations=explanations, test_tensor=inputs, test_classes=grouped_labels, ) diff --git a/quanda/benchmarks/heuristics/mixed_datasets.py b/quanda/benchmarks/heuristics/mixed_datasets.py index 1132f5b6..a9f01129 100644 --- a/quanda/benchmarks/heuristics/mixed_datasets.py +++ b/quanda/benchmarks/heuristics/mixed_datasets.py @@ -554,6 +554,10 @@ def evaluate( explanations = explainer.explain( test_tensor=inputs, targets=targets ) - metric.update(explanations, test_tensor=inputs, test_labels=labels) + metric.update( + explanations=explanations, + test_tensor=inputs, + test_labels=labels, + ) return metric.compute() diff --git a/quanda/benchmarks/heuristics/top_k_cardinality.py b/quanda/benchmarks/heuristics/top_k_cardinality.py index af7ae5fc..2caa6361 100644 --- a/quanda/benchmarks/heuristics/top_k_cardinality.py +++ b/quanda/benchmarks/heuristics/top_k_cardinality.py @@ -345,6 +345,6 @@ def evaluate( test_tensor=input, targets=targets, ) - metric.update(explanations) + metric.update(explanations=explanations) return metric.compute() diff --git a/quanda/metrics/base.py b/quanda/metrics/base.py index 07734b3e..9875b937 100644 --- a/quanda/metrics/base.py +++ b/quanda/metrics/base.py @@ -59,11 +59,21 @@ def __init__( @abstractmethod def update( self, + explanations: torch.Tensor, *args: Any, **kwargs: Any, ): """Update the metric with new data. + Parameters + ---------- + explanations : torch.Tensor + Explanations of the test samples. + *args : Any + Additional positional arguments. + **kwargs : Any + Additional keyword arguments. + Raises ------ NotImplementedError diff --git a/quanda/metrics/downstream_eval/class_detection.py b/quanda/metrics/downstream_eval/class_detection.py index 9ddfa0f4..d74ec082 100644 --- a/quanda/metrics/downstream_eval/class_detection.py +++ b/quanda/metrics/downstream_eval/class_detection.py @@ -61,15 +61,15 @@ def __init__( self.scores: List[torch.Tensor] = [] - def update(self, test_labels: torch.Tensor, explanations: torch.Tensor): + def update(self, explanations: torch.Tensor, test_labels: torch.Tensor): """Update the metric state with the provided explanations. Parameters ---------- - test_labels : torch.Tensor - Labels of the test samples. explanations : torch.Tensor Explanations of the test samples. + test_labels : torch.Tensor + Labels of the test samples. """ assert test_labels.shape[0] == explanations.shape[0], ( diff --git a/quanda/metrics/downstream_eval/mislabeling_detection.py b/quanda/metrics/downstream_eval/mislabeling_detection.py index 27e922d7..f277ad98 100644 --- a/quanda/metrics/downstream_eval/mislabeling_detection.py +++ b/quanda/metrics/downstream_eval/mislabeling_detection.py @@ -227,9 +227,9 @@ def aggr_based( def update( self, + explanations: torch.Tensor, test_data: torch.Tensor, test_labels: torch.Tensor, - explanations: torch.Tensor, **kwargs, ): """Update the aggregator based metric with local attributions. @@ -238,12 +238,12 @@ def update( Parameters ---------- + explanations : torch.Tensor + The local attributions to be added to the aggregated scores. test_data : torch.Tensor The test data for which the attributions were computed. test_labels : torch.Tensor The true labels of the test data. - explanations : torch.Tensor - The local attributions to be added to the aggregated scores. kwargs : Any Additional keyword arguments. diff --git a/quanda/metrics/downstream_eval/subclass_detection.py b/quanda/metrics/downstream_eval/subclass_detection.py index 9642ccc4..621b381c 100644 --- a/quanda/metrics/downstream_eval/subclass_detection.py +++ b/quanda/metrics/downstream_eval/subclass_detection.py @@ -70,8 +70,8 @@ def __init__( def update( self, - test_subclasses: Union[List[int], torch.Tensor], explanations: torch.Tensor, + test_subclasses: Union[List[int], torch.Tensor], test_tensor: Optional[torch.Tensor] = None, test_classes: Optional[torch.Tensor] = None, ): @@ -79,10 +79,10 @@ def update( Parameters ---------- - test_subclasses : torch.Tensor - Original labels of the test samples explanations : torch.Tensor - Explanations of the test samples + Explanations of the test samples. + test_subclasses : torch.Tensor + Original labels of the test samples. test_tensor: Optional[torch.Tensor] Test samples to used to generate the explanations. Only required if `filter_by_prediction` is True during diff --git a/quanda/metrics/ground_truth/linear_datamodeling.py b/quanda/metrics/ground_truth/linear_datamodeling.py index 0e657d9d..b1688dc7 100644 --- a/quanda/metrics/ground_truth/linear_datamodeling.py +++ b/quanda/metrics/ground_truth/linear_datamodeling.py @@ -228,22 +228,22 @@ def load_counterfactual_model(self, model_idx: int): def update( self, - test_tensor: torch.Tensor, explanations: torch.Tensor, explanation_targets: torch.Tensor, + test_tensor: torch.Tensor, **kwargs, ): """Update the evaluation scores based on new data. Parameters ---------- - test_tensor : torch.Tensor - The test data used for evaluation. explanations : torch.Tensor The explanation scores for the test data with shape (test_samples, dataset_size). explanation_targets : torch.Tensor The target values for the explanations. + test_tensor : torch.Tensor + The test data used for evaluation. kwargs: Any Additional keyword arguments diff --git a/quanda/metrics/heuristics/model_randomization.py b/quanda/metrics/heuristics/model_randomization.py index 21fe7e0c..48d46b00 100644 --- a/quanda/metrics/heuristics/model_randomization.py +++ b/quanda/metrics/heuristics/model_randomization.py @@ -112,18 +112,18 @@ def __init__( def update( self, - test_data: torch.Tensor, explanations: torch.Tensor, + test_data: torch.Tensor, explanation_targets: Optional[torch.Tensor] = None, ): """Update the evaluation scores based on the provided data. Parameters ---------- - test_data : torch.Tensor - The test data used for evaluation. explanations : torch.Tensor The explanations generated by the model. + test_data : torch.Tensor + The test data used for evaluation. explanation_targets : Optional[torch.Tensor], optional The target values for the explanations, by default None.