diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fd68df9dda..d9a567dedf 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True): self.get_num_segments() == recording.get_num_segments() ), "The recording has a different number of segments than the sorting!" if check_spike_frames: - if has_exceeding_spikes(recording, self): + if has_exceeding_spikes(self, recording): warnings.warn( "Some spikes exceed the recording's duration! " "Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` " diff --git a/src/spikeinterface/core/frameslicesorting.py b/src/spikeinterface/core/frameslicesorting.py index ffd8af5fd8..f3ec449ab0 100644 --- a/src/spikeinterface/core/frameslicesorting.py +++ b/src/spikeinterface/core/frameslicesorting.py @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike assert ( start_frame <= parent_n_samples ), "`start_frame` should be smaller than the sortings' total number of samples." - if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting): + if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording): raise ValueError( "The sorting object has spikes whose times go beyond the recording duration." "This could indicate a bug in the sorter. " diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 1c0107d235..0722ede23f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource): * compute_spike_amplitudes() * compute_principal_components() + sorting : BaseSorting + The sorting object. recording : BaseRecording The recording object. - sorting: BaseSorting - The sorting object. - channel_from_template: bool, default: True + channel_from_template : bool, default: True If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided. If False, the max channel is computed for each spike given a radius around the template max channel. - extremum_channel_inds: dict of int | None, default: None + extremum_channel_inds : dict of int | None, default: None The extremum channel index dict given from template. - radius_um: float, default: 50 + radius_um : float, default: 50 The radius to find the real max channel. Used only when channel_from_template=False - peak_sign: "neg" | "pos", default: "neg" + peak_sign : "neg" | "pos", default: "neg" Peak sign to find the max channel. Used only when channel_from_template=False - include_spikes_in_margin: bool, default False + include_spikes_in_margin : bool, default False If not None then spikes in margin are added and an extra filed in dtype is added """ def __init__( self, - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=None, radius_um=50, diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index e439ddf1ed..d790308b76 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -127,7 +127,7 @@ def create_sorting_analyzer( recording.channel_ids, sparsity.channel_ids ), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond" elif sparse: - sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs) + sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs) else: sparsity = None diff --git a/src/spikeinterface/core/sparsity.py b/src/spikeinterface/core/sparsity.py index 3dc8c050db..a38562ea2c 100644 --- a/src/spikeinterface/core/sparsity.py +++ b/src/spikeinterface/core/sparsity.py @@ -539,8 +539,8 @@ def compute_sparsity( def estimate_sparsity( - recording: BaseRecording, sorting: BaseSorting, + recording: BaseRecording, num_spikes_for_sparsity: int = 100, ms_before: float = 1.0, ms_after: float = 2.5, @@ -563,10 +563,11 @@ def estimate_sparsity( Parameters ---------- - recording : BaseRecording - The recording sorting : BaseSorting The sorting + recording : BaseRecording + The recording + num_spikes_for_sparsity : int, default: 100 How many spikes per units to compute the sparsity ms_before : float, default: 1.0 diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 03acc9fed1..8d788acbad 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -87,12 +87,12 @@ def test_run_node_pipeline(cache_folder_creation): peak_retriever = PeakRetriever(recording, peaks) # channel index is from template spike_retriever_T = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds ) # channel index is per spike spike_retriever_S = SpikeRetriever( - recording, sorting, + recording, channel_from_template=False, extremum_channel_inds=extremum_channel_inds, radius_um=50, diff --git a/src/spikeinterface/core/tests/test_sparsity.py b/src/spikeinterface/core/tests/test_sparsity.py index 98d033d8ea..a192d90502 100644 --- a/src/spikeinterface/core/tests/test_sparsity.py +++ b/src/spikeinterface/core/tests/test_sparsity.py @@ -166,8 +166,8 @@ def test_estimate_sparsity(): # small radius should give a very sparse = one channel per unit sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, @@ -182,8 +182,8 @@ def test_estimate_sparsity(): # best_channel : the mask should exactly 3 channels per units sparsity = estimate_sparsity( - recording, sorting, + recording, num_spikes_for_sparsity=50, ms_before=1.0, ms_after=2.0, diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index befc49d034..98380e955f 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -679,16 +679,16 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None return waveforms_by_units -def has_exceeding_spikes(recording, sorting) -> bool: +def has_exceeding_spikes(sorting, recording) -> bool: """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments Parameters ---------- - recording : BaseRecording - The recording object sorting : BaseSorting The sorting object + recording : BaseRecording + The recording object Returns ------- diff --git a/src/spikeinterface/curation/remove_excess_spikes.py b/src/spikeinterface/curation/remove_excess_spikes.py index 0ae7a59fc6..d1d6b7f3cb 100644 --- a/src/spikeinterface/curation/remove_excess_spikes.py +++ b/src/spikeinterface/curation/remove_excess_spikes.py @@ -102,7 +102,7 @@ def remove_excess_spikes(sorting, recording): sorting_without_excess_spikes : Sorting The sorting without any excess spikes. """ - if has_exceeding_spikes(recording=recording, sorting=sorting): + if has_exceeding_spikes(sorting=sorting, recording=recording): return RemoveExcessSpikesSorting(sorting=sorting, recording=recording) else: return sorting diff --git a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py index 69edbaba4c..141cc4c34e 100644 --- a/src/spikeinterface/curation/tests/test_remove_excess_spikes.py +++ b/src/spikeinterface/curation/tests/test_remove_excess_spikes.py @@ -39,10 +39,10 @@ def test_remove_excess_spikes(): labels.append(labels_segment) sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency) - assert has_exceeding_spikes(recording, sorting) + assert has_exceeding_spikes(sorting, recording) sorting_corrected = remove_excess_spikes(sorting, recording) - assert not has_exceeding_spikes(recording, sorting_corrected) + assert not has_exceeding_spikes(sorting_corrected, recording) for u in sorting.unit_ids: for segment_index in range(sorting.get_num_segments()): diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 2e544d086b..8ff9cc5666 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -170,8 +170,8 @@ def _get_pipeline_nodes(self): sparsity_mask = sparsity.mask spike_retriever_node = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, include_spikes_in_margin=True, diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aebfd1fd78..72cbcb651f 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -95,7 +95,7 @@ def _get_pipeline_nodes(self): peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign) spike_retriever_node = SpikeRetriever( - recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices + sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices ) spike_amplitudes_node = SpikeAmplitudeNode( recording, diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 52a91342b6..23301292e5 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -103,8 +103,8 @@ def _get_pipeline_nodes(self): ) retriever = SpikeRetriever( - recording, sorting, + recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices, ) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index bb2f5aaafd..52dbaf23d4 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -73,7 +73,7 @@ class instance is used for each. In this case, we have to set self.__class__.recording, self.__class__.sorting = get_dataset() self.__class__.sparsity = estimate_sparsity( - self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20 + self.__class__.sorting, self.__class__.recording, method="radius", radius_um=20 ) self.__class__.cache_folder = create_cache_folder