Skip to content

Commit

Permalink
* revert typing changes, update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
aaarrti committed Oct 18, 2023
1 parent e537708 commit f0f90d9
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 35 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ on:
pull_request:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: true

jobs:
run:
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ on:
pull_request:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.ref }}
cancel-in-progress: true

jobs:
lint:
Expand Down
8 changes: 6 additions & 2 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ BUILDDIR = build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile
.PHONY: help Makefile clean

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

rst:
rst: clean
@sphinx-apidoc -o source/docs_api ../quantus --module-first --separate --force


clean:
rm -rf source/docs_api
7 changes: 7 additions & 0 deletions docs/source/docs_api/quantus.helpers.enums.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.helpers.enums module
============================

.. automodule:: quantus.helpers.enums
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/docs_api/quantus.helpers.perturbation_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.helpers.perturbation\_utils module
==========================================

.. automodule:: quantus.helpers.perturbation_utils
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/source/docs_api/quantus.helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ Submodules

quantus.helpers.asserts
quantus.helpers.constants
quantus.helpers.enums
quantus.helpers.perturbation_utils
quantus.helpers.plotting
quantus.helpers.utils
quantus.helpers.warn
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ignore_missing_imports = True
no_site_packages = True
show_none_errors = False
ignore_errors = False
plugins = numpy.typing.mypy_plugin

[mypy-quantus.*]
disallow_untyped_defs = False
Expand Down
67 changes: 34 additions & 33 deletions quantus/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Sequence,
Set,
TypeVar,
Optional,
)

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -73,10 +74,10 @@ def __init__(
abs: bool,
normalise: bool,
normalise_func: Callable,
normalise_func_kwargs: Dict[str, ...] | None,
normalise_func_kwargs: Optional[Dict[str, Any]],
return_aggregate: bool,
aggregate_func: Callable,
default_plot_func: Callable[[...], None] | None,
default_plot_func: Optional[Callable],
disable_warnings: bool,
display_progressbar: bool,
**kwargs,
Expand Down Expand Up @@ -145,15 +146,15 @@ def __call__(
self,
model,
x_batch: np.ndarray,
y_batch: np.ndarray | None,
a_batch: np.ndarray | None,
s_batch: np.ndarray | None,
channel_first: bool | None,
explain_func: Callable[[...], None] | None,
explain_func_kwargs: Dict[str, ...] | None,
model_predict_kwargs: Dict[str, ...] | None,
softmax: bool | None,
device: str | None = None,
y_batch: Optional[np.ndarray],
a_batch: Optional[np.ndarray],
s_batch: Optional[np.ndarray],
channel_first: Optional[bool],
explain_func: Optional[Callable],
explain_func_kwargs: Optional[Dict],
model_predict_kwargs: Optional[Dict],
softmax: Optional[bool],
device: Optional[str] = None,
batch_size: int = 64,
custom_batch: Any = None,
**kwargs,
Expand Down Expand Up @@ -301,7 +302,7 @@ def evaluate_batch(
x_batch: np.ndarray,
y_batch: np.ndarray,
a_batch: np.ndarray,
s_batch: np.ndarray | None,
s_batch: Optional[np.ndarray],
**kwargs,
):
"""
Expand Down Expand Up @@ -334,17 +335,17 @@ def general_preprocess(
self,
model,
x_batch: np.ndarray,
y_batch: np.ndarray | None,
a_batch: np.ndarray | None,
s_batch: np.ndarray | None,
channel_first: bool | None,
y_batch: Optional[np.ndarray],
a_batch: Optional[np.ndarray],
s_batch: Optional[np.ndarray],
channel_first: Optional[bool],
explain_func: Callable,
explain_func_kwargs: Dict[str, ...] | None,
model_predict_kwargs: Dict[str, ...] | None,
explain_func_kwargs: Optional[Dict[str, Any]],
model_predict_kwargs: Optional[Dict[str, Any]],
softmax: bool,
device: str | None,
custom_batch: np.ndarray | None,
) -> Dict[str, ...]:
device: Optional[str],
custom_batch: Optional[np.ndarray],
) -> Dict[str, Any]:
"""
Prepares all necessary variables for evaluation.
Expand Down Expand Up @@ -463,11 +464,11 @@ def custom_preprocess(
self,
model: ModelInterface,
x_batch: np.ndarray,
y_batch: np.ndarray | None,
a_batch: np.ndarray | None,
s_batch: np.ndarray | None,
y_batch: Optional[np.ndarray],
a_batch: Optional[np.ndarray],
s_batch: Optional[np.ndarray],
custom_batch: Any,
) -> Dict[str, ...] | None:
) -> Optional[Dict[str, Any]]:
"""
Implement this method if you need custom preprocessing of data,
model alteration or simply for creating/initialising additional
Expand Down Expand Up @@ -606,9 +607,9 @@ def custom_postprocess(
self,
model: ModelInterface,
x_batch: np.ndarray,
y_batch: np.ndarray | None,
a_batch: np.ndarray | None,
s_batch: np.ndarray | None,
y_batch: Optional[np.ndarray],
a_batch: Optional[np.ndarray],
s_batch: Optional[np.ndarray],
**kwargs,
):
"""
Expand Down Expand Up @@ -712,9 +713,9 @@ def generate_batches(

def plot(
self,
plot_func: Callable[[...], None] | None = None,
plot_func: Optional[Callable] = None,
show: bool = True,
path_to_save: str | None = None,
path_to_save: Optional[str] = None,
*args,
**kwargs,
) -> None:
Expand Down Expand Up @@ -782,7 +783,7 @@ def get_params(self) -> Dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if k not in attr_exclude}

@final
def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]:
def batch_preprocess(self, data_batch: Dict[str, Any]) -> Dict[str, Any]:
"""
If `data_batch` has no `a_batch`, will compute explanations.
This needs to be done on batch level to avoid OOM. Additionally will set `a_axes` property if it is None,
Expand All @@ -809,8 +810,8 @@ def batch_preprocess(self, data_batch: Dict[str, ...]) -> Dict[str, ...]:
return data_batch

def custom_batch_preprocess(
self, data_batch: Dict[str, ...]
) -> Dict[str, ...] | None:
self, data_batch: Dict[str, Any]
) -> Optional[Dict[str, ...]]:
"""
Implement this method if you need custom preprocessing of data
or simply for creating/initialising additional attributes or assertions
Expand Down

0 comments on commit f0f90d9

Please sign in to comment.