Skip to content

Commit

Permalink
Merge pull request #2179 from alejoe91/waveforms-percentiles
Browse files Browse the repository at this point in the history
Add 'percentile' to template modes and plot_unit_templates
  • Loading branch information
samuelgarcia authored Nov 23, 2023
2 parents aa67315 + 4eb2e33 commit 48a543e
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 52 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ widgets = [
"matplotlib",
"ipympl",
"ipywidgets",
"sortingview>=0.11.15",
"sortingview>=0.12.0",
]

qualitymetrics = [
Expand Down
13 changes: 13 additions & 0 deletions src/spikeinterface/core/tests/test_waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,19 @@ def test_unfiltered_extraction():
wfs_std = we.get_all_templates(mode="std")
assert wfs_std.shape == (num_units, num_samples, num_channels)

wf_prct = we.get_template(0, mode="percentile", percentile=10)
assert wf_prct.shape == (num_samples, num_channels)
wfs_prct = we.get_all_templates(mode="percentile", percentile=10)
assert wfs_prct.shape == (num_units, num_samples, num_channels)

# percentile mode should fail if percentile is None or not in [0, 100]
with pytest.raises(AssertionError):
wf_prct = we.get_template(0, mode="percentile")
with pytest.raises(AssertionError):
wfs_prct = we.get_all_templates(mode="percentile")
with pytest.raises(AssertionError):
wfs_prct = we.get_all_templates(mode="percentile", percentile=101)

wf_segment = we.get_template_segment(unit_id=0, segment_index=0)
assert wf_segment.shape == (num_samples, num_channels)
assert wf_segment.shape == (num_samples, num_channels)
Expand Down
64 changes: 45 additions & 19 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .sparsity import ChannelSparsity, compute_sparsity, _sparsity_doc
from .waveform_tools import extract_waveforms_to_buffers, has_exceeding_spikes

_possible_template_modes = ("average", "std", "median")
_possible_template_modes = ("average", "std", "median", "percentile")


class WaveformExtractor:
Expand Down Expand Up @@ -1179,25 +1179,36 @@ def get_waveforms_segment(self, segment_index: int, unit_id, sparsity):
mask = index_ar["segment_index"] == segment_index
return wfs[mask, :, :]

def precompute_templates(self, modes=("average", "std")) -> None:
def precompute_templates(self, modes=("average", "std", "median", "percentile"), percentile=None) -> None:
"""
Precompute all template for different "modes":
Precompute all templates for different "modes":
* average
* std
* median
* percentile
The results is cache in memory as 3d ndarray (nunits, nsamples, nchans)
and also saved as npy file in the folder to avoid recomputation each time.
Parameters
----------
modes: list
The modes to compute the templates
percentile: float, default: None
Percentile to use for mode="percentile"
The results is cached in memory as a 3d ndarray (nunits, nsamples, nchans)
and also saved as an npy file in the folder to avoid recomputation each time.
"""
# TODO : run this in parralel

unit_ids = self.unit_ids
num_chans = self.get_num_channels()

mode_names = {}
for mode in modes:
mode_name = mode if mode != "percentile" else f"{mode}_{percentile}"
mode_names[mode] = mode_name
dtype = self._params["dtype"] if mode == "median" else np.float32
templates = np.zeros((len(unit_ids), self.nsamples, num_chans), dtype=dtype)
self._template_cache[mode] = templates
self._template_cache[mode_names[mode]] = templates

for unit_ind, unit_id in enumerate(unit_ids):
wfs = self.get_waveforms(unit_id, cache=False)
Expand All @@ -1214,57 +1225,68 @@ def precompute_templates(self, modes=("average", "std")) -> None:
arr = np.average(wfs, axis=0)
elif mode == "std":
arr = np.std(wfs, axis=0)
elif mode == "percentile":
assert percentile is not None, "percentile must be specified for mode='percentile'"
assert 0 <= percentile <= 100, "percentile must be between 0 and 100 inclusive"
arr = np.percentile(wfs, percentile, axis=0)
else:
raise ValueError("mode must in median/average/std")
self._template_cache[mode][unit_ind][:, mask] = arr
raise ValueError(f"'mode' must be in {_possible_template_modes}")
self._template_cache[mode_names[mode]][unit_ind][:, mask] = arr

for mode in modes:
templates = self._template_cache[mode]
templates = self._template_cache[mode_names[mode]]
if self.folder is not None:
template_file = self.folder / f"templates_{mode}.npy"
template_file = self.folder / f"templates_{mode_names[mode]}.npy"
np.save(template_file, templates)

def get_all_templates(self, unit_ids: Optional[Iterable] = None, mode="average"):
def get_all_templates(self, unit_ids: Optional[Iterable] = None, mode="average", percentile: float | None = None):
"""
Return templates (average waveform) for multiple units.
Return templates (average waveforms) for multiple units.
Parameters
----------
unit_ids: list or None
Unit ids to retrieve waveforms for
mode: "average" | "median" | "std", default: "average"
mode: "average" | "median" | "std" | "percentile", default: "average"
The mode to compute the templates
percentile: float, default: None
Percentile to use for mode="percentile"
Returns
-------
templates: np.array
The returned templates (num_units, num_samples, num_channels)
"""
if mode not in self._template_cache:
self.precompute_templates(modes=[mode])

templates = self._template_cache[mode]
self.precompute_templates(modes=[mode], percentile=percentile)
mode_name = mode if mode != "percentile" else f"{mode}_{percentile}"
templates = self._template_cache[mode_name]

if unit_ids is not None:
unit_indices = self.sorting.ids_to_indices(unit_ids)
templates = templates[unit_indices, :, :]

return np.array(templates)

def get_template(self, unit_id, mode="average", sparsity=None, force_dense: bool = False):
def get_template(
self, unit_id, mode="average", sparsity=None, force_dense: bool = False, percentile: float | None = None
):
"""
Return template (average waveform).
Parameters
----------
unit_id: int or str
Unit id to retrieve waveforms for
mode: "average" | "median" | "std", default: "average"
mode: "average" | "median" | "std" | "percentile", default: "average"
The mode to compute the template
sparsity: ChannelSparsity, default: None
Sparsity to apply to the waveforms (if WaveformExtractor is not sparse)
force_dense: bool (False)
force_dense: bool, default: False
Return a dense template even if the waveform extractor is sparse
percentile: float, default: None
Percentile to use for mode="percentile".
Values must be between 0 and 100 inclusive
Returns
-------
Expand Down Expand Up @@ -1304,6 +1326,10 @@ def get_template(self, unit_id, mode="average", sparsity=None, force_dense: bool
template = np.average(wfs, axis=0)
elif mode == "std":
template = np.std(wfs, axis=0)
elif mode == "percentile":
assert percentile is not None, "percentile must be specified for mode='percentile'"
assert 0 <= percentile <= 100, "percentile must be between 0 and 100 inclusive"
template = np.percentile(wfs, percentile, axis=0)

return np.array(template)

Expand Down
68 changes: 64 additions & 4 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,13 @@ def test_plot_unit_waveforms(self):
)

def test_plot_unit_templates(self):
possible_backends = list(sw.UnitWaveformsWidget.get_possible_backends())
possible_backends = list(sw.UnitTemplatesWidget.get_possible_backends())
for backend in possible_backends:
if backend not in self.skip_backends:
sw.plot_unit_templates(self.we_dense, backend=backend, **self.backend_kwargs[backend])
print(f"Testing backend {backend}")
sw.plot_unit_templates(
self.we_dense, backend=backend, templates_percentile_shading=None, **self.backend_kwargs[backend]
)
unit_ids = self.sorting.unit_ids[:6]
sw.plot_unit_templates(
self.we_dense,
Expand All @@ -162,13 +165,70 @@ def test_plot_unit_templates(self):
backend=backend,
**self.backend_kwargs[backend],
)
# test different shadings
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=None,
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=None,
scale=10,
backend=backend,
**self.backend_kwargs[backend],
)
# test different shadings
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
backend=backend,
templates_percentile_shading=None,
shade_templates=False,
**self.backend_kwargs[backend],
)
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=5,
backend=backend,
**self.backend_kwargs[backend],
)
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=[10, 90],
backend=backend,
**self.backend_kwargs[backend],
)
if backend != "sortingview":
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=[1, 5, 25, 75, 95, 99],
backend=backend,
**self.backend_kwargs[backend],
)
else:
# sortingview doesn't support more than 2 shadings
with self.assertRaises(AssertionError):
sw.plot_unit_templates(
self.we_sparse,
sparsity=self.sparsity_best,
unit_ids=unit_ids,
templates_percentile_shading=[1, 5, 25, 75, 95, 99],
backend=backend,
**self.backend_kwargs[backend],
)

def test_plot_unit_waveforms_density_map(self):
possible_backends = list(sw.UnitWaveformDensityMapWidget.get_possible_backends())
Expand Down Expand Up @@ -423,7 +483,7 @@ def test_plot_multicomparison(self):
# mytest.test_plot_traces()
# mytest.test_plot_unit_waveforms()
# mytest.test_plot_unit_templates()
# mytest.test_plot_unit_templates()
mytest.test_plot_unit_templates()
# mytest.test_plot_unit_depths()
# mytest.test_plot_unit_templates()
# mytest.test_plot_unit_summary()
Expand All @@ -439,6 +499,6 @@ def test_plot_multicomparison(self):
# mytest.test_plot_rasters()
# mytest.test_plot_unit_probe_map()
# mytest.test_plot_unit_presence()
mytest.test_plot_multicomparison()
# mytest.test_plot_multicomparison()

plt.show()
13 changes: 10 additions & 3 deletions src/spikeinterface/widgets/unit_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

dp = to_attr(data_plot)

assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview"

# ensure serializable for sortingview
unit_id_to_channel_ids = dp.sparsity.unit_id_to_channel_ids
unit_id_to_channel_indices = dp.sparsity.unit_id_to_channel_indices
Expand All @@ -25,14 +27,19 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
for u_i, unit in enumerate(unit_ids):
templates_dict[unit] = {}
templates_dict[unit]["mean"] = dp.templates[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]]
templates_dict[unit]["std"] = dp.template_stds[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]]
if dp.do_shading:
templates_dict[unit]["shading"] = [
s[u_i].T.astype("float32")[unit_id_to_channel_indices[unit]] for s in dp.templates_shading
]
else:
templates_dict[unit]["shading"] = None

aw_items = [
vv.AverageWaveformItem(
unit_id=u,
channel_ids=list(unit_id_to_channel_ids[u]),
waveform=t["mean"].astype("float32"),
waveform_std_dev=t["std"].astype("float32"),
waveform=t["mean"],
waveform_percentiles=t["shading"],
)
for u, t in templates_dict.items()
]
Expand Down
Loading

0 comments on commit 48a543e

Please sign in to comment.