Skip to content

Commit

Permalink
Merge pull request #2973 from chrishalcrow/add-tests-for-qm-structure
Browse files Browse the repository at this point in the history
Add test to check unit structure in quality metric calculator output
  • Loading branch information
alejoe91 authored Jul 3, 2024
2 parents 5a7d890 + aa314cd commit 211c222
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,11 +388,11 @@ def compute_refrac_period_violations(
nb_violations = {}
rp_contamination = {}

for i, unit_id in enumerate(sorting.unit_ids):
for unit_index, unit_id in enumerate(sorting.unit_ids):
if unit_id not in unit_ids:
continue

nb_violations[unit_id] = n_v = nb_rp_violations[i]
nb_violations[unit_id] = n_v = nb_rp_violations[unit_index]
N = num_spikes[unit_id]
if N == 0:
rp_contamination[unit_id] = np.nan
Expand Down Expand Up @@ -1085,10 +1085,10 @@ def compute_drift_metrics(
spikes_in_bin = spikes_in_segment[i0:i1]
spike_locations_in_bin = spike_locations_in_segment[i0:i1][direction]

for unit_ind in np.arange(len(unit_ids)):
mask = spikes_in_bin["unit_index"] == unit_ind
for i, unit_id in enumerate(unit_ids):
mask = spikes_in_bin["unit_index"] == sorting.id_to_index(unit_id)
if np.sum(mask) >= min_spikes_per_interval:
median_positions[unit_ind, bin_index] = np.median(spike_locations_in_bin[mask])
median_positions[i, bin_index] = np.median(spike_locations_in_bin[mask])
if median_position_segments is None:
median_position_segments = median_positions
else:
Expand Down
124 changes: 124 additions & 0 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@

from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions

from spikeinterface.qualitymetrics.quality_metric_list import (
_misc_metric_name_to_func,
)

from spikeinterface.qualitymetrics import (
get_quality_metric_list,
mahalanobis_metrics,
lda_metrics,
nearest_neighbors_metrics,
Expand All @@ -34,6 +38,7 @@
compute_amplitude_cv_metrics,
compute_sd_ratio,
get_synchrony_counts,
compute_quality_metrics,
)

from spikeinterface.core.basesorting import minimum_spike_dtype
Expand All @@ -42,6 +47,125 @@
job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def _small_sorting_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=4,
seed=1205,
)

sorting = sorting.select_units([3, 2, 0], ["#3", "#9", "#4"])

sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
"random_spikes": {"seed": 1205},
"noise_levels": {"seed": 1205},
"waveforms": {},
"templates": {},
"spike_amplitudes": {},
"spike_locations": {},
"principal_components": {},
}

sorting_analyzer.compute(extensions_to_compute)

return sorting_analyzer


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()


def test_unit_structure_in_output(small_sorting_analyzer):

qm_params = {
"presence_ratio": {"bin_duration_s": 0.1},
"amplitude_cutoff": {"num_histogram_bins": 3},
"amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3},
"firing_range": {"bin_size_s": 1},
"isi_violation": {"isi_threshold_ms": 10},
"drift": {"interval_s": 1, "min_spikes_per_interval": 5},
"sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15},
"rp_violation": {"refractory_period_ms": 10.0, "censored_period_ms": 0.0},
}

for metric_name in get_quality_metric_list():

try:
qm_param = qm_params[metric_name]
except:
qm_param = {}

result_all = _misc_metric_name_to_func[metric_name](sorting_analyzer=small_sorting_analyzer, **qm_param)
result_sub = _misc_metric_name_to_func[metric_name](
sorting_analyzer=small_sorting_analyzer, unit_ids=["#4", "#9"], **qm_param
)

if isinstance(result_all, dict):
assert list(result_all.keys()) == ["#3", "#9", "#4"]
assert list(result_sub.keys()) == ["#4", "#9"]
assert result_sub["#9"] == result_all["#9"]
assert result_sub["#4"] == result_all["#4"]

else:
for result_ind, result in enumerate(result_sub):

assert list(result_all[result_ind].keys()) == ["#3", "#9", "#4"]
assert result_sub[result_ind].keys() == set(["#4", "#9"])

assert result_sub[result_ind]["#9"] == result_all[result_ind]["#9"]
assert result_sub[result_ind]["#4"] == result_all[result_ind]["#4"]


def test_unit_id_order_independence(small_sorting_analyzer):
"""
Takes two almost-identical sorting_analyzers, whose unit_ids are in different orders and have different labels,
and checks that their calculated quality metrics are independent of the ordering and labelling.
"""

recording = small_sorting_analyzer.recording
sorting = small_sorting_analyzer.sorting.select_units(["#4", "#9", "#3"], [0, 2, 3])

small_sorting_analyzer_2 = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory")

extensions_to_compute = {
"random_spikes": {"seed": 1205},
"noise_levels": {"seed": 1205},
"waveforms": {},
"templates": {},
"spike_amplitudes": {},
"spike_locations": {},
"principal_components": {},
}

small_sorting_analyzer_2.compute(extensions_to_compute)

# need special params to get non-nan results on a short recording
qm_params = {
"presence_ratio": {"bin_duration_s": 0.1},
"amplitude_cutoff": {"num_histogram_bins": 3},
"amplitude_cv": {"average_num_spikes_per_bin": 7, "min_num_bins": 3},
"firing_range": {"bin_size_s": 1},
"isi_violation": {"isi_threshold_ms": 10},
"drift": {"interval_s": 1, "min_spikes_per_interval": 5},
"sliding_rp_violation": {"max_ref_period_ms": 50, "bin_size_ms": 0.15},
}

quality_metrics_1 = compute_quality_metrics(
small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params
)
quality_metrics_2 = compute_quality_metrics(
small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params
)

for metric, metric_1_data in quality_metrics_1.items():
assert quality_metrics_2[metric][3] == metric_1_data["#3"]
assert quality_metrics_2[metric][2] == metric_1_data["#9"]
assert quality_metrics_2[metric][0] == metric_1_data["#4"]


def _sorting_analyzer_simple():
recording, sorting = generate_ground_truth_recording(
durations=[
Expand Down

0 comments on commit 211c222

Please sign in to comment.