Skip to content

Commit

Permalink
minor fixes to InverseEstimation class
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem committed Dec 7, 2023
1 parent fb3d5ef commit 1125384
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
8 changes: 2 additions & 6 deletions quantus/metrics/inverse_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
normalise: bool = False,
normalise_func: Optional[Callable] = None,
normalise_func_kwargs: Optional[Dict[str, Any]] = None,
return_aggregate: Optional[bool] = None,
return_aggregate: Optional[bool] = True,
aggregate_func: Optional[Callable] = None,
default_plot_func: Optional[Callable] = None,
disable_warnings: Optional[bool] = None,
Expand Down Expand Up @@ -235,15 +235,12 @@ def __call__(
a_batch is not None
), "'a_batch' must be provided to run the inverse estimation."

# TODO. Do we want to turn the attributions to rankings?
# See: https://github.com/annahedstroem/eval-project/blob/febe271a78c6efc16a51372ab58fcba676e0eb88/src/xai_faithfulness_experiments_lib_edits.py#L403
self.scores = self.metric_init(
model=model,
x_batch=x_batch,
y_batch=y_batch,
a_batch=a_batch,
s_batch=s_batch,
# custom_batch=custom_batch,
channel_first=channel_first,
explain_func=explain_func,
explain_func_kwargs=explain_func_kwargs,
Expand All @@ -261,14 +258,13 @@ def __call__(
self.metric_init.evaluation_scores = []

# Run inverse experiment.
a_batch_inv = -np.array(a_batch) # / np.min(-np.array(a_batch))
a_batch_inv = -np.array(a_batch)
self.scores_inv = self.metric_init(
model=model,
x_batch=x_batch,
y_batch=y_batch,
a_batch=a_batch_inv,
s_batch=s_batch,
# custom_batch=custom_batch,
channel_first=channel_first,
explain_func=explain_func,
explain_func_kwargs=explain_func_kwargs,
Expand Down
1 change: 1 addition & 0 deletions tests/metrics/test_inverse_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def test_inverse_estimation_with_pixel_flipping(
for s in s_list
]
), "Test failed."

except expected["exception"] as e:
print(f'Raised exception type {expected["exception"]}', e)
return

0 comments on commit 1125384

Please sign in to comment.