Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor quality metrics tests to use fixture #3249

Merged
merged 10 commits into from
Sep 12, 2024
40 changes: 37 additions & 3 deletions src/spikeinterface/qualitymetrics/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
create_sorting_analyzer,
)

job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved

def _small_sorting_analyzer():

@pytest.fixture(scope="module")
def small_sorting_analyzer():
recording, sorting = generate_ground_truth_recording(
durations=[2.0],
num_units=10,
Expand All @@ -33,5 +36,36 @@ def _small_sorting_analyzer():


@pytest.fixture(scope="module")
def small_sorting_analyzer():
return _small_sorting_analyzer()
def sorting_analyzer_simple():
# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=1205,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer
58 changes: 0 additions & 58 deletions src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,36 +135,6 @@ def test_unit_id_order_independence(small_sorting_analyzer):
assert quality_metrics_2[metric][1] == metric_1_data["#4"]


def _sorting_analyzer_simple():
recording, sorting = generate_ground_truth_recording(
durations=[
50.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=2205,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs)
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
return _sorting_analyzer_simple()


def _sorting_violation():
max_time = 100.0
sampling_frequency = 30000
Expand Down Expand Up @@ -566,31 +536,3 @@ def test_calculate_sd_ratio(sorting_analyzer_simple):
assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids)
# @aurelien can you check this, this is not working anymore
# assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0)


if __name__ == "__main__":

sorting_analyzer = _sorting_analyzer_simple()
print(sorting_analyzer)

test_unit_structure_in_output(_small_sorting_analyzer())

# test_calculate_firing_rate_num_spikes(sorting_analyzer)
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chrishalcrow can we add this back so @samuelgarcia is happy? :)

# test_calculate_snrs(sorting_analyzer)
# test_calculate_amplitude_cutoff(sorting_analyzer)
# test_calculate_presence_ratio(sorting_analyzer)
# test_calculate_amplitude_median(sorting_analyzer)
# test_calculate_sliding_rp_violations(sorting_analyzer)
# test_calculate_drift_metrics(sorting_analyzer)
# test_synchrony_metrics(sorting_analyzer)
# test_synchrony_metrics_unit_id_subset(sorting_analyzer)
# test_synchrony_metrics_no_unit_ids(sorting_analyzer)
# test_calculate_firing_range(sorting_analyzer)
# test_calculate_amplitude_cv_metrics(sorting_analyzer)
# test_calculate_sd_ratio(sorting_analyzer)

# sorting_analyzer_violations = _sorting_analyzer_violations()
# print(sorting_analyzer_violations)
# test_calculate_isi_violations(sorting_analyzer_violations)
# test_calculate_sliding_rp_violations(sorting_analyzer_violations)
# test_calculate_rp_violations(sorting_analyzer_violations)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from pathlib import Path
import numpy as np


from spikeinterface.core import (
generate_ground_truth_recording,
create_sorting_analyzer,
Expand All @@ -15,51 +14,9 @@
compute_quality_metrics,
)


job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")


def get_sorting_analyzer(seed=2205):
# we need high firing rate for amplitude_cutoff
recording, sorting = generate_ground_truth_recording(
durations=[
120.0,
],
sampling_frequency=30_000.0,
num_channels=6,
num_units=10,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
generate_unit_locations_kwargs=dict(
margin_um=5.0,
minimum_z=5.0,
maximum_z=20.0,
),
generate_templates_kwargs=dict(
unit_params=dict(
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=seed,
)

sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True)

sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed)
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", **job_kwargs)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("spike_amplitudes", **job_kwargs)

return sorting_analyzer


@pytest.fixture(scope="module")
def sorting_analyzer_simple():
sorting_analyzer = get_sorting_analyzer(seed=2205)
return sorting_analyzer


def test_compute_quality_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
print(sorting_analyzer)
Expand Down Expand Up @@ -288,12 +245,3 @@ def test_empty_units(sorting_analyzer_simple):
# for metric_name in metrics.columns:
# # NaNs are skipped
# assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna())

if __name__ == "__main__":

sorting_analyzer = get_sorting_analyzer()
print(sorting_analyzer)

test_compute_quality_metrics(sorting_analyzer)
test_compute_quality_metrics_recordingless(sorting_analyzer)
test_empty_units(sorting_analyzer)