Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes in curation format and apply_curation() #3601

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions src/spikeinterface/curation/curation_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ def validate_curation_dict(curation_dict):
if not removed_units_set.issubset(unit_set):
raise ValueError("Curation format: some removed units are not in the unit list")

for group in curation_dict["merge_unit_groups"]:
if len(group) < 2:
raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements")

all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]]
for gp_1, gp_2 in combinations(all_merging_groups, 2):
if len(gp_1.intersection(gp_2)) != 0:
raise ValueError("Some units belong to multiple merge groups")
raise ValueError("Curation format: some units belong to multiple merge groups")
if len(removed_units_set.intersection(merged_units_set)) != 0:
raise ValueError("Some units were merged and deleted")
raise ValueError("Curation format: some units were merged and deleted")

# Check the labels exclusivity
for lbl in curation_dict["manual_labels"]:
Expand Down Expand Up @@ -238,7 +242,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict):
all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype)
for unit_ind, unit_id in enumerate(sorting.unit_ids):
if unit_id not in new_unit_ids:
ind = curation_dict["unit_ids"].index(unit_id)
ind = list(curation_dict["unit_ids"]).index(unit_id)
all_values[unit_ind] = values[ind]
sorting.set_property(key, all_values)

Expand All @@ -253,7 +257,7 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict):
group_values.append(value)
if len(set(group_values)) == 1:
# all group has the same label or empty
sorting.set_property(key, values=group_values, ids=[new_unit_id])
sorting.set_property(key, values=group_values[:1], ids=[new_unit_id])
else:

for key in label_def["label_options"]:
Expand Down Expand Up @@ -339,18 +343,22 @@ def apply_curation(

elif isinstance(sorting_or_analyzer, SortingAnalyzer):
analyzer = sorting_or_analyzer
analyzer = analyzer.remove_units(curation_dict["removed_units"])
analyzer, new_unit_ids = analyzer.merge_units(
curation_dict["merge_unit_groups"],
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
new_id_strategy=new_id_strategy,
return_new_unit_ids=True,
format="memory",
verbose=verbose,
**job_kwargs,
)
if len(curation_dict["removed_units"]) > 0:
analyzer = analyzer.remove_units(curation_dict["removed_units"])
if len(curation_dict["merge_unit_groups"]) > 0:
analyzer, new_unit_ids = analyzer.merge_units(
curation_dict["merge_unit_groups"],
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
new_id_strategy=new_id_strategy,
return_new_unit_ids=True,
format="memory",
verbose=verbose,
**job_kwargs,
)
else:
new_unit_ids = []
apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict)
return analyzer
else:
Expand Down
Loading