diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index e16bd9ad27..45ba55dee4 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -8,11 +8,9 @@ import numpy as np import warnings -from typing import Optional from copy import deepcopy from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension -from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.template_tools import get_dense_templates_array @@ -238,13 +236,17 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job for metric_name in metrics_single_channel: func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - sampling_frequency=sampling_frequency_up, - trough_idx=trough_idx, - peak_idx=peak_idx, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + sampling_frequency=sampling_frequency_up, + trough_idx=trough_idx, + peak_idx=peak_idx, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value # compute metrics multi_channel @@ -274,12 +276,16 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job sampling_frequency_up = sampling_frequency func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - channel_locations=channel_locations_sparse, - sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value return template_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 0c7cf25237..cdf6151e95 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -164,7 +164,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: if not sorting_analyzer.has_extension("principal_components"): - raise ValueError("waveform_principal_component must be provied") + raise ValueError( + "To compute principal components base metrics, the principal components " + "extension must be computed first." + ) pc_metrics = compute_pc_metrics( sorting_analyzer, unit_ids=non_empty_unit_ids,