Skip to content

Commit

Permalink
Merge pull request #319 from understandable-machine-intelligence-lab/…
Browse files Browse the repository at this point in the history
…tiny-fixes

Add warning of deprecated argument (handle elegantly)
  • Loading branch information
annahedstroem authored Dec 5, 2023
2 parents fcd8ded + 6366f78 commit 6430e69
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion quantus/metrics/randomisation/mprt.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
similarity_func: Optional[Callable] = None,
layer_order: str = "top_down",
seed: int = 42,
return_sample_correlation: Optional[bool] = None,
return_average_correlation: bool = False,
return_last_correlation: bool = False,
skip_layers: bool = False,
Expand Down Expand Up @@ -148,6 +149,16 @@ def __init__(
**kwargs,
)

if return_sample_correlation is not None:
warnings.warn(
"'return_sample_correlation' parameter is deprecated and will be removed in future versions. "
f"Please use 'return_average_correlation' instead. "
f"Setting 'return_average_correlation' to {return_sample_correlation}",
DeprecationWarning,
)
# Use the value of 'return_average_correlation' for 'return_sample_correlation'
return_average_correlation = return_sample_correlation

# Save metric-specific attributes.
if similarity_func is None:
similarity_func = correlation_spearman
Expand Down Expand Up @@ -365,7 +376,9 @@ def __call__(
for a_batch, a_batch_perturbed in zip(
self.generate_a_batches(a_full_dataset), a_perturbed_generator
):
for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed):
for a_instance, a_instance_perturbed in zip(
a_batch, a_batch_perturbed
):
score = self.evaluate_instance(
model=random_layer_model,
x=None,
Expand Down

0 comments on commit 6430e69

Please sign in to comment.