diff --git a/src/spikeinterface/widgets/comparison.py b/src/spikeinterface/widgets/comparison.py index c45b8bf1db..1b14275459 100644 --- a/src/spikeinterface/widgets/comparison.py +++ b/src/spikeinterface/widgets/comparison.py @@ -142,17 +142,17 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") # Major ticks - self.ax.set_xticks(np.arange(0, N2)) - self.ax.set_yticks(np.arange(0, N1)) self.ax.xaxis.tick_bottom() # Labels for major ticks if dp.unit_ticks: - self.ax.set_yticklabels(scores.index, fontsize=12) - self.ax.set_xticklabels(scores.columns, fontsize=12) + self.ax.set_xticks(np.arange(0, N2)) + self.ax.set_yticks(np.arange(0, N1)) + self.ax.set_yticklabels(scores.index) + self.ax.set_xticklabels(scores.columns) - self.ax.set_xlabel(comp.name_list[1], fontsize=20) - self.ax.set_ylabel(comp.name_list[0], fontsize=20) + self.ax.set_xlabel(comp.name_list[1]) + self.ax.set_ylabel(comp.name_list[0]) self.ax.set_xlim(-0.5, N2 - 0.5) self.ax.set_ylim( diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 9867bc6a36..77c557910f 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -190,7 +190,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax.plot(val, label=label) ax.set_title(performance_name) if count == 0: - ax.legend() + ax.legend(loc='upper right') elif dp.mode == "snr": @@ -203,13 +203,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): x = study.get_metrics(key).loc[:, metric_name].values y = perfs.xs(key).loc[:, performance_name].values label = study.cases[key]["label"] - ax.scatter(x, y, label=label) + ax.scatter(x, y, s=10, label=label) max_metric = max(max_metric, np.max(x)) ax.set_title(performance_name) ax.set_xlim(0, max_metric * 1.05) ax.set_ylim(0, 1.05) if count == 0: - ax.legend() + ax.legend(loc='lower right') elif dp.mode == "swarm": @@ -245,7 +245,7 @@ class StudyAgreementMatrix(BaseWidget): def __init__( self, study, - ordered=True, count_text=True, + ordered=True, case_keys=None, backend=None, **backend_kwargs, @@ -257,7 +257,6 @@ def __init__( study=study, case_keys=case_keys, ordered=ordered, - count_text=count_text, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -276,9 +275,22 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): for count, key in enumerate(dp.case_keys): ax = self.axes.flatten()[count] comp = study.comparisons[key] - AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=dp.count_text, backend='matplotlib', ax=ax) + unit_ticks = len(comp.sorting1.unit_ids) <= 16 + count_text = len(comp.sorting1.unit_ids) <= 16 + + + AgreementMatrixWidget(comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend='matplotlib', ax=ax) label = study.cases[key]["label"] - ax.set_title(label) + ax.set_xlabel(label) + + if count > 0: + ax.set_ylabel(None) + ax.set_yticks([]) + ax.set_xticks([]) + + # ax0 = self.axes.flatten()[0] + # for ax in self.axes.flatten()[1:]: + # ax.sharey(ax0) class StudySummary(BaseWidget):