From d2d7f39e25f27ad8f15aa93e67bed3749fe2d404 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Tue, 5 Dec 2023 12:14:31 +0100 Subject: [PATCH] add warning of old script --- quantus/metrics/randomisation/mprt.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/quantus/metrics/randomisation/mprt.py b/quantus/metrics/randomisation/mprt.py index cca0b43d..9a974609 100644 --- a/quantus/metrics/randomisation/mprt.py +++ b/quantus/metrics/randomisation/mprt.py @@ -78,6 +78,7 @@ def __init__( similarity_func: Optional[Callable] = None, layer_order: str = "top_down", seed: int = 42, + return_sample_correlation: Optional[bool] = None, return_average_correlation: bool = False, return_last_correlation: bool = False, skip_layers: bool = False, @@ -148,6 +149,16 @@ def __init__( **kwargs, ) + if return_sample_correlation is not None: + warnings.warn( + "'return_sample_correlation' parameter is deprecated and will be removed in future versions. " + f"Please use 'return_average_correlation' instead. " + f"Setting 'return_average_correlation' to {return_sample_correlation}", + DeprecationWarning, + ) + # Use the value of 'return_average_correlation' for 'return_sample_correlation' + return_average_correlation = return_sample_correlation + # Save metric-specific attributes. if similarity_func is None: similarity_func = correlation_spearman @@ -365,7 +376,9 @@ def __call__( for a_batch, a_batch_perturbed in zip( self.generate_a_batches(a_full_dataset), a_perturbed_generator ): - for a_instance, a_instance_perturbed in zip(a_batch, a_batch_perturbed): + for a_instance, a_instance_perturbed in zip( + a_batch, a_batch_perturbed + ): score = self.evaluate_instance( model=random_layer_model, x=None,