From 59bb1e747db7b2bc6879720f27ec83e4ce66df31 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Sep 2024 17:21:17 +0200 Subject: [PATCH] Add mp_context check --- src/spikeinterface/postprocessing/principal_component.py | 6 ++++++ .../postprocessing/tests/test_principal_component.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index ff1801c1b0..a713070982 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -1,6 +1,7 @@ from __future__ import annotations import warnings +import platform from pathlib import Path from tqdm.auto import tqdm @@ -418,6 +419,11 @@ def _fit_by_channel_local(self, n_jobs, progress_bar, max_threads_per_process, m p = self.params + if mp_context is not None and platform.system() == "Windows": + assert mp_context != "fork", "'fork' mp_context not supported on Windows!" + elif mp_context == "fork" and platform.system() == "Darwin": + warnings.warn('As of Python 3.8 "fork" is no longer considered safe on macOS') + unit_ids = self.sorting_analyzer.unit_ids channel_ids = self.sorting_analyzer.channel_ids # there is one PCA per channel for independent fit per channel diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 328b72f72c..7a509c410f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -25,7 +25,7 @@ def test_multi_processing(self): sorting_analyzer = self._prepare_sorting_analyzer( format="memory", sparse=False, extension_class=ComputePrincipalComponents ) - sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2, mp_context="fork") + sorting_analyzer.compute("principal_components", mode="by_channel_local", n_jobs=2) sorting_analyzer.compute( "principal_components", mode="by_channel_local", n_jobs=2, max_threads_per_process=4, mp_context="spawn" )