diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index e5f4ce8b31..0d08922543 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -88,6 +88,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): (study_folder / "sortings").mkdir() (study_folder / "sortings" / "run_logs").mkdir() (study_folder / "metrics").mkdir() + (study_folder / "comparisons").mkdir() for key, (rec, gt_sorting) in datasets.items(): assert "/" not in key, "'/' cannot be in the key name!" @@ -127,16 +128,17 @@ def scan_folder(self): with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) + self.sortings = {k: None for k in self.cases} self.comparisons = {k: None for k in self.cases} - - self.sortings = {} for key in self.cases: sorting_folder = self.folder / "sortings" / self.key_to_str(key) if sorting_folder.exists(): - sorting = load_extractor(sorting_folder) - else: - sorting = None - self.sortings[key] = sorting + self.sortings[key] = load_extractor(sorting_folder) + + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") + if comparison_file.exists(): + with open(comparison_file, mode="rb") as f: + self.comparisons[key] = pickle.load(f) def __repr__(self): t = f"{self.__class__.__name__} {self.folder.stem} \n" @@ -155,6 +157,16 @@ def key_to_str(self, key): else: raise ValueError("Keys for cases must str or tuple") + def remove_sorting(self, key): + sorting_folder = self.folder / "sortings" / self.key_to_str(key) + log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" + comparison_file = self.folder / "comparisons" / self.key_to_str(key) + if sorting_folder.exists(): + shutil.rmtree(sorting_folder) + for f in (log_file, comparison_file): + if f.exists(): + f.unlink() + def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False): if case_keys is None: case_keys = self.cases.keys() @@ -178,12 +190,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True self.copy_sortings(case_keys=[key]) continue - if sorting_exists: - # delete older sorting + log before running sorters - shutil.rmtree(sorting_folder) - log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" - if log_file.exists(): - log_file.unlink() + self.remove_sorting(key) if sorter_folder_exists: shutil.rmtree(sorter_folder) @@ -228,10 +235,7 @@ def copy_sortings(self, case_keys=None, force=True): if sorting is not None: if sorting_folder.exists(): if force: - # delete folder + log - shutil.rmtree(sorting_folder) - if log_file.exists(): - log_file.unlink() + self.remove_sorting(key) else: continue @@ -255,6 +259,10 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison comp = comparison_class(gt_sorting, sorting, **kwargs) self.comparisons[key] = comp + comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle") + with open(comparison_file, mode="wb") as f: + pickle.dump(comp, f) + def get_run_times(self, case_keys=None): import pandas as pd @@ -288,20 +296,16 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs): recording, gt_sorting = self.datasets[dataset_key] we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs) - def get_waveform_extractor(self, key): - # some recording are not dumpable to json and the waveforms extactor need it! - # so we load it with and put after - # this should be fixed in PR 2027 so remove this after + def get_waveform_extractor(self, case_key=None, dataset_key=None): + if case_key is not None: + dataset_key = self.cases[case_key]["dataset"] - dataset_key = self.cases[key]["dataset"] wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key) - we = load_waveforms(wf_folder, with_recording=False) - recording, _ = self.datasets[dataset_key] - we.set_recording(recording) + we = load_waveforms(wf_folder, with_recording=True) return we def get_templates(self, key, mode="average"): - we = self.get_waveform_extractor(key) + we = self.get_waveform_extractor(case_key=key) templates = we.get_all_templates(mode=mode) return templates @@ -366,7 +370,7 @@ def get_performance_by_unit(self, case_keys=None): perf_by_unit.append(perf) perf_by_unit = pd.concat(perf_by_unit) - perf_by_unit = perf_by_unit.set_index(self.levels) + perf_by_unit = perf_by_unit.set_index(self.levels).sort_index() return perf_by_unit def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 4c79383542..1ed51fb04f 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -20,6 +20,9 @@ from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse +DEBUG = False + + def merge_clusters( peaks, peak_labels, @@ -81,7 +84,6 @@ def merge_clusters( **job_kwargs, ) - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -224,8 +226,6 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" else: raise ValueError - # DEBUG = True - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -233,8 +233,6 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full" nx.draw_networkx(sub_graph) plt.show() - # DEBUG = True - DEBUG = False if DEBUG: import matplotlib.pyplot as plt @@ -551,15 +549,7 @@ def merge( else: final_shift = 0 - # DEBUG = True - DEBUG = False - - # if DEBUG and is_merge: - # if DEBUG and (overlap > 0.1 and overlap <0.3): if DEBUG: - # if DEBUG and not is_merge: - # if DEBUG and (overlap > 0.05 and overlap <0.25): - # if label0 == 49 and label1== 65: import matplotlib.pyplot as plt flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1) @@ -674,8 +664,6 @@ def merge( final_shift = 0 merge_value = np.nan - # DEBUG = False - DEBUG = True if DEBUG and normed_diff < 0.2: # if DEBUG: diff --git a/src/spikeinterface/widgets/agreement_matrix.py b/src/spikeinterface/widgets/agreement_matrix.py deleted file mode 100644 index ec6ea1c87c..0000000000 --- a/src/spikeinterface/widgets/agreement_matrix.py +++ /dev/null @@ -1,87 +0,0 @@ -import numpy as np -from warnings import warn - -from .base import BaseWidget, to_attr -from .utils import get_unit_colors - - -class AgreementMatrixWidget(BaseWidget): - """ - Plots sorting comparison agreement matrix. - - Parameters - ---------- - sorting_comparison: GroundTruthComparison or SymmetricSortingComparison - The sorting comparison object. - Symetric or not. - ordered: bool - Order units with best agreement scores. - This enable to see agreement on a diagonal. - count_text: bool - If True counts are displayed as text - unit_ticks: bool - If True unit tick labels are displayed - - """ - - def __init__( - self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs - ): - plot_data = dict( - sorting_comparison=sorting_comparison, - ordered=ordered, - count_text=count_text, - unit_ticks=unit_ticks, - ) - BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) - - def plot_matplotlib(self, data_plot, **backend_kwargs): - import matplotlib.pyplot as plt - from .utils_matplotlib import make_mpl_figure - - dp = to_attr(data_plot) - - self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - - comp = dp.sorting_comparison - - if dp.ordered: - scores = comp.get_ordered_agreement_scores() - else: - scores = comp.agreement_scores - - N1 = scores.shape[0] - N2 = scores.shape[1] - - unit_ids1 = scores.index.values - unit_ids2 = scores.columns.values - - # Using matshow here just because it sets the ticks up nicely. imshow is faster. - self.ax.matshow(scores.values, cmap="Greens") - - if dp.count_text: - for i, u1 in enumerate(unit_ids1): - u2 = comp.best_match_12[u1] - if u2 != -1: - j = np.where(unit_ids2 == u2)[0][0] - - self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") - - # Major ticks - self.ax.set_xticks(np.arange(0, N2)) - self.ax.set_yticks(np.arange(0, N1)) - self.ax.xaxis.tick_bottom() - - # Labels for major ticks - if dp.unit_ticks: - self.ax.set_yticklabels(scores.index, fontsize=12) - self.ax.set_xticklabels(scores.columns, fontsize=12) - - self.ax.set_xlabel(comp.name_list[1], fontsize=20) - self.ax.set_ylabel(comp.name_list[0], fontsize=20) - - self.ax.set_xlim(-0.5, N2 - 0.5) - self.ax.set_ylim( - N1 - 0.5, - -0.5, - ) diff --git a/src/spikeinterface/widgets/confusion_matrix.py b/src/spikeinterface/widgets/comparison.py similarity index 50% rename from src/spikeinterface/widgets/confusion_matrix.py rename to src/spikeinterface/widgets/comparison.py index 8eb58f30b2..70f98df8b9 100644 --- a/src/spikeinterface/widgets/confusion_matrix.py +++ b/src/spikeinterface/widgets/comparison.py @@ -1,8 +1,6 @@ import numpy as np -from warnings import warn from .base import BaseWidget, to_attr -from .utils import get_unit_colors class ConfusionMatrixWidget(BaseWidget): @@ -77,3 +75,85 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): N1 + 0.5, -0.5, ) + + +class AgreementMatrixWidget(BaseWidget): + """ + Plots sorting comparison agreement matrix. + + Parameters + ---------- + sorting_comparison: GroundTruthComparison or SymmetricSortingComparison + The sorting comparison object. + Can optionally be symmetric if given a SymmetricSortingComparison + ordered: bool, default: True + Order units with best agreement scores. + If True, agreement scores can be seen along a diagonal + count_text: bool, default: True + If True counts are displayed as text + unit_ticks: bool, default: True + If True unit tick labels are displayed + + """ + + def __init__( + self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs + ): + plot_data = dict( + sorting_comparison=sorting_comparison, + ordered=ordered, + count_text=count_text, + unit_ticks=unit_ticks, + ) + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + dp = to_attr(data_plot) + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + comp = dp.sorting_comparison + + if dp.ordered: + scores = comp.get_ordered_agreement_scores() + else: + scores = comp.agreement_scores + + N1 = scores.shape[0] + N2 = scores.shape[1] + + unit_ids1 = scores.index.values + unit_ids2 = scores.columns.values + + # Using matshow here just because it sets the ticks up nicely. imshow is faster. + self.ax.matshow(scores.values, cmap="Greens") + + if dp.count_text: + for i, u1 in enumerate(unit_ids1): + u2 = comp.best_match_12[u1] + if u2 != -1: + j = np.where(unit_ids2 == u2)[0][0] + + self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white") + + # Major ticks + self.ax.xaxis.tick_bottom() + + # Labels for major ticks + if dp.unit_ticks: + self.ax.set_xticks(np.arange(0, N2)) + self.ax.set_yticks(np.arange(0, N1)) + self.ax.set_yticklabels(scores.index) + self.ax.set_xticklabels(scores.columns) + + self.ax.set_xlabel(comp.name_list[1]) + self.ax.set_ylabel(comp.name_list[0]) + + self.ax.set_xlim(-0.5, N2 - 0.5) + self.ax.set_ylim( + N1 - 0.5, + -0.5, + ) diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 6a27b78dec..6e4433ee60 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -1,11 +1,6 @@ import numpy as np from .base import BaseWidget, to_attr -from .utils import get_unit_colors - -from ..core import ChannelSparsity -from ..core.waveform_extractor import WaveformExtractor -from ..core.basesorting import BaseSorting class StudyRunTimesWidget(BaseWidget): @@ -129,7 +124,6 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): self.ax.legend() -# TODO : plot optionally average on some levels using group by class StudyPerformances(BaseWidget): """ Plot performances over case for a study. @@ -139,17 +133,23 @@ class StudyPerformances(BaseWidget): ---------- study: GroundTruthStudy A study object. - mode: str - Which mode in "swarm" + mode: "ordered" | "snr" | "swarm", default: "ordered" + Which plot mode to use: + + * "ordered": plot performance metrics vs unit indices ordered by decreasing accuracy (default) + * "snr": plot performance metrics vs snr + * "swarm": plot performance metrics as a swarm plot (see seaborn.swarmplot for details) + performance_names: list or tuple, default: ('accuracy', 'precision', 'recall') + Which performances to plot ("accuracy", "precision", "recall") case_keys: list or None A selection of cases to plot, if None, then all. - """ def __init__( self, study, - mode="swarm", + mode="ordered", + performance_names=("accuracy", "precision", "recall"), case_keys=None, backend=None, **backend_kwargs, @@ -161,6 +161,7 @@ def __init__( study=study, perfs=study.get_performance_by_unit(case_keys=case_keys), mode=mode, + performance_names=performance_names, case_keys=case_keys, ) @@ -176,43 +177,75 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) perfs = dp.perfs + study = dp.study + if dp.mode in ("ordered", "snr"): + backend_kwargs["num_axes"] = len(dp.performance_names) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - if dp.mode == "swarm": + if dp.mode == "ordered": + for count, performance_name in enumerate(dp.performance_names): + ax = self.axes.flatten()[count] + for key in dp.case_keys: + label = study.cases[key]["label"] + + val = perfs.xs(key).loc[:, performance_name].values + val = np.sort(val)[::-1] + ax.plot(val, label=label) + ax.set_title(performance_name) + if count == 0: + ax.legend(loc="upper right") + + elif dp.mode == "snr": + metric_name = dp.mode + for count, performance_name in enumerate(dp.performance_names): + ax = self.axes.flatten()[count] + + max_metric = 0 + for key in dp.case_keys: + x = study.get_metrics(key).loc[:, metric_name].values + y = perfs.xs(key).loc[:, performance_name].values + label = study.cases[key]["label"] + ax.scatter(x, y, s=10, label=label) + max_metric = max(max_metric, np.max(x)) + ax.set_title(performance_name) + ax.set_xlim(0, max_metric * 1.05) + ax.set_ylim(0, 1.05) + if count == 0: + ax.legend(loc="lower right") + + elif dp.mode == "swarm": levels = perfs.index.names df = pd.melt( perfs.reset_index(), id_vars=levels, var_name="Metric", value_name="Score", - value_vars=("accuracy", "precision", "recall"), + value_vars=dp.performance_names, ) df["x"] = df.apply(lambda r: " ".join([r[col] for col in levels]), axis=1) sns.swarmplot(data=df, x="x", y="Score", hue="Metric", dodge=True) -class StudyPerformancesVsMetrics(BaseWidget): +class StudyAgreementMatrix(BaseWidget): """ - Plot performances vs a metrics (snr for instance) over case for a study. - + Plot agreement matrix. Parameters ---------- study: GroundTruthStudy A study object. - mode: str - Which mode in "swarm" case_keys: list or None A selection of cases to plot, if None, then all. - + ordered: bool + Order units with best agreement scores. + This enable to see agreement on a diagonal. """ def __init__( self, study, - metric_name="snr", - performance_name="accuracy", + ordered=True, case_keys=None, backend=None, **backend_kwargs, @@ -222,9 +255,8 @@ def __init__( plot_data = dict( study=study, - metric_name=metric_name, - performance_name=performance_name, case_keys=case_keys, + ordered=ordered, ) BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) @@ -232,22 +264,79 @@ def __init__( def plot_matplotlib(self, data_plot, **backend_kwargs): import matplotlib.pyplot as plt from .utils_matplotlib import make_mpl_figure - from .utils import get_some_colors + from .comparison import AgreementMatrixWidget dp = to_attr(data_plot) + study = dp.study + + backend_kwargs["num_axes"] = len(dp.case_keys) self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) - study = dp.study - perfs = study.get_performance_by_unit(case_keys=dp.case_keys) + for count, key in enumerate(dp.case_keys): + ax = self.axes.flatten()[count] + comp = study.comparisons[key] + unit_ticks = len(comp.sorting1.unit_ids) <= 16 + count_text = len(comp.sorting1.unit_ids) <= 16 - max_metric = 0 - for key in dp.case_keys: - x = study.get_metrics(key)[dp.metric_name].values - y = perfs.xs(key)[dp.performance_name].values - label = dp.study.cases[key]["label"] - self.ax.scatter(x, y, label=label) - max_metric = max(max_metric, np.max(x)) + AgreementMatrixWidget( + comp, ordered=dp.ordered, count_text=count_text, unit_ticks=unit_ticks, backend="matplotlib", ax=ax + ) + label = study.cases[key]["label"] + ax.set_xlabel(label) - self.ax.legend() - self.ax.set_xlim(0, max_metric * 1.05) - self.ax.set_ylim(0, 1.05) + if count > 0: + ax.set_ylabel(None) + ax.set_yticks([]) + ax.set_xticks([]) + + # ax0 = self.axes.flatten()[0] + # for ax in self.axes.flatten()[1:]: + # ax.sharey(ax0) + + +class StudySummary(BaseWidget): + """ + Plot a summary of a ground truth study. + Internally does: + plot_study_run_times + plot_study_unit_counts + plot_study_performances + plot_study_agreement_matrix + + Parameters + ---------- + study: GroundTruthStudy + A study object. + case_keys: list or None, default: None + A selection of cases to plot, if None, then all. + """ + + def __init__( + self, + study, + case_keys=None, + backend=None, + **backend_kwargs, + ): + if case_keys is None: + case_keys = list(study.cases.keys()) + + plot_data = dict( + study=study, + case_keys=case_keys, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + + study = data_plot["study"] + case_keys = data_plot["case_keys"] + + StudyPerformances(study=study, case_keys=case_keys, mode="ordered", backend="matplotlib", **backend_kwargs) + StudyPerformances(study=study, case_keys=case_keys, mode="snr", backend="matplotlib", **backend_kwargs) + StudyAgreementMatrix(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + StudyRunTimesWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) + StudyUnitCountsWidget(study=study, case_keys=case_keys, backend="matplotlib", **backend_kwargs) diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index 00d179127d..b95e92668a 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -2,11 +2,10 @@ from .base import backend_kwargs_desc -from .agreement_matrix import AgreementMatrixWidget + from .all_amplitudes_distributions import AllAmplitudesDistributionsWidget from .amplitudes import AmplitudesWidget from .autocorrelograms import AutoCorrelogramsWidget -from .confusion_matrix import ConfusionMatrixWidget from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget from .motion import MotionWidget @@ -29,7 +28,8 @@ from .unit_templates import UnitTemplatesWidget from .unit_waveforms_density_map import UnitWaveformDensityMapWidget from .unit_waveforms import UnitWaveformsWidget -from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyPerformancesVsMetrics +from .comparison import AgreementMatrixWidget, ConfusionMatrixWidget +from .gtstudy import StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, StudyAgreementMatrix, StudySummary from .collision import ComparisonCollisionBySimilarityWidget, StudyComparisonCollisionBySimilarityWidget widget_list = [ @@ -66,7 +66,8 @@ StudyRunTimesWidget, StudyUnitCountsWidget, StudyPerformances, - StudyPerformancesVsMetrics, + StudyAgreementMatrix, + StudySummary, StudyComparisonCollisionBySimilarityWidget, ] @@ -136,7 +137,8 @@ plot_study_run_times = StudyRunTimesWidget plot_study_unit_counts = StudyUnitCountsWidget plot_study_performances = StudyPerformances -plot_study_performances_vs_metrics = StudyPerformancesVsMetrics +plot_study_agreement_matrix = StudyAgreementMatrix +plot_study_summary = StudySummary plot_study_comparison_collision_by_similarity = StudyComparisonCollisionBySimilarityWidget