Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 19, 2024
1 parent 501df84 commit f05b96c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ def _select_extension_data(self, unit_ids):
new_data[key] = arr[keep_unit_indices, :, :]

return new_data

def _merge_extension_data(self, merges, former_unit_ids):

new_unit_ids = self.sorting_analyzer._get_ids_after_merging(merges)
new_data = dict()
for key, arr in self.data.items():
Expand Down Expand Up @@ -608,7 +608,7 @@ def _set_params(self, num_chunks_per_segment=20, chunk_size=10000, seed=None):
def _select_extension_data(self, unit_ids):
# this do not depend on units
return self.data

def _merge_extension_data(self, merges):
# this do not depend on units
return self.data
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def random_spikes_selection(

return random_spikes_indices


def apply_merges_to_sorting(sorting, merges, censor_ms=None):
"""
Function to apply a resolved representation of the merges to a sorting object. If censor_ms is not None,
Expand Down Expand Up @@ -245,6 +246,7 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=None):
)

from spikeinterface.core import NumpySorting

times_list = []
labels_list = []
for segment_index in range(sorting.get_num_segments()):
Expand All @@ -259,4 +261,4 @@ def apply_merges_to_sorting(sorting, merges, censor_ms=None):
labels_list += [sorting.unit_ids[labels]]

sorting = NumpySorting.from_times_labels(times_list, labels_list, sorting.sampling_frequency)
return sorting
return sorting
5 changes: 4 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,9 @@ def _get_ids_after_merging(self, merges):
new_unit_ids.discard(id)
return list(new_unit_ids)

def _save_or_select_or_merge(self, format="binary_folder", folder=None, unit_ids=None, merges=None) -> "SortingAnalyzer":
def _save_or_select_or_merge(
self, format="binary_folder", folder=None, unit_ids=None, merges=None
) -> "SortingAnalyzer":
"""
Internal used by both save_as(), copy() and select_units() which are more or less the same.
"""
Expand Down Expand Up @@ -655,6 +657,7 @@ def _save_or_select_or_merge(self, format="binary_folder", folder=None, unit_ids
sorting_provenance = sorting_provenance.select_units(unit_ids)
elif merges is not None:
from spikeinterface.core.sorting_tools import apply_merges_to_sorting

sorting_provenance = apply_merges_to_sorting(sorting_provenance, merges)

if format == "memory":
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _select_extension_data(self, unit_ids):
if self.params["handle_collisions"]:
new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask]
return new_data

def _merge_extension_data(self, merges):
# keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))

Expand Down

0 comments on commit f05b96c

Please sign in to comment.