Skip to content

Commit

Permalink
* run black
Browse files Browse the repository at this point in the history
  • Loading branch information
aaarrti committed Oct 5, 2023
1 parent 4d1d25e commit 58d19ff
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 34 deletions.
2 changes: 1 addition & 1 deletion quantus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@
from quantus.helpers.model import *

# Expose the helpers utils.
from quantus.helpers.utils import *
from quantus.helpers.utils import *
12 changes: 3 additions & 9 deletions quantus/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def evaluate(
return None

if call_kwargs is None:
call_kwargs = {'call_kwargs_empty': {}}
call_kwargs = {"call_kwargs_empty": {}}
elif not isinstance(call_kwargs, Dict):
raise TypeError("xai_methods type is not Dict[str, Dict].")

Expand All @@ -92,11 +92,9 @@ def evaluate(
"xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]."

for method, value in xai_methods.items():

results[method] = {}

if callable(value):

explain_funcs[method] = value
explain_func = value

Expand All @@ -116,7 +114,6 @@ def evaluate(
asserts.assert_attributions(a_batch=a_batch, x_batch=x_batch)

elif isinstance(value, Dict):

if explain_func_kwargs is not None:
warnings.warn(
"Passed explain_func_kwargs will be ignored when passing type Dict[str, Dict] as xai_methods."
Expand All @@ -140,20 +137,17 @@ def evaluate(
a_batch = value

else:

raise TypeError(
"xai_methods type is not in: Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]."
)

if explain_func_kwargs is None:
explain_func_kwargs = {}

for (metric, metric_func) in metrics.items():

for metric, metric_func in metrics.items():
results[method][metric] = {}

for (call_kwarg_str, call_kwarg) in call_kwargs.items():

for call_kwarg_str, call_kwarg in call_kwargs.items():
if progress:
print(
f"Evaluating {method} explanations on {metric} metric on set of call parameters {call_kwarg_str}..."
Expand Down
1 change: 0 additions & 1 deletion quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def generate_tf_explanation(
)

elif method == "SmoothGrad":

num_samples = kwargs.get("num_samples", 5)
noise = kwargs.get("noise", 0.1)
explainer = tf_explain.core.smoothgrad.SmoothGrad()
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/loss_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mse(a: np.array, b: np.array, **kwargs) -> float:

if normalise:
# Calculate MSE in its polynomial expansion (a-b)^2 = a^2 - 2ab + b^2.
return np.average(((a ** 2) - (2 * (a * b)) + (b ** 2)), axis=0)
return np.average(((a**2) - (2 * (a * b)) + (b**2)), axis=0)
# If no need to normalise, return (a-b)^2.

return np.average(((a - b) ** 2), axis=0)
4 changes: 2 additions & 2 deletions quantus/functions/normalise_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,13 @@ def normalise_by_average_second_moment_estimate(

# Check that square root of the second momment estimatte is nonzero.
second_moment_sqrt = np.sqrt(
np.sum(a ** 2, axis=normalise_axes, keepdims=True)
np.sum(a**2, axis=normalise_axes, keepdims=True)
/ np.prod([a.shape[n] for n in normalise_axes])
)

if all(second_moment_sqrt != 0):
a /= np.sqrt(
np.sum(a ** 2, axis=normalise_axes, keepdims=True)
np.sum(a**2, axis=normalise_axes, keepdims=True)
/ np.prod([a.shape[n] for n in normalise_axes])
)
else:
Expand Down
2 changes: 1 addition & 1 deletion quantus/functions/similarity_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def lipschitz_constant(
b: np.array,
c: Union[np.array, None],
d: Union[np.array, None],
**kwargs
**kwargs,
) -> float:
"""
Calculate non-negative local Lipschitz abs(||a-b||/||c-d||), where a,b can be f(x) or a(x) and c,d is x.
Expand Down
2 changes: 1 addition & 1 deletion quantus/helpers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

# Import files dependent on package installations.
__EXTRAS__ = util.find_spec("captum") or util.find_spec("tf_explain")
__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow")
__MODELS__ = util.find_spec("torch") or util.find_spec("tensorflow")
1 change: 0 additions & 1 deletion quantus/helpers/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# Import different models depending on which deep learning framework is installed.
if util.find_spec("torch"):

import torch
import torch.nn as nn

Expand Down
29 changes: 17 additions & 12 deletions quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,11 @@ def _get_model_with_linear_top(self) -> torch.nn:
if isinstance(named_module[1], torch.nn.Softmax):
setattr(linear_model, named_module[0], torch.nn.Identity())

logging.info("Argument softmax=False passed, but the passed model contains a module of type "
"torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().", named_module[0])
logging.info(
"Argument softmax=False passed, but the passed model contains a module of type "
"torch.nn.Softmax. Module {} has been replaced with torch.nn.Identity().",
named_module[0],
)
break

return linear_model
Expand Down Expand Up @@ -118,8 +121,10 @@ def get_softmax_arg_model(self) -> torch.nn:
return self.model # Case 1

if self.softmax and not last_softmax:
logging.info("Argument softmax=True passed, but the passed model contains no module of type "
"torch.nn.Softmax. torch.nn.Softmax module is added as the output layer.")
logging.info(
"Argument softmax=True passed, but the passed model contains no module of type "
"torch.nn.Softmax. torch.nn.Softmax module is added as the output layer."
)
return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 3

if not self.softmax and not last_softmax:
Expand All @@ -133,12 +138,14 @@ def get_softmax_arg_model(self) -> torch.nn:
) # Warning for cases 2, 4, 5

if self.softmax and last_softmax != -1:
logging.info("Argument softmax=True passed. The passed model contains a module of type "
"torch.nn.Softmax, but it is not the last in the list of model's children ("
"self.model.modules()). torch.nn.Softmax module is added as the output layer."
"Make sure that the torch.nn.Softmax layer is the last module in the list "
"of model's children (self.model.modules()) if and only if it is the actual last module "
"applied before output.")
logging.info(
"Argument softmax=True passed. The passed model contains a module of type "
"torch.nn.Softmax, but it is not the last in the list of model's children ("
"self.model.modules()). torch.nn.Softmax module is added as the output layer."
"Make sure that the torch.nn.Softmax layer is the last module in the list "
"of model's children (self.model.modules()) if and only if it is the actual last module "
"applied before output."
)

return torch.nn.Sequential(self.model, torch.nn.Softmax(dim=-1)) # Case 2

Expand Down Expand Up @@ -337,7 +344,6 @@ def add_mean_shift_to_first_layer(
The resulting model with a shifted first layer.
"""
with torch.no_grad():

new_model = deepcopy(self.model)

modules = [l for l in new_model.named_modules()]
Expand All @@ -364,7 +370,6 @@ def get_hidden_representations(
layer_names: Optional[List[str]] = None,
layer_indices: Optional[List[int]] = None,
) -> np.ndarray:

"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
Expand Down
1 change: 0 additions & 1 deletion quantus/helpers/model/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,6 @@ def get_hidden_representations(
layer_indices: Optional[List[int]] = None,
**kwargs,
) -> np.ndarray:

"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
Expand Down
1 change: 0 additions & 1 deletion quantus/helpers/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def plot_model_parameter_randomisation_experiment(

plt.plot(layers, [np.mean(v) for k, v in scores.items()], label=method)
else:

layers = list(results.keys())
scores = {k: [] for k in layers}
# samples = len(results)
Expand Down
3 changes: 0 additions & 3 deletions quantus/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence
for start in range(0, len(x_shape) - len(a_shape) + 1)
]
if x_subshapes.count(a_shape) < 1:

# Check that attribution dimensions are (consecutive) subdimensions of inputs
raise ValueError(
"Attribution dimensions are not (consecutive) subdimensions of inputs: "
Expand All @@ -773,7 +772,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence
)
)
elif x_subshapes.count(a_shape) > 1:

# Check that attribution dimensions are (unique) subdimensions of inputs.
# Consider potentially expanded dims in attributions.

Expand All @@ -783,7 +781,6 @@ def infer_attribution_axes(a_batch: np.ndarray, x_batch: np.ndarray) -> Sequence
for start in range(0, len(np.shape(a_batch)[1:]) - len(a_shape) + 1)
]
if a_subshapes.count(a_shape) == 1:

# Inferring channel shape.
for dim in range(len(np.shape(a_batch)[1:]) + 1):
if a_shape == np.shape(a_batch)[1:][dim:]:
Expand Down

0 comments on commit 58d19ff

Please sign in to comment.