diff --git a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py index 061fc55339..53c2a5c79e 100644 --- a/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py +++ b/src/spikeinterface/widgets/_legacy_mpl_widgets/__init__.py @@ -1,4 +1,3 @@ - # peak activity from .activity import plot_peak_activity_map, PeakActivityMapWidget diff --git a/src/spikeinterface/widgets/isi_distribution.py b/src/spikeinterface/widgets/isi_distribution.py index 2d92d1daf7..4256efd403 100644 --- a/src/spikeinterface/widgets/isi_distribution.py +++ b/src/spikeinterface/widgets/isi_distribution.py @@ -5,7 +5,6 @@ from .utils import get_unit_colors - class ISIDistributionWidget(BaseWidget): """ Plots spike train ISI distribution. @@ -20,13 +19,10 @@ class ISIDistributionWidget(BaseWidget): Bin size in ms window_ms: float Window size in ms - - """ - def __init__( - self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs - ): + """ + def __init__(self, sorting, unit_ids=None, window_ms=100.0, bin_ms=1.0, backend=None, **backend_kwargs): if unit_ids is None: unit_ids = sorting.get_unit_ids() @@ -53,14 +49,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): sorting = dp.sorting num_segments = sorting.get_num_segments() fs = sorting.sampling_frequency - + for i, unit_id in enumerate(dp.unit_ids): ax = self.axes.flatten()[i] bins = np.arange(0, dp.window_ms, dp.bin_ms) bin_counts = None for segment_index in range(num_segments): - times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000. + times_ms = sorting.get_unit_spike_train(unit_id=unit_id, segment_index=segment_index) / fs * 1000.0 isi = np.diff(times_ms) bin_counts_, bin_edges = np.histogram(isi, bins=bins, density=True) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4443ef7b03..bc3ab4272a 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -245,7 +245,6 @@ def test_isi_distribution(self): **self.backend_kwargs[backend], ) - def test_amplitudes(self): possible_backends = list(sw.AmplitudesWidget.get_possible_backends()) for backend in possible_backends: @@ -377,7 +376,6 @@ def test_plot_unit_probe_map(self): for backend in possible_backends: if backend not in self.skip_backends: sw.plot_unit_probe_map(self.we_dense) - if __name__ == "__main__": diff --git a/src/spikeinterface/widgets/unit_probe_map.py b/src/spikeinterface/widgets/unit_probe_map.py index 66b7ff3126..4068c1c530 100644 --- a/src/spikeinterface/widgets/unit_probe_map.py +++ b/src/spikeinterface/widgets/unit_probe_map.py @@ -4,6 +4,7 @@ # from probeinterface import ProbeGroup from .base import BaseWidget, to_attr + # from .utils import get_unit_colors from ..core.waveform_extractor import WaveformExtractor @@ -26,6 +27,7 @@ class UnitProbeMapWidget(BaseWidget): with_channel_ids: bool False default add channel ids text on the probe """ + def __init__( self, waveform_extractor, @@ -37,7 +39,6 @@ def __init__( backend=None, **backend_kwargs, ): - if unit_ids is None: unit_ids = waveform_extractor.sorting.unit_ids self.unit_ids = unit_ids @@ -45,7 +46,6 @@ def __init__( channel_ids = waveform_extractor.recording.channel_ids self.channel_ids = channel_ids - data_plot = dict( waveform_extractor=waveform_extractor, unit_ids=unit_ids, @@ -71,7 +71,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - we = dp.waveform_extractor probe = we.get_probe()