Skip to content

Commit

Permalink
Merge branch 'main' into docs_corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
aski02 committed Dec 29, 2024
2 parents f827cfa + 501e009 commit 7a61984
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 22 deletions.
2 changes: 1 addition & 1 deletion quanda/benchmarks/downstream_eval/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,6 @@ def evaluate(
test_tensor=input,
targets=targets,
)
metric.update(targets, explanations)
metric.update(explanations=explanations, test_labels=targets)

return metric.compute()
6 changes: 5 additions & 1 deletion quanda/benchmarks/downstream_eval/shortcut_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,10 @@ def evaluate(
test_tensor=input,
targets=targets,
)
metric.update(explanations, test_tensor=input, test_labels=labels)
metric.update(
explanations=explanations,
test_tensor=input,
test_labels=labels,
)

return metric.compute()
4 changes: 2 additions & 2 deletions quanda/benchmarks/downstream_eval/subclass_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,8 +565,8 @@ def evaluate(
)

metric.update(
labels,
explanations,
test_subclasses=labels,
explanations=explanations,
test_tensor=inputs,
test_classes=grouped_labels,
)
Expand Down
6 changes: 5 additions & 1 deletion quanda/benchmarks/heuristics/mixed_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,10 @@ def evaluate(
explanations = explainer.explain(
test_tensor=inputs, targets=targets
)
metric.update(explanations, test_tensor=inputs, test_labels=labels)
metric.update(
explanations=explanations,
test_tensor=inputs,
test_labels=labels,
)

return metric.compute()
2 changes: 1 addition & 1 deletion quanda/benchmarks/heuristics/top_k_cardinality.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,6 @@ def evaluate(
test_tensor=input,
targets=targets,
)
metric.update(explanations)
metric.update(explanations=explanations)

return metric.compute()
10 changes: 10 additions & 0 deletions quanda/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,21 @@ def __init__(
@abstractmethod
def update(
self,
explanations: torch.Tensor,
*args: Any,
**kwargs: Any,
):
"""Update the metric with new data.
Parameters
----------
explanations : torch.Tensor
Explanations of the test samples.
*args : Any
Additional positional arguments.
**kwargs : Any
Additional keyword arguments.
Raises
------
NotImplementedError
Expand Down
6 changes: 3 additions & 3 deletions quanda/metrics/downstream_eval/class_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def __init__(

self.scores: List[torch.Tensor] = []

def update(self, test_labels: torch.Tensor, explanations: torch.Tensor):
def update(self, explanations: torch.Tensor, test_labels: torch.Tensor):
"""Update the metric state with the provided explanations.
Parameters
----------
test_labels : torch.Tensor
Labels of the test samples.
explanations : torch.Tensor
Explanations of the test samples.
test_labels : torch.Tensor
Labels of the test samples.
"""
assert test_labels.shape[0] == explanations.shape[0], (
Expand Down
6 changes: 3 additions & 3 deletions quanda/metrics/downstream_eval/mislabeling_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ def aggr_based(

def update(
self,
explanations: torch.Tensor,
test_data: torch.Tensor,
test_labels: torch.Tensor,
explanations: torch.Tensor,
**kwargs,
):
"""Update the aggregator based metric with local attributions.
Expand All @@ -238,12 +238,12 @@ def update(
Parameters
----------
explanations : torch.Tensor
The local attributions to be added to the aggregated scores.
test_data : torch.Tensor
The test data for which the attributions were computed.
test_labels : torch.Tensor
The true labels of the test data.
explanations : torch.Tensor
The local attributions to be added to the aggregated scores.
kwargs : Any
Additional keyword arguments.
Expand Down
8 changes: 4 additions & 4 deletions quanda/metrics/downstream_eval/subclass_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,19 @@ def __init__(

def update(
self,
test_subclasses: Union[List[int], torch.Tensor],
explanations: torch.Tensor,
test_subclasses: Union[List[int], torch.Tensor],
test_tensor: Optional[torch.Tensor] = None,
test_classes: Optional[torch.Tensor] = None,
):
"""Update the metric state with the provided explanations.
Parameters
----------
test_subclasses : torch.Tensor
Original labels of the test samples
explanations : torch.Tensor
Explanations of the test samples
Explanations of the test samples.
test_subclasses : torch.Tensor
Original labels of the test samples.
test_tensor: Optional[torch.Tensor]
Test samples to used to generate the explanations.
Only required if `filter_by_prediction` is True during
Expand Down
6 changes: 3 additions & 3 deletions quanda/metrics/ground_truth/linear_datamodeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,22 +228,22 @@ def load_counterfactual_model(self, model_idx: int):

def update(
self,
test_tensor: torch.Tensor,
explanations: torch.Tensor,
explanation_targets: torch.Tensor,
test_tensor: torch.Tensor,
**kwargs,
):
"""Update the evaluation scores based on new data.
Parameters
----------
test_tensor : torch.Tensor
The test data used for evaluation.
explanations : torch.Tensor
The explanation scores for the test data with shape (test_samples,
dataset_size).
explanation_targets : torch.Tensor
The target values for the explanations.
test_tensor : torch.Tensor
The test data used for evaluation.
kwargs: Any
Additional keyword arguments
Expand Down
6 changes: 3 additions & 3 deletions quanda/metrics/heuristics/model_randomization.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,18 @@ def __init__(

def update(
self,
test_data: torch.Tensor,
explanations: torch.Tensor,
test_data: torch.Tensor,
explanation_targets: Optional[torch.Tensor] = None,
):
"""Update the evaluation scores based on the provided data.
Parameters
----------
test_data : torch.Tensor
The test data used for evaluation.
explanations : torch.Tensor
The explanations generated by the model.
test_data : torch.Tensor
The test data used for evaluation.
explanation_targets : Optional[torch.Tensor], optional
The target values for the explanations, by default None.
Expand Down

0 comments on commit 7a61984

Please sign in to comment.