From 2ea8e0e27e1e3746c2a6c783b90b4932b429929f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 16 Nov 2023 15:18:52 +0100 Subject: [PATCH 1/5] Remove joblib in favor of ParallelProcessExecutor --- pyproject.toml | 1 - src/spikeinterface/core/job_tools.py | 22 +-------- .../postprocessing/principal_component.py | 25 ++++++---- .../tests/test_principal_component.py | 4 +- .../qualitymetrics/pca_metrics.py | 48 +++++++++++-------- 5 files changed, 48 insertions(+), 52 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 658703b25c..5a6d512e9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,6 @@ classifiers = [ dependencies = [ "numpy", "neo>=0.12.0", - "joblib", "threadpoolctl", "tqdm", "probeinterface>=0.2.19", diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index fcf1d93d1c..3eeb7d6aae 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -7,7 +7,6 @@ import os import warnings -import joblib import sys import contextlib from tqdm.auto import tqdm @@ -95,25 +94,6 @@ def split_job_kwargs(mixed_kwargs): return specific_kwargs, job_kwargs -# from https://stackoverflow.com/questions/24983493/tracking-progress-of-joblib-parallel-execution -@contextlib.contextmanager -def tqdm_joblib(tqdm_object): - """Context manager to patch joblib to report into tqdm progress bar given as argument""" - - class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack): - def __call__(self, *args, **kwargs): - tqdm_object.update(n=self.batch_size) - return super().__call__(*args, **kwargs) - - old_batch_callback = joblib.parallel.BatchCompletionCallBack - joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback - try: - yield tqdm_object - finally: - joblib.parallel.BatchCompletionCallBack = old_batch_callback - tqdm_object.close() - - def divide_segment_into_chunks(num_frames, chunk_size): if chunk_size is None: chunks = [(0, num_frames)] @@ -156,7 +136,7 @@ def _mem_to_int(mem): def ensure_n_jobs(recording, n_jobs=1): if n_jobs == -1: - n_jobs = joblib.cpu_count() + n_jobs = os.cpu_count() elif n_jobs == 0: n_jobs = 1 elif n_jobs is None: diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cf32e79b25..ab0dad7ef6 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -370,7 +370,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs): def _fit_by_channel_local(self, n_jobs, progress_bar): from sklearn.decomposition import IncrementalPCA - from joblib import delayed, Parallel + from concurrent.futures import ProcessPoolExecutor we = self.waveform_extractor p = self._params @@ -385,12 +385,15 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): tmp_folder = p["tmp_folder"] if tmp_folder is None: - tmp_folder = "tmp" - tmp_folder = Path(tmp_folder) + if n_jobs > 1: + import tempfile + + tmp_folder = tempfile.mkdtemp(prefix="tmp", dir=".") for chan_ind, chan_id in enumerate(channel_ids): pca_model = pca_models[chan_ind] if n_jobs > 1: + tmp_folder = Path(tmp_folder) tmp_folder.mkdir(exist_ok=True) pca_model_file = tmp_folder / f"tmp_pca_model_{mode}_{chan_id}.pkl" with pca_model_file.open("wb") as f: @@ -411,10 +414,14 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): pca = pca_models[chan_ind] pca.partial_fit(wfs[:, :, wf_ind]) else: - Parallel(n_jobs=n_jobs)( - delayed(partial_fit_one_channel)(pca_model_files[chan_ind], wfs[:, :, wf_ind]) - for wf_ind, chan_ind in enumerate(channel_inds) - ) + # parallel + items = [(pca_model_files[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)] + n_jobs = min(n_jobs, len(items)) + + with ProcessPoolExecutor(max_workers=n_jobs) as executor: + results = executor.map(partial_fit_one_channel, items) + for res in results: + pass # reload the models (if n_jobs > 1) if n_jobs not in (0, 1): @@ -424,6 +431,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): with open(pca_model_file, "rb") as fid: pca_models.append(pickle.load(fid)) pca_model_file.unlink() + print(f"Removing {tmp_folder}") shutil.rmtree(tmp_folder) # add models to extension data @@ -762,7 +770,8 @@ def compute_principal_components( return pc -def partial_fit_one_channel(pca_file, wf_chan): +def partial_fit_one_channel(args): + pca_file, wf_chan = args with open(pca_file, "rb") as fid: pca_model = pickle.load(fid) pca_model.partial_fit(wf_chan) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 49591d9b89..af55034cea 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -204,5 +204,5 @@ def test_project_new(self): # test.test_extension() # test.test_shapes() # test.test_compute_for_all_spikes() - test.test_sparse() - # test.test_project_new() + # test.test_sparse() + test.test_project_new() diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 6b13b1acbf..549724d15c 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -4,6 +4,7 @@ import numpy as np from tqdm.auto import tqdm +from concurrent.futures import ProcessPoolExecutor try: import scipy.stats @@ -11,12 +12,10 @@ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.neighbors import NearestNeighbors from sklearn.decomposition import IncrementalPCA - from joblib import delayed, Parallel except: pass from ..core import get_random_data_chunks, compute_sparsity, WaveformExtractor -from ..core.job_tools import tqdm_joblib from ..core.template_tools import get_template_extremum_channel from ..postprocessing import WaveformPrincipalComponent @@ -25,7 +24,6 @@ from .misc_metrics import compute_num_spikes, compute_firing_rates from ..core import get_random_data_chunks, load_waveforms, compute_sparsity, WaveformExtractor -from ..core.job_tools import tqdm_joblib from ..core.template_tools import get_template_extremum_channel from ..postprocessing import WaveformPrincipalComponent @@ -134,7 +132,9 @@ def calculate_pc_metrics( parallel_functions = [] all_labels, all_pcs = pca.get_all_projections() - for unit_ind, unit_id in units_loop: + + items = [] + for unit_id in unit_ids: if we.is_sparse(): neighbor_channel_ids = we.sparsity.unit_id_to_channel_ids[unit_id] neighbor_unit_ids = [ @@ -166,23 +166,20 @@ def calculate_pc_metrics( n_spikes_all_units, fr_all_units, ) + items.append(func_args) - if not run_in_parallel: - pca_metrics_unit = pca_metrics_one_unit(*func_args) + if not run_in_parallel: + for unit_ind, unit_id in units_loop: + pca_metrics_unit = pca_metrics_one_unit(items[unit_ind]) for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric - else: - parallel_functions.append(delayed(pca_metrics_one_unit)(*func_args)) - - if run_in_parallel: - if progress_bar: - units_loop = tqdm(units_loop, desc="Computing PCA metrics", total=len(unit_ids)) - with tqdm_joblib(units_loop) as pb: - pc_metrics_units = Parallel(n_jobs=n_jobs)(parallel_functions) - else: - pc_metrics_units = Parallel(n_jobs=n_jobs)(parallel_functions) + else: + with ProcessPoolExecutor(n_jobs) as executor: + results = executor.map(pca_metrics_one_unit, items) + if progress_bar: + results = tqdm(results, total=len(unit_ids)) - for ui, pca_metrics_unit in enumerate(pc_metrics_units): + for ui, pca_metrics_unit in enumerate(results): unit_id = unit_ids[ui] for metric_name, metric in pca_metrics_unit.items(): pc_metrics[metric_name][unit_id] = metric @@ -888,9 +885,20 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): return isolation -def pca_metrics_one_unit( - pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, seed, we_folder, n_spikes_all_units, fr_all_units -): +def pca_metrics_one_unit(args): + ( + pcs_flat, + labels, + metric_names, + unit_id, + unit_ids, + qm_params, + seed, + we_folder, + n_spikes_all_units, + fr_all_units, + ) = args + if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names: we = load_waveforms(we_folder) From 00bc31fb5a67d0765647d7928f6c3572dbddad7b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 22 Nov 2023 14:30:10 +0100 Subject: [PATCH 2/5] Use global tmp folder in PCA --- src/spikeinterface/postprocessing/principal_component.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 6f2be41229..62909ea5b3 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -9,6 +9,7 @@ from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension +from spikeinterface.core.globals import get_global_tmp_folder _possible_modes = ["by_channel_local", "by_channel_global", "concatenated"] @@ -388,7 +389,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): if n_jobs > 1: import tempfile - tmp_folder = tempfile.mkdtemp(prefix="tmp", dir=".") + tmp_folder = tempfile.mkdtemp(prefix="tmp", dir=get_global_tmp_folder()) for chan_ind, chan_id in enumerate(channel_ids): pca_model = pca_models[chan_ind] From 277e7190eafb052a28a6ae4ae03c3fad14a73d87 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 22 Nov 2023 16:12:33 +0100 Subject: [PATCH 3/5] Fix indentation! --- src/spikeinterface/qualitymetrics/pca_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 549724d15c..a8e6d90e6a 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -179,10 +179,10 @@ def calculate_pc_metrics( if progress_bar: results = tqdm(results, total=len(unit_ids)) - for ui, pca_metrics_unit in enumerate(results): - unit_id = unit_ids[ui] - for metric_name, metric in pca_metrics_unit.items(): - pc_metrics[metric_name][unit_id] = metric + for ui, pca_metrics_unit in enumerate(results): + unit_id = unit_ids[ui] + for metric_name, metric in pca_metrics_unit.items(): + pc_metrics[metric_name][unit_id] = metric return pc_metrics From 88494c9dd74fb2091dd6b262ed61dc52521f8648 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 23 Nov 2023 12:04:28 +0100 Subject: [PATCH 4/5] Remove print and Ramon's suggestions --- src/spikeinterface/postprocessing/principal_component.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index 62909ea5b3..9e822a5d1a 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -1,6 +1,7 @@ import shutil import pickle import warnings +import tempfile from pathlib import Path from tqdm.auto import tqdm @@ -387,9 +388,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): tmp_folder = p["tmp_folder"] if tmp_folder is None: if n_jobs > 1: - import tempfile - - tmp_folder = tempfile.mkdtemp(prefix="tmp", dir=get_global_tmp_folder()) + tmp_folder = tempfile.mkdtemp(prefix="pca", dir=get_global_tmp_folder()) for chan_ind, chan_id in enumerate(channel_ids): pca_model = pca_models[chan_ind] @@ -432,7 +431,6 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): with open(pca_model_file, "rb") as fid: pca_models.append(pickle.load(fid)) pca_model_file.unlink() - print(f"Removing {tmp_folder}") shutil.rmtree(tmp_folder) # add models to extension data From f74203288ba998614af338afe54d2996dff99294 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 23 Nov 2023 12:10:11 +0100 Subject: [PATCH 5/5] Add futures in waveforms extractor --- src/spikeinterface/core/waveform_extractor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index c51c9ba4f4..1218a7f281 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math import pickle from pathlib import Path