diff --git a/examples/05_flyvision_umap_and_clustering_models.ipynb b/examples/05_flyvision_umap_and_clustering_models.ipynb index 003e653..d4d1f29 100644 --- a/examples/05_flyvision_umap_and_clustering_models.ipynb +++ b/examples/05_flyvision_umap_and_clustering_models.ipynb @@ -1517,9 +1517,10 @@ }, "outputs": [], "source": [ - "# recommended to only run with precomputed responses using the record script\n", - "norm = ensemble.responses_norm()\n", - "responses = stims_and_resps[\"responses\"] / (norm + 1e-6)" + "# recommended to only run with precomputed responses using the pipeline manager script,\n", + "# see example_submissions.sh in the repository\n", + "# norm = ensemble.responses_norm()\n", + "# responses = stims_and_resps[\"responses\"] / (norm + 1e-6)" ] }, { diff --git a/flyvision/utils/chkpt_utils.py b/flyvision/utils/chkpt_utils.py index 2ba3733..ffa656f 100644 --- a/flyvision/utils/chkpt_utils.py +++ b/flyvision/utils/chkpt_utils.py @@ -1,4 +1,5 @@ import logging +import warnings from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Tuple, Union @@ -135,7 +136,11 @@ def get_from_state_dict(state_dict: Union[Dict, Path, str], key: str) -> Dict: if state_dict is None: return None if isinstance(state_dict, (Path, str)): - state = torch.load(state_dict, map_location=flyvision.device).pop(key, None) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + state = torch.load( + state_dict, map_location=flyvision.device, weights_only=False + ).pop(key, None) elif isinstance(state_dict, dict): state = state_dict.get(key, None) else: diff --git a/pyproject.toml b/pyproject.toml index a8bb1c7..f68a52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,9 +54,17 @@ readme = "README.md" requires-python = ">=3.6" [project.optional-dependencies] -dev = ["pre-commit", "ruff", "pytest", "jupyter", "papermill"] +dev = [ + "pre-commit", + "ruff", + "pytest", + "jupyter", + "papermill", + "tabulate", + "tqdm", + "ipywidgets", +] docs = [ - # Documentation "mkdocs", "mkdocs-material", "markdown-include", @@ -67,8 +75,10 @@ docs = [ "mkdocs-macros-plugin", "jupyter", "tabulate", + "tqdm", + "ipywidgets", ] -examples = ["jupyter", "tabulate"] +examples = ["jupyter", "tabulate", "tqdm", "ipywidgets"] no-version-pins = [ "numba", "matplotlib",