diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 3a4561822e..e5691603ac 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -499,6 +499,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save channel_ids=self.sorting_analyzer.channel_ids, unit_ids=unit_ids, probe=self.sorting_analyzer.get_probe(), + is_scaled=self.sorting_analyzer.return_scaled, ) else: raise ValueError("outputs must be numpy or Templates") diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ec76fcbaa9..05c1ebc7ed 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1448,7 +1448,7 @@ def generate_templates( mode="ellipsoid", ): """ - Generate some templates from the given channel positions and neuron position.s + Generate some templates from the given channel positions and neuron positions. The implementation is very naive : it generates a mono channel waveform using generate_single_fake_waveform() and duplicates this same waveform on all channel given a simple decay law per unit. diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3ce20a0209..d57df2b5ae 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -195,7 +195,7 @@ def __init__( ): # very fast init because checks are done in load and create self.sorting = sorting - # self.recorsding will be a property + # self.recording will be a property self._recording = recording self.rec_attributes = rec_attributes self.format = format diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 51688709b2..4eb82be2d6 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -30,9 +30,11 @@ class Templates: Array of unit IDs. If `None`, defaults to an array of increasing integers. probe: Probe, default: None A `probeinterface.Probe` object + is_scaled : bool, optional default: True + If True, it means that the templates are in uV, otherwise they are in raw ADC values. check_for_consistent_sparsity : bool, optional default: None When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the - structure fo the sparsity_masl. + structure of the sparsity_mask. If False, this check is skipped. The following attributes are available after construction: @@ -58,6 +60,7 @@ class Templates: templates_array: np.ndarray sampling_frequency: float nbefore: int + is_scaled: bool = True sparsity_mask: np.ndarray = None channel_ids: np.ndarray = None @@ -193,6 +196,7 @@ def to_dict(self): "unit_ids": self.unit_ids, "sampling_frequency": self.sampling_frequency, "nbefore": self.nbefore, + "is_scaled": self.is_scaled, "probe": self.probe.to_dict() if self.probe is not None else None, } @@ -205,6 +209,7 @@ def from_dict(cls, data): unit_ids=np.asarray(data["unit_ids"]), sampling_frequency=data["sampling_frequency"], nbefore=data["nbefore"], + is_scaled=data["is_scaled"], probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]), ) @@ -238,6 +243,7 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore + zarr_group.attrs["is_scaled"] = self.is_scaled if self.sparsity_mask is not None: zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask) diff --git a/src/spikeinterface/core/tests/test_template_tools.py b/src/spikeinterface/core/tests/test_template_tools.py index f79c830db6..6ef8267742 100644 --- a/src/spikeinterface/core/tests/test_template_tools.py +++ b/src/spikeinterface/core/tests/test_template_tools.py @@ -47,6 +47,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer): sparsity_mask=None, channel_ids=sorting_analyzer.channel_ids, unit_ids=sorting_analyzer.unit_ids, + is_scaled=sorting_analyzer.return_scaled, ) return templates diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 8a658cd97d..7f617c3ade 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -404,6 +404,7 @@ def generate_drifting_recording( sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=True, ) drifting_templates = DriftingTemplates.from_static(templates) diff --git a/src/spikeinterface/generation/tests/test_drift_tools.py b/src/spikeinterface/generation/tests/test_drift_tools.py index ab03b30d82..e64e64ffda 100644 --- a/src/spikeinterface/generation/tests/test_drift_tools.py +++ b/src/spikeinterface/generation/tests/test_drift_tools.py @@ -73,6 +73,7 @@ def make_some_templates(): sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + is_scaled=True, ) return templates diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index ba6870eef2..6575aba15e 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -1,4 +1,5 @@ from __future__ import annotations +from operator import is_ from .si_based import ComponentsBasedSorter @@ -250,13 +251,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) templates = Templates( - templates_array, - sampling_frequency, - nbefore, - None, - recording_w.channel_ids, - unit_ids, - recording_w.get_probe(), + templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording_w.channel_ids, + unit_ids=unit_ids, + probe=recording_w.get_probe(), + is_scaled=False, ) sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"]) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index c2b9f4cfc7..e07924b196 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -191,7 +191,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore, + sparsity_mask=None, probe=recording_w.get_probe(), + is_scaled=False, ) # TODO : try other methods for sparsity # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py index 3401e36dd0..313f19537e 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -77,6 +77,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret channel_ids=recording.channel_ids, unit_ids=gt_sorting.unit_ids, probe=recording.get_probe(), + is_scaled=return_scaled, ) return gt_templates diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 4135bd4b6e..ce7a78e4c6 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -226,7 +226,14 @@ def main_function(cls, recording, peaks, params): ) templates = Templates( - templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=unit_ids, + probe=recording.get_probe(), + is_scaled=False, ) if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False) diff --git a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py index d24af3c175..a07a6140e1 100644 --- a/src/spikeinterface/sortingcomponents/clustering/position_and_features.py +++ b/src/spikeinterface/sortingcomponents/clustering/position_and_features.py @@ -184,7 +184,12 @@ def main_function(cls, recording, peaks, params): **params["job_kwargs"], ) templates = Templates( - templates_array=templates_array, sampling_frequency=fs, nbefore=nbefore, probe=recording.get_probe() + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + probe=recording.get_probe(), + is_scaled=False, ) labels, peak_labels = remove_duplicates_via_matching( diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index efd63be55f..6c1ad75383 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -137,7 +137,14 @@ def main_function(cls, recording, peaks, params): ) templates = Templates( - templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe() + templates_array=templates_array, + sampling_frequency=fs, + nbefore=nbefore, + sparsity_mask=None, + channel_ids=recording.channel_ids, + unit_ids=unit_ids, + probe=recording.get_probe(), + is_scaled=False, ) if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 06dfd994f3..cf0d22c0c8 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -137,4 +137,5 @@ def remove_empty_templates(templates): channel_ids=templates.channel_ids, unit_ids=templates.unit_ids[not_empty], probe=templates.probe, + is_scaled=templates.is_scaled, )