Skip to content

Commit

Permalink
Merge pull request #2856 from alejoe91/plot_templates
Browse files Browse the repository at this point in the history
Extend plot waveforms/templates to Templates object
  • Loading branch information
samuelgarcia authored May 22, 2024
2 parents c52be0e + 3b4f46d commit 29ad02b
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 79 deletions.
10 changes: 10 additions & 0 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,16 @@ def test_plot_unit_templates(self):
backend=backend,
**self.backend_kwargs[backend],
)
# test with templates
templates_ext = self.sorting_analyzer_dense.get_extension("templates")
templates = templates_ext.get_data(outputs="Templates")
sw.plot_unit_templates(
templates,
sparsity=self.sparsity_strict,
unit_ids=unit_ids,
backend=backend,
**self.backend_kwargs[backend],
)
else:
# sortingview doesn't support more than 2 shadings
with self.assertRaises(AssertionError):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_depths.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
unit_ids = sorting_analyzer.sorting.unit_ids

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer.sorting)
unit_colors = get_unit_colors(sorting_analyzer)

colors = [unit_colors[unit_id] for unit_id in unit_ids]

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer.sorting)
unit_colors = get_unit_colors(sorting_analyzer)

plot_data = dict(
sorting_analyzer=sorting_analyzer,
Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface/widgets/unit_templates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from ..core import SortingAnalyzer
from .unit_waveforms import UnitWaveformsWidget
from .base import to_attr

Expand All @@ -17,6 +18,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs):

dp = to_attr(data_plot)

sorting_analyzer = dp.sorting_analyzer_or_templates
assert isinstance(sorting_analyzer, SortingAnalyzer), "This widget requires a SortingAnalyzer as input"

assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview"

# ensure serializable for sortingview
Expand Down Expand Up @@ -50,7 +54,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs):
v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations)

if not dp.hide_unit_selector:
v_units_table = generate_unit_table_view(dp.sorting_analyzer.sorting)
v_units_table = generate_unit_table_view(sorting_analyzer.sorting)

self.view = vv.Box(
direction="horizontal",
Expand Down
234 changes: 161 additions & 73 deletions src/spikeinterface/widgets/unit_waveforms.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/unit_waveforms_density_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
unit_ids = sorting_analyzer.unit_ids

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer.sorting)
unit_colors = get_unit_colors(sorting_analyzer)

if use_max_channel:
assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit"
Expand Down
11 changes: 9 additions & 2 deletions src/spikeinterface/widgets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,19 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB
return dict_colors


def get_unit_colors(sorting, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None):
def get_unit_colors(
sorting_or_analyzer_or_templates, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None
):
"""
Return a dict colors per units.
"""
colors = get_some_colors(
sorting.unit_ids, color_engine=color_engine, map_name=map_name, format=format, shuffle=shuffle, seed=seed
sorting_or_analyzer_or_templates.unit_ids,
color_engine=color_engine,
map_name=map_name,
format=format,
shuffle=shuffle,
seed=seed,
)
return colors

Expand Down
45 changes: 45 additions & 0 deletions src/spikeinterface/widgets/utils_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,51 @@ def value_changed(self, change=None):
self.update_label()


class WidenNarrowWidget(W.VBox):
value = traitlets.Float()

def __init__(self, value=1.0, factor=1.2, **kwargs):
assert factor > 1.0
self.factor = factor

self.scale_label = W.Label("Widen/Narrow", layout=W.Layout(width="95%", justify_content="center"))

self.right_selector = W.Button(
description="",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Increase horizontal scale",
icon="arrow-right",
# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"),
layout=W.Layout(width="60%", align_self="center"),
)

self.left_selector = W.Button(
description="",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Decrease horizontal scale",
icon="arrow-left",
# layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"),
layout=W.Layout(width="60%", align_self="center"),
)

self.right_selector.on_click(self.left_clicked)
self.left_selector.on_click(self.right_clicked)

self.value = value
super(W.VBox, self).__init__(
children=[self.scale_label, W.HBox([self.left_selector, self.right_selector])],
**kwargs,
)

def left_clicked(self, change=None):
self.value = self.value / self.factor

def right_clicked(self, change=None):
self.value = self.value * self.factor


class UnitSelector(W.VBox):
value = traitlets.List()

Expand Down

0 comments on commit 29ad02b

Please sign in to comment.