From c1e5e4ec1f889e3070a19734e4ea9169a872e280 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 10:59:16 +0200 Subject: [PATCH 01/17] extend estimate_sparsity methods and fix from_ptp --- src/spikeinterface/core/sparsity.py | 137 +++++++++--------- .../core/tests/test_sparsity.py | 29 +++- 2 files changed, 98 insertions(+), 68 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index a38562ea2c..57e1fa4769 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -338,7 +338,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None + assert noise_levels is not None, "To compute sparsity from snr you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -353,17 +353,17 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): + def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the SNR threshold. + Use the "threshold" argument to specify the peak-to-peak threshold. """ assert ( templates_or_sorting_analyzer.sparsity is None ), "To compute sparsity you need a dense SortingAnalyzer or Templates" - from .template_tools import get_template_amplitudes + from .template_tools import get_dense_templates_array from .sortinganalyzer import SortingAnalyzer from .template import Templates @@ -371,23 +371,17 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from snr you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None return_scaled = templates_or_sorting_analyzer.is_scaled - from .template_tools import get_dense_templates_array - mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -455,15 +449,15 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( - templates_or_sorting_analyzer, - noise_levels=None, - method="radius", - peak_sign="neg", - num_channels=5, - radius_um=100.0, - threshold=5, - by_property=None, -): + templates_or_sorting_analyzer: Union[Templates, SortingAnalyzer], + noise_levels: np.ndarray | None = None, + method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", + num_channels: int | None = 5, + radius_um: float | None = 100.0, + threshold: float | None = 5, + by_property: str | None = None, +) -> ChannelSparsity: """ Get channel sparsity (subset of channels) for each template with several methods. @@ -500,11 +494,6 @@ def compute_sparsity( templates_or_sorting_analyzer, SortingAnalyzer ), f"compute_sparsity(method='{method}') need SortingAnalyzer" - if method in ("snr", "ptp") and isinstance(templates_or_sorting_analyzer, Templates): - assert ( - noise_levels is not None - ), f"compute_sparsity(..., method='{method}') with Templates need noise_levels as input" - if method == "best_channels": assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" sparsity = ChannelSparsity.from_best_channels(templates_or_sorting_analyzer, num_channels, peak_sign=peak_sign) @@ -521,7 +510,6 @@ def compute_sparsity( sparsity = ChannelSparsity.from_ptp( templates_or_sorting_analyzer, threshold, - noise_levels=noise_levels, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -544,10 +532,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" = "radius", - peak_sign: str = "neg", + method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, + threshold: float | None = 5, + by_property: str | None = None, **job_kwargs, ): """ @@ -567,16 +557,15 @@ def estimate_sparsity( The sorting recording : BaseRecording The recording - num_spikes_for_sparsity : int, default: 100 How many spikes per units to compute the sparsity ms_before : float, default: 1.0 Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels", default: "radius" + method : "radius" | "best_channels" | "ptp" | "by_property", default: "radius" Sparsity method propagated to the `compute_sparsity()` function. - Only "radius" or "best_channels" are implemented + "snr" and "energy" are not available here, because they require noise levels. peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float, default: 100.0 @@ -594,7 +583,10 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels"), "estimate_sparsity() handle only method='radius' or 'best_channel'" + assert method in ("radius", "best_channels", "ptp", "by_property"), ( + f"method={method} is not available for `estimate_sparsity()`. " + "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + ) if recording.get_probes() == 1: # standard case @@ -605,43 +597,54 @@ def estimate_sparsity( chan_locs = recording.get_channel_locations() probe = recording.create_dummy_probe_from_locations(chan_locs) - nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) - - num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] - random_spikes_indices = random_spikes_selection( - sorting, - num_samples, - method="uniform", - max_spikes_per_unit=num_spikes_for_sparsity, - margin_size=max(nbefore, nafter), - seed=2205, - ) - spikes = sorting.to_spike_vector() - spikes = spikes[random_spikes_indices] - - templates_array = estimate_templates_with_accumulator( - recording, - spikes, - sorting.unit_ids, - nbefore, - nafter, - return_scaled=False, - job_name="estimate_sparsity", - **job_kwargs, - ) - templates = Templates( - templates_array=templates_array, - sampling_frequency=recording.sampling_frequency, - nbefore=nbefore, - sparsity_mask=None, - channel_ids=recording.channel_ids, - unit_ids=sorting.unit_ids, - probe=probe, - ) + if method != "by_property": + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + random_spikes_indices = random_spikes_selection( + sorting, + num_samples, + method="uniform", + max_spikes_per_unit=num_spikes_for_sparsity, + margin_size=max(nbefore, nafter), + seed=2205, + ) + spikes = sorting.to_spike_vector() + spikes = spikes[random_spikes_indices] + + templates_array = estimate_templates_with_accumulator( + recording, + spikes, + sorting.unit_ids, + nbefore, + nafter, + return_scaled=False, + job_name="estimate_sparsity", + **job_kwargs, + ) + templates = Templates( + templates_array=templates_array, + sampling_frequency=recording.sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=sorting.unit_ids, + probe=probe, + ) + templates_or_analyzer = templates + else: + from .sortinganalyzer import create_sorting_analyzer + templates_or_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) sparsity = compute_sparsity( - templates, method=method, peak_sign=peak_sign, num_channels=num_channels, radius_um=radius_um + templates_or_analyzer, + method=method, + peak_sign=peak_sign, + num_channels=num_channels, + radius_um=radius_um, + threshold=threshold, + by_property=by_property, ) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index a192d90502..6ee023fc12 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -86,7 +86,7 @@ def test_sparsify_waveforms(): num_active_channels = len(non_zero_indices) assert waveforms_sparse.shape == (num_units, num_samples, num_active_channels) - # Test round-trip (note that this is loosy) + # Test round-trip (note that this is lossy) unit_id = unit_ids[unit_id] non_zero_indices = sparsity.unit_id_to_channel_indices[unit_id] waveforms_dense2 = sparsity.densify_waveforms(waveforms_sparse, unit_id=unit_id) @@ -195,6 +195,33 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) + # ptp : just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=3, + progress_bar=True, + n_jobs=1, + ) + + # by_property : just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="by_property", + by_property="group", + progress_bar=True, + n_jobs=1, + ) + assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + def test_compute_sparsity(): recording, sorting = get_dataset() From 30d7dbbc3998fb820b62a3029a4c36fffd48d71b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 12:38:03 +0200 Subject: [PATCH 02/17] Revert ptp changes --- src/spikeinterface/core/sparsity.py | 24 +++++++++++-------- .../core/tests/test_sparsity.py | 13 ---------- 2 files changed, 14 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 57e1fa4769..4de786cb37 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -353,10 +353,10 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ - Construct sparsity from a thresholds based on template peak-to-peak values. - Use the "threshold" argument to specify the peak-to-peak threshold. + Construct sparsity from a thresholds based on template peak-to-peak relative values. + Use the "threshold" argument to specify the peak-to-peak threshold (with respect to noise_levels). """ assert ( @@ -371,8 +371,12 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): + ext = templates_or_sorting_analyzer.get_extension("noise_levels") + assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): + assert noise_levels is not None, "To compute sparsity from ptp you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -381,7 +385,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -397,7 +401,7 @@ def from_energy(cls, sorting_analyzer, threshold): # noise_levels ext = sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" + assert ext is not None, "To compute sparsity from energy you need to compute 'noise_levels' first" noise_levels = ext.data["noise_levels"] # waveforms @@ -532,7 +536,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + method: "radius" | "best_channels" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -563,9 +567,9 @@ def estimate_sparsity( Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels" | "ptp" | "by_property", default: "radius" + method : "radius" | "best_channels" | "by_property", default: "radius" Sparsity method propagated to the `compute_sparsity()` function. - "snr" and "energy" are not available here, because they require noise levels. + "snr", "ptp", and "energy" are not available here because they require noise levels. peak_sign : "neg" | "pos" | "both", default: "neg" Sign of the template to compute best channels radius_um : float, default: 100.0 @@ -583,9 +587,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "by_property"), ( + assert method in ("radius", "best_channels", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + "Available methods are 'radius', 'best_channels', 'by_property'" ) if recording.get_probes() == 1: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 6ee023fc12..b60c1c2eca 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -195,19 +195,6 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) - # ptp : just run it - sparsity = estimate_sparsity( - sorting, - recording, - num_spikes_for_sparsity=50, - ms_before=1.0, - ms_after=2.0, - method="ptp", - threshold=3, - progress_bar=True, - n_jobs=1, - ) - # by_property : just run it sparsity = estimate_sparsity( sorting, From fa8eb1f5b349846f13247537f8b3460b45ae9deb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:10:14 +0200 Subject: [PATCH 03/17] Expose snr_amplitude_mode for snr sparsity --- src/spikeinterface/core/sparsity.py | 43 +++++++++++-------- .../core/tests/test_sparsity.py | 19 +++++++- 2 files changed, 44 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 4de786cb37..6904886dcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -25,14 +25,16 @@ * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. - peak_sign : str - Sign of the template to compute best channels ("neg", "pos", "both") + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels num_channels : int Number of channels for "best_channels" method radius_um : float Radius in um for "radius" method threshold : float Threshold in SNR "threshold" method + snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr" method by_property : object Property name for "by_property" method """ @@ -316,7 +318,9 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): return cls(mask, templates_or_sorting_analyzer.unit_ids, templates_or_sorting_analyzer.channel_ids) @classmethod - def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, peak_sign="neg"): + def from_snr( + cls, templates_or_sorting_analyzer, threshold, snr_amplitude_mode="extremum", noise_levels=None, peak_sign="neg" + ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. @@ -344,7 +348,7 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode="extremum", return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=snr_amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -353,10 +357,10 @@ def from_snr(cls, templates_or_sorting_analyzer, threshold, noise_levels=None, p return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): + def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ - Construct sparsity from a thresholds based on template peak-to-peak relative values. - Use the "threshold" argument to specify the peak-to-peak threshold (with respect to noise_levels). + Construct sparsity from a thresholds based on template peak-to-peak values. + Use the "threshold" argument to specify the peak-to-peak threshold. """ assert ( @@ -371,12 +375,8 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - ext = templates_or_sorting_analyzer.get_extension("noise_levels") - assert ext is not None, "To compute sparsity from ptp you need to compute 'noise_levels' first" - noise_levels = ext.data["noise_levels"] return_scaled = templates_or_sorting_analyzer.return_scaled elif isinstance(templates_or_sorting_analyzer, Templates): - assert noise_levels is not None, "To compute sparsity from ptp you need to provide noise_levels" return_scaled = templates_or_sorting_analyzer.is_scaled mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") @@ -385,7 +385,7 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): templates_ptps = np.ptp(templates_array, axis=1) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] / noise_levels >= threshold) + chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -453,7 +453,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( - templates_or_sorting_analyzer: Union[Templates, SortingAnalyzer], + templates_or_sorting_analyzer: Templates | SortingAnalyzer, noise_levels: np.ndarray | None = None, method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", @@ -461,6 +461,7 @@ def compute_sparsity( radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, + snr_amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> ChannelSparsity: """ Get channel sparsity (subset of channels) for each template with several methods. @@ -507,7 +508,11 @@ def compute_sparsity( elif method == "snr": assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" sparsity = ChannelSparsity.from_snr( - templates_or_sorting_analyzer, threshold, noise_levels=noise_levels, peak_sign=peak_sign + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + snr_amplitude_mode=snr_amplitude_mode, ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" @@ -536,11 +541,12 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "by_property" = "radius", + method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, + snr_amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, **job_kwargs, ): @@ -576,6 +582,8 @@ def estimate_sparsity( Used for "radius" method num_channels : int, default: 5 Used for "best_channels" method + snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Used for "snr" method to compute the amplitude of the templates. {} @@ -587,9 +595,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "by_property"), ( + assert method in ("radius", "best_channels", "ptp", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'by_property'" + "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" ) if recording.get_probes() == 1: @@ -649,6 +657,7 @@ def estimate_sparsity( radius_um=radius_um, threshold=threshold, by_property=by_property, + snr_amplitude_mode=snr_amplitude_mode, ) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index b60c1c2eca..16b3bbc996 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -195,7 +195,7 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 3) - # by_property : just run it + # by_property sparsity = estimate_sparsity( sorting, recording, @@ -209,6 +209,20 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) + # ptp: just run it + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -226,6 +240,9 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") + sparsity = compute_sparsity( + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", snr_amplitude_mode="peak_to_peak" + ) sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") From b5c56d64373b217dcf395e95e1a2b7b70a5f8dbe Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:35:45 +0200 Subject: [PATCH 04/17] Propagate sparsity change --- src/spikeinterface/sorters/internal/spyking_circus2.py | 2 +- .../benchmark/tests/test_benchmark_matching.py | 4 +++- .../sortingcomponents/clustering/random_projections.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 4701d76012..a6d212425d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index aa9b16bb97..d6d0440a02 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,7 +27,9 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity(gt_templates, noise_levels, method="ptp", threshold=0.25) + sparsity = compute_sparsity( + gt_templates, noise_levels, method="snr", snr_amplitude_mode="peak_to_peak", threshold=0.25 + ) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index 77d47aec16..b56fd3e02b 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy", From cc14ab2697331ca4b9ca35e972d959e2cbe07a03 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 5 Sep 2024 13:39:10 +0200 Subject: [PATCH 05/17] last one --- src/spikeinterface/sortingcomponents/clustering/circus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 2bacf36ac9..11a628bb53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "ptp", "threshold": 0.25}, + "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, From fa84e8c6869a475046e4b11c959a8bbac7c5d106 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 16:52:54 +0200 Subject: [PATCH 06/17] Potentiate 'estimate_sparsity' and refactor from_property() constructor --- src/spikeinterface/core/sparsity.py | 261 +++++++++++++----- .../core/tests/test_sparsity.py | 34 ++- .../sorters/internal/spyking_circus2.py | 2 +- .../tests/test_benchmark_matching.py | 4 +- .../sortingcomponents/clustering/circus.py | 2 +- .../clustering/random_projections.py | 2 +- 6 files changed, 226 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 6904886dcf..c72f89520e 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -13,30 +13,33 @@ _sparsity_doc = """ method : str * "best_channels" : N best channels with the largest amplitude. Use the "num_channels" argument to specify the - number of channels. - * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um + number of channels. + * "radius" : radius around the best channel. Use the "radius_um" argument to specify the radius in um. * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument - to specify the SNR threshold (in units of noise levels) + to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of noise levels) + to specify the ptp threshold (in units of amplitude). * "energy" : threshold based on the expected energy that should be present on the channels, - given their noise levels. Use the "threshold" argument to specify the SNR threshold + given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) - * "by_property" : sparsity is given by a property of the recording and sorting(e.g. "group"). - Use the "by_property" argument to specify the property name. + * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). + In this case the sparsity for each unit is given by the channels that have the same property + value as the unit. Use the "by_property" argument to specify the property name. peak_sign : "neg" | "pos" | "both" - Sign of the template to compute best channels + Sign of the template to compute best channels. num_channels : int - Number of channels for "best_channels" method + Number of channels for "best_channels" method. radius_um : float - Radius in um for "radius" method + Radius in um for "radius" method. threshold : float - Threshold in SNR "threshold" method - snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak" - Mode to compute the amplitude of the templates for the "snr" method + Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). + For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak" + Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods. by_property : object - Property name for "by_property" method + Property name for "by_property" method. """ @@ -279,18 +282,35 @@ def from_dict(cls, dictionary: dict): ## Some convinient function to compute sparsity from several strategy @classmethod - def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg"): + def from_best_channels( + cls, templates_or_sorting_analyzer, num_channels, peak_sign="neg", amplitude_mode="extremum" + ): """ Construct sparsity from N best channels with the largest amplitude. Use the "num_channels" argument to specify the number of channels. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + num_channels : int + Number of channels for "best_channels" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity """ from .template_tools import get_template_amplitudes - print(templates_or_sorting_analyzer) mask = np.zeros( (templates_or_sorting_analyzer.unit_ids.size, templates_or_sorting_analyzer.channel_ids.size), dtype="bool" ) - peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign) + peak_values = get_template_amplitudes(templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode) for unit_ind, unit_id in enumerate(templates_or_sorting_analyzer.unit_ids): chan_inds = np.argsort(np.abs(peak_values[unit_id]))[::-1] chan_inds = chan_inds[:num_channels] @@ -301,7 +321,21 @@ def from_best_channels(cls, templates_or_sorting_analyzer, num_channels, peak_si def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): """ Construct sparsity from a radius around the best channel. - Use the "radius_um" argument to specify the radius in um + Use the "radius_um" argument to specify the radius in um. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + radius_um : float + Radius in um for "radius" method. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute best channels. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_extremum_channel @@ -319,11 +353,37 @@ def from_radius(cls, templates_or_sorting_analyzer, radius_um, peak_sign="neg"): @classmethod def from_snr( - cls, templates_or_sorting_analyzer, threshold, snr_amplitude_mode="extremum", noise_levels=None, peak_sign="neg" + cls, + templates_or_sorting_analyzer, + threshold, + amplitude_mode="extremum", + peak_sign="neg", + noise_levels=None, ): """ Construct sparsity from a thresholds based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "snr" method (in units of noise levels). + noise_levels : np.array | None, default: None + Noise levels required for the "snr" method. You can use the + `get_noise_levels()` function to compute them. + If the input is a `SortingAnalyzer`, the noise levels are automatically retrieved + if the `noise_levels` extension is present. + peak_sign : "neg" | "pos" | "both" + Sign of the template to compute amplitudes. + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates for the "snr" method. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer @@ -348,7 +408,7 @@ def from_snr( mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") peak_values = get_template_amplitudes( - templates_or_sorting_analyzer, peak_sign=peak_sign, mode=snr_amplitude_mode, return_scaled=return_scaled + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=return_scaled ) for unit_ind, unit_id in enumerate(unit_ids): @@ -361,6 +421,18 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the peak-to-peak threshold. + + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "ptp" method (in units of amplitude). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert ( @@ -394,6 +466,19 @@ def from_energy(cls, sorting_analyzer, threshold): """ Construct sparsity from a threshold based on per channel energy ratio. Use the "threshold" argument to specify the SNR threshold. + This method requires the "waveforms" and "noise_levels" extensions to be computed. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + threshold : float + Threshold for "energy" method (in units of noise levels). + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ assert sorting_analyzer.sparsity is None, "To compute sparsity with energy you need a dense SortingAnalyzer" @@ -419,41 +504,61 @@ def from_energy(cls, sorting_analyzer, threshold): return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) @classmethod - def from_property(cls, sorting_analyzer, by_property): + def from_property(cls, sorting, recording, by_property): """ Construct sparsity witha property of the recording and sorting(e.g. "group"). Use the "by_property" argument to specify the property name. + + Parameters + ---------- + sorting : Sorting + A Sorting object. + recording : Recording + A Recording object. + by_property : object + Property name for "by_property" method. Both the recording and sorting must have this property set. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. """ # check consistency - assert ( - by_property in sorting_analyzer.recording.get_property_keys() - ), f"Property {by_property} is not a recording property" - assert ( - by_property in sorting_analyzer.sorting.get_property_keys() - ), f"Property {by_property} is not a sorting property" + assert by_property in recording.get_property_keys(), f"Property {by_property} is not a recording property" + assert by_property in sorting.get_property_keys(), f"Property {by_property} is not a sorting property" - mask = np.zeros((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") - rec_by = sorting_analyzer.recording.split_by(by_property) - for unit_ind, unit_id in enumerate(sorting_analyzer.unit_ids): - unit_property = sorting_analyzer.sorting.get_property(by_property)[unit_ind] + mask = np.zeros((sorting.unit_ids.size, recording.channel_ids.size), dtype="bool") + rec_by = recording.split_by(by_property) + for unit_ind, unit_id in enumerate(sorting.unit_ids): + unit_property = sorting.get_property(by_property)[unit_ind] assert ( unit_property in rec_by.keys() ), f"Unit property {unit_property} cannot be found in the recording properties" - chan_inds = sorting_analyzer.recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) + chan_inds = recording.ids_to_indices(rec_by[unit_property].get_channel_ids()) mask[unit_ind, chan_inds] = True - return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) + return cls(mask, sorting.unit_ids, recording.channel_ids) @classmethod def create_dense(cls, sorting_analyzer): """ Create a sparsity object with all selected channel for all units. + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + A SortingAnalyzer object. + + Returns + ------- + sparsity : ChannelSparsity + The full sparsity. """ mask = np.ones((sorting_analyzer.unit_ids.size, sorting_analyzer.channel_ids.size), dtype="bool") return cls(mask, sorting_analyzer.unit_ids, sorting_analyzer.channel_ids) def compute_sparsity( - templates_or_sorting_analyzer: Templates | SortingAnalyzer, + templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", @@ -461,10 +566,10 @@ def compute_sparsity( radius_um: float | None = 100.0, threshold: float | None = 5, by_property: str | None = None, - snr_amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", + amplitude_mode: "extremum" | "at_index" | "peak_to_peak" = "extremum", ) -> ChannelSparsity: """ - Get channel sparsity (subset of channels) for each template with several methods. + Compute channel sparsity from a `SortingAnalyzer` for each template with several methods. Parameters ---------- @@ -512,7 +617,7 @@ def compute_sparsity( threshold, noise_levels=noise_levels, peak_sign=peak_sign, - snr_amplitude_mode=snr_amplitude_mode, + amplitude_mode=amplitude_mode, ) elif method == "ptp": assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" @@ -525,7 +630,9 @@ def compute_sparsity( sparsity = ChannelSparsity.from_energy(templates_or_sorting_analyzer, threshold) elif method == "by_property": assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" - sparsity = ChannelSparsity.from_property(templates_or_sorting_analyzer, by_property) + sparsity = ChannelSparsity.from_property( + templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -541,19 +648,20 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "by_property" = "radius", + method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, threshold: float | None = 5, - snr_amplitude_mode: "extremum" | "peak_to_peak" = "extremum", + amplitude_mode: "extremum" | "peak_to_peak" = "extremum", by_property: str | None = None, + noise_levels: np.ndarray | list | None = None, **job_kwargs, ): """ - Estimate the sparsity without needing a SortingAnalyzer or Templates object - This is faster than `spikeinterface.waveforms_extractor.precompute_sparsity()` and it - traverses the recording to compute the average templates for each unit. + Estimate the sparsity without needing a SortingAnalyzer or Templates object. + In case the sparsity method needs templates, they are computed on-the-fly. + The same is done for noise levels, if needed by the method ("snr"). Contrary to the previous implementation: * all units are computed in one read of recording @@ -561,6 +669,8 @@ def estimate_sparsity( * it doesn't consume too much memory * it uses internally the `estimate_templates_with_accumulator()` which is fast and parallel + Note that the "energy" method is not supported because it requires a `SortingAnalyzer` object. + Parameters ---------- sorting : BaseSorting @@ -573,18 +683,9 @@ def estimate_sparsity( Cut out in ms before spike time ms_after : float, default: 2.5 Cut out in ms after spike time - method : "radius" | "best_channels" | "by_property", default: "radius" - Sparsity method propagated to the `compute_sparsity()` function. - "snr", "ptp", and "energy" are not available here because they require noise levels. - peak_sign : "neg" | "pos" | "both", default: "neg" - Sign of the template to compute best channels - radius_um : float, default: 100.0 - Used for "radius" method - num_channels : int, default: 5 - Used for "best_channels" method - snr_amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" - Used for "snr" method to compute the amplitude of the templates. - + noise_levels : np.array | None, default: None + Noise levels required for the "snr" and "energy" methods. You can use the + `get_noise_levels()` function to compute them. {} Returns @@ -595,9 +696,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "by_property"), ( + assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'by_property'" + "Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'" ) if recording.get_probes() == 1: @@ -644,21 +745,39 @@ def estimate_sparsity( unit_ids=sorting.unit_ids, probe=probe, ) - templates_or_analyzer = templates + + if method == "best_channels": + assert num_channels is not None, "For the 'best_channels' method, 'num_channels' needs to be given" + sparsity = ChannelSparsity.from_best_channels( + templates, num_channels, peak_sign=peak_sign, amplitude_mode=amplitude_mode + ) + elif method == "radius": + assert radius_um is not None, "For the 'radius' method, 'radius_um' needs to be given" + sparsity = ChannelSparsity.from_radius(templates, radius_um, peak_sign=peak_sign) + elif method == "snr": + assert threshold is not None, "For the 'snr' method, 'threshold' needs to be given" + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." + ) + sparsity = ChannelSparsity.from_snr( + templates, + threshold, + noise_levels=noise_levels, + peak_sign=peak_sign, + amplitude_mode=amplitude_mode, + ) + elif method == "ptp": + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates, + threshold, + ) + else: + raise ValueError(f"compute_sparsity() method={method} does not exists") else: - from .sortinganalyzer import create_sorting_analyzer - - templates_or_analyzer = create_sorting_analyzer(sorting=sorting, recording=recording, sparse=False) - sparsity = compute_sparsity( - templates_or_analyzer, - method=method, - peak_sign=peak_sign, - num_channels=num_channels, - radius_um=radius_um, - threshold=threshold, - by_property=by_property, - snr_amplitude_mode=snr_amplitude_mode, - ) + assert by_property is not None, "For the 'by_property' method, 'by_property' needs to be given" + sparsity = ChannelSparsity.from_property(sorting, recording, by_property) return sparsity diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 16b3bbc996..64517c106d 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -3,7 +3,7 @@ import numpy as np import json -from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, Templates +from spikeinterface.core import ChannelSparsity, estimate_sparsity, compute_sparsity, get_noise_levels from spikeinterface.core.core_tools import check_json from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer @@ -223,6 +223,36 @@ def test_estimate_sparsity(): n_jobs=1, ) + # snr: fails without noise levels + with pytest.raises(AssertionError): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + # snr: works with noise levels + noise_levels = get_noise_levels(recording) + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="snr", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) + def test_compute_sparsity(): recording, sorting = get_dataset() @@ -241,7 +271,7 @@ def test_compute_sparsity(): sparsity = compute_sparsity(sorting_analyzer, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(sorting_analyzer, method="snr", threshold=5, peak_sign="neg") sparsity = compute_sparsity( - sorting_analyzer, method="snr", threshold=5, peak_sign="neg", snr_amplitude_mode="peak_to_peak" + sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" ) sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a6d212425d..c3b3099535 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -25,7 +25,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): _default_params = { "general": {"ms_before": 2, "ms_after": 2, "radius_um": 100}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2}, "detection": {"peak_sign": "neg", "detect_threshold": 4}, "selection": { diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index d6d0440a02..71a5f282a8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -27,9 +27,7 @@ def test_benchmark_matching(create_cache_folder): recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs ) noise_levels = get_noise_levels(recording) - sparsity = compute_sparsity( - gt_templates, noise_levels, method="snr", snr_amplitude_mode="peak_to_peak", threshold=0.25 - ) + sparsity = compute_sparsity(gt_templates, noise_levels, method="snr", amplitude_mode="peak_to_peak", threshold=0.25) gt_templates = gt_templates.to_sparse(sparsity) # create study diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 11a628bb53..b08ee4d9cb 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -50,7 +50,7 @@ class CircusClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "recursive_kwargs": { "recursive": True, "recursive_depth": 3, diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index b56fd3e02b..f7ca999d53 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -45,7 +45,7 @@ class RandomProjectionClustering: }, "cleaning_kwargs": {}, "waveforms": {"ms_before": 2, "ms_after": 2}, - "sparsity": {"method": "snr", "snr_amplitude_mode": "peak_to_peak", "threshold": 0.25}, + "sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25}, "radius_um": 30, "nb_projections": 10, "feature": "energy", From 15a4a11bad4e08bdf4d8ce5de67e05dd2a2a8fab Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 18:14:34 +0200 Subject: [PATCH 07/17] minor docstring fix --- src/spikeinterface/core/sparsity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index c72f89520e..471302d57e 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -661,7 +661,8 @@ def estimate_sparsity( """ Estimate the sparsity without needing a SortingAnalyzer or Templates object. In case the sparsity method needs templates, they are computed on-the-fly. - The same is done for noise levels, if needed by the method ("snr"). + For the "snr" method, `noise_levels` must passed with the `noise_levels` argument. + These can be computed with the `get_noise_levels()` function. Contrary to the previous implementation: * all units are computed in one read of recording From 9ffda3543aa794418af02830be70144de64c54f2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:33:48 +0200 Subject: [PATCH 08/17] Add from_amplitude() option to sparsity and deprecate ptp --- src/spikeinterface/core/sparsity.py | 105 ++++++++++++++---- .../core/tests/test_sparsity.py | 31 +++++- 2 files changed, 108 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 471302d57e..fd613e1fcf 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import warnings from .basesorting import BaseSorting @@ -18,14 +19,16 @@ * "snr" : threshold based on template signal-to-noise ratio. Use the "threshold" argument to specify the SNR threshold (in units of noise levels) and the "amplitude_mode" argument to specify the mode to compute the amplitude of the templates. - * "ptp" : threshold based on the peak-to-peak values on every channels. Use the "threshold" argument - to specify the ptp threshold (in units of amplitude). + * "amplitude" : threshold based on the amplitude values on every channels. Use the "threshold" argument + to specify the ptp threshold (in units of amplitude) and the "amplitude_mode" argument + to specify the mode to compute the amplitude of the templates. * "energy" : threshold based on the expected energy that should be present on the channels, given their noise levels. Use the "threshold" argument to specify the energy threshold (in units of noise levels) * "by_property" : sparsity is given by a property of the recording and sorting (e.g. "group"). In this case the sparsity for each unit is given by the channels that have the same property value as the unit. Use the "by_property" argument to specify the property name. + * "ptp: : deprecated, use the 'snr' method with the 'peak_to_peak' amplitude mode instead. peak_sign : "neg" | "pos" | "both" Sign of the template to compute best channels. @@ -37,7 +40,7 @@ Threshold for "snr", "energy" (in units of noise levels) and "ptp" methods (in units of amplitude). For the "snr" method, the template amplitude mode is controlled by the "amplitude_mode" argument. amplitude_mode : "extremum" | "at_index" | "peak_to_peak" - Mode to compute the amplitude of the templates for the "snr" and "best_channels" methods. + Mode to compute the amplitude of the templates for the "snr", "amplitude", and "best_channels" methods. by_property : object Property name for "by_property" method. """ @@ -417,7 +420,7 @@ def from_snr( return cls(mask, unit_ids, channel_ids) @classmethod - def from_ptp(cls, templates_or_sorting_analyzer, threshold): + def from_ptp(cls, templates_or_sorting_analyzer, threshold, noise_levels=None): """ Construct sparsity from a thresholds based on template peak-to-peak values. Use the "threshold" argument to specify the peak-to-peak threshold. @@ -434,30 +437,67 @@ def from_ptp(cls, templates_or_sorting_analyzer, threshold): sparsity : ChannelSparsity The estimated sparsity. """ + warnings.warn( + "The 'ptp' method is deprecated and will be removed in version 0.103.0. " + "Please use the 'snr' method with the 'peak_to_peak' amplitude mode instead.", + DeprecationWarning, + ) + return cls.from_snr( + templates_or_sorting_analyzer, threshold, amplitude_mode="peak_to_peak", noise_levels=noise_levels + ) - assert ( - templates_or_sorting_analyzer.sparsity is None - ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + @classmethod + def from_amplitude(cls, templates_or_sorting_analyzer, threshold, amplitude_mode="extremum", peak_sign="neg"): + """ + Construct sparsity from a threshold based on template amplitude. + The amplitude is computed with the specified amplitude mode and it is assumed + that the amplitude is in uV. The input `Templates` or `SortingAnalyzer` object must + have scaled templates. - from .template_tools import get_dense_templates_array + Parameters + ---------- + templates_or_sorting_analyzer : Templates | SortingAnalyzer + A Templates or a SortingAnalyzer object. + threshold : float + Threshold for "amplitude" method (in uV). + amplitude_mode : "extremum" | "at_index" | "peak_to_peak", default: "extremum" + Mode to compute the amplitude of the templates. + + Returns + ------- + sparsity : ChannelSparsity + The estimated sparsity. + """ + from .template_tools import get_template_amplitudes from .sortinganalyzer import SortingAnalyzer from .template import Templates + assert ( + templates_or_sorting_analyzer.sparsity is None + ), "To compute sparsity you need a dense SortingAnalyzer or Templates" + unit_ids = templates_or_sorting_analyzer.unit_ids channel_ids = templates_or_sorting_analyzer.channel_ids if isinstance(templates_or_sorting_analyzer, SortingAnalyzer): - return_scaled = templates_or_sorting_analyzer.return_scaled + assert templates_or_sorting_analyzer.return_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `return_scaled=True` when computing the templates." + ) elif isinstance(templates_or_sorting_analyzer, Templates): - return_scaled = templates_or_sorting_analyzer.is_scaled + assert templates_or_sorting_analyzer.is_scaled, ( + "To compute sparsity from amplitude you need to have scaled templates. " + "You can set `is_scaled=True` when creating the Templates object." + ) mask = np.zeros((unit_ids.size, channel_ids.size), dtype="bool") - templates_array = get_dense_templates_array(templates_or_sorting_analyzer, return_scaled=return_scaled) - templates_ptps = np.ptp(templates_array, axis=1) + peak_values = get_template_amplitudes( + templates_or_sorting_analyzer, peak_sign=peak_sign, mode=amplitude_mode, return_scaled=True + ) for unit_ind, unit_id in enumerate(unit_ids): - chan_inds = np.nonzero(templates_ptps[unit_ind] >= threshold) + chan_inds = np.nonzero((np.abs(peak_values[unit_id])) >= threshold) mask[unit_ind, chan_inds] = True return cls(mask, unit_ids, channel_ids) @@ -560,7 +600,7 @@ def create_dense(cls, sorting_analyzer): def compute_sparsity( templates_or_sorting_analyzer: "Templates | SortingAnalyzer", noise_levels: np.ndarray | None = None, - method: "radius" | "best_channels" | "snr" | "ptp" | "energy" | "by_property" = "radius", + method: "radius" | "best_channels" | "snr" | "amplitude" | "energy" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", num_channels: int | None = 5, radius_um: float | None = 100.0, @@ -595,7 +635,7 @@ def compute_sparsity( # to keep backward compatibility templates_or_sorting_analyzer = templates_or_sorting_analyzer.sorting_analyzer - if method in ("best_channels", "radius", "snr", "ptp"): + if method in ("best_channels", "radius", "snr", "amplitude", "ptp"): assert isinstance( templates_or_sorting_analyzer, (Templates, SortingAnalyzer) ), f"compute_sparsity(method='{method}') need Templates or SortingAnalyzer" @@ -619,11 +659,13 @@ def compute_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) - elif method == "ptp": - assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( templates_or_sorting_analyzer, threshold, + amplitude_mode=amplitude_mode, + peak_sign=peak_sign, ) elif method == "energy": assert threshold is not None, "For the 'energy' method, 'threshold' needs to be given" @@ -633,6 +675,14 @@ def compute_sparsity( sparsity = ChannelSparsity.from_property( templates_or_sorting_analyzer.sorting, templates_or_sorting_analyzer.recording, by_property ) + elif method == "ptp": + # TODO: remove after deprecation + assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_ptp( + templates_or_sorting_analyzer, + threshold, + noise_levels=noise_levels, + ) else: raise ValueError(f"compute_sparsity() method={method} does not exists") @@ -648,7 +698,7 @@ def estimate_sparsity( num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, - method: "radius" | "best_channels" | "ptp" | "snr" | "by_property" = "radius", + method: "radius" | "best_channels" | "amplitude" | "snr" | "by_property" | "ptp" = "radius", peak_sign: "neg" | "pos" | "both" = "neg", radius_um: float = 100.0, num_channels: int = 5, @@ -697,9 +747,9 @@ def estimate_sparsity( # Can't be done at module because this is a cyclic import, too bad from .template import Templates - assert method in ("radius", "best_channels", "ptp", "snr", "by_property"), ( + assert method in ("radius", "best_channels", "snr", "amplitude", "by_property", "ptp"), ( f"method={method} is not available for `estimate_sparsity()`. " - "Available methods are 'radius', 'best_channels', 'ptp', 'snr', 'energy', 'by_property'" + "Available methods are 'radius', 'best_channels', 'snr', 'amplitude', 'by_property', 'ptp' (deprecated)" ) if recording.get_probes() == 1: @@ -768,12 +818,19 @@ def estimate_sparsity( peak_sign=peak_sign, amplitude_mode=amplitude_mode, ) + elif method == "amplitude": + assert threshold is not None, "For the 'amplitude' method, 'threshold' needs to be given" + sparsity = ChannelSparsity.from_amplitude( + templates, threshold, amplitude_mode=amplitude_mode, peak_sign=peak_sign + ) elif method == "ptp": + # TODO: remove after deprecation assert threshold is not None, "For the 'ptp' method, 'threshold' needs to be given" - sparsity = ChannelSparsity.from_ptp( - templates, - threshold, + assert noise_levels is not None, ( + "For the 'snr' method, 'noise_levels' needs to be given. You can use the " + "`get_noise_levels()` function to compute them." ) + sparsity = ChannelSparsity.from_ptp(templates, threshold, noise_levels=noise_levels) else: raise ValueError(f"compute_sparsity() method={method} does not exists") else: diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 64517c106d..ace869df8c 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -209,15 +209,16 @@ def test_estimate_sparsity(): ) assert np.array_equal(np.sum(sparsity.mask, axis=1), np.ones(num_units) * 5) - # ptp: just run it + # amplitude sparsity = estimate_sparsity( sorting, recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, - method="ptp", + method="amplitude", threshold=5, + amplitude_mode="peak_to_peak", chunk_duration="1s", progress_bar=True, n_jobs=1, @@ -252,6 +253,23 @@ def test_estimate_sparsity(): progress_bar=True, n_jobs=1, ) + # ptp: just run it + print(noise_levels) + + with pytest.warns(DeprecationWarning): + sparsity = estimate_sparsity( + sorting, + recording, + num_spikes_for_sparsity=50, + ms_before=1.0, + ms_after=2.0, + method="ptp", + threshold=5, + noise_levels=noise_levels, + chunk_duration="1s", + progress_bar=True, + n_jobs=1, + ) def test_compute_sparsity(): @@ -273,9 +291,11 @@ def test_compute_sparsity(): sparsity = compute_sparsity( sorting_analyzer, method="snr", threshold=5, peak_sign="neg", amplitude_mode="peak_to_peak" ) - sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) + sparsity = compute_sparsity(sorting_analyzer, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") sparsity = compute_sparsity(sorting_analyzer, method="energy", threshold=5) sparsity = compute_sparsity(sorting_analyzer, method="by_property", by_property="group") + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(sorting_analyzer, method="ptp", threshold=5) # using object Templates templates = sorting_analyzer.get_extension("templates").get_data(outputs="Templates") @@ -283,7 +303,10 @@ def test_compute_sparsity(): sparsity = compute_sparsity(templates, method="best_channels", num_channels=2, peak_sign="neg") sparsity = compute_sparsity(templates, method="radius", radius_um=50.0, peak_sign="neg") sparsity = compute_sparsity(templates, method="snr", noise_levels=noise_levels, threshold=5, peak_sign="neg") - sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) + sparsity = compute_sparsity(templates, method="amplitude", threshold=5, amplitude_mode="peak_to_peak") + + with pytest.warns(DeprecationWarning): + sparsity = compute_sparsity(templates, method="ptp", noise_levels=noise_levels, threshold=5) if __name__ == "__main__": From 3a13efe7607404a36337baa083b8dbe283a45bf0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 13 Sep 2024 12:59:42 +0200 Subject: [PATCH 09/17] Check run info completed only if it exists (back-compatibility) --- src/spikeinterface/core/sortinganalyzer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3aa582da68..8980fb5559 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2240,9 +2240,10 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info[ - "run_completed" - ], f"You must run the extension {self.extension_name} before retrieving data" + if self.run_info is not None: + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) From dcedbb33739663e0cb662d313e0224f3d64d04f8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 13:01:11 +0200 Subject: [PATCH 10/17] Simplify pandas save-load and convert dtypes --- pyproject.toml | 3 -- src/spikeinterface/core/sortinganalyzer.py | 37 ++++++++++++------- .../postprocessing/template_metrics.py | 2 +- .../tests/common_extension_tests.py | 23 +++++++++++- .../quality_metric_calculator.py | 2 +- 5 files changed, 47 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8309ca89fe..b5894bf3a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,6 @@ preprocessing = [ full = [ "h5py", "pandas", - "xarray", "scipy", "scikit-learn", "networkx", @@ -148,7 +147,6 @@ test = [ "pytest-dependency", "pytest-cov", - "xarray", "huggingface_hub", # preprocessing @@ -193,7 +191,6 @@ docs = [ "pandas", # in the modules gallery comparison tutorial "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions - "xarray", # For use of SortingAnalyzer zarr format "networkx", # Download data "pooch>=1.8.2", diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8980fb5559..312f85a8ca 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1970,12 +1970,14 @@ def load_data(self): if "dict" in ext_data_.attrs: ext_data = ext_data_[0] elif "dataframe" in ext_data_.attrs: - import xarray + import pandas as pd - ext_data = xarray.open_zarr( - ext_data_.store, group=f"{extension_group.name}/{ext_data_name}" - ).to_pandas() - ext_data.index.rename("", inplace=True) + index = ext_data_["index"] + ext_data = pd.DataFrame(index=index) + for col in ext_data_.keys(): + if col != "index": + ext_data.loc[:, col] = ext_data_[col][:] + ext_data = ext_data.convert_dtypes() elif "object" in ext_data_.attrs: ext_data = ext_data_[0] else: @@ -2031,12 +2033,21 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): self._save_run_info() self._save_data(**kwargs) + if self.format == "zarr": + import zarr + + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def save(self, **kwargs): self._save_params() self._save_importing_provenance() - self._save_data(**kwargs) self._save_run_info() + self._save_data(**kwargs) + + if self.format == "zarr": + import zarr + + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _save_data(self, **kwargs): if self.format == "memory": @@ -2096,12 +2107,12 @@ def _save_data(self, **kwargs): elif isinstance(ext_data, np.ndarray): extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): - ext_data.to_xarray().to_zarr( - store=extension_group.store, - group=f"{extension_group.name}/{ext_data_name}", - mode="a", - ) - extension_group[ext_data_name].attrs["dataframe"] = True + df_group = extension_group.create_group(ext_data_name) + # first we save the index + df_group.create_dataset(name="index", data=ext_data.index.to_numpy()) + for col in ext_data.columns: + df_group.create_dataset(name=col, data=ext_data[col].to_numpy()) + df_group.attrs["dataframe"] = True else: # any object try: @@ -2111,8 +2122,6 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") extension_group[ext_data_name].attrs["object"] = True - # we need to re-consolidate - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _reset_extension_folder(self): """ diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 45ba55dee4..ee5ac6103b 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -287,7 +287,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") value = np.nan template_metrics.at[index, metric_name] = value - return template_metrics + return template_metrics.convert_dtypes() def _run(self, verbose=False): self.data["metrics"] = self._compute_metrics( diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 3945e71881..1b0a94d635 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,9 +3,10 @@ import pytest import shutil import numpy as np +import pandas as pd from spikeinterface.core import generate_ground_truth_recording -from spikeinterface.core import create_sorting_analyzer +from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer from spikeinterface.core import estimate_sparsity @@ -138,6 +139,26 @@ def _check_one(self, sorting_analyzer, extension_class, params): merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0) assert len(merged.unit_ids) == num_units_after_merge + # test roundtrip + if sorting_analyzer.format in ("binary_folder", "zarr"): + sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder) + ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name) + for ext_data_name, ext_data_loaded in ext_loaded.data.items(): + if isinstance(ext_data_loaded, np.ndarray): + assert np.array_equal(ext.data[ext_data_name], ext_data_loaded) + elif isinstance(ext_data_loaded, pd.DataFrame): + # skip nan values + for col in ext_data_loaded.columns: + np.testing.assert_array_almost_equal( + ext.data[ext_data_name][col].dropna().to_numpy(), + ext_data_loaded[col].dropna().to_numpy(), + decimal=5, + ) + elif isinstance(ext_data_loaded, dict): + assert ext.data[ext_data_name] == ext_data_loaded + else: + continue + def run_extension_tests(self, extension_class, params): """ Convenience function to perform all checks on the extension diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index cdf6151e95..3d7096651f 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -185,7 +185,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - return metrics + return metrics.convert_dtypes() def _run(self, verbose=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( From 9000ce1980130fac15394b7cf68137fe239b0b13 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 13:08:25 +0200 Subject: [PATCH 11/17] local import --- .../postprocessing/tests/common_extension_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index 1b0a94d635..2207b98da6 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -3,7 +3,6 @@ import pytest import shutil import numpy as np -import pandas as pd from spikeinterface.core import generate_ground_truth_recording from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer @@ -117,6 +116,8 @@ def _check_one(self, sorting_analyzer, extension_class, params): with the passed parameters, and check the output is not empty, the extension exists and `select_units()` method works. """ + import pandas as pd + if extension_class.need_job_kwargs: job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True) else: From c1228d9269663118c065ae9e5f68cb1c0a8b16c7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 18:00:28 +0200 Subject: [PATCH 12/17] Add comment and re-consolidation step for 0.101.0 datasets --- src/spikeinterface/core/sortinganalyzer.py | 15 +++++++++++++++ .../postprocessing/template_metrics.py | 6 +++++- .../qualitymetrics/quality_metric_calculator.py | 9 ++++++--- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 312f85a8ca..a94b7aa3dc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from packaging.version import parse from time import perf_counter import numpy as np @@ -579,6 +580,20 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) + si_info = zarr_root.attrs["spikeinterface_info"] + if parse(si_info["version"]) < parse("0.101.1"): + # v0.101.0 did not have a consolidate metadata step after computing extensions. + # Here we try to consolidate the metadata and throw a warning if it fails. + try: + zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options) + zarr.consolidate_metadata(zarr_root_a.store) + except Exception as e: + warnings.warn( + "The zarr store was not properly consolidated prior to v0.101.1. " + "This may lead to unexpected behavior in loading extensions. " + "Please consider re-saving the SortingAnalyzer object." + ) + # load internal sorting in memory sorting = NumpySorting.from_sorting( ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ee5ac6103b..aa50be4c13 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -287,7 +287,11 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") value = np.nan template_metrics.at[index, metric_name] = value - return template_metrics.convert_dtypes() + + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + template_metrics = template_metrics.convert_dtypes() + return template_metrics def _run(self, verbose=False): self.data["metrics"] = self._compute_metrics( diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 3d7096651f..b2804c2638 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -108,6 +108,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job """ Compute quality metrics. """ + import pandas as pd + metric_names = self.params["metric_names"] qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] @@ -132,8 +134,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job non_empty_unit_ids = unit_ids empty_unit_ids = [] - import pandas as pd - metrics = pd.DataFrame(index=unit_ids) # simple metrics not based on PCs @@ -185,7 +185,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan - return metrics.convert_dtypes() + # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns + # (in case of NaN values) + metrics = metrics.convert_dtypes() + return metrics def _run(self, verbose=False, **job_kwargs): self.data["metrics"] = self._compute_metrics( From b1677fabd82f36d0ed51af8418a559661cbfa4e3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sun, 15 Sep 2024 18:43:27 +0200 Subject: [PATCH 13/17] Update src/spikeinterface/qualitymetrics/quality_metric_calculator.py --- src/spikeinterface/qualitymetrics/quality_metric_calculator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 41ad40293a..3b6c6d3e50 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -141,7 +141,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - metric_names = self.params["metric_names"] qm_params = self.params["qm_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] From c33fbd57b4fc7e95edd8a3680ab40a5f291617e0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 10:00:34 +0200 Subject: [PATCH 14/17] Fix plot motion for multi-segment --- src/spikeinterface/widgets/motion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 81cda212b2..06c0305351 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -340,8 +340,11 @@ def __init__( raise ValueError( "plot drift map : the Motion object is multi-segment you must provide segment_index=XX" ) + assert recording.get_num_segments() == len( + motion.displacement + ), "The number of segments in the recording must be the same as the number of segments in the motion object" - times = recording.get_times() if recording is not None else None + times = recording.get_times(segment_index=segment_index) if recording is not None else None plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], From 9a7295948b83bef46736252f49ac6b164acee53d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 10:04:17 +0200 Subject: [PATCH 15/17] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a94b7aa3dc..177188f21d 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -591,7 +591,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): warnings.warn( "The zarr store was not properly consolidated prior to v0.101.1. " "This may lead to unexpected behavior in loading extensions. " - "Please consider re-saving the SortingAnalyzer object." + "Please consider re-generating the SortingAnalyzer object." ) # load internal sorting in memory From 3d41aca67dbfcd01e74eb3bda0ccf7fa59e4c56a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 11:25:54 +0200 Subject: [PATCH 16/17] Recording cannot be Nonw --- src/spikeinterface/widgets/motion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 06c0305351..7c8389cae8 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -344,7 +344,7 @@ def __init__( motion.displacement ), "The number of segments in the recording must be the same as the number of segments in the motion object" - times = recording.get_times(segment_index=segment_index) if recording is not None else None + times = recording.get_times(segment_index=segment_index) plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], From 36251ef49b9c61582de2f84bd81d76d00be34b3e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 11:33:56 +0200 Subject: [PATCH 17/17] Let recording handle times --- src/spikeinterface/widgets/motion.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 7c8389cae8..42e9a20f3c 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -200,18 +200,11 @@ def __init__( if peak_amplitudes is not None: peak_amplitudes = peak_amplitudes[peak_mask] - if recording is not None: - sampling_frequency = recording.sampling_frequency - times = recording.get_times(segment_index=segment_index) - else: - times = None - plot_data = dict( peaks=peaks, peak_locations=peak_locations, peak_amplitudes=peak_amplitudes, direction=direction, - times=times, sampling_frequency=sampling_frequency, segment_index=segment_index, depth_lim=depth_lim, @@ -238,10 +231,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - if dp.times is None: + if dp.recording is None: peak_times = dp.peaks["sample_index"] / dp.sampling_frequency else: - peak_times = dp.times[dp.peaks["sample_index"]] + peak_times = dp.recording.sample_index_to_time(dp.peaks["sample_index"], segment_index=dp.segment_index) peak_locs = dp.peak_locations[dp.direction] if dp.scatter_decimate is not None: @@ -344,11 +337,8 @@ def __init__( motion.displacement ), "The number of segments in the recording must be the same as the number of segments in the motion object" - times = recording.get_times(segment_index=segment_index) - plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], - times=times, segment_index=segment_index, depth_lim=depth_lim, motion_lim=motion_lim,