Skip to content

Commit

Permalink
Improve import times with full (#2983)
Browse files Browse the repository at this point in the history
* move scipy stats and scipy imports inside

* test more strict

* forgotten import

* more lenient test
  • Loading branch information
h-mayorquin authored Jun 6, 2024
1 parent a8ed7c8 commit ef6d956
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 88 deletions.
9 changes: 4 additions & 5 deletions .github/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@
time_taken = float(result.stdout.strip())
time_taken_list.append(time_taken)

# for time in time_taken_list:
# Uncomment once exporting import is fixed
# if time > 2.5:
# exceptions.append(f"Importing {import_statement} took too long: {time:.2f} seconds")
# break
for time in time_taken_list:
if time > 1.5:
exceptions.append(f"Importing {import_statement} took too long: {time:.2f} seconds")
break

if time_taken_list:
avg_time_taken = sum(time_taken_list) / len(time_taken_list)
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/exporters/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs
import spikeinterface.widgets as sw
from spikeinterface.core import get_template_extremum_channel, get_template_extremum_amplitude
from spikeinterface.postprocessing import compute_spike_amplitudes, compute_unit_locations, compute_correlograms
from spikeinterface.qualitymetrics import compute_quality_metrics
from spikeinterface.postprocessing import compute_correlograms


def export_report(
Expand Down
6 changes: 0 additions & 6 deletions src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import shutil
import warnings

import spikeinterface
from spikeinterface.core import (
write_binary_recording,
BinaryRecordingExtractor,
Expand All @@ -16,11 +15,6 @@
SortingAnalyzer,
)
from spikeinterface.core.job_tools import _shared_job_kwargs_doc, fix_job_kwargs
from spikeinterface.postprocessing import (
compute_spike_amplitudes,
compute_template_similarity,
compute_principal_components,
)


def export_to_phy(
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/extractors/matlabhelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from pathlib import Path
from collections import deque

import numpy as np


class MatlabHelper:
extractor_name = "MATSortingExtractor"
Expand Down
19 changes: 11 additions & 8 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,6 @@
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor

try:
import scipy.stats
import scipy.spatial.distance
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import IncrementalPCA
except:
pass

import warnings

Expand Down Expand Up @@ -237,6 +229,8 @@ def mahalanobis_metrics(all_pcs, all_labels, this_unit_id):
----------
Based on metrics described in [Schmitzer-Torbert]_
"""
import scipy.stats
import scipy.spatial.distance

pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :]
pcs_for_other_units = all_pcs[all_labels != this_unit_id, :]
Expand Down Expand Up @@ -291,6 +285,7 @@ def lda_metrics(all_pcs, all_labels, this_unit_id):
----------
Based on metric described in [Hill]_
"""
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

X = all_pcs

Expand Down Expand Up @@ -348,6 +343,7 @@ def nearest_neighbors_metrics(all_pcs, all_labels, this_unit_id, max_spikes, n_n
----------
Based on metrics described in [Chung]_
"""
from sklearn.neighbors import NearestNeighbors

total_spikes = all_pcs.shape[0]
ratio = max_spikes / total_spikes
Expand Down Expand Up @@ -474,6 +470,8 @@ def nearest_neighbors_isolation(
----------
Based on isolation metric described in [Chung]_
"""
from sklearn.decomposition import IncrementalPCA

rng = np.random.default_rng(seed=seed)

waveforms_ext = sorting_analyzer.get_extension("waveforms")
Expand Down Expand Up @@ -669,6 +667,8 @@ def nearest_neighbors_noise_overlap(
----------
Based on noise overlap metric described in [Chung]_
"""
from sklearn.decomposition import IncrementalPCA

rng = np.random.default_rng(seed=seed)

waveforms_ext = sorting_analyzer.get_extension("waveforms")
Expand Down Expand Up @@ -796,6 +796,7 @@ def simplified_silhouette_score(all_pcs, all_labels, this_unit_id):
----------
Based on simplified silhouette score suggested by [Hruschka]_
"""
import scipy.spatial.distance

pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :]
centroid_for_this_unit = np.expand_dims(np.mean(pcs_for_this_unit, 0), 0)
Expand Down Expand Up @@ -846,6 +847,7 @@ def silhouette_score(all_pcs, all_labels, this_unit_id):
----------
Based on [Rousseeuw]_
"""
import scipy.spatial.distance

pcs_for_this_unit = all_pcs[all_labels == this_unit_id, :]
distances_for_this_unit = scipy.spatial.distance.cdist(pcs_for_this_unit, pcs_for_this_unit)
Expand Down Expand Up @@ -905,6 +907,7 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):
(1) ranges from 0 to 1; and
(2) is symmetric, i.e. Isolation(A, B) = Isolation(B, A)
"""
from sklearn.neighbors import NearestNeighbors

# get lengths
n_spikes_target = pcs_target_unit.shape[0]
Expand Down
23 changes: 10 additions & 13 deletions src/spikeinterface/sorters/container_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
import string

# TODO move this inside functions
try:
HAS_DOCKER = True
import docker
except ModuleNotFoundError:
HAS_DOCKER = False


from spikeinterface.core.core_tools import recursive_path_modifier
Expand Down Expand Up @@ -83,8 +78,8 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs):
container_requires_gpu = extra_kwargs.get("container_requires_gpu", None)

if mode == "docker":
if not HAS_DOCKER:
raise ModuleNotFoundError("No module named 'docker'")
import docker

client = docker.from_env()
if container_requires_gpu is not None:
extra_kwargs.pop("container_requires_gpu")
Expand All @@ -108,12 +103,12 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs):
elif Path(sif_file).exists():
singularity_image = sif_file
else:
if HAS_DOCKER:
docker_image = self._get_docker_image(container_image)
if docker_image and len(docker_image.tags) > 0:
tag = docker_image.tags[0]
print(f"Building singularity image from local docker image: {tag}")
singularity_image = Client.build(f"docker-daemon://{tag}", sif_file, sudo=False)

docker_image = self._get_docker_image(container_image)
if docker_image and len(docker_image.tags) > 0:
tag = docker_image.tags[0]
print(f"Building singularity image from local docker image: {tag}")
singularity_image = Client.build(f"docker-daemon://{tag}", sif_file, sudo=False)
if not singularity_image:
print(f"Singularity: pulling image {container_image}")
singularity_image = Client.pull(f"docker://{container_image}")
Expand All @@ -134,6 +129,8 @@ def __init__(self, mode, container_image, volumes, py_user_base, extra_kwargs):

@staticmethod
def _get_docker_image(container_image):
import docker

docker_client = docker.from_env(timeout=300)
try:
docker_image = docker_client.images.get(container_image)
Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/sorters/external/combinato.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
from spikeinterface.extractors import CombinatoSortingExtractor
from spikeinterface.preprocessing import ScaleRecording

try:
import h5py

HAVE_H5PY = True
except ImportError:
HAVE_H5PY = False

PathType = Union[str, Path]

Expand Down Expand Up @@ -128,6 +122,12 @@ def _check_apply_filter_in_params(cls, params):

@classmethod
def _setup_recording(cls, recording, sorter_output_folder, params, verbose):
try:
import h5py

HAVE_H5PY = True
except ImportError:
HAVE_H5PY = False
assert HAVE_H5PY, "You must install h5py for combinato"
# Generate h5 files in the dataset directory
chan_ids = recording.get_channel_ids()
Expand Down
19 changes: 11 additions & 8 deletions src/spikeinterface/sorters/external/pykilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,6 @@
import json
from ..basesorter import BaseSorter, get_job_kwargs

try:
import pykilosort
from pykilosort import Bunch, add_default_handler, run

HAVE_PYKILOSORT = True
except ImportError:
HAVE_PYKILOSORT = False


class PyKilosortSorter(BaseSorter):
"""Pykilosort Sorter object."""
Expand Down Expand Up @@ -128,10 +120,19 @@ class PyKilosortSorter(BaseSorter):

@classmethod
def is_installed(cls):
try:
import pykilosort

HAVE_PYKILOSORT = True
except ImportError:
HAVE_PYKILOSORT = False

return HAVE_PYKILOSORT

@classmethod
def get_sorter_version(cls):
import pykilosort

return pykilosort.__version__

@classmethod
Expand All @@ -150,6 +151,8 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose):

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
from pykilosort import Bunch, run

recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)

if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1):
Expand Down
15 changes: 7 additions & 8 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@
from spikeinterface.core.analyzer_extension_core import ComputeTemplates


try:
import hdbscan

HAVE_HDBSCAN = True
except:
HAVE_HDBSCAN = False


class Spykingcircus2Sorter(ComponentsBasedSorter):
sorter_name = "spykingcircus2"

Expand Down Expand Up @@ -100,6 +92,13 @@ def get_sorter_version(cls):

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
try:
import hdbscan

HAVE_HDBSCAN = True
except:
HAVE_HDBSCAN = False

assert HAVE_HDBSCAN, "spykingcircus2 needs hdbscan to be installed"

# this is importanted only on demand because numba import are too heavy
Expand Down
42 changes: 21 additions & 21 deletions src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,10 @@

import numpy as np


try:
import sklearn
from sklearn.feature_extraction.image import extract_patches_2d

HAVE_SKLEARN = True
except ImportError:
HAVE_SKLEARN = False


from spikeinterface.core import get_noise_levels
from spikeinterface.sortingcomponents.peak_detection import DetectPeakByChannel
from spikeinterface.core.template import Templates

try:
import scipy.spatial

import scipy

(potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32)

(nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32)
except:
pass

spike_dtype = [
("sample_index", "int64"),
("channel_index", "int64"),
Expand Down Expand Up @@ -74,6 +53,9 @@ def compress_templates(templates_array, approx_rank, remove_mean=True, return_ne


def compute_overlaps(templates, num_samples, num_channels, sparsities):
import scipy.spatial
import scipy

num_templates = len(templates)

dense_templates = np.zeros((num_templates, num_samples, num_channels), dtype=np.float32)
Expand Down Expand Up @@ -278,6 +260,13 @@ def get_margin(cls, recording, kwargs):

@classmethod
def main_function(cls, traces, d):
import scipy.spatial
import scipy

(potrs,) = scipy.linalg.get_lapack_funcs(("potrs",), dtype=np.float32)

(nrm2,) = scipy.linalg.get_blas_funcs(("nrm2",), dtype=np.float32)

num_templates = d["num_templates"]
num_samples = d["num_samples"]
num_channels = d["num_channels"]
Expand Down Expand Up @@ -543,6 +532,9 @@ class CircusPeeler(BaseTemplateMatchingEngine):

@classmethod
def _prepare_templates(cls, d):
import scipy.spatial
import scipy

templates = d["templates"]
num_samples = d["num_samples"]
num_channels = d["num_channels"]
Expand Down Expand Up @@ -634,6 +626,13 @@ def _prepare_templates(cls, d):

@classmethod
def initialize_and_check_kwargs(cls, recording, kwargs):
try:
from sklearn.feature_extraction.image import extract_patches_2d

HAVE_SKLEARN = True
except ImportError:
HAVE_SKLEARN = False

assert HAVE_SKLEARN, "CircusPeeler needs sklearn to work"
d = cls._default_params.copy()
d.update(kwargs)
Expand Down Expand Up @@ -732,6 +731,7 @@ def main_function(cls, traces, d):
peak_sample_index, peak_chan_ind = DetectPeakByChannel.detect_peaks(
peak_traces, peak_sign, abs_threholds, exclude_sweep_size
)
from sklearn.feature_extraction.image import extract_patches_2d

if jitter > 0:
jittered_peaks = peak_sample_index[:, np.newaxis] + np.arange(-jitter, jitter)
Expand Down
Loading

0 comments on commit ef6d956

Please sign in to comment.