From 11253840f399709569730aa02cbd9d76e9fe0956 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Thu, 7 Dec 2023 16:40:01 +0100 Subject: [PATCH] minor fixes to InverseEstimation class --- quantus/metrics/inverse_estimation.py | 8 ++------ tests/metrics/test_inverse_estimation.py | 1 + 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/quantus/metrics/inverse_estimation.py b/quantus/metrics/inverse_estimation.py index 73e12ed5..90d5a13d 100644 --- a/quantus/metrics/inverse_estimation.py +++ b/quantus/metrics/inverse_estimation.py @@ -56,7 +56,7 @@ def __init__( normalise: bool = False, normalise_func: Optional[Callable] = None, normalise_func_kwargs: Optional[Dict[str, Any]] = None, - return_aggregate: Optional[bool] = None, + return_aggregate: Optional[bool] = True, aggregate_func: Optional[Callable] = None, default_plot_func: Optional[Callable] = None, disable_warnings: Optional[bool] = None, @@ -235,15 +235,12 @@ def __call__( a_batch is not None ), "'a_batch' must be provided to run the inverse estimation." - # TODO. Do we want to turn the attributions to rankings? - # See: https://github.com/annahedstroem/eval-project/blob/febe271a78c6efc16a51372ab58fcba676e0eb88/src/xai_faithfulness_experiments_lib_edits.py#L403 self.scores = self.metric_init( model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch, s_batch=s_batch, - # custom_batch=custom_batch, channel_first=channel_first, explain_func=explain_func, explain_func_kwargs=explain_func_kwargs, @@ -261,14 +258,13 @@ def __call__( self.metric_init.evaluation_scores = [] # Run inverse experiment. - a_batch_inv = -np.array(a_batch) # / np.min(-np.array(a_batch)) + a_batch_inv = -np.array(a_batch) self.scores_inv = self.metric_init( model=model, x_batch=x_batch, y_batch=y_batch, a_batch=a_batch_inv, s_batch=s_batch, - # custom_batch=custom_batch, channel_first=channel_first, explain_func=explain_func, explain_func_kwargs=explain_func_kwargs, diff --git a/tests/metrics/test_inverse_estimation.py b/tests/metrics/test_inverse_estimation.py index 8d68b53c..22ae1257 100644 --- a/tests/metrics/test_inverse_estimation.py +++ b/tests/metrics/test_inverse_estimation.py @@ -247,6 +247,7 @@ def test_inverse_estimation_with_pixel_flipping( for s in s_list ] ), "Test failed." + except expected["exception"] as e: print(f'Raised exception type {expected["exception"]}', e) return