diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 31a938829d..895a8733c7 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -232,21 +232,19 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + assert backend_kwargs["axes"] is None, "axes argument is not allowed in DriftRasterMapWidget. Use ax instead." self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - fig = self.figure if dp.times is None: - x = dp.peaks["sample_index"] / dp.sampling_frequency + peak_times = dp.peaks["sample_index"] / dp.sampling_frequency else: - x = dp.times[dp.peaks["sample_index"]] + peak_times = dp.times[dp.peaks["sample_index"]] - y = dp.peak_locations[dp.direction] + peak_locs = dp.peak_locations[dp.direction] if dp.scatter_decimate is not None: - x = x[:: dp.scatter_decimate] - y = y[:: dp.scatter_decimate] - y2 = y2[:: dp.scatter_decimate] + peak_times = peak_times[:: dp.scatter_decimate] + peak_locs = peak_locs[:: dp.scatter_decimate] if dp.color_amplitude: amps = dp.peak_amplitudes @@ -271,7 +269,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): else: color_kwargs = dict(color=dp.color, c=None, alpha=dp.alpha) - self.ax.scatter(x, y, s=1, **color_kwargs) + self.ax.scatter(peak_times, peak_locs, s=1, **color_kwargs) if dp.depth_lim is not None: self.ax.set_ylim(*dp.depth_lim) self.ax.set_title("Peak depth") diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 0eef8539cc..7887ecda66 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -613,7 +613,7 @@ def test_drift_raster_map(self): color_amplitude=False, ) # with analyzer - sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True) + sw.plot_drift_raster_map(sorting_analyzer=analyzer, color_amplitude=True, scatter_decimate=2) def test_plot_motion_info(self): motion_info = self.motion_info