diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index ad359a3e7e..245c828b9e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -38,6 +38,7 @@ def get_potential_auto_merge( p_value=0.2, CC_threshold=0.1, k_nn=10, + knn_kwargs=None, **presence_distance_kwargs, ): """ @@ -111,6 +112,8 @@ def get_potential_auto_merge( Parameter to control how present two units should be simultaneously k_nn : int, default 5 The number of neighbors to consider for every spike in the recording + knn_kwargs : dict, default None + The dict of extra params to be passed to knn extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned steps : None or list of str, default: None @@ -202,14 +205,14 @@ def get_potential_auto_merge( assert step in all_steps, f"{step} is not a valid step" - # STEP 1 : + # STEP : remove units with too few spikes if step == "min_spikes": num_spikes = sorting.count_num_spikes_per_unit(outputs="array") to_remove = num_spikes < minimum_spikes pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False - # STEP 2 : remove contaminated auto corr + # STEP : remove contaminated auto corr elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms @@ -220,7 +223,7 @@ def get_potential_auto_merge( pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False - # STEP 3 : unit positions are estimated roughly with channel + # STEP : unit positions are estimated roughly with channel elif step == "unit_positions" in steps: positions_ext = sorting_analyzer.get_extension("unit_locations") if positions_ext is not None: @@ -237,7 +240,7 @@ def get_potential_auto_merge( pair_mask = pair_mask & (unit_distances <= maximum_distance_um) outs["unit_distances"] = unit_distances - # STEP 4 : potential auto merge by correlogram + # STEP : potential auto merge by correlogram elif step == "correlogram" in steps: correlograms_ext = sorting_analyzer.get_extension("correlograms") if correlograms_ext is not None: @@ -268,7 +271,7 @@ def get_potential_auto_merge( outs["correlogram_diff"] = correlogram_diff outs["win_sizes"] = win_sizes - # STEP 5 : check if potential merge with CC also have template similarity + # STEP : check if potential merge with CC also have template similarity elif step == "template_similarity" in steps: template_similarity_ext = sorting_analyzer.get_extension("template_similarity") if template_similarity_ext is not None: @@ -295,23 +298,26 @@ def get_potential_auto_merge( pair_mask = pair_mask & (templates_diff < template_diff_thresh) outs["templates_diff"] = templates_diff + # STEP : check the vicinity of the spikes elif step == "knn" in steps: - pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask) + if knn_kwargs is None: + knn_kwargs = dict() + pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs) - # STEP 6 : [optional] check how the rates overlap in times + # STEP : check how the rates overlap in times elif step == "presence_distance" in steps: presence_distances = compute_presence_distance(sorting, pair_mask, **presence_distance_kwargs) pair_mask = pair_mask & (presence_distances > presence_distance_thresh) outs["presence_distances"] = presence_distances - # STEP 7 : [optional] check if the cross contamination is significant + # STEP : check if the cross contamination is significant elif step == "cross_contamination" in steps: refractory = (censored_period_ms, refractory_period_ms) CC, p_values = compute_cross_contaminations(sorting_analyzer, pair_mask, CC_threshold, refractory) pair_mask = pair_mask & (p_values > p_value) outs["cross_contaminations"] = CC, p_values - # STEP 8 : validate the potential merges with CC increase the contamination quality metrics + # STEP : validate the potential merges with CC increase the contamination quality metrics elif step == "check_increase_score" in steps: pair_mask, pairs_decreased_score = check_improve_contaminations_score( sorting_analyzer, @@ -333,7 +339,7 @@ def get_potential_auto_merge( return potential_merges -def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None): +def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids @@ -354,7 +360,7 @@ def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None): all_spike_counts = sorting_analyzer.sorting.count_num_spikes_per_unit() all_spike_counts = np.array(list(all_spike_counts.keys())) - kdtree = NearestNeighbors(n_neighbors=k_nn, n_jobs=-1) + kdtree = NearestNeighbors(n_neighbors=k_nn, **knn_kwargs) kdtree.fit(data) for unit_ind in range(n): diff --git a/src/spikeinterface/sortingcomponents/merging/knn.py b/src/spikeinterface/sortingcomponents/merging/knn.py index cc56d1c7b7..5a4d81dd7b 100644 --- a/src/spikeinterface/sortingcomponents/merging/knn.py +++ b/src/spikeinterface/sortingcomponents/merging/knn.py @@ -33,7 +33,7 @@ class KNNMerging(BaseMergingEngine): "maximum_distance_um": 100, "refractory_period": (0.3, 1.0), "corr_diff_thresh": 0.2, - "k_nn" : 10 + "k_nn" : 5 }, }