Skip to content

Commit

Permalink
Fix compute analyzer pipeline with tmp recording
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Sep 23, 2024
1 parent f576da3 commit d9f53d0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def __repr__(self) -> str:
txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}"
if self.is_sparse():
txt += " - sparse"
if self.has_recording():
if self.has_recording() or self.has_temporary_recording():
txt += " - has recording"
ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys())
txt += "\n" + ext_txt
Expand Down Expand Up @@ -1355,7 +1355,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job

for extension_name, extension_params in extensions_with_pipeline.items():
extension_class = get_extension_class(extension_name)
assert self.has_recording(), f"Extension {extension_name} need the recording"
assert (
self.has_recording() or self.has_temporary_recording()
), f"Extension {extension_name} need the recording"

for variable_name in extension_class.nodepipeline_variables:
result_routage.append((extension_name, variable_name))
Expand Down
17 changes: 8 additions & 9 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

job_kwargs = fix_job_kwargs(job_kwargs)
p = self.params
we = self.sorting_analyzer
sorting = we.sorting
sorting_analyzer = self.sorting_analyzer
sorting = sorting_analyzer.sorting
assert (
we.has_recording()
), "To compute PCA projections for all spikes, the waveform extractor needs the recording"
recording = we.recording
sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording()
), "To compute PCA projections for all spikes, the sorting analyzer needs the recording"
recording = sorting_analyzer.recording

# assert sorting.get_num_segments() == 1
assert p["mode"] in ("by_channel_local", "by_channel_global")
Expand All @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

sparsity = self.sorting_analyzer.sparsity
if sparsity is None:
sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids}
max_channels_per_template = we.get_num_channels()
num_channels = recording.get_num_channels()
sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids}
max_channels_per_template = num_channels
else:
sparse_channels_indices = sparsity.unit_id_to_channel_indices
max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()])
Expand Down Expand Up @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
return pca_models

def _fit_by_channel_global(self, progress_bar):
# we = self.sorting_analyzer
p = self.params
# unit_ids = we.unit_ids
unit_ids = self.sorting_analyzer.unit_ids

# there is one unique PCA accross channels
Expand Down

0 comments on commit d9f53d0

Please sign in to comment.