Skip to content

Commit

Permalink
Merge pull request #2037 from rkim48/fix-unit-id-matching
Browse files Browse the repository at this point in the history
Fix: Correct unit ID matching in sortingview curation
  • Loading branch information
alejoe91 authored Oct 2, 2023
2 parents 225d9d1 + c20ffda commit 8c35a3a
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,4 @@ test_folder/

# Mac OS
.DS_Store
test_data.json
49 changes: 29 additions & 20 deletions src/spikeinterface/curation/sortingview_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,37 +57,47 @@ def apply_sortingview_curation(
unit_ids_dtype = sorting.unit_ids.dtype

# STEP 1: merge groups
labels_dict = sortingview_curation_dict["labelsByUnit"]
if "mergeGroups" in sortingview_curation_dict and not skip_merge:
merge_groups = sortingview_curation_dict["mergeGroups"]
for mg in merge_groups:
for merge_group in merge_groups:
# Store labels of units that are about to be merged
labels_to_inherit = []
for unit in merge_group:
labels_to_inherit.extend(labels_dict.get(str(unit), []))
labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates

if verbose:
print(f"Merging {mg}")
print(f"Merging {merge_group}")
if unit_ids_dtype.kind in ("U", "S"):
# if unit dtype is str, set new id as "{unit1}-{unit2}"
new_unit_id = "-".join(mg)
new_unit_id = "-".join(merge_group)
curation_sorting.merge(merge_group, new_unit_id=new_unit_id)
else:
# in this case, the CurationSorting takes care of finding a new unused int
new_unit_id = None
curation_sorting.merge(mg, new_unit_id=new_unit_id)
curation_sorting.merge(merge_group, new_unit_id=None)
new_unit_id = curation_sorting.max_used_id # merged unit id
labels_dict[str(new_unit_id)] = labels_to_inherit

# STEP 2: gather and apply sortingview curation labels

# In sortingview, a unit is not required to have all labels.
# For example, the first 3 units could be labeled as "accept".
# In this case, the first 3 values of the property "accept" will be True, the rest False
labels_dict = sortingview_curation_dict["labelsByUnit"]
properties = {}
for _, labels in labels_dict.items():
for label in labels:
if label not in properties:
properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)
for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids):
labels_unit = []
for unit_label, labels in labels_dict.items():
if unit_label in str(unit_id):
labels_unit.extend(labels)
for label in labels_unit:
properties[label][u_i] = True

# Initialize the properties dictionary
properties = {
label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool)
for labels in labels_dict.values()
for label in labels
}

# Populate the properties dictionary
for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids):
unit_id_str = str(unit_id)
if unit_id_str in labels_dict:
for label in labels_dict[unit_id_str]:
properties[label][unit_index] = True

for prop_name, prop_values in properties.items():
curation_sorting.current_sorting.set_property(prop_name, prop_values)

Expand All @@ -103,5 +113,4 @@ def apply_sortingview_curation(
units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True])
units_to_remove = np.unique(units_to_remove)
curation_sorting.remove_units(units_to_remove)

return curation_sorting.current_sorting
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"labelsByUnit": {
"1": [
"accept"
],
"2": [
"artifact"
],
"12": [
"artifact"
]
},
"mergeGroups": [
[
2,
12
]
]
}
39 changes: 39 additions & 0 deletions src/spikeinterface/curation/tests/sv-sorting-curation-int.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"labelsByUnit": {
"1": [
"mua"
],
"2": [
"mua"
],
"3": [
"reject"
],
"4": [
"noise"
],
"5": [
"accept"
],
"6": [
"accept"
],
"7": [
"accept"
]
},
"mergeGroups": [
[
1,
2
],
[
3,
4
],
[
5,
6
]
]
}
39 changes: 39 additions & 0 deletions src/spikeinterface/curation/tests/sv-sorting-curation-str.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"labelsByUnit": {
"a": [
"mua"
],
"b": [
"mua"
],
"c": [
"reject"
],
"d": [
"noise"
],
"e": [
"accept"
],
"f": [
"accept"
],
"g": [
"accept"
]
},
"mergeGroups": [
[
"a",
"b"
],
[
"c",
"d"
],
[
"e",
"f"
]
]
}
147 changes: 140 additions & 7 deletions src/spikeinterface/curation/tests/test_sortingview_curation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytest
from pathlib import Path
import os
import json
import numpy as np

import spikeinterface as si
import spikeinterface.extractors as se
from spikeinterface.extractors import read_mearec
from spikeinterface import set_global_tmp_folder
from spikeinterface.postprocessing import (
Expand All @@ -19,7 +22,6 @@
cache_folder = Path("cache_folder") / "curation"

parent_folder = Path(__file__).parent

ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS"))
KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY"))

Expand Down Expand Up @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset():

@pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available")
def test_gh_curation():
"""
Test curation using GitHub URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

# from GH
# curated link:
# https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22}
gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json"
sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True)
print(f"From GH: {sorting_curated_gh}")

assert len(sorting_curated_gh.unit_ids) == 9
assert "#8-#9" in sorting_curated_gh.unit_ids
Expand All @@ -78,6 +80,9 @@ def test_gh_curation():

@pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available")
def test_sha1_curation():
"""
Test curation using SHA1 URI.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

Expand All @@ -86,14 +91,14 @@ def test_sha1_curation():
# https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22}
sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22"
sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True)
print(f"From SHA: {sorting_curated_sha1}")
# print(f"From SHA: {sorting_curated_sha1}")

assert len(sorting_curated_sha1.unit_ids) == 9
assert "#8-#9" in sorting_curated_sha1.unit_ids
assert "accept" in sorting_curated_sha1.get_property_keys()
assert "mua" in sorting_curated_sha1.get_property_keys()
assert "artifact" in sorting_curated_sha1.get_property_keys()

unit_ids = sorting_curated_sha1.unit_ids
sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"])
sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"])
sorting_curated_sha1_art_mua = apply_sortingview_curation(
Expand All @@ -105,13 +110,16 @@ def test_sha1_curation():


def test_json_curation():
"""
Test curation using a JSON file.
"""
local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5")
_, sorting = read_mearec(local_path)

# from curation.json
json_file = parent_folder / "sv-sorting-curation.json"
# print(f"Sorting: {sorting.get_unit_ids()}")
sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)
print(f"From JSON: {sorting_curated_json}")

assert len(sorting_curated_json.unit_ids) == 9
assert "#8-#9" in sorting_curated_json.unit_ids
Expand All @@ -131,8 +139,133 @@ def test_json_curation():
assert len(sorting_curated_json_mua1.unit_ids) == 5


def test_false_positive_curation():
"""
Test curation for false positives.
"""
# https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_units = 20
num_spikes = 1000
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.randint(1, num_units + 1, size=num_spikes)

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)
# print("Sorting: {}".format(sorting.get_unit_ids()))

json_file = parent_folder / "sv-sorting-curation-false-positive.json"
sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)
# print("Curated:", sorting_curated_json.get_unit_ids())

# Assertions
assert sorting_curated_json.get_unit_property(unit_id=1, key="accept")
assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept")
assert 21 in sorting_curated_json.unit_ids


def test_label_inheritance_int():
"""
Test curation for label inheritance for integer unit IDs.
"""
# Setup
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_spikes = 1000
num_units = 7
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)

json_file = parent_folder / "sv-sorting-curation-int.json"
sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file)

# Assertions for merged units
# print(f"Merge only: {sorting_merge.get_unit_ids()}")
assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2
assert not sorting_merge.get_unit_property(unit_id=8, key="reject")
assert not sorting_merge.get_unit_property(unit_id=8, key="noise")
assert not sorting_merge.get_unit_property(unit_id=8, key="accept")

assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4
assert sorting_merge.get_unit_property(unit_id=9, key="reject")
assert sorting_merge.get_unit_property(unit_id=9, key="noise")
assert not sorting_merge.get_unit_property(unit_id=9, key="accept")

assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6
assert not sorting_merge.get_unit_property(unit_id=10, key="reject")
assert not sorting_merge.get_unit_property(unit_id=10, key="noise")
assert sorting_merge.get_unit_property(unit_id=10, key="accept")

# Assertions for exclude_labels
sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"])
# print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}")
assert 9 not in sorting_exclude_noise.get_unit_ids()

# Assertions for include_labels
sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"])
# print(f"Include accept: {sorting_include_accept.get_unit_ids()}")
assert 8 not in sorting_include_accept.get_unit_ids()
assert 9 not in sorting_include_accept.get_unit_ids()
assert 10 in sorting_include_accept.get_unit_ids()


def test_label_inheritance_str():
"""
Test curation for label inheritance for string unit IDs.
"""
sampling_frequency = 30000.0
duration = 20.0
num_timepoints = int(sampling_frequency * duration)
num_spikes = 1000
times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes)))
labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes)

sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency)
# print(f"Sorting: {sorting.get_unit_ids()}")

# Apply curation
json_file = parent_folder / "sv-sorting-curation-str.json"
sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True)

# Assertions for merged units
# print(f"Merge only: {sorting_merge.get_unit_ids()}")
assert sorting_merge.get_unit_property(unit_id="a-b", key="mua")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise")
assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept")

assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua")
assert sorting_merge.get_unit_property(unit_id="c-d", key="reject")
assert sorting_merge.get_unit_property(unit_id="c-d", key="noise")
assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept")

assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua")
assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject")
assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise")
assert sorting_merge.get_unit_property(unit_id="e-f", key="accept")

# Assertions for exclude_labels
sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"])
# print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}")
assert "c-d" not in sorting_exclude_noise.get_unit_ids()

# Assertions for include_labels
sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"])
# print(f"Include accept: {sorting_include_accept.get_unit_ids()}")
assert "a-b" not in sorting_include_accept.get_unit_ids()
assert "c-d" not in sorting_include_accept.get_unit_ids()
assert "e-f" in sorting_include_accept.get_unit_ids()


if __name__ == "__main__":
# generate_sortingview_curation_dataset()
test_sha1_curation()
test_gh_curation()
test_json_curation()
test_false_positive_curation()
test_label_inheritance_int()
test_label_inheritance_str()

0 comments on commit 8c35a3a

Please sign in to comment.