Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 23, 2024
1 parent 53a179f commit d3b11a6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
20 changes: 11 additions & 9 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _get_waveforms(self, sorting_analyzer=None, unit_ids=None, verbose=False, **

recording = self.sorting_analyzer.recording
sorting = self.sorting_analyzer.sorting

if unit_ids is None:
unit_ids = sorting.unit_ids

Expand Down Expand Up @@ -238,22 +238,24 @@ def _select_extension_data(self, unit_ids):

def _merge_extension_data(self, units_to_merge, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
new_data = dict()

if new_sorting_analyzer.sparsity is not None:
sparsity_mask = new_sorting_analyzer.sparsity.mask
num_chans = int(max(np.sum(sparsity_mask, axis=1)))
old_num_chans = self.data['waveforms'].shape[2]
old_num_chans = self.data["waveforms"].shape[2]
if num_chans > old_num_chans:
new_data['waveforms'] = self.data['waveforms']
new_data["waveforms"] = self.data["waveforms"]
else:
num_waveforms = len(self.data['waveforms'])
num_samples = self.data['waveforms'][1]
new_data['waveforms'] = np.zeros((num_waveforms, num_samples, num_chans), dtype=self.data['waveforms'].dtype)
num_waveforms = len(self.data["waveforms"])
num_samples = self.data["waveforms"][1]
new_data["waveforms"] = np.zeros(
(num_waveforms, num_samples, num_chans), dtype=self.data["waveforms"].dtype
)
keep_unit_indices = np.flatnonzero(~np.isin(new_sorting_analyzer, new_unit_ids))
new_data['waveforms'][keep_unit_indices, :, :old_num_chans] = self.data['waveforms'][keep_unit_indices]
new_data["waveforms"][keep_unit_indices, :, :old_num_chans] = self.data["waveforms"][keep_unit_indices]
updated_unit_indices = np.flatnonzero(np.isin(new_sorting_analyzer, new_unit_ids))
new_waveforms = self._get_waveforms(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs)
new_data['waveforms'][updated_unit_indices] = new_waveforms
new_data["waveforms"][updated_unit_indices] = new_waveforms
return new_data

def get_waveforms_one_unit(
Expand Down
19 changes: 14 additions & 5 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,8 +725,8 @@ def _save_or_select_or_merge(
new_sorting_analyzer,
units_to_merge=units_to_merge,
new_unit_ids=unit_ids,
verbose=verbose,
**job_kwargs
verbose=verbose,
**job_kwargs,
)
else:
new_sorting_analyzer.extensions[extension_name] = extension.copy(
Expand Down Expand Up @@ -777,7 +777,9 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz
# TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!!
return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids)

def merge_units(self, units_to_merge, new_unit_ids=None, format="memory", folder=None, verbose=False, **job_kwargs) -> "SortingAnalyzer":
def merge_units(
self, units_to_merge, new_unit_ids=None, format="memory", folder=None, verbose=False, **job_kwargs
) -> "SortingAnalyzer":
"""
This method is equivalent to `save_as()`but with a list of merges that have to be achieved.
Merges units by creating a new sorting analyzer object in a new folder with appropriate merges
Expand Down Expand Up @@ -816,7 +818,12 @@ def merge_units(self, units_to_merge, new_unit_ids=None, format="memory", folder
new_unit_ids = [i[0] for i in units_to_merge]

return self._save_or_select_or_merge(
format=format, folder=folder, units_to_merge=units_to_merge, unit_ids=new_unit_ids, verbose=verbose, **job_kwargs
format=format,
folder=folder,
units_to_merge=units_to_merge,
unit_ids=new_unit_ids,
verbose=verbose,
**job_kwargs,
)

def copy(self):
Expand Down Expand Up @@ -1752,7 +1759,9 @@ def merge(self, new_sorting_analyzer, units_to_merge=None, new_unit_ids=None, ve
if units_to_merge is None:
new_extension.data = self.data
else:
new_extension.data = self._merge_extension_data(units_to_merge, new_unit_ids, new_sorting_analyzer, verbose=verbose, **job_kwargs)
new_extension.data = self._merge_extension_data(
units_to_merge, new_unit_ids, new_sorting_analyzer, verbose=verbose, **job_kwargs
)
new_extension.save()
return new_extension

Expand Down

0 comments on commit d3b11a6

Please sign in to comment.