Skip to content

Commit

Permalink
Merge pull request SpikeInterface#2842 from h-mayorquin/add_templates…
Browse files Browse the repository at this point in the history
…_is_scaled

Add `is_scaled` to `Templates` object
  • Loading branch information
alejoe91 authored May 14, 2024
2 parents ec5925c + 08aa901 commit cec72f4
Show file tree
Hide file tree
Showing 14 changed files with 48 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def get_templates(self, unit_ids=None, operator="average", percentile=None, save
channel_ids=self.sorting_analyzer.channel_ids,
unit_ids=unit_ids,
probe=self.sorting_analyzer.get_probe(),
is_scaled=self.sorting_analyzer.return_scaled,
)
else:
raise ValueError("outputs must be numpy or Templates")
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,7 +1448,7 @@ def generate_templates(
mode="ellipsoid",
):
"""
Generate some templates from the given channel positions and neuron position.s
Generate some templates from the given channel positions and neuron positions.
The implementation is very naive : it generates a mono channel waveform using generate_single_fake_waveform()
and duplicates this same waveform on all channel given a simple decay law per unit.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
):
# very fast init because checks are done in load and create
self.sorting = sorting
# self.recorsding will be a property
# self.recording will be a property
self._recording = recording
self.rec_attributes = rec_attributes
self.format = format
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class Templates:
Array of unit IDs. If `None`, defaults to an array of increasing integers.
probe: Probe, default: None
A `probeinterface.Probe` object
is_scaled : bool, optional default: True
If True, it means that the templates are in uV, otherwise they are in raw ADC values.
check_for_consistent_sparsity : bool, optional default: None
When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the
structure fo the sparsity_masl.
structure of the sparsity_mask. If False, this check is skipped.
The following attributes are available after construction:
Expand All @@ -58,6 +60,7 @@ class Templates:
templates_array: np.ndarray
sampling_frequency: float
nbefore: int
is_scaled: bool = True

sparsity_mask: np.ndarray = None
channel_ids: np.ndarray = None
Expand Down Expand Up @@ -193,6 +196,7 @@ def to_dict(self):
"unit_ids": self.unit_ids,
"sampling_frequency": self.sampling_frequency,
"nbefore": self.nbefore,
"is_scaled": self.is_scaled,
"probe": self.probe.to_dict() if self.probe is not None else None,
}

Expand All @@ -205,6 +209,7 @@ def from_dict(cls, data):
unit_ids=np.asarray(data["unit_ids"]),
sampling_frequency=data["sampling_frequency"],
nbefore=data["nbefore"],
is_scaled=data["is_scaled"],
probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]),
)

Expand Down Expand Up @@ -238,6 +243,7 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None:

zarr_group.attrs["sampling_frequency"] = self.sampling_frequency
zarr_group.attrs["nbefore"] = self.nbefore
zarr_group.attrs["is_scaled"] = self.is_scaled

if self.sparsity_mask is not None:
zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/tests/test_template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _get_templates_object_from_sorting_analyzer(sorting_analyzer):
sparsity_mask=None,
channel_ids=sorting_analyzer.channel_ids,
unit_ids=sorting_analyzer.unit_ids,
is_scaled=sorting_analyzer.return_scaled,
)
return templates

Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def generate_drifting_recording(
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
)

drifting_templates = DriftingTemplates.from_static(templates)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/generation/tests/test_drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def make_some_templates():
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=True,
)

return templates
Expand Down
16 changes: 9 additions & 7 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from operator import is_

from .si_based import ComponentsBasedSorter

Expand Down Expand Up @@ -250,13 +251,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)

templates = Templates(
templates_array,
sampling_frequency,
nbefore,
None,
recording_w.channel_ids,
unit_ids,
recording_w.get_probe(),
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording_w.channel_ids,
unit_ids=unit_ids,
probe=recording_w.get_probe(),
is_scaled=False,
)

sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"])
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=None,
probe=recording_w.get_probe(),
is_scaled=False,
)
# TODO : try other methods for sparsity
# sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret
channel_ids=recording.channel_ids,
unit_ids=gt_sorting.unit_ids,
probe=recording.get_probe(),
is_scaled=return_scaled,
)
return gt_templates

Expand Down
9 changes: 8 additions & 1 deletion src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,14 @@ def main_function(cls, recording, peaks, params):
)

templates = Templates(
templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()
templates_array=templates_array,
sampling_frequency=fs,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording.channel_ids,
unit_ids=unit_ids,
probe=recording.get_probe(),
is_scaled=False,
)
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,12 @@ def main_function(cls, recording, peaks, params):
**params["job_kwargs"],
)
templates = Templates(
templates_array=templates_array, sampling_frequency=fs, nbefore=nbefore, probe=recording.get_probe()
templates_array=templates_array,
sampling_frequency=fs,
nbefore=nbefore,
sparsity_mask=None,
probe=recording.get_probe(),
is_scaled=False,
)

labels, peak_labels = remove_duplicates_via_matching(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,14 @@ def main_function(cls, recording, peaks, params):
)

templates = Templates(
templates_array, fs, nbefore, None, recording.channel_ids, unit_ids, recording.get_probe()
templates_array=templates_array,
sampling_frequency=fs,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording.channel_ids,
unit_ids=unit_ids,
probe=recording.get_probe(),
is_scaled=False,
)
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False)
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,5 @@ def remove_empty_templates(templates):
channel_ids=templates.channel_ids,
unit_ids=templates.unit_ids[not_empty],
probe=templates.probe,
is_scaled=templates.is_scaled,
)

0 comments on commit cec72f4

Please sign in to comment.