From 04b4d552b2833c61485e8aa71cb4d1c88cd53190 Mon Sep 17 00:00:00 2001 From: Axoft Server Date: Mon, 22 Jul 2024 18:23:33 -0400 Subject: [PATCH 01/98] feat: Add extra args to UnitSummaryWidget for subwidgets --- src/spikeinterface/widgets/unit_summary.py | 43 +++++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 0b2a348edf..40481ae61d 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -30,6 +30,16 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored + unitlocationswidget_params : dict or None, default: None + Parameters for UnitLocationsWidget (see UnitLocationsWidget for details) + unitwaveformswidgets_params : dict or None, default: None + Parameters for UnitWaveformsWidget (see UnitWaveformsWidget for details) + unitwaveformdensitymapwidget_params : dict or None, default: None + Parameters for UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelogramswidget_params : dict or None, default: None + Parameters for AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudeswidget_params : dict or None, default: None + Parameters for AmplitudesWidget (see AmplitudesWidget for details) """ # possible_backends = {} @@ -40,21 +50,40 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - radius_um=100, + unitlocationswidget_params=None, + unitwaveformswidgets_params=None, + unitwaveformdensitymapwidget_params=None, + autocorrelogramswidget_params=None, + amplitudeswidget_params=None, backend=None, **backend_kwargs, ): - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) + if unitlocationswidget_params is None: + unitlocationswidget_params = dict() + if unitwaveformswidgets_params is None: + unitwaveformswidgets_params = dict() + if unitwaveformdensitymapwidget_params is None: + unitwaveformdensitymapwidget_params = dict() + if autocorrelogramswidget_params is None: + autocorrelogramswidget_params = dict() + if amplitudeswidget_params is None: + amplitudeswidget_params = dict() + plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, + unitlocationswidget_params=unitlocationswidget_params, + unitwaveformswidgets_params=unitwaveformswidgets_params, + unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, + autocorrelogramswidget_params=autocorrelogramswidget_params, + amplitudeswidget_params=amplitudeswidget_params ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -69,6 +98,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting_analyzer = dp.sorting_analyzer unit_colors = dp.unit_colors sparsity = dp.sparsity + unitlocationswidget = dp.unitlocationswidget_params + unitwaveformswidgets = dp.unitwaveformswidgets_params + unitwaveformdensitymapwidget = dp.unitwaveformdensitymapwidget_params + autocorrelogramswidget = dp.autocorrelogramswidget_params + amplitudeswidget = dp.amplitudeswidget_params # force the figure without axes if "figsize" not in backend_kwargs: @@ -99,6 +133,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, + **unitlocationswidget ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -121,6 +156,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, + **unitwaveformswidgets ) ax2.set_title(None) @@ -134,6 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, + **unitwaveformdensitymapwidget ) ax3.set_ylabel(None) @@ -145,6 +182,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, + **autocorrelogramswidget ) ax4.set_title(None) @@ -162,6 +200,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, + **amplitudeswidget ) fig.suptitle(f"unit_id: {dp.unit_id}") From fc0743e59412910e07f5c1355a58718ad6f9a4ed Mon Sep 17 00:00:00 2001 From: Axoft Server Date: Mon, 22 Jul 2024 18:29:38 -0400 Subject: [PATCH 02/98] fix: improve variable names --- src/spikeinterface/widgets/unit_summary.py | 28 +++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 40481ae61d..8313e6b9c7 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -51,7 +51,7 @@ def __init__( unit_colors=None, sparsity=None, unitlocationswidget_params=None, - unitwaveformswidgets_params=None, + unitwaveformswidget_params=None, unitwaveformdensitymapwidget_params=None, autocorrelogramswidget_params=None, amplitudeswidget_params=None, @@ -65,8 +65,8 @@ def __init__( if unitlocationswidget_params is None: unitlocationswidget_params = dict() - if unitwaveformswidgets_params is None: - unitwaveformswidgets_params = dict() + if unitwaveformswidget_params is None: + unitwaveformswidget_params = dict() if unitwaveformdensitymapwidget_params is None: unitwaveformdensitymapwidget_params = dict() if autocorrelogramswidget_params is None: @@ -80,7 +80,7 @@ def __init__( unit_colors=unit_colors, sparsity=sparsity, unitlocationswidget_params=unitlocationswidget_params, - unitwaveformswidgets_params=unitwaveformswidgets_params, + unitwaveformswidget_params=unitwaveformswidget_params, unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, autocorrelogramswidget_params=autocorrelogramswidget_params, amplitudeswidget_params=amplitudeswidget_params @@ -98,11 +98,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting_analyzer = dp.sorting_analyzer unit_colors = dp.unit_colors sparsity = dp.sparsity - unitlocationswidget = dp.unitlocationswidget_params - unitwaveformswidgets = dp.unitwaveformswidgets_params - unitwaveformdensitymapwidget = dp.unitwaveformdensitymapwidget_params - autocorrelogramswidget = dp.autocorrelogramswidget_params - amplitudeswidget = dp.amplitudeswidget_params + unitlocationswidget_params = dp.unitlocationswidget_params + unitwaveformswidget_params = dp.unitwaveformswidget_params + unitwaveformdensitymapwidget_params = dp.unitwaveformdensitymapwidget_params + autocorrelogramswidget_params = dp.autocorrelogramswidget_params + amplitudeswidget_params = dp.amplitudeswidget_params # force the figure without axes if "figsize" not in backend_kwargs: @@ -133,7 +133,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, - **unitlocationswidget + **unitlocationswidget_params ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -156,7 +156,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, - **unitwaveformswidgets + **unitwaveformswidget_params ) ax2.set_title(None) @@ -170,7 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, - **unitwaveformdensitymapwidget + **unitwaveformdensitymapwidget_params ) ax3.set_ylabel(None) @@ -182,7 +182,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, - **autocorrelogramswidget + **autocorrelogramswidget_params ) ax4.set_title(None) @@ -200,7 +200,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, - **amplitudeswidget + **amplitudeswidget_params ) fig.suptitle(f"unit_id: {dp.unit_id}") From d3c9eb0820109e69b5e973419e01531410abb19a Mon Sep 17 00:00:00 2001 From: florian6973 <70778912+florian6973@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:31:33 -0400 Subject: [PATCH 03/98] fix: typo doc --- src/spikeinterface/widgets/unit_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8313e6b9c7..d41be2641f 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -32,7 +32,7 @@ class UnitSummaryWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored unitlocationswidget_params : dict or None, default: None Parameters for UnitLocationsWidget (see UnitLocationsWidget for details) - unitwaveformswidgets_params : dict or None, default: None + unitwaveformswidget_params : dict or None, default: None Parameters for UnitWaveformsWidget (see UnitWaveformsWidget for details) unitwaveformdensitymapwidget_params : dict or None, default: None Parameters for UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) From 60f775f9e4bd9c75f6657532ccdf4a2e57701dbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 22:31:55 +0000 Subject: [PATCH 04/98] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/unit_summary.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index d41be2641f..979e8c0398 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -83,7 +83,7 @@ def __init__( unitwaveformswidget_params=unitwaveformswidget_params, unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, autocorrelogramswidget_params=autocorrelogramswidget_params, - amplitudeswidget_params=amplitudeswidget_params + amplitudeswidget_params=amplitudeswidget_params, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -133,7 +133,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, - **unitlocationswidget_params + **unitlocationswidget_params, ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -156,7 +156,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, - **unitwaveformswidget_params + **unitwaveformswidget_params, ) ax2.set_title(None) @@ -170,7 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, - **unitwaveformdensitymapwidget_params + **unitwaveformdensitymapwidget_params, ) ax3.set_ylabel(None) @@ -182,7 +182,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, - **autocorrelogramswidget_params + **autocorrelogramswidget_params, ) ax4.set_title(None) @@ -200,7 +200,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, - **amplitudeswidget_params + **amplitudeswidget_params, ) fig.suptitle(f"unit_id: {dp.unit_id}") From c61378ec1d3fad9b5f9d6521a51a0dd5376dfc35 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:45:17 +0100 Subject: [PATCH 05/98] simplify qm tests and wrap for multiprocessing --- .../qualitymetrics/tests/conftest.py | 51 +++++++++++ .../tests/test_metrics_functions.py | 85 +++++++------------ .../qualitymetrics/tests/test_pca_metrics.py | 5 ++ .../tests/test_quality_metric_calculator.py | 53 +----------- 4 files changed, 92 insertions(+), 102 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index bb2a345340..b1b23fcaee 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,6 +5,8 @@ create_sorting_analyzer, ) +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + def _small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( @@ -35,3 +37,52 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") def small_sorting_analyzer(): return _small_sorting_analyzer() + + +def _sorting_analyzer_simple(): + + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + sorting_analyzer = get_sorting_analyzer(seed=2205) + return sorting_analyzer + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + return _sorting_analyzer_simple() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index bb222200e9..446007d10b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -135,36 +135,6 @@ def test_unit_id_order_independence(small_sorting_analyzer): assert quality_metrics_2[metric][1] == metric_1_data["#4"] -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() - - def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -570,27 +540,36 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): if __name__ == "__main__": - sorting_analyzer = _sorting_analyzer_simple() - print(sorting_analyzer) - - test_unit_structure_in_output(_small_sorting_analyzer()) - - # test_calculate_firing_rate_num_spikes(sorting_analyzer) - # test_calculate_snrs(sorting_analyzer) - # test_calculate_amplitude_cutoff(sorting_analyzer) - # test_calculate_presence_ratio(sorting_analyzer) - # test_calculate_amplitude_median(sorting_analyzer) - # test_calculate_sliding_rp_violations(sorting_analyzer) - # test_calculate_drift_metrics(sorting_analyzer) - # test_synchrony_metrics(sorting_analyzer) - # test_synchrony_metrics_unit_id_subset(sorting_analyzer) - # test_synchrony_metrics_no_unit_ids(sorting_analyzer) - # test_calculate_firing_range(sorting_analyzer) - # test_calculate_amplitude_cv_metrics(sorting_analyzer) - # test_calculate_sd_ratio(sorting_analyzer) - - # sorting_analyzer_violations = _sorting_analyzer_violations() + test_unit_structure_in_output(small_sorting_analyzer) + test_unit_id_order_independence(small_sorting_analyzer) + + test_synchrony_counts_no_sync() + test_synchrony_counts_one_sync() + test_synchrony_counts_one_quad_sync() + test_synchrony_counts_not_all_units() + + test_mahalanobis_metrics() + test_lda_metrics() + test_nearest_neighbors_metrics() + test_silhouette_score_metrics() + test_simplified_silhouette_score_metrics() + + test_calculate_firing_rate_num_spikes(sorting_analyzer_simple) + test_calculate_snrs(sorting_analyzer) + test_calculate_amplitude_cutoff(sorting_analyzer) + test_calculate_presence_ratio(sorting_analyzer) + test_calculate_amplitude_median(sorting_analyzer) + test_calculate_sliding_rp_violations(sorting_analyzer) + test_calculate_drift_metrics(sorting_analyzer) + test_synchrony_metrics(sorting_analyzer) + test_synchrony_metrics_unit_id_subset(sorting_analyzer) + test_synchrony_metrics_no_unit_ids(sorting_analyzer) + test_calculate_firing_range(sorting_analyzer) + test_calculate_amplitude_cv_metrics(sorting_analyzer) + test_calculate_sd_ratio(sorting_analyzer) + + sorting_analyzer_violations = _sorting_analyzer_violations() # print(sorting_analyzer_violations) - # test_calculate_isi_violations(sorting_analyzer_violations) - # test_calculate_sliding_rp_violations(sorting_analyzer_violations) - # test_calculate_rp_violations(sorting_analyzer_violations) + test_calculate_isi_violations(sorting_analyzer_violations) + test_calculate_sliding_rp_violations(sorting_analyzer_violations) + test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6ddeb02689..a0fc97c37c 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -22,3 +22,8 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + + +if __name__ == "__main__": + + test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 28869ba5ff..f877f12708 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -2,7 +2,6 @@ from pathlib import Path import numpy as np - from spikeinterface.core import ( generate_ground_truth_recording, create_sorting_analyzer, @@ -15,51 +14,9 @@ compute_quality_metrics, ) - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def get_sorting_analyzer(seed=2205): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=seed, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple print(sorting_analyzer) @@ -118,6 +75,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): ) for metric_name in metrics.columns: + print(metric_name) if metric_name == "sd_ratio": # this one need recording!!! continue @@ -291,9 +249,6 @@ def test_empty_units(sorting_analyzer_simple): if __name__ == "__main__": - sorting_analyzer = get_sorting_analyzer() - print(sorting_analyzer) - - test_compute_quality_metrics(sorting_analyzer) - test_compute_quality_metrics_recordingless(sorting_analyzer) - test_empty_units(sorting_analyzer) + test_compute_quality_metrics(sorting_analyzer_simple) + test_compute_quality_metrics_recordingless(sorting_analyzer_simple) + test_empty_units(sorting_analyzer_simple) From df1564b877c26e099754f412066e5f24b0008d58 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 13:18:00 +0100 Subject: [PATCH 06/98] try -1 instead of 2 --- src/spikeinterface/qualitymetrics/tests/conftest.py | 2 +- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py | 2 +- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index b1b23fcaee..4b55a25b4a 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,7 +5,7 @@ create_sorting_analyzer, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def _small_sorting_analyzer(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 446007d10b..0df1c25586 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -44,7 +44,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def test_unit_structure_in_output(small_sorting_analyzer): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index a0fc97c37c..507b9a1f70 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -13,7 +13,7 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=-1, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) for metric_name in res1.columns: diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index f877f12708..da1f08c536 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -14,7 +14,7 @@ compute_quality_metrics, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def test_compute_quality_metrics(sorting_analyzer_simple): From b5d896d2d0e996d0d375583f87c817e837963e23 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:50:04 +0100 Subject: [PATCH 07/98] delete print statement --- .../qualitymetrics/tests/test_quality_metric_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index da1f08c536..0756596654 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -75,7 +75,6 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): ) for metric_name in metrics.columns: - print(metric_name) if metric_name == "sd_ratio": # this one need recording!!! continue From 9d3aa2a7d14edb72c5729b1e4606519f1f63b3e3 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 25 Jul 2024 09:44:10 +0100 Subject: [PATCH 08/98] go back to n_jobs=1 --- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py | 2 +- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 0df1c25586..446007d10b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -44,7 +44,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") def test_unit_structure_in_output(small_sorting_analyzer): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 507b9a1f70..a0fc97c37c 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -13,7 +13,7 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=-1, progress_bar=True, seed=1205) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) for metric_name in res1.columns: diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 0756596654..616e6c90c1 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -14,7 +14,7 @@ compute_quality_metrics, ) -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") def test_compute_quality_metrics(sorting_analyzer_simple): From 79f0206f73f654ead63c11f6065e10533d0309ea Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:50:52 +0100 Subject: [PATCH 09/98] Respond to Joe review --- .../qualitymetrics/tests/conftest.py | 25 ++----------- .../tests/test_metrics_functions.py | 37 ------------------- .../qualitymetrics/tests/test_pca_metrics.py | 5 --- .../tests/test_quality_metric_calculator.py | 6 --- 4 files changed, 4 insertions(+), 69 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 4b55a25b4a..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,10 +5,11 @@ create_sorting_analyzer, ) -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): +@pytest.fixture(scope="module") +def small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, @@ -35,12 +36,7 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() - - -def _sorting_analyzer_simple(): - +def sorting_analyzer_simple(): # we need high firing rate for amplitude_cutoff recording, sorting = generate_ground_truth_recording( durations=[ @@ -73,16 +69,3 @@ def _sorting_analyzer_simple(): sorting_analyzer.compute("spike_amplitudes", **job_kwargs) return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 446007d10b..156bab84d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -536,40 +536,3 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) - - -if __name__ == "__main__": - - test_unit_structure_in_output(small_sorting_analyzer) - test_unit_id_order_independence(small_sorting_analyzer) - - test_synchrony_counts_no_sync() - test_synchrony_counts_one_sync() - test_synchrony_counts_one_quad_sync() - test_synchrony_counts_not_all_units() - - test_mahalanobis_metrics() - test_lda_metrics() - test_nearest_neighbors_metrics() - test_silhouette_score_metrics() - test_simplified_silhouette_score_metrics() - - test_calculate_firing_rate_num_spikes(sorting_analyzer_simple) - test_calculate_snrs(sorting_analyzer) - test_calculate_amplitude_cutoff(sorting_analyzer) - test_calculate_presence_ratio(sorting_analyzer) - test_calculate_amplitude_median(sorting_analyzer) - test_calculate_sliding_rp_violations(sorting_analyzer) - test_calculate_drift_metrics(sorting_analyzer) - test_synchrony_metrics(sorting_analyzer) - test_synchrony_metrics_unit_id_subset(sorting_analyzer) - test_synchrony_metrics_no_unit_ids(sorting_analyzer) - test_calculate_firing_range(sorting_analyzer) - test_calculate_amplitude_cv_metrics(sorting_analyzer) - test_calculate_sd_ratio(sorting_analyzer) - - sorting_analyzer_violations = _sorting_analyzer_violations() - # print(sorting_analyzer_violations) - test_calculate_isi_violations(sorting_analyzer_violations) - test_calculate_sliding_rp_violations(sorting_analyzer_violations) - test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index a0fc97c37c..6ddeb02689 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -22,8 +22,3 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) - - -if __name__ == "__main__": - - test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 616e6c90c1..fec5ceeb95 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -245,9 +245,3 @@ def test_empty_units(sorting_analyzer_simple): # for metric_name in metrics.columns: # # NaNs are skipped # assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) - -if __name__ == "__main__": - - test_compute_quality_metrics(sorting_analyzer_simple) - test_compute_quality_metrics_recordingless(sorting_analyzer_simple) - test_empty_units(sorting_analyzer_simple) From d5b3e0cb716c5b5a72e58ae53dd8c5eef2305db2 Mon Sep 17 00:00:00 2001 From: Axoft Server Date: Mon, 5 Aug 2024 15:45:21 -0400 Subject: [PATCH 10/98] feat: switch to nested dict --- src/spikeinterface/widgets/unit_summary.py | 54 ++++++++-------------- 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8313e6b9c7..a22b17bb8a 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections import defaultdict import numpy as np @@ -30,16 +31,13 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - unitlocationswidget_params : dict or None, default: None - Parameters for UnitLocationsWidget (see UnitLocationsWidget for details) - unitwaveformswidgets_params : dict or None, default: None - Parameters for UnitWaveformsWidget (see UnitWaveformsWidget for details) - unitwaveformdensitymapwidget_params : dict or None, default: None - Parameters for UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) - autocorrelogramswidget_params : dict or None, default: None - Parameters for AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) - amplitudeswidget_params : dict or None, default: None - Parameters for AmplitudesWidget (see AmplitudesWidget for details) + widget_params : dict or None, default: None + Parameters for the subwidgets in a nested dictionary + unitlocations_params: UnitLocationsWidget (see UnitLocationsWidget for details) + unitwaveforms_params: UnitWaveformsWidget (see UnitWaveformsWidget for details) + unitwaveformdensitymap_params : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms_params : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes_params : AmplitudesWidget (see AmplitudesWidget for details) """ # possible_backends = {} @@ -50,11 +48,7 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - unitlocationswidget_params=None, - unitwaveformswidget_params=None, - unitwaveformdensitymapwidget_params=None, - autocorrelogramswidget_params=None, - amplitudeswidget_params=None, + widget_params=None, backend=None, **backend_kwargs, ): @@ -63,27 +57,15 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) - if unitlocationswidget_params is None: - unitlocationswidget_params = dict() - if unitwaveformswidget_params is None: - unitwaveformswidget_params = dict() - if unitwaveformdensitymapwidget_params is None: - unitwaveformdensitymapwidget_params = dict() - if autocorrelogramswidget_params is None: - autocorrelogramswidget_params = dict() - if amplitudeswidget_params is None: - amplitudeswidget_params = dict() + if widget_params is None: + widget_params = dict() plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, - unitlocationswidget_params=unitlocationswidget_params, - unitwaveformswidget_params=unitwaveformswidget_params, - unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, - autocorrelogramswidget_params=autocorrelogramswidget_params, - amplitudeswidget_params=amplitudeswidget_params + widget_params=widget_params, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -98,11 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting_analyzer = dp.sorting_analyzer unit_colors = dp.unit_colors sparsity = dp.sparsity - unitlocationswidget_params = dp.unitlocationswidget_params - unitwaveformswidget_params = dp.unitwaveformswidget_params - unitwaveformdensitymapwidget_params = dp.unitwaveformdensitymapwidget_params - autocorrelogramswidget_params = dp.autocorrelogramswidget_params - amplitudeswidget_params = dp.amplitudeswidget_params + + widget_params = defaultdict(lambda: dict(), dp.widget_params) + unitlocationswidget_params = widget_params['unitlocations_params'] + unitwaveformswidget_params = widget_params['unitwaveforms_params'] + unitwaveformdensitymapwidget_params = widget_params['unitwaveformdensitymap_params'] + autocorrelogramswidget_params = widget_params['autocorrelograms_params'] + amplitudeswidget_params = widget_params['amplitudes_params'] # force the figure without axes if "figsize" not in backend_kwargs: From 87956ccc2791883c98435f1a0ae3482ecce2d2d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 19:47:11 +0000 Subject: [PATCH 11/98] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/unit_summary.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index a61ebe2c86..75a399fab5 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -31,7 +31,7 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - widget_params : dict or None, default: None + widget_params : dict or None, default: None Parameters for the subwidgets in a nested dictionary unitlocations_params: UnitLocationsWidget (see UnitLocationsWidget for details) unitwaveforms_params: UnitWaveformsWidget (see UnitWaveformsWidget for details) @@ -58,7 +58,7 @@ def __init__( unit_colors = get_unit_colors(sorting_analyzer) if widget_params is None: - widget_params = dict() + widget_params = dict() plot_data = dict( sorting_analyzer=sorting_analyzer, @@ -82,11 +82,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity = dp.sparsity widget_params = defaultdict(lambda: dict(), dp.widget_params) - unitlocationswidget_params = widget_params['unitlocations_params'] - unitwaveformswidget_params = widget_params['unitwaveforms_params'] - unitwaveformdensitymapwidget_params = widget_params['unitwaveformdensitymap_params'] - autocorrelogramswidget_params = widget_params['autocorrelograms_params'] - amplitudeswidget_params = widget_params['amplitudes_params'] + unitlocationswidget_params = widget_params["unitlocations_params"] + unitwaveformswidget_params = widget_params["unitwaveforms_params"] + unitwaveformdensitymapwidget_params = widget_params["unitwaveformdensitymap_params"] + autocorrelogramswidget_params = widget_params["autocorrelograms_params"] + amplitudeswidget_params = widget_params["amplitudes_params"] # force the figure without axes if "figsize" not in backend_kwargs: From a9a511af00aa7b6dbda0739f10dfa02dd7a11cc1 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 7 Aug 2024 16:39:54 +0100 Subject: [PATCH 12/98] Do not delete quality metrics on recompute --- .../quality_metric_calculator.py | 27 +++++++-- .../tests/test_metrics_functions.py | 59 +++++++++++++++++++ 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 0c7cf25237..25b8cc7c05 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -2,10 +2,11 @@ from __future__ import annotations - +import weakref import warnings from copy import deepcopy +import pandas as pd import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs @@ -49,6 +50,18 @@ class ComputeQualityMetrics(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True + def __init__(self, sorting_analyzer): + + self._sorting_analyzer = weakref.ref(sorting_analyzer) + + qm_extension = sorting_analyzer.get_extension("quality_metrics") + if qm_extension: + self.params = qm_extension.params + self.data = {"metrics": qm_extension.get_data()} + else: + self.params = {} + self.data = {"metrics": pd.DataFrame(index=sorting_analyzer.sorting.unit_ids)} + def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) @@ -71,8 +84,14 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign + try: + existing_metric_names = self.sorting_analyzer.get_extension("quality_metrics").params.get("metric_names") + metric_names_for_params = np.concatenate((existing_metric_names, metric_names)) + except: + metric_names_for_params = metric_names + params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=[str(name) for name in np.unique(metric_names_for_params)], peak_sign=peak_sign, seed=seed, qm_params=qm_params_, @@ -89,8 +108,6 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - import pandas as pd - old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -134,7 +151,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd - metrics = pd.DataFrame(index=unit_ids) + metrics = self.data["metrics"] # simple metrics not based on PCs for metric_name in metric_names: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index bb222200e9..e34c15c936 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -1,6 +1,7 @@ import pytest from pathlib import Path import numpy as np +from copy import deepcopy from spikeinterface.core import ( NumpySorting, synthetize_spike_train_bad_isi, @@ -47,6 +48,64 @@ job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +def test_compute_new_quality_metrics(small_sorting_analyzer): + """ + Computes quality metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + """ + + qm_params = { + "presence_ratio": {"bin_duration_s": 0.1}, + "amplitude_cutoff": {"num_histogram_bins": 3}, + "firing_range": {"bin_size_s": 1}, + } + + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + ) + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + + quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert list(quality_metric_extension.get_data().keys()) == [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + assert list(quality_metric_extension.params.get("metric_names")) == [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + + # check that, when parameters are changed, the data and metadata are updated + old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) + small_sorting_analyzer.compute( + {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + ) + new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") + new_snr_data = new_quality_metric_extension.get_data()["snr"].values + + assert np.all(old_snr_data != new_snr_data) + assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + + # check that all quality metrics are deleted when parents are recomputed, even after + # recomputation + extensions_to_compute = { + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + "spike_locations": {}, + "principal_components": {}, + } + + small_sorting_analyzer.compute(extensions_to_compute) + + assert small_sorting_analyzer.get_extension("quality_metrics") is None + + def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { From 63713db9b026c6ab55a2e452396a860b78c9934a Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:16:51 +0100 Subject: [PATCH 13/98] Change where pandas is imported --- .../qualitymetrics/quality_metric_calculator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 25b8cc7c05..b34407027b 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -6,7 +6,6 @@ import warnings from copy import deepcopy -import pandas as pd import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs @@ -60,7 +59,7 @@ def __init__(self, sorting_analyzer): self.data = {"metrics": qm_extension.get_data()} else: self.params = {} - self.data = {"metrics": pd.DataFrame(index=sorting_analyzer.sorting.unit_ids)} + self.data = {"metrics": None} def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): if metric_names is None: @@ -152,6 +151,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd metrics = self.data["metrics"] + if metrics is None: + metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) # simple metrics not based on PCs for metric_name in metric_names: From 01abc84370c2bfc8acd81523ea18c2e6e15b03de Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 8 Aug 2024 09:14:47 +0100 Subject: [PATCH 14/98] Replace try/except with some ifs for metric names in params --- .../qualitymetrics/quality_metric_calculator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b34407027b..ced1eedbd2 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -83,11 +83,12 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign - try: - existing_metric_names = self.sorting_analyzer.get_extension("quality_metrics").params.get("metric_names") - metric_names_for_params = np.concatenate((existing_metric_names, metric_names)) - except: - metric_names_for_params = metric_names + metric_names_for_params = metric_names + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if qm_extension: + existing_metric_names = qm_extension.params.get("metric_names") + if existing_metric_names is not None: + metric_names_for_params.extend(existing_metric_names) params = dict( metric_names=[str(name) for name in np.unique(metric_names_for_params)], From e6b394115c52c370af8c8411186f97d8e99f24cc Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:15:09 +0100 Subject: [PATCH 15/98] Fix problem with loading sorting_analyzer with qms --- .../qualitymetrics/quality_metric_calculator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index ced1eedbd2..708498e3fa 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -53,10 +53,11 @@ def __init__(self, sorting_analyzer): self._sorting_analyzer = weakref.ref(sorting_analyzer) - qm_extension = sorting_analyzer.get_extension("quality_metrics") - if qm_extension: - self.params = qm_extension.params - self.data = {"metrics": qm_extension.get_data()} + qm_class = sorting_analyzer.extensions.get("quality_metrics") + + if qm_class: + self.params = qm_class.params + self.data = {"metrics": qm_class.get_data()} else: self.params = {} self.data = {"metrics": None} From 762e8faec7fca93379f9be1bf9be398eb65d4774 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:49:33 +0100 Subject: [PATCH 16/98] update if statement --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 708498e3fa..c3c95a2f54 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -86,7 +86,7 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No metric_names_for_params = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if qm_extension: + if qm_extension is not None: existing_metric_names = qm_extension.params.get("metric_names") if existing_metric_names is not None: metric_names_for_params.extend(existing_metric_names) From c1e5e4ec1f889e3070a19734e4ea9169a872e280 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 10:59:16 +0200 Subject: [PATCH 17/98] extend estimate_sparsity methods and fix from_ptp --- src/spikeinterface/core/sparsity.py | 137 +++++++++--------- .../core/tests/test_sparsity.py | 29 +++- 2 files changed, 98 insertions(+), 68 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index a38562ea2c..57e1fa4769 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -338,7 +338,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None + assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -353,17 +353,17 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): + def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the SNR threshold. + Use the "threshold" argument to specify the peak-to-peak threshold. """ assert ( templates_or_sorting_analyzer.sparsity is None ), "To compute sparsity you need a dense SortingAnalyzer or Templates" - from .template_tools import get_template_amplitudes + from .template_tools import get_dense_templates_array from .sortinganalyzer import SortingAnalyzer from .template import Templates @@ -371,23 +371,17 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None return_scaled = templates_or_sorting_analyzer.is_scaled - from .template_tools import get_dense_templates_array - mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -455,15 +449,15 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( - templates_or_sorting_analyzer, - noise_levels=None, - method="radius", - peak_sign="neg", - num_channels=5, - radius_um=100.0, - threshold=5, - by_property=None, -): + templates_or_sorting_analyzer: Union[Templates, SortingAnalyzer], + noise_levels: np.ndarray | None = None, + method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", + num_channels: int | None = 5, + radius_um: float | None = 100.0, + threshold: float | None = 5, + by_property: str | None = None, +) -> ChannelSparsity: """ Get channel sparsity (subset of channels) for each template with several methods. @@ -500,11 +494,6 @@ def compute_sparsity( templates_or_sorting_analyzer, SortingAnalyzer ), f"compute_sparsity(method='{method}') need SortingAnalyzer" - if method in ("snr", "ptp") and isinstance(templates_or_sorting_analyzer, Templates): - assert ( - noise_levels is not None - ), f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" - if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) @@ -521,7 +510,6 @@ def compute_sparsity( sparsity = ChannelSparsity.from_ptp( templates_or_sorting_analyzer, threshold, - noise_levels=noise_levels, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -544,10 +532,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" = "radius", - peak_sign: str = "neg", + method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, + threshold: float | None = 5, + by_property: str | None = None, **job_kwargs, ): """ @@ -567,16 +557,15 @@ def estimate_sparsity( The sorting recording : BaseRecording The recording - num_spikes_for_sparsity : int, default: 100 How many spikes per units to compute the sparsity ms_before : float, default: 1.0 Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels", default: "radius" + method : "radius" | "best_channels" | "ptp" | "by_property", default: "radius" Sparsity method propagated to the `compute_sparsity()` function. - Only "radius" or "best_channels" are implemented + "snr" and "energy" are not available here, because they require noise levels. peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float, default: 100.0 @@ -594,7 +583,10 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'" + assert method in ("radius", "best_channels", "ptp", "by_property"), ( + f"method={method} is not available for `estimate_sparsity()`. " + "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + ) if recording.get_probes() == 1: # standard case @@ -605,43 +597,54 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) - - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - random_spikes_indices = random_spikes_selection( - sorting, - num_samples, - method="uniform", - max_spikes_per_unit=num_spikes_for_sparsity, - margin_size=max(nbefore, nafter), - seed=2205, - ) - spikes = sorting.to_spike_vector() - spikes = spikes[random_spikes_indices] - - templates_array = estimate_templates_with_accumulator( - recording, - spikes, - sorting.unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name="estimate_sparsity", - **job_kwargs, - ) - templates = Templates( - templates_array=templates_array, - sampling_frequency=recording.sampling_frequency, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=sorting.unit_ids, - probe=probe, - ) + if method != "by_property": + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_sparsity, + margin_size=max(nbefore, nafter), + seed=2205, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name="estimate_sparsity", + **job_kwargs, + ) + templates = Templates( + templates_array=templates_array, + sampling_frequency=recording.sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=sorting.unit_ids, + probe=probe, + ) + templates_or_analyzer = templates + else: + from .sortinganalyzer import create_sorting_analyzer + templates_or_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) sparsity = compute_sparsity( - templates, method=method, peak_sign=peak_sign, num_channels=num_channels, radius_um=radius_um + templates_or_analyzer, + method=method, + peak_sign=peak_sign, + num_channels=num_channels, + radius_um=radius_um, + threshold=threshold, + by_property=by_property, ) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a192d90502..6ee023fc12 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -86,7 +86,7 @@ def test_sparsify_waveforms(): num_active_channels = len(non_zero_indices) assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) - # Test round-trip (note that this is loosy) + # Test round-trip (note that this is lossy) unit_id = unit_ids[unit_id] non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) @@ -195,6 +195,33 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + # ptp : just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=3, + progress_bar=True, + n_jobs=1, + ) + + # by_property : just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="by_property", + by_property="group", + progress_bar=True, + n_jobs=1, + ) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + def test_compute_sparsity(): recording, sorting = get_dataset() From c5fc74983cb1b8801a5e4e65ad91fd019c6770d5 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 11:49:57 +0200 Subject: [PATCH 18/98] (wip) 0.101.1 --- doc/releases/0.101.0.rst | 2 +- doc/releases/0.101.1.rst | 113 +++++++++++++++++++++++++++++++++ doc/whatisnew.rst | 7 ++ pyproject.toml | 16 ++--- src/spikeinterface/__init__.py | 4 +- 5 files changed, 131 insertions(+), 11 deletions(-) create mode 100644 doc/releases/0.101.1.rst diff --git a/doc/releases/0.101.0.rst b/doc/releases/0.101.0.rst index c34cd0dc8e..0e686cca1a 100644 --- a/doc/releases/0.101.0.rst +++ b/doc/releases/0.101.0.rst @@ -3,7 +3,7 @@ SpikeInterface 0.101.0 release notes ------------------------------------ -Estimated: 19th July 2024 +19th July 2024 Main changes: diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst new file mode 100644 index 0000000000..41f76d4815 --- /dev/null +++ b/doc/releases/0.101.1.rst @@ -0,0 +1,113 @@ +.. _release0.101.0: + +SpikeInterface 0.101.0 release notes +------------------------------------ + +6th September 2024 + +Main changes: + +* + +core: + +* Add `BaseRecording.reset_times()` function (#3363) +* Add `load_sorting_analyzer_or_waveforms` function (#3352) +* Propagate storage_options to load_sorting_analyzer (#3351) +* Fix zarr folder suffix handling (#3349) +* Lazy loading of zarr timestamps (#3318) +* Enable cloud-loading for analyzer Zarr (#3314) +* Refactor `set_property` in base (#3287) +* Job kwargs fix (#3259) +* Add check for None in 'NoiseGeneratorRecordingSegment' get_traces(). (#3230) + +extractors: + +* "quality" property to be read as string instead of object in `BasePhyKilosortSortingExtractor` (#3365) +* Test IBL skip when the setting up the one client fails (#3289) + +preprocessing: + +* Update doc handle drift + better preset (#3232) +* Add causal filtering to filter.py (#3172) + +sorters: + +* fix: download apptainer images without docker client (#3335) +* Expose save preprocessing in ks4 (#3276) +* Fix KS2/2.5/3 skip_kilosort_preprocessing (#3265) +* added lowpass parameter, fixed verbose option (#3262) +* Now exclusive support for HS v0.4 (Lightning) (#3210) +* Add kilosort4 wrapper tests (#3085) + +postprocessing: + +* Protect median against nans in get_prototype_spike (#3270) +* Fix docstring and error for spike_amplitudes (#3269) + + +curation: + +* Clean-up identity merges in `get_potential_auto_merges` (#3346) +* Fix sortingview curation no merge case (#3309) +* Start apply_curation() (#3208) + +widgets: + +* Fix widgets tests and add test on unit_table_properties (#3354) +* Allow quality and template metrics in sortingview's unit table (#3299) +* Fix #3236 (#3238) + +sortingcomponents: + +* Update doc handle drift + better preset (#3232) + +motion correction: + +* Make InterpolateMotionRecording not JSON-serializable (#3341) + +documentation: + +* Clarify meaning of `delta_time` in `compare_sorter_to_ground_truth` (#3360) +* Added sphinxcontrib-jquery (#3307) +* Adding return type annotations (#3304) +* More docstring updates for multiple modules (#3298) +* Fix sampling frequency repr (#3294) +* Proposal for adding Examples to docstrings (#3279) +* More numpydoc fixes (#3275) +* Fix docstring and error for spike_amplitudes (#3269) +* Fix postprocessing docs (#3268) +* Fix name of principal_components ext in qm docs (take 2) (#3261) +* Update doc handle drift + better preset (#3232) +* Add `int` type to `num_samples` on `InjectTemplatesRecording`. (#3229) + +continuous integration: + +* Fix streaming extractor condition in the CI (#3362) + +packaging: + +* Drop python 3.8 in pyproject.toml (#3267) + +testing: + +* Fix streaming extractor condition in the CI (#3362) +* Test IBL skip when the setting up the one client fails (#3289) +* Refactor `set_property` in base (#3287) +* Add kilosort4 wrapper tests (#3085) + +Contributors: + +* @Djoels +* @JoeZiminski +* @JuanPimientoCaicedo +* @alejoe91 +* @app/pre-commit-ci +* @chrishalcrow +* @h-mayorquin +* @jonahpearl +* @mhhennig +* @rkim48 +* @samuelgarcia +* @tabedzki +* @zm711 diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 94da5d15fb..56d38ce85b 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.101.1.rst releases/0.101.0.rst releases/0.100.8.rst releases/0.100.7.rst @@ -43,6 +44,12 @@ Release notes releases/0.9.1.rst +Version 0.101.1 +=============== + + + + Version 0.101.0 =============== diff --git a/pyproject.toml b/pyproject.toml index 8309ca89fe..db435998a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,16 +125,16 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_extractors = [ # Functions to download data in neo test suite "pooch>=1.8.2", "datalad>=1.0.2", - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] test_preprocessing = [ @@ -175,8 +175,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] docs = [ @@ -200,8 +200,8 @@ docs = [ "datalad>=1.0.2", # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version - "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/__init__.py b/src/spikeinterface/__init__.py index 306c12d516..97fb95b623 100644 --- a/src/spikeinterface/__init__.py +++ b/src/spikeinterface/__init__.py @@ -30,5 +30,5 @@ # This flag must be set to False for release # This avoids using versioning that contains ".dev0" (and this is a better choice) # This is mainly useful when using run_sorter in a container and spikeinterface install -DEV_MODE = True -# DEV_MODE = False +# DEV_MODE = True +DEV_MODE = False From f73acfe73a8f052bce5d36c07c4e9606d0b8aeab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 11:58:20 +0200 Subject: [PATCH 19/98] Add date format in auto-release-notes --- doc/scripts/auto-release-notes.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/scripts/auto-release-notes.sh b/doc/scripts/auto-release-notes.sh index 14bee3dad0..f3818e1e18 100644 --- a/doc/scripts/auto-release-notes.sh +++ b/doc/scripts/auto-release-notes.sh @@ -1,6 +1,6 @@ #!/bin/bash if [ $# -eq 0 ]; then - echo "Usage: $0 START_DATE END_DATE [LABEL] [BRANCH1,BRANCH2] [LIMIT]" + echo "Usage: $0 START_DATE (format: YEAR-MM-DD) END_DATE [LABEL] [BRANCH1,BRANCH2] [LIMIT]" exit 1 fi From 30d7dbbc3998fb820b62a3029a4c36fffd48d71b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 12:38:03 +0200 Subject: [PATCH 20/98] Revert ptp changes --- src/spikeinterface/core/sparsity.py | 24 +++++++++++-------- .../core/tests/test_sparsity.py | 13 ---------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 57e1fa4769..4de786cb37 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -353,10 +353,10 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ - Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the peak-to-peak threshold. + Construct sparsity from a thresholds based on template peak-to-peak relative values. + Use the "threshold" argument to specify the peak-to-peak threshold (with respect to noise_levels). """ assert ( @@ -371,8 +371,12 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + ext = templates_or_sorting_analyzer.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): + assert noise_levels is not None, "To compute sparsity from ptp you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -381,7 +385,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -397,7 +401,7 @@ def from_energy(cls, sorting_analyzer, threshold): # noise_levels ext = sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + assert ext is not None, "To compute sparsity from energy you need to compute 'noise_levels' first" noise_levels = ext.data["noise_levels"] # waveforms @@ -532,7 +536,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + method: "radius" | "best_channels" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -563,9 +567,9 @@ def estimate_sparsity( Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels" | "ptp" | "by_property", default: "radius" + method : "radius" | "best_channels" | "by_property", default: "radius" Sparsity method propagated to the `compute_sparsity()` function. - "snr" and "energy" are not available here, because they require noise levels. + "snr", "ptp", and "energy" are not available here because they require noise levels. peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float, default: 100.0 @@ -583,9 +587,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "by_property"), ( + assert method in ("radius", "best_channels", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + "Available methods are 'radius', 'best_channels', 'by_property'" ) if recording.get_probes() == 1: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 6ee023fc12..b60c1c2eca 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -195,19 +195,6 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) - # ptp : just run it - sparsity = estimate_sparsity( - sorting, - recording, - num_spikes_for_sparsity=50, - ms_before=1.0, - ms_after=2.0, - method="ptp", - threshold=3, - progress_bar=True, - n_jobs=1, - ) - # by_property : just run it sparsity = estimate_sparsity( sorting, From fa8eb1f5b349846f13247537f8b3460b45ae9deb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:10:14 +0200 Subject: [PATCH 21/98] Expose snr_amplitude_mode for snr sparsity --- src/spikeinterface/core/sparsity.py | 43 +++++++++++-------- .../core/tests/test_sparsity.py | 19 +++++++- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 4de786cb37..6904886dcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -25,14 +25,16 @@ * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. - peak_sign : str - Sign of the template to compute best channels ("neg", "pos", "both") + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels num_channels : int Number of channels for "best_channels" method radius_um : float Radius in um for "radius" method threshold : float Threshold in SNR "threshold" method + snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr" method by_property : object Property name for "by_property" method """ @@ -316,7 +318,9 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, peak_sign="neg"): + def from_snr( + cls, templates_or_sorting_analyzer, threshold, snr_amplitude_mode="extremum", noise_levels=None, peak_sign="neg" + ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. @@ -344,7 +348,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode="extremum", return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=snr_amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -353,10 +357,10 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): + def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ - Construct sparsity from a thresholds based on template peak-to-peak relative values. - Use the "threshold" argument to specify the peak-to-peak threshold (with respect to noise_levels). + Construct sparsity from a thresholds based on template peak-to-peak values. + Use the "threshold" argument to specify the peak-to-peak threshold. """ assert ( @@ -371,12 +375,8 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None, "To compute sparsity from ptp you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -385,7 +385,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -453,7 +453,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( - templates_or_sorting_analyzer: Union[Templates, SortingAnalyzer], + templates_or_sorting_analyzer: Templates | SortingAnalyzer, noise_levels: np.ndarray | None = None, method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", @@ -461,6 +461,7 @@ def compute_sparsity( radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, + snr_amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> ChannelSparsity: """ Get channel sparsity (subset of channels) for each template with several methods. @@ -507,7 +508,11 @@ def compute_sparsity( elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( - templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, peak_sign=peak_sign + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + snr_amplitude_mode=snr_amplitude_mode, ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" @@ -536,11 +541,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "by_property" = "radius", + method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, + snr_amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, **job_kwargs, ): @@ -576,6 +582,8 @@ def estimate_sparsity( Used for "radius" method num_channels : int, default: 5 Used for "best_channels" method + snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Used for "snr" method to compute the amplitude of the templates. {} @@ -587,9 +595,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "by_property"), ( + assert method in ("radius", "best_channels", "ptp", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'by_property'" + "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" ) if recording.get_probes() == 1: @@ -649,6 +657,7 @@ def estimate_sparsity( radius_um=radius_um, threshold=threshold, by_property=by_property, + snr_amplitude_mode=snr_amplitude_mode, ) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index b60c1c2eca..16b3bbc996 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -195,7 +195,7 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) - # by_property : just run it + # by_property sparsity = estimate_sparsity( sorting, recording, @@ -209,6 +209,20 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + # ptp: just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -226,6 +240,9 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") + sparsity = compute_sparsity( + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", snr_amplitude_mode="peak_to_peak" + ) sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") From b5c56d64373b217dcf395e95e1a2b7b70a5f8dbe Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:35:45 +0200 Subject: [PATCH 22/98] Propagate sparsity change --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../benchmark/tests/test_benchmark_matching.py | 4 +++- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4701d76012..a6d212425d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index aa9b16bb97..d6d0440a02 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,7 +27,9 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity(gt_templates, noise_levels, method="ptp", threshold=0.25) + sparsity = compute_sparsity( + gt_templates, noise_levels, method="snr", snr_amplitude_mode="peak_to_peak", threshold=0.25 + ) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 77d47aec16..b56fd3e02b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy", From cc14ab2697331ca4b9ca35e972d959e2cbe07a03 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:39:10 +0200 Subject: [PATCH 23/98] last one --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 2bacf36ac9..11a628bb53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, From 71a5b09a49522a06e69e59bd0777056f8bf6ab17 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:51:10 +0200 Subject: [PATCH 24/98] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/releases/0.101.1.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index 41f76d4815..d2b0d35020 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -1,6 +1,6 @@ -.. _release0.101.0: +.. _release0.101.1: -SpikeInterface 0.101.0 release notes +SpikeInterface 0.101.1 release notes ------------------------------------ 6th September 2024 From 6dcf0b468f5b9daf790e9f9bb6d9b5c3a4e17f37 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Fri, 6 Sep 2024 16:15:06 +0100 Subject: [PATCH 25/98] Move bulk of calc from init to set_params and run --- .../qualitymetrics/misc_metrics.py | 13 +++++ .../quality_metric_calculator.py | 52 +++++++++++-------- 2 files changed, 42 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 2de31ad750..8dfd41cf88 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -69,6 +69,9 @@ def compute_num_spikes(sorting_analyzer, unit_ids=None, **kwargs): return num_spikes +_default_params["num_spikes"] = {} + + def compute_firing_rates(sorting_analyzer, unit_ids=None): """ Compute the firing rate across segments. @@ -98,6 +101,9 @@ def compute_firing_rates(sorting_analyzer, unit_ids=None): return firing_rates +_default_params["firing_rate"] = {} + + def compute_presence_ratios(sorting_analyzer, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, unit_ids=None): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -1550,3 +1556,10 @@ def compute_sd_ratio( sd_ratio[unit_id] = unit_std / std_noise return sd_ratio + + +_default_params["sd_ratio"] = dict( + censored_period_ms=4.0, + correct_for_drift=True, + correct_for_template_itself=True, +) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index c3c95a2f54..b85d3bcfd3 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -2,7 +2,6 @@ from __future__ import annotations -import weakref import warnings from copy import deepcopy @@ -30,8 +29,10 @@ class ComputeQualityMetrics(AnalyzerExtension): qm_params : dict or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` - skip_pc_metrics : bool + skip_pc_metrics : bool, default: False If True, PC metrics computation is skipped. + delete_existing_metrics : bool, default: False + If True, deletes any quality_metrics attached to the `sorting_analyzer` Returns ------- @@ -49,20 +50,16 @@ class ComputeQualityMetrics(AnalyzerExtension): use_nodepipeline = False need_job_kwargs = True - def __init__(self, sorting_analyzer): - - self._sorting_analyzer = weakref.ref(sorting_analyzer) - - qm_class = sorting_analyzer.extensions.get("quality_metrics") - - if qm_class: - self.params = qm_class.params - self.data = {"metrics": qm_class.get_data()} - else: - self.params = {} - self.data = {"metrics": None} + def _set_params( + self, + metric_names=None, + qm_params=None, + peak_sign=None, + seed=None, + skip_pc_metrics=False, + delete_existing_metrics=False, + ): - def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=None, skip_pc_metrics=False): if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list @@ -84,15 +81,17 @@ def _set_params(self, metric_names=None, qm_params=None, peak_sign=None, seed=No if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign - metric_names_for_params = metric_names + all_metric_names = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if qm_extension is not None: - existing_metric_names = qm_extension.params.get("metric_names") - if existing_metric_names is not None: - metric_names_for_params.extend(existing_metric_names) + if delete_existing_metrics is False and qm_extension is not None: + existing_params = qm_extension.params + for metric_name in existing_params["metric_names"]: + if metric_name not in metric_names: + all_metric_names.append(metric_name) + qm_params_[metric_name] = existing_params["qm_params"][metric_name] params = dict( - metric_names=[str(name) for name in np.unique(metric_names_for_params)], + metric_names=[str(name) for name in np.unique(all_metric_names)], peak_sign=peak_sign, seed=seed, qm_params=qm_params_, @@ -152,7 +151,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd - metrics = self.data["metrics"] + metrics = self.data.get("metrics") if metrics is None: metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) @@ -204,11 +203,18 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job return metrics - def _run(self, verbose=False, **job_kwargs): + def _run(self, verbose=False, delete_existing_metrics=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs ) + qm_extension = self.sorting_analyzer.get_extension("quality_metrics") + if delete_existing_metrics is False and qm_extension is not None: + existing_metrics = qm_extension.get_data() + for metric_name, metric_data in existing_metrics.items(): + if metric_name not in self.data["metrics"]: + self.data["metrics"][metric_name] = metric_data + def _get_data(self): return self.data["metrics"] From a5ed0b36fae0e33f6b6f310807e90258fed2962a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Sep 2024 17:46:18 +0200 Subject: [PATCH 26/98] Ensure sorting analyzer in zarr are consolidated --- src/spikeinterface/core/sortinganalyzer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 817c453a97..0831391469 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -76,6 +76,9 @@ def create_sorting_analyzer( return_scaled : bool, default: True All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates". This prevent return_scaled being differents from different extensions and having wrong snr for instance. + overwrite: bool, default: False + If True, overwrite the folder if it already exists. + Returns ------- @@ -563,11 +566,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") + zarr.consolidate_metadata(zarr_root.store) + @classmethod def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory sorting = NumpySorting.from_sorting( @@ -2002,7 +2007,7 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - + import zarr import numcodecs extension_group = self._get_zarr_extension_group(mode="r+") @@ -2036,6 +2041,8 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") extension_group[ext_data_name].attrs["object"] = True + # we need to re-consolidate + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _reset_extension_folder(self): """ @@ -2051,7 +2058,7 @@ def _reset_extension_folder(self): import zarr zarr_root = zarr.open(self.folder, mode="r+") - extension_group = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) def reset(self): """ From 4ad6a884c3cba86a4eb0f2f1979937a25ed4f0bf Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Sat, 7 Sep 2024 12:38:04 +0100 Subject: [PATCH 27/98] Add template_metrics --- .../postprocessing/template_metrics.py | 36 +++++++++++++++++-- .../postprocessing/tests/conftest.py | 33 +++++++++++++++++ .../tests/test_template_metrics.py | 30 ++++++++++++++++ .../quality_metric_calculator.py | 4 +-- 4 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 src/spikeinterface/postprocessing/tests/conftest.py diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index e16bd9ad27..57d8fd5839 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -64,6 +64,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): For more on generating a ChannelSparsity, see the `~spikeinterface.compute_sparsity()` function. include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics + delete_existing_metrics : bool, default: False + If True, deletes any quality_metrics attached to the `sorting_analyzer` metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -111,8 +113,10 @@ def _set_params( sparsity=None, metrics_kwargs=None, include_multi_channel_metrics=False, + delete_existing_metrics=False, **other_kwargs, ): + import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() if include_multi_channel_metrics or ( @@ -140,9 +144,30 @@ def _set_params( else: metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) + print(metrics_kwargs_) + + all_metric_names = metric_names + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if delete_existing_metrics is False and tm_extension is not None: + existing_metric_names = tm_extension.params["metric_names"] + existing_params = tm_extension.params["metrics_kwargs"] + + # checks that existing metrics were calculated using the same params + if existing_params != metrics_kwargs_: + warnings.warn( + "The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..." + ) + self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame( + index=self.sorting_analyzer.unit_ids + ) + existing_metric_names = [] + + for metric_name in existing_metric_names: + if metric_name not in metric_names: + all_metric_names.append(metric_name) params = dict( - metric_names=[str(name) for name in np.unique(metric_names)], + metric_names=[str(name) for name in np.unique(all_metric_names)], sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), @@ -283,11 +308,18 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job template_metrics.at[index, metric_name] = value return template_metrics - def _run(self, verbose=False): + def _run(self, delete_existing_metrics=False, verbose=False): self.data["metrics"] = self._compute_metrics( sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose ) + tm_extension = self.sorting_analyzer.get_extension("template_metrics") + if delete_existing_metrics is False and tm_extension is not None: + existing_metrics = tm_extension.get_data() + for metric_name, metric_data in existing_metrics.items(): + if metric_name not in self.data["metrics"]: + self.data["metrics"][metric_name] = metric_data + def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/tests/conftest.py b/src/spikeinterface/postprocessing/tests/conftest.py new file mode 100644 index 0000000000..51ac8aa250 --- /dev/null +++ b/src/spikeinterface/postprocessing/tests/conftest.py @@ -0,0 +1,33 @@ +import pytest + +from spikeinterface.core import ( + generate_ground_truth_recording, + create_sorting_analyzer, +) + + +def _small_sorting_analyzer(): + recording, sorting = generate_ground_truth_recording( + durations=[2.0], + num_units=10, + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(recording=recording, sorting=sorting, format="memory") + + extensions_to_compute = { + "random_spikes": {"seed": 1205}, + "noise_levels": {"seed": 1205}, + "waveforms": {}, + "templates": {"operators": ["average", "median"]}, + "spike_amplitudes": {}, + } + + sorting_analyzer.compute(extensions_to_compute) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def small_sorting_analyzer(): + return _small_sorting_analyzer() diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 694aa083cc..f444e12c36 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -3,6 +3,36 @@ import pytest +def test_compute_new_template_metrics(small_sorting_analyzer): + """ + Computes template metrics then computes a subset of quality metrics, and checks + that the old quality metrics are not deleted. + + Then computes template metrics with new parameters and checks that old metrics + are deleted. + """ + + small_sorting_analyzer.compute("template_metrics") + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + # Check old metrics are not deleted and the new one is added to the data and metadata + assert "exp_decay" in list(template_metric_extension.get_data().keys()) + assert "half_width" in list(template_metric_extension.get_data().keys()) + + # check that, when parameters are changed, the old metrics are deleted + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + ) + + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + assert "half_width" not in list(template_metric_extension.get_data().keys()) + + assert small_sorting_analyzer.get_extension("quality_metrics") is None + + class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b85d3bcfd3..ebd6439be8 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -151,9 +151,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job import pandas as pd - metrics = self.data.get("metrics") - if metrics is None: - metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) + metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) # simple metrics not based on PCs for metric_name in metric_names: From 3360022f5cfd38e9f7ac3b43cbfefa257d0b8695 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Sat, 7 Sep 2024 13:20:38 +0100 Subject: [PATCH 28/98] Tests now pass --- src/spikeinterface/postprocessing/template_metrics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 57d8fd5839..fef35bfc59 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -144,7 +144,6 @@ def _set_params( else: metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) - print(metrics_kwargs_) all_metric_names = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -160,6 +159,7 @@ def _set_params( self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame( index=self.sorting_analyzer.unit_ids ) + self.sorting_analyzer.get_extension("template_metrics").params["metric_names"] = [] existing_metric_names = [] for metric_name in existing_metric_names: @@ -315,9 +315,11 @@ def _run(self, delete_existing_metrics=False, verbose=False): tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_metrics = tm_extension.get_data() - for metric_name, metric_data in existing_metrics.items(): + existing_metrics = tm_extension.params["metric_names"] + + for metric_name in existing_metrics: if metric_name not in self.data["metrics"]: + metric_data = tm_extension.get_data()[metric_name] self.data["metrics"][metric_name] = metric_data def _get_data(self): From 7ddf482056243e4f3198db163757251cd737bda5 Mon Sep 17 00:00:00 2001 From: Florent Pollet Date: Sat, 7 Sep 2024 19:05:27 -0400 Subject: [PATCH 29/98] fix: change naming convention --- src/spikeinterface/widgets/unit_summary.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 75a399fab5..f1f85b7dc3 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -33,11 +33,11 @@ class UnitSummaryWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored widget_params : dict or None, default: None Parameters for the subwidgets in a nested dictionary - unitlocations_params: UnitLocationsWidget (see UnitLocationsWidget for details) - unitwaveforms_params: UnitWaveformsWidget (see UnitWaveformsWidget for details) - unitwaveformdensitymap_params : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) - autocorrelograms_params : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) - amplitudes_params : AmplitudesWidget (see AmplitudesWidget for details) + unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details) + unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details) + unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes: AmplitudesWidget (see AmplitudesWidget for details) """ # possible_backends = {} @@ -82,11 +82,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity = dp.sparsity widget_params = defaultdict(lambda: dict(), dp.widget_params) - unitlocationswidget_params = widget_params["unitlocations_params"] - unitwaveformswidget_params = widget_params["unitwaveforms_params"] - unitwaveformdensitymapwidget_params = widget_params["unitwaveformdensitymap_params"] - autocorrelogramswidget_params = widget_params["autocorrelograms_params"] - amplitudeswidget_params = widget_params["amplitudes_params"] + unitlocationswidget_params = widget_params["unit_locations"] + unitwaveformswidget_params = widget_params["unit_waveforms"] + unitwaveformdensitymapwidget_params = widget_params["unit_waveform_density_map"] + autocorrelogramswidget_params = widget_params["autocorrelograms"] + amplitudeswidget_params = widget_params["amplitudes"] # force the figure without axes if "figsize" not in backend_kwargs: From 3cc1298d8ecb9ce0dfa487a8f93f3b14f9a6ba90 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 9 Sep 2024 08:20:08 +0100 Subject: [PATCH 30/98] tests definitely 100% pass now --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index ebd6439be8..31353df724 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -108,6 +108,8 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): + import pandas as pd + old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids From c12ceb6a60af3ae4625101e9db6332205b71db0c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 9 Sep 2024 17:43:48 +0200 Subject: [PATCH 31/98] Use open_consolidated when possible --- src/spikeinterface/core/sortinganalyzer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0831391469..3abf0e9b5e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -488,7 +488,11 @@ def _get_zarr_root(self, mode="r+"): if is_path_remote(str(self.folder)): mode = "r" - zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) + # we open_consolidated only if we are in read mode + if mode in ("r+", "a"): + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options) + else: + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -2057,8 +2061,9 @@ def _reset_extension_folder(self): elif self.format == "zarr": import zarr - zarr_root = zarr.open(self.folder, mode="r+") + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + zarr.consolidate_metadata(zarr_root.store) def reset(self): """ @@ -2074,7 +2079,7 @@ def set_params(self, save=True, **params): Set parameters for the extension and make it persistent in json. """ - # this ensure data is also deleted and corresponf to params + # this ensure data is also deleted and corresponds to params # this also ensure the group is created self._reset_extension_folder() From fa84e8c6869a475046e4b11c959a8bbac7c5d106 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 16:52:54 +0200 Subject: [PATCH 32/98] Potentiate 'estimate_sparsity' and refactor from_property() constructor --- src/spikeinterface/core/sparsity.py | 261 +++++++++++++----- .../core/tests/test_sparsity.py | 34 ++- .../sorters/internal/spyking_circus2.py | 2 +- .../tests/test_benchmark_matching.py | 4 +- .../sortingcomponents/clustering/circus.py | 2 +- .../clustering/random_projections.py | 2 +- 6 files changed, 226 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 6904886dcf..c72f89520e 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -13,30 +13,33 @@ _sparsity_doc = """ method : str * "best_channels" : N best channels with the largest amplitude. Use the "num_channels" argument to specify the - number of channels. - * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um + number of channels. + * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um. * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument - to specify the SNR threshold (in units of noise levels) + to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of noise levels) + to specify the ptp threshold (in units of amplitude). * "energy" : threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the "threshold" argument to specify the SNR threshold + given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) - * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). - Use the "by_property" argument to specify the property name. + * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). + In this case the sparsity for each unit is given by the channels that have the same property + value as the unit. Use the "by_property" argument to specify the property name. peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels + Sign of the template to compute best channels. num_channels : int - Number of channels for "best_channels" method + Number of channels for "best_channels" method. radius_um : float - Radius in um for "radius" method + Radius in um for "radius" method. threshold : float - Threshold in SNR "threshold" method - snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak" - Mode to compute the amplitude of the templates for the "snr" method + Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). + For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods. by_property : object - Property name for "by_property" method + Property name for "by_property" method. """ @@ -279,18 +282,35 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): + def from_best_channels( + cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg", amplitude_mode="extremum" + ): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + num_channels : int + Number of channels for "best_channels" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity """ from .template_tools import get_template_amplitudes - print(templates_or_sorting_analyzer) mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode) for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] @@ -301,7 +321,21 @@ def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_si def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. - Use the "radius_um" argument to specify the radius in um + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_extremum_channel @@ -319,11 +353,37 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): @classmethod def from_snr( - cls, templates_or_sorting_analyzer, threshold, snr_amplitude_mode="extremum", noise_levels=None, peak_sign="neg" + cls, + templates_or_sorting_analyzer, + threshold, + amplitude_mode="extremum", + peak_sign="neg", + noise_levels=None, ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "snr" method (in units of noise levels). + noise_levels : np.array | None, default: None + Noise levels required for the "snr" method. You can use the + `get_noise_levels()` function to compute them. + If the input is a `SortingAnalyzer`, the noise levels are automatically retrieved + if the `noise_levels` extension is present. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute amplitudes. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates for the "snr" method. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer @@ -348,7 +408,7 @@ def from_snr( mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=snr_amplitude_mode, return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -361,6 +421,18 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the peak-to-peak threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "ptp" method (in units of amplitude). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert ( @@ -394,6 +466,19 @@ def from_energy(cls, sorting_analyzer, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. + This method requires the "waveforms" and "noise_levels" extensions to be computed. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + threshold : float + Threshold for "energy" method (in units of noise levels). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert sorting_analyzer.sparsity is None, "To compute sparsity with energy you need a dense SortingAnalyzer" @@ -419,41 +504,61 @@ def from_energy(cls, sorting_analyzer, threshold): return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) @classmethod - def from_property(cls, sorting_analyzer, by_property): + def from_property(cls, sorting, recording, by_property): """ Construct sparsity witha property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. + + Parameters + ---------- + sorting : Sorting + A Sorting object. + recording : Recording + A Recording object. + by_property : object + Property name for "by_property" method. Both the recording and sorting must have this property set. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ # check consistency - assert ( - by_property in sorting_analyzer.recording.get_property_keys() - ), f"Property {by_property} is not a recording property" - assert ( - by_property in sorting_analyzer.sorting.get_property_keys() - ), f"Property {by_property} is not a sorting property" + assert by_property in recording.get_property_keys(), f"Property {by_property} is not a recording property" + assert by_property in sorting.get_property_keys(), f"Property {by_property} is not a sorting property" - mask = np.zeros((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") - rec_by = sorting_analyzer.recording.split_by(by_property) - for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - unit_property = sorting_analyzer.sorting.get_property(by_property)[unit_ind] + mask = np.zeros((sorting.unit_ids.size, recording.channel_ids.size), dtype="bool") + rec_by = recording.split_by(by_property) + for unit_ind, unit_id in enumerate(sorting.unit_ids): + unit_property = sorting.get_property(by_property)[unit_ind] assert ( unit_property in rec_by.keys() ), f"Unit property {unit_property} cannot be found in the recording properties" - chan_inds = sorting_analyzer.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) + chan_inds = recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) + return cls(mask, sorting.unit_ids, recording.channel_ids) @classmethod def create_dense(cls, sorting_analyzer): """ Create a sparsity object with all selected channel for all units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + + Returns + ------- + sparsity : ChannelSparsity + The full sparsity. """ mask = np.ones((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) def compute_sparsity( - templates_or_sorting_analyzer: Templates | SortingAnalyzer, + templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", @@ -461,10 +566,10 @@ def compute_sparsity( radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, - snr_amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> ChannelSparsity: """ - Get channel sparsity (subset of channels) for each template with several methods. + Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. Parameters ---------- @@ -512,7 +617,7 @@ def compute_sparsity( threshold, noise_levels=noise_levels, peak_sign=peak_sign, - snr_amplitude_mode=snr_amplitude_mode, + amplitude_mode=amplitude_mode, ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" @@ -525,7 +630,9 @@ def compute_sparsity( sparsity = ChannelSparsity.from_energy(templates_or_sorting_analyzer, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(templates_or_sorting_analyzer, by_property) + sparsity = ChannelSparsity.from_property( + templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -541,19 +648,20 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, - snr_amplitude_mode: "extremum" | "peak_to_peak" = "extremum", + amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, + noise_levels: np.ndarray | list | None = None, **job_kwargs, ): """ - Estimate the sparsity without needing a SortingAnalyzer or Templates object - This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it - traverses the recording to compute the average templates for each unit. + Estimate the sparsity without needing a SortingAnalyzer or Templates object. + In case the sparsity method needs templates, they are computed on-the-fly. + The same is done for noise levels, if needed by the method ("snr"). Contrary to the previous implementation: * all units are computed in one read of recording @@ -561,6 +669,8 @@ def estimate_sparsity( * it doesn't consume too much memory * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. + Parameters ---------- sorting : BaseSorting @@ -573,18 +683,9 @@ def estimate_sparsity( Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels" | "by_property", default: "radius" - Sparsity method propagated to the `compute_sparsity()` function. - "snr", "ptp", and "energy" are not available here because they require noise levels. - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute best channels - radius_um : float, default: 100.0 - Used for "radius" method - num_channels : int, default: 5 - Used for "best_channels" method - snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" - Used for "snr" method to compute the amplitude of the templates. - + noise_levels : np.array | None, default: None + Noise levels required for the "snr" and "energy" methods. You can use the + `get_noise_levels()` function to compute them. {} Returns @@ -595,9 +696,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "by_property"), ( + assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + "Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'" ) if recording.get_probes() == 1: @@ -644,21 +745,39 @@ def estimate_sparsity( unit_ids=sorting.unit_ids, probe=probe, ) - templates_or_analyzer = templates + + if method == "best_channels": + assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" + sparsity = ChannelSparsity.from_best_channels( + templates, num_channels, peak_sign=peak_sign, amplitude_mode=amplitude_mode + ) + elif method == "radius": + assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" + sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + elif method == "snr": + assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." + ) + sparsity = ChannelSparsity.from_snr( + templates, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + amplitude_mode=amplitude_mode, + ) + elif method == "ptp": + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates, + threshold, + ) + else: + raise ValueError(f"compute_sparsity() method={method} does not exists") else: - from .sortinganalyzer import create_sorting_analyzer - - templates_or_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) - sparsity = compute_sparsity( - templates_or_analyzer, - method=method, - peak_sign=peak_sign, - num_channels=num_channels, - radius_um=radius_um, - threshold=threshold, - by_property=by_property, - snr_amplitude_mode=snr_amplitude_mode, - ) + assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" + sparsity = ChannelSparsity.from_property(sorting, recording, by_property) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 16b3bbc996..64517c106d 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,7 +3,7 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, Templates +from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, get_noise_levels from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer @@ -223,6 +223,36 @@ def test_estimate_sparsity(): n_jobs=1, ) + # snr: fails without noise levels + with pytest.raises(AssertionError): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + # snr: works with noise levels + noise_levels = get_noise_levels(recording) + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -241,7 +271,7 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity( - sorting_analyzer, method="snr", threshold=5, peak_sign="neg", snr_amplitude_mode="peak_to_peak" + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" ) sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a6d212425d..c3b3099535 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index d6d0440a02..71a5f282a8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,9 +27,7 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity( - gt_templates, noise_levels, method="snr", snr_amplitude_mode="peak_to_peak", threshold=0.25 - ) + sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 11a628bb53..b08ee4d9cb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index b56fd3e02b..f7ca999d53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy", From 15a4a11bad4e08bdf4d8ce5de67e05dd2a2a8fab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 18:14:34 +0200 Subject: [PATCH 33/98] minor docstring fix --- src/spikeinterface/core/sparsity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index c72f89520e..471302d57e 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -661,7 +661,8 @@ def estimate_sparsity( """ Estimate the sparsity without needing a SortingAnalyzer or Templates object. In case the sparsity method needs templates, they are computed on-the-fly. - The same is done for noise levels, if needed by the method ("snr"). + For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. + These can be computed with the `get_noise_levels()` function. Contrary to the previous implementation: * all units are computed in one read of recording From 689c633a697964a9e18673aae5f2c3528382e716 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:06:06 +0100 Subject: [PATCH 34/98] Update template metrics based on Joe feedback --- .../postprocessing/template_metrics.py | 61 +++++++++++-------- .../tests/test_template_metrics.py | 53 +++++++++++++++- 2 files changed, 88 insertions(+), 26 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index fef35bfc59..15b8c85e38 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -65,7 +65,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, deletes any quality_metrics attached to the `sorting_analyzer` + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. metrics_kwargs : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 @@ -116,6 +116,7 @@ def _set_params( delete_existing_metrics=False, **other_kwargs, ): + import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -135,6 +136,10 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() + # `run` cannot take parameters, so need to find another way to pass this + self.delete_existing_metrics = delete_existing_metrics + self.metric_names = metric_names + if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -145,29 +150,24 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) - all_metric_names = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_metric_names = tm_extension.params["metric_names"] - existing_params = tm_extension.params["metrics_kwargs"] + existing_params = tm_extension.params["metrics_kwargs"] # checks that existing metrics were calculated using the same params if existing_params != metrics_kwargs_: warnings.warn( "The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..." ) - self.sorting_analyzer.get_extension("template_metrics").data["metrics"] = pd.DataFrame( - index=self.sorting_analyzer.unit_ids - ) - self.sorting_analyzer.get_extension("template_metrics").params["metric_names"] = [] + tm_extension.params["metric_names"] = [] existing_metric_names = [] + else: + existing_metric_names = tm_extension.params["metric_names"] - for metric_name in existing_metric_names: - if metric_name not in metric_names: - all_metric_names.append(metric_name) + metric_names = list(set(existing_metric_names + metric_names)) params = dict( - metric_names=[str(name) for name in np.unique(all_metric_names)], + metric_names=metric_names, sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), @@ -185,6 +185,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -193,19 +194,20 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. """ import pandas as pd from scipy.signal import resample_poly - metric_names = self.params["metric_names"] sparsity = self.params["sparsity"] peak_sign = self.params["peak_sign"] upsampling_factor = self.params["upsampling_factor"] @@ -308,19 +310,30 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job template_metrics.at[index, metric_name] = value return template_metrics - def _run(self, delete_existing_metrics=False, verbose=False): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose - ) + def _run(self, verbose=False): + + delete_existing_metrics = self.delete_existing_metrics + metric_names = self.metric_names + existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") - if delete_existing_metrics is False and tm_extension is not None: + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): existing_metrics = tm_extension.params["metric_names"] - for metric_name in existing_metrics: - if metric_name not in self.data["metrics"]: - metric_data = tm_extension.get_data()[metric_name] - self.data["metrics"][metric_name] = metric_data + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names + ) + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metric_names): + computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + + self.data["metrics"] = computed_metrics def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index f444e12c36..1fa2ac638c 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,18 +1,22 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite from spikeinterface.postprocessing import ComputeTemplateMetrics import pytest +import csv def test_compute_new_template_metrics(small_sorting_analyzer): """ - Computes template metrics then computes a subset of quality metrics, and checks - that the old quality metrics are not deleted. + Computes template metrics then computes a subset of template metrics, and checks + that the old template metrics are not deleted. Then computes template metrics with new parameters and checks that old metrics are deleted. """ + # calculate all template metrics small_sorting_analyzer.compute("template_metrics") + + # calculate just exp_decay - this should not delete the previously calculated metrics small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") @@ -33,6 +37,51 @@ def test_compute_new_template_metrics(small_sorting_analyzer): assert small_sorting_analyzer.get_extension("quality_metrics") is None +def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes template metrics in binary folder format. Then computes subsets of template + metrics and checks if they are saved correctly. + """ + + from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + + small_sorting_analyzer.compute("template_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + template_metrics_filename = output_folder / "extensions" / "template_metrics" / "metrics.csv" + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + assert metric_name in metric_names + + folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) + + with open(template_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in list(_single_channel_metric_name_to_func.keys()): + if metric_name == "half_width": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + class TestTemplateMetrics(AnalyzerExtensionCommonTestSuite): @pytest.mark.parametrize( From 6cbe4dbe9c805aac3fb3ed29a0f50f4623eeab9c Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 10 Sep 2024 18:25:30 +0100 Subject: [PATCH 35/98] Improve tests for template metrics --- .../tests/test_template_metrics.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 1fa2ac638c..8aaad8ffbc 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -3,6 +3,10 @@ import pytest import csv +from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func + +template_metrics = list(_single_channel_metric_name_to_func.keys()) + def test_compute_new_template_metrics(small_sorting_analyzer): """ @@ -13,28 +17,38 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + # calculate just exp_decay + small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) + template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + + assert "exp_decay" in list(template_metric_extension.get_data().keys()) + assert "half_width" not in list(template_metric_extension.get_data().keys()) + # calculate all template metrics small_sorting_analyzer.compute("template_metrics") - - # calculate just exp_decay - this should not delete the previously calculated metrics + # calculate just exp_decay - this should not delete any other metrics small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) - template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") - # Check old metrics are not deleted and the new one is added to the data and metadata - assert "exp_decay" in list(template_metric_extension.get_data().keys()) - assert "half_width" in list(template_metric_extension.get_data().keys()) + set(template_metrics) == set(template_metric_extension.get_data().keys()) - # check that, when parameters are changed, the old metrics are deleted + # calculate just exp_decay with delete_existing_metrics small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "delete_existing_metrics": True}} ) - template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") + computed_metric_names = template_metric_extension.get_data().keys() - assert "half_width" not in list(template_metric_extension.get_data().keys()) + for metric_name in template_metrics: + if metric_name == "exp_decay": + assert metric_name in computed_metric_names + else: + assert metric_name not in computed_metric_names - assert small_sorting_analyzer.get_extension("quality_metrics") is None + # check that, when parameters are changed, the old metrics are deleted + small_sorting_analyzer.compute( + {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + ) def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): @@ -43,8 +57,6 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): metrics and checks if they are saved correctly. """ - from spikeinterface.postprocessing.template_metrics import _single_channel_metric_name_to_func - small_sorting_analyzer.compute("template_metrics") cache_folder = create_cache_folder @@ -57,7 +69,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=False) @@ -66,7 +78,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: assert metric_name in metric_names folder_analyzer.compute("template_metrics", metric_names=["half_width"], delete_existing_metrics=True) @@ -75,7 +87,7 @@ def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): saved_metrics = csv.reader(metrics_file) metric_names = next(saved_metrics) - for metric_name in list(_single_channel_metric_name_to_func.keys()): + for metric_name in template_metrics: if metric_name == "half_width": assert metric_name in metric_names else: From 15c51f96ef593b73b35eb1691168b3464bb9128e Mon Sep 17 00:00:00 2001 From: Florent Pollet Date: Tue, 10 Sep 2024 20:25:41 -0400 Subject: [PATCH 36/98] feat: renaming to kwargs + clearer error msg --- src/spikeinterface/widgets/unit_summary.py | 46 ++++++++++++---------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index f1f85b7dc3..fd7dafd5d6 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -18,26 +18,27 @@ class UnitSummaryWidget(BaseWidget): """ Plot a unit summary. - If amplitudes are alreday computed they are displayed. + If amplitudes are alreday computed, they are displayed. Parameters ---------- - sorting_analyzer : SortingAnalyzer + sorting_analyzer: SortingAnalyzer The SortingAnalyzer object - unit_id : int or str + unit_id: int or str The unit id to plot the summary of - unit_colors : dict or None, default: None + unit_colors: dict or None, default: None If given, a dictionary with unit ids as keys and colors as values, - sparsity : ChannelSparsity or None, default: None + sparsity: ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - widget_params : dict or None, default: None + subwidget_kwargs: dict or None, default: None Parameters for the subwidgets in a nested dictionary unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details) unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details) unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) amplitudes: AmplitudesWidget (see AmplitudesWidget for details) + Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ # possible_backends = {} @@ -48,7 +49,7 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - widget_params=None, + subwidget_kwargs=None, backend=None, **backend_kwargs, ): @@ -57,15 +58,18 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) - if widget_params is None: - widget_params = dict() + if subwidget_kwargs is None: + subwidget_kwargs = dict() + for kwargs in subwidget_kwargs.values(): + if "unit_colors" in kwargs: + raise ValueError("unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary") plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, - widget_params=widget_params, + subwidget_kwargs=subwidget_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -81,12 +85,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors = dp.unit_colors sparsity = dp.sparsity - widget_params = defaultdict(lambda: dict(), dp.widget_params) - unitlocationswidget_params = widget_params["unit_locations"] - unitwaveformswidget_params = widget_params["unit_waveforms"] - unitwaveformdensitymapwidget_params = widget_params["unit_waveform_density_map"] - autocorrelogramswidget_params = widget_params["autocorrelograms"] - amplitudeswidget_params = widget_params["amplitudes"] + subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs) + unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"] + unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"] + unitwaveformdensitymapwidget_kwargs = subwidget_kwargs["unit_waveform_density_map"] + autocorrelogramswidget_kwargs = subwidget_kwargs["autocorrelograms"] + amplitudeswidget_kwargs = subwidget_kwargs["amplitudes"] # force the figure without axes if "figsize" not in backend_kwargs: @@ -117,7 +121,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, - **unitlocationswidget_params, + **unitlocationswidget_kwargs, ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -140,7 +144,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, - **unitwaveformswidget_params, + **unitwaveformswidget_kwargs, ) ax2.set_title(None) @@ -154,7 +158,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, - **unitwaveformdensitymapwidget_params, + **unitwaveformdensitymapwidget_kwargs, ) ax3.set_ylabel(None) @@ -166,7 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, - **autocorrelogramswidget_params, + **autocorrelogramswidget_kwargs, ) ax4.set_title(None) @@ -184,7 +188,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, - **amplitudeswidget_params, + **amplitudeswidget_kwargs, ) fig.suptitle(f"unit_id: {dp.unit_id}") From 3a48d66a52af617fdbbac4bcd94595727dc1cffe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 00:26:04 +0000 Subject: [PATCH 37/98] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/unit_summary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fd7dafd5d6..652c0f841f 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -62,7 +62,9 @@ def __init__( subwidget_kwargs = dict() for kwargs in subwidget_kwargs.values(): if "unit_colors" in kwargs: - raise ValueError("unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary") + raise ValueError( + "unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary" + ) plot_data = dict( sorting_analyzer=sorting_analyzer, From 94abde04386441e3d2994635d4ee545a5b414175 Mon Sep 17 00:00:00 2001 From: Florent Pollet Date: Tue, 10 Sep 2024 20:27:18 -0400 Subject: [PATCH 38/98] feat: default dict comment --- src/spikeinterface/widgets/unit_summary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fd7dafd5d6..6c39c902eb 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -85,6 +85,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors = dp.unit_colors sparsity = dp.sparsity + # defaultdict returns empty dict if key not found in subwidget_kwargs subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs) unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"] unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"] From 8b5de9d65abb3152c6ad30ecb9599b27f17a903e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 11 Sep 2024 09:52:33 +0200 Subject: [PATCH 39/98] propagate storage option --- src/spikeinterface/core/sortinganalyzer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3abf0e9b5e..519d741bc1 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -622,6 +622,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, + storage_options=storage_options ) return sorting_analyzer From c56d625d744e7474658d0acca1284f6eacbbe200 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 07:54:25 +0000 Subject: [PATCH 40/98] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 519d741bc1..daa693d667 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -622,7 +622,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, - storage_options=storage_options + storage_options=storage_options, ) return sorting_analyzer From 4589c6efae21d540345ad1ba858e53828d441e48 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 09:28:00 +0100 Subject: [PATCH 41/98] Add ordering and propogate through params --- .../postprocessing/template_metrics.py | 18 +++++++++++------- .../tests/test_template_metrics.py | 11 +++++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 15b8c85e38..062b0bd76b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -137,8 +137,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() # `run` cannot take parameters, so need to find another way to pass this - self.delete_existing_metrics = delete_existing_metrics - self.metric_names = metric_names + metric_names_to_compute = metric_names if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() @@ -164,7 +163,10 @@ def _set_params( else: existing_metric_names = tm_extension.params["metric_names"] - metric_names = list(set(existing_metric_names + metric_names)) + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metric_names_to_compute + ] + metric_names = metric_names_to_compute + existing_metric_names_propogated params = dict( metric_names=metric_names, @@ -172,6 +174,8 @@ def _set_params( peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), metrics_kwargs=metrics_kwargs_, + delete_existing_metrics=delete_existing_metrics, + metric_names_to_compute=metric_names_to_compute, ) return params @@ -312,8 +316,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): - delete_existing_metrics = self.delete_existing_metrics - metric_names = self.metric_names + delete_existing_metrics = self.params["delete_existing_metrics"] + metric_names_to_compute = self.params["metric_names_to_compute"] existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -326,11 +330,11 @@ def _run(self, verbose=False): # compute the metrics which have been specified by the user computed_metrics = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names_to_compute ) # append the metrics which were previously computed - for metric_name in set(existing_metrics).difference(metric_names): + for metric_name in set(existing_metrics).difference(metric_names_to_compute): computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 8aaad8ffbc..5056d4ff2a 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -51,6 +51,17 @@ def test_compute_new_template_metrics(small_sorting_analyzer): ) +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified template metrics and checks order is propogated. + """ + specified_metric_names = ["peak_trough_ratio", "num_negative_peaks", "half_width"] + small_sorting_analyzer.compute("template_metrics", metric_names=specified_metric_names) + tm_keys = small_sorting_analyzer.get_extension("template_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == tm_keys[i] + + def test_save_template_metrics(small_sorting_analyzer, create_cache_folder): """ Computes template metrics in binary folder format. Then computes subsets of template From accf40a30660900dabeadd5f8b0d17190ed6d3a4 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:19:31 +0100 Subject: [PATCH 42/98] Update quality metrics --- .../postprocessing/template_metrics.py | 24 ++-- .../quality_metric_calculator.py | 64 ++++++--- .../qualitymetrics/quality_metric_list.py | 26 ++++ .../tests/test_metrics_functions.py | 122 ++++++++++++++++-- 4 files changed, 192 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 062b0bd76b..5f4c1e904b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -136,9 +136,6 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - # `run` cannot take parameters, so need to find another way to pass this - metric_names_to_compute = metric_names - if metrics_kwargs is None: metrics_kwargs_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -149,6 +146,7 @@ def _set_params( metrics_kwargs_ = _default_function_kwargs.copy() metrics_kwargs_.update(metrics_kwargs) + metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: @@ -164,9 +162,9 @@ def _set_params( existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ - metric_name for metric_name in existing_metric_names if metric_name not in metric_names_to_compute + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] - metric_names = metric_names_to_compute + existing_metric_names_propogated + metric_names = metrics_to_compute + existing_metric_names_propogated params = dict( metric_names=metric_names, @@ -175,7 +173,7 @@ def _set_params( upsampling_factor=int(upsampling_factor), metrics_kwargs=metrics_kwargs_, delete_existing_metrics=delete_existing_metrics, - metric_names_to_compute=metric_names_to_compute, + metrics_to_compute=metrics_to_compute, ) return params @@ -317,7 +315,12 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): delete_existing_metrics = self.params["delete_existing_metrics"] - metric_names_to_compute = self.params["metric_names_to_compute"] + metrics_to_compute = self.params["metrics_to_compute"] + + # compute the metrics which have been specified by the user + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metrics_to_compute + ) existing_metrics = [] tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -328,13 +331,8 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] - # compute the metrics which have been specified by the user - computed_metrics = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, metric_names=metric_names_to_compute - ) - # append the metrics which were previously computed - for metric_name in set(existing_metrics).difference(metric_names_to_compute): + for metric_name in set(existing_metrics).difference(metrics_to_compute): computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 31353df724..1c7483212a 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -11,7 +11,12 @@ from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension -from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names +from .quality_metric_list import ( + compute_pc_metrics, + _misc_metric_name_to_func, + _possible_pc_metric_names, + compute_name_to_column_names, +) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -32,7 +37,7 @@ class ComputeQualityMetrics(AnalyzerExtension): skip_pc_metrics : bool, default: False If True, PC metrics computation is skipped. delete_existing_metrics : bool, default: False - If True, deletes any quality_metrics attached to the `sorting_analyzer` + If True, any quality metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept. Returns ------- @@ -81,21 +86,24 @@ def _set_params( if "peak_sign" in qm_params_[k] and peak_sign is not None: qm_params_[k]["peak_sign"] = peak_sign - all_metric_names = metric_names + metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") if delete_existing_metrics is False and qm_extension is not None: - existing_params = qm_extension.params - for metric_name in existing_params["metric_names"]: - if metric_name not in metric_names: - all_metric_names.append(metric_name) - qm_params_[metric_name] = existing_params["qm_params"][metric_name] + + existing_metric_names = qm_extension.params["metric_names"] + existing_metric_names_propogated = [ + metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute + ] + metric_names = metrics_to_compute + existing_metric_names_propogated params = dict( - metric_names=[str(name) for name in np.unique(all_metric_names)], + metric_names=metric_names, peak_sign=peak_sign, seed=seed, qm_params=qm_params_, skip_pc_metrics=skip_pc_metrics, + delete_existing_metrics=delete_existing_metrics, + metrics_to_compute=metrics_to_compute, ) return params @@ -123,11 +131,11 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data - def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job_kwargs): + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics. """ - metric_names = self.params["metric_names"] + qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -203,17 +211,35 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job return metrics - def _run(self, verbose=False, delete_existing_metrics=False, **job_kwargs): - self.data["metrics"] = self._compute_metrics( - sorting_analyzer=self.sorting_analyzer, unit_ids=None, verbose=verbose, **job_kwargs + def _run(self, verbose=False, **job_kwargs): + + metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] + + computed_metrics = self._compute_metrics( + sorting_analyzer=self.sorting_analyzer, + unit_ids=None, + verbose=verbose, + metric_names=metrics_to_compute, + **job_kwargs, ) + existing_metrics = [] qm_extension = self.sorting_analyzer.get_extension("quality_metrics") - if delete_existing_metrics is False and qm_extension is not None: - existing_metrics = qm_extension.get_data() - for metric_name, metric_data in existing_metrics.items(): - if metric_name not in self.data["metrics"]: - self.data["metrics"][metric_name] = metric_data + if ( + delete_existing_metrics is False + and qm_extension is not None + and qm_extension.data.get("metrics") is not None + ): + existing_metrics = qm_extension.params["metric_names"] + + # append the metrics which were previously computed + for metric_name in set(existing_metrics).difference(metrics_to_compute): + # some metrics names produce data columns with other names. This deals with that. + for column_name in compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = qm_extension.data["metrics"][column_name] + + self.data["metrics"] = computed_metrics def _get_data(self): return self.data["metrics"] diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 140ad87a8b..375dd320ae 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -53,3 +53,29 @@ "drift": compute_drift_metrics, "sd_ratio": compute_sd_ratio, } + +# a dict converting the name of the metric for computation to the output of that computation +compute_name_to_column_names = { + "num_spikes": ["num_spikes"], + "firing_rate": ["firing_rate"], + "presence_ratio": ["presence_ratio"], + "snr": ["snr"], + "isi_violation": ["isi_violations_ratio", "isi_violations_count"], + "rp_violation": ["rp_violations", "rp_contamination"], + "sliding_rp_violation": ["sliding_rp_violation"], + "amplitude_cutoff": ["amplitude_cutoff"], + "amplitude_median": ["amplitude_median"], + "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], + "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "firing_range": ["firing_range"], + "drift": ["drift_ptp", "drift_std", "drift_mad"], + "sd_ratio": ["sd_ratio"], + "isolation_distance": ["isolation_distance"], + "l_ratio": ["l_ratio"], + "d_prime": ["d_prime"], + "nearest_neighbor": ["nn_hit_rate", "nn_miss_rate"], + "nn_isolation": ["nn_isolation", "nn_unit_id"], + "nn_noise_overlap": ["nn_noise_overlap"], + "silhouette": ["silhouette"], + "silhouette_full": ["silhouette_full"], +} diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index e34c15c936..77909798a3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -2,6 +2,7 @@ from pathlib import Path import numpy as np from copy import deepcopy +import csv from spikeinterface.core import ( NumpySorting, synthetize_spike_train_bad_isi, @@ -42,6 +43,7 @@ compute_quality_metrics, ) + from spikeinterface.core.basesorting import minimum_spike_dtype @@ -60,6 +62,12 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): "firing_range": {"bin_size_s": 1}, } + small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) + qm_extension = small_sorting_analyzer.get_extension("quality_metrics") + calculated_metrics = list(qm_extension.get_data().keys()) + + assert calculated_metrics == ["snr"] + small_sorting_analyzer.compute( {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} ) @@ -68,18 +76,22 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") # Check old metrics are not deleted and the new one is added to the data and metadata - assert list(quality_metric_extension.get_data().keys()) == [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] - assert list(quality_metric_extension.params.get("metric_names")) == [ - "amplitude_cutoff", - "firing_range", - "presence_ratio", - "snr", - ] + assert set(list(quality_metric_extension.get_data().keys())) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) + assert set(list(quality_metric_extension.params.get("metric_names"))) == set( + [ + "amplitude_cutoff", + "firing_range", + "presence_ratio", + "snr", + ] + ) # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) @@ -106,6 +118,92 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert small_sorting_analyzer.get_extension("quality_metrics") is None +def test_metric_names_in_same_order(small_sorting_analyzer): + """ + Computes sepecified quality metrics and checks order is propogated. + """ + specified_metric_names = ["firing_range", "snr", "amplitude_cutoff"] + small_sorting_analyzer.compute("quality_metrics", metric_names=specified_metric_names) + qm_keys = small_sorting_analyzer.get_extension("quality_metrics").get_data().keys() + for i in range(3): + assert specified_metric_names[i] == qm_keys[i] + + +def test_save_quality_metrics(small_sorting_analyzer, create_cache_folder): + """ + Computes quality metrics in binary folder format. Then computes subsets of quality + metrics and checks if they are saved correctly. + """ + + # can't use _misc_metric_name_to_func as some functions compute several qms + # e.g. isi_violation and synchrony + quality_metrics = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violations_ratio", + "isi_violations_count", + "rp_contamination", + "rp_violations", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + "amplitude_cv_median", + "amplitude_cv_range", + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + "firing_range", + "drift_ptp", + "drift_std", + "drift_mad", + "sd_ratio", + "isolation_distance", + "l_ratio", + "d_prime", + "silhouette", + "nn_hit_rate", + "nn_miss_rate", + ] + + small_sorting_analyzer.compute("quality_metrics") + + cache_folder = create_cache_folder + output_folder = cache_folder / "sorting_analyzer" + + folder_analyzer = small_sorting_analyzer.save_as(format="binary_folder", folder=output_folder) + quality_metrics_filename = output_folder / "extensions" / "quality_metrics" / "metrics.csv" + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=False) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + assert metric_name in metric_names + + folder_analyzer.compute("quality_metrics", metric_names=["snr"], delete_existing_metrics=True) + + with open(quality_metrics_filename) as metrics_file: + saved_metrics = csv.reader(metrics_file) + metric_names = next(saved_metrics) + + for metric_name in quality_metrics: + if metric_name == "snr": + assert metric_name in metric_names + else: + assert metric_name not in metric_names + + def test_unit_structure_in_output(small_sorting_analyzer): qm_params = { From c1f0b2a8ae1996f30b2612b92aa3fd48e50dba3a Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:34:43 +0100 Subject: [PATCH 43/98] Update merge_extension_data for quality_metrics --- .../qualitymetrics/quality_metric_calculator.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 1c7483212a..a143ac3562 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -63,6 +63,7 @@ def _set_params( seed=None, skip_pc_metrics=False, delete_existing_metrics=False, + metrics_to_compute=None, ): if metric_names is None: @@ -118,6 +119,7 @@ def _merge_extension_data( ): import pandas as pd + metric_names = self.params["metric_names"] old_metrics = self.data["metrics"] all_unit_ids = new_sorting_analyzer.unit_ids @@ -126,7 +128,9 @@ def _merge_extension_data( metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] - metrics.loc[new_unit_ids, :] = self._compute_metrics(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs) + metrics.loc[new_unit_ids, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs + ) new_data = dict(metrics=metrics) return new_data From 5d5afd2873eb19791a7366be88ae577961f6755e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:45:12 +0100 Subject: [PATCH 44/98] Put main stuff back --- .../tests/test_metrics_functions.py | 29 +++++++++++++++++++ .../tests/test_quality_metric_calculator.py | 10 ++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 156bab84d8..0a936edb39 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -536,3 +536,32 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) + + +if __name__ == "__main__": + + sorting_analyzer = _sorting_analyzer_simple() + print(sorting_analyzer) + + test_unit_structure_in_output(_small_sorting_analyzer()) + + # test_calculate_firing_rate_num_spikes(sorting_analyzer) + + test_calculate_snrs(sorting_analyzer) + # test_calculate_amplitude_cutoff(sorting_analyzer) + # test_calculate_presence_ratio(sorting_analyzer) + # test_calculate_amplitude_median(sorting_analyzer) + # test_calculate_sliding_rp_violations(sorting_analyzer) + # test_calculate_drift_metrics(sorting_analyzer) + # test_synchrony_metrics(sorting_analyzer) + # test_synchrony_metrics_unit_id_subset(sorting_analyzer) + # test_synchrony_metrics_no_unit_ids(sorting_analyzer) + # test_calculate_firing_range(sorting_analyzer) + # test_calculate_amplitude_cv_metrics(sorting_analyzer) + # test_calculate_sd_ratio(sorting_analyzer) + + # sorting_analyzer_violations = _sorting_analyzer_violations() + # print(sorting_analyzer_violations) + # test_calculate_isi_violations(sorting_analyzer_violations) + # test_calculate_sliding_rp_violations(sorting_analyzer_violations) + # test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index fec5ceeb95..a6415c58e8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -19,7 +19,6 @@ def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - print(sorting_analyzer) # without PCs metrics = compute_quality_metrics( @@ -245,3 +244,12 @@ def test_empty_units(sorting_analyzer_simple): # for metric_name in metrics.columns: # # NaNs are skipped # assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) + +if __name__ == "__main__": + + sorting_analyzer = get_sorting_analyzer() + print(sorting_analyzer) + + test_compute_quality_metrics(sorting_analyzer) + test_compute_quality_metrics_recordingless(sorting_analyzer) + test_empty_units(sorting_analyzer) From 492507ec3451b6cb4862e7c1b6985074eefd0085 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 11 Sep 2024 11:46:41 +0200 Subject: [PATCH 45/98] Fix proposal for channel location when probegroup --- .../core/baserecordingsnippets.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 428472bf93..f7b55d3f6a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -348,17 +348,22 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"): if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - if self.get_property("contact_vector") is not None: - if len(self.get_probes()) == 1: - probe = self.get_probe() - positions = probe.contact_positions[channel_indices] - else: - all_probes = self.get_probes() - # check that multiple probes are non-overlapping - check_probe_do_not_overlap(all_probes) - all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - positions = all_positions[channel_indices] - return select_axes(positions, axes) + contact_vector = self.get_property("contact_vector") + if contact_vector is not None: + # to avoid the get_probes() when only one probe do check unique probe_id + num_probes = np.unique(contact_vector["probe_index"]).size + if num_probes > 1: + # get_probes() is called only when several probes check_overlaps + # TODO make this check_probe_do_not_overlap() use only the contact_vector instead of constructing the probe + check_probe_do_not_overlap(self.get_probes()) + + # here we bypass the probe reconstruction so this works both for probe and probegroup + ndim = len(axes) + all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") + for i, dim in enumerate(axes): + all_positions[:, i] = contact_vector[dim] + positions = all_positions[channel_indices] + return positions else: locations = self.get_property("location") if locations is None: From f3aac424c6da100aac0614b7f5382c0d797ebd17 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:46:59 +0100 Subject: [PATCH 46/98] oups --- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 0a936edb39..e7fc7ce209 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -547,7 +547,7 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): # test_calculate_firing_rate_num_spikes(sorting_analyzer) - test_calculate_snrs(sorting_analyzer) + # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) # test_calculate_amplitude_median(sorting_analyzer) From ff07ac603bb8552ef40bd1f43962f201a72df17c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 11:50:22 +0200 Subject: [PATCH 47/98] Fix run_info and consolidate metadata --- src/spikeinterface/core/sortinganalyzer.py | 35 +++++++++++++++---- .../tests/test_analyzer_extension_core.py | 5 +++ .../core/tests/test_sortinganalyzer.py | 3 ++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 4f5be665f4..424fab7c5e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1472,7 +1472,7 @@ def delete_extension(self, extension_name) -> None: if self.format != "memory" and self.has_extension(extension_name): # need a reload to reset the folder ext = self.load_extension(extension_name) - ext.reset() + ext.delete() # remove from dict self.extensions.pop(extension_name, None) @@ -2014,19 +2014,17 @@ def run(self, save=True, **kwargs): # NB: this call to _save_params() also resets the folder or zarr group self._save_params() self._save_importing_provenance() - self._save_run_info() t_start = perf_counter() self._run(**kwargs) t_end = perf_counter() self.run_info["runtime_s"] = t_end - t_start + self.run_info["run_completed"] = True if save and not self.sorting_analyzer.is_read_only(): + self._save_run_info() self._save_data(**kwargs) - self.run_info["run_completed"] = True - self._save_run_info() - def save(self, **kwargs): self._save_params() self._save_importing_provenance() @@ -2126,6 +2124,32 @@ def _reset_extension_folder(self): _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) zarr.consolidate_metadata(zarr_root.store) + def _delete_extension_folder(self): + """ + Delete the extension in a folder (binary or zarr). + """ + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + if extension_folder.is_dir(): + shutil.rmtree(extension_folder) + + elif self.format == "zarr": + import zarr + + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") + if self.extension_name in zarr_root["extensions"]: + del zarr_root["extensions"][self.extension_name] + zarr.consolidate_metadata(zarr_root.store) + + def delete(self): + """ + Delete the extension from the folder or zarr and from the dict. + """ + self._delete_extension_folder() + self.params = None + self.run_info = self._default_run_info_dict() + self.data = dict() + def reset(self): """ Reset the waveform extension. @@ -2154,7 +2178,6 @@ def set_params(self, save=True, **params): if save: self._save_params() self._save_importing_provenance() - self._save_run_info() def _save_params(self): params_to_save = self.params.copy() diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index b4d96a3391..626899ab6e 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -79,15 +79,20 @@ def _check_result_extension(sorting_analyzer, extension_name, cache_folder): ) def test_ComputeRandomSpikes(format, sparse, create_cache_folder): cache_folder = create_cache_folder + print("Creating analyzer") sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) + print("Computing random spikes") ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205) indices = ext.data["random_spikes_indices"] assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size + print("Checking results") _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) + print("Delering extension") sorting_analyzer.delete_extension("random_spikes") + print("Re-computing random spikes") ext = sorting_analyzer.compute("random_spikes", method="all") indices = ext.data["random_spikes_indices"] assert indices.size == len(sorting_analyzer.sorting.to_spike_vector()) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 3f45487f4c..77b8f2c5bf 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -126,6 +126,8 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): def test_load_without_runtime_info(tmp_path, dataset): + import zarr + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_run_info" @@ -153,6 +155,7 @@ def test_load_without_runtime_info(tmp_path, dataset): root = sorting_analyzer._get_zarr_root(mode="r+") for ext in extensions: del root["extensions"][ext].attrs["run_info"] + zarr.consolidate_metadata(root.store) # should raise a warning for missing run_info with pytest.warns(UserWarning): sorting_analyzer = load_sorting_analyzer(folder, format="auto") From 9fb7f344b578e0bf2a9436608757fdffaa008e06 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:59:35 +0100 Subject: [PATCH 48/98] use sa unit ids and switch order of id indep test --- .../qualitymetrics/quality_metric_calculator.py | 2 +- .../qualitymetrics/tests/test_metrics_functions.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 123293f313..8a754fc7da 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -165,7 +165,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri import pandas as pd - metrics = pd.DataFrame(index=sorting_analyzer.sorting.unit_ids) + metrics = pd.DataFrame(index=sorting_analyzer.unit_ids) # simple metrics not based on PCs for metric_name in metric_names: diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 77909798a3..ee5d5849b3 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -286,10 +286,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params ) - for metric, metric_1_data in quality_metrics_1.items(): - assert quality_metrics_2[metric][2] == metric_1_data["#3"] - assert quality_metrics_2[metric][7] == metric_1_data["#9"] - assert quality_metrics_2[metric][1] == metric_1_data["#4"] + for metric, metric_2_data in quality_metrics_2.items(): + assert quality_metrics_1[metric]["#3"] == metric_2_data[2] + assert quality_metrics_1[metric]["#9"] == metric_2_data[7] + assert quality_metrics_1[metric]["#4"] == metric_2_data[1] def _sorting_analyzer_simple(): From 01485d91646b0234d0ade2df0b98db2d197ca4fc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 12:17:36 +0200 Subject: [PATCH 49/98] Add space before colons in docs --- src/spikeinterface/widgets/unit_summary.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 851dfa049a..755e60ccbf 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -22,22 +22,22 @@ class UnitSummaryWidget(BaseWidget): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer The SortingAnalyzer object - unit_id: int or str + unit_id : int or str The unit id to plot the summary of - unit_colors: dict or None, default: None + unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values, - sparsity: ChannelSparsity or None, default: None + sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - subwidget_kwargs: dict or None, default: None + subwidget_kwargs : dict or None, default: None Parameters for the subwidgets in a nested dictionary - unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details) - unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details) - unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) - autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) - amplitudes: AmplitudesWidget (see AmplitudesWidget for details) + unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details) + unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details) + unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes : AmplitudesWidget (see AmplitudesWidget for details) Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ From e5c710d50dc21e3e86748d31c4894a541fc3bcab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:04:20 +0200 Subject: [PATCH 50/98] Refactor analyzer.get_channel_locations() --- src/spikeinterface/core/sortinganalyzer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 49a31738e3..a7a1ad587e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1101,9 +1101,14 @@ def get_probe(self): def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording # this give all channel locations, so no kwargs like channel_ids and axes - all_probes = self.get_probegroup().probes - all_positions = np.vstack([probe.contact_positions for probe in all_probes]) - return all_positions + probegroup = self.get_probegroup() + probe_as_numpy_array = probegroup.to_numpy() + # duplicate positions to "locations" property + ndim = probegroup.ndim + locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") + for i, dim in enumerate(["x", "y", "z"][:ndim]): + locations[:, i] = probe_as_numpy_array[dim] + return locations def channel_ids_to_indices(self, channel_ids) -> np.ndarray: all_channel_ids = list(self.rec_attributes["channel_ids"]) From 0f0834428081a7db68bf5be97d04803527f872e0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:09:34 +0200 Subject: [PATCH 51/98] Check probes do not overlap at _set_probes --- src/spikeinterface/core/baserecordingsnippets.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index f7b55d3f6a..763f9e5801 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -145,6 +145,11 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False else: raise ValueError("must give Probe or ProbeGroup or list of Probe") + # check that the probe do not overlap + num_probes = len(probegroup.probes) + if num_probes > 1: + check_probe_do_not_overlap(probegroup.probes) + # handle not connected channels assert all( probe.device_channel_indices is not None for probe in probegroup.probes @@ -350,13 +355,6 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"): channel_indices = self.ids_to_indices(channel_ids) contact_vector = self.get_property("contact_vector") if contact_vector is not None: - # to avoid the get_probes() when only one probe do check unique probe_id - num_probes = np.unique(contact_vector["probe_index"]).size - if num_probes > 1: - # get_probes() is called only when several probes check_overlaps - # TODO make this check_probe_do_not_overlap() use only the contact_vector instead of constructing the probe - check_probe_do_not_overlap(self.get_probes()) - # here we bypass the probe reconstruction so this works both for probe and probegroup ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") From cc447b05d91651e6b3a29280049fb23cc3fd0d10 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:10:06 +0200 Subject: [PATCH 52/98] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a7a1ad587e..83e214f4ab 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1103,7 +1103,6 @@ def get_channel_locations(self) -> np.ndarray: # this give all channel locations, so no kwargs like channel_ids and axes probegroup = self.get_probegroup() probe_as_numpy_array = probegroup.to_numpy() - # duplicate positions to "locations" property ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): From d62653b8810cc53cd7275a4b25e4628ef87acf5c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:15:33 +0200 Subject: [PATCH 53/98] Sort probegroup array by device_channel_indices --- src/spikeinterface/core/sortinganalyzer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a7a1ad587e..f0bc8e49bb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1103,6 +1103,8 @@ def get_channel_locations(self) -> np.ndarray: # this give all channel locations, so no kwargs like channel_ids and axes probegroup = self.get_probegroup() probe_as_numpy_array = probegroup.to_numpy() + # we need to sort by device_channel_indices to ensure the order of locations is correct + probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] # duplicate positions to "locations" property ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") From 9539c93ba2a7dac133d4ce742cc92e834ca9576c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 13:18:42 +0200 Subject: [PATCH 54/98] fix to_numpy --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1f54f6687f..6ce8d180c5 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1102,7 +1102,7 @@ def get_channel_locations(self) -> np.ndarray: # important note : contrary to recording # this give all channel locations, so no kwargs like channel_ids and axes probegroup = self.get_probegroup() - probe_as_numpy_array = probegroup.to_numpy() + probe_as_numpy_array = probegroup.to_numpy(complete=True) # we need to sort by device_channel_indices to ensure the order of locations is correct probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] ndim = probegroup.ndim From b35873cd5a3fd7e37cf8178b9699a74cd18c08f8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 20:08:17 +0100 Subject: [PATCH 55/98] compute pcs for sortinganalyzer again --- src/spikeinterface/qualitymetrics/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..676889094b 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -67,5 +67,6 @@ def sorting_analyzer_simple(): sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) return sorting_analyzer From 406f99bf7d2f68b68db3e4535ba6e004331de419 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 12 Sep 2024 08:39:56 +0100 Subject: [PATCH 56/98] update warning to include metric names --- src/spikeinterface/postprocessing/template_metrics.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index d05e3ae7ef..0d0d633c04 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -152,7 +152,9 @@ def _set_params( # checks that existing metrics were calculated using the same params if existing_params != metrics_kwargs_: warnings.warn( - "The parameters used to calculate the previous template metrics are different than those used now. Deleting previous template metrics..." + f"The parameters used to calculate the previous template metrics are different" + f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " + f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." ) tm_extension.params["metric_names"] = [] existing_metric_names = [] From 41f73ed5c68992a871531a1d832e4179c2e9e02d Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:50:19 +0100 Subject: [PATCH 57/98] re-remove pcs --- src/spikeinterface/qualitymetrics/tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 676889094b..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -67,6 +67,5 @@ def sorting_analyzer_simple(): sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) return sorting_analyzer From 5a8535a166f9c449d24054dfa47380b6cdb1e811 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Sep 2024 16:54:27 +0200 Subject: [PATCH 58/98] Add recording and analyzer tests with interleaved probegroups --- .../core/baserecordingsnippets.py | 2 +- src/spikeinterface/core/recording_tools.py | 3 +- src/spikeinterface/core/sortinganalyzer.py | 3 +- .../core/tests/test_baserecording.py | 33 +++++++++++++++++-- .../core/tests/test_sortinganalyzer.py | 21 ++++++++++++ 5 files changed, 56 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 763f9e5801..d6088a01d7 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -239,7 +239,7 @@ def set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False) warning_msg = ( "`set_probes` is now a private function and the public function will be " - "removed in 0.103.0. Please use `set_probe` or `set_probegroups` instead" + "removed in 0.103.0. Please use `set_probe` or `set_probegroup` instead" ) warn(warning_msg, category=DeprecationWarning, stacklevel=2) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 0ec5449bae..5137eda545 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -888,11 +888,10 @@ def check_probe_do_not_overlap(probes): for j in range(i + 1, len(probes)): probe_j = probes[j] - if np.any( np.array( [ - x_bounds_i[0] < cp[0] < x_bounds_i[1] and y_bounds_i[0] < cp[1] < y_bounds_i[1] + x_bounds_i[0] <= cp[0] <= x_bounds_i[1] and y_bounds_i[0] <= cp[1] <= y_bounds_i[1] for cp in probe_j.contact_positions ] ) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 6ce8d180c5..e3b6527b90 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1107,7 +1107,8 @@ def get_channel_locations(self) -> np.ndarray: probe_as_numpy_array = probe_as_numpy_array[np.argsort(probe_as_numpy_array["device_channel_indices"])] ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): + # here we only loop through xy because only 2d locations are supported + for i, dim in enumerate(["x", "y"][:ndim]): locations[:, i] = probe_as_numpy_array[dim] return locations diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 6b60efe2b6..df614978ba 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -10,7 +10,7 @@ import numpy as np from numpy.testing import assert_raises -from probeinterface import Probe +from probeinterface import Probe, ProbeGroup, generate_linear_probe from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load_extractor, get_default_zarr_compressor from spikeinterface.core.base import BaseExtractor @@ -358,6 +358,34 @@ def test_BaseRecording(create_cache_folder): assert np.allclose(rec_u.get_traces(cast_unsigned=True), rec_i.get_traces().astype("float")) +def test_interleaved_probegroups(): + recording = generate_recording(durations=[1.0], num_channels=16) + + probe1 = generate_linear_probe(num_elec=8, ypitch=20.0) + probe2_overlap = probe1.copy() + + probegroup_overlap = ProbeGroup() + probegroup_overlap.add_probe(probe1) + probegroup_overlap.add_probe(probe2_overlap) + probegroup_overlap.set_global_device_channel_indices(np.arange(16)) + + # setting overlapping probes should raise an error + with pytest.raises(Exception): + recording.set_probegroup(probegroup_overlap) + + probe2 = probe1.copy() + probe2.move([100.0, 100.0]) + probegroup = ProbeGroup() + probegroup.add_probe(probe1) + probegroup.add_probe(probe2) + probegroup.set_global_device_channel_indices(np.random.permutation(16)) + + recording.set_probegroup(probegroup) + probegroup_set = recording.get_probegroup() + # check that the probe group is correctly set, by sorting the device channel indices + assert np.array_equal(probegroup_set.get_global_device_channel_indices()["device_channel_indices"], np.arange(16)) + + def test_rename_channels(): recording = generate_recording(durations=[1.0], num_channels=3) renamed_recording = recording.rename_channels(new_channel_ids=["a", "b", "c"]) @@ -399,4 +427,5 @@ def test_time_slice_with_time_vector(): if __name__ == "__main__": - test_BaseRecording() + # test_BaseRecording() + test_interleaved_probegroups() diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 3f45487f4c..4468c3f505 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -178,6 +178,27 @@ def test_SortingAnalyzer_tmp_recording(dataset): sorting_analyzer.set_temporary_recording(recording_sliced) +def test_SortingAnalyzer_interleaved_probegroup(dataset): + from probeinterface import generate_linear_probe, ProbeGroup + + recording, sorting = dataset + num_channels = recording.get_num_channels() + probe1 = generate_linear_probe(num_elec=num_channels // 2, ypitch=20.0) + probe2 = probe1.copy() + probe2.move([100.0, 100.0]) + + probegroup = ProbeGroup() + probegroup.add_probe(probe1) + probegroup.add_probe(probe2) + probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) + + recording.set_probegroup(probegroup) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + # check that locations are correct + assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations()) + + def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) From 992716fa01d44724d83a7cb749bb2a1524d5af26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Sep 2024 17:30:53 +0200 Subject: [PATCH 59/98] Update release notes --- doc/releases/0.101.1.rst | 52 +++++++++++++++++++++++++++++----------- doc/whatisnew.rst | 5 ++++ 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index d2b0d35020..d8c8dffddd 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -3,28 +3,35 @@ SpikeInterface 0.101.1 release notes ------------------------------------ -6th September 2024 +13th September 2024 Main changes: -* +* Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) +* Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) +* Skip recomputation of quality and template metrics if already computed (#3292) +* Dropped support for Python<3.9 (#3267) core: -* Add `BaseRecording.reset_times()` function (#3363) +* Fix proposal for channel location when probegroup (#3392) +* Fix time handling test memory (#3379) +* Add `BaseRecording.reset_times()` function (#3363, #3380, #3391) +* Extend `estimate_sparsity` methods and fix from_ptp (#3369) * Add `load_sorting_analyzer_or_waveforms` function (#3352) -* Propagate storage_options to load_sorting_analyzer (#3351) * Fix zarr folder suffix handling (#3349) +* Analyzer extension exit status (#3347) * Lazy loading of zarr timestamps (#3318) -* Enable cloud-loading for analyzer Zarr (#3314) +* Enable cloud-loading for analyzer Zarr (#3314, #3351, #3378) * Refactor `set_property` in base (#3287) * Job kwargs fix (#3259) -* Add check for None in 'NoiseGeneratorRecordingSegment' get_traces(). (#3230) +* Add `is_filtered` to annotations in `binary.json` (#3245) +* Add check for None in `NoiseGeneratorRecordingSegment`` get_traces() (#3230) extractors: +* Load phy channel_group as group (#3368) * "quality" property to be read as string instead of object in `BasePhyKilosortSortingExtractor` (#3365) -* Test IBL skip when the setting up the one client fails (#3289) preprocessing: @@ -33,18 +40,25 @@ preprocessing: sorters: -* fix: download apptainer images without docker client (#3335) +* Updates to kilosort 4: version >= 4.0.16, `bad_channels`, `clear_cache`, `use_binary_file` (#3339) +* Download apptainer images without docker client (#3335) * Expose save preprocessing in ks4 (#3276) * Fix KS2/2.5/3 skip_kilosort_preprocessing (#3265) -* added lowpass parameter, fixed verbose option (#3262) +* HS: Added lowpass parameter, fixed verbose option (#3262) * Now exclusive support for HS v0.4 (Lightning) (#3210) * Add kilosort4 wrapper tests (#3085) postprocessing: +* Add extra protection for template metrics (#3364) * Protect median against nans in get_prototype_spike (#3270) * Fix docstring and error for spike_amplitudes (#3269) +qualitymetrics: + +* Do not delete quality and template metrics on recompute (#3292) +* Refactor quality metrics tests to use fixture (#3249) + curation: @@ -56,15 +70,17 @@ widgets: * Fix widgets tests and add test on unit_table_properties (#3354) * Allow quality and template metrics in sortingview's unit table (#3299) -* Fix #3236 (#3238) +* Add subwidget parameters for UnitSummaryWidget (#3242) +* Fix `ipympl`/`widget` backend check (#3238) -sortingcomponents: +generators: -* Update doc handle drift + better preset (#3232) +* Handle case where channel count changes from probeA to probeB (#3237) -motion correction: +sortingcomponents: -* Make InterpolateMotionRecording not JSON-serializable (#3341) +* Update doc handle drift + better preset (#3232) +* Make `InterpolateMotionRecording`` not JSON-serializable (#3341) documentation: @@ -87,14 +103,19 @@ continuous integration: packaging: +* Minor typing fixes (#3374) * Drop python 3.8 in pyproject.toml (#3267) testing: +* Fix time handling test memory (#3379) * Fix streaming extractor condition in the CI (#3362) * Test IBL skip when the setting up the one client fails (#3289) * Refactor `set_property` in base (#3287) +* Refactor quality metrics tests to use fixture (#3249) * Add kilosort4 wrapper tests (#3085) +* Test IBL skip when the setting up the one client fails (#3289) + Contributors: @@ -104,7 +125,10 @@ Contributors: * @alejoe91 * @app/pre-commit-ci * @chrishalcrow +* @cwindolf +* @florian6973 * @h-mayorquin +* @jiumao2 * @jonahpearl * @mhhennig * @rkim48 diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 56d38ce85b..330f72f215 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -47,7 +47,12 @@ Release notes Version 0.101.1 =============== +Minor release with bug fixes and minor improvements: +* Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) +* Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) +* Skip recomputation of quality and template metrics if already computed (#3292) +* Dropped support for Python<3.9 (#3267) Version 0.101.0 From 726fed06d0485dc708a6e1fc8605beeae5533a14 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 09:38:42 -0600 Subject: [PATCH 60/98] freeze decision --- doc/development/development.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/development/development.rst b/doc/development/development.rst index 1094b466fc..792660d9aa 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -192,6 +192,11 @@ Miscelleaneous Stylistic Conventions #. Avoid using abreviations in variable names (e.g., use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. #. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. +#. For creating headers to divide section of codes we use the following convention (see issue 3019): + +######################################### +# A header +######################################### How to build the documentation From 20fd07e8cdc0ec505a02fde1013f9edb0f9a47d2 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 09:49:48 -0600 Subject: [PATCH 61/98] Update doc/development/development.rst Co-authored-by: Alessio Buccino --- doc/development/development.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/development/development.rst b/doc/development/development.rst index 792660d9aa..8d5471dc71 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -192,7 +192,7 @@ Miscelleaneous Stylistic Conventions #. Avoid using abreviations in variable names (e.g., use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. #. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. -#. For creating headers to divide section of codes we use the following convention (see issue 3019): +#. For creating headers to divide section of codes we use the following convention (see issue `#3019 `_): ######################################### # A header From bc946a259fa851bbcf892184969da569b71b8d24 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 10:36:12 -0600 Subject: [PATCH 62/98] Update doc/development/development.rst Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/development/development.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/development/development.rst b/doc/development/development.rst index 8d5471dc71..dd6886bb63 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -192,7 +192,7 @@ Miscelleaneous Stylistic Conventions #. Avoid using abreviations in variable names (e.g., use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. #. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. -#. For creating headers to divide section of codes we use the following convention (see issue `#3019 `_): +#. For creating headers to divide sections of code we use the following convention (see issue `#3019 `_): ######################################### # A header From ee854b6ff0d589f9dd7c8552d65ebdfcf5426df8 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 10:43:08 -0600 Subject: [PATCH 63/98] Update doc/development/development.rst --- doc/development/development.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/development/development.rst b/doc/development/development.rst index dd6886bb63..be94319fb4 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -189,7 +189,7 @@ so that the user knows what the options are. Miscelleaneous Stylistic Conventions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -#. Avoid using abreviations in variable names (e.g., use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. +#. Avoid using abbreviations in variable names (e.g. use :code:`recording` instead of :code:`rec`). It is especially important to avoid single letter variables. #. Use index as singular and indices for plural following the NumPy convention. Avoid idx or indexes. Plus, id and ids are reserved for identifiers (i.e. channel_ids) #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. #. For creating headers to divide sections of code we use the following convention (see issue `#3019 `_): From 82093448e6cf0b3675e280cfe61331bd7ba4b3a9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Sep 2024 19:33:59 +0200 Subject: [PATCH 64/98] Update version to 0.102.0 --- doc/releases/{0.101.1.rst => 0.102.0.rst} | 7 ++++--- doc/whatisnew.rst | 6 +++--- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) rename doc/releases/{0.101.1.rst => 0.102.0.rst} (95%) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.102.0.rst similarity index 95% rename from doc/releases/0.101.1.rst rename to doc/releases/0.102.0.rst index d8c8dffddd..1ae1c44a09 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.102.0.rst @@ -1,6 +1,6 @@ -.. _release0.101.1: +.. _release0.102.0: -SpikeInterface 0.101.1 release notes +SpikeInterface 0.102.0 release notes ------------------------------------ 13th September 2024 @@ -10,6 +10,7 @@ Main changes: * Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) * Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) * Skip recomputation of quality and template metrics if already computed (#3292) +* Modified and improved `estimate_sparsity` and refactored `from_ptp` option (#3369) * Dropped support for Python<3.9 (#3267) core: @@ -17,7 +18,7 @@ core: * Fix proposal for channel location when probegroup (#3392) * Fix time handling test memory (#3379) * Add `BaseRecording.reset_times()` function (#3363, #3380, #3391) -* Extend `estimate_sparsity` methods and fix from_ptp (#3369) +* Extend `estimate_sparsity` methods and update `from_ptp`` (#3369) * Add `load_sorting_analyzer_or_waveforms` function (#3352) * Fix zarr folder suffix handling (#3349) * Analyzer extension exit status (#3347) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 330f72f215..090722e27b 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,7 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 - releases/0.101.1.rst + releases/0.102.0.rst releases/0.101.0.rst releases/0.100.8.rst releases/0.100.7.rst @@ -44,14 +44,14 @@ Release notes releases/0.9.1.rst -Version 0.101.1 +Version 0.102.0 =============== -Minor release with bug fixes and minor improvements: * Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) * Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) * Skip recomputation of quality and template metrics if already computed (#3292) +* Modified and improved `estimate_sparsity` and refactored `from_ptp` option (#3369) * Dropped support for Python<3.9 (#3267) diff --git a/pyproject.toml b/pyproject.toml index db435998a1..1bdf3c303b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.101.1" +version = "0.102.0" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, From 67469ac1ce8706f684dbcc0ce08251b21c0ddb3b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 11:48:43 -0600 Subject: [PATCH 65/98] Update doc/development/development.rst Co-authored-by: Alessio Buccino --- doc/development/development.rst | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/doc/development/development.rst b/doc/development/development.rst index be94319fb4..246a2bcb9a 100644 --- a/doc/development/development.rst +++ b/doc/development/development.rst @@ -194,9 +194,12 @@ Miscelleaneous Stylistic Conventions #. We use file_path and folder_path (instead of file_name and folder_name) for clarity. #. For creating headers to divide sections of code we use the following convention (see issue `#3019 `_): -######################################### -# A header -######################################### + +.. code:: python + + ######################################### + # A header + ######################################### How to build the documentation From ba278a2915e9e2d8bc9e6653dbf655a637aba94a Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Thu, 12 Sep 2024 21:17:04 +0200 Subject: [PATCH 66/98] Update src/spikeinterface/core/tests/test_sortinganalyzer.py --- src/spikeinterface/core/tests/test_sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 4468c3f505..bc1db643df 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -192,7 +192,7 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset): probegroup.add_probe(probe2) probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) - recording.set_probegroup(probegroup) + recording = recording.set_probegroup(probegroup) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) # check that locations are correct From 710a730013f00db9efda0ac6a9a6dd93ca6aff22 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 11:50:19 +0200 Subject: [PATCH 67/98] Sortingview: only round float properties --- src/spikeinterface/widgets/utils_sortingview.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 7a9dc47826..d18c581b6b 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -141,12 +141,13 @@ def generate_unit_table_view( elif prop_name in tm_props: property_values = tm_data[prop_name].values - # Check for NaN values + # Check for NaN values and round floats val0 = np.array(property_values[0]) if val0.dtype.kind == "f": if np.isnan(property_values[ui]): continue - values[prop_name] = np.format_float_positional(property_values[ui], precision=4, fractional=False) + property_values[ui] = np.format_float_positional(property_values[ui], precision=4, fractional=False) + values[prop_name] = property_values[ui] ut_rows.append(vv.UnitsTableRow(unit_id=unit, values=check_json(values))) v_units_table = vv.UnitsTable(rows=ut_rows, columns=ut_columns, similarity_scores=similarity_scores) From 9ffda3543aa794418af02830be70144de64c54f2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:33:48 +0200 Subject: [PATCH 68/98] Add from_amplitude() option to sparsity and deprecate ptp --- src/spikeinterface/core/sparsity.py | 105 ++++++++++++++---- .../core/tests/test_sparsity.py | 31 +++++- 2 files changed, 108 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 471302d57e..fd613e1fcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import warnings from .basesorting import BaseSorting @@ -18,14 +19,16 @@ * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument to specify the mode to compute the amplitude of the templates. - * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of amplitude). + * "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument + to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "energy" : threshold based on the expected energy that should be present on the channels, given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. + * "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead. peak_sign : "neg" | "pos" | "both" Sign of the template to compute best channels. @@ -37,7 +40,7 @@ Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. amplitude_mode : "extremum" | "at_index" | "peak_to_peak" - Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods. + Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods. by_property : object Property name for "by_property" method. """ @@ -417,7 +420,7 @@ def from_snr( return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the peak-to-peak threshold. @@ -434,30 +437,67 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): sparsity : ChannelSparsity The estimated sparsity. """ + warnings.warn( + "The 'ptp' method is deprecated and will be removed in version 0.103.0. " + "Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.", + DeprecationWarning, + ) + return cls.from_snr( + templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels + ) - assert ( - templates_or_sorting_analyzer.sparsity is None - ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + @classmethod + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + """ + Construct sparsity from a threshold based on template amplitude. + The amplitude is computed with the specified amplitude mode and it is assumed + that the amplitude is in uV. The input `Templates` or `SortingAnalyzer` object must + have scaled templates. - from .template_tools import get_dense_templates_array + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "amplitude" method (in uV). + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer from .template import Templates + assert ( + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + unit_ids = templates_or_sorting_analyzer.unit_ids channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - return_scaled = templates_or_sorting_analyzer.return_scaled + assert templates_or_sorting_analyzer.return_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `return_scaled=True` when computing the templates." + ) elif isinstance(templates_or_sorting_analyzer, Templates): - return_scaled = templates_or_sorting_analyzer.is_scaled + assert templates_or_sorting_analyzer.is_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `is_scaled=True` when creating the Templates object." + ) mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) - templates_ptps = np.ptp(templates_array, axis=1) + peak_values = get_template_amplitudes( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=True + ) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero((np.abs(peak_values[unit_id])) >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -560,7 +600,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", + method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, @@ -595,7 +635,7 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "radius", "snr", "ptp"): + if method in ("best_channels", "radius", "snr", "amplitude", "ptp"): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -619,11 +659,13 @@ def compute_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) - elif method == "ptp": - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( templates_or_sorting_analyzer, threshold, + amplitude_mode=amplitude_mode, + peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -633,6 +675,14 @@ def compute_sparsity( sparsity = ChannelSparsity.from_property( templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -648,7 +698,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius", + method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -697,9 +747,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), ( + assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'" + "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" ) if recording.get_probes() == 1: @@ -768,12 +818,19 @@ def estimate_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign + ) elif method == "ptp": + # TODO: remove after deprecation assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( - templates, - threshold, + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." ) + sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels) else: raise ValueError(f"compute_sparsity() method={method} does not exists") else: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 64517c106d..ace869df8c 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -209,15 +209,16 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) - # ptp: just run it + # amplitude sparsity = estimate_sparsity( sorting, recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, - method="ptp", + method="amplitude", threshold=5, + amplitude_mode="peak_to_peak", chunk_duration="1s", progress_bar=True, n_jobs=1, @@ -252,6 +253,23 @@ def test_estimate_sparsity(): progress_bar=True, n_jobs=1, ) + # ptp: just run it + print(noise_levels) + + with pytest.warns(DeprecationWarning): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) def test_compute_sparsity(): @@ -273,9 +291,11 @@ def test_compute_sparsity(): sparsity = compute_sparsity( sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" ) - sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") @@ -283,7 +303,10 @@ def test_compute_sparsity(): sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) if __name__ == "__main__": From 6c8889d74d5e8c4c2b5fa073f9b5cfcd7b9141b6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:39:05 +0200 Subject: [PATCH 69/98] Update src/spikeinterface/qualitymetrics/quality_metric_calculator.py Co-authored-by: Garcia Samuel --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 8a754fc7da..52eb56c4ee 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -165,7 +165,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri import pandas as pd - metrics = pd.DataFrame(index=sorting_analyzer.unit_ids) + metrics = pd.DataFrame(index=unit_ids) # simple metrics not based on PCs for metric_name in metric_names: From c769f5457d3173377c5906e755ec32411e9a978f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:49:09 +0200 Subject: [PATCH 70/98] Revert back to 0.101.1 --- doc/releases/{0.102.0.rst => 0.101.1.rst} | 10 +++++----- doc/whatisnew.rst | 7 +++---- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 10 deletions(-) rename doc/releases/{0.102.0.rst => 0.101.1.rst} (95%) diff --git a/doc/releases/0.102.0.rst b/doc/releases/0.101.1.rst similarity index 95% rename from doc/releases/0.102.0.rst rename to doc/releases/0.101.1.rst index 1ae1c44a09..38522f0819 100644 --- a/doc/releases/0.102.0.rst +++ b/doc/releases/0.101.1.rst @@ -1,6 +1,6 @@ -.. _release0.102.0: +.. _release0.101.1: -SpikeInterface 0.102.0 release notes +SpikeInterface 0.101.1 release notes ------------------------------------ 13th September 2024 @@ -10,7 +10,7 @@ Main changes: * Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) * Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) * Skip recomputation of quality and template metrics if already computed (#3292) -* Modified and improved `estimate_sparsity` and refactored `from_ptp` option (#3369) +* Improved `estimate_sparsity` with new `amplitude` method and and deprecated `from_ptp` option (#3369) * Dropped support for Python<3.9 (#3267) core: @@ -47,7 +47,6 @@ sorters: * Fix KS2/2.5/3 skip_kilosort_preprocessing (#3265) * HS: Added lowpass parameter, fixed verbose option (#3262) * Now exclusive support for HS v0.4 (Lightning) (#3210) -* Add kilosort4 wrapper tests (#3085) postprocessing: @@ -69,6 +68,7 @@ curation: widgets: +* Sortingview: only round float properties (#3406) * Fix widgets tests and add test on unit_table_properties (#3354) * Allow quality and template metrics in sortingview's unit table (#3299) * Add subwidget parameters for UnitSummaryWidget (#3242) @@ -116,7 +116,7 @@ testing: * Refactor quality metrics tests to use fixture (#3249) * Add kilosort4 wrapper tests (#3085) * Test IBL skip when the setting up the one client fails (#3289) - +* Add kilosort4 wrapper tests (#3085) Contributors: diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 090722e27b..442876ff94 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,7 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 - releases/0.102.0.rst + releases/0.101.1.rst releases/0.101.0.rst releases/0.100.8.rst releases/0.100.7.rst @@ -44,14 +44,13 @@ Release notes releases/0.9.1.rst -Version 0.102.0 +Version 0.101.1 =============== - * Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) * Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) * Skip recomputation of quality and template metrics if already computed (#3292) -* Modified and improved `estimate_sparsity` and refactored `from_ptp` option (#3369) +* Improved `estimate_sparsity` with new `amplitude` method and and deprecated `from_ptp` option (#3369) * Dropped support for Python<3.9 (#3267) diff --git a/pyproject.toml b/pyproject.toml index 1bdf3c303b..db435998a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.102.0" +version = "0.101.1" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, From 3a13efe7607404a36337baa083b8dbe283a45bf0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:59:42 +0200 Subject: [PATCH 71/98] Check run info completed only if it exists (back-compatibility) --- src/spikeinterface/core/sortinganalyzer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3aa582da68..8980fb5559 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2240,9 +2240,10 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info[ - "run_completed" - ], f"You must run the extension {self.extension_name} before retrieving data" + if self.run_info is not None: + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) From 0f0c72e7720b63f67941507fa5926110bfdf31f2 Mon Sep 17 00:00:00 2001 From: Yue Huang <806628409@qq.com> Date: Fri, 13 Sep 2024 21:20:32 +0800 Subject: [PATCH 72/98] Update recording_tools.py Update the method to create an empty file with the right size --- src/spikeinterface/core/recording_tools.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 5137eda545..2e93b3671b 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -131,8 +131,10 @@ def write_binary_recording( data_size_bytes = dtype_size_bytes * num_frames * num_channels file_size_bytes = data_size_bytes + byte_offset + # create a file with file_size_bytes file = open(file_path, "wb+") - file.truncate(file_size_bytes) + file.seek(file_size_bytes - 1) + file.write(b'\0') file.close() assert Path(file_path).is_file() From 2cb34c3fd6d9f2296d5c9acdc1b00d608e8531c3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 13:40:59 +0000 Subject: [PATCH 73/98] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/recording_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 2e93b3671b..c18e461f90 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -134,7 +134,7 @@ def write_binary_recording( # create a file with file_size_bytes file = open(file_path, "wb+") file.seek(file_size_bytes - 1) - file.write(b'\0') + file.write(b"\0") file.close() assert Path(file_path).is_file() From 0f5948667352da9fe3fada5b5a456a551dd81c43 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 13 Sep 2024 10:29:03 -0400 Subject: [PATCH 74/98] fix binary argument --- src/spikeinterface/core/baserecording.py | 19 +++++++++++++++++-- .../sorters/external/kilosort4.py | 6 +++--- .../sorters/external/kilosortbase.py | 2 +- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 082afd880b..e44ed9b948 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -768,7 +768,13 @@ def get_binary_description(self): raise NotImplementedError def binary_compatible_with( - self, dtype=None, time_axis=None, file_paths_lenght=None, file_offset=None, file_suffix=None + self, + dtype=None, + time_axis=None, + file_paths_length=None, + file_offset=None, + file_suffix=None, + file_paths_lenght=None, ): """ Check is the recording is binary compatible with some constrain on @@ -779,6 +785,15 @@ def binary_compatible_with( * file_offset * file_suffix """ + + # spelling typo need to fix + if file_paths_lenght is not None: + warnings.warn( + "`file_paths_lenght` is deprecated and will be removed in 0.103.0 please use `file_paths_length`" + ) + if file_paths_length is None: + file_paths_length = file_paths_lenght + if not self.is_binary_compatible(): return False @@ -790,7 +805,7 @@ def binary_compatible_with( if time_axis is not None and time_axis != d["time_axis"]: return False - if file_paths_lenght is not None and file_paths_lenght != len(d["file_paths"]): + if file_paths_length is not None and file_paths_length != len(d["file_paths"]): return False if file_offset is not None and file_offset != d["file_offset"]: diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index e73ac2cb6c..2a9fb34267 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -179,7 +179,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): write_prb(probe_filename, pg) if params["use_binary_file"]: - if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + if not recording.binary_compatible_with(time_axis=0, file_paths_length=1): # local copy needed binary_file_path = sorter_output_folder / "recording.dat" write_binary_recording( @@ -235,7 +235,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): probe_name = "" if params["use_binary_file"] is None: - if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + if recording.binary_compatible_with(time_axis=0, file_paths_length=1): # no copy binary_description = recording.get_binary_description() filename = str(binary_description["file_paths"][0]) @@ -247,7 +247,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): file_object = RecordingExtractorAsArray(recording_extractor=recording) elif params["use_binary_file"]: # here we force the use of a binary file - if recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): + if recording.binary_compatible_with(time_axis=0, file_paths_length=1): # no copy binary_description = recording.get_binary_description() filename = str(binary_description["file_paths"][0]) diff --git a/src/spikeinterface/sorters/external/kilosortbase.py b/src/spikeinterface/sorters/external/kilosortbase.py index 95d8d3badc..2aff9d296f 100644 --- a/src/spikeinterface/sorters/external/kilosortbase.py +++ b/src/spikeinterface/sorters/external/kilosortbase.py @@ -127,7 +127,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): skip_kilosort_preprocessing = params.get("skip_kilosort_preprocessing", False) if ( - recording.binary_compatible_with(dtype="int16", time_axis=0, file_paths_lenght=1) + recording.binary_compatible_with(dtype="int16", time_axis=0, file_paths_length=1) and not skip_kilosort_preprocessing ): # no copy From 3f9a9011a958252547acc4589dca88d727a49b6a Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Fri, 13 Sep 2024 18:00:33 -0600 Subject: [PATCH 75/98] Update src/spikeinterface/core/recording_tools.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/recording_tools.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index c18e461f90..30cd54473c 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -133,7 +133,11 @@ def write_binary_recording( # create a file with file_size_bytes file = open(file_path, "wb+") - file.seek(file_size_bytes - 1) + if platform.system() == 'Windows': + file.seek(file_size_bytes - 1) + file.write(b"\0") + else: + file.truncate(file_size_bytes) file.write(b"\0") file.close() assert Path(file_path).is_file() From d460c1ed312f5399b11100613483c3b364ef7213 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Sep 2024 00:00:54 +0000 Subject: [PATCH 76/98] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/recording_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 30cd54473c..f649e60b28 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -133,7 +133,7 @@ def write_binary_recording( # create a file with file_size_bytes file = open(file_path, "wb+") - if platform.system() == 'Windows': + if platform.system() == "Windows": file.seek(file_size_bytes - 1) file.write(b"\0") else: From dcedbb33739663e0cb662d313e0224f3d64d04f8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 13:01:11 +0200 Subject: [PATCH 77/98] Simplify pandas save-load and convert dtypes --- pyproject.toml | 3 -- src/spikeinterface/core/sortinganalyzer.py | 37 ++++++++++++------- .../postprocessing/template_metrics.py | 2 +- .../tests/common_extension_tests.py | 23 +++++++++++- .../quality_metric_calculator.py | 2 +- 5 files changed, 47 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8309ca89fe..b5894bf3a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,6 @@ preprocessing = [ full = [ "h5py", "pandas", - "xarray", "scipy", "scikit-learn", "networkx", @@ -148,7 +147,6 @@ test = [ "pytest-dependency", "pytest-cov", - "xarray", "huggingface_hub", # preprocessing @@ -193,7 +191,6 @@ docs = [ "pandas", # in the modules gallery comparison tutorial "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions - "xarray", # For use of SortingAnalyzer zarr format "networkx", # Download data "pooch>=1.8.2", diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8980fb5559..312f85a8ca 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1970,12 +1970,14 @@ def load_data(self): if "dict" in ext_data_.attrs: ext_data = ext_data_[0] elif "dataframe" in ext_data_.attrs: - import xarray + import pandas as pd - ext_data = xarray.open_zarr( - ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" - ).to_pandas() - ext_data.index.rename("", inplace=True) + index = ext_data_["index"] + ext_data = pd.DataFrame(index=index) + for col in ext_data_.keys(): + if col != "index": + ext_data.loc[:, col] = ext_data_[col][:] + ext_data = ext_data.convert_dtypes() elif "object" in ext_data_.attrs: ext_data = ext_data_[0] else: @@ -2031,12 +2033,21 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): self._save_run_info() self._save_data(**kwargs) + if self.format == "zarr": + import zarr + + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def save(self, **kwargs): self._save_params() self._save_importing_provenance() - self._save_data(**kwargs) self._save_run_info() + self._save_data(**kwargs) + + if self.format == "zarr": + import zarr + + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _save_data(self, **kwargs): if self.format == "memory": @@ -2096,12 +2107,12 @@ def _save_data(self, **kwargs): elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=extension_group.store, - group=f"{extension_group.name}/{ext_data_name}", - mode="a", - ) - extension_group[ext_data_name].attrs["dataframe"] = True + df_group = extension_group.create_group(ext_data_name) + # first we save the index + df_group.create_dataset(name="index", data=ext_data.index.to_numpy()) + for col in ext_data.columns: + df_group.create_dataset(name=col, data=ext_data[col].to_numpy()) + df_group.attrs["dataframe"] = True else: # any object try: @@ -2111,8 +2122,6 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") extension_group[ext_data_name].attrs["object"] = True - # we need to re-consolidate - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _reset_extension_folder(self): """ diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 45ba55dee4..ee5ac6103b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -287,7 +287,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") value = np.nan template_metrics.at[index, metric_name] = value - return template_metrics + return template_metrics.convert_dtypes() def _run(self, verbose=False): self.data["metrics"] = self._compute_metrics( diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 3945e71881..1b0a94d635 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,9 +3,10 @@ import pytest import shutil import numpy as np +import pandas as pd from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer from spikeinterface.core import estimate_sparsity @@ -138,6 +139,26 @@ def _check_one(self, sorting_analyzer, extension_class, params): merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0) assert len(merged.unit_ids) == num_units_after_merge + # test roundtrip + if sorting_analyzer.format in ("binary_folder", "zarr"): + sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder) + ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name) + for ext_data_name, ext_data_loaded in ext_loaded.data.items(): + if isinstance(ext_data_loaded, np.ndarray): + assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) + elif isinstance(ext_data_loaded, pd.DataFrame): + # skip nan values + for col in ext_data_loaded.columns: + np.testing.assert_array_almost_equal( + ext.data[ext_data_name][col].dropna().to_numpy(), + ext_data_loaded[col].dropna().to_numpy(), + decimal=5, + ) + elif isinstance(ext_data_loaded, dict): + assert ext.data[ext_data_name] == ext_data_loaded + else: + continue + def run_extension_tests(self, extension_class, params): """ Convenience function to perform all checks on the extension diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index cdf6151e95..3d7096651f 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -185,7 +185,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - return metrics + return metrics.convert_dtypes() def _run(self, verbose=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( From 9000ce1980130fac15394b7cf68137fe239b0b13 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 13:08:25 +0200 Subject: [PATCH 78/98] local import --- .../postprocessing/tests/common_extension_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 1b0a94d635..2207b98da6 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,7 +3,6 @@ import pytest import shutil import numpy as np -import pandas as pd from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer @@ -117,6 +116,8 @@ def _check_one(self, sorting_analyzer, extension_class, params): with the passed parameters, and check the output is not empty, the extension exists and `select_units()` method works. """ + import pandas as pd + if extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: From c1228d9269663118c065ae9e5f68cb1c0a8b16c7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 18:00:28 +0200 Subject: [PATCH 79/98] Add comment and re-consolidation step for 0.101.0 datasets --- src/spikeinterface/core/sortinganalyzer.py | 15 +++++++++++++++ .../postprocessing/template_metrics.py | 6 +++++- .../qualitymetrics/quality_metric_calculator.py | 9 ++++++--- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 312f85a8ca..a94b7aa3dc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from packaging.version import parse from time import perf_counter import numpy as np @@ -579,6 +580,20 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) + si_info = zarr_root.attrs["spikeinterface_info"] + if parse(si_info["version"]) < parse("0.101.1"): + # v0.101.0 did not have a consolidate metadata step after computing extensions. + # Here we try to consolidate the metadata and throw a warning if it fails. + try: + zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options) + zarr.consolidate_metadata(zarr_root_a.store) + except Exception as e: + warnings.warn( + "The zarr store was not properly consolidated prior to v0.101.1. " + "This may lead to unexpected behavior in loading extensions. " + "Please consider re-saving the SortingAnalyzer object." + ) + # load internal sorting in memory sorting = NumpySorting.from_sorting( ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ee5ac6103b..aa50be4c13 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -287,7 +287,11 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") value = np.nan template_metrics.at[index, metric_name] = value - return template_metrics.convert_dtypes() + + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + template_metrics = template_metrics.convert_dtypes() + return template_metrics def _run(self, verbose=False): self.data["metrics"] = self._compute_metrics( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 3d7096651f..b2804c2638 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -108,6 +108,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job """ Compute quality metrics. """ + import pandas as pd + metric_names = self.params["metric_names"] qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] @@ -132,8 +134,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job non_empty_unit_ids = unit_ids empty_unit_ids = [] - import pandas as pd - metrics = pd.DataFrame(index=unit_ids) # simple metrics not based on PCs @@ -185,7 +185,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - return metrics.convert_dtypes() + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + metrics = metrics.convert_dtypes() + return metrics def _run(self, verbose=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( From 559cbaade8faaf9a0056da84069ddc9cb96faae3 Mon Sep 17 00:00:00 2001 From: Yue Huang <806628409@qq.com> Date: Mon, 16 Sep 2024 00:09:51 +0800 Subject: [PATCH 80/98] Update recording_tools.py --- src/spikeinterface/core/recording_tools.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index f649e60b28..8f42774dec 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -131,14 +131,13 @@ def write_binary_recording( data_size_bytes = dtype_size_bytes * num_frames * num_channels file_size_bytes = data_size_bytes + byte_offset - # create a file with file_size_bytes + # Create an empty file with file_size_bytes file = open(file_path, "wb+") - if platform.system() == "Windows": - file.seek(file_size_bytes - 1) - file.write(b"\0") - else: - file.truncate(file_size_bytes) + + # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) + file.seek(file_size_bytes - 1) file.write(b"\0") + file.close() assert Path(file_path).is_file() From af6c61807f75e98420eb2a849771fb29d57dd1f1 Mon Sep 17 00:00:00 2001 From: Yue Huang <806628409@qq.com> Date: Mon, 16 Sep 2024 00:18:20 +0800 Subject: [PATCH 81/98] Update recording_tools.py --- src/spikeinterface/core/recording_tools.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 8f42774dec..77d427bc88 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -132,13 +132,11 @@ def write_binary_recording( file_size_bytes = data_size_bytes + byte_offset # Create an empty file with file_size_bytes - file = open(file_path, "wb+") + with open(file_path, "wb+") as file: + # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) + file.seek(file_size_bytes - 1) + file.write(b"\0") - # The previous implementation `file.truncate(file_size_bytes)` was slow on Windows (#3408) - file.seek(file_size_bytes - 1) - file.write(b"\0") - - file.close() assert Path(file_path).is_file() # use executor (loop or workers) From 2cb4f7c4332ff4b76b9449edf3a606c7c8a08dc9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 18:29:20 +0200 Subject: [PATCH 82/98] Update release notes --- doc/releases/0.101.1.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index 38522f0819..5879dcc6a2 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -15,6 +15,9 @@ Main changes: core: +* Refactor pandas save load and convert dtypes (#3412) +* Check run info completed only if it exists (back-compatibility) (#3407) +* Fix argument spelling in check for binary compatibility (#3409) * Fix proposal for channel location when probegroup (#3392) * Fix time handling test memory (#3379) * Add `BaseRecording.reset_times()` function (#3363, #3380, #3391) From b1677fabd82f36d0ed51af8418a559661cbfa4e3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 18:43:27 +0200 Subject: [PATCH 83/98] Update src/spikeinterface/qualitymetrics/quality_metric_calculator.py --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 41ad40293a..3b6c6d3e50 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -141,7 +141,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - metric_names = self.params["metric_names"] qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] From c33fbd57b4fc7e95edd8a3680ab40a5f291617e0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 10:00:34 +0200 Subject: [PATCH 84/98] Fix plot motion for multi-segment --- src/spikeinterface/widgets/motion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 81cda212b2..06c0305351 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -340,8 +340,11 @@ def __init__( raise ValueError( "plot drift map : the Motion object is multi-segment you must provide segment_index=XX" ) + assert recording.get_num_segments() == len( + motion.displacement + ), "The number of segments in the recording must be the same as the number of segments in the motion object" - times = recording.get_times() if recording is not None else None + times = recording.get_times(segment_index=segment_index) if recording is not None else None plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], From 9a7295948b83bef46736252f49ac6b164acee53d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 10:04:17 +0200 Subject: [PATCH 85/98] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a94b7aa3dc..177188f21d 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -591,7 +591,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): warnings.warn( "The zarr store was not properly consolidated prior to v0.101.1. " "This may lead to unexpected behavior in loading extensions. " - "Please consider re-saving the SortingAnalyzer object." + "Please consider re-generating the SortingAnalyzer object." ) # load internal sorting in memory From 3d41aca67dbfcd01e74eb3bda0ccf7fa59e4c56a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 11:25:54 +0200 Subject: [PATCH 86/98] Recording cannot be Nonw --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 06c0305351..7c8389cae8 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -344,7 +344,7 @@ def __init__( motion.displacement ), "The number of segments in the recording must be the same as the number of segments in the motion object" - times = recording.get_times(segment_index=segment_index) if recording is not None else None + times = recording.get_times(segment_index=segment_index) plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], From 36251ef49b9c61582de2f84bd81d76d00be34b3e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 11:33:56 +0200 Subject: [PATCH 87/98] Let recording handle times --- src/spikeinterface/widgets/motion.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 7c8389cae8..42e9a20f3c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -200,18 +200,11 @@ def __init__( if peak_amplitudes is not None: peak_amplitudes = peak_amplitudes[peak_mask] - if recording is not None: - sampling_frequency = recording.sampling_frequency - times = recording.get_times(segment_index=segment_index) - else: - times = None - plot_data = dict( peaks=peaks, peak_locations=peak_locations, peak_amplitudes=peak_amplitudes, direction=direction, - times=times, sampling_frequency=sampling_frequency, segment_index=segment_index, depth_lim=depth_lim, @@ -238,10 +231,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - if dp.times is None: + if dp.recording is None: peak_times = dp.peaks["sample_index"] / dp.sampling_frequency else: - peak_times = dp.times[dp.peaks["sample_index"]] + peak_times = dp.recording.sample_index_to_time(dp.peaks["sample_index"], segment_index=dp.segment_index) peak_locs = dp.peak_locations[dp.direction] if dp.scatter_decimate is not None: @@ -344,11 +337,8 @@ def __init__( motion.displacement ), "The number of segments in the recording must be the same as the number of segments in the motion object" - times = recording.get_times(segment_index=segment_index) - plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], - times=times, segment_index=segment_index, depth_lim=depth_lim, motion_lim=motion_lim, From f2d1f078d368d4d0543eeef9663159b1b9684c00 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 11:35:59 +0200 Subject: [PATCH 88/98] Add #3414 PR --- doc/releases/0.101.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index 5879dcc6a2..46ae6b64c0 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -71,6 +71,7 @@ curation: widgets: +* Fix plot motion for multi-segment (#3414) * Sortingview: only round float properties (#3406) * Fix widgets tests and add test on unit_table_properties (#3354) * Allow quality and template metrics in sortingview's unit table (#3299) From eaf09ba55b3062bf5ff340f92c669ba04d722e75 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 11:50:24 +0200 Subject: [PATCH 89/98] Auto-cast recording to float prior to interpolation --- src/spikeinterface/preprocessing/motion.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index ddb981a944..14c565a290 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -13,6 +13,7 @@ from spikeinterface.core.core_tools import SIJsonEncoder from spikeinterface.core.job_tools import _shared_job_kwargs_doc + motion_options_preset = { # dredge "dredge": { @@ -277,10 +278,11 @@ def correct_motion( This function depends on several modular components of :py:mod:`spikeinterface.sortingcomponents`. - If select_kwargs is None then all peak are used for localized. + If `select_kwargs` is None then all peak are used for localized. The recording must be preprocessed (filter and denoised at least), and we recommend to not use whithening before motion estimation. + Since the motion interpolation requires a "float" recording, the recording is casted to float32 if necessary. Parameters for each step are handled as separate dictionaries. For more information please check the documentation of the following functions: @@ -435,6 +437,8 @@ def correct_motion( t1 = time.perf_counter() run_times["estimate_motion"] = t1 - t0 + if recording.get_dtype().kind != "f": + recording = recording.astype("float32") recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) motion_info = dict( From 9456e0577767311bd45363ad33e26ab76dca3188 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 11:52:44 +0200 Subject: [PATCH 90/98] Add PR #3415 --- doc/releases/0.101.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index 46ae6b64c0..8f58ba0359 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -39,6 +39,7 @@ extractors: preprocessing: +* Auto-cast recording to float prior to interpolation (#3415) * Update doc handle drift + better preset (#3232) * Add causal filtering to filter.py (#3172) From 9cc8c2d07717d24dcf483ffe9acc331f77e33075 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 12:32:53 +0200 Subject: [PATCH 91/98] Fix plot_sorting_summary after #3412 with to_numpy() --- src/spikeinterface/widgets/utils_sortingview.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index d18c581b6b..554f8d221d 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -137,9 +137,9 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() # Check for NaN values and round floats val0 = np.array(property_values[0]) From 81e53ab64b1fdafa9f7c67344502b2cf3e769512 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 12:35:58 +0200 Subject: [PATCH 92/98] Fix plot_sorting_summary after #3412 with to_numpy() 2 --- src/spikeinterface/widgets/utils_sortingview.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index 554f8d221d..a6cc562ba2 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -106,9 +106,9 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() else: warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") continue From 2a1ecceeaa7a8d38462546feee1434b88d99836a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 13:23:00 +0200 Subject: [PATCH 93/98] Fix plot_sorting_summary after #3412 with to_numpy() 3 --- src/spikeinterface/widgets/metrics.py | 3 +++ src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 2fbd0e31eb..813e7d7b63 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -235,6 +235,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): values = check_json(metrics.loc[unit_id].to_dict()) values_skip_nans = {} for k, v in values.items(): + # convert_dypes returns NaN as None or np.nan (for float) + if v is None: + continue if np.isnan(v): continue values_skip_nans[k] = v diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index debcd52085..80f58f5ad9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,7 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes"]), + quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), From 6fb9fe650429cc07236bd572915a709cafa0fb40 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 16:25:00 +0200 Subject: [PATCH 94/98] Update doc/releases/0.101.1.rst Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/releases/0.101.1.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index 8f58ba0359..591321fcd4 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -129,7 +129,6 @@ Contributors: * @JoeZiminski * @JuanPimientoCaicedo * @alejoe91 -* @app/pre-commit-ci * @chrishalcrow * @cwindolf * @florian6973 From 7917a5532b7176389d4f40aee13967a58998e6fd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 16:25:50 +0200 Subject: [PATCH 95/98] Apply suggestions from code review Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/releases/0.101.1.rst | 4 ++-- doc/whatisnew.rst | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index 591321fcd4..c24372868c 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -3,14 +3,14 @@ SpikeInterface 0.101.1 release notes ------------------------------------ -13th September 2024 +16th September 2024 Main changes: * Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) * Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) * Skip recomputation of quality and template metrics if already computed (#3292) -* Improved `estimate_sparsity` with new `amplitude` method and and deprecated `from_ptp` option (#3369) +* Improved `estimate_sparsity` with new `amplitude` method and deprecated `from_ptp` option (#3369) * Dropped support for Python<3.9 (#3267) core: diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 442876ff94..c8038387f9 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -50,7 +50,7 @@ Version 0.101.1 * Enabled support for consolidated Zarr-backend for `SortingAnalyzer`, including cloud support (#3314, #3318, #3349, #3351) * Improved support for Kilosort4 **ONLY VERSIONS >= 4.0.16** (#3339, #3276) * Skip recomputation of quality and template metrics if already computed (#3292) -* Improved `estimate_sparsity` with new `amplitude` method and and deprecated `from_ptp` option (#3369) +* Improved `estimate_sparsity` with new `amplitude` method and deprecated `from_ptp` option (#3369) * Dropped support for Python<3.9 (#3267) From 9834996cf38e2589bdbcd44d20cbcd7a4afee205 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 18:34:09 +0200 Subject: [PATCH 96/98] Fix metrics widgets for convert_dtypes --- src/spikeinterface/widgets/metrics.py | 3 +++ src/spikeinterface/widgets/tests/test_widgets.py | 2 +- src/spikeinterface/widgets/utils_sortingview.py | 8 ++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/widgets/metrics.py b/src/spikeinterface/widgets/metrics.py index 2fbd0e31eb..813e7d7b63 100644 --- a/src/spikeinterface/widgets/metrics.py +++ b/src/spikeinterface/widgets/metrics.py @@ -235,6 +235,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): values = check_json(metrics.loc[unit_id].to_dict()) values_skip_nans = {} for k, v in values.items(): + # convert_dypes returns NaN as None or np.nan (for float) + if v is None: + continue if np.isnan(v): continue values_skip_nans[k] = v diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index debcd52085..80f58f5ad9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -73,7 +73,7 @@ def setUpClass(cls): spike_amplitudes=dict(), unit_locations=dict(), spike_locations=dict(), - quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes"]), + quality_metrics=dict(metric_names=["snr", "isi_violation", "num_spikes", "amplitude_cutoff"]), template_metrics=dict(), correlograms=dict(), template_similarity=dict(), diff --git a/src/spikeinterface/widgets/utils_sortingview.py b/src/spikeinterface/widgets/utils_sortingview.py index d18c581b6b..a6cc562ba2 100644 --- a/src/spikeinterface/widgets/utils_sortingview.py +++ b/src/spikeinterface/widgets/utils_sortingview.py @@ -106,9 +106,9 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() else: warn(f"Property '{prop_name}' not found in sorting, quality_metrics, or template_metrics") continue @@ -137,9 +137,9 @@ def generate_unit_table_view( if prop_name in sorting_props: property_values = sorting.get_property(prop_name) elif prop_name in qm_props: - property_values = qm_data[prop_name].values + property_values = qm_data[prop_name].to_numpy() elif prop_name in tm_props: - property_values = tm_data[prop_name].values + property_values = tm_data[prop_name].to_numpy() # Check for NaN values and round floats val0 = np.array(property_values[0]) From ca0d99c64535067965fe73777bc0b3ccf31fe840 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 18:38:17 +0200 Subject: [PATCH 97/98] Add #3417 --- doc/releases/0.101.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index c24372868c..f68cd65e46 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -72,6 +72,7 @@ curation: widgets: +* Fix metrics widgets for convert_dtypes (#3417) * Fix plot motion for multi-segment (#3414) * Sortingview: only round float properties (#3406) * Fix widgets tests and add test on unit_table_properties (#3354) From 12f089b15a9d12cbafde1442dca2aa0c5cafa71b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 19:10:44 +0200 Subject: [PATCH 98/98] Add #3408 --- doc/releases/0.101.1.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/0.101.1.rst b/doc/releases/0.101.1.rst index f68cd65e46..eeb54566a6 100644 --- a/doc/releases/0.101.1.rst +++ b/doc/releases/0.101.1.rst @@ -15,6 +15,7 @@ Main changes: core: +* Update the method of creating an empty file with right size when saving binary files (#3408) * Refactor pandas save load and convert dtypes (#3412) * Check run info completed only if it exists (back-compatibility) (#3407) * Fix argument spelling in check for binary compatibility (#3409)