Skip to content

Commit

Permalink
Merge branch 'refactor_GTStudy' of github.com:samuelgarcia/spikeinter…
Browse files Browse the repository at this point in the history
…face into refactor_GTStudy
  • Loading branch information
samuelgarcia committed Oct 7, 2024
2 parents 7cbbdef + 4016136 commit 56fff23
Show file tree
Hide file tree
Showing 14 changed files with 51 additions and 39 deletions.
2 changes: 1 addition & 1 deletion doc/modules/benchmark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ The all mechanism is based on an intrinsic organization into a "study_folder" wi
# some plots
m = comp.get_confusion_matrix()
w_comp = sw.plot_agreement_matrix(sorting_comparison=comp)
# Collect synthetic dataframes and display
# As shown previously, the performance is returned as a pandas dataframe.
# The spikeinterface.comparison.get_performance_by_unit() function,
Expand Down
4 changes: 0 additions & 4 deletions doc/modules/comparison.rst
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,3 @@ sorting analyzers from day 1 (:code:`analyzer_day1`) to day 5 (:code:`analyzer_d
# match all
m_tcmp = sc.compare_multiple_templates(waveform_list=analyzer_list,
name_list=["D1", "D2", "D3", "D4", "D5"])
2 changes: 1 addition & 1 deletion src/spikeinterface/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
* some sorting components (clustering, motion, template matching)
"""

from .benchmark_sorter import SorterStudy
from .benchmark_sorter import SorterStudy
2 changes: 1 addition & 1 deletion src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def get_run_times(self, case_keys=None):

def plot_run_times(self, case_keys=None):
from .benchmark_plot_tools import plot_run_times

return plot_run_times(self, case_keys=case_keys)

def compute_results(self, case_keys=None, verbose=False, **result_params):
Expand Down Expand Up @@ -445,4 +446,3 @@ def run(self):
def compute_result(self):
# run becnhmark result
raise NotImplementedError

4 changes: 3 additions & 1 deletion src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,18 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc
# plotting by methods
def plot_unit_counts(self, **kwargs):
from .benchmark_plot_tools import plot_unit_counts

return plot_unit_counts(self, **kwargs)

def plot_agreement_matrix(self, **kwargs):
from .benchmark_plot_tools import plot_agreement_matrix

return plot_agreement_matrix(self, **kwargs)

def plot_performances_vs_snr(self, **kwargs):
from .benchmark_plot_tools import plot_performances_vs_snr
return plot_performances_vs_snr(self, **kwargs)

return plot_performances_vs_snr(self, **kwargs)

def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):

Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def create_benchmark(self, key):

def plot_agreement_matrix(self, **kwargs):
from .benchmark_plot_tools import plot_agreement_matrix

return plot_agreement_matrix(self, **kwargs)

def plot_performances_vs_snr(self, **kwargs):
from .benchmark_plot_tools import plot_performances_vs_snr
return plot_performances_vs_snr(self, **kwargs)

return plot_performances_vs_snr(self, **kwargs)

def plot_collisions(self, case_keys=None, figsize=None):
if case_keys is None:
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/benchmark/benchmark_peak_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .benchmark_base import Benchmark, BenchmarkStudy


class PeakSelectionBenchmark(Benchmark):

def __init__(self, recording, gt_sorting, params, indices, exhaustive_gt=True):
Expand Down
12 changes: 4 additions & 8 deletions src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np



def _simpleaxis(ax):
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
Expand All @@ -28,7 +27,6 @@ def plot_run_times(study, case_keys=None):
run_times = study.get_run_times(case_keys=case_keys)

colors = study.get_colors()


fig, ax = plt.subplots()
labels = []
Expand Down Expand Up @@ -58,7 +56,6 @@ def plot_unit_counts(study, case_keys=None):
if case_keys is None:
case_keys = list(study.cases.keys())


count_units = study.get_count_units(case_keys=case_keys)

fig, ax = plt.subplots()
Expand Down Expand Up @@ -95,6 +92,7 @@ def plot_unit_counts(study, case_keys=None):

return fig


def plot_performances(study, mode="ordered", performance_names=("accuracy", "precision", "recall"), case_keys=None):
"""
Plot performances over case for a study.
Expand All @@ -121,10 +119,9 @@ def plot_performances(study, mode="ordered", performance_names=("accuracy", "pre
if case_keys is None:
case_keys = list(study.cases.keys())

perfs=study.get_performance_by_unit(case_keys=case_keys)
perfs = study.get_performance_by_unit(case_keys=case_keys)
colors = study.get_colors()


if mode in ("ordered", "snr"):
num_axes = len(performance_names)
fig, axs = plt.subplots(ncols=num_axes)
Expand Down Expand Up @@ -195,7 +192,6 @@ def plot_agreement_matrix(study, ordered=True, case_keys=None):
if case_keys is None:
case_keys = list(study.cases.keys())


num_axes = len(case_keys)
fig, axs = plt.subplots(ncols=num_axes)

Expand Down Expand Up @@ -238,9 +234,9 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu
y = study.get_result(key)["gt_comparison"].get_performance()[k].values
ax.scatter(x, y, marker=".", label=label)
ax.set_title(k)

ax.set_ylim(0, 1.05)

if count == 2:
ax.legend()

Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/benchmark/benchmark_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
This replace the previous `GroundTruthStudy`
"""


import numpy as np
from ..core import NumpySorting
from .benchmark_base import Benchmark, BenchmarkStudy
Expand Down Expand Up @@ -40,6 +39,7 @@ def compute_result(self):
("gt_comparison", "pickle"),
]


class SorterStudy(BenchmarkStudy):
"""
This class is used to tests several sorter in several situtation.
Expand Down Expand Up @@ -121,15 +121,15 @@ def get_count_units(self, case_keys=None, well_detected_score=None, redundant_sc
# plotting as methods
def plot_unit_counts(self, **kwargs):
from .benchmark_plot_tools import plot_unit_counts

return plot_unit_counts(self, **kwargs)

def plot_performances(self, **kwargs):
from .benchmark_plot_tools import plot_performances

return plot_performances(self, **kwargs)

def plot_agreement_matrix(self, **kwargs):
from .benchmark_plot_tools import plot_agreement_matrix
return plot_agreement_matrix(self, **kwargs)



return plot_agreement_matrix(self, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pathlib import Path


@pytest.mark.skip()
def test_benchmark_peak_selection(create_cache_folder):
cache_folder = create_cache_folder
Expand Down
4 changes: 1 addition & 3 deletions src/spikeinterface/benchmark/tests/test_benchmark_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def setup_module(tmp_path_factory):
create_a_study(study_folder)
return study_folder


def simple_preprocess(rec):
return bandpass_filter(rec)

Expand Down Expand Up @@ -75,14 +76,11 @@ def test_SorterStudy(setup_module):
# import matplotlib.pyplot as plt
# plt.show()


perf_by_unit = study.get_performance_by_unit()
# print(perf_by_unit)
count_units = study.get_count_units()
# print(count_units)




if __name__ == "__main__":
study_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy"
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/comparison/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def compute_collision_by_similarity(self, similarity_matrix, unit_ids=None, good

return similarities, recall_scores, pair_names


# This is removed at the moment.
# We need to move this maybe one day in benchmark.
# please do not delete this
Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


_txt_error_message = """
GroundTruthStudy has been replaced by SorterStudy with similar API but not back compatible folder loading.
You can do:
Expand All @@ -13,6 +11,7 @@
...
"""


class GroundTruthStudy:
def __init__(self, study_folder):
raise RuntimeError(_txt_error_message)
Expand Down
43 changes: 30 additions & 13 deletions src/spikeinterface/widgets/gtstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import warnings


class StudyRunTimesWidget(BaseWidget):
"""
Plot sorter run times for a SorterStudy.
Expand All @@ -25,12 +26,15 @@ class StudyRunTimesWidget(BaseWidget):
"""

def __init__(self, study, case_keys=None, backend=None, **backend_kwargs):
warnings.warn("plot_study_run_times is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.")
warnings.warn(
"plot_study_run_times is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead."
)
plot_data = dict(study=study, case_keys=case_keys)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
from spikeinterface.benchmark.benchmark_plot_tools import plot_run_times

plot_run_times(data_plot["study"], case_keys=data_plot["case_keys"])


Expand All @@ -48,12 +52,15 @@ class StudyUnitCountsWidget(BaseWidget):
"""

def __init__(self, study, case_keys=None, backend=None, **backend_kwargs):
warnings.warn("plot_study_unit_counts is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.")
warnings.warn(
"plot_study_unit_counts is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead."
)
plot_data = dict(study=study, case_keys=case_keys)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
from spikeinterface.benchmark.benchmark_plot_tools import plot_unit_counts

plot_unit_counts(data_plot["study"], case_keys=data_plot["case_keys"])


Expand Down Expand Up @@ -87,7 +94,9 @@ def __init__(
backend=None,
**backend_kwargs,
):
warnings.warn("plot_study_performances is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.")
warnings.warn(
"plot_study_performances is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead."
)
plot_data = dict(
study=study,
mode=mode,
Expand All @@ -98,13 +107,15 @@ def __init__(

def plot_matplotlib(self, data_plot, **backend_kwargs):
from spikeinterface.benchmark.benchmark_plot_tools import plot_performances

plot_performances(
data_plot["study"],
mode=data_plot["mode"],
performance_names=data_plot["performance_names"],
case_keys=data_plot["case_keys"]
case_keys=data_plot["case_keys"],
)


class StudyAgreementMatrix(BaseWidget):
"""
Plot agreement matrix.
Expand All @@ -128,7 +139,9 @@ def __init__(
backend=None,
**backend_kwargs,
):
warnings.warn("plot_study_agreement_matrix is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.")
warnings.warn(
"plot_study_agreement_matrix is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead."
)
plot_data = dict(
study=study,
case_keys=case_keys,
Expand All @@ -139,11 +152,8 @@ def __init__(

def plot_matplotlib(self, data_plot, **backend_kwargs):
from spikeinterface.benchmark.benchmark_plot_tools import plot_agreement_matrix
plot_agreement_matrix(
data_plot["study"],
ordered=data_plot["ordered"],
case_keys=data_plot["case_keys"]
)

plot_agreement_matrix(data_plot["study"], ordered=data_plot["ordered"], case_keys=data_plot["case_keys"])


class StudySummary(BaseWidget):
Expand Down Expand Up @@ -171,16 +181,23 @@ def __init__(
backend=None,
**backend_kwargs,
):

warnings.warn("plot_study_summary is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead.")

warnings.warn(
"plot_study_summary is to be deprecated. Use spikeinterface.benchmark.benchmark_plot_tools instead."
)
plot_data = dict(study=study, case_keys=case_keys)
BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
study = data_plot["study"]
case_keys = data_plot["case_keys"]

from spikeinterface.benchmark.benchmark_plot_tools import plot_agreement_matrix, plot_performances, plot_unit_counts, plot_run_times
from spikeinterface.benchmark.benchmark_plot_tools import (
plot_agreement_matrix,
plot_performances,
plot_unit_counts,
plot_run_times,
)

plot_performances(study=study, case_keys=case_keys, mode="ordered")
plot_performances(study=study, case_keys=case_keys, mode="snr")
Expand Down

0 comments on commit 56fff23

Please sign in to comment.