From 46c4ada52b95a7deeed4babf5bb40a9e775047d4 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 14:45:53 +0200 Subject: [PATCH 1/7] Port plot_agreement_matrix to new widgets API --- .../widgets/_legacy_mpl_widgets/__init__.py | 2 +- .../_legacy_mpl_widgets/agreementmatrix.py | 91 ------------------- .../widgets/tests/test_widgets.py | 10 +- src/spikeinterface/widgets/widget_list.py | 3 + 4 files changed, 13 insertions(+), 93 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c0dcd7ea6e..045b8acc8e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -17,7 +17,7 @@ # comparison related from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget -from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget + from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py deleted file mode 100644 index 369746e99b..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class AgreementMatrixWidget(BaseWidget): - """ - Plots sorting comparison confusion matrix. - - Parameters - ---------- - sorting_comparison: GroundTruthComparison or SymmetricSortingComparison - The sorting comparison object. - Symetric or not. - ordered: bool - Order units with best agreement scores. - This enable to see agreement on a diagonal. - count_text: bool - If True counts are displayed as text - unit_ticks: bool - If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - """ - - def __init__(self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._sc = sorting_comparison - self._ordered = ordered - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" - - def plot(self): - self._do_plot() - - def _do_plot(self): - # a dataframe - if self._ordered: - scores = self._sc.get_ordered_agreement_scores() - else: - scores = self._sc.agreement_scores - - N1 = scores.shape[0] - N2 = scores.shape[1] - - unit_ids1 = scores.index.values - unit_ids2 = scores.columns.values - - # Using matshow here just because it sets the ticks up nicely. imshow is faster. - self.ax.matshow(scores.values, cmap="Greens") - - if self._count_text: - for i, u1 in enumerate(unit_ids1): - u2 = self._sc.best_match_12[u1] - if u2 != -1: - j = np.where(unit_ids2 == u2)[0][0] - - self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") - - # Major ticks - self.ax.set_xticks(np.arange(0, N2)) - self.ax.set_yticks(np.arange(0, N1)) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - if self._unit_ticks: - self.ax.set_yticklabels(scores.index, fontsize=12) - self.ax.set_xticklabels(scores.columns, fontsize=12) - - self.ax.set_xlabel(self._sc.name_list[1], fontsize=20) - self.ax.set_ylabel(self._sc.name_list[0], fontsize=20) - - self.ax.set_xlim(-0.5, N2 - 0.5) - self.ax.set_ylim( - N1 - 0.5, - -0.5, - ) - - -def plot_agreement_matrix(*args, **kwargs): - W = AgreementMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_agreement_matrix.__doc__ = AgreementMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index a5f75ebf50..2f11e5ee3c 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -324,6 +324,13 @@ def test_sorting_summary(self): sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + def test_plot_agreement_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_agreement_matrix(self.gt_comp) + + if __name__ == "__main__": # unittest.main() @@ -344,7 +351,8 @@ def test_sorting_summary(self): # mytest.test_unit_locations() # mytest.test_quality_metrics() # mytest.test_template_metrics() - mytest.test_amplitudes() + # mytest.test_amplitudes() + mytest.test_plot_agreement_matrix() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 9c89b3981e..22b33e38aa 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,6 +2,7 @@ from .base import backend_kwargs_desc +from .agreement_matrix import AgreementMatrixWidget from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget @@ -23,6 +24,7 @@ widget_list = [ + AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, @@ -76,6 +78,7 @@ # make function for all widgets +plot_agreement_matrix = AgreementMatrixWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget From e49071e38394c039d70cbc083c8b5a2cbb785b1b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 14:53:01 +0200 Subject: [PATCH 2/7] Port plot_confusion_matrix to new API. --- .../widgets/_legacy_mpl_widgets/__init__.py | 3 - .../_legacy_mpl_widgets/confusionmatrix.py | 91 ------------------- .../widgets/tests/test_widgets.py | 9 +- src/spikeinterface/widgets/widget_list.py | 3 + 4 files changed, 11 insertions(+), 95 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 045b8acc8e..6013512022 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -15,9 +15,6 @@ # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# comparison related -from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget - from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py deleted file mode 100644 index 942b613fbf..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py +++ /dev/null @@ -1,91 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ConfusionMatrixWidget(BaseWidget): - """ - Plots sorting comparison confusion matrix. - - Parameters - ---------- - gt_comparison: GroundTruthComparison - The ground truth sorting comparison object - count_text: bool - If True counts are displayed as text - unit_ticks: bool - If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ConfusionMatrixWidget - The output widget - """ - - def __init__(self, gt_comparison, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._gtcomp = gt_comparison - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" - - def plot(self): - self._do_plot() - - def _do_plot(self): - # a dataframe - confusion_matrix = self._gtcomp.get_confusion_matrix() - - N1 = confusion_matrix.shape[0] - 1 - N2 = confusion_matrix.shape[1] - 1 - - # Using matshow here just because it sets the ticks up nicely. imshow is faster. - self.ax.matshow(confusion_matrix.values, cmap="Greens") - - if self._count_text: - for (i, j), z in np.ndenumerate(confusion_matrix.values): - if z != 0: - if z > np.max(confusion_matrix.values) / 2.0: - self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="white") - else: - self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="black") - - self.ax.axhline(int(N1 - 1) + 0.5, color="black") - self.ax.axvline(int(N2 - 1) + 0.5, color="black") - - # Major ticks - self.ax.set_xticks(np.arange(0, N2 + 1)) - self.ax.set_yticks(np.arange(0, N1 + 1)) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - if self._unit_ticks: - self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) - self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) - else: - self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) - self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) - - self.ax.set_xlabel(self._gtcomp.name_list[1], fontsize=20) - self.ax.set_ylabel(self._gtcomp.name_list[0], fontsize=20) - - self.ax.set_xlim(-0.5, N2 + 0.5) - self.ax.set_ylim( - N1 + 0.5, - -0.5, - ) - - -def plot_confusion_matrix(*args, **kwargs): - W = ConfusionMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_confusion_matrix.__doc__ = ConfusionMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 2f11e5ee3c..0aa309f748 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -330,6 +330,12 @@ def test_plot_agreement_matrix(self): if backend not in self.skip_backends: sw.plot_agreement_matrix(self.gt_comp) + def test_plot_confusion_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_confusion_matrix(self.gt_comp) + if __name__ == "__main__": @@ -352,7 +358,8 @@ def test_plot_agreement_matrix(self): # mytest.test_quality_metrics() # mytest.test_template_metrics() # mytest.test_amplitudes() - mytest.test_plot_agreement_matrix() + # mytest.test_plot_agreement_matrix() + mytest.test_plot_confusion_matrix() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 22b33e38aa..d02aa7de7a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -6,6 +6,7 @@ from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget +from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget from .quality_metrics import QualityMetricsWidget @@ -28,6 +29,7 @@ AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, QualityMetricsWidget, @@ -82,6 +84,7 @@ plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget +plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget plot_quality_metrics = QualityMetricsWidget From 3d792951a6036849b5d82ea523bb6cc20e784a07 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 17:09:32 +0200 Subject: [PATCH 3/7] port plot_probe_map() to new widgets API --- .../widgets/_legacy_mpl_widgets/__init__.py | 1 - .../widgets/_legacy_mpl_widgets/probemap.py | 77 ------------------- .../widgets/tests/test_widgets.py | 8 +- src/spikeinterface/widgets/widget_list.py | 3 + 4 files changed, 10 insertions(+), 79 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 6013512022..af1419fb11 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,7 +1,6 @@ # basics # from .timeseries import plot_timeseries, TracesWidget from .rasters import plot_rasters, RasterWidget -from .probemap import plot_probe_map, ProbeMapWidget # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py deleted file mode 100644 index 6e6578a4c4..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ProbeMapWidget(BaseWidget): - """ - Plot the probe of a recording. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - channel_ids: list - The channel ids to display - with_channel_ids: bool False default - Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__(self, recording, channel_ids=None, with_channel_ids=False, figure=None, ax=None, **plot_probe_kwargs): - import matplotlib.pylab as plt - from probeinterface.plotting import plot_probe, get_auto_lims - - BaseWidget.__init__(self, figure, ax) - - if channel_ids is not None: - recording = recording.channel_slice(channel_ids) - self._recording = recording - self._probegroup = recording.get_probegroup() - self.with_channel_ids = with_channel_ids - self._plot_probe_kwargs = plot_probe_kwargs - - def plot(self): - self._do_plot() - - def _do_plot(self): - from probeinterface.plotting import get_auto_lims - - xlims, ylims, zlims = get_auto_lims(self._probegroup.probes[0]) - for i, probe in enumerate(self._probegroup.probes): - xlims2, ylims2, _ = get_auto_lims(probe) - xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) - ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) - - self._plot_probe_kwargs["title"] = False - pos = 0 - text_on_contact = None - for i, probe in enumerate(self._probegroup.probes): - n = probe.get_contact_count() - if self.with_channel_ids: - text_on_contact = self._recording.channel_ids[pos : pos + n] - pos += n - from probeinterface.plotting import plot_probe - - plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **self._plot_probe_kwargs) - - self.ax.set_xlim(*xlims) - self.ax.set_ylim(*ylims) - - -def plot_probe_map(*args, **kwargs): - W = ProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_probe_map.__doc__ = ProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 0aa309f748..bc0ec68041 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -336,6 +336,11 @@ def test_plot_confusion_matrix(self): if backend not in self.skip_backends: sw.plot_confusion_matrix(self.gt_comp) + def test_plot_probe_map(self): + possible_backends = list(sw.ProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_probe_map(self.recording, with_channel_ids=True, with_contact_id=True) if __name__ == "__main__": @@ -359,7 +364,8 @@ def test_plot_confusion_matrix(self): # mytest.test_template_metrics() # mytest.test_amplitudes() # mytest.test_plot_agreement_matrix() - mytest.test_plot_confusion_matrix() + # mytest.test_plot_confusion_matrix() + mytest.test_plot_probe_map() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index d02aa7de7a..77db17029f 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -9,6 +9,7 @@ from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget +from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget @@ -32,6 +33,7 @@ ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, + ProbeMapWidget, QualityMetricsWidget, SortingSummaryWidget, SpikeLocationsWidget, @@ -87,6 +89,7 @@ plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget +plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget From 45012894a558a59903e7b87f235d5f85f7637711 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 18:41:27 +0200 Subject: [PATCH 4/7] Port plot_raster() to new API. --- .../widgets/_legacy_mpl_widgets/__init__.py | 4 - .../widgets/_legacy_mpl_widgets/rasters.py | 120 --------- .../tests/test_widgets_legacy.py | 48 +--- .../_legacy_mpl_widgets/timeseries_.py | 233 ------------------ .../widgets/tests/test_widgets.py | 10 +- src/spikeinterface/widgets/widget_list.py | 3 + 6 files changed, 13 insertions(+), 405 deletions(-) delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py delete mode 100644 src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index af1419fb11..9593f14d1c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,7 +1,3 @@ -# basics -# from .timeseries import plot_timeseries, TracesWidget -from .rasters import plot_rasters, RasterWidget - # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py deleted file mode 100644 index d05373103e..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class RasterWidget(BaseWidget): - """ - Plots spike train rasters. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - segment_index: None or int - The segment index. - unit_ids: list - List of unit ids - time_range: list - List with start time and end time - color: matplotlib color - The color to be used - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: RasterWidget - The output widget - """ - - def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._sorting = sorting - - if segment_index is None: - nseg = sorting.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - else: - segment_index = 0 - self.segment_index = segment_index - - self._unit_ids = unit_ids - self._figure = None - self._sampling_frequency = sorting.get_sampling_frequency() - self._color = color - self._max_frame = 0 - for unit_id in self._sorting.get_unit_ids(): - spike_train = self._sorting.get_unit_spike_train(unit_id, segment_index=self.segment_index) - if len(spike_train) > 0: - curr_max_frame = np.max(spike_train) - if curr_max_frame > self._max_frame: - self._max_frame = curr_max_frame - self._visible_trange = time_range - if self._visible_trange is None: - self._visible_trange = [0, self._max_frame] - else: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - self._visible_trange = [int(t * self._sampling_frequency) for t in time_range] - - self._visible_trange = self._fix_trange(self._visible_trange) - self.name = "Raster" - - def plot(self): - self._do_plot() - - def _do_plot(self): - units_ids = self._unit_ids - if units_ids is None: - units_ids = self._sorting.get_unit_ids() - import matplotlib.pyplot as plt - - with plt.rc_context({"axes.edgecolor": "gray"}): - for u_i, unit_id in enumerate(units_ids): - spiketrain = self._sorting.get_unit_spike_train( - unit_id, - start_frame=self._visible_trange[0], - end_frame=self._visible_trange[1], - segment_index=self.segment_index, - ) - spiketimes = spiketrain / float(self._sampling_frequency) - self.ax.plot( - spiketimes, - u_i * np.ones_like(spiketimes), - marker="|", - mew=1, - markersize=3, - ls="", - color=self._color, - ) - visible_start_frame = self._visible_trange[0] / self._sampling_frequency - visible_end_frame = self._visible_trange[1] / self._sampling_frequency - self.ax.set_yticks(np.arange(len(units_ids))) - self.ax.set_yticklabels(units_ids) - self.ax.set_xlim(visible_start_frame, visible_end_frame) - self.ax.set_xlabel("time (s)") - - def _fix_trange(self, trange): - if trange[1] > self._max_frame: - # trange[0] += max_t - trange[1] - trange[1] = self._max_frame - if trange[0] < 0: - # trange[1] += -trange[0] - trange[0] = 0 - # trange[0] = np.maximum(0, trange[0]) - # trange[1] = np.minimum(max_t, trange[1]) - return trange - - -def plot_rasters(*args, **kwargs): - W = RasterWidget(*args, **kwargs) - W.plot() - return W - - -plot_rasters.__doc__ = RasterWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 5004765251..defe10f0d4 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -43,43 +43,7 @@ def setUp(self): def tearDown(self): pass - # def test_timeseries(self): - # sw.plot_timeseries(self._rec, mode='auto') - # sw.plot_timeseries(self._rec, mode='line', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True, order_channel_by_depth=True) - - def test_rasters(self): - sw.plot_rasters(self._sorting) - - def test_plot_probe_map(self): - sw.plot_probe_map(self._rec) - sw.plot_probe_map(self._rec, with_channel_ids=True) - - # TODO - # def test_spectrum(self): - # sw.plot_spectrum(self._rec) - - # TODO - # def test_spectrogram(self): - # sw.plot_spectrogram(self._rec, channel=0) - - # def test_unitwaveforms(self): - # w = sw.plot_unit_waveforms(self._we) - # unit_ids = self._sorting.unit_ids[:6] - # sw.plot_unit_waveforms(self._we, max_channels=5, unit_ids=unit_ids) - # sw.plot_unit_waveforms(self._we, radius_um=60, unit_ids=unit_ids) - - # def test_plot_unit_waveform_density_map(self): - # unit_ids = self._sorting.unit_ids[:3] - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=4) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=50) - # - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=25, same_axis=True) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=2, same_axis=True) - - # def test_unittemplates(self): - # sw.plot_unit_templates(self._we) + def test_plot_unit_probe_map(self): sw.plot_unit_probe_map(self._we, with_channel_ids=True) @@ -120,12 +84,6 @@ def test_plot_peak_activity_map(self): sw.plot_peak_activity_map(self._rec, with_channel_ids=True) sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) - def test_confusion(self): - sw.plot_confusion_matrix(self._gt_comp, count_text=True) - - def test_agreement(self): - sw.plot_agreement_matrix(self._gt_comp, count_text=True) - def test_multicomp_graph(self): msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) sw.plot_multicomp_graph(msc, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False) @@ -150,8 +108,6 @@ def test_sorting_performance(self): mytest.setUp() # ~ mytest.test_timeseries() - # ~ mytest.test_rasters() - mytest.test_plot_probe_map() # ~ mytest.test_unitwaveforms() # ~ mytest.test_plot_unit_waveform_density_map() # mytest.test_unittemplates() @@ -169,8 +125,6 @@ def test_sorting_performance(self): # ~ mytest.test_plot_drift_over_time() # ~ mytest.test_plot_peak_activity_map() - # mytest.test_confusion() - # mytest.test_agreement() # ~ mytest.test_multicomp_graph() #  mytest.test_sorting_performance() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py deleted file mode 100644 index ab6fa2ace5..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ /dev/null @@ -1,233 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from matplotlib.ticker import MaxNLocator -from .basewidget import BaseWidget - -import scipy.spatial - - -class TracesWidget(BaseWidget): - """ - Plots recording timeseries. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - segment_index: None or int - The segment index. - channel_ids: list - The channel ids to display. - order_channel_by_depth: boolean - Reorder channel by depth. - time_range: list - List with start time and end time - mode: 'line' or 'map' or 'auto' - 2 possible mode: - * 'line' : classical for low channel count - * 'map' : for high channel count use color heat map - * 'auto' : auto switch depending the channel count <32ch - cmap: str default 'RdBu' - matplotlib colormap used in mode 'map' - show_channel_ids: bool - Set yticks with channel ids - color_groups: bool - If True groups are plotted with different colors - color: matplotlib color, default: None - The color used to draw the traces. - clim: None or tupple - When mode='map' this control color lims - with_colorbar: bool default True - When mode='map' add colorbar - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: TracesWidget - The output widget - """ - - def __init__( - self, - recording, - segment_index=None, - channel_ids=None, - order_channel_by_depth=False, - time_range=None, - mode="auto", - cmap="RdBu", - show_channel_ids=False, - color_groups=False, - color=None, - clim=None, - with_colorbar=True, - figure=None, - ax=None, - **plot_kwargs, - ): - BaseWidget.__init__(self, figure, ax) - self.recording = recording - self._sampling_frequency = recording.get_sampling_frequency() - self.visible_channel_ids = channel_ids - self._plot_kwargs = plot_kwargs - - if segment_index is None: - nseg = recording.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 - self.segment_index = segment_index - - if self.visible_channel_ids is None: - self.visible_channel_ids = recording.get_channel_ids() - - if order_channel_by_depth: - locations = self.recording.get_channel_locations() - channel_inds = self.recording.ids_to_indices(self.visible_channel_ids) - locations = locations[channel_inds, :] - origin = np.array([np.max(locations[:, 0]), np.min(locations[:, 1])])[None, :] - dist = scipy.spatial.distance.cdist(locations, origin, metric="euclidean") - dist = dist[:, 0] - self.order = np.argsort(dist) - else: - self.order = None - - if channel_ids is None: - channel_ids = recording.get_channel_ids() - - fs = recording.get_sampling_frequency() - if time_range is None: - time_range = (0, 1.0) - time_range = np.array(time_range) - - assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" - if mode == "auto": - if len(channel_ids) <= 64: - mode = "line" - else: - mode = "map" - self.mode = mode - self.cmap = cmap - - self.show_channel_ids = show_channel_ids - - self._frame_range = (time_range * fs).astype("int64") - a_max = self.recording.get_num_frames(segment_index=self.segment_index) - self._frame_range = np.clip(self._frame_range, 0, a_max) - self._time_range = [e / fs for e in self._frame_range] - - self.clim = clim - self.with_colorbar = with_colorbar - - self._initialize_stats() - - # self._vspacing = self._mean_channel_std * 20 - self._vspacing = self._max_channel_amp * 1.5 - - if recording.get_channel_groups() is None: - color_groups = False - - self._color_groups = color_groups - self._color = color - if color_groups: - self._colors = [] - self._group_color_map = {} - all_groups = recording.get_channel_groups() - groups = np.unique(all_groups) - N = len(groups) - import colorsys - - HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)] - self._colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)) - color_idx = 0 - for group in groups: - self._group_color_map[group] = color_idx - color_idx += 1 - self.name = "TimeSeries" - - def plot(self): - self._do_plot() - - def _do_plot(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - if self.order is not None: - chunk0 = chunk0[:, self.order] - self.visible_channel_ids = np.array(self.visible_channel_ids)[self.order] - - ax = self.ax - - n = len(self.visible_channel_ids) - - if self.mode == "line": - ax.set_xlim( - self._frame_range[0] / self._sampling_frequency, self._frame_range[1] / self._sampling_frequency - ) - ax.set_ylim(-self._vspacing, self._vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.get_yaxis().set_ticks([]) - ax.set_xlabel("time (s)") - - self._plots = {} - self._plot_offsets = {} - offset0 = self._vspacing * (n - 1) - times = np.arange(self._frame_range[0], self._frame_range[1]) / self._sampling_frequency - for im, m in enumerate(self.visible_channel_ids): - self._plot_offsets[m] = offset0 - if self._color_groups: - group = self.recording.get_channel_groups(channel_ids=[m])[0] - group_color_idx = self._group_color_map[group] - color = self._colors[group_color_idx] - else: - color = self._color - self._plots[m] = ax.plot(times, self._plot_offsets[m] + chunk0[:, im], color=color, **self._plot_kwargs) - offset0 = offset0 - self._vspacing - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) * self._vspacing) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - elif self.mode == "map": - extent = (self._time_range[0], self._time_range[1], 0, self.recording.get_num_channels()) - im = ax.imshow( - chunk0.T, interpolation="nearest", origin="upper", aspect="auto", extent=extent, cmap=self.cmap - ) - - if self.clim is None: - im.set_clim(-self._max_channel_amp, self._max_channel_amp) - else: - im.set_clim(*self.clim) - - if self.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) + 0.5) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - def _initialize_stats(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - - self._mean_channel_std = np.mean(np.std(chunk0, axis=0)) - self._max_channel_amp = np.max(np.max(np.abs(chunk0), axis=0)) - - -def plot_timeseries(*args, **kwargs): - W = TracesWidget(*args, **kwargs) - W.plot() - return W - - -plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index bc0ec68041..509194cb93 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -342,6 +342,13 @@ def test_plot_probe_map(self): if backend not in self.skip_backends: sw.plot_probe_map(self.recording, with_channel_ids=True, with_contact_id=True) + def test_plot_rasters(self): + possible_backends = list(sw.RasterWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_rasters(self.sorting) + + if __name__ == "__main__": # unittest.main() @@ -365,7 +372,8 @@ def test_plot_probe_map(self): # mytest.test_amplitudes() # mytest.test_plot_agreement_matrix() # mytest.test_plot_confusion_matrix() - mytest.test_plot_probe_map() + # mytest.test_plot_probe_map() + mytest.test_plot_rasters() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 77db17029f..6ea2593432 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -11,6 +11,7 @@ from .motion import MotionWidget from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget +from .rasters import RasterWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget from .spikes_on_traces import SpikesOnTracesWidget @@ -35,6 +36,7 @@ MotionWidget, ProbeMapWidget, QualityMetricsWidget, + RasterWidget, SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, @@ -91,6 +93,7 @@ plot_motion = MotionWidget plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget +plot_rasters = RasterWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget From 625ff5e35219d397215413bebdb4f64dac8f0707 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Tue, 19 Sep 2023 18:44:28 +0200 Subject: [PATCH 5/7] Oups. --- .../widgets/agreement_matrix.py | 91 ++++++++++++++++++ .../widgets/confusion_matrix.py | 83 ++++++++++++++++ src/spikeinterface/widgets/probe_map.py | 78 +++++++++++++++ src/spikeinterface/widgets/rasters.py | 95 +++++++++++++++++++ 4 files changed, 347 insertions(+) create mode 100644 src/spikeinterface/widgets/agreement_matrix.py create mode 100644 src/spikeinterface/widgets/confusion_matrix.py create mode 100644 src/spikeinterface/widgets/probe_map.py create mode 100644 src/spikeinterface/widgets/rasters.py diff --git a/src/spikeinterface/widgets/agreement_matrix.py b/src/spikeinterface/widgets/agreement_matrix.py new file mode 100644 index 0000000000..55f38f078b --- /dev/null +++ b/src/spikeinterface/widgets/agreement_matrix.py @@ -0,0 +1,91 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + + +class AgreementMatrixWidget(BaseWidget): + """ + Plot unit depths + + Parameters + ---------- + sorting_comparison: GroundTruthComparison or SymmetricSortingComparison + The sorting comparison object. + Symetric or not. + ordered: bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. + count_text: bool + If True counts are displayed as text + unit_ticks: bool + If True unit tick labels are displayed + + """ + + def __init__( + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, + backend=None, **backend_kwargs + ): + plot_data = dict( + sorting_comparison=sorting_comparison, + ordered=ordered, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + comp = dp.sorting_comparison + + if dp.ordered: + scores = comp.get_ordered_agreement_scores() + else: + scores = comp.agreement_scores + + N1 = scores.shape[0] + N2 = scores.shape[1] + + unit_ids1 = scores.index.values + unit_ids2 = scores.columns.values + + # Using matshow here just because it sets the ticks up nicely. imshow is faster. + self.ax.matshow(scores.values, cmap="Greens") + + if dp.count_text: + for i, u1 in enumerate(unit_ids1): + u2 = comp.best_match_12[u1] + if u2 != -1: + j = np.where(unit_ids2 == u2)[0][0] + + self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") + + # Major ticks + self.ax.set_xticks(np.arange(0, N2)) + self.ax.set_yticks(np.arange(0, N1)) + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + if dp.unit_ticks: + self.ax.set_yticklabels(scores.index, fontsize=12) + self.ax.set_xticklabels(scores.columns, fontsize=12) + + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) + + self.ax.set_xlim(-0.5, N2 - 0.5) + self.ax.set_ylim( + N1 - 0.5, + -0.5, + ) + + diff --git a/src/spikeinterface/widgets/confusion_matrix.py b/src/spikeinterface/widgets/confusion_matrix.py new file mode 100644 index 0000000000..da021092db --- /dev/null +++ b/src/spikeinterface/widgets/confusion_matrix.py @@ -0,0 +1,83 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr +from .utils import get_unit_colors + + + +class ConfusionMatrixWidget(BaseWidget): + """ + Plot unit depths + + Parameters + ---------- + gt_comparison: GroundTruthComparison + The ground truth sorting comparison object + count_text: bool + If True counts are displayed as text + unit_ticks: bool + If True unit tick labels are displayed + + """ + + def __init__( + self, gt_comparison, count_text=True, unit_ticks=True, + backend=None, **backend_kwargs + ): + plot_data = dict( + gt_comparison=gt_comparison, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + comp = dp.gt_comparison + + confusion_matrix = comp.get_confusion_matrix() + N1 = confusion_matrix.shape[0] - 1 + N2 = confusion_matrix.shape[1] - 1 + + # Using matshow here just because it sets the ticks up nicely. imshow is faster. + self.ax.matshow(confusion_matrix.values, cmap="Greens") + + if dp.count_text: + for (i, j), z in np.ndenumerate(confusion_matrix.values): + if z != 0: + if z > np.max(confusion_matrix.values) / 2.0: + self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="white") + else: + self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="black") + + self.ax.axhline(int(N1 - 1) + 0.5, color="black") + self.ax.axvline(int(N2 - 1) + 0.5, color="black") + + # Major ticks + self.ax.set_xticks(np.arange(0, N2 + 1)) + self.ax.set_yticks(np.arange(0, N1 + 1)) + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + if dp.unit_ticks: + self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) + self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) + else: + self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) + self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) + + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) + + self.ax.set_xlim(-0.5, N2 + 0.5) + self.ax.set_ylim( + N1 + 0.5, + -0.5, + ) \ No newline at end of file diff --git a/src/spikeinterface/widgets/probe_map.py b/src/spikeinterface/widgets/probe_map.py new file mode 100644 index 0000000000..193711a34f --- /dev/null +++ b/src/spikeinterface/widgets/probe_map.py @@ -0,0 +1,78 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs +from .utils import get_unit_colors + + + +class ProbeMapWidget(BaseWidget): + """ + Plot the probe of a recording. + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object + channel_ids: list + The channel ids to display + with_channel_ids: bool False default + Add channel ids text on the probe + **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function + + """ + + def __init__( + self, recording, channel_ids=None, with_channel_ids=False, + backend=None, **backend_or_plot_probe_kwargs + ): + + # split backend_or_plot_probe_kwargs + backend_kwargs = dict() + plot_probe_kwargs = dict() + backend = self.check_backend(backend) + for k, v in backend_or_plot_probe_kwargs.items(): + if k in default_backend_kwargs[backend]: + backend_kwargs[k] = v + else: + plot_probe_kwargs[k] = v + + plot_data = dict( + recording=recording, + channel_ids=channel_ids, + with_channel_ids=with_channel_ids, + plot_probe_kwargs=plot_probe_kwargs, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import get_auto_lims, plot_probe + + dp = to_attr(data_plot) + + plot_probe_kwargs = dp.plot_probe_kwargs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + probegroup = dp.recording.get_probegroup() + + xlims, ylims, zlims = get_auto_lims(probegroup.probes[0]) + for i, probe in enumerate(probegroup.probes): + xlims2, ylims2, _ = get_auto_lims(probe) + xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) + ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) + + plot_probe_kwargs["title"] = False + pos = 0 + text_on_contact = None + for i, probe in enumerate(probegroup.probes): + n = probe.get_contact_count() + if dp.with_channel_ids: + text_on_contact = dp.recording.channel_ids[pos : pos + n] + pos += n + plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **plot_probe_kwargs) + + self.ax.set_xlim(*xlims) + self.ax.set_ylim(*ylims) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py new file mode 100644 index 0000000000..de855ebe45 --- /dev/null +++ b/src/spikeinterface/widgets/rasters.py @@ -0,0 +1,95 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs + + + +class RasterWidget(BaseWidget): + """ + Plots spike train rasters. + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + segment_index: None or int + The segment index. + unit_ids: list + List of unit ids + time_range: list + List with start time and end time + color: matplotlib color + The color to be used + """ + + def __init__( + self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", + backend=None, **backend_kwargs + ): + + + if segment_index is None: + if sorting.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 + + if time_range is None: + frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] + time_range = [f / sorting.sampling_frequency for f in frame_range] + else: + assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + frame_range = [int(t * sorting.sampling_frequency) for t in time_range] + + plot_data = dict( + sorting=sorting, + segment_index=segment_index, + unit_ids=unit_ids, + color=color, + frame_range=frame_range, + time_range=time_range, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting = dp.sorting + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + units_ids = dp.unit_ids + if units_ids is None: + units_ids = sorting.unit_ids + + with plt.rc_context({"axes.edgecolor": "gray"}): + for unit_index, unit_id in enumerate(units_ids): + spiketrain = sorting.get_unit_spike_train( + unit_id, + start_frame=dp.frame_range[0], + end_frame=dp.frame_range[1], + segment_index=dp.segment_index, + ) + spiketimes = spiketrain / float(sorting.sampling_frequency) + self.ax.plot( + spiketimes, + unit_index * np.ones_like(spiketimes), + marker="|", + mew=1, + markersize=3, + ls="", + color=dp.color, + ) + self.ax.set_yticks(np.arange(len(units_ids))) + self.ax.set_yticklabels(units_ids) + self.ax.set_xlim(*dp.time_range) + self.ax.set_xlabel("time (s)") + + + + + + + From 84051d1515a444a3174a4642029ed02aa69d755e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 21 Sep 2023 10:41:27 +0200 Subject: [PATCH 6/7] oups --- src/spikeinterface/widgets/agreement_matrix.py | 2 +- src/spikeinterface/widgets/confusion_matrix.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/widgets/agreement_matrix.py b/src/spikeinterface/widgets/agreement_matrix.py index 55f38f078b..22617f6be0 100644 --- a/src/spikeinterface/widgets/agreement_matrix.py +++ b/src/spikeinterface/widgets/agreement_matrix.py @@ -8,7 +8,7 @@ class AgreementMatrixWidget(BaseWidget): """ - Plot unit depths + Plots sorting comparison agreement matrix. Parameters ---------- diff --git a/src/spikeinterface/widgets/confusion_matrix.py b/src/spikeinterface/widgets/confusion_matrix.py index da021092db..b76283b421 100644 --- a/src/spikeinterface/widgets/confusion_matrix.py +++ b/src/spikeinterface/widgets/confusion_matrix.py @@ -8,7 +8,7 @@ class ConfusionMatrixWidget(BaseWidget): """ - Plot unit depths + Plots sorting comparison confusion matrix. Parameters ---------- From 85c7755f3a3c4a93117ecb7fb842309e00e22915 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 21 Sep 2023 08:43:30 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests/test_widgets_legacy.py | 2 -- src/spikeinterface/widgets/agreement_matrix.py | 8 ++------ src/spikeinterface/widgets/confusion_matrix.py | 10 +++------- src/spikeinterface/widgets/probe_map.py | 5 +---- src/spikeinterface/widgets/rasters.py | 15 ++------------- src/spikeinterface/widgets/tests/test_widgets.py | 3 +-- 6 files changed, 9 insertions(+), 34 deletions(-) diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index defe10f0d4..39eb80e2e5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -43,8 +43,6 @@ def setUp(self): def tearDown(self): pass - - def test_plot_unit_probe_map(self): sw.plot_unit_probe_map(self._we, with_channel_ids=True) sw.plot_unit_probe_map(self._we, animated=True) diff --git a/src/spikeinterface/widgets/agreement_matrix.py b/src/spikeinterface/widgets/agreement_matrix.py index 22617f6be0..ec6ea1c87c 100644 --- a/src/spikeinterface/widgets/agreement_matrix.py +++ b/src/spikeinterface/widgets/agreement_matrix.py @@ -5,7 +5,6 @@ from .utils import get_unit_colors - class AgreementMatrixWidget(BaseWidget): """ Plots sorting comparison agreement matrix. @@ -22,12 +21,11 @@ class AgreementMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - + """ def __init__( - self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, - backend=None, **backend_kwargs + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs ): plot_data = dict( sorting_comparison=sorting_comparison, @@ -87,5 +85,3 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): N1 - 0.5, -0.5, ) - - diff --git a/src/spikeinterface/widgets/confusion_matrix.py b/src/spikeinterface/widgets/confusion_matrix.py index b76283b421..8eb58f30b2 100644 --- a/src/spikeinterface/widgets/confusion_matrix.py +++ b/src/spikeinterface/widgets/confusion_matrix.py @@ -5,7 +5,6 @@ from .utils import get_unit_colors - class ConfusionMatrixWidget(BaseWidget): """ Plots sorting comparison confusion matrix. @@ -18,13 +17,10 @@ class ConfusionMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - + """ - def __init__( - self, gt_comparison, count_text=True, unit_ticks=True, - backend=None, **backend_kwargs - ): + def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs): plot_data = dict( gt_comparison=gt_comparison, count_text=count_text, @@ -80,4 +76,4 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.set_ylim( N1 + 0.5, -0.5, - ) \ No newline at end of file + ) diff --git a/src/spikeinterface/widgets/probe_map.py b/src/spikeinterface/widgets/probe_map.py index 193711a34f..7fb74abd7c 100644 --- a/src/spikeinterface/widgets/probe_map.py +++ b/src/spikeinterface/widgets/probe_map.py @@ -5,7 +5,6 @@ from .utils import get_unit_colors - class ProbeMapWidget(BaseWidget): """ Plot the probe of a recording. @@ -23,10 +22,8 @@ class ProbeMapWidget(BaseWidget): """ def __init__( - self, recording, channel_ids=None, with_channel_ids=False, - backend=None, **backend_or_plot_probe_kwargs + self, recording, channel_ids=None, with_channel_ids=False, backend=None, **backend_or_plot_probe_kwargs ): - # split backend_or_plot_probe_kwargs backend_kwargs = dict() plot_probe_kwargs = dict() diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py index de855ebe45..4a1d76279f 100644 --- a/src/spikeinterface/widgets/rasters.py +++ b/src/spikeinterface/widgets/rasters.py @@ -4,7 +4,6 @@ from .base import BaseWidget, to_attr, default_backend_kwargs - class RasterWidget(BaseWidget): """ Plots spike train rasters. @@ -24,16 +23,13 @@ class RasterWidget(BaseWidget): """ def __init__( - self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", - backend=None, **backend_kwargs + self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs ): - - if segment_index is None: if sorting.get_num_segments() != 1: raise ValueError("You must provide segment_index=...") segment_index = 0 - + if time_range is None: frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] time_range = [f / sorting.sampling_frequency for f in frame_range] @@ -86,10 +82,3 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.set_yticklabels(units_ids) self.ax.set_xlim(*dp.time_range) self.ax.set_xlabel("time (s)") - - - - - - - diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 509194cb93..2c583391c3 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -349,7 +349,6 @@ def test_plot_rasters(self): sw.plot_rasters(self.sorting) - if __name__ == "__main__": # unittest.main() @@ -371,7 +370,7 @@ def test_plot_rasters(self): # mytest.test_template_metrics() # mytest.test_amplitudes() # mytest.test_plot_agreement_matrix() - # mytest.test_plot_confusion_matrix() + # mytest.test_plot_confusion_matrix() # mytest.test_plot_probe_map() mytest.test_plot_rasters()