Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subwidget parameters for UnitSummaryWidget #3242

Merged
27 changes: 25 additions & 2 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import defaultdict

import numpy as np

Expand Down Expand Up @@ -30,6 +31,13 @@ class UnitSummaryWidget(BaseWidget):
sparsity : ChannelSparsity or None, default: None
Optional ChannelSparsity to apply.
If SortingAnalyzer is already sparse, the argument is ignored
widget_params : dict or None, default: None
Parameters for the subwidgets in a nested dictionary
unitlocations_params: UnitLocationsWidget (see UnitLocationsWidget for details)
unitwaveforms_params: UnitWaveformsWidget (see UnitWaveformsWidget for details)
unitwaveformdensitymap_params : UnitWaveformDensityMapWidget (see UnitWaveformDensityMapWidget for details)
autocorrelograms_params : AutoCorrelogramsWidget (see AutoCorrelogramsWidget for details)
amplitudes_params : AmplitudesWidget (see AmplitudesWidget for details)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
"""

# possible_backends = {}
Expand All @@ -40,21 +48,24 @@ def __init__(
unit_id,
unit_colors=None,
sparsity=None,
radius_um=100,
widget_params=None,
backend=None,
**backend_kwargs,
):

sorting_analyzer = self.ensure_sorting_analyzer(sorting_analyzer)

if unit_colors is None:
unit_colors = get_unit_colors(sorting_analyzer)

if widget_params is None:
widget_params = dict()

plot_data = dict(
sorting_analyzer=sorting_analyzer,
unit_id=unit_id,
unit_colors=unit_colors,
sparsity=sparsity,
widget_params=widget_params,
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)
Expand All @@ -70,6 +81,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
unit_colors = dp.unit_colors
sparsity = dp.sparsity

widget_params = defaultdict(lambda: dict(), dp.widget_params)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool! I did not know about that defaultdict. Might be adding a quick comment like # Returns an empty dict if the key is not in passed in dp.widget_params, but maybe this is quite widely known and doesn't need the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added a comment :)

unitlocationswidget_params = widget_params["unitlocations_params"]
unitwaveformswidget_params = widget_params["unitwaveforms_params"]
unitwaveformdensitymapwidget_params = widget_params["unitwaveformdensitymap_params"]
autocorrelogramswidget_params = widget_params["autocorrelograms_params"]
amplitudeswidget_params = widget_params["amplitudes_params"]

# force the figure without axes
if "figsize" not in backend_kwargs:
backend_kwargs["figsize"] = (18, 7)
Expand Down Expand Up @@ -99,6 +117,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
plot_legend=False,
backend="matplotlib",
ax=ax1,
**unitlocationswidget_params,
)

unit_locations = sorting_analyzer.get_extension("unit_locations").get_data(outputs="by_unit")
Expand All @@ -121,6 +140,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
sparsity=sparsity,
backend="matplotlib",
ax=ax2,
**unitwaveformswidget_params,
)

ax2.set_title(None)
Expand All @@ -134,6 +154,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
same_axis=False,
backend="matplotlib",
ax=ax3,
**unitwaveformdensitymapwidget_params,
)
ax3.set_ylabel(None)

Expand All @@ -145,6 +166,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
unit_colors=unit_colors,
backend="matplotlib",
ax=ax4,
**autocorrelogramswidget_params,
)

ax4.set_title(None)
Expand All @@ -162,6 +184,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
plot_histograms=True,
backend="matplotlib",
axes=axes,
**amplitudeswidget_params,
)

fig.suptitle(f"unit_id: {dp.unit_id}")