Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subwidget parameters for UnitSummaryWidget #3242

Merged
Merged
44 changes: 37 additions & 7 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import defaultdict

import numpy as np

Expand All @@ -17,19 +18,27 @@ class UnitSummaryWidget(BaseWidget):
"""
Plot a unit summary.

If amplitudes are alreday computed they are displayed.
If amplitudes are alreday computed, they are displayed.

Parameters
----------
sorting_analyzer : SortingAnalyzer
sorting_analyzer: SortingAnalyzer
florian6973 marked this conversation as resolved.
Show resolved Hide resolved
The SortingAnalyzer object
unit_id : int or str
unit_id: int or str
The unit id to plot the summary of
unit_colors : dict or None, default: None
unit_colors: dict or None, default: None
If given, a dictionary with unit ids as keys and colors as values,
sparsity : ChannelSparsity or None, default: None
sparsity: ChannelSparsity or None, default: None
Optional ChannelSparsity to apply.
If SortingAnalyzer is already sparse, the argument is ignored
subwidget_kwargs: dict or None, default: None
Parameters for the subwidgets in a nested dictionary
unit_locations: UnitLocationsWidget (see UnitLocationsWidget for details)
unit_waveforms: UnitWaveformsWidget (see UnitWaveformsWidget for details)
unit_waveform_density_map: UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details)
autocorrelograms: AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details)
amplitudes: AmplitudesWidget (see AmplitudesWidget for details)
Please note that the unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary.
"""

# possible_backends = {}
Expand All @@ -40,21 +49,29 @@ def __init__(
unit_id,
unit_colors=None,
sparsity=None,
radius_um=100,
subwidget_kwargs=None,
backend=None,
**backend_kwargs,
):

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer)

if subwidget_kwargs is None:
subwidget_kwargs = dict()
for kwargs in subwidget_kwargs.values():
if "unit_colors" in kwargs:
raise ValueError(
"unit_colors should not be set in subwidget_kwargs, but directly as a parameter of plot_unit_summary"
)

plot_data = dict(
sorting_analyzer=sorting_analyzer,
unit_id=unit_id,
unit_colors=unit_colors,
sparsity=sparsity,
subwidget_kwargs=subwidget_kwargs,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
Expand All @@ -70,6 +87,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
unit_colors = dp.unit_colors
sparsity = dp.sparsity

# defaultdict returns empty dict if key not found in subwidget_kwargs
subwidget_kwargs = defaultdict(lambda: dict(), dp.subwidget_kwargs)
unitlocationswidget_kwargs = subwidget_kwargs["unit_locations"]
unitwaveformswidget_kwargs = subwidget_kwargs["unit_waveforms"]
unitwaveformdensitymapwidget_kwargs = subwidget_kwargs["unit_waveform_density_map"]
autocorrelogramswidget_kwargs = subwidget_kwargs["autocorrelograms"]
amplitudeswidget_kwargs = subwidget_kwargs["amplitudes"]

# force the figure without axes
if "figsize" not in backend_kwargs:
backend_kwargs["figsize"] = (18, 7)
Expand Down Expand Up @@ -99,6 +124,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
plot_legend=False,
backend="matplotlib",
ax=ax1,
**unitlocationswidget_kwargs,
)

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
Expand All @@ -121,6 +147,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
sparsity=sparsity,
backend="matplotlib",
ax=ax2,
**unitwaveformswidget_kwargs,
)

ax2.set_title(None)
Expand All @@ -134,6 +161,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
same_axis=False,
backend="matplotlib",
ax=ax3,
**unitwaveformdensitymapwidget_kwargs,
)
ax3.set_ylabel(None)

Expand All @@ -145,6 +173,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
unit_colors=unit_colors,
backend="matplotlib",
ax=ax4,
**autocorrelogramswidget_kwargs,
)

ax4.set_title(None)
Expand All @@ -162,6 +191,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
plot_histograms=True,
backend="matplotlib",
axes=axes,
**amplitudeswidget_kwargs,
)

fig.suptitle(f"unit_id: {dp.unit_id}")