Skip to content

Commit

Permalink
inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem committed Nov 24, 2023
1 parent cfb980f commit 4b44e2b
Show file tree
Hide file tree
Showing 2 changed files with 408 additions and 21 deletions.
90 changes: 69 additions & 21 deletions quantus/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,89 @@ def evaluate(
**kwargs,
) -> Optional[dict]:
"""
A method to evaluate some explanation methods given some metrics.
Evaluate different explanation methods using specified metrics.
Parameters
----------
metrics: dict
A dictionary with intialised metrics.
xai_methods: dict, list
Pass the different explanation methods as:
1) Dict[str, np.ndarray] where values are pre-calculcated attributions, or
2) Dict[str, Dict] where the keys are the name of the Quantus build-in explanation methods,
and the values are the explain function keyword arguments as a dictionary, or
3) Dict[str, Callable] where the keys are the name of explanation methods,
and the values a callable explanation function.
model: torch.nn.Module, tf.keras.Model
A torch or tensorflow model e.g., torchvision.models that is subject to explanation.
metrics : dict
A dictionary of initialized evaluation metrics. See quantus.AVAILABLE_METRICS.
Example: {'Robustness': quantus.MaxSensitivity(), 'Faithfulness': quantus.PixelFlipping()}
xai_methods : dict
A dictionary specifying the explanation methods to evaluate, which can be structured in three ways:
1) Dict[str, Dict] for built-in Quantus methods:
Example:
xai_methods = {
'IntegratedGradients': {
'n_steps': 10,
'xai_lib': 'captum'
},
'Saliency': {
'xai_lib': 'captum'
}
}
- See quantus.AVAILABLE_XAI_METHODS_CAPTUM for supported captum methods.
- See quantus.AVAILABLE_XAI_METHODS_TF for supported tensorflow methods.
- See https://github.com/chr5tphr/zennit for supported zennit methods.
- Read more about the explanation function arguments here:
<https://quantus.readthedocs.io/en/latest/docs_api/quantus.functions.explanation_func.html#quantus.functions.explanation_func.explain>
2) Dict[str, Callable] for custom methods:
Example:
xai_methods = {
'custom_own_xai_method': custom_explain_function
}
- Here, you can provide your own callable that mirrors the input and outputs of the quantus.explain() method.
3) Dict[str, np.ndarray] for pre-calculated attributions:
Example:
xai_methods = {
'LIME': precomputed_numpy_lime_attributions,
'GradientShap': precomputed_numpy_shap_attributions
}
- Note that some metrics, e.g., quantus.MaxSensitivity() within the robustness category,
requires the explanation function to be passed (as this is used in the evaluation logic).
It is also possible to pass a combination of the above.
model: Union[torch.nn.Module, tf.keras.Model]
A torch or tensorflow model that is subject to explanation.
x_batch: np.ndarray
A np.ndarray which contains the input data that are explained.
A np.ndarray containing the input data to be explained.
y_batch: np.ndarray
A np.ndarray which contains the output labels that are explained.
A np.ndarray containing the output labels corresponding to x_batch.
s_batch: np.ndarray, optional
A np.ndarray which contains segmentation masks that matches the input.
agg_func: callable
Indicates how to aggregates scores e.g., pass np.mean.
progress: boolean
Indicates if progress should be printed to std, or not.
A np.ndarray containing segmentation masks that match the input.
agg_func: Callable
Indicates how to aggregate scores, e.g., pass np.mean.
progress: bool
Indicates if progress should be printed to standard output.
explain_func_kwargs: dict, optional
Keyword arguments to be passed to explain_func on call. Pass None if using Dict[str, Dict] type for xai_methods.
call_kwargs: Dict[str, Dict]
Keyword arguments for the call of the metrics, keys are names for arg set and values are argument dictionaries.
Keyword arguments for the call of the metrics. Keys are names for argument sets, and values are argument dictionaries.
kwargs: optional
Deprecated keyword arguments for the call of the metrics.
Returns
-------
results: dict
A dictionary with the results.
A dictionary with the evaluation results.
"""

warn.check_kwargs(kwargs)
Expand Down
Loading

0 comments on commit 4b44e2b

Please sign in to comment.