From bcd4430621baa6df26dedb37de9533c1e4c2dd3e Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Fri, 24 Nov 2023 12:02:20 +0100 Subject: [PATCH 1/2] update warning message for SmoothMPRT --- quantus/metrics/randomisation/smooth_mprt.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index 81b2efd4..169f3a91 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -6,7 +6,6 @@ # You should have received a copy of the GNU Lesser General Public License along with Quantus. If not, see . # Quantus project URL: . -import sys import warnings from typing import ( Any, @@ -499,14 +498,15 @@ def evaluate_instance( try: return self.similarity_func(a_perturbed_flat, a_flat) - except stats._warnings_errors.ConstantInputWarning: + except Exception as e: + print(f"Exception: {e}") warnings.warn( "Encountered constant input in similarity measure calculation.", UserWarning, ) return 1.0 - # Compute similarity measure. + # Compute similarity measure. return self.similarity_func(a_perturbed_flat, a_flat) def custom_preprocess( From 283f9ef9c72808d4bb776c6fd80cef9e0206e421 Mon Sep 17 00:00:00 2001 From: annahedstroem Date: Fri, 24 Nov 2023 12:04:50 +0100 Subject: [PATCH 2/2] small fixes warnings message --- quantus/metrics/randomisation/smooth_mprt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quantus/metrics/randomisation/smooth_mprt.py b/quantus/metrics/randomisation/smooth_mprt.py index 169f3a91..db84cab3 100644 --- a/quantus/metrics/randomisation/smooth_mprt.py +++ b/quantus/metrics/randomisation/smooth_mprt.py @@ -499,9 +499,9 @@ def evaluate_instance( try: return self.similarity_func(a_perturbed_flat, a_flat) except Exception as e: - print(f"Exception: {e}") + print(f"Encountered exception: {e} in similarity measure calculation") warnings.warn( - "Encountered constant input in similarity measure calculation.", + "Setting similarity output to 1.", UserWarning, ) return 1.0