From 3d792951a6036849b5d82ea523bb6cc20e784a07 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 17:09:32 +0200 Subject: [PATCH] port plot_probe_map() to new widgets API --- .../widgets/_legacy_mpl_widgets/__init__.py | 1 - .../widgets/_legacy_mpl_widgets/probemap.py | 77 ------------------- .../widgets/tests/test_widgets.py | 8 +- src/spikeinterface/widgets/widget_list.py | 3 + 4 files changed, 10 insertions(+), 79 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 6013512022..af1419fb11 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,7 +1,6 @@ # basics # from .timeseries import plot_timeseries, TracesWidget from .rasters import plot_rasters, RasterWidget -from .probemap import plot_probe_map, ProbeMapWidget # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py deleted file mode 100644 index 6e6578a4c4..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ProbeMapWidget(BaseWidget): - """ - Plot the probe of a recording. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - channel_ids: list - The channel ids to display - with_channel_ids: bool False default - Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__(self, recording, channel_ids=None, with_channel_ids=False, figure=None, ax=None, **plot_probe_kwargs): - import matplotlib.pylab as plt - from probeinterface.plotting import plot_probe, get_auto_lims - - BaseWidget.__init__(self, figure, ax) - - if channel_ids is not None: - recording = recording.channel_slice(channel_ids) - self._recording = recording - self._probegroup = recording.get_probegroup() - self.with_channel_ids = with_channel_ids - self._plot_probe_kwargs = plot_probe_kwargs - - def plot(self): - self._do_plot() - - def _do_plot(self): - from probeinterface.plotting import get_auto_lims - - xlims, ylims, zlims = get_auto_lims(self._probegroup.probes[0]) - for i, probe in enumerate(self._probegroup.probes): - xlims2, ylims2, _ = get_auto_lims(probe) - xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) - ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) - - self._plot_probe_kwargs["title"] = False - pos = 0 - text_on_contact = None - for i, probe in enumerate(self._probegroup.probes): - n = probe.get_contact_count() - if self.with_channel_ids: - text_on_contact = self._recording.channel_ids[pos : pos + n] - pos += n - from probeinterface.plotting import plot_probe - - plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **self._plot_probe_kwargs) - - self.ax.set_xlim(*xlims) - self.ax.set_ylim(*ylims) - - -def plot_probe_map(*args, **kwargs): - W = ProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_probe_map.__doc__ = ProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 0aa309f748..bc0ec68041 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -336,6 +336,11 @@ def test_plot_confusion_matrix(self): if backend not in self.skip_backends: sw.plot_confusion_matrix(self.gt_comp) + def test_plot_probe_map(self): + possible_backends = list(sw.ProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_probe_map(self.recording, with_channel_ids=True, with_contact_id=True) if __name__ == "__main__": @@ -359,7 +364,8 @@ def test_plot_confusion_matrix(self): # mytest.test_template_metrics() # mytest.test_amplitudes() # mytest.test_plot_agreement_matrix() - mytest.test_plot_confusion_matrix() + # mytest.test_plot_confusion_matrix() + mytest.test_plot_probe_map() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index d02aa7de7a..77db17029f 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -9,6 +9,7 @@ from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget +from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget @@ -32,6 +33,7 @@ ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, + ProbeMapWidget, QualityMetricsWidget, SortingSummaryWidget, SpikeLocationsWidget, @@ -87,6 +89,7 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget +plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget