Skip to content

Commit

Permalink
Add mp_context and max_threads_per_process to pca metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 23, 2024
1 parent 59bb1e7 commit 5a02a26
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 36 deletions.
60 changes: 27 additions & 33 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from __future__ import annotations


import warnings
from copy import deepcopy

import numpy as np
import platform
from tqdm.auto import tqdm
from concurrent.futures import ProcessPoolExecutor

import numpy as np

import warnings
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
from threadpoolctl import threadpool_limits

from .misc_metrics import compute_num_spikes, compute_firing_rates

Expand Down Expand Up @@ -56,6 +57,8 @@ def compute_pc_metrics(
seed=None,
n_jobs=1,
progress_bar=False,
mp_context=None,
max_threads_per_process=None,
) -> dict:
"""
Calculate principal component derived metrics.
Expand Down Expand Up @@ -144,17 +147,7 @@ def compute_pc_metrics(
pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices]
pcs_flat = pcs.reshape(pcs.shape[0], -1)

func_args = (
pcs_flat,
labels,
non_nn_metrics,
unit_id,
unit_ids,
qm_params,
seed,
n_spikes_all_units,
fr_all_units,
)
func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process)
items.append(func_args)

if not run_in_parallel and non_nn_metrics:
Expand All @@ -167,7 +160,15 @@ def compute_pc_metrics(
for metric_name, metric in pca_metrics_unit.items():
pc_metrics[metric_name][unit_id] = metric
elif run_in_parallel and non_nn_metrics:
with ProcessPoolExecutor(n_jobs) as executor:
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"
elif mp_context == "fork" and platform.system() == "Darwin":
warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS')

with ProcessPoolExecutor(
max_workers=n_jobs,
mp_context=mp.get_context(mp_context),
) as executor:
results = executor.map(pca_metrics_one_unit, items)
if progress_bar:
results = tqdm(results, total=len(unit_ids), desc="calculate_pc_metrics")
Expand Down Expand Up @@ -976,26 +977,19 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):


def pca_metrics_one_unit(args):
(
pcs_flat,
labels,
metric_names,
unit_id,
unit_ids,
qm_params,
seed,
# we_folder,
n_spikes_all_units,
fr_all_units,
) = args

# if "nn_isolation" in metric_names or "nn_noise_overlap" in metric_names:
# we = load_waveforms(we_folder)
(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args

if max_threads_per_process is None:
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)
else:
with threadpool_limits(limits=int(max_threads_per_process)):
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)


def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params):
pc_metrics = {}
# metrics
if "isolation_distance" in metric_names or "l_ratio" in metric_names:

try:
isolation_distance, l_ratio = mahalanobis_metrics(pcs_flat, labels, unit_id)
except:
Expand Down
25 changes: 22 additions & 3 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import pytest
import numpy as np

from spikeinterface.qualitymetrics import (
compute_pc_metrics,
)
from spikeinterface.qualitymetrics import compute_pc_metrics, get_quality_pca_metric_list


def test_calculate_pc_metrics(small_sorting_analyzer):
Expand All @@ -22,3 +20,24 @@ def test_calculate_pc_metrics(small_sorting_analyzer):
assert not np.all(np.isnan(res2[metric_name].values))

assert np.array_equal(res1[metric_name].values, res2[metric_name].values)


def test_pca_metrics_multi_processing(small_sorting_analyzer):
sorting_analyzer = small_sorting_analyzer

metric_names = get_quality_pca_metric_list()
metric_names.remove("nn_isolation")
metric_names.remove("nn_noise_overlap")

print(f"Computing PCA metrics with 1 thread per process")
res1 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=1, progress_bar=True
)
print(f"Computing PCA metrics with 2 thread per process")
res2 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
)
print("Computing PCA metrics with spawn context")
res2 = compute_pc_metrics(
sorting_analyzer, n_jobs=-1, metric_names=metric_names, max_threads_per_process=2, progress_bar=True
)

0 comments on commit 5a02a26

Please sign in to comment.