From 162468c6613f1f89419e0571196c395c2eaa072b Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 4 Jun 2024 17:29:17 +0100 Subject: [PATCH 1/4] Add test to check unit structure in qm output --- .../qualitymetrics/misc_metrics.py | 4 +- .../tests/test_metrics_functions.py | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index f1082386cc..fca521c90f 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -388,9 +388,7 @@ def compute_refrac_period_violations( nb_violations = {} rp_contamination = {} - for i, unit_id in enumerate(sorting.unit_ids): - if unit_id not in unit_ids: - continue + for i, unit_id in enumerate(unit_ids): nb_violations[unit_id] = n_v = nb_rp_violations[i] N = num_spikes[unit_id] diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 5a7d43cbae..d0c2481741 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -12,8 +12,12 @@ from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions +from spikeinterface.qualitymetrics.quality_metric_list import ( + _misc_metric_name_to_func, +) from spikeinterface.qualitymetrics import ( + get_quality_metric_list, mahalanobis_metrics, lda_metrics, nearest_neighbors_metrics, @@ -47,6 +51,57 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=4, + seed=1205, + ) + + sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"]) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {}, + "spike_amplitudes": {}, + "principal_components": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() + + +def test_unit_structure_in_output(small_sorting_analyzer): + for metric_name in get_quality_metric_list(): + result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer) + + if isinstance(result, dict): + assert list(result.keys()) == ["#3", "#9", "#4"] + else: + for one_result in result: + assert list(one_result.keys()) == ["#3", "#9", "#4"] + + for metric_name in get_quality_metric_list(): + result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, unit_ids=["#9", "#3"]) + + if isinstance(result, dict): + assert list(result.keys()) == ["#9", "#3"] + else: + for one_result in result: + print(metric_name) + assert list(one_result.keys()) == ["#9", "#3"] + + def _sorting_analyzer_simple(): recording, sorting = generate_ground_truth_recording( durations=[ From 11ae9aa6227e01fd19908d510d142651b8da1c1f Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 13 Jun 2024 09:59:00 +0100 Subject: [PATCH 2/4] Add test to check unit_id order independence --- .../tests/test_metrics_functions.py | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index d0c2481741..6652ea6654 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -38,6 +38,7 @@ compute_amplitude_cv_metrics, compute_sd_ratio, get_synchrony_counts, + compute_quality_metrics, ) from spikeinterface.core.basesorting import minimum_spike_dtype @@ -68,6 +69,7 @@ def _small_sorting_analyzer(): "waveforms": {}, "templates": {}, "spike_amplitudes": {}, + "spike_locations": {}, "principal_components": {}, } @@ -102,6 +104,56 @@ def test_unit_structure_in_output(small_sorting_analyzer): assert list(one_result.keys()) == ["#9", "#3"] +def test_unit_id_order_independence(small_sorting_analyzer): + """ + Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels, + and checks that their calculated quality metrics are independent of the ordering and labelling. + """ + + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=4, + seed=1205, + ) + sorting = sorting.select_units([0, 2, 3]) + small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer_2.compute(extensions_to_compute) + + # need special params to get non-nan results on a short recording + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + } + + quality_metrics_1 = compute_quality_metrics( + small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params + ) + quality_metrics_2 = compute_quality_metrics( + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params + ) + + for metric, metric_1_data in quality_metrics_1.items(): + assert quality_metrics_2[metric][3] == metric_1_data["#3"] + assert quality_metrics_2[metric][2] == metric_1_data["#9"] + assert quality_metrics_2[metric][0] == metric_1_data["#4"] + + def _sorting_analyzer_simple(): recording, sorting = generate_ground_truth_recording( durations=[ From 6529bd3983a7e9ed01ae7f74e9242d60ea066ea2 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 2 Jul 2024 11:56:23 +0100 Subject: [PATCH 3/4] Update test and fix the failings: rp_violation and drift --- .../qualitymetrics/misc_metrics.py | 12 ++-- .../tests/test_metrics_functions.py | 55 ++++++++++++------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index fca521c90f..79e6d1c8e3 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -388,9 +388,11 @@ def compute_refrac_period_violations( nb_violations = {} rp_contamination = {} - for i, unit_id in enumerate(unit_ids): + for unit_index, unit_id in enumerate(sorting.unit_ids): + if unit_id not in unit_ids: + continue - nb_violations[unit_id] = n_v = nb_rp_violations[i] + nb_violations[unit_id] = n_v = nb_rp_violations[unit_index] N = num_spikes[unit_id] if N == 0: rp_contamination[unit_id] = np.nan @@ -1083,10 +1085,10 @@ def compute_drift_metrics( spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - for unit_ind in np.arange(len(unit_ids)): - mask = spikes_in_bin["unit_index"] == unit_ind + for unit_index, unit_id in enumerate(unit_ids): + mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id) if np.sum(mask) >= min_spikes_per_interval: - median_positions[unit_ind, bin_index] = np.median(spike_locations_in_bin[mask]) + median_positions[unit_index, bin_index] = np.median(spike_locations_in_bin[mask]) if median_position_segments is None: median_position_segments = median_positions else: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 6652ea6654..9175748a4a 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -84,24 +84,44 @@ def small_sorting_analyzer(): def test_unit_structure_in_output(small_sorting_analyzer): - for metric_name in get_quality_metric_list(): - result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer) - if isinstance(result, dict): - assert list(result.keys()) == ["#3", "#9", "#4"] - else: - for one_result in result: - assert list(one_result.keys()) == ["#3", "#9", "#4"] + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3}, + "firing_range": {"bin_size_s": 1}, + "isi_violation": {"isi_threshold_ms": 10}, + "drift": {"interval_s": 1, "min_spikes_per_interval": 5}, + "sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15}, + "rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0}, + } for metric_name in get_quality_metric_list(): - result = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, unit_ids=["#9", "#3"]) - if isinstance(result, dict): - assert list(result.keys()) == ["#9", "#3"] + try: + qm_param = qm_params[metric_name] + except: + qm_param = {} + + result_all = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, **qm_param) + result_sub = _misc_metric_name_to_func[metric_name]( + sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param + ) + + if isinstance(result_all, dict): + assert list(result_all.keys()) == ["#3", "#9", "#4"] + assert list(result_sub.keys()) == ["#4", "#9"] + assert result_sub["#9"] == result_all["#9"] + assert result_sub["#4"] == result_all["#4"] + else: - for one_result in result: - print(metric_name) - assert list(one_result.keys()) == ["#9", "#3"] + for result_ind, result in enumerate(result_sub): + + assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"] + assert result_sub[result_ind].keys() == set(["#4", "#9"]) + + assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"] + assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"] def test_unit_id_order_independence(small_sorting_analyzer): @@ -110,12 +130,9 @@ def test_unit_id_order_independence(small_sorting_analyzer): and checks that their calculated quality metrics are independent of the ordering and labelling. """ - recording, sorting = generate_ground_truth_recording( - durations=[2.0], - num_units=4, - seed=1205, - ) - sorting = sorting.select_units([0, 2, 3]) + recording = small_sorting_analyzer.recording + sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3]) + small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") extensions_to_compute = { From aa314cdba68d583a4bc551b3e957efd809d0c919 Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Wed, 3 Jul 2024 10:50:14 +0200 Subject: [PATCH 4/4] not use unit_index Co-authored-by: Alessio Buccino --- src/spikeinterface/qualitymetrics/misc_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 04c604edcb..a7d51ad330 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -1085,10 +1085,10 @@ def compute_drift_metrics( spikes_in_bin = spikes_in_segment[i0:i1] spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction] - for unit_index, unit_id in enumerate(unit_ids): + for i, unit_id in enumerate(unit_ids): mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id) if np.sum(mask) >= min_spikes_per_interval: - median_positions[unit_index, bin_index] = np.median(spike_locations_in_bin[mask]) + median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask]) if median_position_segments is None: median_position_segments = median_positions else: