Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lappalainenj committed Oct 24, 2024
1 parent 105061d commit 152bee1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
7 changes: 4 additions & 3 deletions examples/05_flyvision_umap_and_clustering_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1517,9 +1517,10 @@
},
"outputs": [],
"source": [
"# recommended to only run with precomputed responses using the record script\n",
"norm = ensemble.responses_norm()\n",
"responses = stims_and_resps[\"responses\"] / (norm + 1e-6)"
"# recommended to only run with precomputed responses using the pipeline manager script,\n",
"# see example_submissions.sh in the repository\n",
"# norm = ensemble.responses_norm()\n",
"# responses = stims_and_resps[\"responses\"] / (norm + 1e-6)"
]
},
{
Expand Down
7 changes: 6 additions & 1 deletion flyvision/utils/chkpt_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Union
Expand Down Expand Up @@ -135,7 +136,11 @@ def get_from_state_dict(state_dict: Union[Dict, Path, str], key: str) -> Dict:
if state_dict is None:
return None
if isinstance(state_dict, (Path, str)):
state = torch.load(state_dict, map_location=flyvision.device).pop(key, None)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=FutureWarning)
state = torch.load(
state_dict, map_location=flyvision.device, weights_only=False
).pop(key, None)
elif isinstance(state_dict, dict):
state = state_dict.get(key, None)
else:
Expand Down
16 changes: 13 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,17 @@ readme = "README.md"
requires-python = ">=3.6"

[project.optional-dependencies]
dev = ["pre-commit", "ruff", "pytest", "jupyter", "papermill"]
dev = [
"pre-commit",
"ruff",
"pytest",
"jupyter",
"papermill",
"tabulate",
"tqdm",
"ipywidgets",
]
docs = [
# Documentation
"mkdocs",
"mkdocs-material",
"markdown-include",
Expand All @@ -67,8 +75,10 @@ docs = [
"mkdocs-macros-plugin",
"jupyter",
"tabulate",
"tqdm",
"ipywidgets",
]
examples = ["jupyter", "tabulate"]
examples = ["jupyter", "tabulate", "tqdm", "ipywidgets"]
no-version-pins = [
"numba",
"matplotlib",
Expand Down

0 comments on commit 152bee1

Please sign in to comment.