diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index cbb55aeb8b..a7d51ad330 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -388,11 +388,11 @@ def compute_refrac_period_violations( nb_violations = {} rp_contamination = {} - for i, unit_id in enumerate(sorting.unit_ids): + for unit_index, unit_id in enumerate(sorting.unit_ids): if unit_id not in unit_ids: continue - nb_violations[unit_id] = n_v = nb_rp_violations[i] + nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] N = num_spikes[unit_id] if N == 0: rp_contamination[unit_id] = np.nan @@ -1085,10 +1085,10 @@ def compute_drift_metrics( spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - for unit_ind in np.arange(len(unit_ids)): - mask = spikes_in_bin["unit_index"] == unit_ind + for i, unit_id in enumerate(unit_ids): + mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id) if np.sum(mask) >= min_spikes_per_interval: - median_positions[unit_ind, bin_index] = np.median(spike_locations_in_bin[mask]) + median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) if median_position_segments is None: median_position_segments = median_positions else: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 2d4eeb360b..aec8201f44 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -12,8 +12,12 @@ from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions +from spikeinterface.qualitymetrics.quality_metric_list import ( + _misc_metric_name_to_func, +) from spikeinterface.qualitymetrics import ( + get_quality_metric_list, mahalanobis_metrics, lda_metrics, nearest_neighbors_metrics, @@ -34,6 +38,7 @@ compute_amplitude_cv_metrics, compute_sd_ratio, get_synchrony_counts, + compute_quality_metrics, ) from spikeinterface.core.basesorting import minimum_spike_dtype @@ -42,6 +47,125 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=4, + seed=1205, + ) + + sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() + + +def test_unit_structure_in_output(small_sorting_analyzer): + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, + } + + for metric_name in get_quality_metric_list(): + + try: + qm_param = qm_params[metric_name] + except: + qm_param = {} + + result_all = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, **qm_param) + result_sub = _misc_metric_name_to_func[metric_name]( + sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param + ) + + if isinstance(result_all, dict): + assert list(result_all.keys()) == ["#3", "#9", "#4"] + assert list(result_sub.keys()) == ["#4", "#9"] + assert result_sub["#9"] == result_all["#9"] + assert result_sub["#4"] == result_all["#4"] + + else: + for result_ind, result in enumerate(result_sub): + + assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"] + assert result_sub[result_ind].keys() == set(["#4", "#9"]) + + assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"] + assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"] + + +def test_unit_id_order_independence(small_sorting_analyzer): + """ + Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, + and checks that their calculated quality metrics are independent of the ordering and labelling. + """ + + recording = small_sorting_analyzer.recording + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3]) + + small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer_2.compute(extensions_to_compute) + + # need special params to get non-nan results on a short recording + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + } + + quality_metrics_1 = compute_quality_metrics( + small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params + ) + quality_metrics_2 = compute_quality_metrics( + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params + ) + + for metric, metric_1_data in quality_metrics_1.items(): + assert quality_metrics_2[metric][3] == metric_1_data["#3"] + assert quality_metrics_2[metric][2] == metric_1_data["#9"] + assert quality_metrics_2[metric][0] == metric_1_data["#4"] + + def _sorting_analyzer_simple(): recording, sorting = generate_ground_truth_recording( durations=[