Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reordering recording, sorting args #3094

Merged
merged 8 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True):
self.get_num_segments() == recording.get_num_segments()
), "The recording has a different number of segments than the sorting!"
if check_spike_frames:
if has_exceeding_spikes(recording, self):
if has_exceeding_spikes(self, recording):
warnings.warn(
"Some spikes exceed the recording's duration! "
"Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` "
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
assert (
start_frame <= parent_n_samples
), "`start_frame` should be smaller than the sortings' total number of samples."
if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting):
if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording):
raise ValueError(
"The sorting object has spikes whose times go beyond the recording duration."
"This could indicate a bug in the sorter. "
Expand Down
16 changes: 8 additions & 8 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource):
* compute_spike_amplitudes()
* compute_principal_components()
sorting : BaseSorting
The sorting object.
recording : BaseRecording
The recording object.
sorting: BaseSorting
The sorting object.
channel_from_template: bool, default: True
channel_from_template : bool, default: True
If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided.
If False, the max channel is computed for each spike given a radius around the template max channel.
extremum_channel_inds: dict of int | None, default: None
extremum_channel_inds : dict of int | None, default: None
The extremum channel index dict given from template.
radius_um: float, default: 50
radius_um : float, default: 50
The radius to find the real max channel.
Used only when channel_from_template=False
peak_sign: "neg" | "pos", default: "neg"
peak_sign : "neg" | "pos", default: "neg"
Peak sign to find the max channel.
Used only when channel_from_template=False
include_spikes_in_margin: bool, default False
include_spikes_in_margin : bool, default False
If not None then spikes in margin are added and an extra filed in dtype is added
"""

def __init__(
self,
recording,
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=None,
radius_um=50,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def create_sorting_analyzer(
recording.channel_ids, sparsity.channel_ids
), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond"
elif sparse:
sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs)
sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs)
else:
sparsity = None

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ def compute_sparsity(


def estimate_sparsity(
recording: BaseRecording,
sorting: BaseSorting,
recording: BaseRecording,
num_spikes_for_sparsity: int = 100,
ms_before: float = 1.0,
ms_after: float = 2.5,
Expand All @@ -563,10 +563,11 @@ def estimate_sparsity(
Parameters
----------
recording : BaseRecording
The recording
sorting : BaseSorting
The sorting
recording : BaseRecording
The recording
num_spikes_for_sparsity : int, default: 100
How many spikes per units to compute the sparsity
ms_before : float, default: 1.0
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def test_run_node_pipeline(cache_folder_creation):
peak_retriever = PeakRetriever(recording, peaks)
# channel index is from template
spike_retriever_T = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
)
# channel index is per spike
spike_retriever_S = SpikeRetriever(
recording,
sorting,
recording,
channel_from_template=False,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_estimate_sparsity():

# small radius should give a very sparse = one channel per unit
sparsity = estimate_sparsity(
recording,
sorting,
recording,
num_spikes_for_sparsity=50,
ms_before=1.0,
ms_after=2.0,
Expand All @@ -182,8 +182,8 @@ def test_estimate_sparsity():

# best_channel : the mask should exactly 3 channels per units
sparsity = estimate_sparsity(
recording,
sorting,
recording,
num_spikes_for_sparsity=50,
ms_before=1.0,
ms_after=2.0,
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,16 +679,16 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None
return waveforms_by_units


def has_exceeding_spikes(recording, sorting) -> bool:
def has_exceeding_spikes(sorting, recording) -> bool:
"""
Check if the sorting objects has spikes exceeding the recording number of samples, for all segments
Parameters
----------
recording : BaseRecording
The recording object
sorting : BaseSorting
The sorting object
recording : BaseRecording
The recording object
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/remove_excess_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def remove_excess_spikes(sorting, recording):
sorting_without_excess_spikes : Sorting
The sorting without any excess spikes.
"""
if has_exceeding_spikes(recording=recording, sorting=sorting):
if has_exceeding_spikes(sorting=sorting, recording=recording):
return RemoveExcessSpikesSorting(sorting=sorting, recording=recording)
else:
return sorting
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def test_remove_excess_spikes():
labels.append(labels_segment)

sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency)
assert has_exceeding_spikes(recording, sorting)
assert has_exceeding_spikes(sorting, recording)

sorting_corrected = remove_excess_spikes(sorting, recording)
assert not has_exceeding_spikes(recording, sorting_corrected)
assert not has_exceeding_spikes(sorting_corrected, recording)

for u in sorting.unit_ids:
for segment_index in range(sorting.get_num_segments()):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def _get_pipeline_nodes(self):
sparsity_mask = sparsity.mask

spike_retriever_node = SpikeRetriever(
recording,
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=extremum_channels_indices,
include_spikes_in_margin=True,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_pipeline_nodes(self):
peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign)

spike_retriever_node = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices
sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices
)
spike_amplitudes_node = SpikeAmplitudeNode(
recording,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def _get_pipeline_nodes(self):
)

retriever = SpikeRetriever(
recording,
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=extremum_channels_indices,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class instance is used for each. In this case, we have to set
self.__class__.recording, self.__class__.sorting = get_dataset()

self.__class__.sparsity = estimate_sparsity(
self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20
self.__class__.sorting, self.__class__.recording, method="radius", radius_um=20
)
self.__class__.cache_folder = create_cache_folder

Expand Down
Loading