Skip to content

Commit

Permalink
Merge pull request #3094 from jakeswann1/main
Browse files Browse the repository at this point in the history
Reordering recording, sorting args
  • Loading branch information
samuelgarcia authored Jun 28, 2024
2 parents 20cb6c8 + 63cda4b commit 6f87a9b
Show file tree
Hide file tree
Showing 14 changed files with 29 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def register_recording(self, recording, check_spike_frames=True):
self.get_num_segments() == recording.get_num_segments()
), "The recording has a different number of segments than the sorting!"
if check_spike_frames:
if has_exceeding_spikes(recording, self):
if has_exceeding_spikes(self, recording):
warnings.warn(
"Some spikes exceed the recording's duration! "
"Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` "
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
assert (
start_frame <= parent_n_samples
), "`start_frame` should be smaller than the sortings' total number of samples."
if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting):
if check_spike_frames and has_exceeding_spikes(parent_sorting, parent_sorting._recording):
raise ValueError(
"The sorting object has spikes whose times go beyond the recording duration."
"This could indicate a bug in the sorter. "
Expand Down
16 changes: 8 additions & 8 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,29 +152,29 @@ class SpikeRetriever(PeakSource):
* compute_spike_amplitudes()
* compute_principal_components()
sorting : BaseSorting
The sorting object.
recording : BaseRecording
The recording object.
sorting: BaseSorting
The sorting object.
channel_from_template: bool, default: True
channel_from_template : bool, default: True
If True, then the channel_index is inferred from the template and `extremum_channel_inds` must be provided.
If False, the max channel is computed for each spike given a radius around the template max channel.
extremum_channel_inds: dict of int | None, default: None
extremum_channel_inds : dict of int | None, default: None
The extremum channel index dict given from template.
radius_um: float, default: 50
radius_um : float, default: 50
The radius to find the real max channel.
Used only when channel_from_template=False
peak_sign: "neg" | "pos", default: "neg"
peak_sign : "neg" | "pos", default: "neg"
Peak sign to find the max channel.
Used only when channel_from_template=False
include_spikes_in_margin: bool, default False
include_spikes_in_margin : bool, default False
If not None then spikes in margin are added and an extra filed in dtype is added
"""

def __init__(
self,
recording,
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=None,
radius_um=50,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def create_sorting_analyzer(
recording.channel_ids, sparsity.channel_ids
), "create_sorting_analyzer(): if external sparsity is given unit_ids must correspond"
elif sparse:
sparsity = estimate_sparsity(recording, sorting, **sparsity_kwargs)
sparsity = estimate_sparsity(sorting, recording, **sparsity_kwargs)
else:
sparsity = None

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ def compute_sparsity(


def estimate_sparsity(
recording: BaseRecording,
sorting: BaseSorting,
recording: BaseRecording,
num_spikes_for_sparsity: int = 100,
ms_before: float = 1.0,
ms_after: float = 2.5,
Expand All @@ -563,10 +563,11 @@ def estimate_sparsity(
Parameters
----------
recording : BaseRecording
The recording
sorting : BaseSorting
The sorting
recording : BaseRecording
The recording
num_spikes_for_sparsity : int, default: 100
How many spikes per units to compute the sparsity
ms_before : float, default: 1.0
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@ def test_run_node_pipeline(cache_folder_creation):
peak_retriever = PeakRetriever(recording, peaks)
# channel index is from template
spike_retriever_T = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channel_inds
)
# channel index is per spike
spike_retriever_S = SpikeRetriever(
recording,
sorting,
recording,
channel_from_template=False,
extremum_channel_inds=extremum_channel_inds,
radius_um=50,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_estimate_sparsity():

# small radius should give a very sparse = one channel per unit
sparsity = estimate_sparsity(
recording,
sorting,
recording,
num_spikes_for_sparsity=50,
ms_before=1.0,
ms_after=2.0,
Expand All @@ -182,8 +182,8 @@ def test_estimate_sparsity():

# best_channel : the mask should exactly 3 channels per units
sparsity = estimate_sparsity(
recording,
sorting,
recording,
num_spikes_for_sparsity=50,
ms_before=1.0,
ms_after=2.0,
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,16 +679,16 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None
return waveforms_by_units


def has_exceeding_spikes(recording, sorting) -> bool:
def has_exceeding_spikes(sorting, recording) -> bool:
"""
Check if the sorting objects has spikes exceeding the recording number of samples, for all segments
Parameters
----------
recording : BaseRecording
The recording object
sorting : BaseSorting
The sorting object
recording : BaseRecording
The recording object
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/remove_excess_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def remove_excess_spikes(sorting, recording):
sorting_without_excess_spikes : Sorting
The sorting without any excess spikes.
"""
if has_exceeding_spikes(recording=recording, sorting=sorting):
if has_exceeding_spikes(sorting=sorting, recording=recording):
return RemoveExcessSpikesSorting(sorting=sorting, recording=recording)
else:
return sorting
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def test_remove_excess_spikes():
labels.append(labels_segment)

sorting = NumpySorting.from_times_labels(times, labels, sampling_frequency=sampling_frequency)
assert has_exceeding_spikes(recording, sorting)
assert has_exceeding_spikes(sorting, recording)

sorting_corrected = remove_excess_spikes(sorting, recording)
assert not has_exceeding_spikes(recording, sorting_corrected)
assert not has_exceeding_spikes(sorting_corrected, recording)

for u in sorting.unit_ids:
for segment_index in range(sorting.get_num_segments()):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ def _get_pipeline_nodes(self):
sparsity_mask = sparsity.mask

spike_retriever_node = SpikeRetriever(
recording,
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=extremum_channels_indices,
include_spikes_in_margin=True,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/spike_amplitudes.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_pipeline_nodes(self):
peak_shifts = get_template_extremum_channel_peak_shift(self.sorting_analyzer, peak_sign=peak_sign)

spike_retriever_node = SpikeRetriever(
recording, sorting, channel_from_template=True, extremum_channel_inds=extremum_channels_indices
sorting, recording, channel_from_template=True, extremum_channel_inds=extremum_channels_indices
)
spike_amplitudes_node = SpikeAmplitudeNode(
recording,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/spike_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def _get_pipeline_nodes(self):
)

retriever = SpikeRetriever(
recording,
sorting,
recording,
channel_from_template=True,
extremum_channel_inds=extremum_channels_indices,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class instance is used for each. In this case, we have to set
self.__class__.recording, self.__class__.sorting = get_dataset()

self.__class__.sparsity = estimate_sparsity(
self.__class__.recording, self.__class__.sorting, method="radius", radius_um=20
self.__class__.sorting, self.__class__.recording, method="radius", radius_um=20
)
self.__class__.cache_folder = create_cache_folder

Expand Down

0 comments on commit 6f87a9b

Please sign in to comment.