Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove joblib in favor of ParallelProcessExecutor #2218

Merged
merged 9 commits into from
Nov 23, 2023
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ classifiers = [
dependencies = [
"numpy",
"neo>=0.12.0",
"joblib",
"threadpoolctl",
"tqdm",
"probeinterface>=0.2.19",
Expand Down
22 changes: 1 addition & 21 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import warnings

import joblib
import sys
import contextlib
from tqdm.auto import tqdm
Expand Down Expand Up @@ -95,25 +94,6 @@ def split_job_kwargs(mixed_kwargs):
return specific_kwargs, job_kwargs


# from https://stackoverflow.com/questions/24983493/tracking-progress-of-joblib-parallel-execution
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""

class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)

old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()


def divide_segment_into_chunks(num_frames, chunk_size):
if chunk_size is None:
chunks = [(0, num_frames)]
Expand Down Expand Up @@ -156,7 +136,7 @@ def _mem_to_int(mem):

def ensure_n_jobs(recording, n_jobs=1):
if n_jobs == -1:
n_jobs = joblib.cpu_count()
n_jobs = os.cpu_count()
elif n_jobs == 0:
n_jobs = 1
elif n_jobs is None:
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math
import pickle
from pathlib import Path
Expand Down
24 changes: 16 additions & 8 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import shutil
import pickle
import warnings
import tempfile
from pathlib import Path
from tqdm.auto import tqdm

Expand All @@ -9,6 +10,7 @@

from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, fix_job_kwargs
from spikeinterface.core.waveform_extractor import WaveformExtractor, BaseWaveformExtractorExtension
from spikeinterface.core.globals import get_global_tmp_folder

_possible_modes = ["by_channel_local", "by_channel_global", "concatenated"]

Expand Down Expand Up @@ -370,7 +372,7 @@ def run_for_all_spikes(self, file_path=None, **job_kwargs):

def _fit_by_channel_local(self, n_jobs, progress_bar):
from sklearn.decomposition import IncrementalPCA
from joblib import delayed, Parallel
from concurrent.futures import ProcessPoolExecutor

we = self.waveform_extractor
p = self._params
Expand All @@ -385,12 +387,13 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):

tmp_folder = p["tmp_folder"]
if tmp_folder is None:
tmp_folder = "tmp"
tmp_folder = Path(tmp_folder)
if n_jobs > 1:
tmp_folder = tempfile.mkdtemp(prefix="pca", dir=get_global_tmp_folder())

for chan_ind, chan_id in enumerate(channel_ids):
pca_model = pca_models[chan_ind]
if n_jobs > 1:
tmp_folder = Path(tmp_folder)
tmp_folder.mkdir(exist_ok=True)
pca_model_file = tmp_folder / f"tmp_pca_model_{mode}_{chan_id}.pkl"
with pca_model_file.open("wb") as f:
Expand All @@ -411,10 +414,14 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
pca = pca_models[chan_ind]
pca.partial_fit(wfs[:, :, wf_ind])
else:
Parallel(n_jobs=n_jobs)(
delayed(partial_fit_one_channel)(pca_model_files[chan_ind], wfs[:, :, wf_ind])
for wf_ind, chan_ind in enumerate(channel_inds)
)
# parallel
items = [(pca_model_files[chan_ind], wfs[:, :, wf_ind]) for wf_ind, chan_ind in enumerate(channel_inds)]
n_jobs = min(n_jobs, len(items))

with ProcessPoolExecutor(max_workers=n_jobs) as executor:
results = executor.map(partial_fit_one_channel, items)
for res in results:
pass
samuelgarcia marked this conversation as resolved.
Show resolved Hide resolved

# reload the models (if n_jobs > 1)
if n_jobs not in (0, 1):
Expand Down Expand Up @@ -762,7 +769,8 @@ def compute_principal_components(
return pc


def partial_fit_one_channel(pca_file, wf_chan):
def partial_fit_one_channel(args):
pca_file, wf_chan = args
with open(pca_file, "rb") as fid:
pca_model = pickle.load(fid)
pca_model.partial_fit(wf_chan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,5 +204,5 @@ def test_project_new(self):
# test.test_extension()
# test.test_shapes()
# test.test_compute_for_all_spikes()
test.test_sparse()
# test.test_project_new()
# test.test_sparse()
test.test_project_new()
54 changes: 31 additions & 23 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@

import numpy as np
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor

try:
import scipy.stats
import scipy.spatial.distance
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import IncrementalPCA
from joblib import delayed, Parallel
except:
pass

from ..core import get_random_data_chunks, compute_sparsity, WaveformExtractor
from ..core.job_tools import tqdm_joblib
from ..core.template_tools import get_template_extremum_channel

from ..postprocessing import WaveformPrincipalComponent
Expand All @@ -25,7 +24,6 @@
from .misc_metrics import compute_num_spikes, compute_firing_rates

from ..core import get_random_data_chunks, load_waveforms, compute_sparsity, WaveformExtractor
from ..core.job_tools import tqdm_joblib
from ..core.template_tools import get_template_extremum_channel
from ..postprocessing import WaveformPrincipalComponent

Expand Down Expand Up @@ -134,7 +132,9 @@ def calculate_pc_metrics(
parallel_functions = []

all_labels, all_pcs = pca.get_all_projections()
for unit_ind, unit_id in units_loop:

items = []
for unit_id in unit_ids:
if we.is_sparse():
neighbor_channel_ids = we.sparsity.unit_id_to_channel_ids[unit_id]
neighbor_unit_ids = [
Expand Down Expand Up @@ -166,26 +166,23 @@ def calculate_pc_metrics(
n_spikes_all_units,
fr_all_units,
)
items.append(func_args)

if not run_in_parallel:
pca_metrics_unit = pca_metrics_one_unit(*func_args)
if not run_in_parallel:
for unit_ind, unit_id in units_loop:
pca_metrics_unit = pca_metrics_one_unit(items[unit_ind])
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric
else:
parallel_functions.append(delayed(pca_metrics_one_unit)(*func_args))

if run_in_parallel:
if progress_bar:
units_loop = tqdm(units_loop, desc="Computing PCA metrics", total=len(unit_ids))
with tqdm_joblib(units_loop) as pb:
pc_metrics_units = Parallel(n_jobs=n_jobs)(parallel_functions)
else:
pc_metrics_units = Parallel(n_jobs=n_jobs)(parallel_functions)
else:
with ProcessPoolExecutor(n_jobs) as executor:
results = executor.map(pca_metrics_one_unit, items)
if progress_bar:
results = tqdm(results, total=len(unit_ids))

for ui, pca_metrics_unit in enumerate(pc_metrics_units):
unit_id = unit_ids[ui]
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric
for ui, pca_metrics_unit in enumerate(results):
unit_id = unit_ids[ui]
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric

return pc_metrics

Expand Down Expand Up @@ -888,9 +885,20 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):
return isolation


def pca_metrics_one_unit(
pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, seed, we_folder, n_spikes_all_units, fr_all_units
):
def pca_metrics_one_unit(args):
(
pcs_flat,
labels,
metric_names,
unit_id,
unit_ids,
qm_params,
seed,
we_folder,
n_spikes_all_units,
fr_all_units,
) = args

if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names:
we = load_waveforms(we_folder)

Expand Down