From f46d13e0810ea66193206e9c49dc7bb7cc388f7c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 24 Sep 2024 12:01:36 +0200 Subject: [PATCH 01/25] Refactoring auto_merge --- src/spikeinterface/curation/auto_merge.py | 339 ++++++++++++++++------ 1 file changed, 251 insertions(+), 88 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 19336e5943..00c156094d 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -12,8 +12,6 @@ HAVE_NUMBA = False from ..core import SortingAnalyzer, Templates -from ..core.template_tools import get_template_extremum_channel -from ..postprocessing import compute_correlograms from ..qualitymetrics import compute_refrac_period_violations, compute_firing_rates from .mergeunitssorting import MergeUnitsSorting @@ -25,35 +23,43 @@ _required_extensions = { "unit_locations": ["unit_locations"], "correlogram": ["correlograms"], + "min_snr": ["noise_levels", "templates"], "template_similarity": ["template_similarity"], "knn": ["spike_locations", "spike_amplitudes"], } +_templates_needed = ["unit_locations", "min_snr", "template_similarity", "spike_locations", "spike_amplitudes"] -def get_potential_auto_merge( + +def auto_merges( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", resolve_graph: bool = False, - min_spikes: int = 100, - min_snr: float = 2, - max_distance_um: float = 150.0, - corr_diff_thresh: float = 0.16, - template_diff_thresh: float = 0.25, - contamination_thresh: float = 0.2, - presence_distance_thresh: float = 100, - p_value: float = 0.2, - cc_thresh: float = 0.1, - censored_period_ms: float = 0.3, - refractory_period_ms: float = 1.0, - sigma_smooth_ms: float = 0.6, - adaptative_window_thresh: float = 0.5, - censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, - k_nn: int = 10, - knn_kwargs: dict | None = None, - presence_distance_kwargs: dict | None = None, + num_spikes_kwargs={"min_spikes": 100}, + snr_kwargs={"min_snr": 2}, + remove_contaminated_kwargs={"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + unit_locations_kwargs={"max_distance_um": 50}, + correlogram_kwargs={ + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + template_similarity_kwargs={"template_diff_thresh": 0.25}, + presence_distance_kwargs={"presence_distance_thresh": 100}, + knn_kwargs={"k_nn": 10}, + cross_contamination_kwargs={ + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + quality_score_kwargs={"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, + force_copy: bool = True, + **job_kwargs, ) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: """ Algorithm to find and check potential merges between units. @@ -98,56 +104,21 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" - If `preset` is None, you can specify the steps manually with the `steps` parameter. resolve_graph : bool, default: False If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. - min_spikes : int, default: 100 - Minimum number of spikes for each unit to consider a potential merge. - Enough spikes are needed to estimate the correlogram - min_snr : float, default 2 - Minimum Signal to Noise ratio for templates to be considered while merging - max_distance_um : float, default: 150 - Maximum distance between units for considering a merge - corr_diff_thresh : float, default: 0.16 - The threshold on the "correlogram distance metric" for considering a merge. - It needs to be between 0 and 1 - template_diff_thresh : float, default: 0.25 - The threshold on the "template distance metric" for considering a merge. - It needs to be between 0 and 1 - contamination_thresh : float, default: 0.2 - Threshold for not taking in account a unit when it is too contaminated. - presence_distance_thresh : float, default: 100 - Parameter to control how present two units should be simultaneously. - p_value : float, default: 0.2 - The p-value threshold for the cross-contamination test. - cc_thresh : float, default: 0.1 - The threshold on the cross-contamination for considering a merge. - censored_period_ms : float, default: 0.3 - Used to compute the refractory period violations aka "contamination". - refractory_period_ms : float, default: 1 - Used to compute the refractory period violations aka "contamination". - sigma_smooth_ms : float, default: 0.6 - Parameters to smooth the correlogram estimation. - adaptative_window_thresh : float, default: 0.5 - Parameter to detect the window size in correlogram estimation. - censor_correlograms_ms : float, default: 0.15 - The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 - Parameter to control the balance between firing rate and contamination in computing unit "quality score". - k_nn : int, default 5 - The number of neighbors to consider for every spike in the recording. - knn_kwargs : dict, default None - The dict of extra params to be passed to knn. + compute_needed_extensions : bool, default : True + Should we force the computation of needed extensions? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None Which steps to run, if no preset is used. Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" - Please check steps explanations above! - presence_distance_kwargs : None|dict, default: None - A dictionary of kwargs to be passed to compute_presence_distance(). + Please check steps explanations above!$ + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite Returns ------- @@ -230,12 +201,24 @@ def get_potential_auto_merge( "knn", "quality_score", ] + if force_copy and compute_needed_extensions: + # To avoid erasing the extensions of the user + sorting_analyzer = sorting_analyzer.copy() for step in steps: if step in _required_extensions: for ext in _required_extensions[step]: - if not sorting_analyzer.has_extension(ext): - raise ValueError(f"{step} requires {ext} extension") + if compute_needed_extensions: + if step in _templates_needed: + template_ext = sorting_analyzer.get_extension("templates") + if template_ext is None: + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + params = eval(f"{step}_kwargs") + params = params.get(ext, dict()) + sorting_analyzer.compute(ext, **params, **job_kwargs) + else: + if not sorting_analyzer.has_extension(ext): + raise ValueError(f"{step} requires {ext} extension") n = unit_ids.size pair_mask = np.triu(np.arange(n)) > 0 @@ -248,33 +231,38 @@ def get_potential_auto_merge( # STEP : remove units with too few spikes if step == "num_spikes": num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < min_spikes + to_remove = num_spikes < num_spikes_kwargs["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["num_spikes"] = to_remove # STEP : remove units with too small SNR elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.compute(["noise_levels"], **job_kwargs) + sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < min_snr + to_remove = snrs < snr_kwargs["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["snr"] = to_remove # STEP : remove contaminated auto corr elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( - sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer, + refractory_period_ms=remove_contaminated_kwargs["refractory_period_ms"], + censored_period_ms=remove_contaminated_kwargs["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > contamination_thresh + to_remove = contaminations > remove_contaminated_kwargs["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel elif step == "unit_locations" in steps: @@ -282,21 +270,23 @@ def get_potential_auto_merge( unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= max_distance_um) + pair_mask = pair_mask & (unit_distances <= unit_locations_kwargs["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram elif step == "correlogram" in steps: correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - mask = (bins[:-1] >= -censor_correlograms_ms) & (bins[:-1] < censor_correlograms_ms) + censor_ms = correlogram_kwargs["censor_correlograms_ms"] + sigma_smooth_ms = correlogram_kwargs["sigma_smooth_ms"] + mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) # find correlogram window for each units win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * adaptative_window_thresh + thresh = np.max(auto_corr) * correlogram_kwargs["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -306,7 +296,7 @@ def get_potential_auto_merge( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < corr_diff_thresh) + pair_mask = pair_mask & (correlogram_diff < correlogram_kwargs["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -318,18 +308,17 @@ def get_potential_auto_merge( template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_diff_thresh) + pair_mask = pair_mask & (templates_diff < template_similarity_kwargs["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes elif step == "knn" in steps: - if knn_kwargs is None: - knn_kwargs = dict() - pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs) + pair_mask = get_pairs_via_nntree(sorting_analyzer, **knn_kwargs, pair_mask=pair_mask) # STEP : check how the rates overlap in times elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs or dict() + presence_distance_kwargs = presence_distance_kwargs.copy() + presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) ] @@ -341,11 +330,14 @@ def get_potential_auto_merge( # STEP : check if the cross contamination is significant elif step == "cross_contamination" in steps: - refractory = (censored_period_ms, refractory_period_ms) + refractory = ( + cross_contamination_kwargs["censored_period_ms"], + cross_contamination_kwargs["refractory_period_ms"], + ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cc_thresh, refractory, contaminations + sorting_analyzer, pair_mask, cross_contamination_kwargs["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > p_value) + pair_mask = pair_mask & (p_values > cross_contamination_kwargs["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics @@ -354,9 +346,9 @@ def get_potential_auto_merge( sorting_analyzer, pair_mask, contaminations, - firing_contamination_balance, - refractory_period_ms, - censored_period_ms, + quality_score_kwargs["firing_contamination_balance"], + quality_score_kwargs["refractory_period_ms"], + quality_score_kwargs["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score @@ -364,9 +356,6 @@ def get_potential_auto_merge( ind1, ind2 = np.nonzero(pair_mask) potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) - # some methods return identities ie (1,1) which we can cleanup first. - potential_merges = [(ids[0], ids[1]) for ids in potential_merges if ids[0] != ids[1]] - if resolve_graph: potential_merges = resolve_merging_graph(sorting, potential_merges) @@ -376,6 +365,180 @@ def get_potential_auto_merge( return potential_merges +def get_potential_auto_merge( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + min_spikes: int = 100, + min_snr: float = 2, + max_distance_um: float = 150.0, + corr_diff_thresh: float = 0.16, + template_diff_thresh: float = 0.25, + contamination_thresh: float = 0.2, + presence_distance_thresh: float = 100, + p_value: float = 0.2, + cc_thresh: float = 0.1, + censored_period_ms: float = 0.3, + refractory_period_ms: float = 1.0, + sigma_smooth_ms: float = 0.6, + adaptative_window_thresh: float = 0.5, + censor_correlograms_ms: float = 0.15, + firing_contamination_balance: float = 2.5, + k_nn: int = 10, + knn_kwargs: dict | None = None, + presence_distance_kwargs: dict | None = None, + extra_outputs: bool = False, + steps: list[str] | None = None, +) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + """ + Algorithm to find and check potential merges between units. + + The merges are proposed based on a series of steps with different criteria: + + * "num_spikes": enough spikes are found in each unit for computing the correlogram (`min_spikes`) + * "snr": the SNR of the units is above a threshold (`min_snr`) + * "remove_contaminated": each unit is not contaminated (by checking auto-correlogram - `contamination_thresh`) + * "unit_locations": estimated unit locations are close enough (`max_distance_um`) + * "correlogram": the cross-correlograms of the two units are similar to each auto-corrleogram (`corr_diff_thresh`) + * "template_similarity": the templates of the two units are similar (`template_diff_thresh`) + * "presence_distance": the presence of the units is complementary in time (`presence_distance_thresh`) + * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) + * "knn": the two units are close in the feature space + * "quality_score": the unit "quality score" is increased after the merge + + The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in + contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). + + .. math:: + + Q = f(1 - (k + 1)C) + + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms" + The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on: + + * | "similarity_correlograms": mainly focused on template similarity and correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "correlogram", "quality_score" + * | "x_contaminations": similar to "similarity_correlograms", but checks for cross-contamination instead of correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "cross_contamination", "quality_score" + * | "temporal_splits": focused on finding temporal splits using presence distance. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "presence_distance", "quality_score" + * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. + | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", + | "knn", "quality_score" + + If `preset` is None, you can specify the steps manually with the `steps` parameter. + resolve_graph : bool, default: False + If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. + min_spikes : int, default: 100 + Minimum number of spikes for each unit to consider a potential merge. + Enough spikes are needed to estimate the correlogram + min_snr : float, default 2 + Minimum Signal to Noise ratio for templates to be considered while merging + max_distance_um : float, default: 150 + Maximum distance between units for considering a merge + corr_diff_thresh : float, default: 0.16 + The threshold on the "correlogram distance metric" for considering a merge. + It needs to be between 0 and 1 + template_diff_thresh : float, default: 0.25 + The threshold on the "template distance metric" for considering a merge. + It needs to be between 0 and 1 + contamination_thresh : float, default: 0.2 + Threshold for not taking in account a unit when it is too contaminated. + presence_distance_thresh : float, default: 100 + Parameter to control how present two units should be simultaneously. + p_value : float, default: 0.2 + The p-value threshold for the cross-contamination test. + cc_thresh : float, default: 0.1 + The threshold on the cross-contamination for considering a merge. + censored_period_ms : float, default: 0.3 + Used to compute the refractory period violations aka "contamination". + refractory_period_ms : float, default: 1 + Used to compute the refractory period violations aka "contamination". + sigma_smooth_ms : float, default: 0.6 + Parameters to smooth the correlogram estimation. + adaptative_window_thresh : float, default: 0.5 + Parameter to detect the window size in correlogram estimation. + censor_correlograms_ms : float, default: 0.15 + The period to censor on the auto and cross-correlograms. + firing_contamination_balance : float, default: 2.5 + Parameter to control the balance between firing rate and contamination in computing unit "quality score". + k_nn : int, default 5 + The number of neighbors to consider for every spike in the recording. + knn_kwargs : dict, default None + The dict of extra params to be passed to knn. + extra_outputs : bool, default: False + If True, an additional dictionary (`outs`) with processed data is returned. + steps : None or list of str, default: None + Which steps to run, if no preset is used. + Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", + "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" + Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). + + Returns + ------- + potential_merges: + A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). + List of pairs that could be merged. + outs: + Returned only when extra_outputs=True + A dictionary that contains data for debugging and plotting. + + References + ---------- + This function is inspired and built upon similar functions from Lussac [Llobet]_, + done by Aurelien Wyngaard and Victor Llobet. + https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + """ + presence_distance_kwargs = presence_distance_kwargs or dict() + knn_kwargs = knn_kwargs or dict() + return auto_merges( + sorting_analyzer, + preset, + resolve_graph, + num_spikes_kwargs={"min_spikes": min_spikes}, + snr_kwargs={"min_snr": min_snr}, + remove_contaminated_kwargs={ + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + unit_locations_kwargs={"max_distance_um": max_distance_um}, + correlogram_kwargs={ + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + template_similarity_kwargs={"template_diff_thresh": template_diff_thresh}, + presence_distance_kwargs={"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + knn_kwargs={"k_nn": k_nn, **knn_kwargs}, + cross_contamination_kwargs={ + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + quality_score_kwargs={ + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + compute_needed_extensions=False, + extra_outputs=extra_outputs, + steps=steps, + ) + + def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): sorting = sorting_analyzer.sorting From 484a5f4626c1cd160f40d08cb8f6980ea6f6b8b3 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 8 Oct 2024 10:38:45 +0200 Subject: [PATCH 02/25] WIP --- src/spikeinterface/curation/auto_merge.py | 151 +++++++++++----------- 1 file changed, 77 insertions(+), 74 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 00c156094d..7a101ad609 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -35,26 +35,31 @@ def auto_merges( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", resolve_graph: bool = False, - num_spikes_kwargs={"min_spikes": 100}, - snr_kwargs={"min_snr": 2}, - remove_contaminated_kwargs={"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - unit_locations_kwargs={"max_distance_um": 50}, - correlogram_kwargs={ - "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, - "sigma_smooth_ms": 0.6, - "adaptative_window_thresh": 0.5, - }, - template_similarity_kwargs={"template_diff_thresh": 0.25}, - presence_distance_kwargs={"presence_distance_thresh": 100}, - knn_kwargs={"k_nn": 10}, - cross_contamination_kwargs={ - "cc_thresh": 0.1, - "p_value": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3, - }, - quality_score_kwargs={"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + steps_params: dict = {"num_spikes" : {"min_spikes": 100}, + "snr" : {"min_snr": 2}, + "remove_contaminated" : {"contamination_thresh": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, + "unit_locations" : {"max_distance_um": 50}, + "correlogram" : { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity" : {"template_diff_thresh": 0.25}, + "presence_distance" : {"presence_distance_thresh": 100}, + "knn" : {"k_nn": 10}, + "cross_contamination" : { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score" : {"firing_contamination_balance": 2.5, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, + }, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, @@ -115,7 +120,8 @@ def auto_merges( Which steps to run, if no preset is used. Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" - Please check steps explanations above!$ + Please check steps explanations above! + steps_params : A dictionary whose keys are the steps, and keys are steps parameters. force_copy : boolean, default: True When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting already computed extensions. False if you want to overwrite @@ -140,11 +146,6 @@ def auto_merges( sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - # to get fast computation we will not analyse pairs when: - # * not enough spikes for one of theses - # * auto correlogram is contaminated - # * to far away one from each other - all_steps = [ "num_spikes", "snr", @@ -227,11 +228,13 @@ def auto_merges( for step in steps: assert step in all_steps, f"{step} is not a valid step" + params = steps_params.get(step, {}) # STEP : remove units with too few spikes if step == "num_spikes": + num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < num_spikes_kwargs["min_spikes"] + to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False outs["num_spikes"] = to_remove @@ -245,7 +248,7 @@ def auto_merges( qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < snr_kwargs["min_snr"] + to_remove = snrs < params["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False outs["snr"] = to_remove @@ -254,12 +257,12 @@ def auto_merges( elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( sorting_analyzer, - refractory_period_ms=remove_contaminated_kwargs["refractory_period_ms"], - censored_period_ms=remove_contaminated_kwargs["censored_period_ms"], + refractory_period_ms=params["refractory_period_ms"], + censored_period_ms=params["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > remove_contaminated_kwargs["contamination_thresh"] + to_remove = contaminations > params["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False outs["remove_contaminated"] = to_remove @@ -270,15 +273,15 @@ def auto_merges( unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= unit_locations_kwargs["max_distance_um"]) + pair_mask = pair_mask & (unit_distances <= params["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram elif step == "correlogram" in steps: correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - censor_ms = correlogram_kwargs["censor_correlograms_ms"] - sigma_smooth_ms = correlogram_kwargs["sigma_smooth_ms"] + censor_ms = params["censor_correlograms_ms"] + sigma_smooth_ms = params["sigma_smooth_ms"] mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) @@ -286,7 +289,7 @@ def auto_merges( win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * correlogram_kwargs["adaptative_window_thresh"] + thresh = np.max(auto_corr) * params["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -296,7 +299,7 @@ def auto_merges( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < correlogram_kwargs["corr_diff_thresh"]) + pair_mask = pair_mask & (correlogram_diff < params["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -308,16 +311,16 @@ def auto_merges( template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_similarity_kwargs["template_diff_thresh"]) + pair_mask = pair_mask & (templates_diff < params["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes elif step == "knn" in steps: - pair_mask = get_pairs_via_nntree(sorting_analyzer, **knn_kwargs, pair_mask=pair_mask) + pair_mask = get_pairs_via_nntree(sorting_analyzer, **params, pair_mask=pair_mask) # STEP : check how the rates overlap in times elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs.copy() + presence_distance_kwargs = params.copy() presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) @@ -331,13 +334,13 @@ def auto_merges( # STEP : check if the cross contamination is significant elif step == "cross_contamination" in steps: refractory = ( - cross_contamination_kwargs["censored_period_ms"], - cross_contamination_kwargs["refractory_period_ms"], + params["censored_period_ms"], + params["refractory_period_ms"], ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cross_contamination_kwargs["cc_thresh"], refractory, contaminations + sorting_analyzer, pair_mask, params["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > cross_contamination_kwargs["p_value"]) + pair_mask = pair_mask & (p_values > params["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics @@ -346,9 +349,9 @@ def auto_merges( sorting_analyzer, pair_mask, contaminations, - quality_score_kwargs["firing_contamination_balance"], - quality_score_kwargs["refractory_period_ms"], - quality_score_kwargs["censored_period_ms"], + params["firing_contamination_balance"], + params["refractory_period_ms"], + params["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score @@ -505,34 +508,34 @@ def get_potential_auto_merge( sorting_analyzer, preset, resolve_graph, - num_spikes_kwargs={"min_spikes": min_spikes}, - snr_kwargs={"min_snr": min_snr}, - remove_contaminated_kwargs={ - "contamination_thresh": contamination_thresh, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - unit_locations_kwargs={"max_distance_um": max_distance_um}, - correlogram_kwargs={ - "corr_diff_thresh": corr_diff_thresh, - "censor_correlograms_ms": censor_correlograms_ms, - "sigma_smooth_ms": sigma_smooth_ms, - "adaptative_window_thresh": adaptative_window_thresh, - }, - template_similarity_kwargs={"template_diff_thresh": template_diff_thresh}, - presence_distance_kwargs={"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, - knn_kwargs={"k_nn": k_nn, **knn_kwargs}, - cross_contamination_kwargs={ - "cc_thresh": cc_thresh, - "p_value": p_value, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - quality_score_kwargs={ - "firing_contamination_balance": firing_contamination_balance, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, + step_params={"num_spikes" : {"min_spikes": min_spikes}, + "snr_kwargs" : {"min_snr": min_snr}, + "remove_contaminated_kwargs" : { + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "unit_locations" : {"max_distance_um": max_distance_um}, + "correlogram" : { + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + "template_similarity": {"template_diff_thresh": template_diff_thresh}, + "presence_distance" : {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + "knn" : {"k_nn": k_nn, **knn_kwargs}, + "cross_contamination" : { + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "quality_score" : { + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }}, compute_needed_extensions=False, extra_outputs=extra_outputs, steps=steps, From 35ad317e619be60abbdd40f1da41a167171be1c9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 08:41:42 +0000 Subject: [PATCH 03/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 107 +++++++++++----------- 1 file changed, 53 insertions(+), 54 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 7a101ad609..db3300f0d2 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -35,31 +35,28 @@ def auto_merges( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", resolve_graph: bool = False, - steps_params: dict = {"num_spikes" : {"min_spikes": 100}, - "snr" : {"min_snr": 2}, - "remove_contaminated" : {"contamination_thresh": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, - "unit_locations" : {"max_distance_um": 50}, - "correlogram" : { - "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, - "sigma_smooth_ms": 0.6, - "adaptative_window_thresh": 0.5, - }, - "template_similarity" : {"template_diff_thresh": 0.25}, - "presence_distance" : {"presence_distance_thresh": 100}, - "knn" : {"k_nn": 10}, - "cross_contamination" : { - "cc_thresh": 0.1, - "p_value": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3, - }, - "quality_score" : {"firing_contamination_balance": 2.5, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, - }, + steps_params: dict = { + "num_spikes": {"min_spikes": 100}, + "snr": {"min_snr": 2}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 50}, + "correlogram": { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity": {"template_diff_thresh": 0.25}, + "presence_distance": {"presence_distance_thresh": 100}, + "knn": {"k_nn": 10}, + "cross_contamination": { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + }, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, @@ -232,7 +229,7 @@ def auto_merges( # STEP : remove units with too few spikes if step == "num_spikes": - + num_spikes = sorting.count_num_spikes_per_unit(outputs="array") to_remove = num_spikes < params["min_spikes"] pair_mask[to_remove, :] = False @@ -508,34 +505,36 @@ def get_potential_auto_merge( sorting_analyzer, preset, resolve_graph, - step_params={"num_spikes" : {"min_spikes": min_spikes}, - "snr_kwargs" : {"min_snr": min_snr}, - "remove_contaminated_kwargs" : { - "contamination_thresh": contamination_thresh, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - "unit_locations" : {"max_distance_um": max_distance_um}, - "correlogram" : { - "corr_diff_thresh": corr_diff_thresh, - "censor_correlograms_ms": censor_correlograms_ms, - "sigma_smooth_ms": sigma_smooth_ms, - "adaptative_window_thresh": adaptative_window_thresh, - }, - "template_similarity": {"template_diff_thresh": template_diff_thresh}, - "presence_distance" : {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, - "knn" : {"k_nn": k_nn, **knn_kwargs}, - "cross_contamination" : { - "cc_thresh": cc_thresh, - "p_value": p_value, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }, - "quality_score" : { - "firing_contamination_balance": firing_contamination_balance, - "refractory_period_ms": refractory_period_ms, - "censored_period_ms": censored_period_ms, - }}, + step_params={ + "num_spikes": {"min_spikes": min_spikes}, + "snr_kwargs": {"min_snr": min_snr}, + "remove_contaminated_kwargs": { + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "unit_locations": {"max_distance_um": max_distance_um}, + "correlogram": { + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + "template_similarity": {"template_diff_thresh": template_diff_thresh}, + "presence_distance": {"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + "knn": {"k_nn": k_nn, **knn_kwargs}, + "cross_contamination": { + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + "quality_score": { + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + }, compute_needed_extensions=False, extra_outputs=extra_outputs, steps=steps, From 3bf9b4884de04c89ffd0e89c647a9c151c27ed96 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 15 Oct 2024 14:54:30 +0200 Subject: [PATCH 04/25] Fixing tests --- src/spikeinterface/curation/auto_merge.py | 46 ++++++++++++----------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index db3300f0d2..7a8404d076 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -38,7 +38,9 @@ def auto_merges( steps_params: dict = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "remove_contaminated": {"contamination_thresh": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, "unit_locations": {"max_distance_um": 50}, "correlogram": { "corr_diff_thresh": 0.16, @@ -55,7 +57,9 @@ def auto_merges( "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 2.5, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, }, compute_needed_extensions: bool = True, extra_outputs: bool = False, @@ -203,21 +207,6 @@ def auto_merges( # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() - for step in steps: - if step in _required_extensions: - for ext in _required_extensions[step]: - if compute_needed_extensions: - if step in _templates_needed: - template_ext = sorting_analyzer.get_extension("templates") - if template_ext is None: - sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - params = eval(f"{step}_kwargs") - params = params.get(ext, dict()) - sorting_analyzer.compute(ext, **params, **job_kwargs) - else: - if not sorting_analyzer.has_extension(ext): - raise ValueError(f"{step} requires {ext} extension") - n = unit_ids.size pair_mask = np.triu(np.arange(n)) > 0 outs = dict() @@ -225,7 +214,20 @@ def auto_merges( for step in steps: assert step in all_steps, f"{step} is not a valid step" - params = steps_params.get(step, {}) + + if step in _required_extensions: + for ext in _required_extensions[step]: + if compute_needed_extensions and step in _templates_needed: + template_ext = sorting_analyzer.get_extension("templates") + if template_ext is None: + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + print(f"Extension {ext} is computed with default params") + sorting_analyzer.compute(ext, **job_kwargs) + elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): + raise ValueError(f"{step} requires {ext} extension") + + + params = steps_params.get(step, dict()) # STEP : remove units with too few spikes if step == "num_spikes": @@ -240,7 +242,7 @@ def auto_merges( elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute(["noise_levels"], **job_kwargs) + sorting_analyzer.compute("noise_levels", **job_kwargs) sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") @@ -505,10 +507,10 @@ def get_potential_auto_merge( sorting_analyzer, preset, resolve_graph, - step_params={ + steps_params={ "num_spikes": {"min_spikes": min_spikes}, - "snr_kwargs": {"min_snr": min_snr}, - "remove_contaminated_kwargs": { + "snr": {"min_snr": min_snr}, + "remove_contaminated": { "contamination_thresh": contamination_thresh, "refractory_period_ms": refractory_period_ms, "censored_period_ms": censored_period_ms, From 3df19c2e11117e9b69be4416bdb1123637ce63e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 15 Oct 2024 12:55:03 +0000 Subject: [PATCH 05/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 7a8404d076..d38b717bc8 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -38,9 +38,7 @@ def auto_merges( steps_params: dict = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, "unit_locations": {"max_distance_um": 50}, "correlogram": { "corr_diff_thresh": 0.16, @@ -57,9 +55,7 @@ def auto_merges( "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 2.5, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, }, compute_needed_extensions: bool = True, extra_outputs: bool = False, @@ -226,7 +222,6 @@ def auto_merges( elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") - params = steps_params.get(step, dict()) # STEP : remove units with too few spikes From 51edfece2f8ef041774bd2b27582021431e0f93d Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 15 Oct 2024 15:00:18 +0200 Subject: [PATCH 06/25] Fixing tests --- src/spikeinterface/curation/auto_merge.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 7a8404d076..4966db4247 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -23,12 +23,12 @@ _required_extensions = { "unit_locations": ["unit_locations"], "correlogram": ["correlograms"], - "min_snr": ["noise_levels", "templates"], + "snr": ["noise_levels", "templates"], "template_similarity": ["template_similarity"], "knn": ["spike_locations", "spike_amplitudes"], } -_templates_needed = ["unit_locations", "min_snr", "template_similarity", "spike_locations", "spike_amplitudes"] +_templates_needed = ["unit_locations", "snr", "template_similarity", "knn", "spike_amplitudes"] def auto_merges( @@ -242,7 +242,6 @@ def auto_merges( elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels", **job_kwargs) sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") @@ -537,7 +536,7 @@ def get_potential_auto_merge( "censored_period_ms": censored_period_ms, }, }, - compute_needed_extensions=False, + compute_needed_extensions=True, extra_outputs=extra_outputs, steps=steps, ) From c26b7199e086c9a3e48c99aa0495f540206e44a4 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 16 Oct 2024 14:27:14 +0200 Subject: [PATCH 07/25] Default params --- src/spikeinterface/curation/auto_merge.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 03c8c131a9..39a155ec09 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -31,11 +31,7 @@ _templates_needed = ["unit_locations", "snr", "template_similarity", "knn", "spike_amplitudes"] -def auto_merges( - sorting_analyzer: SortingAnalyzer, - preset: str | None = "similarity_correlograms", - resolve_graph: bool = False, - steps_params: dict = { +_default_step_params = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, @@ -56,7 +52,14 @@ def auto_merges( "censored_period_ms": 0.3, }, "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - }, + } + + +def auto_merges( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + steps_params: dict = None, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, @@ -222,7 +225,9 @@ def auto_merges( elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") - params = steps_params.get(step, dict()) + params = _default_step_params.get(step).copy() + if step in steps_params: + params.update(steps_params[step]) # STEP : remove units with too few spikes if step == "num_spikes": From 9692fb0fbaf294c323edc4bbaeb66d3347e2145c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:31:36 +0000 Subject: [PATCH 08/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 42 +++++++++++------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 39a155ec09..e337b3d99d 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -32,27 +32,27 @@ _default_step_params = { - "num_spikes": {"min_spikes": 100}, - "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "unit_locations": {"max_distance_um": 50}, - "correlogram": { - "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, - "sigma_smooth_ms": 0.6, - "adaptative_window_thresh": 0.5, - }, - "template_similarity": {"template_diff_thresh": 0.25}, - "presence_distance": {"presence_distance_thresh": 100}, - "knn": {"k_nn": 10}, - "cross_contamination": { - "cc_thresh": 0.1, - "p_value": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3, - }, - "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - } + "num_spikes": {"min_spikes": 100}, + "snr": {"min_snr": 2}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 50}, + "correlogram": { + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + "template_similarity": {"template_diff_thresh": 0.25}, + "presence_distance": {"presence_distance_thresh": 100}, + "knn": {"k_nn": 10}, + "cross_contamination": { + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, +} def auto_merges( From 3c277b3445dc05760d617c249fbc58a43b2d7ace Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 16 Oct 2024 14:35:24 +0200 Subject: [PATCH 09/25] Precomputing extensions --- src/spikeinterface/curation/auto_merge.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 39a155ec09..fcf5fd8fd9 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -216,12 +216,15 @@ def auto_merges( if step in _required_extensions: for ext in _required_extensions[step]: - if compute_needed_extensions and step in _templates_needed: - template_ext = sorting_analyzer.get_extension("templates") - if template_ext is None: - sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - print(f"Extension {ext} is computed with default params") - sorting_analyzer.compute(ext, **job_kwargs) + if compute_needed_extensions: + if step in _templates_needed: + template_ext = sorting_analyzer.get_extension("templates") + if template_ext is None: + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + res_ext = sorting_analyzer.get_extension(step) + if res_ext is None: + print(f"Extension {ext} is computed with default params. Precompute it with custom params if needed") + sorting_analyzer.compute(ext, **job_kwargs) elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") From a3d1c2c4f49025e01bf17b13b76b931cd071e938 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:35:58 +0000 Subject: [PATCH 10/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index ffc4fea78b..86f47af0eb 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -223,7 +223,9 @@ def auto_merges( sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) res_ext = sorting_analyzer.get_extension(step) if res_ext is None: - print(f"Extension {ext} is computed with default params. Precompute it with custom params if needed") + print( + f"Extension {ext} is computed with default params. Precompute it with custom params if needed" + ) sorting_analyzer.compute(ext, **job_kwargs) elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): raise ValueError(f"{step} requires {ext} extension") From 68b4b200907be01e63149dd673b49f1f02f9b821 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 23 Oct 2024 17:14:06 +0200 Subject: [PATCH 11/25] small updates on auto merge + renaming --- src/spikeinterface/curation/__init__.py | 2 +- src/spikeinterface/curation/auto_merge.py | 197 ++++++++++-------- .../curation/tests/test_auto_merge.py | 49 +++-- 3 files changed, 137 insertions(+), 111 deletions(-) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 657b936fb9..579e47a553 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,7 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import get_potential_auto_merge +from .auto_merge import compute_merge_unit_groups, auto_merge, get_potential_auto_merge # manual sorting, diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 86f47af0eb..16147a6225 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + from typing import Tuple import numpy as np import math @@ -17,19 +19,50 @@ from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph - -_possible_presets = ["similarity_correlograms", "x_contaminations", "temporal_splits", "feature_neighbors"] +_compute_merge_persets = { + "similarity_correlograms":[ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "correlogram", + "quality_score", + ], + "temporal_splits":[ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "presence_distance", + "quality_score", + ], + "x_contaminations":[ + "num_spikes", + "remove_contaminated", + "unit_locations", + "template_similarity", + "cross_contamination", + "quality_score", + ], + "feature_neighbors":[ + "num_spikes", + "snr", + "remove_contaminated", + "unit_locations", + "knn", + "quality_score", + ] +} _required_extensions = { - "unit_locations": ["unit_locations"], + "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "snr": ["noise_levels", "templates"], - "template_similarity": ["template_similarity"], - "knn": ["spike_locations", "spike_amplitudes"], + "snr": ["templates","noise_levels", "templates"], + "template_similarity": ["templates", "template_similarity"], + "knn": ["templates", "spike_locations", "spike_amplitudes"], + "spike_amplitudes" : ["templates"], } -_templates_needed = ["unit_locations", "snr", "template_similarity", "knn", "spike_amplitudes"] - _default_step_params = { "num_spikes": {"min_spikes": 100}, @@ -55,17 +88,18 @@ } -def auto_merges( + +def compute_merge_unit_groups( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", - resolve_graph: bool = False, + resolve_graph: bool = True, steps_params: dict = None, compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, force_copy: bool = True, **job_kwargs, -) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: +) -> list[tuple[int | str, int | str]] | Tuple[list[tuple[int | str, int | str]], dict]: """ Algorithm to find and check potential merges between units. @@ -110,7 +144,7 @@ def auto_merges( | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" If `preset` is None, you can specify the steps manually with the `steps` parameter. - resolve_graph : bool, default: False + resolve_graph : bool, default: True If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. compute_needed_extensions : bool, default : True Should we force the computation of needed extensions? @@ -128,9 +162,10 @@ def auto_merges( Returns ------- - potential_merges: - A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). - List of pairs that could be merged. + merge_unit_groups: + List of groups that need to be merge. + When `resolve_graph` is true (default) a list of tuples of 2+ elements + If `resolve_graph` is false then a list of tuple of 2 elements is returned instead. outs: Returned only when extra_outputs=True A dictionary that contains data for debugging and plotting. @@ -146,62 +181,17 @@ def auto_merges( sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - all_steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "correlogram", - "template_similarity", - "presence_distance", - "knn", - "cross_contamination", - "quality_score", - ] - if preset is not None and preset not in _possible_presets: - raise ValueError(f"preset must be one of {_possible_presets}") - - if steps is None: - if preset is None: - if steps is None: - raise ValueError("You need to specify a preset or steps for the auto-merge function") - elif preset == "similarity_correlograms": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "correlogram", - "quality_score", - ] - elif preset == "temporal_splits": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "presence_distance", - "quality_score", - ] - elif preset == "x_contaminations": - steps = [ - "num_spikes", - "remove_contaminated", - "unit_locations", - "template_similarity", - "cross_contamination", - "quality_score", - ] - elif preset == "feature_neighbors": - steps = [ - "num_spikes", - "snr", - "remove_contaminated", - "unit_locations", - "knn", - "quality_score", - ] + if preset is None and steps is None: + raise ValueError("You need to specify a preset or steps for the auto-merge function") + elif steps is not None: + # steps has presendance on presets + pass + elif preset is not None: + if preset not in _compute_merge_persets: + raise ValueError(f"preset must be one of {list(_compute_merge_persets.keys())}") + steps = _compute_merge_persets[preset] + if force_copy and compute_needed_extensions: # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() @@ -212,26 +202,23 @@ def auto_merges( for step in steps: - assert step in all_steps, f"{step} is not a valid step" + assert step in _default_step_params, f"{step} is not a valid step" if step in _required_extensions: for ext in _required_extensions[step]: - if compute_needed_extensions: - if step in _templates_needed: - template_ext = sorting_analyzer.get_extension("templates") - if template_ext is None: - sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) - res_ext = sorting_analyzer.get_extension(step) - if res_ext is None: - print( - f"Extension {ext} is computed with default params. Precompute it with custom params if needed" - ) - sorting_analyzer.compute(ext, **job_kwargs) - elif not compute_needed_extensions and not sorting_analyzer.has_extension(ext): + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: raise ValueError(f"{step} requires {ext} extension") + + # special case for templates + if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + else: + sorting_analyzer.compute(ext, **job_kwargs) params = _default_step_params.get(step).copy() - if step in steps_params: + if steps_params is not None and step in steps_params: params.update(steps_params[step]) # STEP : remove units with too few spikes @@ -360,15 +347,38 @@ def auto_merges( # FINAL STEP : create the final list from pair_mask boolean matrix ind1, ind2 = np.nonzero(pair_mask) - potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) + merge_unit_groups = list(zip(unit_ids[ind1], unit_ids[ind2])) if resolve_graph: - potential_merges = resolve_merging_graph(sorting, potential_merges) + merge_unit_groups = resolve_merging_graph(sorting, merge_unit_groups) if extra_outputs: - return potential_merges, outs + return merge_unit_groups, outs else: - return potential_merges + return merge_unit_groups + +def auto_merge( + sorting_analyzer: SortingAnalyzer, + compute_merge_kwargs:dict = {}, + apply_merge_kwargs: dict = {}, + **job_kwargs + ) -> SortingAnalyzer: + """ + Compute merge unit groups and apply it on a SortingAnalyzer. + Internally uses `compute_merge_unit_groups()` + """ + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, + extra_outputs=False, + **compute_merge_kwargs, + **job_kwargs + ) + + merged_analyzer = sorting_analyzer.merge_units( + merge_unit_groups, **apply_merge_kwargs, **job_kwargs + ) + return merged_analyzer + def get_potential_auto_merge( @@ -397,6 +407,9 @@ def get_potential_auto_merge( steps: list[str] | None = None, ) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: """ + This function is deprecated. Use compute_merge_unit_groups() instead. + This will be removed in 0.103.0 + Algorithm to find and check potential merges between units. The merges are proposed based on a series of steps with different criteria: @@ -505,9 +518,15 @@ def get_potential_auto_merge( done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py """ + warnings.warn( + "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + DeprecationWarning, + stacklevel=2, + ) + presence_distance_kwargs = presence_distance_kwargs or dict() knn_kwargs = knn_kwargs or dict() - return auto_merges( + return compute_merge_unit_groups( sorting_analyzer, preset, resolve_graph, diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 33fd06d27a..ebd7bf1504 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -3,16 +3,16 @@ from spikeinterface.core import create_sorting_analyzer from spikeinterface.core.generate import inject_some_split_units -from spikeinterface.curation import get_potential_auto_merge +from spikeinterface.curation import compute_merge_unit_groups, auto_merge from spikeinterface.curation.tests.common import make_sorting_analyzer, sorting_analyzer_for_curation @pytest.mark.parametrize( - "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms"] + "preset", ["x_contaminations", "feature_neighbors", "temporal_splits", "similarity_correlograms", None] ) -def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): +def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): print(sorting_analyzer_for_curation) sorting = sorting_analyzer_for_curation.sorting @@ -47,32 +47,37 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): ) if preset is not None: - potential_merges, outs = get_potential_auto_merge( + # do not resolve graph for checking true pairs + merge_unit_groups, outs = compute_merge_unit_groups( sorting_analyzer, preset=preset, - min_spikes=1000, - max_distance_um=150.0, - contamination_thresh=0.2, - corr_diff_thresh=0.16, - template_diff_thresh=0.25, - censored_period_ms=0.0, - refractory_period_ms=4.0, - sigma_smooth_ms=0.6, - adaptative_window_thresh=0.5, - firing_contamination_balance=1.5, + resolve_graph=False, + # min_spikes=1000, + # max_distance_um=150.0, + # contamination_thresh=0.2, + # corr_diff_thresh=0.16, + # template_diff_thresh=0.25, + # censored_period_ms=0.0, + # refractory_period_ms=4.0, + # sigma_smooth_ms=0.6, + # adaptative_window_thresh=0.5, + # firing_contamination_balance=1.5, extra_outputs=True, + **job_kwargs ) if preset == "x_contaminations": - assert len(potential_merges) == num_unit_splited + assert len(merge_unit_groups) == num_unit_splited for true_pair in other_ids.values(): true_pair = tuple(true_pair) - assert true_pair in potential_merges + assert true_pair in merge_unit_groups else: # when preset is None you have to specify the steps with pytest.raises(ValueError): - potential_merges = get_potential_auto_merge(sorting_analyzer, preset=preset) - potential_merges = get_potential_auto_merge( - sorting_analyzer, preset=preset, steps=["min_spikes", "min_snr", "remove_contaminated", "unit_positions"] + merge_unit_groups = compute_merge_unit_groups(sorting_analyzer, preset=preset) + merge_unit_groups = compute_merge_unit_groups( + sorting_analyzer, preset=preset, + steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"], + **job_kwargs ) # DEBUG @@ -93,7 +98,7 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): # m = correlograms.shape[2] // 2 - # for unit_id1, unit_id2 in potential_merges[:5]: + # for unit_id1, unit_id2 in merge_unit_groups[:5]: # unit_ind1 = sorting_with_split.id_to_index(unit_id1) # unit_ind2 = sorting_with_split.id_to_index(unit_id2) @@ -129,4 +134,6 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): if __name__ == "__main__": sorting_analyzer = make_sorting_analyzer(sparse=True) - test_get_auto_merge_list(sorting_analyzer) + # preset = "x_contaminations" + preset = None + test_compute_merge_unit_groups(sorting_analyzer, preset=preset) From 4476d4ccc6bde244561936b8ed22c9b7a0032113 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:15:47 +0000 Subject: [PATCH 12/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 44 +++++++------------ .../curation/tests/test_auto_merge.py | 7 +-- 2 files changed, 21 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 16147a6225..ec5e8be20c 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -20,7 +20,7 @@ from .curation_tools import resolve_merging_graph _compute_merge_persets = { - "similarity_correlograms":[ + "similarity_correlograms": [ "num_spikes", "remove_contaminated", "unit_locations", @@ -28,7 +28,7 @@ "correlogram", "quality_score", ], - "temporal_splits":[ + "temporal_splits": [ "num_spikes", "remove_contaminated", "unit_locations", @@ -36,7 +36,7 @@ "presence_distance", "quality_score", ], - "x_contaminations":[ + "x_contaminations": [ "num_spikes", "remove_contaminated", "unit_locations", @@ -44,23 +44,23 @@ "cross_contamination", "quality_score", ], - "feature_neighbors":[ + "feature_neighbors": [ "num_spikes", "snr", "remove_contaminated", "unit_locations", "knn", "quality_score", - ] + ], } _required_extensions = { "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "snr": ["templates","noise_levels", "templates"], + "snr": ["templates", "noise_levels", "templates"], "template_similarity": ["templates", "template_similarity"], "knn": ["templates", "spike_locations", "spike_amplitudes"], - "spike_amplitudes" : ["templates"], + "spike_amplitudes": ["templates"], } @@ -88,7 +88,6 @@ } - def compute_merge_unit_groups( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", @@ -181,7 +180,6 @@ def compute_merge_unit_groups( sorting = sorting_analyzer.sorting unit_ids = sorting.unit_ids - if preset is None and steps is None: raise ValueError("You need to specify a preset or steps for the auto-merge function") elif steps is not None: @@ -210,7 +208,7 @@ def compute_merge_unit_groups( continue if not compute_needed_extensions: raise ValueError(f"{step} requires {ext} extension") - + # special case for templates if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"): sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) @@ -357,30 +355,22 @@ def compute_merge_unit_groups( else: return merge_unit_groups + def auto_merge( - sorting_analyzer: SortingAnalyzer, - compute_merge_kwargs:dict = {}, - apply_merge_kwargs: dict = {}, - **job_kwargs - ) -> SortingAnalyzer: + sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs +) -> SortingAnalyzer: """ Compute merge unit groups and apply it on a SortingAnalyzer. Internally uses `compute_merge_unit_groups()` """ merge_unit_groups = compute_merge_unit_groups( - sorting_analyzer, - extra_outputs=False, - **compute_merge_kwargs, - **job_kwargs + sorting_analyzer, extra_outputs=False, **compute_merge_kwargs, **job_kwargs ) - merged_analyzer = sorting_analyzer.merge_units( - merge_unit_groups, **apply_merge_kwargs, **job_kwargs - ) + merged_analyzer = sorting_analyzer.merge_units(merge_unit_groups, **apply_merge_kwargs, **job_kwargs) return merged_analyzer - def get_potential_auto_merge( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", @@ -519,10 +509,10 @@ def get_potential_auto_merge( https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py """ warnings.warn( - "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", - DeprecationWarning, - stacklevel=2, - ) + "get_potential_auto_merge() is deprecated. Use compute_merge_unit_groups() instead", + DeprecationWarning, + stacklevel=2, + ) presence_distance_kwargs = presence_distance_kwargs or dict() knn_kwargs = knn_kwargs or dict() diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index ebd7bf1504..4c05f41a4c 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -63,7 +63,7 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): # adaptative_window_thresh=0.5, # firing_contamination_balance=1.5, extra_outputs=True, - **job_kwargs + **job_kwargs, ) if preset == "x_contaminations": assert len(merge_unit_groups) == num_unit_splited @@ -75,9 +75,10 @@ def test_compute_merge_unit_groups(sorting_analyzer_for_curation, preset): with pytest.raises(ValueError): merge_unit_groups = compute_merge_unit_groups(sorting_analyzer, preset=preset) merge_unit_groups = compute_merge_unit_groups( - sorting_analyzer, preset=preset, + sorting_analyzer, + preset=preset, steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"], - **job_kwargs + **job_kwargs, ) # DEBUG From f0f7f6c7165b76f07706254597c5e0730691789a Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Mon, 28 Oct 2024 10:16:27 +0100 Subject: [PATCH 13/25] Update src/spikeinterface/curation/auto_merge.py Co-authored-by: Alessio Buccino --- src/spikeinterface/curation/auto_merge.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index ec5e8be20c..73b69426f1 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -154,7 +154,8 @@ def compute_merge_unit_groups( Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" Please check steps explanations above! - steps_params : A dictionary whose keys are the steps, and keys are steps parameters. + steps_params : dict + A dictionary whose keys are the steps, and keys are steps parameters. force_copy : boolean, default: True When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting already computed extensions. False if you want to overwrite From 0dd48c424e437e9729af16f44101e881ba1d968e Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 28 Oct 2024 10:27:17 +0100 Subject: [PATCH 14/25] Typos and signatures --- src/spikeinterface/curation/auto_merge.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 73b69426f1..6680a70af4 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -19,7 +19,7 @@ from .mergeunitssorting import MergeUnitsSorting from .curation_tools import resolve_merging_graph -_compute_merge_persets = { +_compute_merge_presets = { "similarity_correlograms": [ "num_spikes", "remove_contaminated", @@ -146,7 +146,7 @@ def compute_merge_unit_groups( resolve_graph : bool, default: True If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. compute_needed_extensions : bool, default : True - Should we force the computation of needed extensions? + Should we force the computation of needed extensions, if not already computed? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None @@ -172,9 +172,11 @@ def compute_merge_unit_groups( References ---------- - This function is inspired and built upon similar functions from Lussac [Llobet]_, + This function used to be inspired and built upon similar functions from Lussac [Llobet]_, done by Aurelien Wyngaard and Victor Llobet. https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + + However, it has been greatly consolidated and refined depending on the presets. """ import scipy @@ -187,11 +189,11 @@ def compute_merge_unit_groups( # steps has presendance on presets pass elif preset is not None: - if preset not in _compute_merge_persets: - raise ValueError(f"preset must be one of {list(_compute_merge_persets.keys())}") - steps = _compute_merge_persets[preset] + if preset not in _compute_merge_presets: + raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") + steps = _compute_merge_presets[preset] - if force_copy and compute_needed_extensions: + if force_copy: # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() @@ -357,7 +359,7 @@ def compute_merge_unit_groups( return merge_unit_groups -def auto_merge( +def auto_merge_units( sorting_analyzer: SortingAnalyzer, compute_merge_kwargs: dict = {}, apply_merge_kwargs: dict = {}, **job_kwargs ) -> SortingAnalyzer: """ From a0587f6e04a210fe6bbde62e8b759176c69a47c3 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 28 Oct 2024 10:30:35 +0100 Subject: [PATCH 15/25] Cleaning requiered extensions --- src/spikeinterface/curation/auto_merge.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 6680a70af4..52dffc0378 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -57,10 +57,9 @@ _required_extensions = { "unit_locations": ["templates", "unit_locations"], "correlogram": ["correlograms"], - "snr": ["templates", "noise_levels", "templates"], + "snr": ["templates", "noise_levels"], "template_similarity": ["templates", "template_similarity"], - "knn": ["templates", "spike_locations", "spike_amplitudes"], - "spike_amplitudes": ["templates"], + "knn": ["templates", "spike_locations", "spike_amplitudes"] } From f22a4cc95a690ee5c0d89608a79c93d9207ca2be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 28 Oct 2024 09:31:21 +0000 Subject: [PATCH 16/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 52dffc0378..dfcd7bbb17 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -59,7 +59,7 @@ "correlogram": ["correlograms"], "snr": ["templates", "noise_levels"], "template_similarity": ["templates", "template_similarity"], - "knn": ["templates", "spike_locations", "spike_amplitudes"] + "knn": ["templates", "spike_locations", "spike_amplitudes"], } From 516acc9dda2c55bd5014f3ac4cee4350d3940607 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Mon, 28 Oct 2024 12:14:46 +0100 Subject: [PATCH 17/25] Names --- src/spikeinterface/curation/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/__init__.py b/src/spikeinterface/curation/__init__.py index 579e47a553..0302ffe5b7 100644 --- a/src/spikeinterface/curation/__init__.py +++ b/src/spikeinterface/curation/__init__.py @@ -3,7 +3,7 @@ from .remove_redundant import remove_redundant_units, find_redundant_units from .remove_duplicated_spikes import remove_duplicated_spikes from .remove_excess_spikes import remove_excess_spikes -from .auto_merge import compute_merge_unit_groups, auto_merge, get_potential_auto_merge +from .auto_merge import compute_merge_unit_groups, auto_merge_units, get_potential_auto_merge # manual sorting, From 78738ef679ebf8de5c4a16769aa879e51f68cf29 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 29 Oct 2024 14:41:51 +0100 Subject: [PATCH 18/25] WIP --- src/spikeinterface/curation/auto_merge.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index dfcd7bbb17..f7110f131d 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -258,7 +258,7 @@ def compute_merge_unit_groups( outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel - elif step == "unit_locations" in steps: + elif step == "unit_locations": location_ext = sorting_analyzer.get_extension("unit_locations") unit_locations = location_ext.get_data()[:, :2] @@ -267,7 +267,7 @@ def compute_merge_unit_groups( outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram - elif step == "correlogram" in steps: + elif step == "correlogram": correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() censor_ms = params["censor_correlograms_ms"] @@ -297,7 +297,7 @@ def compute_merge_unit_groups( outs["win_sizes"] = win_sizes # STEP : check if potential merge with CC also have template similarity - elif step == "template_similarity" in steps: + elif step == "template_similarity": template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity @@ -305,11 +305,11 @@ def compute_merge_unit_groups( outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes - elif step == "knn" in steps: + elif step == "knn": pair_mask = get_pairs_via_nntree(sorting_analyzer, **params, pair_mask=pair_mask) # STEP : check how the rates overlap in times - elif step == "presence_distance" in steps: + elif step == "presence_distance": presence_distance_kwargs = params.copy() presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ @@ -322,7 +322,7 @@ def compute_merge_unit_groups( outs["presence_distances"] = presence_distances # STEP : check if the cross contamination is significant - elif step == "cross_contamination" in steps: + elif step == "cross_contamination": refractory = ( params["censored_period_ms"], params["refractory_period_ms"], @@ -334,7 +334,7 @@ def compute_merge_unit_groups( outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics - elif step == "quality_score" in steps: + elif step == "quality_score": pair_mask, pairs_decreased_score = check_improve_contaminations_score( sorting_analyzer, pair_mask, From 71e38e023ab660b28957c44d518477bfabf1782b Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 29 Oct 2024 15:17:13 +0100 Subject: [PATCH 19/25] Mix up with default params. Bringing back order --- src/spikeinterface/curation/auto_merge.py | 24 +++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index f7110f131d..12f7f9eac3 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -66,11 +66,13 @@ _default_step_params = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, - "unit_locations": {"max_distance_um": 50}, + "remove_contaminated": {"contamination_thresh": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, + "unit_locations": {"max_distance_um": 150}, "correlogram": { "corr_diff_thresh": 0.16, - "censor_correlograms_ms": 0.3, + "censor_correlograms_ms": 0.15, "sigma_smooth_ms": 0.6, "adaptative_window_thresh": 0.5, }, @@ -83,7 +85,9 @@ "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 1.5, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3}, } @@ -391,7 +395,7 @@ def get_potential_auto_merge( sigma_smooth_ms: float = 0.6, adaptative_window_thresh: float = 0.5, censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, + firing_contamination_balance: float = 1.5, k_nn: int = 10, knn_kwargs: dict | None = None, presence_distance_kwargs: dict | None = None, @@ -479,7 +483,7 @@ def get_potential_auto_merge( Parameter to detect the window size in correlogram estimation. censor_correlograms_ms : float, default: 0.15 The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 + firing_contamination_balance : float, default: 1.5 Parameter to control the balance between firing rate and contamination in computing unit "quality score". k_nn : int, default 5 The number of neighbors to consider for every spike in the recording. @@ -843,10 +847,10 @@ def check_improve_contaminations_score( f_new = compute_firing_rates(sorting_analyzer_new)[unit_id1] # old and new scores - k = firing_contamination_balance - score_1 = f_1 * (1 - (k + 1) * c_1) - score_2 = f_2 * (1 - (k + 1) * c_2) - score_new = f_new * (1 - (k + 1) * c_new) + k = 1 + firing_contamination_balance + score_1 = f_1 * (1 - k * c_1) + score_2 = f_2 * (1 - k * c_2) + score_new = f_new * (1 - k * c_new) if score_new < score_1 or score_new < score_2: # the score is not improved From 10d455cdf6db3038b59f484d7bc12d107cf8c578 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Oct 2024 14:18:35 +0000 Subject: [PATCH 20/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 12f7f9eac3..085467fe9f 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -66,9 +66,7 @@ _default_step_params = { "num_spikes": {"min_spikes": 100}, "snr": {"min_snr": 2}, - "remove_contaminated": {"contamination_thresh": 0.2, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "remove_contaminated": {"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, "unit_locations": {"max_distance_um": 150}, "correlogram": { "corr_diff_thresh": 0.16, @@ -85,9 +83,7 @@ "refractory_period_ms": 1.0, "censored_period_ms": 0.3, }, - "quality_score": {"firing_contamination_balance": 1.5, - "refractory_period_ms": 1.0, - "censored_period_ms": 0.3}, + "quality_score": {"firing_contamination_balance": 1.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, } From 988723df9212bda349adc40aaa631ddb68f44123 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Tue, 29 Oct 2024 15:50:24 +0100 Subject: [PATCH 21/25] Triangular sup excluding self pairs --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 085467fe9f..994cc25d26 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -197,7 +197,7 @@ def compute_merge_unit_groups( sorting_analyzer = sorting_analyzer.copy() n = unit_ids.size - pair_mask = np.triu(np.arange(n)) > 0 + pair_mask = np.triu(np.arange(n), 1) > 0 outs = dict() for step in steps: From 95120e1391a041924879ab4236b1e431f892c020 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 29 Oct 2024 16:17:21 +0100 Subject: [PATCH 22/25] Update src/spikeinterface/curation/auto_merge.py Co-authored-by: Alessio Buccino --- src/spikeinterface/curation/auto_merge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 994cc25d26..8ac1ef0f95 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -185,7 +185,7 @@ def compute_merge_unit_groups( if preset is None and steps is None: raise ValueError("You need to specify a preset or steps for the auto-merge function") elif steps is not None: - # steps has presendance on presets + # steps has precedence on presets pass elif preset is not None: if preset not in _compute_merge_presets: From 21408543fc5589a06977997fb93567287b8cbbda Mon Sep 17 00:00:00 2001 From: Sebastien Date: Wed, 30 Oct 2024 10:52:50 +0100 Subject: [PATCH 23/25] Docs --- src/spikeinterface/curation/auto_merge.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 8ac1ef0f95..af4407b10e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -121,6 +121,9 @@ def compute_merge_unit_groups( Q = f(1 - (k + 1)C) + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- @@ -424,6 +427,9 @@ def get_potential_auto_merge( Q = f(1 - (k + 1)C) + IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- From df3d2dffda836b90a75ac8a68deb859a3b824b24 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 30 Oct 2024 09:54:41 +0000 Subject: [PATCH 24/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/curation/auto_merge.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index af4407b10e..eeeb5b2098 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -122,8 +122,8 @@ def compute_merge_unit_groups( Q = f(1 - (k + 1)C) IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed - with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to - have a finer control on these values, please precompute the extensions before applying the auto_merge + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- @@ -428,8 +428,8 @@ def get_potential_auto_merge( Q = f(1 - (k + 1)C) IMPORTANT: internally, all computations are relying on extensions of the analyzer, that are computed - with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to - have a finer control on these values, please precompute the extensions before applying the auto_merge + with default parameters if not present (i.e. correlograms, template_similarity, ...) If you want to + have a finer control on these values, please precompute the extensions before applying the auto_merge Parameters ---------- From 22b90945c82e55c15e06c9c92ebd6b752889906a Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 30 Oct 2024 16:56:22 +0100 Subject: [PATCH 25/25] avoid copy when not necessary --- src/spikeinterface/curation/auto_merge.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index eeeb5b2098..4f4cff144e 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -195,7 +195,19 @@ def compute_merge_unit_groups( raise ValueError(f"preset must be one of {list(_compute_merge_presets.keys())}") steps = _compute_merge_presets[preset] - if force_copy: + # check at least one extension is needed + at_least_one_extension_to_compute = False + for step in steps: + assert step in _default_step_params, f"{step} is not a valid step" + if step in _required_extensions: + for ext in _required_extensions[step]: + if sorting_analyzer.has_extension(ext): + continue + if not compute_needed_extensions: + raise ValueError(f"{step} requires {ext} extension") + at_least_one_extension_to_compute = True + + if force_copy and at_least_one_extension_to_compute: # To avoid erasing the extensions of the user sorting_analyzer = sorting_analyzer.copy() @@ -205,14 +217,10 @@ def compute_merge_unit_groups( for step in steps: - assert step in _default_step_params, f"{step} is not a valid step" - if step in _required_extensions: for ext in _required_extensions[step]: if sorting_analyzer.has_extension(ext): continue - if not compute_needed_extensions: - raise ValueError(f"{step} requires {ext} extension") # special case for templates if ext == "templates" and not sorting_analyzer.has_extension("random_spikes"):