From ef6d956e4202e04c1b783b112210bcae893c9d37 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 6 Jun 2024 11:01:46 -0600 Subject: [PATCH] Improve import times with full (#2983) * move scipy stats and scipy imports inside * test more strict * forgotten import * more lenient test --- .github/import_test.py | 9 ++-- src/spikeinterface/exporters/report.py | 3 +- src/spikeinterface/exporters/to_phy.py | 6 --- .../extractors/matlabhelpers.py | 2 - .../qualitymetrics/pca_metrics.py | 19 +++++---- src/spikeinterface/sorters/container_tools.py | 23 +++++----- .../sorters/external/combinato.py | 12 +++--- .../sorters/external/pykilosort.py | 19 +++++---- .../sorters/internal/spyking_circus2.py | 15 ++++--- .../sortingcomponents/matching/circus.py | 42 +++++++++---------- src/spikeinterface/widgets/utils.py | 16 ++++--- 11 files changed, 78 insertions(+), 88 deletions(-) diff --git a/.github/import_test.py b/.github/import_test.py index 6a6ac30f2e..f7c3e9f858 100644 --- a/.github/import_test.py +++ b/.github/import_test.py @@ -45,11 +45,10 @@ time_taken = float(result.stdout.strip()) time_taken_list.append(time_taken) - # for time in time_taken_list: - # Uncomment once exporting import is fixed - # if time > 2.5: - # exceptions.append(f"Importing {import_statement} took too long: {time:.2f} seconds") - # break + for time in time_taken_list: + if time > 1.5: + exceptions.append(f"Importing {import_statement} took too long: {time:.2f} seconds") + break if time_taken_list: avg_time_taken = sum(time_taken_list) / len(time_taken_list) diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 95d3713065..e12bb9b588 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -6,8 +6,7 @@ from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs import spikeinterface.widgets as sw from spikeinterface.core import get_template_extremum_channel, get_template_extremum_amplitude -from spikeinterface.postprocessing import compute_spike_amplitudes, compute_unit_locations, compute_correlograms -from spikeinterface.qualitymetrics import compute_quality_metrics +from spikeinterface.postprocessing import compute_correlograms def export_report( diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 551431fe09..d7be6c1ba3 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -7,7 +7,6 @@ import shutil import warnings -import spikeinterface from spikeinterface.core import ( write_binary_recording, BinaryRecordingExtractor, @@ -16,11 +15,6 @@ SortingAnalyzer, ) from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs -from spikeinterface.postprocessing import ( - compute_spike_amplitudes, - compute_template_similarity, - compute_principal_components, -) def export_to_phy( diff --git a/src/spikeinterface/extractors/matlabhelpers.py b/src/spikeinterface/extractors/matlabhelpers.py index 1c2e1491c8..e9948575a2 100644 --- a/src/spikeinterface/extractors/matlabhelpers.py +++ b/src/spikeinterface/extractors/matlabhelpers.py @@ -3,8 +3,6 @@ from pathlib import Path from collections import deque -import numpy as np - class MatlabHelper: extractor_name = "MATSortingExtractor" diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index bfeb514ac8..2915cee8ec 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -9,14 +9,6 @@ from tqdm.auto import tqdm from concurrent.futures import ProcessPoolExecutor -try: - import scipy.stats - import scipy.spatial.distance - from sklearn.discriminant_analysis import LinearDiscriminantAnalysis - from sklearn.neighbors import NearestNeighbors - from sklearn.decomposition import IncrementalPCA -except: - pass import warnings @@ -237,6 +229,8 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id): ---------- Based on metrics described in [Schmitzer-Torbert]_ """ + import scipy.stats + import scipy.spatial.distance pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :] pcs_for_other_units = all_pcs[all_labels != this_unit_id, :] @@ -291,6 +285,7 @@ def lda_metrics(all_pcs, all_labels, this_unit_id): ---------- Based on metric described in [Hill]_ """ + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis X = all_pcs @@ -348,6 +343,7 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n ---------- Based on metrics described in [Chung]_ """ + from sklearn.neighbors import NearestNeighbors total_spikes = all_pcs.shape[0] ratio = max_spikes / total_spikes @@ -474,6 +470,8 @@ def nearest_neighbors_isolation( ---------- Based on isolation metric described in [Chung]_ """ + from sklearn.decomposition import IncrementalPCA + rng = np.random.default_rng(seed=seed) waveforms_ext = sorting_analyzer.get_extension("waveforms") @@ -669,6 +667,8 @@ def nearest_neighbors_noise_overlap( ---------- Based on noise overlap metric described in [Chung]_ """ + from sklearn.decomposition import IncrementalPCA + rng = np.random.default_rng(seed=seed) waveforms_ext = sorting_analyzer.get_extension("waveforms") @@ -796,6 +796,7 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id): ---------- Based on simplified silhouette score suggested by [Hruschka]_ """ + import scipy.spatial.distance pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :] centroid_for_this_unit = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0) @@ -846,6 +847,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id): ---------- Based on [Rousseeuw]_ """ + import scipy.spatial.distance pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :] distances_for_this_unit = scipy.spatial.distance.cdist(pcs_for_this_unit, pcs_for_this_unit) @@ -905,6 +907,7 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): (1) ranges from 0 to 1; and (2) is symmetric, i.e. Isolation(A, B) = Isolation(B, A) """ + from sklearn.neighbors import NearestNeighbors # get lengths n_spikes_target = pcs_target_unit.shape[0] diff --git a/src/spikeinterface/sorters/container_tools.py b/src/spikeinterface/sorters/container_tools.py index b0ee73e21c..60eb080ae5 100644 --- a/src/spikeinterface/sorters/container_tools.py +++ b/src/spikeinterface/sorters/container_tools.py @@ -7,11 +7,6 @@ import string # TODO move this inside functions -try: - HAS_DOCKER = True - import docker -except ModuleNotFoundError: - HAS_DOCKER = False from spikeinterface.core.core_tools import recursive_path_modifier @@ -83,8 +78,8 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): container_requires_gpu = extra_kwargs.get("container_requires_gpu", None) if mode == "docker": - if not HAS_DOCKER: - raise ModuleNotFoundError("No module named 'docker'") + import docker + client = docker.from_env() if container_requires_gpu is not None: extra_kwargs.pop("container_requires_gpu") @@ -108,12 +103,12 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): elif Path(sif_file).exists(): singularity_image = sif_file else: - if HAS_DOCKER: - docker_image = self._get_docker_image(container_image) - if docker_image and len(docker_image.tags) > 0: - tag = docker_image.tags[0] - print(f"Building singularity image from local docker image: {tag}") - singularity_image = Client.build(f"docker-daemon://{tag}", sif_file, sudo=False) + + docker_image = self._get_docker_image(container_image) + if docker_image and len(docker_image.tags) > 0: + tag = docker_image.tags[0] + print(f"Building singularity image from local docker image: {tag}") + singularity_image = Client.build(f"docker-daemon://{tag}", sif_file, sudo=False) if not singularity_image: print(f"Singularity: pulling image {container_image}") singularity_image = Client.pull(f"docker://{container_image}") @@ -134,6 +129,8 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs): @staticmethod def _get_docker_image(container_image): + import docker + docker_client = docker.from_env(timeout=300) try: docker_image = docker_client.images.get(container_image) diff --git a/src/spikeinterface/sorters/external/combinato.py b/src/spikeinterface/sorters/external/combinato.py index de946633ac..082c1d172e 100644 --- a/src/spikeinterface/sorters/external/combinato.py +++ b/src/spikeinterface/sorters/external/combinato.py @@ -12,12 +12,6 @@ from spikeinterface.extractors import CombinatoSortingExtractor from spikeinterface.preprocessing import ScaleRecording -try: - import h5py - - HAVE_H5PY = True -except ImportError: - HAVE_H5PY = False PathType = Union[str, Path] @@ -128,6 +122,12 @@ def _check_apply_filter_in_params(cls, params): @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): + try: + import h5py + + HAVE_H5PY = True + except ImportError: + HAVE_H5PY = False assert HAVE_H5PY, "You must install h5py for combinato" # Generate h5 files in the dataset directory chan_ids = recording.get_channel_ids() diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index dfe77501f7..9d0aab9702 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -10,14 +10,6 @@ import json from ..basesorter import BaseSorter, get_job_kwargs -try: - import pykilosort - from pykilosort import Bunch, add_default_handler, run - - HAVE_PYKILOSORT = True -except ImportError: - HAVE_PYKILOSORT = False - class PyKilosortSorter(BaseSorter): """Pykilosort Sorter object.""" @@ -128,10 +120,19 @@ class PyKilosortSorter(BaseSorter): @classmethod def is_installed(cls): + try: + import pykilosort + + HAVE_PYKILOSORT = True + except ImportError: + HAVE_PYKILOSORT = False + return HAVE_PYKILOSORT @classmethod def get_sorter_version(cls): + import pykilosort + return pykilosort.__version__ @classmethod @@ -150,6 +151,8 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): + from pykilosort import Bunch, run + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 05853b4c39..f2c385b718 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -22,14 +22,6 @@ from spikeinterface.core.analyzer_extension_core import ComputeTemplates -try: - import hdbscan - - HAVE_HDBSCAN = True -except: - HAVE_HDBSCAN = False - - class Spykingcircus2Sorter(ComponentsBasedSorter): sorter_name = "spykingcircus2" @@ -100,6 +92,13 @@ def get_sorter_version(cls): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): + try: + import hdbscan + + HAVE_HDBSCAN = True + except: + HAVE_HDBSCAN = False + assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed" # this is importanted only on demand because numba import are too heavy diff --git a/src/spikeinterface/sortingcomponents/matching/circus.py b/src/spikeinterface/sortingcomponents/matching/circus.py index f78dd2a070..183fdab04c 100644 --- a/src/spikeinterface/sortingcomponents/matching/circus.py +++ b/src/spikeinterface/sortingcomponents/matching/circus.py @@ -5,31 +5,10 @@ import numpy as np - -try: - import sklearn - from sklearn.feature_extraction.image import extract_patches_2d - - HAVE_SKLEARN = True -except ImportError: - HAVE_SKLEARN = False - - from spikeinterface.core import get_noise_levels from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel from spikeinterface.core.template import Templates -try: - import scipy.spatial - - import scipy - - (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) - - (nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32) -except: - pass - spike_dtype = [ ("sample_index", "int64"), ("channel_index", "int64"), @@ -74,6 +53,9 @@ def compress_templates(templates_array, approx_rank, remove_mean=True, return_ne def compute_overlaps(templates, num_samples, num_channels, sparsities): + import scipy.spatial + import scipy + num_templates = len(templates) dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32) @@ -278,6 +260,13 @@ def get_margin(cls, recording, kwargs): @classmethod def main_function(cls, traces, d): + import scipy.spatial + import scipy + + (potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32) + + (nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32) + num_templates = d["num_templates"] num_samples = d["num_samples"] num_channels = d["num_channels"] @@ -543,6 +532,9 @@ class CircusPeeler(BaseTemplateMatchingEngine): @classmethod def _prepare_templates(cls, d): + import scipy.spatial + import scipy + templates = d["templates"] num_samples = d["num_samples"] num_channels = d["num_channels"] @@ -634,6 +626,13 @@ def _prepare_templates(cls, d): @classmethod def initialize_and_check_kwargs(cls, recording, kwargs): + try: + from sklearn.feature_extraction.image import extract_patches_2d + + HAVE_SKLEARN = True + except ImportError: + HAVE_SKLEARN = False + assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work" d = cls._default_params.copy() d.update(kwargs) @@ -732,6 +731,7 @@ def main_function(cls, traces, d): peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks( peak_traces, peak_sign, abs_threholds, exclude_sweep_size ) + from sklearn.feature_extraction.image import extract_patches_2d if jitter > 0: jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter) diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 337e253cfa..8677c788a2 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -1,9 +1,7 @@ from __future__ import annotations import numpy as np -import random -from ..core import ChannelSparsity try: import distinctipy @@ -12,13 +10,6 @@ except ImportError: HAVE_DISTINCTIPY = False -try: - import matplotlib.pyplot as plt - - HAVE_MPL = True -except ImportError: - HAVE_MPL = False - def get_some_colors( keys, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None, margin=None @@ -50,6 +41,13 @@ def get_some_colors( A dict of colors for given keys. """ + try: + import matplotlib.pyplot as plt + + HAVE_MPL = True + except ImportError: + HAVE_MPL = False + assert color_engine in ("auto", "distinctipy", "matplotlib", "colorsys") possible_formats = ("RGBA",)