Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 23, 2024
1 parent 898bf8a commit 0838ff3
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _get_waveforms(self, sorting_analyzer=None, unit_ids=None, verbose=False, **

recording = self.sorting_analyzer.recording
sorting = self.sorting_analyzer.sorting

some_spikes = sorting_analyzer.get_extension("random_spikes").get_random_spikes()

if unit_ids is None:
Expand Down Expand Up @@ -246,27 +246,29 @@ def _merge_extension_data(self, units_to_merge, new_unit_ids, new_sorting_analyz
if new_sorting_analyzer.sparsity is not None:
sparsity_mask = new_sorting_analyzer.sparsity.mask
num_chans = int(max(np.sum(sparsity_mask, axis=1)))
old_num_chans = self.data['waveforms'].shape[2]
old_num_chans = self.data["waveforms"].shape[2]
if num_chans == old_num_chans:
new_data['waveforms'] = self.data['waveforms']
new_data["waveforms"] = self.data["waveforms"]
else:
num_waveforms = len(self.data['waveforms'])
num_samples = self.data['waveforms'].shape[1]
num_waveforms = len(self.data["waveforms"])
num_samples = self.data["waveforms"].shape[1]

some_spikes = new_sorting_analyzer.get_extension("random_spikes").get_random_spikes()
new_data['waveforms'] = np.zeros((num_waveforms, num_samples, num_chans), dtype=self.data['waveforms'].dtype)
new_data["waveforms"] = np.zeros(
(num_waveforms, num_samples, num_chans), dtype=self.data["waveforms"].dtype
)
keep_unit_indices = np.flatnonzero(~np.isin(new_sorting_analyzer.unit_ids, new_unit_ids))
keep_spike_mask = np.isin(some_spikes["unit_index"], keep_unit_indices)
new_data['waveforms'][keep_spike_mask, :, :old_num_chans] = self.data['waveforms'][keep_spike_mask]
new_data["waveforms"][keep_spike_mask, :, :old_num_chans] = self.data["waveforms"][keep_spike_mask]

# We only recompute waveforms for new units that might have a new sparsity mask. Could be
# We only recompute waveforms for new units that might have a new sparsity mask. Could be
# slightly optimized by checking exactly which merged units have a different mask
updated_unit_indices = np.flatnonzero(np.isin(new_sorting_analyzer.unit_ids, new_unit_ids))
updated_spike_mask = np.isin(some_spikes["unit_index"], updated_unit_indices)
new_waveforms = self._get_waveforms(new_sorting_analyzer, new_unit_ids, verbose, **job_kwargs)
new_data['waveforms'][updated_spike_mask] = new_waveforms
new_data["waveforms"][updated_spike_mask] = new_waveforms
else:
new_data['waveforms'] = self.data['waveforms']
new_data["waveforms"] = self.data["waveforms"]

return new_data

Expand Down Expand Up @@ -495,7 +497,6 @@ def _merge_extension_data(self, units_to_merge, new_unit_ids, new_sorting_analyz

return new_data


def _get_data(self, operator="average", percentile=None, outputs="numpy"):
if operator != "percentile":
key = operator
Expand Down

0 comments on commit 0838ff3

Please sign in to comment.