Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compute analyzer pipeline with tmp recording #3433

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def __repr__(self) -> str:
txt += " - sparse"
if self.has_recording():
txt += " - has recording"
if self.has_temporary_recording():
txt += " - has temporary recording"
ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys())
txt += "\n" + ext_txt
return txt
Expand Down Expand Up @@ -350,7 +352,7 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut
def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes):
# used by create and save_as

assert recording is not None, "To create a SortingAnalyzer you need recording not None"
assert recording is not None, "To create a SortingAnalyzer you need to specify the recording"

folder = Path(folder)
if folder.is_dir():
Expand Down Expand Up @@ -1221,7 +1223,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar
extensions[ext_name] = ext_params
self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs)
else:
raise ValueError("SortingAnalyzer.compute() need str, dict or list")
raise ValueError("SortingAnalyzer.compute() needs a str, dict or list")

def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs) -> "AnalyzerExtension":
"""
Expand Down Expand Up @@ -1355,7 +1357,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job

for extension_name, extension_params in extensions_with_pipeline.items():
extension_class = get_extension_class(extension_name)
assert self.has_recording(), f"Extension {extension_name} need the recording"
assert (
self.has_recording() or self.has_temporary_recording()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the main fix!

), f"Extension {extension_name} requires the recording"

for variable_name in extension_class.nodepipeline_variables:
result_routage.append((extension_name, variable_name))
Expand Down Expand Up @@ -1603,17 +1607,17 @@ def _sort_extensions_by_dependency(extensions):
def _get_children_dependencies(extension_name):
"""
Extension classes have a `depend_on` attribute to declare on which class they
depend. For instance "templates" depend on "waveforms". "waveforms depends on "random_spikes".
depend on. For instance "templates" depends on "waveforms". "waveforms" depends on "random_spikes".

This function is making the reverse way : get all children that depend of a
This function is going the opposite way: it finds all children that depend on a
particular extension.

This is recursive so this includes : children and so grand children and great grand children
The implementation is recursive so that the output includes children, grand children, great grand children, etc.

This function is usefull for deleting on recompute.
For instance recompute the "waveforms" need to delete "template"
This make sens if "ms_before" is change in "waveforms" because the template also depends
on this parameters.
This function is useful for deleting existing extensions on recompute.
For instance, recomputing the "waveforms" needs to delete the "templates", since the latter depends on the former.
For this particular example, if we change the "ms_before" parameter of the "waveforms", also the "templates" will
require recomputation as this parameter is inherited.
"""
names = []
children = _extension_children[extension_name]
Expand Down
17 changes: 8 additions & 9 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

job_kwargs = fix_job_kwargs(job_kwargs)
p = self.params
we = self.sorting_analyzer
sorting = we.sorting
sorting_analyzer = self.sorting_analyzer
sorting = sorting_analyzer.sorting
assert (
we.has_recording()
), "To compute PCA projections for all spikes, the waveform extractor needs the recording"
recording = we.recording
sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording()
), "To compute PCA projections for all spikes, the sorting analyzer needs the recording"
recording = sorting_analyzer.recording

# assert sorting.get_num_segments() == 1
assert p["mode"] in ("by_channel_local", "by_channel_global")
Expand All @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

sparsity = self.sorting_analyzer.sparsity
if sparsity is None:
sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids}
max_channels_per_template = we.get_num_channels()
num_channels = recording.get_num_channels()
sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids}
max_channels_per_template = num_channels
else:
sparse_channels_indices = sparsity.unit_id_to_channel_indices
max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()])
Expand Down Expand Up @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
return pca_models

def _fit_by_channel_global(self, progress_bar):
# we = self.sorting_analyzer
p = self.params
# unit_ids = we.unit_ids
unit_ids = self.sorting_analyzer.unit_ids

# there is one unique PCA accross channels
Expand Down
Loading