Skip to content

Commit

Permalink
Merge branch 'main' into expose_attempts_in_plexon2
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin authored Oct 15, 2024
2 parents 9175fa1 + 65d4b1e commit 94ab456
Show file tree
Hide file tree
Showing 18 changed files with 1,083 additions and 486 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 24.8.0
rev: 24.10.0
hooks:
- id: black
files: ^src/
74 changes: 9 additions & 65 deletions src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(self, **job_kwargs):
sorting["unit_index"] = spikes["cluster_index"]
sorting["segment_index"] = spikes["segment_index"]
sorting = NumpySorting(sorting, self.recording.sampling_frequency, unit_ids)
self.result = {"sorting": sorting}
self.result = {"sorting": sorting, "spikes": spikes}
self.result["templates"] = self.templates

def compute_result(self, with_collision=False, **result_params):
Expand All @@ -45,6 +45,7 @@ def compute_result(self, with_collision=False, **result_params):

_run_key_saved = [
("sorting", "sorting"),
("spikes", "npy"),
("templates", "zarr_templates"),
]
_result_key_saved = [("gt_collision", "pickle"), ("gt_comparison", "pickle")]
Expand All @@ -71,9 +72,15 @@ def plot_performances_vs_snr(self, **kwargs):

return plot_performances_vs_snr(self, **kwargs)

def plot_performances_comparison(self, **kwargs):
from .benchmark_plot_tools import plot_performances_comparison

return plot_performances_comparison(self, **kwargs)

def plot_collisions(self, case_keys=None, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand All @@ -90,70 +97,6 @@ def plot_collisions(self, case_keys=None, figsize=None):

return fig

def plot_comparison_matching(
self,
case_keys=None,
performance_names=["accuracy", "recall", "precision"],
colors=["g", "b", "r"],
ylim=(-0.1, 1.1),
figsize=None,
):

if case_keys is None:
case_keys = list(self.cases.keys())

num_methods = len(case_keys)
import pylab as plt

fig, axs = plt.subplots(ncols=num_methods, nrows=num_methods, figsize=(10, 10))
for i, key1 in enumerate(case_keys):
for j, key2 in enumerate(case_keys):
if len(axs.shape) > 1:
ax = axs[i, j]
else:
ax = axs[j]
comp1 = self.get_result(key1)["gt_comparison"]
comp2 = self.get_result(key2)["gt_comparison"]
if i <= j:
for performance, color in zip(performance_names, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.plot(perf2, perf1, ".", label=performance, color=color)

ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

label1 = self.cases[key1]["label"]
label2 = self.cases[key2]["label"]
if j == i:
ax.set_ylabel(f"{label1}")
else:
ax.set_yticks([])
if i == j:
ax.set_xlabel(f"{label2}")
else:
ax.set_xticks([])
if i == num_methods - 1 and j == num_methods - 1:
patches = []
import matplotlib.patches as mpatches

for color, name in zip(colors, performance_names):
patches.append(mpatches.Patch(color=color, label=name))
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0.0)
else:
ax.spines["bottom"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout(h_pad=0, w_pad=0)

return fig

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
import pandas as pd

Expand Down Expand Up @@ -196,6 +139,7 @@ def plot_unit_counts(self, case_keys=None, figsize=None):
plot_study_unit_counts(self, case_keys, figsize=figsize)

def plot_unit_losses(self, before, after, metric=["precision"], figsize=None):
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=1, nrows=len(metric), figsize=figsize, squeeze=False)

Expand Down
64 changes: 63 additions & 1 deletion src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,71 @@ def plot_performances_vs_snr(study, case_keys=None, figsize=None, metrics=["accu
ax.scatter(x, y, marker=".", label=label)
ax.set_title(k)

ax.set_ylim(0, 1.05)
ax.set_ylim(-0.05, 1.05)

if count == 2:
ax.legend()

return fig


def plot_performances_comparison(
study,
case_keys=None,
figsize=None,
metrics=["accuracy", "recall", "precision"],
colors=["g", "b", "r"],
ylim=(-0.1, 1.1),
):
import matplotlib.pyplot as plt

if case_keys is None:
case_keys = list(study.cases.keys())

num_methods = len(case_keys)
assert num_methods >= 2, "plot_performances_comparison need at least 2 cases!"

fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=(10, 10), squeeze=False)
for i, key1 in enumerate(case_keys):
for j, key2 in enumerate(case_keys):

if i < j:
ax = axs[i, j - 1]

comp1 = study.get_result(key1)["gt_comparison"]
comp2 = study.get_result(key2)["gt_comparison"]

for performance, color in zip(metrics, colors):
perf1 = comp1.get_performance()[performance]
perf2 = comp2.get_performance()[performance]
ax.scatter(perf2, perf1, marker=".", label=performance, color=color)

ax.plot([0, 1], [0, 1], "k--", alpha=0.5)
ax.set_ylim(ylim)
ax.set_xlim(ylim)
ax.spines[["right", "top"]].set_visible(False)
ax.set_aspect("equal")

label1 = study.cases[key1]["label"]
label2 = study.cases[key2]["label"]

if i == j - 1:
ax.set_xlabel(label2)
ax.set_ylabel(label1)

else:
if j >= 1 and i < num_methods - 1:
ax = axs[i, j - 1]
ax.spines[["right", "top", "left", "bottom"]].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])

ax = axs[num_methods - 2, 0]
patches = []
from matplotlib.patches import Patch

for color, name in zip(colors, metrics):
patches.append(Patch(color=color, label=name))
ax.legend(handles=patches)
fig.tight_layout()
return fig
6 changes: 4 additions & 2 deletions src/spikeinterface/benchmark/tests/test_benchmark_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def test_SorterStudy(setup_module):
print(study)

# # this run the sorters
# study.run()
study.run()

# # this run comparisons
# study.compute_results()
study.compute_results()
print(study)

# this is from the base class
Expand All @@ -84,5 +84,7 @@ def test_SorterStudy(setup_module):

if __name__ == "__main__":
study_folder = Path(__file__).resolve().parents[4] / "cache_folder" / "benchmarks" / "test_SorterStudy"
if study_folder.exists():
shutil.rmtree(study_folder)
create_a_study(study_folder)
test_SorterStudy(study_folder)
2 changes: 2 additions & 0 deletions src/spikeinterface/extractors/neoextractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .mearec import MEArecRecordingExtractor, MEArecSortingExtractor, read_mearec
from .mcsraw import MCSRawRecordingExtractor, read_mcsraw
from .neuralynx import NeuralynxRecordingExtractor, NeuralynxSortingExtractor, read_neuralynx, read_neuralynx_sorting
from .neuronexus import NeuroNexusRecordingExtractor, read_neuronexus
from .neuroscope import (
NeuroScopeRecordingExtractor,
NeuroScopeSortingExtractor,
Expand Down Expand Up @@ -54,6 +55,7 @@
MCSRawRecordingExtractor,
NeuralynxRecordingExtractor,
NeuroScopeRecordingExtractor,
NeuroNexusRecordingExtractor,
NixRecordingExtractor,
OpenEphysBinaryRecordingExtractor,
OpenEphysLegacyRecordingExtractor,
Expand Down
66 changes: 66 additions & 0 deletions src/spikeinterface/extractors/neoextractors/neuronexus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from pathlib import Path

from spikeinterface.core.core_tools import define_function_from_class

from .neobaseextractor import NeoBaseRecordingExtractor, NeoBaseSortingExtractor


class NeuroNexusRecordingExtractor(NeoBaseRecordingExtractor):
"""
Class for reading data from NeuroNexus Allego.
Based on :py:class:`neo.rawio.NeuronexusRawIO`
Parameters
----------
file_path : str | Path
The file path to the metadata .xdat.json file of an Allego session
stream_id : str | None, default: None
If there are several streams, specify the stream id you want to load.
stream_name : str | None, default: None
If there are several streams, specify the stream name you want to load.
all_annotations : bool, default: False
Load exhaustively all annotations from neo.
use_names_as_ids : bool, default: False
Determines the format of the channel IDs used by the extractor. If set to True, the channel IDs will be the
names from NeoRawIO. If set to False, the channel IDs will be the ids provided by NeoRawIO.
In Neuronexus the ids provided by NeoRawIO are the hardware channel ids stored as `ntv_chan_name` within
the metada and the names are the `chan_names`
"""

NeoRawIOClass = "NeuroNexusRawIO"

def __init__(
self,
file_path: str | Path,
stream_id: str | None = None,
stream_name: str | None = None,
all_annotations: bool = False,
use_names_as_ids: bool = False,
):
neo_kwargs = self.map_to_neo_kwargs(file_path)
NeoBaseRecordingExtractor.__init__(
self,
stream_id=stream_id,
stream_name=stream_name,
all_annotations=all_annotations,
use_names_as_ids=use_names_as_ids,
**neo_kwargs,
)

self._kwargs.update(dict(file_path=str(Path(file_path).resolve())))

@classmethod
def map_to_neo_kwargs(cls, file_path):

neo_kwargs = {"filename": str(file_path)}

return neo_kwargs


read_neuronexus = define_function_from_class(source_class=NeuroNexusRecordingExtractor, name="read_neuronexus")
7 changes: 5 additions & 2 deletions src/spikeinterface/extractors/tests/common_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ def test_open(self):
num_samples = rec.get_num_samples(segment_index=segment_index)

full_traces = rec.get_traces(segment_index=segment_index)
assert full_traces.shape == (num_samples, num_chans)
assert full_traces.dtype == dtype
assert full_traces.shape == (
num_samples,
num_chans,
), f"{full_traces.shape} != {(num_samples, num_chans)}"
assert full_traces.dtype == dtype, f"{full_traces.dtype} != {dtype=}"

traces_sample_first = rec.get_traces(segment_index=segment_index, start_frame=0, end_frame=1)
assert traces_sample_first.shape == (1, num_chans)
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,14 @@ class NeuroScopeSortingTest(SortingCommonTestSuite, unittest.TestCase):
]


class NeuroNexusRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = NeuroNexusRecordingExtractor
downloads = ["neuronexus"]
entities = [
("neuronexus/allego_1/allego_2__uid0701-13-04-49.xdat.json", {"stream_id": "0"}),
]


class PlexonRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = PlexonRecordingExtractor
downloads = ["plexon"]
Expand Down
Loading

0 comments on commit 94ab456

Please sign in to comment.