diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 0b2a348edf..755e60ccbf 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections import defaultdict import numpy as np @@ -17,7 +18,7 @@ class UnitSummaryWidget(BaseWidget): """ Plot a unit summary. - If amplitudes are alreday computed they are displayed. + If amplitudes are alreday computed, they are displayed. Parameters ---------- @@ -30,6 +31,14 @@ class UnitSummaryWidget(BaseWidget): sparsity : ChannelSparsity or None, default: None Optional ChannelSparsity to apply. If SortingAnalyzer is already sparse, the argument is ignored + subwidget_kwargs : dict or None, default: None + Parameters for the subwidgets in a nested dictionary + unit_locations : UnitLocationsWidget (see UnitLocationsWidget for details) + unit_waveforms : UnitWaveformsWidget (see UnitWaveformsWidget for details) + unit_waveform_density_map : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details) + autocorrelograms : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details) + amplitudes : AmplitudesWidget (see AmplitudesWidget for details) + Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary. """ # possible_backends = {} @@ -40,21 +49,29 @@ def __init__( unit_id, unit_colors=None, sparsity=None, - radius_um=100, + subwidget_kwargs=None, backend=None, **backend_kwargs, ): - sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer) if unit_colors is None: unit_colors = get_unit_colors(sorting_analyzer) + if subwidget_kwargs is None: + subwidget_kwargs = dict() + for kwargs in subwidget_kwargs.values(): + if "unit_colors" in kwargs: + raise ValueError( + "unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary" + ) + plot_data = dict( sorting_analyzer=sorting_analyzer, unit_id=unit_id, unit_colors=unit_colors, sparsity=sparsity, + subwidget_kwargs=subwidget_kwargs, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -70,6 +87,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors = dp.unit_colors sparsity = dp.sparsity + # defaultdict returns empty dict if key not found in subwidget_kwargs + subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs) + unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"] + unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"] + unitwaveformdensitymapwidget_kwargs = subwidget_kwargs["unit_waveform_density_map"] + autocorrelogramswidget_kwargs = subwidget_kwargs["autocorrelograms"] + amplitudeswidget_kwargs = subwidget_kwargs["amplitudes"] + # force the figure without axes if "figsize" not in backend_kwargs: backend_kwargs["figsize"] = (18, 7) @@ -99,6 +124,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_legend=False, backend="matplotlib", ax=ax1, + **unitlocationswidget_kwargs, ) unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit") @@ -121,6 +147,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sparsity=sparsity, backend="matplotlib", ax=ax2, + **unitwaveformswidget_kwargs, ) ax2.set_title(None) @@ -134,6 +161,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): same_axis=False, backend="matplotlib", ax=ax3, + **unitwaveformdensitymapwidget_kwargs, ) ax3.set_ylabel(None) @@ -145,6 +173,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): unit_colors=unit_colors, backend="matplotlib", ax=ax4, + **autocorrelogramswidget_kwargs, ) ax4.set_title(None) @@ -162,6 +191,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): plot_histograms=True, backend="matplotlib", axes=axes, + **amplitudeswidget_kwargs, ) fig.suptitle(f"unit_id: {dp.unit_id}")