-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4501289
commit 625ff5e
Showing
4 changed files
with
347 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import numpy as np | ||
from warnings import warn | ||
|
||
from .base import BaseWidget, to_attr | ||
from .utils import get_unit_colors | ||
|
||
|
||
|
||
class AgreementMatrixWidget(BaseWidget): | ||
""" | ||
Plot unit depths | ||
Parameters | ||
---------- | ||
sorting_comparison: GroundTruthComparison or SymmetricSortingComparison | ||
The sorting comparison object. | ||
Symetric or not. | ||
ordered: bool | ||
Order units with best agreement scores. | ||
This enable to see agreement on a diagonal. | ||
count_text: bool | ||
If True counts are displayed as text | ||
unit_ticks: bool | ||
If True unit tick labels are displayed | ||
""" | ||
|
||
def __init__( | ||
self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, | ||
backend=None, **backend_kwargs | ||
): | ||
plot_data = dict( | ||
sorting_comparison=sorting_comparison, | ||
ordered=ordered, | ||
count_text=count_text, | ||
unit_ticks=unit_ticks, | ||
) | ||
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) | ||
|
||
def plot_matplotlib(self, data_plot, **backend_kwargs): | ||
import matplotlib.pyplot as plt | ||
from .utils_matplotlib import make_mpl_figure | ||
|
||
dp = to_attr(data_plot) | ||
|
||
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) | ||
|
||
comp = dp.sorting_comparison | ||
|
||
if dp.ordered: | ||
scores = comp.get_ordered_agreement_scores() | ||
else: | ||
scores = comp.agreement_scores | ||
|
||
N1 = scores.shape[0] | ||
N2 = scores.shape[1] | ||
|
||
unit_ids1 = scores.index.values | ||
unit_ids2 = scores.columns.values | ||
|
||
# Using matshow here just because it sets the ticks up nicely. imshow is faster. | ||
self.ax.matshow(scores.values, cmap="Greens") | ||
|
||
if dp.count_text: | ||
for i, u1 in enumerate(unit_ids1): | ||
u2 = comp.best_match_12[u1] | ||
if u2 != -1: | ||
j = np.where(unit_ids2 == u2)[0][0] | ||
|
||
self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") | ||
|
||
# Major ticks | ||
self.ax.set_xticks(np.arange(0, N2)) | ||
self.ax.set_yticks(np.arange(0, N1)) | ||
self.ax.xaxis.tick_bottom() | ||
|
||
# Labels for major ticks | ||
if dp.unit_ticks: | ||
self.ax.set_yticklabels(scores.index, fontsize=12) | ||
self.ax.set_xticklabels(scores.columns, fontsize=12) | ||
|
||
self.ax.set_xlabel(comp.name_list[1], fontsize=20) | ||
self.ax.set_ylabel(comp.name_list[0], fontsize=20) | ||
|
||
self.ax.set_xlim(-0.5, N2 - 0.5) | ||
self.ax.set_ylim( | ||
N1 - 0.5, | ||
-0.5, | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import numpy as np | ||
from warnings import warn | ||
|
||
from .base import BaseWidget, to_attr | ||
from .utils import get_unit_colors | ||
|
||
|
||
|
||
class ConfusionMatrixWidget(BaseWidget): | ||
""" | ||
Plot unit depths | ||
Parameters | ||
---------- | ||
gt_comparison: GroundTruthComparison | ||
The ground truth sorting comparison object | ||
count_text: bool | ||
If True counts are displayed as text | ||
unit_ticks: bool | ||
If True unit tick labels are displayed | ||
""" | ||
|
||
def __init__( | ||
self, gt_comparison, count_text=True, unit_ticks=True, | ||
backend=None, **backend_kwargs | ||
): | ||
plot_data = dict( | ||
gt_comparison=gt_comparison, | ||
count_text=count_text, | ||
unit_ticks=unit_ticks, | ||
) | ||
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) | ||
|
||
def plot_matplotlib(self, data_plot, **backend_kwargs): | ||
import matplotlib.pyplot as plt | ||
from .utils_matplotlib import make_mpl_figure | ||
|
||
dp = to_attr(data_plot) | ||
|
||
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) | ||
|
||
comp = dp.gt_comparison | ||
|
||
confusion_matrix = comp.get_confusion_matrix() | ||
N1 = confusion_matrix.shape[0] - 1 | ||
N2 = confusion_matrix.shape[1] - 1 | ||
|
||
# Using matshow here just because it sets the ticks up nicely. imshow is faster. | ||
self.ax.matshow(confusion_matrix.values, cmap="Greens") | ||
|
||
if dp.count_text: | ||
for (i, j), z in np.ndenumerate(confusion_matrix.values): | ||
if z != 0: | ||
if z > np.max(confusion_matrix.values) / 2.0: | ||
self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="white") | ||
else: | ||
self.ax.text(j, i, "{:d}".format(z), ha="center", va="center", color="black") | ||
|
||
self.ax.axhline(int(N1 - 1) + 0.5, color="black") | ||
self.ax.axvline(int(N2 - 1) + 0.5, color="black") | ||
|
||
# Major ticks | ||
self.ax.set_xticks(np.arange(0, N2 + 1)) | ||
self.ax.set_yticks(np.arange(0, N1 + 1)) | ||
self.ax.xaxis.tick_bottom() | ||
|
||
# Labels for major ticks | ||
if dp.unit_ticks: | ||
self.ax.set_yticklabels(confusion_matrix.index, fontsize=12) | ||
self.ax.set_xticklabels(confusion_matrix.columns, fontsize=12) | ||
else: | ||
self.ax.set_xticklabels(np.append([""] * N2, "FN"), fontsize=10) | ||
self.ax.set_yticklabels(np.append([""] * N1, "FP"), fontsize=10) | ||
|
||
self.ax.set_xlabel(comp.name_list[1], fontsize=20) | ||
self.ax.set_ylabel(comp.name_list[0], fontsize=20) | ||
|
||
self.ax.set_xlim(-0.5, N2 + 0.5) | ||
self.ax.set_ylim( | ||
N1 + 0.5, | ||
-0.5, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
from warnings import warn | ||
|
||
from .base import BaseWidget, to_attr, default_backend_kwargs | ||
from .utils import get_unit_colors | ||
|
||
|
||
|
||
class ProbeMapWidget(BaseWidget): | ||
""" | ||
Plot the probe of a recording. | ||
Parameters | ||
---------- | ||
recording: RecordingExtractor | ||
The recording extractor object | ||
channel_ids: list | ||
The channel ids to display | ||
with_channel_ids: bool False default | ||
Add channel ids text on the probe | ||
**plot_probe_kwargs: keyword arguments for probeinterface.plotting.plot_probe_group() function | ||
""" | ||
|
||
def __init__( | ||
self, recording, channel_ids=None, with_channel_ids=False, | ||
backend=None, **backend_or_plot_probe_kwargs | ||
): | ||
|
||
# split backend_or_plot_probe_kwargs | ||
backend_kwargs = dict() | ||
plot_probe_kwargs = dict() | ||
backend = self.check_backend(backend) | ||
for k, v in backend_or_plot_probe_kwargs.items(): | ||
if k in default_backend_kwargs[backend]: | ||
backend_kwargs[k] = v | ||
else: | ||
plot_probe_kwargs[k] = v | ||
|
||
plot_data = dict( | ||
recording=recording, | ||
channel_ids=channel_ids, | ||
with_channel_ids=with_channel_ids, | ||
plot_probe_kwargs=plot_probe_kwargs, | ||
) | ||
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) | ||
|
||
def plot_matplotlib(self, data_plot, **backend_kwargs): | ||
import matplotlib.pyplot as plt | ||
from .utils_matplotlib import make_mpl_figure | ||
from probeinterface.plotting import get_auto_lims, plot_probe | ||
|
||
dp = to_attr(data_plot) | ||
|
||
plot_probe_kwargs = dp.plot_probe_kwargs | ||
|
||
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) | ||
|
||
probegroup = dp.recording.get_probegroup() | ||
|
||
xlims, ylims, zlims = get_auto_lims(probegroup.probes[0]) | ||
for i, probe in enumerate(probegroup.probes): | ||
xlims2, ylims2, _ = get_auto_lims(probe) | ||
xlims = min(xlims[0], xlims2[0]), max(xlims[1], xlims2[1]) | ||
ylims = min(ylims[0], ylims2[0]), max(ylims[1], ylims2[1]) | ||
|
||
plot_probe_kwargs["title"] = False | ||
pos = 0 | ||
text_on_contact = None | ||
for i, probe in enumerate(probegroup.probes): | ||
n = probe.get_contact_count() | ||
if dp.with_channel_ids: | ||
text_on_contact = dp.recording.channel_ids[pos : pos + n] | ||
pos += n | ||
plot_probe(probe, ax=self.ax, text_on_contact=text_on_contact, **plot_probe_kwargs) | ||
|
||
self.ax.set_xlim(*xlims) | ||
self.ax.set_ylim(*ylims) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import numpy as np | ||
from warnings import warn | ||
|
||
from .base import BaseWidget, to_attr, default_backend_kwargs | ||
|
||
|
||
|
||
class RasterWidget(BaseWidget): | ||
""" | ||
Plots spike train rasters. | ||
Parameters | ||
---------- | ||
sorting: SortingExtractor | ||
The sorting extractor object | ||
segment_index: None or int | ||
The segment index. | ||
unit_ids: list | ||
List of unit ids | ||
time_range: list | ||
List with start time and end time | ||
color: matplotlib color | ||
The color to be used | ||
""" | ||
|
||
def __init__( | ||
self, sorting, segment_index=None, unit_ids=None, time_range=None, color="k", | ||
backend=None, **backend_kwargs | ||
): | ||
|
||
|
||
if segment_index is None: | ||
if sorting.get_num_segments() != 1: | ||
raise ValueError("You must provide segment_index=...") | ||
segment_index = 0 | ||
|
||
if time_range is None: | ||
frame_range = [0, sorting.to_spike_vector()[-1]["sample_index"]] | ||
time_range = [f / sorting.sampling_frequency for f in frame_range] | ||
else: | ||
assert len(time_range) == 2, "'time_range' should be a list with start and end time in seconds" | ||
frame_range = [int(t * sorting.sampling_frequency) for t in time_range] | ||
|
||
plot_data = dict( | ||
sorting=sorting, | ||
segment_index=segment_index, | ||
unit_ids=unit_ids, | ||
color=color, | ||
frame_range=frame_range, | ||
time_range=time_range, | ||
) | ||
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) | ||
|
||
def plot_matplotlib(self, data_plot, **backend_kwargs): | ||
import matplotlib.pyplot as plt | ||
from .utils_matplotlib import make_mpl_figure | ||
|
||
dp = to_attr(data_plot) | ||
sorting = dp.sorting | ||
|
||
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) | ||
|
||
units_ids = dp.unit_ids | ||
if units_ids is None: | ||
units_ids = sorting.unit_ids | ||
|
||
with plt.rc_context({"axes.edgecolor": "gray"}): | ||
for unit_index, unit_id in enumerate(units_ids): | ||
spiketrain = sorting.get_unit_spike_train( | ||
unit_id, | ||
start_frame=dp.frame_range[0], | ||
end_frame=dp.frame_range[1], | ||
segment_index=dp.segment_index, | ||
) | ||
spiketimes = spiketrain / float(sorting.sampling_frequency) | ||
self.ax.plot( | ||
spiketimes, | ||
unit_index * np.ones_like(spiketimes), | ||
marker="|", | ||
mew=1, | ||
markersize=3, | ||
ls="", | ||
color=dp.color, | ||
) | ||
self.ax.set_yticks(np.arange(len(units_ids))) | ||
self.ax.set_yticklabels(units_ids) | ||
self.ax.set_xlim(*dp.time_range) | ||
self.ax.set_xlabel("time (s)") | ||
|
||
|
||
|
||
|
||
|
||
|
||
|