Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 21, 2023
1 parent eaaffd0 commit 85c7755
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def setUp(self):
def tearDown(self):
pass



def test_plot_unit_probe_map(self):
sw.plot_unit_probe_map(self._we, with_channel_ids=True)
sw.plot_unit_probe_map(self._we, animated=True)
Expand Down
8 changes: 2 additions & 6 deletions src/spikeinterface/widgets/agreement_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .utils import get_unit_colors



class AgreementMatrixWidget(BaseWidget):
"""
Plots sorting comparison agreement matrix.
Expand All @@ -22,12 +21,11 @@ class AgreementMatrixWidget(BaseWidget):
If True counts are displayed as text
unit_ticks: bool
If True unit tick labels are displayed
"""

def __init__(
self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True,
backend=None, **backend_kwargs
self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs
):
plot_data = dict(
sorting_comparison=sorting_comparison,
Expand Down Expand Up @@ -87,5 +85,3 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
N1 - 0.5,
-0.5,
)


10 changes: 3 additions & 7 deletions src/spikeinterface/widgets/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .utils import get_unit_colors



class ConfusionMatrixWidget(BaseWidget):
"""
Plots sorting comparison confusion matrix.
Expand All @@ -18,13 +17,10 @@ class ConfusionMatrixWidget(BaseWidget):
If True counts are displayed as text
unit_ticks: bool
If True unit tick labels are displayed
"""

def __init__(
self, gt_comparison, count_text=True, unit_ticks=True,
backend=None, **backend_kwargs
):
def __init__(self, gt_comparison, count_text=True, unit_ticks=True, backend=None, **backend_kwargs):
plot_data = dict(
gt_comparison=gt_comparison,
count_text=count_text,
Expand Down Expand Up @@ -80,4 +76,4 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.ax.set_ylim(
N1 + 0.5,
-0.5,
)
)
5 changes: 1 addition & 4 deletions src/spikeinterface/widgets/probe_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .utils import get_unit_colors



class ProbeMapWidget(BaseWidget):
"""
Plot the probe of a recording.
Expand All @@ -23,10 +22,8 @@ class ProbeMapWidget(BaseWidget):
"""

def __init__(
self, recording, channel_ids=None, with_channel_ids=False,
backend=None, **backend_or_plot_probe_kwargs
self, recording, channel_ids=None, with_channel_ids=False, backend=None, **backend_or_plot_probe_kwargs
):

# split backend_or_plot_probe_kwargs
backend_kwargs = dict()
plot_probe_kwargs = dict()
Expand Down
15 changes: 2 additions & 13 deletions src/spikeinterface/widgets/rasters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .base import BaseWidget, to_attr, default_backend_kwargs



class RasterWidget(BaseWidget):
"""
Plots spike train rasters.
Expand All @@ -24,16 +23,13 @@ class RasterWidget(BaseWidget):
"""

def __init__(
self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k",
backend=None, **backend_kwargs
self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", backend=None, **backend_kwargs
):


if segment_index is None:
if sorting.get_num_segments() != 1:
raise ValueError("You must provide segment_index=...")
segment_index = 0

if time_range is None:
frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]]
time_range = [f / sorting.sampling_frequency for f in frame_range]
Expand Down Expand Up @@ -86,10 +82,3 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
self.ax.set_yticklabels(units_ids)
self.ax.set_xlim(*dp.time_range)
self.ax.set_xlabel("time (s)")







3 changes: 1 addition & 2 deletions src/spikeinterface/widgets/tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def test_plot_rasters(self):
sw.plot_rasters(self.sorting)



if __name__ == "__main__":
# unittest.main()

Expand All @@ -371,7 +370,7 @@ def test_plot_rasters(self):
# mytest.test_template_metrics()
# mytest.test_amplitudes()
# mytest.test_plot_agreement_matrix()
# mytest.test_plot_confusion_matrix()
# mytest.test_plot_confusion_matrix()
# mytest.test_plot_probe_map()
mytest.test_plot_rasters()

Expand Down

0 comments on commit 85c7755

Please sign in to comment.