Skip to content

Commit

Permalink
Merge pull request #2128 from samuelgarcia/gt_study
Browse files Browse the repository at this point in the history
Some improvement in Study and related widgets
  • Loading branch information
alejoe91 authored Oct 26, 2023
2 parents ef095c2 + bf5b0a6 commit 109b5b3
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 170 deletions.
56 changes: 30 additions & 26 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None):
(study_folder / "sortings").mkdir()
(study_folder / "sortings" / "run_logs").mkdir()
(study_folder / "metrics").mkdir()
(study_folder / "comparisons").mkdir()

for key, (rec, gt_sorting) in datasets.items():
assert "/" not in key, "'/' cannot be in the key name!"
Expand Down Expand Up @@ -127,16 +128,17 @@ def scan_folder(self):
with open(self.folder / "cases.pickle", "rb") as f:
self.cases = pickle.load(f)

self.sortings = {k: None for k in self.cases}
self.comparisons = {k: None for k in self.cases}

self.sortings = {}
for key in self.cases:
sorting_folder = self.folder / "sortings" / self.key_to_str(key)
if sorting_folder.exists():
sorting = load_extractor(sorting_folder)
else:
sorting = None
self.sortings[key] = sorting
self.sortings[key] = load_extractor(sorting_folder)

comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle")
if comparison_file.exists():
with open(comparison_file, mode="rb") as f:
self.comparisons[key] = pickle.load(f)

def __repr__(self):
t = f"{self.__class__.__name__} {self.folder.stem} \n"
Expand All @@ -155,6 +157,16 @@ def key_to_str(self, key):
else:
raise ValueError("Keys for cases must str or tuple")

def remove_sorting(self, key):
sorting_folder = self.folder / "sortings" / self.key_to_str(key)
log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json"
comparison_file = self.folder / "comparisons" / self.key_to_str(key)
if sorting_folder.exists():
shutil.rmtree(sorting_folder)
for f in (log_file, comparison_file):
if f.exists():
f.unlink()

def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True, verbose=False):
if case_keys is None:
case_keys = self.cases.keys()
Expand All @@ -178,12 +190,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True
self.copy_sortings(case_keys=[key])
continue

if sorting_exists:
# delete older sorting + log before running sorters
shutil.rmtree(sorting_folder)
log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json"
if log_file.exists():
log_file.unlink()
self.remove_sorting(key)

if sorter_folder_exists:
shutil.rmtree(sorter_folder)
Expand Down Expand Up @@ -228,10 +235,7 @@ def copy_sortings(self, case_keys=None, force=True):
if sorting is not None:
if sorting_folder.exists():
if force:
# delete folder + log
shutil.rmtree(sorting_folder)
if log_file.exists():
log_file.unlink()
self.remove_sorting(key)
else:
continue

Expand All @@ -255,6 +259,10 @@ def run_comparisons(self, case_keys=None, comparison_class=GroundTruthComparison
comp = comparison_class(gt_sorting, sorting, **kwargs)
self.comparisons[key] = comp

comparison_file = self.folder / "comparisons" / (self.key_to_str(key) + ".pickle")
with open(comparison_file, mode="wb") as f:
pickle.dump(comp, f)

def get_run_times(self, case_keys=None):
import pandas as pd

Expand Down Expand Up @@ -288,20 +296,16 @@ def extract_waveforms_gt(self, case_keys=None, **extract_kwargs):
recording, gt_sorting = self.datasets[dataset_key]
we = extract_waveforms(recording, gt_sorting, folder=wf_folder, **extract_kwargs)

def get_waveform_extractor(self, key):
# some recording are not dumpable to json and the waveforms extactor need it!
# so we load it with and put after
# this should be fixed in PR 2027 so remove this after
def get_waveform_extractor(self, case_key=None, dataset_key=None):
if case_key is not None:
dataset_key = self.cases[case_key]["dataset"]

dataset_key = self.cases[key]["dataset"]
wf_folder = self.folder / "waveforms" / self.key_to_str(dataset_key)
we = load_waveforms(wf_folder, with_recording=False)
recording, _ = self.datasets[dataset_key]
we.set_recording(recording)
we = load_waveforms(wf_folder, with_recording=True)
return we

def get_templates(self, key, mode="average"):
we = self.get_waveform_extractor(key)
we = self.get_waveform_extractor(case_key=key)
templates = we.get_all_templates(mode=mode)
return templates

Expand Down Expand Up @@ -366,7 +370,7 @@ def get_performance_by_unit(self, case_keys=None):
perf_by_unit.append(perf)

perf_by_unit = pd.concat(perf_by_unit)
perf_by_unit = perf_by_unit.set_index(self.levels)
perf_by_unit = perf_by_unit.set_index(self.levels).sort_index()
return perf_by_unit

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
Expand Down
18 changes: 3 additions & 15 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from .tools import aggregate_sparse_features, FeaturesLoader, compute_template_from_sparse


DEBUG = False


def merge_clusters(
peaks,
peak_labels,
Expand Down Expand Up @@ -81,7 +84,6 @@ def merge_clusters(
**job_kwargs,
)

DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -224,17 +226,13 @@ def agglomerate_pairs(labels_set, pair_mask, pair_values, connection_mode="full"
else:
raise ValueError

# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

fig = plt.figure()
nx.draw_networkx(sub_graph)
plt.show()

# DEBUG = True
DEBUG = False
if DEBUG:
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -551,15 +549,7 @@ def merge(
else:
final_shift = 0

# DEBUG = True
DEBUG = False

# if DEBUG and is_merge:
# if DEBUG and (overlap > 0.1 and overlap <0.3):
if DEBUG:
# if DEBUG and not is_merge:
# if DEBUG and (overlap > 0.05 and overlap <0.25):
# if label0 == 49 and label1== 65:
import matplotlib.pyplot as plt

flatten_wfs0 = wfs0.swapaxes(1, 2).reshape(wfs0.shape[0], -1)
Expand Down Expand Up @@ -674,8 +664,6 @@ def merge(
final_shift = 0
merge_value = np.nan

# DEBUG = False
DEBUG = True
if DEBUG and normed_diff < 0.2:
# if DEBUG:

Expand Down
87 changes: 0 additions & 87 deletions src/spikeinterface/widgets/agreement_matrix.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
from warnings import warn

from .base import BaseWidget, to_attr
from .utils import get_unit_colors


class ConfusionMatrixWidget(BaseWidget):
Expand Down Expand Up @@ -77,3 +75,85 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
N1 + 0.5,
-0.5,
)


class AgreementMatrixWidget(BaseWidget):
"""
Plots sorting comparison agreement matrix.
Parameters
----------
sorting_comparison: GroundTruthComparison or SymmetricSortingComparison
The sorting comparison object.
Can optionally be symmetric if given a SymmetricSortingComparison
ordered: bool, default: True
Order units with best agreement scores.
If True, agreement scores can be seen along a diagonal
count_text: bool, default: True
If True counts are displayed as text
unit_ticks: bool, default: True
If True unit tick labels are displayed
"""

def __init__(
self, sorting_comparison, ordered=True, count_text=True, unit_ticks=True, backend=None, **backend_kwargs
):
plot_data = dict(
sorting_comparison=sorting_comparison,
ordered=ordered,
count_text=count_text,
unit_ticks=unit_ticks,
)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

comp = dp.sorting_comparison

if dp.ordered:
scores = comp.get_ordered_agreement_scores()
else:
scores = comp.agreement_scores

N1 = scores.shape[0]
N2 = scores.shape[1]

unit_ids1 = scores.index.values
unit_ids2 = scores.columns.values

# Using matshow here just because it sets the ticks up nicely. imshow is faster.
self.ax.matshow(scores.values, cmap="Greens")

if dp.count_text:
for i, u1 in enumerate(unit_ids1):
u2 = comp.best_match_12[u1]
if u2 != -1:
j = np.where(unit_ids2 == u2)[0][0]

self.ax.text(j, i, "{:0.2f}".format(scores.at[u1, u2]), ha="center", va="center", color="white")

# Major ticks
self.ax.xaxis.tick_bottom()

# Labels for major ticks
if dp.unit_ticks:
self.ax.set_xticks(np.arange(0, N2))
self.ax.set_yticks(np.arange(0, N1))
self.ax.set_yticklabels(scores.index)
self.ax.set_xticklabels(scores.columns)

self.ax.set_xlabel(comp.name_list[1])
self.ax.set_ylabel(comp.name_list[0])

self.ax.set_xlim(-0.5, N2 - 0.5)
self.ax.set_ylim(
N1 - 0.5,
-0.5,
)
Loading

0 comments on commit 109b5b3

Please sign in to comment.