diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index c0dcd7ea6e..9593f14d1c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,8 +1,3 @@ -# basics -# from .timeseries import plot_timeseries, TracesWidget -from .rasters import plot_rasters, RasterWidget -from .probemap import plot_probe_map, ProbeMapWidget - # isi/ccg/acg from .isidistribution import plot_isi_distribution, ISIDistributionWidget @@ -15,9 +10,6 @@ # units on probe from .unitprobemap import plot_unit_probe_map, UnitProbeMapWidget -# comparison related -from .confusionmatrix import plot_confusion_matrix, ConfusionMatrixWidget -from .agreementmatrix import plot_agreement_matrix, AgreementMatrixWidget from .multicompgraph import ( plot_multicomp_graph, MultiCompGraphWidget, diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py deleted file mode 100644 index 6e6578a4c4..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/probemap.py +++ /dev/null @@ -1,77 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class ProbeMapWidget(BaseWidget): - """ - Plot the probe of a recording. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - channel_ids: list - The channel ids to display - with_channel_ids: bool False default - Add channel ids text on the probe - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function - - Returns - ------- - W: ProbeMapWidget - The output widget - """ - - def __init__(self, recording, channel_ids=None, with_channel_ids=False, figure=None, ax=None, **plot_probe_kwargs): - import matplotlib.pylab as plt - from probeinterface.plotting import plot_probe, get_auto_lims - - BaseWidget.__init__(self, figure, ax) - - if channel_ids is not None: - recording = recording.channel_slice(channel_ids) - self._recording = recording - self._probegroup = recording.get_probegroup() - self.with_channel_ids = with_channel_ids - self._plot_probe_kwargs = plot_probe_kwargs - - def plot(self): - self._do_plot() - - def _do_plot(self): - from probeinterface.plotting import get_auto_lims - - xlims, ylims, zlims = get_auto_lims(self._probegroup.probes[0]) - for i, probe in enumerate(self._probegroup.probes): - xlims2, ylims2, _ = get_auto_lims(probe) - xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) - ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) - - self._plot_probe_kwargs["title"] = False - pos = 0 - text_on_contact = None - for i, probe in enumerate(self._probegroup.probes): - n = probe.get_contact_count() - if self.with_channel_ids: - text_on_contact = self._recording.channel_ids[pos : pos + n] - pos += n - from probeinterface.plotting import plot_probe - - plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **self._plot_probe_kwargs) - - self.ax.set_xlim(*xlims) - self.ax.set_ylim(*ylims) - - -def plot_probe_map(*args, **kwargs): - W = ProbeMapWidget(*args, **kwargs) - W.plot() - return W - - -plot_probe_map.__doc__ = ProbeMapWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py deleted file mode 100644 index d05373103e..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/rasters.py +++ /dev/null @@ -1,120 +0,0 @@ -import numpy as np - -from .basewidget import BaseWidget - - -class RasterWidget(BaseWidget): - """ - Plots spike train rasters. - - Parameters - ---------- - sorting: SortingExtractor - The sorting extractor object - segment_index: None or int - The segment index. - unit_ids: list - List of unit ids - time_range: list - List with start time and end time - color: matplotlib color - The color to be used - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: RasterWidget - The output widget - """ - - def __init__(self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", figure=None, ax=None): - from matplotlib import pyplot as plt - - BaseWidget.__init__(self, figure, ax) - self._sorting = sorting - - if segment_index is None: - nseg = sorting.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - else: - segment_index = 0 - self.segment_index = segment_index - - self._unit_ids = unit_ids - self._figure = None - self._sampling_frequency = sorting.get_sampling_frequency() - self._color = color - self._max_frame = 0 - for unit_id in self._sorting.get_unit_ids(): - spike_train = self._sorting.get_unit_spike_train(unit_id, segment_index=self.segment_index) - if len(spike_train) > 0: - curr_max_frame = np.max(spike_train) - if curr_max_frame > self._max_frame: - self._max_frame = curr_max_frame - self._visible_trange = time_range - if self._visible_trange is None: - self._visible_trange = [0, self._max_frame] - else: - assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" - self._visible_trange = [int(t * self._sampling_frequency) for t in time_range] - - self._visible_trange = self._fix_trange(self._visible_trange) - self.name = "Raster" - - def plot(self): - self._do_plot() - - def _do_plot(self): - units_ids = self._unit_ids - if units_ids is None: - units_ids = self._sorting.get_unit_ids() - import matplotlib.pyplot as plt - - with plt.rc_context({"axes.edgecolor": "gray"}): - for u_i, unit_id in enumerate(units_ids): - spiketrain = self._sorting.get_unit_spike_train( - unit_id, - start_frame=self._visible_trange[0], - end_frame=self._visible_trange[1], - segment_index=self.segment_index, - ) - spiketimes = spiketrain / float(self._sampling_frequency) - self.ax.plot( - spiketimes, - u_i * np.ones_like(spiketimes), - marker="|", - mew=1, - markersize=3, - ls="", - color=self._color, - ) - visible_start_frame = self._visible_trange[0] / self._sampling_frequency - visible_end_frame = self._visible_trange[1] / self._sampling_frequency - self.ax.set_yticks(np.arange(len(units_ids))) - self.ax.set_yticklabels(units_ids) - self.ax.set_xlim(visible_start_frame, visible_end_frame) - self.ax.set_xlabel("time (s)") - - def _fix_trange(self, trange): - if trange[1] > self._max_frame: - # trange[0] += max_t - trange[1] - trange[1] = self._max_frame - if trange[0] < 0: - # trange[1] += -trange[0] - trange[0] = 0 - # trange[0] = np.maximum(0, trange[0]) - # trange[1] = np.minimum(max_t, trange[1]) - return trange - - -def plot_rasters(*args, **kwargs): - W = RasterWidget(*args, **kwargs) - W.plot() - return W - - -plot_rasters.__doc__ = RasterWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py index 5004765251..39eb80e2e5 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/tests/test_widgets_legacy.py @@ -43,44 +43,6 @@ def setUp(self): def tearDown(self): pass - # def test_timeseries(self): - # sw.plot_timeseries(self._rec, mode='auto') - # sw.plot_timeseries(self._rec, mode='line', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True) - # sw.plot_timeseries(self._rec, mode='map', show_channel_ids=True, order_channel_by_depth=True) - - def test_rasters(self): - sw.plot_rasters(self._sorting) - - def test_plot_probe_map(self): - sw.plot_probe_map(self._rec) - sw.plot_probe_map(self._rec, with_channel_ids=True) - - # TODO - # def test_spectrum(self): - # sw.plot_spectrum(self._rec) - - # TODO - # def test_spectrogram(self): - # sw.plot_spectrogram(self._rec, channel=0) - - # def test_unitwaveforms(self): - # w = sw.plot_unit_waveforms(self._we) - # unit_ids = self._sorting.unit_ids[:6] - # sw.plot_unit_waveforms(self._we, max_channels=5, unit_ids=unit_ids) - # sw.plot_unit_waveforms(self._we, radius_um=60, unit_ids=unit_ids) - - # def test_plot_unit_waveform_density_map(self): - # unit_ids = self._sorting.unit_ids[:3] - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=4) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=50) - # - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, radius_um=25, same_axis=True) - # sw.plot_unit_waveform_density_map(self._we, unit_ids=unit_ids, max_channels=2, same_axis=True) - - # def test_unittemplates(self): - # sw.plot_unit_templates(self._we) - def test_plot_unit_probe_map(self): sw.plot_unit_probe_map(self._we, with_channel_ids=True) sw.plot_unit_probe_map(self._we, animated=True) @@ -120,12 +82,6 @@ def test_plot_peak_activity_map(self): sw.plot_peak_activity_map(self._rec, with_channel_ids=True) sw.plot_peak_activity_map(self._rec, bin_duration_s=1.0) - def test_confusion(self): - sw.plot_confusion_matrix(self._gt_comp, count_text=True) - - def test_agreement(self): - sw.plot_agreement_matrix(self._gt_comp, count_text=True) - def test_multicomp_graph(self): msc = sc.compare_multiple_sorters([self._sorting, self._sorting, self._sorting]) sw.plot_multicomp_graph(msc, edge_cmap="viridis", node_cmap="rainbow", draw_labels=False) @@ -150,8 +106,6 @@ def test_sorting_performance(self): mytest.setUp() # ~ mytest.test_timeseries() - # ~ mytest.test_rasters() - mytest.test_plot_probe_map() # ~ mytest.test_unitwaveforms() # ~ mytest.test_plot_unit_waveform_density_map() # mytest.test_unittemplates() @@ -169,8 +123,6 @@ def test_sorting_performance(self): # ~ mytest.test_plot_drift_over_time() # ~ mytest.test_plot_peak_activity_map() - # mytest.test_confusion() - # mytest.test_agreement() # ~ mytest.test_multicomp_graph() #  mytest.test_sorting_performance() diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py deleted file mode 100644 index ab6fa2ace5..0000000000 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/timeseries_.py +++ /dev/null @@ -1,233 +0,0 @@ -import numpy as np -from matplotlib import pyplot as plt -from matplotlib.ticker import MaxNLocator -from .basewidget import BaseWidget - -import scipy.spatial - - -class TracesWidget(BaseWidget): - """ - Plots recording timeseries. - - Parameters - ---------- - recording: RecordingExtractor - The recording extractor object - segment_index: None or int - The segment index. - channel_ids: list - The channel ids to display. - order_channel_by_depth: boolean - Reorder channel by depth. - time_range: list - List with start time and end time - mode: 'line' or 'map' or 'auto' - 2 possible mode: - * 'line' : classical for low channel count - * 'map' : for high channel count use color heat map - * 'auto' : auto switch depending the channel count <32ch - cmap: str default 'RdBu' - matplotlib colormap used in mode 'map' - show_channel_ids: bool - Set yticks with channel ids - color_groups: bool - If True groups are plotted with different colors - color: matplotlib color, default: None - The color used to draw the traces. - clim: None or tupple - When mode='map' this control color lims - with_colorbar: bool default True - When mode='map' add colorbar - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: TracesWidget - The output widget - """ - - def __init__( - self, - recording, - segment_index=None, - channel_ids=None, - order_channel_by_depth=False, - time_range=None, - mode="auto", - cmap="RdBu", - show_channel_ids=False, - color_groups=False, - color=None, - clim=None, - with_colorbar=True, - figure=None, - ax=None, - **plot_kwargs, - ): - BaseWidget.__init__(self, figure, ax) - self.recording = recording - self._sampling_frequency = recording.get_sampling_frequency() - self.visible_channel_ids = channel_ids - self._plot_kwargs = plot_kwargs - - if segment_index is None: - nseg = recording.get_num_segments() - if nseg != 1: - raise ValueError("You must provide segment_index=...") - segment_index = 0 - self.segment_index = segment_index - - if self.visible_channel_ids is None: - self.visible_channel_ids = recording.get_channel_ids() - - if order_channel_by_depth: - locations = self.recording.get_channel_locations() - channel_inds = self.recording.ids_to_indices(self.visible_channel_ids) - locations = locations[channel_inds, :] - origin = np.array([np.max(locations[:, 0]), np.min(locations[:, 1])])[None, :] - dist = scipy.spatial.distance.cdist(locations, origin, metric="euclidean") - dist = dist[:, 0] - self.order = np.argsort(dist) - else: - self.order = None - - if channel_ids is None: - channel_ids = recording.get_channel_ids() - - fs = recording.get_sampling_frequency() - if time_range is None: - time_range = (0, 1.0) - time_range = np.array(time_range) - - assert mode in ("auto", "line", "map"), "Mode must be in auto/line/map" - if mode == "auto": - if len(channel_ids) <= 64: - mode = "line" - else: - mode = "map" - self.mode = mode - self.cmap = cmap - - self.show_channel_ids = show_channel_ids - - self._frame_range = (time_range * fs).astype("int64") - a_max = self.recording.get_num_frames(segment_index=self.segment_index) - self._frame_range = np.clip(self._frame_range, 0, a_max) - self._time_range = [e / fs for e in self._frame_range] - - self.clim = clim - self.with_colorbar = with_colorbar - - self._initialize_stats() - - # self._vspacing = self._mean_channel_std * 20 - self._vspacing = self._max_channel_amp * 1.5 - - if recording.get_channel_groups() is None: - color_groups = False - - self._color_groups = color_groups - self._color = color - if color_groups: - self._colors = [] - self._group_color_map = {} - all_groups = recording.get_channel_groups() - groups = np.unique(all_groups) - N = len(groups) - import colorsys - - HSV_tuples = [(x * 1.0 / N, 0.5, 0.5) for x in range(N)] - self._colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), HSV_tuples)) - color_idx = 0 - for group in groups: - self._group_color_map[group] = color_idx - color_idx += 1 - self.name = "TimeSeries" - - def plot(self): - self._do_plot() - - def _do_plot(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - if self.order is not None: - chunk0 = chunk0[:, self.order] - self.visible_channel_ids = np.array(self.visible_channel_ids)[self.order] - - ax = self.ax - - n = len(self.visible_channel_ids) - - if self.mode == "line": - ax.set_xlim( - self._frame_range[0] / self._sampling_frequency, self._frame_range[1] / self._sampling_frequency - ) - ax.set_ylim(-self._vspacing, self._vspacing * n) - ax.get_xaxis().set_major_locator(MaxNLocator(prune="both")) - ax.get_yaxis().set_ticks([]) - ax.set_xlabel("time (s)") - - self._plots = {} - self._plot_offsets = {} - offset0 = self._vspacing * (n - 1) - times = np.arange(self._frame_range[0], self._frame_range[1]) / self._sampling_frequency - for im, m in enumerate(self.visible_channel_ids): - self._plot_offsets[m] = offset0 - if self._color_groups: - group = self.recording.get_channel_groups(channel_ids=[m])[0] - group_color_idx = self._group_color_map[group] - color = self._colors[group_color_idx] - else: - color = self._color - self._plots[m] = ax.plot(times, self._plot_offsets[m] + chunk0[:, im], color=color, **self._plot_kwargs) - offset0 = offset0 - self._vspacing - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) * self._vspacing) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - elif self.mode == "map": - extent = (self._time_range[0], self._time_range[1], 0, self.recording.get_num_channels()) - im = ax.imshow( - chunk0.T, interpolation="nearest", origin="upper", aspect="auto", extent=extent, cmap=self.cmap - ) - - if self.clim is None: - im.set_clim(-self._max_channel_amp, self._max_channel_amp) - else: - im.set_clim(*self.clim) - - if self.with_colorbar: - self.figure.colorbar(im, ax=ax) - - if self.show_channel_ids: - ax.set_yticks(np.arange(n) + 0.5) - ax.set_yticklabels([str(chan_id) for chan_id in self.visible_channel_ids[::-1]]) - - def _initialize_stats(self): - chunk0 = self.recording.get_traces( - segment_index=self.segment_index, - channel_ids=self.visible_channel_ids, - start_frame=self._frame_range[0], - end_frame=self._frame_range[1], - ) - - self._mean_channel_std = np.mean(np.std(chunk0, axis=0)) - self._max_channel_amp = np.max(np.max(np.abs(chunk0), axis=0)) - - -def plot_timeseries(*args, **kwargs): - W = TracesWidget(*args, **kwargs) - W.plot() - return W - - -plot_timeseries.__doc__ = TracesWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py b/src/spikeinterface/widgets/agreement_matrix.py similarity index 53% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py rename to src/spikeinterface/widgets/agreement_matrix.py index 369746e99b..ec6ea1c87c 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/agreementmatrix.py +++ b/src/spikeinterface/widgets/agreement_matrix.py @@ -1,11 +1,13 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class AgreementMatrixWidget(BaseWidget): """ - Plots sorting comparison confusion matrix. + Plots sorting comparison agreement matrix. Parameters ---------- @@ -19,31 +21,34 @@ class AgreementMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created + """ - def __init__(self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__( + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs + ): + plot_data = dict( + sorting_comparison=sorting_comparison, + ordered=ordered, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) - BaseWidget.__init__(self, figure, ax) - self._sc = sorting_comparison - self._ordered = ordered - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def plot(self): - self._do_plot() + comp = dp.sorting_comparison - def _do_plot(self): - # a dataframe - if self._ordered: - scores = self._sc.get_ordered_agreement_scores() + if dp.ordered: + scores = comp.get_ordered_agreement_scores() else: - scores = self._sc.agreement_scores + scores = comp.agreement_scores N1 = scores.shape[0] N2 = scores.shape[1] @@ -54,9 +59,9 @@ def _do_plot(self): # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(scores.values, cmap="Greens") - if self._count_text: + if dp.count_text: for i, u1 in enumerate(unit_ids1): - u2 = self._sc.best_match_12[u1] + u2 = comp.best_match_12[u1] if u2 != -1: j = np.where(unit_ids2 == u2)[0][0] @@ -68,24 +73,15 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(scores.index, fontsize=12) self.ax.set_xticklabels(scores.columns, fontsize=12) - self.ax.set_xlabel(self._sc.name_list[1], fontsize=20) - self.ax.set_ylabel(self._sc.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 - 0.5) self.ax.set_ylim( N1 - 0.5, -0.5, ) - - -def plot_agreement_matrix(*args, **kwargs): - W = AgreementMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_agreement_matrix.__doc__ = AgreementMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py b/src/spikeinterface/widgets/confusion_matrix.py similarity index 62% rename from src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py rename to src/spikeinterface/widgets/confusion_matrix.py index 942b613fbf..8eb58f30b2 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/confusionmatrix.py +++ b/src/spikeinterface/widgets/confusion_matrix.py @@ -1,6 +1,8 @@ import numpy as np +from warnings import warn -from .basewidget import BaseWidget +from .base import BaseWidget, to_attr +from .utils import get_unit_colors class ConfusionMatrixWidget(BaseWidget): @@ -15,40 +17,35 @@ class ConfusionMatrixWidget(BaseWidget): If True counts are displayed as text unit_ticks: bool If True unit tick labels are displayed - figure: matplotlib figure - The figure to be used. If not given a figure is created - ax: matplotlib axis - The axis to be used. If not given an axis is created - - Returns - ------- - W: ConfusionMatrixWidget - The output widget + """ - def __init__(self, gt_comparison, count_text=True, unit_ticks=True, figure=None, ax=None): - from matplotlib import pyplot as plt + def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs): + plot_data = dict( + gt_comparison=gt_comparison, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure - BaseWidget.__init__(self, figure, ax) - self._gtcomp = gt_comparison - self._count_text = count_text - self._unit_ticks = unit_ticks - self.name = "ConfusionMatrix" + dp = to_attr(data_plot) - def plot(self): - self._do_plot() + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - def _do_plot(self): - # a dataframe - confusion_matrix = self._gtcomp.get_confusion_matrix() + comp = dp.gt_comparison + confusion_matrix = comp.get_confusion_matrix() N1 = confusion_matrix.shape[0] - 1 N2 = confusion_matrix.shape[1] - 1 # Using matshow here just because it sets the ticks up nicely. imshow is faster. self.ax.matshow(confusion_matrix.values, cmap="Greens") - if self._count_text: + if dp.count_text: for (i, j), z in np.ndenumerate(confusion_matrix.values): if z != 0: if z > np.max(confusion_matrix.values) / 2.0: @@ -65,27 +62,18 @@ def _do_plot(self): self.ax.xaxis.tick_bottom() # Labels for major ticks - if self._unit_ticks: + if dp.unit_ticks: self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) else: self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) - self.ax.set_xlabel(self._gtcomp.name_list[1], fontsize=20) - self.ax.set_ylabel(self._gtcomp.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1], fontsize=20) + self.ax.set_ylabel(comp.name_list[0], fontsize=20) self.ax.set_xlim(-0.5, N2 + 0.5) self.ax.set_ylim( N1 + 0.5, -0.5, ) - - -def plot_confusion_matrix(*args, **kwargs): - W = ConfusionMatrixWidget(*args, **kwargs) - W.plot() - return W - - -plot_confusion_matrix.__doc__ = ConfusionMatrixWidget.__doc__ diff --git a/src/spikeinterface/widgets/probe_map.py b/src/spikeinterface/widgets/probe_map.py new file mode 100644 index 0000000000..7fb74abd7c --- /dev/null +++ b/src/spikeinterface/widgets/probe_map.py @@ -0,0 +1,75 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs +from .utils import get_unit_colors + + +class ProbeMapWidget(BaseWidget): + """ + Plot the probe of a recording. + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object + channel_ids: list + The channel ids to display + with_channel_ids: bool False default + Add channel ids text on the probe + **plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function + + """ + + def __init__( + self, recording, channel_ids=None, with_channel_ids=False, backend=None, **backend_or_plot_probe_kwargs + ): + # split backend_or_plot_probe_kwargs + backend_kwargs = dict() + plot_probe_kwargs = dict() + backend = self.check_backend(backend) + for k, v in backend_or_plot_probe_kwargs.items(): + if k in default_backend_kwargs[backend]: + backend_kwargs[k] = v + else: + plot_probe_kwargs[k] = v + + plot_data = dict( + recording=recording, + channel_ids=channel_ids, + with_channel_ids=with_channel_ids, + plot_probe_kwargs=plot_probe_kwargs, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from probeinterface.plotting import get_auto_lims, plot_probe + + dp = to_attr(data_plot) + + plot_probe_kwargs = dp.plot_probe_kwargs + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + probegroup = dp.recording.get_probegroup() + + xlims, ylims, zlims = get_auto_lims(probegroup.probes[0]) + for i, probe in enumerate(probegroup.probes): + xlims2, ylims2, _ = get_auto_lims(probe) + xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) + ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) + + plot_probe_kwargs["title"] = False + pos = 0 + text_on_contact = None + for i, probe in enumerate(probegroup.probes): + n = probe.get_contact_count() + if dp.with_channel_ids: + text_on_contact = dp.recording.channel_ids[pos : pos + n] + pos += n + plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **plot_probe_kwargs) + + self.ax.set_xlim(*xlims) + self.ax.set_ylim(*ylims) diff --git a/src/spikeinterface/widgets/rasters.py b/src/spikeinterface/widgets/rasters.py new file mode 100644 index 0000000000..4a1d76279f --- /dev/null +++ b/src/spikeinterface/widgets/rasters.py @@ -0,0 +1,84 @@ +import numpy as np +from warnings import warn + +from .base import BaseWidget, to_attr, default_backend_kwargs + + +class RasterWidget(BaseWidget): + """ + Plots spike train rasters. + + Parameters + ---------- + sorting: SortingExtractor + The sorting extractor object + segment_index: None or int + The segment index. + unit_ids: list + List of unit ids + time_range: list + List with start time and end time + color: matplotlib color + The color to be used + """ + + def __init__( + self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs + ): + if segment_index is None: + if sorting.get_num_segments() != 1: + raise ValueError("You must provide segment_index=...") + segment_index = 0 + + if time_range is None: + frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] + time_range = [f / sorting.sampling_frequency for f in frame_range] + else: + assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" + frame_range = [int(t * sorting.sampling_frequency) for t in time_range] + + plot_data = dict( + sorting=sorting, + segment_index=segment_index, + unit_ids=unit_ids, + color=color, + frame_range=frame_range, + time_range=time_range, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + sorting = dp.sorting + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + units_ids = dp.unit_ids + if units_ids is None: + units_ids = sorting.unit_ids + + with plt.rc_context({"axes.edgecolor": "gray"}): + for unit_index, unit_id in enumerate(units_ids): + spiketrain = sorting.get_unit_spike_train( + unit_id, + start_frame=dp.frame_range[0], + end_frame=dp.frame_range[1], + segment_index=dp.segment_index, + ) + spiketimes = spiketrain / float(sorting.sampling_frequency) + self.ax.plot( + spiketimes, + unit_index * np.ones_like(spiketimes), + marker="|", + mew=1, + markersize=3, + ls="", + color=dp.color, + ) + self.ax.set_yticks(np.arange(len(units_ids))) + self.ax.set_yticklabels(units_ids) + self.ax.set_xlim(*dp.time_range) + self.ax.set_xlabel("time (s)") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 7386167d0b..f44878927d 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -324,6 +324,30 @@ def test_sorting_summary(self): sw.plot_sorting_summary(self.we, backend=backend, **self.backend_kwargs[backend]) sw.plot_sorting_summary(self.we_sparse, backend=backend, **self.backend_kwargs[backend]) + def test_plot_agreement_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_agreement_matrix(self.gt_comp) + + def test_plot_confusion_matrix(self): + possible_backends = list(sw.AgreementMatrixWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_confusion_matrix(self.gt_comp) + + def test_plot_probe_map(self): + possible_backends = list(sw.ProbeMapWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_probe_map(self.recording, with_channel_ids=True, with_contact_id=True) + + def test_plot_rasters(self): + possible_backends = list(sw.RasterWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_rasters(self.sorting) + if __name__ == "__main__": # unittest.main() @@ -345,6 +369,10 @@ def test_sorting_summary(self): # mytest.test_quality_metrics() # mytest.test_template_metrics() # mytest.test_amplitudes() + # mytest.test_plot_agreement_matrix() + # mytest.test_plot_confusion_matrix() + # mytest.test_plot_probe_map() + mytest.test_plot_rasters() # plt.ion() plt.show() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 9c89b3981e..6ea2593432 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,12 +2,16 @@ from .base import backend_kwargs_desc +from .agreement_matrix import AgreementMatrixWidget from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget +from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .motion import MotionWidget +from .probe_map import ProbeMapWidget from .quality_metrics import QualityMetricsWidget +from .rasters import RasterWidget from .sorting_summary import SortingSummaryWidget from .spike_locations import SpikeLocationsWidget from .spikes_on_traces import SpikesOnTracesWidget @@ -23,12 +27,16 @@ widget_list = [ + AgreementMatrixWidget, AllAmplitudesDistributionsWidget, AmplitudesWidget, AutoCorrelogramsWidget, + ConfusionMatrixWidget, CrossCorrelogramsWidget, MotionWidget, + ProbeMapWidget, QualityMetricsWidget, + RasterWidget, SortingSummaryWidget, SpikeLocationsWidget, SpikesOnTracesWidget, @@ -76,12 +84,16 @@ # make function for all widgets +plot_agreement_matrix = AgreementMatrixWidget plot_all_amplitudes_distributions = AllAmplitudesDistributionsWidget plot_amplitudes = AmplitudesWidget plot_autocorrelograms = AutoCorrelogramsWidget +plot_confusion_matrix = ConfusionMatrixWidget plot_crosscorrelograms = CrossCorrelogramsWidget plot_motion = MotionWidget +plot_probe_map = ProbeMapWidget plot_quality_metrics = QualityMetricsWidget +plot_rasters = RasterWidget plot_sorting_summary = SortingSummaryWidget plot_spike_locations = SpikeLocationsWidget plot_spikes_on_traces = SpikesOnTracesWidget