Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jul 2, 2024
1 parent 1f323f9 commit 016d7cc
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
28 changes: 17 additions & 11 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_potential_auto_merge(
p_value=0.2,
CC_threshold=0.1,
k_nn=10,
knn_kwargs=None,
**presence_distance_kwargs,
):
"""
Expand Down Expand Up @@ -111,6 +112,8 @@ def get_potential_auto_merge(
Parameter to control how present two units should be simultaneously
k_nn : int, default 5
The number of neighbors to consider for every spike in the recording
knn_kwargs : dict, default None
The dict of extra params to be passed to knn
extra_outputs : bool, default: False
If True, an additional dictionary (`outs`) with processed data is returned
steps : None or list of str, default: None
Expand Down Expand Up @@ -202,14 +205,14 @@ def get_potential_auto_merge(

assert step in all_steps, f"{step} is not a valid step"

# STEP 1 :
# STEP : remove units with too few spikes
if step == "min_spikes":
num_spikes = sorting.count_num_spikes_per_unit(outputs="array")
to_remove = num_spikes < minimum_spikes
pair_mask[to_remove, :] = False
pair_mask[:, to_remove] = False

# STEP 2 : remove contaminated auto corr
# STEP : remove contaminated auto corr
elif step == "remove_contaminated":
contaminations, nb_violations = compute_refrac_period_violations(
sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms
Expand All @@ -220,7 +223,7 @@ def get_potential_auto_merge(
pair_mask[to_remove, :] = False
pair_mask[:, to_remove] = False

# STEP 3 : unit positions are estimated roughly with channel
# STEP : unit positions are estimated roughly with channel
elif step == "unit_positions" in steps:
positions_ext = sorting_analyzer.get_extension("unit_locations")
if positions_ext is not None:
Expand All @@ -237,7 +240,7 @@ def get_potential_auto_merge(
pair_mask = pair_mask & (unit_distances <= maximum_distance_um)
outs["unit_distances"] = unit_distances

# STEP 4 : potential auto merge by correlogram
# STEP : potential auto merge by correlogram
elif step == "correlogram" in steps:
correlograms_ext = sorting_analyzer.get_extension("correlograms")
if correlograms_ext is not None:
Expand Down Expand Up @@ -268,7 +271,7 @@ def get_potential_auto_merge(
outs["correlogram_diff"] = correlogram_diff
outs["win_sizes"] = win_sizes

# STEP 5 : check if potential merge with CC also have template similarity
# STEP : check if potential merge with CC also have template similarity
elif step == "template_similarity" in steps:
template_similarity_ext = sorting_analyzer.get_extension("template_similarity")
if template_similarity_ext is not None:
Expand All @@ -295,23 +298,26 @@ def get_potential_auto_merge(
pair_mask = pair_mask & (templates_diff < template_diff_thresh)
outs["templates_diff"] = templates_diff

# STEP : check the vicinity of the spikes
elif step == "knn" in steps:
pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask)
if knn_kwargs is None:
knn_kwargs = dict()
pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs)

# STEP 6 : [optional] check how the rates overlap in times
# STEP : check how the rates overlap in times
elif step == "presence_distance" in steps:
presence_distances = compute_presence_distance(sorting, pair_mask, **presence_distance_kwargs)
pair_mask = pair_mask & (presence_distances > presence_distance_thresh)
outs["presence_distances"] = presence_distances

# STEP 7 : [optional] check if the cross contamination is significant
# STEP : check if the cross contamination is significant
elif step == "cross_contamination" in steps:
refractory = (censored_period_ms, refractory_period_ms)
CC, p_values = compute_cross_contaminations(sorting_analyzer, pair_mask, CC_threshold, refractory)
pair_mask = pair_mask & (p_values > p_value)
outs["cross_contaminations"] = CC, p_values

# STEP 8 : validate the potential merges with CC increase the contamination quality metrics
# STEP : validate the potential merges with CC increase the contamination quality metrics
elif step == "check_increase_score" in steps:
pair_mask, pairs_decreased_score = check_improve_contaminations_score(
sorting_analyzer,
Expand All @@ -333,7 +339,7 @@ def get_potential_auto_merge(
return potential_merges


def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None):
def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs):

sorting = sorting_analyzer.sorting
unit_ids = sorting.unit_ids
Expand All @@ -354,7 +360,7 @@ def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None):
all_spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit()
all_spike_counts = np.array(list(all_spike_counts.keys()))

kdtree = NearestNeighbors(n_neighbors=k_nn, n_jobs=-1)
kdtree = NearestNeighbors(n_neighbors=k_nn, **knn_kwargs)
kdtree.fit(data)

for unit_ind in range(n):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/merging/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class KNNMerging(BaseMergingEngine):
"maximum_distance_um": 100,
"refractory_period": (0.3, 1.0),
"corr_diff_thresh": 0.2,
"k_nn" : 10
"k_nn" : 5
},
}

Expand Down

0 comments on commit 016d7cc

Please sign in to comment.