From d9f53d04f99f78ecc4fbdab343e7d8e1161faf07 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Sep 2024 09:40:52 +0200 Subject: [PATCH] Fix compute analyzer pipeline with tmp recording --- src/spikeinterface/core/sortinganalyzer.py | 6 ++++-- .../postprocessing/principal_component.py | 17 ++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 177188f21d..0b4d959604 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -230,7 +230,7 @@ def __repr__(self) -> str: txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" if self.is_sparse(): txt += " - sparse" - if self.has_recording(): + if self.has_recording() or self.has_temporary_recording(): txt += " - has recording" ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) txt += "\n" + ext_txt @@ -1355,7 +1355,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job for extension_name, extension_params in extensions_with_pipeline.items(): extension_class = get_extension_class(extension_name) - assert self.has_recording(), f"Extension {extension_name} need the recording" + assert ( + self.has_recording() or self.has_temporary_recording() + ), f"Extension {extension_name} need the recording" for variable_name in extension_class.nodepipeline_variables: result_routage.append((extension_name, variable_name)) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f1f89403c7..1871c11b85 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) p = self.params - we = self.sorting_analyzer - sorting = we.sorting + sorting_analyzer = self.sorting_analyzer + sorting = sorting_analyzer.sorting assert ( - we.has_recording() - ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" - recording = we.recording + sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording() + ), "To compute PCA projections for all spikes, the sorting analyzer needs the recording" + recording = sorting_analyzer.recording # assert sorting.get_num_segments() == 1 assert p["mode"] in ("by_channel_local", "by_channel_global") @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): sparsity = self.sorting_analyzer.sparsity if sparsity is None: - sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} - max_channels_per_template = we.get_num_channels() + num_channels = recording.get_num_channels() + sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids} + max_channels_per_template = num_channels else: sparse_channels_indices = sparsity.unit_id_to_channel_indices max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): return pca_models def _fit_by_channel_global(self, progress_bar): - # we = self.sorting_analyzer p = self.params - # unit_ids = we.unit_ids unit_ids = self.sorting_analyzer.unit_ids # there is one unique PCA accross channels