From 8b84711c0272ce9d35f248a3ad20c22fa3f51730 Mon Sep 17 00:00:00 2001 From: Charlie Windolf Date: Fri, 19 Jul 2024 13:20:11 -0700 Subject: [PATCH 01/52] Handle case where channel count changes from probeA to probeB --- src/spikeinterface/generation/drift_tools.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 0e4f1985c6..42b6ca99dd 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -40,9 +40,13 @@ def interpolate_templates(templates_array, source_locations, dest_locations, int source_locations = np.asarray(source_locations) dest_locations = np.asarray(dest_locations) if dest_locations.ndim == 2: - new_shape = templates_array.shape + new_shape = (*templates_array.shape[:2], len(dest_locations)) elif dest_locations.ndim == 3: - new_shape = (dest_locations.shape[0],) + templates_array.shape + new_shape = ( + dest_locations.shape[0], + *templates_array.shape[:2], + dest_locations.shape[1], + ) else: raise ValueError(f"Incorrect dimensions for dest_locations: {dest_locations.ndim}. Dimensions can be 2 or 3. ") From 04b4d552b2833c61485e8aa71cb4d1c88cd53190 Mon Sep 17 00:00:00 2001 From: Axoft Server Date: Mon, 22 Jul 2024 18:23:33 -0400 Subject: [PATCH 02/52] feat: Add extra args to UnitSummaryWidget for subwidgets --- src/spikeinterface/widgets/unit_summary.py | 43 +++++++++++++++++++++- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 0b2a348edf..40481ae61d 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -30,6 +30,16 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored + unitlocationswidget_params : dict or None, default: None + Parameters for UnitLocationsWidget (see UnitLocationsWidget for details) + unitwaveformswidgets_params : dict or None, default: None + Parameters for UnitWaveformsWidget (see UnitWaveformsWidget for details) + unitwaveformdensitymapwidget_params : dict or None, default: None + Parameters for UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelogramswidget_params : dict or None, default: None + Parameters for AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudeswidget_params : dict or None, default: None + Parameters for AmplitudesWidget (see AmplitudesWidget for details) """ # possible_backends = {} @@ -40,21 +50,40 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - radius_um=100, + unitlocationswidget_params=None, + unitwaveformswidgets_params=None, + unitwaveformdensitymapwidget_params=None, + autocorrelogramswidget_params=None, + amplitudeswidget_params=None, backend=None, **backend_kwargs, ): - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) + if unitlocationswidget_params is None: + unitlocationswidget_params = dict() + if unitwaveformswidgets_params is None: + unitwaveformswidgets_params = dict() + if unitwaveformdensitymapwidget_params is None: + unitwaveformdensitymapwidget_params = dict() + if autocorrelogramswidget_params is None: + autocorrelogramswidget_params = dict() + if amplitudeswidget_params is None: + amplitudeswidget_params = dict() + plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, + unitlocationswidget_params=unitlocationswidget_params, + unitwaveformswidgets_params=unitwaveformswidgets_params, + unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, + autocorrelogramswidget_params=autocorrelogramswidget_params, + amplitudeswidget_params=amplitudeswidget_params ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -69,6 +98,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting_analyzer = dp.sorting_analyzer unit_colors = dp.unit_colors sparsity = dp.sparsity + unitlocationswidget = dp.unitlocationswidget_params + unitwaveformswidgets = dp.unitwaveformswidgets_params + unitwaveformdensitymapwidget = dp.unitwaveformdensitymapwidget_params + autocorrelogramswidget = dp.autocorrelogramswidget_params + amplitudeswidget = dp.amplitudeswidget_params # force the figure without axes if "figsize" not in backend_kwargs: @@ -99,6 +133,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, + **unitlocationswidget ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -121,6 +156,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, + **unitwaveformswidgets ) ax2.set_title(None) @@ -134,6 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, + **unitwaveformdensitymapwidget ) ax3.set_ylabel(None) @@ -145,6 +182,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, + **autocorrelogramswidget ) ax4.set_title(None) @@ -162,6 +200,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, + **amplitudeswidget ) fig.suptitle(f"unit_id: {dp.unit_id}") From fc0743e59412910e07f5c1355a58718ad6f9a4ed Mon Sep 17 00:00:00 2001 From: Axoft Server Date: Mon, 22 Jul 2024 18:29:38 -0400 Subject: [PATCH 03/52] fix: improve variable names --- src/spikeinterface/widgets/unit_summary.py | 28 +++++++++++----------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 40481ae61d..8313e6b9c7 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -51,7 +51,7 @@ def __init__( unit_colors=None, sparsity=None, unitlocationswidget_params=None, - unitwaveformswidgets_params=None, + unitwaveformswidget_params=None, unitwaveformdensitymapwidget_params=None, autocorrelogramswidget_params=None, amplitudeswidget_params=None, @@ -65,8 +65,8 @@ def __init__( if unitlocationswidget_params is None: unitlocationswidget_params = dict() - if unitwaveformswidgets_params is None: - unitwaveformswidgets_params = dict() + if unitwaveformswidget_params is None: + unitwaveformswidget_params = dict() if unitwaveformdensitymapwidget_params is None: unitwaveformdensitymapwidget_params = dict() if autocorrelogramswidget_params is None: @@ -80,7 +80,7 @@ def __init__( unit_colors=unit_colors, sparsity=sparsity, unitlocationswidget_params=unitlocationswidget_params, - unitwaveformswidgets_params=unitwaveformswidgets_params, + unitwaveformswidget_params=unitwaveformswidget_params, unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, autocorrelogramswidget_params=autocorrelogramswidget_params, amplitudeswidget_params=amplitudeswidget_params @@ -98,11 +98,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting_analyzer = dp.sorting_analyzer unit_colors = dp.unit_colors sparsity = dp.sparsity - unitlocationswidget = dp.unitlocationswidget_params - unitwaveformswidgets = dp.unitwaveformswidgets_params - unitwaveformdensitymapwidget = dp.unitwaveformdensitymapwidget_params - autocorrelogramswidget = dp.autocorrelogramswidget_params - amplitudeswidget = dp.amplitudeswidget_params + unitlocationswidget_params = dp.unitlocationswidget_params + unitwaveformswidget_params = dp.unitwaveformswidget_params + unitwaveformdensitymapwidget_params = dp.unitwaveformdensitymapwidget_params + autocorrelogramswidget_params = dp.autocorrelogramswidget_params + amplitudeswidget_params = dp.amplitudeswidget_params # force the figure without axes if "figsize" not in backend_kwargs: @@ -133,7 +133,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, - **unitlocationswidget + **unitlocationswidget_params ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -156,7 +156,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, - **unitwaveformswidgets + **unitwaveformswidget_params ) ax2.set_title(None) @@ -170,7 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, - **unitwaveformdensitymapwidget + **unitwaveformdensitymapwidget_params ) ax3.set_ylabel(None) @@ -182,7 +182,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, - **autocorrelogramswidget + **autocorrelogramswidget_params ) ax4.set_title(None) @@ -200,7 +200,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, - **amplitudeswidget + **amplitudeswidget_params ) fig.suptitle(f"unit_id: {dp.unit_id}") From d3c9eb0820109e69b5e973419e01531410abb19a Mon Sep 17 00:00:00 2001 From: florian6973 <70778912+florian6973@users.noreply.github.com> Date: Mon, 22 Jul 2024 18:31:33 -0400 Subject: [PATCH 04/52] fix: typo doc --- src/spikeinterface/widgets/unit_summary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8313e6b9c7..d41be2641f 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -32,7 +32,7 @@ class UnitSummaryWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored unitlocationswidget_params : dict or None, default: None Parameters for UnitLocationsWidget (see UnitLocationsWidget for details) - unitwaveformswidgets_params : dict or None, default: None + unitwaveformswidget_params : dict or None, default: None Parameters for UnitWaveformsWidget (see UnitWaveformsWidget for details) unitwaveformdensitymapwidget_params : dict or None, default: None Parameters for UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) From 60f775f9e4bd9c75f6657532ccdf4a2e57701dbb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 22:31:55 +0000 Subject: [PATCH 05/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/unit_summary.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index d41be2641f..979e8c0398 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -83,7 +83,7 @@ def __init__( unitwaveformswidget_params=unitwaveformswidget_params, unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, autocorrelogramswidget_params=autocorrelogramswidget_params, - amplitudeswidget_params=amplitudeswidget_params + amplitudeswidget_params=amplitudeswidget_params, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -133,7 +133,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, - **unitlocationswidget_params + **unitlocationswidget_params, ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -156,7 +156,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, - **unitwaveformswidget_params + **unitwaveformswidget_params, ) ax2.set_title(None) @@ -170,7 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, - **unitwaveformdensitymapwidget_params + **unitwaveformdensitymapwidget_params, ) ax3.set_ylabel(None) @@ -182,7 +182,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, - **autocorrelogramswidget_params + **autocorrelogramswidget_params, ) ax4.set_title(None) @@ -200,7 +200,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, - **amplitudeswidget_params + **amplitudeswidget_params, ) fig.suptitle(f"unit_id: {dp.unit_id}") From 9802c34f78081b63e5d953fb0601cd6e079ce7cf Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 23 Jul 2024 08:33:50 -0400 Subject: [PATCH 06/52] add `is_filtered` to annotations --- src/spikeinterface/core/baserecording.py | 1 + src/spikeinterface/core/binaryrecordingextractor.py | 6 +----- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e65afabaca..03db6bd9af 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -549,6 +549,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): channel_ids=self.get_channel_ids(), time_axis=0, file_offset=0, + is_filtered=self.is_filtered(), gain_to_uV=self.get_channel_gains(), offset_to_uV=self.get_channel_offsets(), ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index a0e349728e..8f542647f1 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -65,13 +65,9 @@ def __init__( gain_to_uV=None, offset_to_uV=None, is_filtered=None, - num_chan=None, ): - # This assigns num_channels if num_channels is not None, otherwise num_chan is assigned - num_channels = num_channels or num_chan + assert num_channels is not None, "You must provide num_channels or num_chan" - if num_chan is not None: - warnings.warn("`num_chan` is to be deprecated in version 0.100, please use `num_channels` instead") if channel_ids is None: channel_ids = list(range(num_channels)) From 309be48e5b38ad56c0ab2cf25c2e359f3058b7f2 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 23 Jul 2024 08:40:09 -0400 Subject: [PATCH 07/52] fix test for deprecation --- src/spikeinterface/core/tests/test_binaryrecordingextractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index ea5edc6e6e..7d90c48947 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -94,7 +94,7 @@ def test_sequential_reading_of_small_traces(folder_with_binary_files): file_paths = [folder / "traces_cached_seg0.raw"] recording = BinaryRecordingExtractor( - num_chan=num_channels, + num_channels=num_channels, file_paths=file_paths, sampling_frequency=sampling_frequency, dtype=dtype, From c61378ec1d3fad9b5f9d6521a51a0dd5376dfc35 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 11:45:17 +0100 Subject: [PATCH 08/52] simplify qm tests and wrap for multiprocessing --- .../qualitymetrics/tests/conftest.py | 51 +++++++++++ .../tests/test_metrics_functions.py | 85 +++++++------------ .../qualitymetrics/tests/test_pca_metrics.py | 5 ++ .../tests/test_quality_metric_calculator.py | 53 +----------- 4 files changed, 92 insertions(+), 102 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index bb2a345340..b1b23fcaee 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,6 +5,8 @@ create_sorting_analyzer, ) +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") + def _small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( @@ -35,3 +37,52 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") def small_sorting_analyzer(): return _small_sorting_analyzer() + + +def _sorting_analyzer_simple(): + + # we need high firing rate for amplitude_cutoff + recording, sorting = generate_ground_truth_recording( + durations=[ + 120.0, + ], + sampling_frequency=30_000.0, + num_channels=6, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + generate_unit_locations_kwargs=dict( + margin_um=5.0, + minimum_z=5.0, + maximum_z=20.0, + ), + generate_templates_kwargs=dict( + unit_params=dict( + alpha=(200.0, 500.0), + ) + ), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=1205, + ) + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) + + sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=1205) + sorting_analyzer.compute("noise_levels") + sorting_analyzer.compute("waveforms", **job_kwargs) + sorting_analyzer.compute("templates") + sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + sorting_analyzer = get_sorting_analyzer(seed=2205) + return sorting_analyzer + + return sorting_analyzer + + +@pytest.fixture(scope="module") +def sorting_analyzer_simple(): + return _sorting_analyzer_simple() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index bb222200e9..446007d10b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -135,36 +135,6 @@ def test_unit_id_order_independence(small_sorting_analyzer): assert quality_metrics_2[metric][1] == metric_1_data["#4"] -def _sorting_analyzer_simple(): - recording, sorting = generate_ground_truth_recording( - durations=[ - 50.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=2205, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=2205) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() - - def _sorting_violation(): max_time = 100.0 sampling_frequency = 30000 @@ -570,27 +540,36 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): if __name__ == "__main__": - sorting_analyzer = _sorting_analyzer_simple() - print(sorting_analyzer) - - test_unit_structure_in_output(_small_sorting_analyzer()) - - # test_calculate_firing_rate_num_spikes(sorting_analyzer) - # test_calculate_snrs(sorting_analyzer) - # test_calculate_amplitude_cutoff(sorting_analyzer) - # test_calculate_presence_ratio(sorting_analyzer) - # test_calculate_amplitude_median(sorting_analyzer) - # test_calculate_sliding_rp_violations(sorting_analyzer) - # test_calculate_drift_metrics(sorting_analyzer) - # test_synchrony_metrics(sorting_analyzer) - # test_synchrony_metrics_unit_id_subset(sorting_analyzer) - # test_synchrony_metrics_no_unit_ids(sorting_analyzer) - # test_calculate_firing_range(sorting_analyzer) - # test_calculate_amplitude_cv_metrics(sorting_analyzer) - # test_calculate_sd_ratio(sorting_analyzer) - - # sorting_analyzer_violations = _sorting_analyzer_violations() + test_unit_structure_in_output(small_sorting_analyzer) + test_unit_id_order_independence(small_sorting_analyzer) + + test_synchrony_counts_no_sync() + test_synchrony_counts_one_sync() + test_synchrony_counts_one_quad_sync() + test_synchrony_counts_not_all_units() + + test_mahalanobis_metrics() + test_lda_metrics() + test_nearest_neighbors_metrics() + test_silhouette_score_metrics() + test_simplified_silhouette_score_metrics() + + test_calculate_firing_rate_num_spikes(sorting_analyzer_simple) + test_calculate_snrs(sorting_analyzer) + test_calculate_amplitude_cutoff(sorting_analyzer) + test_calculate_presence_ratio(sorting_analyzer) + test_calculate_amplitude_median(sorting_analyzer) + test_calculate_sliding_rp_violations(sorting_analyzer) + test_calculate_drift_metrics(sorting_analyzer) + test_synchrony_metrics(sorting_analyzer) + test_synchrony_metrics_unit_id_subset(sorting_analyzer) + test_synchrony_metrics_no_unit_ids(sorting_analyzer) + test_calculate_firing_range(sorting_analyzer) + test_calculate_amplitude_cv_metrics(sorting_analyzer) + test_calculate_sd_ratio(sorting_analyzer) + + sorting_analyzer_violations = _sorting_analyzer_violations() # print(sorting_analyzer_violations) - # test_calculate_isi_violations(sorting_analyzer_violations) - # test_calculate_sliding_rp_violations(sorting_analyzer_violations) - # test_calculate_rp_violations(sorting_analyzer_violations) + test_calculate_isi_violations(sorting_analyzer_violations) + test_calculate_sliding_rp_violations(sorting_analyzer_violations) + test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 6ddeb02689..a0fc97c37c 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -22,3 +22,8 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) + + +if __name__ == "__main__": + + test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 28869ba5ff..f877f12708 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -2,7 +2,6 @@ from pathlib import Path import numpy as np - from spikeinterface.core import ( generate_ground_truth_recording, create_sorting_analyzer, @@ -15,51 +14,9 @@ compute_quality_metrics, ) - job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def get_sorting_analyzer(seed=2205): - # we need high firing rate for amplitude_cutoff - recording, sorting = generate_ground_truth_recording( - durations=[ - 120.0, - ], - sampling_frequency=30_000.0, - num_channels=6, - num_units=10, - generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), - generate_unit_locations_kwargs=dict( - margin_um=5.0, - minimum_z=5.0, - maximum_z=20.0, - ), - generate_templates_kwargs=dict( - unit_params=dict( - alpha=(200.0, 500.0), - ) - ), - noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), - seed=seed, - ) - - sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=True) - - sorting_analyzer.compute("random_spikes", max_spikes_per_unit=300, seed=seed) - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("waveforms", **job_kwargs) - sorting_analyzer.compute("templates") - sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple print(sorting_analyzer) @@ -118,6 +75,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): ) for metric_name in metrics.columns: + print(metric_name) if metric_name == "sd_ratio": # this one need recording!!! continue @@ -291,9 +249,6 @@ def test_empty_units(sorting_analyzer_simple): if __name__ == "__main__": - sorting_analyzer = get_sorting_analyzer() - print(sorting_analyzer) - - test_compute_quality_metrics(sorting_analyzer) - test_compute_quality_metrics_recordingless(sorting_analyzer) - test_empty_units(sorting_analyzer) + test_compute_quality_metrics(sorting_analyzer_simple) + test_compute_quality_metrics_recordingless(sorting_analyzer_simple) + test_empty_units(sorting_analyzer_simple) From df1564b877c26e099754f412066e5f24b0008d58 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 13:18:00 +0100 Subject: [PATCH 09/52] try -1 instead of 2 --- src/spikeinterface/qualitymetrics/tests/conftest.py | 2 +- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py | 2 +- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index b1b23fcaee..4b55a25b4a 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,7 +5,7 @@ create_sorting_analyzer, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def _small_sorting_analyzer(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 446007d10b..0df1c25586 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -44,7 +44,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def test_unit_structure_in_output(small_sorting_analyzer): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index a0fc97c37c..507b9a1f70 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -13,7 +13,7 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=-1, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) for metric_name in res1.columns: diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index f877f12708..da1f08c536 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -14,7 +14,7 @@ compute_quality_metrics, ) -job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") def test_compute_quality_metrics(sorting_analyzer_simple): From b5d896d2d0e996d0d375583f87c817e837963e23 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 24 Jul 2024 14:50:04 +0100 Subject: [PATCH 10/52] delete print statement --- .../qualitymetrics/tests/test_quality_metric_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index da1f08c536..0756596654 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -75,7 +75,6 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): ) for metric_name in metrics.columns: - print(metric_name) if metric_name == "sd_ratio": # this one need recording!!! continue From 9d3aa2a7d14edb72c5729b1e4606519f1f63b3e3 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 25 Jul 2024 09:44:10 +0100 Subject: [PATCH 11/52] go back to n_jobs=1 --- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py | 2 +- .../qualitymetrics/tests/test_quality_metric_calculator.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 0df1c25586..446007d10b 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -44,7 +44,7 @@ from spikeinterface.core.basesorting import minimum_spike_dtype -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") def test_unit_structure_in_output(small_sorting_analyzer): diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index 507b9a1f70..a0fc97c37c 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -13,7 +13,7 @@ def test_calculate_pc_metrics(small_sorting_analyzer): res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True, seed=1205) res1 = pd.DataFrame(res1) - res2 = compute_pc_metrics(sorting_analyzer, n_jobs=-1, progress_bar=True, seed=1205) + res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True, seed=1205) res2 = pd.DataFrame(res2) for metric_name in res1.columns: diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 0756596654..616e6c90c1 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -14,7 +14,7 @@ compute_quality_metrics, ) -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") def test_compute_quality_metrics(sorting_analyzer_simple): From 79f0206f73f654ead63c11f6065e10533d0309ea Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 25 Jul 2024 13:50:52 +0100 Subject: [PATCH 12/52] Respond to Joe review --- .../qualitymetrics/tests/conftest.py | 25 ++----------- .../tests/test_metrics_functions.py | 37 ------------------- .../qualitymetrics/tests/test_pca_metrics.py | 5 --- .../tests/test_quality_metric_calculator.py | 6 --- 4 files changed, 4 insertions(+), 69 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 4b55a25b4a..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -5,10 +5,11 @@ create_sorting_analyzer, ) -job_kwargs = dict(n_jobs=-1, progress_bar=True, chunk_duration="1s") +job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") -def _small_sorting_analyzer(): +@pytest.fixture(scope="module") +def small_sorting_analyzer(): recording, sorting = generate_ground_truth_recording( durations=[2.0], num_units=10, @@ -35,12 +36,7 @@ def _small_sorting_analyzer(): @pytest.fixture(scope="module") -def small_sorting_analyzer(): - return _small_sorting_analyzer() - - -def _sorting_analyzer_simple(): - +def sorting_analyzer_simple(): # we need high firing rate for amplitude_cutoff recording, sorting = generate_ground_truth_recording( durations=[ @@ -73,16 +69,3 @@ def _sorting_analyzer_simple(): sorting_analyzer.compute("spike_amplitudes", **job_kwargs) return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - sorting_analyzer = get_sorting_analyzer(seed=2205) - return sorting_analyzer - - return sorting_analyzer - - -@pytest.fixture(scope="module") -def sorting_analyzer_simple(): - return _sorting_analyzer_simple() diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 446007d10b..156bab84d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -536,40 +536,3 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) - - -if __name__ == "__main__": - - test_unit_structure_in_output(small_sorting_analyzer) - test_unit_id_order_independence(small_sorting_analyzer) - - test_synchrony_counts_no_sync() - test_synchrony_counts_one_sync() - test_synchrony_counts_one_quad_sync() - test_synchrony_counts_not_all_units() - - test_mahalanobis_metrics() - test_lda_metrics() - test_nearest_neighbors_metrics() - test_silhouette_score_metrics() - test_simplified_silhouette_score_metrics() - - test_calculate_firing_rate_num_spikes(sorting_analyzer_simple) - test_calculate_snrs(sorting_analyzer) - test_calculate_amplitude_cutoff(sorting_analyzer) - test_calculate_presence_ratio(sorting_analyzer) - test_calculate_amplitude_median(sorting_analyzer) - test_calculate_sliding_rp_violations(sorting_analyzer) - test_calculate_drift_metrics(sorting_analyzer) - test_synchrony_metrics(sorting_analyzer) - test_synchrony_metrics_unit_id_subset(sorting_analyzer) - test_synchrony_metrics_no_unit_ids(sorting_analyzer) - test_calculate_firing_range(sorting_analyzer) - test_calculate_amplitude_cv_metrics(sorting_analyzer) - test_calculate_sd_ratio(sorting_analyzer) - - sorting_analyzer_violations = _sorting_analyzer_violations() - # print(sorting_analyzer_violations) - test_calculate_isi_violations(sorting_analyzer_violations) - test_calculate_sliding_rp_violations(sorting_analyzer_violations) - test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py index a0fc97c37c..6ddeb02689 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py @@ -22,8 +22,3 @@ def test_calculate_pc_metrics(small_sorting_analyzer): assert not np.all(np.isnan(res2[metric_name].values)) assert np.array_equal(res1[metric_name].values, res2[metric_name].values) - - -if __name__ == "__main__": - - test_calculate_pc_metrics(small_sorting_analyzer) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 616e6c90c1..fec5ceeb95 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -245,9 +245,3 @@ def test_empty_units(sorting_analyzer_simple): # for metric_name in metrics.columns: # # NaNs are skipped # assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) - -if __name__ == "__main__": - - test_compute_quality_metrics(sorting_analyzer_simple) - test_compute_quality_metrics_recordingless(sorting_analyzer_simple) - test_empty_units(sorting_analyzer_simple) From d5b3e0cb716c5b5a72e58ae53dd8c5eef2305db2 Mon Sep 17 00:00:00 2001 From: Axoft Server Date: Mon, 5 Aug 2024 15:45:21 -0400 Subject: [PATCH 13/52] feat: switch to nested dict --- src/spikeinterface/widgets/unit_summary.py | 54 ++++++++-------------- 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 8313e6b9c7..a22b17bb8a 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections import defaultdict import numpy as np @@ -30,16 +31,13 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - unitlocationswidget_params : dict or None, default: None - Parameters for UnitLocationsWidget (see UnitLocationsWidget for details) - unitwaveformswidgets_params : dict or None, default: None - Parameters for UnitWaveformsWidget (see UnitWaveformsWidget for details) - unitwaveformdensitymapwidget_params : dict or None, default: None - Parameters for UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) - autocorrelogramswidget_params : dict or None, default: None - Parameters for AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) - amplitudeswidget_params : dict or None, default: None - Parameters for AmplitudesWidget (see AmplitudesWidget for details) + widget_params : dict or None, default: None + Parameters for the subwidgets in a nested dictionary + unitlocations_params: UnitLocationsWidget (see UnitLocationsWidget for details) + unitwaveforms_params: UnitWaveformsWidget (see UnitWaveformsWidget for details) + unitwaveformdensitymap_params : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms_params : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes_params : AmplitudesWidget (see AmplitudesWidget for details) """ # possible_backends = {} @@ -50,11 +48,7 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - unitlocationswidget_params=None, - unitwaveformswidget_params=None, - unitwaveformdensitymapwidget_params=None, - autocorrelogramswidget_params=None, - amplitudeswidget_params=None, + widget_params=None, backend=None, **backend_kwargs, ): @@ -63,27 +57,15 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) - if unitlocationswidget_params is None: - unitlocationswidget_params = dict() - if unitwaveformswidget_params is None: - unitwaveformswidget_params = dict() - if unitwaveformdensitymapwidget_params is None: - unitwaveformdensitymapwidget_params = dict() - if autocorrelogramswidget_params is None: - autocorrelogramswidget_params = dict() - if amplitudeswidget_params is None: - amplitudeswidget_params = dict() + if widget_params is None: + widget_params = dict() plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, - unitlocationswidget_params=unitlocationswidget_params, - unitwaveformswidget_params=unitwaveformswidget_params, - unitwaveformdensitymapwidget_params=unitwaveformdensitymapwidget_params, - autocorrelogramswidget_params=autocorrelogramswidget_params, - amplitudeswidget_params=amplitudeswidget_params + widget_params=widget_params, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -98,11 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting_analyzer = dp.sorting_analyzer unit_colors = dp.unit_colors sparsity = dp.sparsity - unitlocationswidget_params = dp.unitlocationswidget_params - unitwaveformswidget_params = dp.unitwaveformswidget_params - unitwaveformdensitymapwidget_params = dp.unitwaveformdensitymapwidget_params - autocorrelogramswidget_params = dp.autocorrelogramswidget_params - amplitudeswidget_params = dp.amplitudeswidget_params + + widget_params = defaultdict(lambda: dict(), dp.widget_params) + unitlocationswidget_params = widget_params['unitlocations_params'] + unitwaveformswidget_params = widget_params['unitwaveforms_params'] + unitwaveformdensitymapwidget_params = widget_params['unitwaveformdensitymap_params'] + autocorrelogramswidget_params = widget_params['autocorrelograms_params'] + amplitudeswidget_params = widget_params['amplitudes_params'] # force the figure without axes if "figsize" not in backend_kwargs: From 87956ccc2791883c98435f1a0ae3482ecce2d2d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 5 Aug 2024 19:47:11 +0000 Subject: [PATCH 14/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/unit_summary.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index a61ebe2c86..75a399fab5 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -31,7 +31,7 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - widget_params : dict or None, default: None + widget_params : dict or None, default: None Parameters for the subwidgets in a nested dictionary unitlocations_params: UnitLocationsWidget (see UnitLocationsWidget for details) unitwaveforms_params: UnitWaveformsWidget (see UnitWaveformsWidget for details) @@ -58,7 +58,7 @@ def __init__( unit_colors = get_unit_colors(sorting_analyzer) if widget_params is None: - widget_params = dict() + widget_params = dict() plot_data = dict( sorting_analyzer=sorting_analyzer, @@ -82,11 +82,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity = dp.sparsity widget_params = defaultdict(lambda: dict(), dp.widget_params) - unitlocationswidget_params = widget_params['unitlocations_params'] - unitwaveformswidget_params = widget_params['unitwaveforms_params'] - unitwaveformdensitymapwidget_params = widget_params['unitwaveformdensitymap_params'] - autocorrelogramswidget_params = widget_params['autocorrelograms_params'] - amplitudeswidget_params = widget_params['amplitudes_params'] + unitlocationswidget_params = widget_params["unitlocations_params"] + unitwaveformswidget_params = widget_params["unitwaveforms_params"] + unitwaveformdensitymapwidget_params = widget_params["unitwaveformdensitymap_params"] + autocorrelogramswidget_params = widget_params["autocorrelograms_params"] + amplitudeswidget_params = widget_params["amplitudes_params"] # force the figure without axes if "figsize" not in backend_kwargs: From 5c351a4d25ac3fb992e30db3eeea3dcdf4cd3955 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Wed, 28 Aug 2024 11:37:05 -0400 Subject: [PATCH 15/52] save run_info to check extensions for completion --- src/spikeinterface/core/sortinganalyzer.py | 90 +++++++++++++++++++--- 1 file changed, 79 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fa4547d272..5eaebb3189 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -18,6 +18,8 @@ import spikeinterface +from zarr.errors import ArrayNotFoundError + from .baserecording import BaseRecording from .basesorting import BaseSorting @@ -1719,8 +1721,12 @@ def __init__(self, sorting_analyzer): self._sorting_analyzer = weakref.ref(sorting_analyzer) self.params = None + self.run_info = self._default_run_info_dict() self.data = dict() + def _default_run_info_dict(self): + return dict(run_completed=False, data_loadable=False, runtime_s=None) + ####### # This 3 methods must be implemented in the subclass!!! # See DummyAnalyzerExtension in test_sortinganalyzer.py as a simple example @@ -1832,11 +1838,35 @@ def _get_zarr_extension_group(self, mode="r+"): def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() - ext.load_data() - if cls.need_backward_compatibility_on_load: - ext._handle_backward_compatibility_on_load() + ext.load_run_info() + if ext.run_info["run_completed"] and ext.run_info["data_loadable"]: + ext.load_data() + if cls.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + return ext + elif ext.run_info["run_completed"] and not ext.run_info["data_loadable"]: + warnings.warn( + f"Extension {cls.extension_name} has been computed but the data is not loadable. " + "The extension should be re-computed." + ) + return ext + else: + return None + + def load_run_info(self): + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + assert run_info_file.is_file(), f"No run_info file in extension {self.extension_name} folder" + with open(str(run_info_file), "r") as f: + run_info = json.load(f) + + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r") + assert "run_info" in extension_group.attrs, f"No run_info file in extension {self.extension_name} folder" + run_info = extension_group.attrs["run_info"] - return ext + self.run_info = run_info def load_params(self): if self.format == "binary_folder": @@ -1853,13 +1883,15 @@ def load_params(self): self.params = params - def load_data(self): + def load_data(self, keep=True): + ext_data = None + if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() for ext_data_file in extension_folder.iterdir(): # patch for https://github.com/SpikeInterface/spikeinterface/issues/3041 # maybe add a check for version number from the info.json during loading only - if ext_data_file.name == "params.json" or ext_data_file.name == "info.json": + if ext_data_file.name == "params.json" or ext_data_file.name == "info.json" or ext_data_file.name == "run_info.json": continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": @@ -1878,7 +1910,6 @@ def load_data(self): ext_data = pickle.load(ext_data_file.open("rb")) else: continue - self.data[ext_data_name] = ext_data elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") @@ -1898,12 +1929,29 @@ def load_data(self): else: # this load in memmory ext_data = np.array(ext_data_) - self.data[ext_data_name] = ext_data + + if ext_data is None: + warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") + + if keep: + self.data[ext_data_name] = ext_data + + def _check_data_loadable(self): + try: + self.load_data(keep=False) + return True + except ( + ValueError, IOError, EOFError, KeyError, UnicodeDecodeError, + json.JSONDecodeError, pickle.UnpicklingError, pd.errors.ParserError, + ArrayNotFoundError + ): + return False def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() + new_extension.run_info = self.run_info.copy() # TODO: does copy() assume both extensions have been run? if unit_ids is None: new_extension.data = self.data else: @@ -1922,6 +1970,7 @@ def merge( ): new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() + new_extension.run_info = self.run_info.copy() # TODO: does merge() assume both extensions have been run? new_extension.data = self._merge_extension_data( merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs ) @@ -1930,19 +1979,26 @@ def merge( def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): - # this also reset the folder or zarr group - self._save_params() + # NB: this call to _save_params() also resets the folder or zarr group + self._save_params() self._save_importing_provenance() + self._save_run_info() self._run(**kwargs) - + if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) + self.run_info["data_loadable"] = self._check_data_loadable() # maybe overkill? + + self.run_info["run_completed"] = True + self._save_run_info() def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) + self.run_info["data_loadable"] = self._check_data_loadable() + self._save_run_info() def _save_data(self, **kwargs): if self.format == "memory": @@ -2041,6 +2097,7 @@ def reset(self): """ self._reset_extension_folder() self.params = None + self.run_info = self._default_run_info_dict() self.data = dict() def set_params(self, save=True, **params): @@ -2098,6 +2155,17 @@ def _save_importing_provenance(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["info"] = info + def _save_run_info(self): + run_info = self.run_info.copy() + + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8") + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r+") + extension_group.attrs["run_info"] = run_info + def get_pipeline_nodes(self): assert ( self.use_nodepipeline From 725b208afba299191419f0ef5bfd9eb379036730 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Wed, 28 Aug 2024 12:55:58 -0400 Subject: [PATCH 16/52] save run time --- src/spikeinterface/core/sortinganalyzer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 5eaebb3189..03a3653c52 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from time import time import numpy as np @@ -1984,12 +1985,15 @@ def run(self, save=True, **kwargs): self._save_importing_provenance() self._save_run_info() + start = time() self._run(**kwargs) - + end = time() + if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) self.run_info["data_loadable"] = self._check_data_loadable() # maybe overkill? - + + self.run_info["runtime_s"] = np.round(end - start, 1) self.run_info["run_completed"] = True self._save_run_info() From d95a754c0924cfcb2f55307a052a1382ecf8ea92 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Wed, 28 Aug 2024 12:56:13 -0400 Subject: [PATCH 17/52] bug fixes for pipeline extensions --- src/spikeinterface/core/sortinganalyzer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 03a3653c52..7bdf2dc516 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2002,6 +2002,8 @@ def save(self, **kwargs): self._save_importing_provenance() self._save_data(**kwargs) self.run_info["data_loadable"] = self._check_data_loadable() + if self.run_info["data_loadable"]: + self.run_info["run_completed"] = True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead self._save_run_info() def _save_data(self, **kwargs): @@ -2122,6 +2124,7 @@ def set_params(self, save=True, **params): if save: self._save_params() self._save_importing_provenance() + self._save_run_info() def _save_params(self): params_to_save = self.params.copy() From 1699413fbb16872f0c3a40561e874a650d87648b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:27:44 +0000 Subject: [PATCH 18/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 34 +++++++++++++++------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7bdf2dc516..3ca4617551 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1853,7 +1853,7 @@ def load(cls, sorting_analyzer): return ext else: return None - + def load_run_info(self): if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() @@ -1861,7 +1861,7 @@ def load_run_info(self): assert run_info_file.is_file(), f"No run_info file in extension {self.extension_name} folder" with open(str(run_info_file), "r") as f: run_info = json.load(f) - + elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") assert "run_info" in extension_group.attrs, f"No run_info file in extension {self.extension_name} folder" @@ -1892,7 +1892,11 @@ def load_data(self, keep=True): for ext_data_file in extension_folder.iterdir(): # patch for https://github.com/SpikeInterface/spikeinterface/issues/3041 # maybe add a check for version number from the info.json during loading only - if ext_data_file.name == "params.json" or ext_data_file.name == "info.json" or ext_data_file.name == "run_info.json": + if ( + ext_data_file.name == "params.json" + or ext_data_file.name == "info.json" + or ext_data_file.name == "run_info.json" + ): continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": @@ -1930,10 +1934,10 @@ def load_data(self, keep=True): else: # this load in memmory ext_data = np.array(ext_data_) - + if ext_data is None: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") - + if keep: self.data[ext_data_name] = ext_data @@ -1942,10 +1946,16 @@ def _check_data_loadable(self): self.load_data(keep=False) return True except ( - ValueError, IOError, EOFError, KeyError, UnicodeDecodeError, - json.JSONDecodeError, pickle.UnpicklingError, pd.errors.ParserError, - ArrayNotFoundError - ): + ValueError, + IOError, + EOFError, + KeyError, + UnicodeDecodeError, + json.JSONDecodeError, + pickle.UnpicklingError, + pd.errors.ParserError, + ArrayNotFoundError, + ): return False def copy(self, new_sorting_analyzer, unit_ids=None): @@ -1981,7 +1991,7 @@ def merge( def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # NB: this call to _save_params() also resets the folder or zarr group - self._save_params() + self._save_params() self._save_importing_provenance() self._save_run_info() @@ -2003,7 +2013,9 @@ def save(self, **kwargs): self._save_data(**kwargs) self.run_info["data_loadable"] = self._check_data_loadable() if self.run_info["data_loadable"]: - self.run_info["run_completed"] = True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead + self.run_info["run_completed"] = ( + True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead + ) self._save_run_info() def _save_data(self, **kwargs): From 869b01a44d758a35869aa91b1009270307850304 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:25:58 -0400 Subject: [PATCH 19/52] switch to perf counter --- src/spikeinterface/core/sortinganalyzer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 61c5515425..642f760241 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2013,9 +2013,9 @@ def run(self, save=True, **kwargs): self._save_importing_provenance() self._save_run_info() - start = time() + start = perf_counter() self._run(**kwargs) - end = time() + end = perf_counter() if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) From 554f6e3765767fe7a045e910027a2606d61ad0c4 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:26:56 -0400 Subject: [PATCH 20/52] use perf counter --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 642f760241..9a54f0a627 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,7 +11,7 @@ import shutil import warnings import importlib -from time import time +from time import perf_counter import numpy as np From be023edc10bd26afef2fd26aa3847eadcc602cf9 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:27:16 -0400 Subject: [PATCH 21/52] remove data_loadable and _check_data_loadable --- src/spikeinterface/core/sortinganalyzer.py | 43 ++++------------------ 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9a54f0a627..251ce3249a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1744,7 +1744,7 @@ def __init__(self, sorting_analyzer): self.data = dict() def _default_run_info_dict(self): - return dict(run_completed=False, data_loadable=False, runtime_s=None) + return dict(run_completed=False, runtime_s=None) ####### # This 3 methods must be implemented in the subclass!!! @@ -1858,17 +1858,11 @@ def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() ext.load_run_info() - if ext.run_info["run_completed"] and ext.run_info["data_loadable"]: + if ext.run_info["run_completed"]: ext.load_data() if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() return ext - elif ext.run_info["run_completed"] and not ext.run_info["data_loadable"]: - warnings.warn( - f"Extension {cls.extension_name} has been computed but the data is not loadable. " - "The extension should be re-computed." - ) - return ext else: return None @@ -1902,7 +1896,7 @@ def load_params(self): self.params = params - def load_data(self, keep=True): + def load_data(self): ext_data = None if self.format == "binary_folder": @@ -1956,25 +1950,7 @@ def load_data(self, keep=True): if ext_data is None: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") - if keep: - self.data[ext_data_name] = ext_data - - def _check_data_loadable(self): - try: - self.load_data(keep=False) - return True - except ( - ValueError, - IOError, - EOFError, - KeyError, - UnicodeDecodeError, - json.JSONDecodeError, - pickle.UnpicklingError, - pd.errors.ParserError, - ArrayNotFoundError, - ): - return False + self.data[ext_data_name] = ext_data def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! @@ -2016,12 +1992,11 @@ def run(self, save=True, **kwargs): start = perf_counter() self._run(**kwargs) end = perf_counter() + self.run_info["runtime_s"] = np.round(end - start, 1) if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) - self.run_info["data_loadable"] = self._check_data_loadable() # maybe overkill? - self.run_info["runtime_s"] = np.round(end - start, 1) self.run_info["run_completed"] = True self._save_run_info() @@ -2029,11 +2004,9 @@ def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) - self.run_info["data_loadable"] = self._check_data_loadable() - if self.run_info["data_loadable"]: - self.run_info["run_completed"] = ( - True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead - ) + self.run_info["run_completed"] = ( + True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to change run_completed here (or somewhere, at least) + ) self._save_run_info() def _save_data(self, **kwargs): From 943e398b3200be284d26752e338907f46cde22c6 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:51:15 -0400 Subject: [PATCH 22/52] edge case where data file is deleted --- src/spikeinterface/core/sortinganalyzer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 251ce3249a..33ab79c45e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1947,11 +1947,11 @@ def load_data(self): # this load in memmory ext_data = np.array(ext_data_) - if ext_data is None: + if ext_data is not None: + self.data[ext_data_name] = ext_data + else: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") - self.data[ext_data_name] = ext_data - def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_analyzer) @@ -2183,7 +2183,8 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data" + assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" + assert len(self.data) > 0, "Extension has been run but no data found — data file might be missing. Re-compute the extension." return self._get_data(*args, **kwargs) From f9d7c0491cfe5cb6ac75dea9e632043fb11fa71c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:53:14 +0000 Subject: [PATCH 23/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 33ab79c45e..1bded4e6a8 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1948,7 +1948,7 @@ def load_data(self): ext_data = np.array(ext_data_) if ext_data is not None: - self.data[ext_data_name] = ext_data + self.data[ext_data_name] = ext_data else: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") @@ -2183,8 +2183,12 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" - assert len(self.data) > 0, "Extension has been run but no data found — data file might be missing. Re-compute the extension." + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" + assert ( + len(self.data) > 0 + ), "Extension has been run but no data found — data file might be missing. Re-compute the extension." return self._get_data(*args, **kwargs) From 8e9acfc66bd22b2232914274c12319f43e197559 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 10:18:23 -0400 Subject: [PATCH 24/52] always return None if extension data is missing --- src/spikeinterface/core/sortinganalyzer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1bded4e6a8..9b2144ddcf 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1862,9 +1862,12 @@ def load(cls, sorting_analyzer): ext.load_data() if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() - return ext - else: - return None + if len(ext.data) > 0: + return ext + + # If extension run not completed, or data has gone missing, + # return None to indicate that the extension should be (re)computed. + return None def load_run_info(self): if self.format == "binary_folder": @@ -2183,12 +2186,8 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info[ - "run_completed" - ], f"You must run the extension {self.extension_name} before retrieving data" - assert ( - len(self.data) > 0 - ), "Extension has been run but no data found — data file might be missing. Re-compute the extension." + assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" + assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) From f1a3d97d8daf55607f9ac0da909b03e4297fce9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:21:14 +0000 Subject: [PATCH 25/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9b2144ddcf..017760c002 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2186,7 +2186,9 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) From d18f48f7ca9a3168c15fdd896377a64c56695f9c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 15:45:29 +0200 Subject: [PATCH 26/52] Add extra protection for template metrix --- .../postprocessing/template_metrics.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index e54ff87221..9d21e56611 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -8,11 +8,9 @@ import numpy as np import warnings -from typing import Optional from copy import deepcopy from ..core.sortinganalyzer import register_result_extension, AnalyzerExtension -from ..core import ChannelSparsity from ..core.template_tools import get_template_extremum_channel from ..core.template_tools import get_dense_templates_array @@ -238,13 +236,17 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job for metric_name in metrics_single_channel: func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - sampling_frequency=sampling_frequency_up, - trough_idx=trough_idx, - peak_idx=peak_idx, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + sampling_frequency=sampling_frequency_up, + trough_idx=trough_idx, + peak_idx=peak_idx, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value # compute metrics multi_channel From dd53372d58506f045ce5b1742498a9fc1c8a20a8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 19:16:22 +0200 Subject: [PATCH 27/52] Fix error in load data, ensure backward compatibility, and add tests --- src/spikeinterface/core/sortinganalyzer.py | 51 ++++++++++++------- .../core/tests/test_sortinganalyzer.py | 33 ++++++++++++ 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b0c6162b05..06879d09a2 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -19,8 +19,6 @@ import spikeinterface -from zarr.errors import ArrayNotFoundError - from .baserecording import BaseRecording from .basesorting import BaseSorting @@ -1339,6 +1337,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys()) + t_start = perf_counter() results = run_node_pipeline( self.recording, all_nodes, @@ -1348,10 +1347,14 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job squeeze_output=False, verbose=verbose, ) + t_end = perf_counter() + # for pipeline node extensions we can only track the runtime of the run_node_pipeline + runtime_s = np.round(t_end - t_start, 1) for r, result in enumerate(results): extension_name, variable_name = result_routage[r] extension_instances[extension_name].data[variable_name] = result + extension_instances[extension_name].run_info["runtime_s"] = runtime_s for extension_name, extension_instance in extension_instances.items(): self.extensions[extension_name] = extension_instance @@ -1859,13 +1862,20 @@ def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() ext.load_run_info() - if ext.run_info["run_completed"]: + if ext.run_info is not None: + if ext.run_info["run_completed"]: + ext.load_data() + if cls.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + if len(ext.data) > 0: + return ext + else: + # this is for back-compatibility of old analyzers ext.load_data() if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() if len(ext.data) > 0: return ext - # If extension run not completed, or data has gone missing, # return None to indicate that the extension should be (re)computed. return None @@ -1874,15 +1884,18 @@ def load_run_info(self): if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() run_info_file = extension_folder / "run_info.json" - assert run_info_file.is_file(), f"No run_info file in extension {self.extension_name} folder" - with open(str(run_info_file), "r") as f: - run_info = json.load(f) + if run_info_file.is_file(): + with open(str(run_info_file), "r") as f: + run_info = json.load(f) + else: + warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") + run_info = None elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") - assert "run_info" in extension_group.attrs, f"No run_info file in extension {self.extension_name} folder" - run_info = extension_group.attrs["run_info"] - + run_info = extension_group.attrs.get("run_info", None) + if run_info is None: + warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") self.run_info = run_info def load_params(self): @@ -1902,7 +1915,6 @@ def load_params(self): def load_data(self): ext_data = None - if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() for ext_data_file in extension_folder.iterdir(): @@ -1922,6 +1934,7 @@ def load_data(self): # and have a link to the old buffer on windows then it fails # ext_data = np.load(ext_data_file, mmap_mode="r") # so we go back to full loading + print(f"{ext_data_file} is numpy!") ext_data = np.load(ext_data_file) elif ext_data_file.suffix == ".csv": import pandas as pd @@ -1931,6 +1944,7 @@ def load_data(self): ext_data = pickle.load(ext_data_file.open("rb")) else: continue + self.data[ext_data_name] = ext_data elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") @@ -1950,21 +1964,20 @@ def load_data(self): else: # this load in memmory ext_data = np.array(ext_data_) + self.data[ext_data_name] = ext_data - if ext_data is not None: - self.data[ext_data_name] = ext_data - else: + if len(self.data) == 0: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() - new_extension.run_info = self.run_info.copy() # TODO: does copy() assume both extensions have been run? if unit_ids is None: new_extension.data = self.data else: new_extension.data = self._select_extension_data(unit_ids) + new_extension.run_info = self.run_info.copy() new_extension.save() return new_extension @@ -1979,10 +1992,10 @@ def merge( ): new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() - new_extension.run_info = self.run_info.copy() # TODO: does merge() assume both extensions have been run? new_extension.data = self._merge_extension_data( merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs ) + new_extension.run_info = self.run_info.copy() new_extension.save() return new_extension @@ -1993,10 +2006,10 @@ def run(self, save=True, **kwargs): self._save_importing_provenance() self._save_run_info() - start = perf_counter() + t_start = perf_counter() self._run(**kwargs) - end = perf_counter() - self.run_info["runtime_s"] = np.round(end - start, 1) + t_end = perf_counter() + self.run_info["runtime_s"] = np.round(t_end - t_start, 1) if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 689073d6bf..3f45487f4c 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -125,6 +125,39 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): ) +def test_load_without_runtime_info(tmp_path, dataset): + recording, sorting = dataset + + folder = tmp_path / "test_SortingAnalyzer_run_info" + + extensions = ["random_spikes", "templates"] + # binary_folder + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None + ) + sorting_analyzer.compute(extensions) + # remove run_info.json to mimic a previous version of spikeinterface + for ext in extensions: + (folder / "extensions" / ext / "run_info.json").unlink() + # should raise a warning for missing run_info + with pytest.warns(UserWarning): + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + + # zarr + folder = tmp_path / "test_SortingAnalyzer_run_info.zarr" + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + ) + sorting_analyzer.compute(extensions) + # remove run_info from attrs to mimic a previous version of spikeinterface + root = sorting_analyzer._get_zarr_root(mode="r+") + for ext in extensions: + del root["extensions"][ext].attrs["run_info"] + # should raise a warning for missing run_info + with pytest.warns(UserWarning): + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + + def test_SortingAnalyzer_tmp_recording(dataset): recording, sorting = dataset recording_cached = recording.save(mode="memory") From 8240ee9281186ebb47a0cbdf7317d05b6d684a65 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 19:17:57 +0200 Subject: [PATCH 28/52] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 06879d09a2..95c1a290ba 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2021,9 +2021,7 @@ def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) - self.run_info["run_completed"] = ( - True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to change run_completed here (or somewhere, at least) - ) + self.run_info["run_completed"] = True self._save_run_info() def _save_data(self, **kwargs): From df9efc9a39e8534daa81cb39efa9db3c1d51518b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 15:11:57 +0200 Subject: [PATCH 29/52] Minor typing fixes --- src/spikeinterface/core/core_tools.py | 2 +- src/spikeinterface/core/recording_tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 380562dbd5..b3a857d158 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -201,7 +201,7 @@ def is_dict_extractor(d: dict) -> bool: extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"]) -def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]: +def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element, None, None]: """ Iterator for recursive traversal of a dictionary. This function explores the dictionary recursively and yields the path to each value along with the value itself. diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index b3eda360e7..34be7153b7 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -69,7 +69,7 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsi def write_binary_recording( recording: "BaseRecording", file_paths: list[Path | str] | Path | str, - dtype: np.ndtype = None, + dtype: np.dtype = None, add_file_extension: bool = True, byte_offset: int = 0, auto_cast_uint: bool = True, From ea13bcb9996e4894e7d9ea1be49fe6a2c5dee6c8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Sep 2024 19:16:27 +0200 Subject: [PATCH 30/52] Add protection for multi-channel metrics (thanks Chris) --- .../qualitymetrics/quality_metric_calculator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 0c7cf25237..cdf6151e95 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -164,7 +164,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]: if not sorting_analyzer.has_extension("principal_components"): - raise ValueError("waveform_principal_component must be provied") + raise ValueError( + "To compute principal components base metrics, the principal components " + "extension must be computed first." + ) pc_metrics = compute_pc_metrics( sorting_analyzer, unit_ids=non_empty_unit_ids, From 4e000ed041b11a9f2195691caf0bcb39bca4a500 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Sep 2024 19:24:07 +0200 Subject: [PATCH 31/52] same for multi-channel --- .../postprocessing/template_metrics.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 9d21e56611..726ec49558 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -276,12 +276,16 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job sampling_frequency_up = sampling_frequency func = _metric_name_to_func[metric_name] - value = func( - template_upsampled, - channel_locations=channel_locations_sparse, - sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], - ) + try: + value = func( + template_upsampled, + channel_locations=channel_locations_sparse, + sampling_frequency=sampling_frequency_up, + **self.params["metrics_kwargs"], + ) + except Exception as e: + warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") + value = np.nan template_metrics.at[index, metric_name] = value return template_metrics From 1fc89b4a128f4584fc560f963200b076734d2654 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Sep 2024 19:33:33 +0200 Subject: [PATCH 32/52] Reset-times: segment must have either time vector or sampling frequency --- src/spikeinterface/core/baserecording.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f7918be7b0..225f070d9d 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -506,6 +506,7 @@ def reset_times(self): rs = self._recording_segments[segment_index] rs.t_start = None rs.time_vector = None + rs.sampling_frequency = self.sampling_frequency def sample_index_to_time(self, sample_ind, segment_index=None): """ From 7ddf482056243e4f3198db163757251cd737bda5 Mon Sep 17 00:00:00 2001 From: Florent Pollet Date: Sat, 7 Sep 2024 19:05:27 -0400 Subject: [PATCH 33/52] fix: change naming convention --- src/spikeinterface/widgets/unit_summary.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 75a399fab5..f1f85b7dc3 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -33,11 +33,11 @@ class UnitSummaryWidget(BaseWidget): If SortingAnalyzer is already sparse, the argument is ignored widget_params : dict or None, default: None Parameters for the subwidgets in a nested dictionary - unitlocations_params: UnitLocationsWidget (see UnitLocationsWidget for details) - unitwaveforms_params: UnitWaveformsWidget (see UnitWaveformsWidget for details) - unitwaveformdensitymap_params : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) - autocorrelograms_params : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) - amplitudes_params : AmplitudesWidget (see AmplitudesWidget for details) + unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details) + unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details) + unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes: AmplitudesWidget (see AmplitudesWidget for details) """ # possible_backends = {} @@ -82,11 +82,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity = dp.sparsity widget_params = defaultdict(lambda: dict(), dp.widget_params) - unitlocationswidget_params = widget_params["unitlocations_params"] - unitwaveformswidget_params = widget_params["unitwaveforms_params"] - unitwaveformdensitymapwidget_params = widget_params["unitwaveformdensitymap_params"] - autocorrelogramswidget_params = widget_params["autocorrelograms_params"] - amplitudeswidget_params = widget_params["amplitudes_params"] + unitlocationswidget_params = widget_params["unit_locations"] + unitwaveformswidget_params = widget_params["unit_waveforms"] + unitwaveformdensitymapwidget_params = widget_params["unit_waveform_density_map"] + autocorrelogramswidget_params = widget_params["autocorrelograms"] + amplitudeswidget_params = widget_params["amplitudes"] # force the figure without axes if "figsize" not in backend_kwargs: From a3fef7055d3702481cee99f17821f910b84921c3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 9 Sep 2024 17:54:01 +0200 Subject: [PATCH 34/52] Reset times also sets t_start to None --- src/spikeinterface/core/baserecording.py | 9 +++++---- src/spikeinterface/core/tests/test_baserecording.py | 4 ++++ 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 225f070d9d..3e5e43b528 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -497,16 +497,17 @@ def set_times(self, times, segment_index=None, with_warning=True): def reset_times(self): """ - Reset times in-memory for all segments that have a time vector. + Reset time information in-memory for all segments that have a time vector. If the timestamps come from a file, the files won't be modified. but only the in-memory - attributes of the recording objects are deleted. + attributes of the recording objects are deleted. Also `t_start` is set to None and the + segment's sampling frequency is set to the recording's sampling frequency. """ for segment_index in range(self.get_num_segments()): if self.has_time_vector(segment_index): rs = self._recording_segments[segment_index] - rs.t_start = None rs.time_vector = None - rs.sampling_frequency = self.sampling_frequency + rs.t_start = None + rs.sampling_frequency = self.sampling_frequency def sample_index_to_time(self, sample_ind, segment_index=None): """ diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 3758fc3b43..9c354510ac 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -292,7 +292,11 @@ def test_BaseRecording(create_cache_folder): # reset times rec.reset_times() for segm in range(num_seg): + time_info = rec.get_time_info(segment_index=segm) assert not rec.has_time_vector(segment_index=segm) + assert time_info["t_start"] is None + assert time_info["time_vector"] is None + assert time_info["sampling_frequency"] == rec.sampling_frequency # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) From c9a5bce90b7a12d4de5f51aac09156f06d49f8f5 Mon Sep 17 00:00:00 2001 From: Yue Huang <62061407+jiumao2@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:58:09 +0800 Subject: [PATCH 35/52] Update segmentutils.py Fix integer overflow of the total sample number when concatenating long recordings --- src/spikeinterface/core/segmentutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index 039fa8fd60..db97aead00 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -157,7 +157,7 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True): self.parent_segments = parent_segments self.all_length = [rec_seg.get_num_samples() for rec_seg in self.parent_segments] self.cumsum_length = [0] + [sum(self.all_length[: i + 1]) for i in range(len(self.all_length))] - self.total_length = int(np.sum(self.all_length)) + self.total_length = int(np.sum(self.all_length, dtype=np.int64)) def get_num_samples(self): return self.total_length @@ -450,7 +450,7 @@ def __init__(self, parent_segments, parent_num_samples, sampling_frequency): self.parent_segments = parent_segments self.parent_num_samples = parent_num_samples self.cumsum_length = np.cumsum([0] + self.parent_num_samples) - self.total_num_samples = np.sum(self.parent_num_samples) + self.total_num_samples = np.sum(self.parent_num_samples, dtype=np.int64) def get_num_samples(self): return self.total_num_samples From 5775c36a5309a7b1b48698940ca83551bc4053f7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 10:23:11 +0200 Subject: [PATCH 36/52] Update src/spikeinterface/core/recording_tools.py --- src/spikeinterface/core/recording_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 34be7153b7..0ec5449bae 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -69,7 +69,7 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsi def write_binary_recording( recording: "BaseRecording", file_paths: list[Path | str] | Path | str, - dtype: np.dtype = None, + dtype: np.typing.DTypeLike = None, add_file_extension: bool = True, byte_offset: int = 0, auto_cast_uint: bool = True, From 5769eff8f6697eac93780ef4ded75c969b22f0ae Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 12:18:57 +0200 Subject: [PATCH 37/52] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 95c1a290ba..970cd150fc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1934,7 +1934,6 @@ def load_data(self): # and have a link to the old buffer on windows then it fails # ext_data = np.load(ext_data_file, mmap_mode="r") # so we go back to full loading - print(f"{ext_data_file} is numpy!") ext_data = np.load(ext_data_file) elif ext_data_file.suffix == ".csv": import pandas as pd From cc21f06492f3d84c476814e0b09f19b2ca84ceeb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 12:22:41 +0200 Subject: [PATCH 38/52] Apply suggestions from code review --- src/spikeinterface/core/sortinganalyzer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 970cd150fc..57c9e5f37c 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1349,12 +1349,13 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job ) t_end = perf_counter() # for pipeline node extensions we can only track the runtime of the run_node_pipeline - runtime_s = np.round(t_end - t_start, 1) + runtime_s = t_end - t_start for r, result in enumerate(results): extension_name, variable_name = result_routage[r] extension_instances[extension_name].data[variable_name] = result extension_instances[extension_name].run_info["runtime_s"] = runtime_s + extension_instances[extension_name].run_info["run_completed"] = True for extension_name, extension_instance in extension_instances.items(): self.extensions[extension_name] = extension_instance @@ -2008,7 +2009,7 @@ def run(self, save=True, **kwargs): t_start = perf_counter() self._run(**kwargs) t_end = perf_counter() - self.run_info["runtime_s"] = np.round(t_end - t_start, 1) + self.run_info["runtime_s"] = t_end - t_start if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) @@ -2020,7 +2021,6 @@ def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) - self.run_info["run_completed"] = True self._save_run_info() def _save_data(self, **kwargs): From ccc8a3888dd5924cf15e5e56cb2fa7d1d8342aa5 Mon Sep 17 00:00:00 2001 From: Yue Huang <806628409@qq.com> Date: Tue, 10 Sep 2024 22:53:11 +0800 Subject: [PATCH 39/52] Update segmentutils.py --- src/spikeinterface/core/segmentutils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/segmentutils.py b/src/spikeinterface/core/segmentutils.py index db97aead00..e2ce266ca7 100644 --- a/src/spikeinterface/core/segmentutils.py +++ b/src/spikeinterface/core/segmentutils.py @@ -157,7 +157,7 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True): self.parent_segments = parent_segments self.all_length = [rec_seg.get_num_samples() for rec_seg in self.parent_segments] self.cumsum_length = [0] + [sum(self.all_length[: i + 1]) for i in range(len(self.all_length))] - self.total_length = int(np.sum(self.all_length, dtype=np.int64)) + self.total_length = int(sum(self.all_length)) def get_num_samples(self): return self.total_length @@ -450,7 +450,7 @@ def __init__(self, parent_segments, parent_num_samples, sampling_frequency): self.parent_segments = parent_segments self.parent_num_samples = parent_num_samples self.cumsum_length = np.cumsum([0] + self.parent_num_samples) - self.total_num_samples = np.sum(self.parent_num_samples, dtype=np.int64) + self.total_num_samples = int(sum(self.parent_num_samples)) def get_num_samples(self): return self.total_num_samples From 30bd9b93ae88bc07a3e8ab13eadedb255440a083 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 19:30:31 +0200 Subject: [PATCH 40/52] Fix bug with reset_times --- src/spikeinterface/core/baserecording.py | 2 +- src/spikeinterface/core/tests/test_baserecording.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 3e5e43b528..f05d5b29dc 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -503,8 +503,8 @@ def reset_times(self): segment's sampling frequency is set to the recording's sampling frequency. """ for segment_index in range(self.get_num_segments()): + rs = self._recording_segments[segment_index] if self.has_time_vector(segment_index): - rs = self._recording_segments[segment_index] rs.time_vector = None rs.t_start = None rs.sampling_frequency = self.sampling_frequency diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 9c354510ac..6b60efe2b6 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -298,6 +298,9 @@ def test_BaseRecording(create_cache_folder): assert time_info["time_vector"] is None assert time_info["sampling_frequency"] == rec.sampling_frequency + # resetting time again should be ok + rec.reset_times() + # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) locations_3d = rec_3d.get_property("location") From e5e0dd206dce9b51065893b1ebb2d3ef2b1709ed Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:06:25 -0400 Subject: [PATCH 41/52] Revert "add `is_filtered` to annotations" This reverts commit 9802c34f78081b63e5d953fb0601cd6e079ce7cf. --- src/spikeinterface/core/baserecording.py | 1 - src/spikeinterface/core/binaryrecordingextractor.py | 6 +++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 03db6bd9af..e65afabaca 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -549,7 +549,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): channel_ids=self.get_channel_ids(), time_axis=0, file_offset=0, - is_filtered=self.is_filtered(), gain_to_uV=self.get_channel_gains(), offset_to_uV=self.get_channel_offsets(), ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 8f542647f1..a0e349728e 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -65,9 +65,13 @@ def __init__( gain_to_uV=None, offset_to_uV=None, is_filtered=None, + num_chan=None, ): - + # This assigns num_channels if num_channels is not None, otherwise num_chan is assigned + num_channels = num_channels or num_chan assert num_channels is not None, "You must provide num_channels or num_chan" + if num_chan is not None: + warnings.warn("`num_chan` is to be deprecated in version 0.100, please use `num_channels` instead") if channel_ids is None: channel_ids = list(range(num_channels)) From 66d7fcb31ae17e301003cc6df46e3317877e9b69 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:12:11 -0400 Subject: [PATCH 42/52] include num_chan for backward compatibility --- src/spikeinterface/core/baserecording.py | 1 + src/spikeinterface/core/binaryrecordingextractor.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e65afabaca..03db6bd9af 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -549,6 +549,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): channel_ids=self.get_channel_ids(), time_axis=0, file_offset=0, + is_filtered=self.is_filtered(), gain_to_uV=self.get_channel_gains(), offset_to_uV=self.get_channel_offsets(), ) diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index a0e349728e..36de79d111 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -68,10 +68,12 @@ def __init__( num_chan=None, ): # This assigns num_channels if num_channels is not None, otherwise num_chan is assigned + # num_chan needs to be be kept for backward compatibility but should not be used by the + # end user num_channels = num_channels or num_chan assert num_channels is not None, "You must provide num_channels or num_chan" if num_chan is not None: - warnings.warn("`num_chan` is to be deprecated in version 0.100, please use `num_channels` instead") + warnings.warn("`num_chan` is to be deprecated as of version 0.100, please use `num_channels` instead") if channel_ids is None: channel_ids = list(range(num_channels)) From ec0a63949262a535ff5b1e484c7ca6cf5f65c1ca Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:12:34 -0400 Subject: [PATCH 43/52] Revert "fix test for deprecation" This reverts commit 309be48e5b38ad56c0ab2cf25c2e359f3058b7f2. --- src/spikeinterface/core/tests/test_binaryrecordingextractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 7d90c48947..ea5edc6e6e 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -94,7 +94,7 @@ def test_sequential_reading_of_small_traces(folder_with_binary_files): file_paths = [folder / "traces_cached_seg0.raw"] recording = BinaryRecordingExtractor( - num_channels=num_channels, + num_chan=num_channels, file_paths=file_paths, sampling_frequency=sampling_frequency, dtype=dtype, From 4b6b1b96f70edeaae197d2ca6175af11ab9b6d10 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:14:15 -0400 Subject: [PATCH 44/52] add comment about num_chan --- src/spikeinterface/core/tests/test_binaryrecordingextractor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index ea5edc6e6e..700910f9cd 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -93,6 +93,8 @@ def test_sequential_reading_of_small_traces(folder_with_binary_files): dtype = "float32" file_paths = [folder / "traces_cached_seg0.raw"] + # `num_chan` is kept for backward compatibility so including it at least one test + # run is good to ensure that it is appropriately accepted as an argument recording = BinaryRecordingExtractor( num_chan=num_channels, file_paths=file_paths, From 15c51f96ef593b73b35eb1691168b3464bb9128e Mon Sep 17 00:00:00 2001 From: Florent Pollet Date: Tue, 10 Sep 2024 20:25:41 -0400 Subject: [PATCH 45/52] feat: renaming to kwargs + clearer error msg --- src/spikeinterface/widgets/unit_summary.py | 46 ++++++++++++---------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index f1f85b7dc3..fd7dafd5d6 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -18,26 +18,27 @@ class UnitSummaryWidget(BaseWidget): """ Plot a unit summary. - If amplitudes are alreday computed they are displayed. + If amplitudes are alreday computed, they are displayed. Parameters ---------- - sorting_analyzer : SortingAnalyzer + sorting_analyzer: SortingAnalyzer The SortingAnalyzer object - unit_id : int or str + unit_id: int or str The unit id to plot the summary of - unit_colors : dict or None, default: None + unit_colors: dict or None, default: None If given, a dictionary with unit ids as keys and colors as values, - sparsity : ChannelSparsity or None, default: None + sparsity: ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - widget_params : dict or None, default: None + subwidget_kwargs: dict or None, default: None Parameters for the subwidgets in a nested dictionary unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details) unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details) unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) amplitudes: AmplitudesWidget (see AmplitudesWidget for details) + Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ # possible_backends = {} @@ -48,7 +49,7 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - widget_params=None, + subwidget_kwargs=None, backend=None, **backend_kwargs, ): @@ -57,15 +58,18 @@ def __init__( if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) - if widget_params is None: - widget_params = dict() + if subwidget_kwargs is None: + subwidget_kwargs = dict() + for kwargs in subwidget_kwargs.values(): + if "unit_colors" in kwargs: + raise ValueError("unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary") plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, - widget_params=widget_params, + subwidget_kwargs=subwidget_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -81,12 +85,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors = dp.unit_colors sparsity = dp.sparsity - widget_params = defaultdict(lambda: dict(), dp.widget_params) - unitlocationswidget_params = widget_params["unit_locations"] - unitwaveformswidget_params = widget_params["unit_waveforms"] - unitwaveformdensitymapwidget_params = widget_params["unit_waveform_density_map"] - autocorrelogramswidget_params = widget_params["autocorrelograms"] - amplitudeswidget_params = widget_params["amplitudes"] + subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs) + unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"] + unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"] + unitwaveformdensitymapwidget_kwargs = subwidget_kwargs["unit_waveform_density_map"] + autocorrelogramswidget_kwargs = subwidget_kwargs["autocorrelograms"] + amplitudeswidget_kwargs = subwidget_kwargs["amplitudes"] # force the figure without axes if "figsize" not in backend_kwargs: @@ -117,7 +121,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, - **unitlocationswidget_params, + **unitlocationswidget_kwargs, ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -140,7 +144,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, - **unitwaveformswidget_params, + **unitwaveformswidget_kwargs, ) ax2.set_title(None) @@ -154,7 +158,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, - **unitwaveformdensitymapwidget_params, + **unitwaveformdensitymapwidget_kwargs, ) ax3.set_ylabel(None) @@ -166,7 +170,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, - **autocorrelogramswidget_params, + **autocorrelogramswidget_kwargs, ) ax4.set_title(None) @@ -184,7 +188,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, - **amplitudeswidget_params, + **amplitudeswidget_kwargs, ) fig.suptitle(f"unit_id: {dp.unit_id}") From 3a48d66a52af617fdbbac4bcd94595727dc1cffe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 00:26:04 +0000 Subject: [PATCH 46/52] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/widgets/unit_summary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fd7dafd5d6..652c0f841f 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -62,7 +62,9 @@ def __init__( subwidget_kwargs = dict() for kwargs in subwidget_kwargs.values(): if "unit_colors" in kwargs: - raise ValueError("unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary") + raise ValueError( + "unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary" + ) plot_data = dict( sorting_analyzer=sorting_analyzer, From 94abde04386441e3d2994635d4ee545a5b414175 Mon Sep 17 00:00:00 2001 From: Florent Pollet Date: Tue, 10 Sep 2024 20:27:18 -0400 Subject: [PATCH 47/52] feat: default dict comment --- src/spikeinterface/widgets/unit_summary.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index fd7dafd5d6..6c39c902eb 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -85,6 +85,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors = dp.unit_colors sparsity = dp.sparsity + # defaultdict returns empty dict if key not found in subwidget_kwargs subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs) unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"] unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"] From 5d5afd2873eb19791a7366be88ae577961f6755e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:45:12 +0100 Subject: [PATCH 48/52] Put main stuff back --- .../tests/test_metrics_functions.py | 29 +++++++++++++++++++ .../tests/test_quality_metric_calculator.py | 10 ++++++- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 156bab84d8..0a936edb39 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -536,3 +536,32 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): assert np.all(list(sd_ratio.keys()) == sorting_analyzer_simple.unit_ids) # @aurelien can you check this, this is not working anymore # assert np.allclose(list(sd_ratio.values()), 1, atol=0.25, rtol=0) + + +if __name__ == "__main__": + + sorting_analyzer = _sorting_analyzer_simple() + print(sorting_analyzer) + + test_unit_structure_in_output(_small_sorting_analyzer()) + + # test_calculate_firing_rate_num_spikes(sorting_analyzer) + + test_calculate_snrs(sorting_analyzer) + # test_calculate_amplitude_cutoff(sorting_analyzer) + # test_calculate_presence_ratio(sorting_analyzer) + # test_calculate_amplitude_median(sorting_analyzer) + # test_calculate_sliding_rp_violations(sorting_analyzer) + # test_calculate_drift_metrics(sorting_analyzer) + # test_synchrony_metrics(sorting_analyzer) + # test_synchrony_metrics_unit_id_subset(sorting_analyzer) + # test_synchrony_metrics_no_unit_ids(sorting_analyzer) + # test_calculate_firing_range(sorting_analyzer) + # test_calculate_amplitude_cv_metrics(sorting_analyzer) + # test_calculate_sd_ratio(sorting_analyzer) + + # sorting_analyzer_violations = _sorting_analyzer_violations() + # print(sorting_analyzer_violations) + # test_calculate_isi_violations(sorting_analyzer_violations) + # test_calculate_sliding_rp_violations(sorting_analyzer_violations) + # test_calculate_rp_violations(sorting_analyzer_violations) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index fec5ceeb95..a6415c58e8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -19,7 +19,6 @@ def test_compute_quality_metrics(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple - print(sorting_analyzer) # without PCs metrics = compute_quality_metrics( @@ -245,3 +244,12 @@ def test_empty_units(sorting_analyzer_simple): # for metric_name in metrics.columns: # # NaNs are skipped # assert np.allclose(metrics[metric_name].dropna(), metrics_par[metric_name].dropna()) + +if __name__ == "__main__": + + sorting_analyzer = get_sorting_analyzer() + print(sorting_analyzer) + + test_compute_quality_metrics(sorting_analyzer) + test_compute_quality_metrics_recordingless(sorting_analyzer) + test_empty_units(sorting_analyzer) From f3aac424c6da100aac0614b7f5382c0d797ebd17 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 10:46:59 +0100 Subject: [PATCH 49/52] oups --- .../qualitymetrics/tests/test_metrics_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 0a936edb39..e7fc7ce209 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -547,7 +547,7 @@ def test_calculate_sd_ratio(sorting_analyzer_simple): # test_calculate_firing_rate_num_spikes(sorting_analyzer) - test_calculate_snrs(sorting_analyzer) + # test_calculate_snrs(sorting_analyzer) # test_calculate_amplitude_cutoff(sorting_analyzer) # test_calculate_presence_ratio(sorting_analyzer) # test_calculate_amplitude_median(sorting_analyzer) From 01485d91646b0234d0ade2df0b98db2d197ca4fc Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 12:17:36 +0200 Subject: [PATCH 50/52] Add space before colons in docs --- src/spikeinterface/widgets/unit_summary.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 851dfa049a..755e60ccbf 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -22,22 +22,22 @@ class UnitSummaryWidget(BaseWidget): Parameters ---------- - sorting_analyzer: SortingAnalyzer + sorting_analyzer : SortingAnalyzer The SortingAnalyzer object - unit_id: int or str + unit_id : int or str The unit id to plot the summary of - unit_colors: dict or None, default: None + unit_colors : dict or None, default: None If given, a dictionary with unit ids as keys and colors as values, - sparsity: ChannelSparsity or None, default: None + sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored - subwidget_kwargs: dict or None, default: None + subwidget_kwargs : dict or None, default: None Parameters for the subwidgets in a nested dictionary - unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details) - unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details) - unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) - autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) - amplitudes: AmplitudesWidget (see AmplitudesWidget for details) + unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details) + unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details) + unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes : AmplitudesWidget (see AmplitudesWidget for details) Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ From b35873cd5a3fd7e37cf8178b9699a74cd18c08f8 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Wed, 11 Sep 2024 20:08:17 +0100 Subject: [PATCH 51/52] compute pcs for sortinganalyzer again --- src/spikeinterface/qualitymetrics/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 01fa16c8d7..676889094b 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -67,5 +67,6 @@ def sorting_analyzer_simple(): sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("spike_amplitudes", **job_kwargs) + sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) return sorting_analyzer From 41f73ed5c68992a871531a1d832e4179c2e9e02d Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 12 Sep 2024 09:50:19 +0100 Subject: [PATCH 52/52] re-remove pcs --- src/spikeinterface/qualitymetrics/tests/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/tests/conftest.py b/src/spikeinterface/qualitymetrics/tests/conftest.py index 676889094b..01fa16c8d7 100644 --- a/src/spikeinterface/qualitymetrics/tests/conftest.py +++ b/src/spikeinterface/qualitymetrics/tests/conftest.py @@ -67,6 +67,5 @@ def sorting_analyzer_simple(): sorting_analyzer.compute("waveforms", **job_kwargs) sorting_analyzer.compute("templates") sorting_analyzer.compute("spike_amplitudes", **job_kwargs) - sorting_analyzer.compute("principal_components", n_components=5, mode="by_channel_local", **job_kwargs) return sorting_analyzer