diff --git a/.gitignore b/.gitignore index 3ee3cb8867..7838213bed 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,4 @@ test_folder/ # Mac OS .DS_Store +test_data.json diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 6adf9effd4..626ea79eb9 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -57,37 +57,47 @@ def apply_sortingview_curation( unit_ids_dtype = sorting.unit_ids.dtype # STEP 1: merge groups + labels_dict = sortingview_curation_dict["labelsByUnit"] if "mergeGroups" in sortingview_curation_dict and not skip_merge: merge_groups = sortingview_curation_dict["mergeGroups"] - for mg in merge_groups: + for merge_group in merge_groups: + # Store labels of units that are about to be merged + labels_to_inherit = [] + for unit in merge_group: + labels_to_inherit.extend(labels_dict.get(str(unit), [])) + labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates + if verbose: - print(f"Merging {mg}") + print(f"Merging {merge_group}") if unit_ids_dtype.kind in ("U", "S"): # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(mg) + new_unit_id = "-".join(merge_group) + curation_sorting.merge(merge_group, new_unit_id=new_unit_id) else: # in this case, the CurationSorting takes care of finding a new unused int - new_unit_id = None - curation_sorting.merge(mg, new_unit_id=new_unit_id) + curation_sorting.merge(merge_group, new_unit_id=None) + new_unit_id = curation_sorting.max_used_id # merged unit id + labels_dict[str(new_unit_id)] = labels_to_inherit # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. # For example, the first 3 units could be labeled as "accept". # In this case, the first 3 values of the property "accept" will be True, the rest False - labels_dict = sortingview_curation_dict["labelsByUnit"] - properties = {} - for _, labels in labels_dict.items(): - for label in labels: - if label not in properties: - properties[label] = np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for u_i, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - labels_unit = [] - for unit_label, labels in labels_dict.items(): - if unit_label in str(unit_id): - labels_unit.extend(labels) - for label in labels_unit: - properties[label][u_i] = True + + # Initialize the properties dictionary + properties = { + label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) + for labels in labels_dict.values() + for label in labels + } + + # Populate the properties dictionary + for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): + unit_id_str = str(unit_id) + if unit_id_str in labels_dict: + for label in labels_dict[unit_id_str]: + properties[label][unit_index] = True + for prop_name, prop_values in properties.items(): curation_sorting.current_sorting.set_property(prop_name, prop_values) @@ -103,5 +113,4 @@ def apply_sortingview_curation( units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) units_to_remove = np.unique(units_to_remove) curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json new file mode 100644 index 0000000000..48881388bb --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-false-positive.json @@ -0,0 +1,19 @@ +{ + "labelsByUnit": { + "1": [ + "accept" + ], + "2": [ + "artifact" + ], + "12": [ + "artifact" + ] + }, + "mergeGroups": [ + [ + 2, + 12 + ] + ] +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-int.json b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json new file mode 100644 index 0000000000..2047c514ce --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-int.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "1": [ + "mua" + ], + "2": [ + "mua" + ], + "3": [ + "reject" + ], + "4": [ + "noise" + ], + "5": [ + "accept" + ], + "6": [ + "accept" + ], + "7": [ + "accept" + ] + }, + "mergeGroups": [ + [ + 1, + 2 + ], + [ + 3, + 4 + ], + [ + 5, + 6 + ] + ] +} diff --git a/src/spikeinterface/curation/tests/sv-sorting-curation-str.json b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json new file mode 100644 index 0000000000..2585b5cc50 --- /dev/null +++ b/src/spikeinterface/curation/tests/sv-sorting-curation-str.json @@ -0,0 +1,39 @@ +{ + "labelsByUnit": { + "a": [ + "mua" + ], + "b": [ + "mua" + ], + "c": [ + "reject" + ], + "d": [ + "noise" + ], + "e": [ + "accept" + ], + "f": [ + "accept" + ], + "g": [ + "accept" + ] + }, + "mergeGroups": [ + [ + "a", + "b" + ], + [ + "c", + "d" + ], + [ + "e", + "f" + ] + ] +} diff --git a/src/spikeinterface/curation/tests/test_sortingview_curation.py b/src/spikeinterface/curation/tests/test_sortingview_curation.py index 9177cb5536..ce6c7dd5a6 100644 --- a/src/spikeinterface/curation/tests/test_sortingview_curation.py +++ b/src/spikeinterface/curation/tests/test_sortingview_curation.py @@ -1,8 +1,11 @@ import pytest from pathlib import Path import os +import json +import numpy as np import spikeinterface as si +import spikeinterface.extractors as se from spikeinterface.extractors import read_mearec from spikeinterface import set_global_tmp_folder from spikeinterface.postprocessing import ( @@ -19,7 +22,6 @@ cache_folder = Path("cache_folder") / "curation" parent_folder = Path(__file__).parent - ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) KACHERY_CLOUD_SET = bool(os.getenv("KACHERY_CLOUD_CLIENT_ID")) and bool(os.getenv("KACHERY_CLOUD_PRIVATE_KEY")) @@ -50,15 +52,15 @@ def generate_sortingview_curation_dataset(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_gh_curation(): + """ + Test curation using GitHub URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) - - # from GH # curated link: # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22gh://alejoe91/spikeinterface/fix-codecov/spikeinterface/curation/tests/sv-sorting-curation.json%22} gh_uri = "gh://SpikeInterface/spikeinterface/main/src/spikeinterface/curation/tests/sv-sorting-curation.json" sorting_curated_gh = apply_sortingview_curation(sorting, uri_or_json=gh_uri, verbose=True) - print(f"From GH: {sorting_curated_gh}") assert len(sorting_curated_gh.unit_ids) == 9 assert "#8-#9" in sorting_curated_gh.unit_ids @@ -78,6 +80,9 @@ def test_gh_curation(): @pytest.mark.skipif(ON_GITHUB and not KACHERY_CLOUD_SET, reason="Kachery cloud secrets not available") def test_sha1_curation(): + """ + Test curation using SHA1 URI. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) @@ -86,14 +91,14 @@ def test_sha1_curation(): # https://figurl.org/f?v=gs://figurl/spikesortingview-10&d=sha1://bd53f6b707f8121cadc901562a89b67aec81cc81&label=SpikeInterface%20-%20Sorting%20Summary&s={%22sortingCuration%22:%22sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22%22} sha1_uri = "sha1://1182ba19671fcc7d3f8e0501b0f8c07fb9736c22" sorting_curated_sha1 = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, verbose=True) - print(f"From SHA: {sorting_curated_sha1}") + # print(f"From SHA: {sorting_curated_sha1}") assert len(sorting_curated_sha1.unit_ids) == 9 assert "#8-#9" in sorting_curated_sha1.unit_ids assert "accept" in sorting_curated_sha1.get_property_keys() assert "mua" in sorting_curated_sha1.get_property_keys() assert "artifact" in sorting_curated_sha1.get_property_keys() - + unit_ids = sorting_curated_sha1.unit_ids sorting_curated_sha1_accepted = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, include_labels=["accept"]) sorting_curated_sha1_mua = apply_sortingview_curation(sorting, uri_or_json=sha1_uri, exclude_labels=["mua"]) sorting_curated_sha1_art_mua = apply_sortingview_curation( @@ -105,13 +110,16 @@ def test_sha1_curation(): def test_json_curation(): + """ + Test curation using a JSON file. + """ local_path = si.download_dataset(remote_path="mearec/mearec_test_10s.h5") _, sorting = read_mearec(local_path) # from curation.json json_file = parent_folder / "sv-sorting-curation.json" + # print(f"Sorting: {sorting.get_unit_ids()}") sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) - print(f"From JSON: {sorting_curated_json}") assert len(sorting_curated_json.unit_ids) == 9 assert "#8-#9" in sorting_curated_json.unit_ids @@ -131,8 +139,133 @@ def test_json_curation(): assert len(sorting_curated_json_mua1.unit_ids) == 5 +def test_false_positive_curation(): + """ + Test curation for false positives. + """ + # https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_2_sorting_extractor.html + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_units = 20 + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, num_units + 1, size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + # print("Sorting: {}".format(sorting.get_unit_ids())) + + json_file = parent_folder / "sv-sorting-curation-false-positive.json" + sorting_curated_json = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + # print("Curated:", sorting_curated_json.get_unit_ids()) + + # Assertions + assert sorting_curated_json.get_unit_property(unit_id=1, key="accept") + assert not sorting_curated_json.get_unit_property(unit_id=10, key="accept") + assert 21 in sorting_curated_json.unit_ids + + +def test_label_inheritance_int(): + """ + Test curation for label inheritance for integer unit IDs. + """ + # Setup + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + num_units = 7 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.randint(1, 1 + num_units, size=num_spikes) # 7 units: 1 to 7 + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + + json_file = parent_folder / "sv-sorting-curation-int.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file) + + # Assertions for merged units + # print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id=8, key="mua") # 8 = merged unit of 1 and 2 + assert not sorting_merge.get_unit_property(unit_id=8, key="reject") + assert not sorting_merge.get_unit_property(unit_id=8, key="noise") + assert not sorting_merge.get_unit_property(unit_id=8, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=9, key="mua") # 9 = merged unit of 3 and 4 + assert sorting_merge.get_unit_property(unit_id=9, key="reject") + assert sorting_merge.get_unit_property(unit_id=9, key="noise") + assert not sorting_merge.get_unit_property(unit_id=9, key="accept") + + assert not sorting_merge.get_unit_property(unit_id=10, key="mua") # 10 = merged unit of 5 and 6 + assert not sorting_merge.get_unit_property(unit_id=10, key="reject") + assert not sorting_merge.get_unit_property(unit_id=10, key="noise") + assert sorting_merge.get_unit_property(unit_id=10, key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert 9 not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert 8 not in sorting_include_accept.get_unit_ids() + assert 9 not in sorting_include_accept.get_unit_ids() + assert 10 in sorting_include_accept.get_unit_ids() + + +def test_label_inheritance_str(): + """ + Test curation for label inheritance for string unit IDs. + """ + sampling_frequency = 30000.0 + duration = 20.0 + num_timepoints = int(sampling_frequency * duration) + num_spikes = 1000 + times = np.int_(np.sort(np.random.uniform(0, num_timepoints, num_spikes))) + labels = np.random.choice(["a", "b", "c", "d", "e", "f", "g"], size=num_spikes) + + sorting = se.NumpySorting.from_times_labels(times, labels, sampling_frequency) + # print(f"Sorting: {sorting.get_unit_ids()}") + + # Apply curation + json_file = parent_folder / "sv-sorting-curation-str.json" + sorting_merge = apply_sortingview_curation(sorting, uri_or_json=json_file, verbose=True) + + # Assertions for merged units + # print(f"Merge only: {sorting_merge.get_unit_ids()}") + assert sorting_merge.get_unit_property(unit_id="a-b", key="mua") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="reject") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="noise") + assert not sorting_merge.get_unit_property(unit_id="a-b", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="c-d", key="mua") + assert sorting_merge.get_unit_property(unit_id="c-d", key="reject") + assert sorting_merge.get_unit_property(unit_id="c-d", key="noise") + assert not sorting_merge.get_unit_property(unit_id="c-d", key="accept") + + assert not sorting_merge.get_unit_property(unit_id="e-f", key="mua") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="reject") + assert not sorting_merge.get_unit_property(unit_id="e-f", key="noise") + assert sorting_merge.get_unit_property(unit_id="e-f", key="accept") + + # Assertions for exclude_labels + sorting_exclude_noise = apply_sortingview_curation(sorting, uri_or_json=json_file, exclude_labels=["noise"]) + # print(f"Exclude noise: {sorting_exclude_noise.get_unit_ids()}") + assert "c-d" not in sorting_exclude_noise.get_unit_ids() + + # Assertions for include_labels + sorting_include_accept = apply_sortingview_curation(sorting, uri_or_json=json_file, include_labels=["accept"]) + # print(f"Include accept: {sorting_include_accept.get_unit_ids()}") + assert "a-b" not in sorting_include_accept.get_unit_ids() + assert "c-d" not in sorting_include_accept.get_unit_ids() + assert "e-f" in sorting_include_accept.get_unit_ids() + + if __name__ == "__main__": # generate_sortingview_curation_dataset() test_sha1_curation() test_gh_curation() test_json_curation() + test_false_positive_curation() + test_label_inheritance_int() + test_label_inheritance_str()