diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index d35b562298..4c79383542 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -298,8 +298,8 @@ def find_merge_pairs( indices0, indices1 = np.nonzero(pair_mask) n_jobs = job_kwargs["n_jobs"] - mp_context = job_kwargs["mp_context"] - max_threads_per_process = job_kwargs["max_threads_per_process"] + mp_context = job_kwargs.get("mp_context", None) + max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) progress_bar = job_kwargs["progress_bar"] Executor = get_poolexecutor(n_jobs) diff --git a/src/spikeinterface/sortingcomponents/clustering/split.py b/src/spikeinterface/sortingcomponents/clustering/split.py index a31e7d62fc..48ec26679e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/split.py +++ b/src/spikeinterface/sortingcomponents/clustering/split.py @@ -61,7 +61,7 @@ def split_clusters( n_jobs = job_kwargs["n_jobs"] mp_context = job_kwargs.get("mp_context", None) progress_bar = job_kwargs["progress_bar"] - max_threads_per_process = job_kwargs["max_threads_per_process"] + max_threads_per_process = job_kwargs.get("max_threads_per_process", 1) original_labels = peak_labels peak_labels = peak_labels.copy() diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index c7b701c8b0..bc44e58a33 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -224,7 +224,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): metrics_sv = [] for col in metric_names: - dtype = metrics.iloc[0][col].dtype + dtype = np.array(metrics.iloc[0][col]).dtype metric = vv.UnitMetricsGraphMetric(key=col, label=col, dtype=dtype.str) metrics_sv.append(metric) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 1a2fdf38d9..f60346ade0 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -376,9 +376,9 @@ def test_plot_rasters(self): # mytest.test_plot_unit_summary() # mytest.test_unit_locations() # mytest.test_quality_metrics() - # mytest.test_template_metrics() + mytest.test_template_metrics() # mytest.test_amplitudes() - mytest.test_plot_agreement_matrix() + # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() # mytest.test_plot_rasters()