From 58d19ff683112ce591568df3773f53e0247cbd04 Mon Sep 17 00:00:00 2001 From: aaarrti Date: Thu, 5 Oct 2023 17:05:37 +0200 Subject: [PATCH] * run black --- quantus/__init__.py | 2 +- quantus/evaluation.py | 12 +++-------- quantus/functions/explanation_func.py | 1 - quantus/functions/loss_func.py | 2 +- quantus/functions/normalise_func.py | 4 ++-- quantus/functions/similarity_func.py | 2 +- quantus/helpers/__init__.py | 2 +- quantus/helpers/model/models.py | 1 - quantus/helpers/model/pytorch_model.py | 29 +++++++++++++++----------- quantus/helpers/model/tf_model.py | 1 - quantus/helpers/plotting.py | 1 - quantus/helpers/utils.py | 3 --- 12 files changed, 26 insertions(+), 34 deletions(-) diff --git a/quantus/__init__.py b/quantus/__init__.py index c29dc0a35..485b9d851 100644 --- a/quantus/__init__.py +++ b/quantus/__init__.py @@ -26,4 +26,4 @@ from quantus.helpers.model import * # Expose the helpers utils. -from quantus.helpers.utils import * \ No newline at end of file +from quantus.helpers.utils import * diff --git a/quantus/evaluation.py b/quantus/evaluation.py index 75d628ee3..8c37a24a7 100644 --- a/quantus/evaluation.py +++ b/quantus/evaluation.py @@ -81,7 +81,7 @@ def evaluate( return None if call_kwargs is None: - call_kwargs = {'call_kwargs_empty': {}} + call_kwargs = {"call_kwargs_empty": {}} elif not isinstance(call_kwargs, Dict): raise TypeError("xai_methods type is not Dict[str, Dict].") @@ -92,11 +92,9 @@ def evaluate( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." for method, value in xai_methods.items(): - results[method] = {} if callable(value): - explain_funcs[method] = value explain_func = value @@ -116,7 +114,6 @@ def evaluate( asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch) elif isinstance(value, Dict): - if explain_func_kwargs is not None: warnings.warn( "Passed explain_func_kwargs will be ignored when passing type Dict[str, Dict] as xai_methods." @@ -140,7 +137,6 @@ def evaluate( a_batch = value else: - raise TypeError( "xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]." ) @@ -148,12 +144,10 @@ def evaluate( if explain_func_kwargs is None: explain_func_kwargs = {} - for (metric, metric_func) in metrics.items(): - + for metric, metric_func in metrics.items(): results[method][metric] = {} - for (call_kwarg_str, call_kwarg) in call_kwargs.items(): - + for call_kwarg_str, call_kwarg in call_kwargs.items(): if progress: print( f"Evaluating {method} explanations on {metric} metric on set of call parameters {call_kwarg_str}..." diff --git a/quantus/functions/explanation_func.py b/quantus/functions/explanation_func.py index d34260f3f..fbc4de9ba 100644 --- a/quantus/functions/explanation_func.py +++ b/quantus/functions/explanation_func.py @@ -385,7 +385,6 @@ def generate_tf_explanation( ) elif method == "SmoothGrad": - num_samples = kwargs.get("num_samples", 5) noise = kwargs.get("noise", 0.1) explainer = tf_explain.core.smoothgrad.SmoothGrad() diff --git a/quantus/functions/loss_func.py b/quantus/functions/loss_func.py index 69181e1f3..bd0723645 100644 --- a/quantus/functions/loss_func.py +++ b/quantus/functions/loss_func.py @@ -34,7 +34,7 @@ def mse(a: np.array, b: np.array, **kwargs) -> float: if normalise: # Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2. - return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0) + return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=0) # If no need to normalise, return (a-b)^2. return np.average(((a - b) ** 2), axis=0) diff --git a/quantus/functions/normalise_func.py b/quantus/functions/normalise_func.py index 2d518aad4..5ad7f8145 100644 --- a/quantus/functions/normalise_func.py +++ b/quantus/functions/normalise_func.py @@ -231,13 +231,13 @@ def normalise_by_average_second_moment_estimate( # Check that square root of the second momment estimatte is nonzero. second_moment_sqrt = np.sqrt( - np.sum(a ** 2, axis=normalise_axes, keepdims=True) + np.sum(a**2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) if all(second_moment_sqrt != 0): a /= np.sqrt( - np.sum(a ** 2, axis=normalise_axes, keepdims=True) + np.sum(a**2, axis=normalise_axes, keepdims=True) / np.prod([a.shape[n] for n in normalise_axes]) ) else: diff --git a/quantus/functions/similarity_func.py b/quantus/functions/similarity_func.py index 88d19a9a7..93ffe8832 100644 --- a/quantus/functions/similarity_func.py +++ b/quantus/functions/similarity_func.py @@ -145,7 +145,7 @@ def lipschitz_constant( b: np.array, c: Union[np.array, None], d: Union[np.array, None], - **kwargs + **kwargs, ) -> float: """ Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x. diff --git a/quantus/helpers/__init__.py b/quantus/helpers/__init__.py index eafc68d63..e4d00f57e 100644 --- a/quantus/helpers/__init__.py +++ b/quantus/helpers/__init__.py @@ -8,4 +8,4 @@ # Import files dependent on package installations. __EXTRAS__ = util.find_spec("captum") or util.find_spec("tf_explain") -__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") \ No newline at end of file +__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow") diff --git a/quantus/helpers/model/models.py b/quantus/helpers/model/models.py index 38d0ca004..2feb32c55 100644 --- a/quantus/helpers/model/models.py +++ b/quantus/helpers/model/models.py @@ -12,7 +12,6 @@ # Import different models depending on which deep learning framework is installed. if util.find_spec("torch"): - import torch import torch.nn as nn diff --git a/quantus/helpers/model/pytorch_model.py b/quantus/helpers/model/pytorch_model.py index cd422931f..b79442cb8 100644 --- a/quantus/helpers/model/pytorch_model.py +++ b/quantus/helpers/model/pytorch_model.py @@ -86,8 +86,11 @@ def _get_model_with_linear_top(self) -> torch.nn: if isinstance(named_module[1], torch.nn.Softmax): setattr(linear_model, named_module[0], torch.nn.Identity()) - logging.info("Argument softmax=False passed, but the passed model contains a module of type " - "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", named_module[0]) + logging.info( + "Argument softmax=False passed, but the passed model contains a module of type " + "torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", + named_module[0], + ) break return linear_model @@ -118,8 +121,10 @@ def get_softmax_arg_model(self) -> torch.nn: return self.model # Case 1 if self.softmax and not last_softmax: - logging.info("Argument softmax=True passed, but the passed model contains no module of type " - "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer.") + logging.info( + "Argument softmax=True passed, but the passed model contains no module of type " + "torch.nn.Softmax. torch.nn.Softmax module is added as the output layer." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 3 if not self.softmax and not last_softmax: @@ -133,12 +138,14 @@ def get_softmax_arg_model(self) -> torch.nn: ) # Warning for cases 2, 4, 5 if self.softmax and last_softmax != -1: - logging.info("Argument softmax=True passed. The passed model contains a module of type " - "torch.nn.Softmax, but it is not the last in the list of model's children (" - "self.model.modules()). torch.nn.Softmax module is added as the output layer." - "Make sure that the torch.nn.Softmax layer is the last module in the list " - "of model's children (self.model.modules()) if and only if it is the actual last module " - "applied before output.") + logging.info( + "Argument softmax=True passed. The passed model contains a module of type " + "torch.nn.Softmax, but it is not the last in the list of model's children (" + "self.model.modules()). torch.nn.Softmax module is added as the output layer." + "Make sure that the torch.nn.Softmax layer is the last module in the list " + "of model's children (self.model.modules()) if and only if it is the actual last module " + "applied before output." + ) return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 2 @@ -337,7 +344,6 @@ def add_mean_shift_to_first_layer( The resulting model with a shifted first layer. """ with torch.no_grad(): - new_model = deepcopy(self.model) modules = [l for l in new_model.named_modules()] @@ -364,7 +370,6 @@ def get_hidden_representations( layer_names: Optional[List[str]] = None, layer_indices: Optional[List[int]] = None, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/model/tf_model.py b/quantus/helpers/model/tf_model.py index 7a1769147..c72b06860 100644 --- a/quantus/helpers/model/tf_model.py +++ b/quantus/helpers/model/tf_model.py @@ -360,7 +360,6 @@ def get_hidden_representations( layer_indices: Optional[List[int]] = None, **kwargs, ) -> np.ndarray: - """ Compute the model's internal representation of input x. In practice, this means, executing a forward pass and then, capturing the output of layers (of interest). diff --git a/quantus/helpers/plotting.py b/quantus/helpers/plotting.py index 4fa305713..fd504556a 100644 --- a/quantus/helpers/plotting.py +++ b/quantus/helpers/plotting.py @@ -245,7 +245,6 @@ def plot_model_parameter_randomisation_experiment( plt.plot(layers, [np.mean(v) for k, v in scores.items()], label=method) else: - layers = list(results.keys()) scores = {k: [] for k in layers} # samples = len(results) diff --git a/quantus/helpers/utils.py b/quantus/helpers/utils.py index 196bdb33f..0ec29da91 100644 --- a/quantus/helpers/utils.py +++ b/quantus/helpers/utils.py @@ -764,7 +764,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(x_shape) - len(a_shape) + 1) ] if x_subshapes.count(a_shape) < 1: - # Check that attribution dimensions are (consecutive) subdimensions of inputs raise ValueError( "Attribution dimensions are not (consecutive) subdimensions of inputs: " @@ -773,7 +772,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence ) ) elif x_subshapes.count(a_shape) > 1: - # Check that attribution dimensions are (unique) subdimensions of inputs. # Consider potentially expanded dims in attributions. @@ -783,7 +781,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence for start in range(0, len(np.shape(a_batch)[1:]) - len(a_shape) + 1) ] if a_subshapes.count(a_shape) == 1: - # Inferring channel shape. for dim in range(len(np.shape(a_batch)[1:]) + 1): if a_shape == np.shape(a_batch)[1:][dim:]: