From 8054b3aa82a52495c3b22dea6ba2629cb22d42ac Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 May 2024 19:16:10 +0200 Subject: [PATCH 1/5] Extend plot waveforms/templates to Templates object --- src/spikeinterface/widgets/unit_depths.py | 2 +- src/spikeinterface/widgets/unit_summary.py | 2 +- src/spikeinterface/widgets/unit_templates.py | 6 +- src/spikeinterface/widgets/unit_waveforms.py | 233 ++++++++++++------ .../widgets/unit_waveforms_density_map.py | 2 +- src/spikeinterface/widgets/utils.py | 11 +- 6 files changed, 177 insertions(+), 79 deletions(-) diff --git a/src/spikeinterface/widgets/unit_depths.py b/src/spikeinterface/widgets/unit_depths.py index c5fe3e05e8..c2e9c06863 100644 --- a/src/spikeinterface/widgets/unit_depths.py +++ b/src/spikeinterface/widgets/unit_depths.py @@ -35,7 +35,7 @@ def __init__( unit_ids = sorting_analyzer.sorting.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting_analyzer.sorting) + unit_colors = get_unit_colors(sorting_analyzer) colors = [unit_colors[unit_id] for unit_id in unit_ids] diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index ea6476784e..0b2a348edf 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -48,7 +48,7 @@ def __init__( sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_colors is None: - unit_colors = get_unit_colors(sorting_analyzer.sorting) + unit_colors = get_unit_colors(sorting_analyzer) plot_data = dict( sorting_analyzer=sorting_analyzer, diff --git a/src/spikeinterface/widgets/unit_templates.py b/src/spikeinterface/widgets/unit_templates.py index 1350bb71a5..eb9a90d1d1 100644 --- a/src/spikeinterface/widgets/unit_templates.py +++ b/src/spikeinterface/widgets/unit_templates.py @@ -1,5 +1,6 @@ from __future__ import annotations +from ..core import SortingAnalyzer from .unit_waveforms import UnitWaveformsWidget from .base import to_attr @@ -17,6 +18,9 @@ def plot_sortingview(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) + sorting_analyzer = dp.sorting_analyzer_or_templates + assert isinstance(sorting_analyzer, SortingAnalyzer), "This widget requires a SortingAnalyzer as input" + assert len(dp.templates_shading) <= 4, "Only 2 ans 4 templates shading are supported in sortingview" # ensure serializable for sortingview @@ -50,7 +54,7 @@ def plot_sortingview(self, data_plot, **backend_kwargs): v_average_waveforms = vv.AverageWaveforms(average_waveforms=aw_items, channel_locations=locations) if not dp.hide_unit_selector: - v_units_table = generate_unit_table_view(dp.sorting_analyzer.sorting) + v_units_table = generate_unit_table_view(sorting_analyzer.sorting) self.view = vv.Box( direction="horizontal", diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index f701f9a868..2b3dc7ed34 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -6,7 +6,7 @@ from .base import BaseWidget, to_attr from .utils import get_unit_colors -from ..core import ChannelSparsity, SortingAnalyzer +from ..core import ChannelSparsity, SortingAnalyzer, Templates from ..core.basesorting import BaseSorting @@ -16,8 +16,9 @@ class UnitWaveformsWidget(BaseWidget): Parameters ---------- - sorting_analyzer : SortingAnalyzer - The SortingAnalyzer + sorting_analyzer_or_templates : SortingAnalyzer | Templates + The SortingAnalyzer or Templates object. + If Templates is given, the "plot_waveforms" argument is set to False channel_ids: list or None, default: None The channel ids to display unit_ids : list or None, default: None @@ -39,6 +40,8 @@ class UnitWaveformsWidget(BaseWidget): displayed per waveform, (matplotlib backend) scale : float, default: 1 Scale factor for the waveforms/templates (matplotlib backend) + widen_narrow_scale : float, default: 1 + Scale factor for the x-axis of the waveforms/templates (matplotlib backend) axis_equal : bool, default: False Equal aspect ratio for x and y axis, to visualize the array geometry to scale lw_waveforms : float, default: 1 @@ -64,6 +67,8 @@ class UnitWaveformsWidget(BaseWidget): are used for the lower bounds, and the second half for the upper bounds. Inner elements produce darker shadings. For sortingview backend only 2 or 4 elements are supported. + scalebar : bool, default: False + Display a scale bar on the waveforms plot (matplotlib backend) hide_unit_selector : bool, default: False For sortingview backend, if True the unit selector is not displayed same_axis : bool, default: False @@ -77,7 +82,7 @@ class UnitWaveformsWidget(BaseWidget): def __init__( self, - sorting_analyzer: SortingAnalyzer, + sorting_analyzer_or_templates: SortingAnalyzer | Templates, channel_ids=None, unit_ids=None, plot_waveforms=True, @@ -87,6 +92,7 @@ def __init__( sparsity=None, ncols=5, scale=1, + widen_narrow_scale=1, lw_waveforms=1, lw_templates=2, axis_equal=False, @@ -96,6 +102,7 @@ def __init__( same_axis=False, shade_templates=True, templates_percentile_shading=(1, 25, 75, 99), + scalebar=False, x_offset_units=False, alpha_waveforms=0.5, alpha_templates=1, @@ -104,29 +111,29 @@ def __init__( backend=None, **backend_kwargs, ): - - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) - sorting: BaseSorting = sorting_analyzer.sorting + if not isinstance(sorting_analyzer_or_templates, Templates): + sorting_analyzer_or_templates = self.ensure_sorting_analyzer(sorting_analyzer_or_templates) + else: + plot_waveforms = False + shade_templates = False if unit_ids is None: - unit_ids = sorting.unit_ids + unit_ids = sorting_analyzer_or_templates.unit_ids if channel_ids is None: - channel_ids = sorting_analyzer.channel_ids + channel_ids = sorting_analyzer_or_templates.channel_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting) - - channel_locations = sorting_analyzer.get_channel_locations()[ - sorting_analyzer.channel_ids_to_indices(channel_ids) - ] + unit_colors = get_unit_colors(sorting_analyzer_or_templates) + channel_indices = [list(sorting_analyzer_or_templates.channel_ids).index(ch) for ch in channel_ids] + channel_locations = sorting_analyzer_or_templates.get_channel_locations()[channel_indices] extra_sparsity = False - if sorting_analyzer.is_sparse(): + if sorting_analyzer_or_templates.sparsity is not None: if sparsity is None: - sparsity = sorting_analyzer.sparsity + sparsity = sorting_analyzer_or_templates.sparsity else: # assert provided sparsity is a subset of waveform sparsity - combined_mask = np.logical_or(sorting_analyzer.sparsity.mask, sparsity.mask) - assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer.sparsity.mask, 1) == 0), ( + combined_mask = np.logical_or(sorting_analyzer_or_templates.sparsity.mask, sparsity.mask) + assert np.all(np.sum(combined_mask, 1) - np.sum(sorting_analyzer_or_templates.sparsity.mask, 1) == 0), ( "The provided 'sparsity' needs to include only the sparse channels " "used to extract waveforms (for example, by using a smaller 'radius_um')." ) @@ -134,41 +141,49 @@ def __init__( else: if sparsity is None: # in this case, we construct a dense sparsity - unit_id_to_channel_ids = {u: sorting_analyzer.channel_ids for u in sorting_analyzer.unit_ids} + unit_id_to_channel_ids = { + u: sorting_analyzer_or_templates.channel_ids for u in sorting_analyzer_or_templates.unit_ids + } sparsity = ChannelSparsity.from_unit_id_to_channel_ids( unit_id_to_channel_ids=unit_id_to_channel_ids, - unit_ids=sorting_analyzer.unit_ids, - channel_ids=sorting_analyzer.channel_ids, + unit_ids=sorting_analyzer_or_templates.unit_ids, + channel_ids=sorting_analyzer_or_templates.channel_ids, ) else: assert isinstance(sparsity, ChannelSparsity), "'sparsity' should be a ChannelSparsity object!" # get templates - self.templates_ext = sorting_analyzer.get_extension("templates") - assert self.templates_ext is not None, "plot_waveforms() need extension 'templates'" - templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") - - if templates_percentile_shading is not None and not sorting_analyzer.has_extension("waveforms"): - warn( - "templates_percentile_shading can only be used if the 'waveforms' extension is available. " - "Settimg templates_percentile_shading to None." - ) - templates_percentile_shading = None - templates_shading = self._get_template_shadings(sorting_analyzer, unit_ids, templates_percentile_shading) - - xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( - sorting_analyzer, templates, channel_locations, x_offset_units - ) + if isinstance(sorting_analyzer_or_templates, Templates): + templates = sorting_analyzer_or_templates.templates_array + nbefore = sorting_analyzer_or_templates.nbefore + self.templates_ext = None + templates_shading = None + else: + self.templates_ext = sorting_analyzer_or_templates.get_extension("templates") + assert self.templates_ext is not None, "plot_waveforms() need extension 'templates'" + templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") + nbefore = self.templates_ext.nbefore + + if templates_percentile_shading is not None and not sorting_analyzer_or_templates.has_extension( + "waveforms" + ): + warn( + "templates_percentile_shading can only be used if the 'waveforms' extension is available. " + "Settimg templates_percentile_shading to None." + ) + templates_percentile_shading = None + templates_shading = self._get_template_shadings(unit_ids, templates_percentile_shading) wfs_by_ids = {} if plot_waveforms: - wf_ext = sorting_analyzer.get_extension("waveforms") + # this must be a sorting_analyzer + wf_ext = sorting_analyzer_or_templates.get_extension("waveforms") if wf_ext is None: raise ValueError("plot_waveforms() needs the extension 'waveforms'") for unit_id in unit_ids: - unit_index = list(sorting.unit_ids).index(unit_id) + unit_index = list(sorting_analyzer_or_templates.unit_ids).index(unit_id) if not extra_sparsity: - if sorting_analyzer.is_sparse(): + if sorting_analyzer_or_templates.is_sparse(): # wfs = we.get_waveforms(unit_id) wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) else: @@ -181,7 +196,7 @@ def __init__( # wfs = we.get_waveforms(unit_id) wfs = wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) # find additional slice to apply to sparse waveforms - (wfs_sparse_indices,) = np.nonzero(sorting_analyzer.sparsity.mask[unit_index]) + (wfs_sparse_indices,) = np.nonzero(sorting_analyzer_or_templates.sparsity.mask[unit_index]) (extra_sparse_indices,) = np.nonzero(sparsity.mask[unit_index]) (extra_slice,) = np.nonzero(np.isin(wfs_sparse_indices, extra_sparse_indices)) # apply extra sparsity @@ -189,14 +204,16 @@ def __init__( wfs_by_ids[unit_id] = wfs plot_data = dict( - sorting_analyzer=sorting_analyzer, - sampling_frequency=sorting_analyzer.sampling_frequency, + sorting_analyzer_or_templates=sorting_analyzer_or_templates, + sampling_frequency=sorting_analyzer_or_templates.sampling_frequency, + nbefore=nbefore, unit_ids=unit_ids, channel_ids=channel_ids, sparsity=sparsity, unit_colors=unit_colors, channel_locations=channel_locations, scale=scale, + widen_narrow_scale=widen_narrow_scale, templates=templates, templates_shading=templates_shading, do_shading=shade_templates, @@ -207,19 +224,16 @@ def __init__( unit_selected_waveforms=unit_selected_waveforms, axis_equal=axis_equal, max_spikes_per_unit=max_spikes_per_unit, - xvectors=xvectors, - y_scale=y_scale, - y_offset=y_offset, wfs_by_ids=wfs_by_ids, set_title=set_title, same_axis=same_axis, + scalebar=scalebar, templates_percentile_shading=templates_percentile_shading, x_offset_units=x_offset_units, lw_waveforms=lw_waveforms, lw_templates=lw_templates, alpha_waveforms=alpha_waveforms, alpha_templates=alpha_templates, - delta_x=delta_x, hide_unit_selector=hide_unit_selector, plot_legend=plot_legend, ) @@ -245,6 +259,10 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + xvectors, y_scale, y_offset, delta_x = get_waveforms_scales( + dp.templates, dp.channel_locations, dp.nbefore, dp.x_offset_units, dp.widen_narrow_scale + ) + for i, unit_id in enumerate(dp.unit_ids): if dp.same_axis: ax = self.ax @@ -253,7 +271,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): color = dp.unit_colors[unit_id] chan_inds = dp.sparsity.unit_id_to_channel_indices[unit_id] - xvectors_flat = dp.xvectors[:, chan_inds].T.flatten() + xvectors_flat = xvectors[:, chan_inds].T.flatten() # plot waveforms if dp.plot_waveforms: @@ -265,12 +283,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): random_idxs = np.random.permutation(len(wfs))[: dp.max_spikes_per_unit] wfs = wfs[random_idxs] - wfs = wfs * dp.y_scale + dp.y_offset[None, :, chan_inds] + wfs = wfs * y_scale + y_offset[None, :, chan_inds] wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1).T if dp.x_offset_units: # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x + xvec = xvectors_flat + i * 0.7 * delta_x else: xvec = xvectors_flat @@ -278,14 +296,33 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if not dp.plot_templates: ax.get_lines()[-1].set_label(f"{unit_id}") + if not dp.plot_templates and dp.scalebar and not dp.same_axis: + # xscale + min_wfs = np.min(wfs_flat) + wfs_for_scale = dp.wfs_by_ids[unit_id] * y_scale + offset = 0.1 * (np.max(wfs_flat) - np.min(wfs_flat)) + xargmin = np.nanargmin(xvec) + xscale_bar = [xvec[xargmin], xvec[xargmin + dp.nbefore]] + ax.plot(xscale_bar, [min_wfs - offset, min_wfs - offset], color="k") + nbefore_time = int(dp.nbefore / dp.sampling_frequency * 1000) + ax.text( + xscale_bar[0] + xscale_bar[1] // 3, min_wfs - 1.5 * offset, f"{nbefore_time} ms", fontsize=8 + ) + + # yscale + length = int(np.ptp(wfs_flat) // 5) + length_uv = int(np.ptp(wfs_for_scale) // 5) + x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2 + ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k") + ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90) # plot template if dp.plot_templates: - template = dp.templates[i, :, :][:, chan_inds] * dp.scale * dp.y_scale + dp.y_offset[:, chan_inds] + template = dp.templates[i, :, :][:, chan_inds] * dp.scale * y_scale + y_offset[:, chan_inds] if dp.x_offset_units: # 0.7 is to match spacing in xvect - xvec = xvectors_flat + i * 0.7 * dp.delta_x + xvec = xvectors_flat + i * 0.7 * delta_x else: xvec = xvectors_flat # plot template shading if waveforms are not plotted @@ -297,12 +334,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): shading_alphas = np.linspace(lightest_gray_alpha, darkest_gray_alpha, n_shadings) for s in range(n_shadings): lower_bound = ( - dp.templates_shading[s][i, :, :][:, chan_inds] * dp.scale * dp.y_scale - + dp.y_offset[:, chan_inds] + dp.templates_shading[s][i, :, :][:, chan_inds] * dp.scale * y_scale + y_offset[:, chan_inds] ) upper_bound = ( - dp.templates_shading[n_percentiles - 1 - s][i, :, :][:, chan_inds] * dp.scale * dp.y_scale - + dp.y_offset[:, chan_inds] + dp.templates_shading[n_percentiles - 1 - s][i, :, :][:, chan_inds] * dp.scale * y_scale + + y_offset[:, chan_inds] ) ax.fill_between( xvec, @@ -332,6 +368,26 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): if dp.set_title: ax.set_title(f"template {template_label}") + if not dp.plot_waveforms and dp.scalebar and not dp.same_axis: + # xscale + template_for_scale = dp.templates[i, :, :][:, chan_inds] * dp.scale + min_wfs = np.min(template) + offset = 0.1 * (np.max(template) - np.min(template)) + xargmin = np.nanargmin(xvec) + xscale_bar = [xvec[xargmin], xvec[xargmin + dp.nbefore]] + ax.plot(xscale_bar, [min_wfs - offset, min_wfs - offset], color="k") + nbefore_time = int(dp.nbefore / dp.sampling_frequency * 1000) + ax.text( + xscale_bar[0] + xscale_bar[1] // 3, min_wfs - 1.5 * offset, f"{nbefore_time} ms", fontsize=8 + ) + + # yscale + length = int(np.ptp(template) // 5) + length_uv = int(np.ptp(template_for_scale) // 5) + x_offset = xscale_bar[0] - np.ptp(xscale_bar) // 2 + ax.plot([xscale_bar[0], xscale_bar[0]], [min_wfs - offset, min_wfs - offset + length], color="k") + ax.text(x_offset, min_wfs - offset + length // 3, f"{length_uv} $\mu$V", fontsize=8, rotation=90) + # plot channels if dp.plot_channels: # TODO enhance this @@ -348,14 +404,19 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt import ipywidgets.widgets as widgets from IPython.display import display - from .utils_ipywidgets import check_ipywidget_backend, UnitSelector, ScaleWidget + from .utils_ipywidgets import check_ipywidget_backend, UnitSelector, ScaleWidget, WidenNarrowWidget check_ipywidget_backend() self.next_data_plot = data_plot.copy() cm = 1 / 2.54 - self.sorting_analyzer = data_plot["sorting_analyzer"] + if isinstance(data_plot["sorting_analyzer_or_templates"], SortingAnalyzer): + self.sorting_analyzer = data_plot["sorting_analyzer_or_templates"] + self.templates = None + else: + self.sorting_analyzer = None + self.templates = data_plot["sorting_analyzer_or_templates"] width_cm = backend_kwargs["width_cm"] height_cm = backend_kwargs["height_cm"] @@ -375,6 +436,7 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector = UnitSelector(data_plot["unit_ids"], layout=widgets.Layout(height="80%")) self.unit_selector.value = list(data_plot["unit_ids"])[:1] self.scaler = ScaleWidget(value=data_plot["scale"], layout=widgets.Layout(height="20%")) + self.widen_narrow = WidenNarrowWidget(value=1.0, layout=widgets.Layout(height="20%")) self.same_axis_button = widgets.Checkbox( value=False, @@ -400,10 +462,20 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): disabled=False, ) - footer = widgets.HBox( - [self.same_axis_button, self.plot_templates_button, self.template_shading_button, self.hide_axis_button] + self.scalebar = widgets.Checkbox( + value=False, + description="scalebar", + disabled=False, ) - left_sidebar = widgets.VBox([self.unit_selector, self.scaler]) + if self.sorting_analyzer is not None: + footer_list = [self.same_axis_button, self.template_shading_button, self.hide_axis_button, self.scalebar] + else: + footer_list = [self.same_axis_button, self.hide_axis_button, self.scalebar] + if data_plot["plot_waveforms"]: + footer_list.append(self.plot_templates_button) + + footer = widgets.HBox(footer_list) + left_sidebar = widgets.VBox([self.unit_selector, self.scaler, self.widen_narrow]) self.widget = widgets.AppLayout( center=self.fig_wf.canvas, @@ -418,13 +490,20 @@ def plot_ipywidgets(self, data_plot, **backend_kwargs): self.unit_selector.observe(self._update_plot, names="value", type="change") self.scaler.observe(self._update_plot, names="value", type="change") - for w in self.same_axis_button, self.plot_templates_button, self.template_shading_button, self.hide_axis_button: + self.widen_narrow.observe(self._update_plot, names="value", type="change") + for w in ( + self.same_axis_button, + self.plot_templates_button, + self.template_shading_button, + self.hide_axis_button, + self.scalebar, + ): w.observe(self._update_plot, names="value", type="change") if backend_kwargs["display"]: display(self.widget) - def _get_template_shadings(self, sorting_analyzer, unit_ids, templates_percentile_shading): + def _get_template_shadings(self, unit_ids, templates_percentile_shading): templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") if templates_percentile_shading is None: @@ -460,30 +539,40 @@ def _update_plot(self, change): hide_axis = self.hide_axis_button.value do_shading = self.template_shading_button.value - wf_ext = self.sorting_analyzer.get_extension("waveforms") - templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") + if self.sorting_analyzer is not None: + templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") + templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"]) + channel_locations = self.sorting_analyzer.get_channel_locations() + + else: + unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] + templates = self.templates.templates_array[unit_indices] + templates_shadings = None + channel_locations = self.templates.get_channel_locations() # matplotlib next_data_plot dict update at each call data_plot = self.next_data_plot data_plot["unit_ids"] = unit_ids data_plot["templates"] = templates - templates_shadings = self._get_template_shadings( - self.sorting_analyzer, unit_ids, data_plot["templates_percentile_shading"] - ) data_plot["templates_shading"] = templates_shadings data_plot["same_axis"] = same_axis data_plot["plot_templates"] = plot_templates data_plot["do_shading"] = do_shading data_plot["scale"] = self.scaler.value + data_plot["widen_narrow_scale"] = self.widen_narrow.value + + if same_axis: + self.scalebar.value = False + data_plot["scalebar"] = self.scalebar.value + if data_plot["plot_waveforms"]: + wf_ext = self.sorting_analyzer.get_extension("waveforms") data_plot["wfs_by_ids"] = { unit_id: wf_ext.get_waveforms_one_unit(unit_id, force_dense=False) for unit_id in unit_ids } # TODO option for plot_legend - backend_kwargs = {} - if same_axis: backend_kwargs["ax"] = self.fig_wf.add_subplot() data_plot["set_title"] = False @@ -502,7 +591,6 @@ def _update_plot(self, change): ax.axis("off") # update probe plot - channel_locations = self.sorting_analyzer.get_channel_locations() self.ax_probe.plot( channel_locations[:, 0], channel_locations[:, 1], ls="", marker="o", color="gray", markersize=2, alpha=0.5 ) @@ -529,7 +617,7 @@ def _update_plot(self, change): fig_probe.canvas.flush_events() -def get_waveforms_scales(sorting_analyzer, templates, channel_locations, x_offset_units=False): +def get_waveforms_scales(templates, channel_locations, nbefore, x_offset_units=False, widen_narrow_scale=1.0): """ Return scales and x_vector for templates plotting """ @@ -555,10 +643,9 @@ def get_waveforms_scales(sorting_analyzer, templates, channel_locations, x_offse y_offset = channel_locations[:, 1][None, :] - nbefore = sorting_analyzer.get_extension("templates").nbefore nsamples = templates.shape[1] - xvect = delta_x * (np.arange(nsamples) - nbefore) / nsamples * 0.7 + xvect = (delta_x * widen_narrow_scale) * (np.arange(nsamples) - nbefore) / nsamples * 0.7 if x_offset_units: ch_locs = channel_locations diff --git a/src/spikeinterface/widgets/unit_waveforms_density_map.py b/src/spikeinterface/widgets/unit_waveforms_density_map.py index 2e7ec883e6..6ef1a7a782 100644 --- a/src/spikeinterface/widgets/unit_waveforms_density_map.py +++ b/src/spikeinterface/widgets/unit_waveforms_density_map.py @@ -57,7 +57,7 @@ def __init__( unit_ids = sorting_analyzer.unit_ids if unit_colors is None: - unit_colors = get_unit_colors(sorting_analyzer.sorting) + unit_colors = get_unit_colors(sorting_analyzer) if use_max_channel: assert len(unit_ids) == 1, " UnitWaveformDensity : use_max_channel=True works only with one unit" diff --git a/src/spikeinterface/widgets/utils.py b/src/spikeinterface/widgets/utils.py index 29e6474ee9..3461fb179a 100644 --- a/src/spikeinterface/widgets/utils.py +++ b/src/spikeinterface/widgets/utils.py @@ -95,12 +95,19 @@ def get_some_colors(keys, color_engine="auto", map_name="gist_ncar", format="RGB return dict_colors -def get_unit_colors(sorting, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None): +def get_unit_colors( + sorting_or_analyzer_or_templates, color_engine="auto", map_name="gist_ncar", format="RGBA", shuffle=None, seed=None +): """ Return a dict colors per units. """ colors = get_some_colors( - sorting.unit_ids, color_engine=color_engine, map_name=map_name, format=format, shuffle=shuffle, seed=seed + sorting_or_analyzer_or_templates.unit_ids, + color_engine=color_engine, + map_name=map_name, + format=format, + shuffle=shuffle, + seed=seed, ) return colors From f073d5a559e28e490cfc95ce51f34a81244351dd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 15 May 2024 19:22:15 +0200 Subject: [PATCH 2/5] Add tests --- src/spikeinterface/widgets/tests/test_widgets.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 59ba29ef73..fd31bd31dc 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -1,4 +1,5 @@ import unittest +from numba.cuda import Out import pytest import os from pathlib import Path @@ -286,6 +287,16 @@ def test_plot_unit_templates(self): backend=backend, **self.backend_kwargs[backend], ) + # test with templates + templates_ext = self.sorting_analyzer_dense.get_extension("templates") + templates = templates_ext.get_data(outputs="Templates") + sw.plot_unit_templates( + templates, + sparsity=self.sparsity_strict, + unit_ids=unit_ids, + backend=backend, + **self.backend_kwargs[backend], + ) else: # sortingview doesn't support more than 2 shadings with self.assertRaises(AssertionError): From 80b09e54254c2a7dbb176ca16a9d4a27a9d0ba3e Mon Sep 17 00:00:00 2001 From: Garcia Samuel Date: Tue, 21 May 2024 15:58:14 +0200 Subject: [PATCH 3/5] Update src/spikeinterface/widgets/tests/test_widgets.py --- src/spikeinterface/widgets/tests/test_widgets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index fd31bd31dc..5366fb864f 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -1,5 +1,4 @@ import unittest -from numba.cuda import Out import pytest import os from pathlib import Path From f309c79dcc7ac7ea069f83482f0acfb28d7af83d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 21 May 2024 16:05:31 +0200 Subject: [PATCH 4/5] Add WidenNarrowWidget for templates --- .../widgets/utils_ipywidgets.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/spikeinterface/widgets/utils_ipywidgets.py b/src/spikeinterface/widgets/utils_ipywidgets.py index 6d91140500..0d81a8184f 100644 --- a/src/spikeinterface/widgets/utils_ipywidgets.py +++ b/src/spikeinterface/widgets/utils_ipywidgets.py @@ -329,6 +329,51 @@ def value_changed(self, change=None): self.update_label() +class WidenNarrowWidget(W.VBox): + value = traitlets.Float() + + def __init__(self, value=1.0, factor=1.2, **kwargs): + assert factor > 1.0 + self.factor = factor + + self.scale_label = W.Label("Widen/Narrow", layout=W.Layout(width="95%", justify_content="center")) + + self.right_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Increase horizontal scale", + icon="arrow-right", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width="60%", align_self="center"), + ) + + self.left_selector = W.Button( + description="", + disabled=False, + button_style="", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Decrease horizontal scale", + icon="arrow-left", + # layout=W.Layout(width=f"{0.8 * width_cm}cm", height=f"{0.4 * height_cm}cm"), + layout=W.Layout(width="60%", align_self="center"), + ) + + self.right_selector.on_click(self.left_clicked) + self.left_selector.on_click(self.right_clicked) + + self.value = value + super(W.VBox, self).__init__( + children=[self.scale_label, W.HBox([self.left_selector, self.right_selector])], + **kwargs, + ) + + def left_clicked(self, change=None): + self.value = self.value / self.factor + + def right_clicked(self, change=None): + self.value = self.value * self.factor + + class UnitSelector(W.VBox): value = traitlets.List() From 345f5edd9491532f634385aee0e0d3117cb220c6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 21 May 2024 16:08:03 +0200 Subject: [PATCH 5/5] Fix bug with analyzer --- src/spikeinterface/widgets/unit_waveforms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/unit_waveforms.py b/src/spikeinterface/widgets/unit_waveforms.py index 2b3dc7ed34..f6e16abaae 100644 --- a/src/spikeinterface/widgets/unit_waveforms.py +++ b/src/spikeinterface/widgets/unit_waveforms.py @@ -541,9 +541,10 @@ def _update_plot(self, change): if self.sorting_analyzer is not None: templates = self.templates_ext.get_templates(unit_ids=unit_ids, operator="average") - templates_shadings = self._get_template_shadings(unit_ids, data_plot["templates_percentile_shading"]) + templates_shadings = self._get_template_shadings( + unit_ids, self.next_data_plot["templates_percentile_shading"] + ) channel_locations = self.sorting_analyzer.get_channel_locations() - else: unit_indices = [list(self.templates.unit_ids).index(unit_id) for unit_id in unit_ids] templates = self.templates.templates_array[unit_indices]