diff --git a/src/spikeinterface/curation/auto_merge.py b/src/spikeinterface/curation/auto_merge.py index 19336e5943..19386c79e6 100644 --- a/src/spikeinterface/curation/auto_merge.py +++ b/src/spikeinterface/curation/auto_merge.py @@ -25,35 +25,43 @@ _required_extensions = { "unit_locations": ["unit_locations"], "correlogram": ["correlograms"], + "min_snr": ["noise_levels", "templates"], "template_similarity": ["template_similarity"], "knn": ["spike_locations", "spike_amplitudes"], } +_templates_needed = ["unit_locations", "min_snr", "template_similarity", "spike_locations", "spike_amplitudes"] -def get_potential_auto_merge( + +def auto_merges( sorting_analyzer: SortingAnalyzer, preset: str | None = "similarity_correlograms", resolve_graph: bool = False, - min_spikes: int = 100, - min_snr: float = 2, - max_distance_um: float = 150.0, - corr_diff_thresh: float = 0.16, - template_diff_thresh: float = 0.25, - contamination_thresh: float = 0.2, - presence_distance_thresh: float = 100, - p_value: float = 0.2, - cc_thresh: float = 0.1, - censored_period_ms: float = 0.3, - refractory_period_ms: float = 1.0, - sigma_smooth_ms: float = 0.6, - adaptative_window_thresh: float = 0.5, - censor_correlograms_ms: float = 0.15, - firing_contamination_balance: float = 2.5, - k_nn: int = 10, - knn_kwargs: dict | None = None, - presence_distance_kwargs: dict | None = None, + num_spikes_kwargs={"min_spikes": 100}, + snr_kwargs={"min_snr": 2}, + remove_contaminated_kwargs={"contamination_thresh": 0.2, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + unit_locations_kwargs={"max_distance_um": 50}, + correlogram_kwargs={ + "corr_diff_thresh": 0.16, + "censor_correlograms_ms": 0.3, + "sigma_smooth_ms": 0.6, + "adaptative_window_thresh": 0.5, + }, + template_similarity_kwargs={"template_diff_thresh": 0.25}, + presence_distance_kwargs={"presence_distance_thresh": 100}, + knn_kwargs={"k_nn": 10}, + cross_contamination_kwargs={ + "cc_thresh": 0.1, + "p_value": 0.2, + "refractory_period_ms": 1.0, + "censored_period_ms": 0.3, + }, + quality_score_kwargs={"firing_contamination_balance": 2.5, "refractory_period_ms": 1.0, "censored_period_ms": 0.3}, + compute_needed_extensions: bool = True, extra_outputs: bool = False, steps: list[str] | None = None, + force_copy: bool = True, + **job_kwargs, ) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: """ Algorithm to find and check potential merges between units. @@ -98,56 +106,21 @@ def get_potential_auto_merge( * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", | "knn", "quality_score" - If `preset` is None, you can specify the steps manually with the `steps` parameter. resolve_graph : bool, default: False If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. - min_spikes : int, default: 100 - Minimum number of spikes for each unit to consider a potential merge. - Enough spikes are needed to estimate the correlogram - min_snr : float, default 2 - Minimum Signal to Noise ratio for templates to be considered while merging - max_distance_um : float, default: 150 - Maximum distance between units for considering a merge - corr_diff_thresh : float, default: 0.16 - The threshold on the "correlogram distance metric" for considering a merge. - It needs to be between 0 and 1 - template_diff_thresh : float, default: 0.25 - The threshold on the "template distance metric" for considering a merge. - It needs to be between 0 and 1 - contamination_thresh : float, default: 0.2 - Threshold for not taking in account a unit when it is too contaminated. - presence_distance_thresh : float, default: 100 - Parameter to control how present two units should be simultaneously. - p_value : float, default: 0.2 - The p-value threshold for the cross-contamination test. - cc_thresh : float, default: 0.1 - The threshold on the cross-contamination for considering a merge. - censored_period_ms : float, default: 0.3 - Used to compute the refractory period violations aka "contamination". - refractory_period_ms : float, default: 1 - Used to compute the refractory period violations aka "contamination". - sigma_smooth_ms : float, default: 0.6 - Parameters to smooth the correlogram estimation. - adaptative_window_thresh : float, default: 0.5 - Parameter to detect the window size in correlogram estimation. - censor_correlograms_ms : float, default: 0.15 - The period to censor on the auto and cross-correlograms. - firing_contamination_balance : float, default: 2.5 - Parameter to control the balance between firing rate and contamination in computing unit "quality score". - k_nn : int, default 5 - The number of neighbors to consider for every spike in the recording. - knn_kwargs : dict, default None - The dict of extra params to be passed to knn. + compute_needed_extensions : bool, default : True + Should we force the computation of needed extensions? extra_outputs : bool, default: False If True, an additional dictionary (`outs`) with processed data is returned. steps : None or list of str, default: None Which steps to run, if no preset is used. Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" - Please check steps explanations above! - presence_distance_kwargs : None|dict, default: None - A dictionary of kwargs to be passed to compute_presence_distance(). + Please check steps explanations above!$ + force_copy : boolean, default: True + When new extensions are computed, the default is to make a copy of the analyzer, to avoid overwriting + already computed extensions. False if you want to overwrite Returns ------- @@ -230,12 +203,24 @@ def get_potential_auto_merge( "knn", "quality_score", ] + if force_copy and compute_needed_extensions: + # To avoid erasing the extensions of the user + sorting_analyzer = sorting_analyzer.copy() for step in steps: if step in _required_extensions: for ext in _required_extensions[step]: - if not sorting_analyzer.has_extension(ext): - raise ValueError(f"{step} requires {ext} extension") + if compute_needed_extensions: + if step in _templates_needed: + template_ext = sorting_analyzer.get_extension("templates") + if template_ext is None: + sorting_analyzer.compute(["random_spikes", "templates"], **job_kwargs) + params = eval(f"{step}_kwargs") + params = params.get(ext, dict()) + sorting_analyzer.compute(ext, **params, **job_kwargs) + else: + if not sorting_analyzer.has_extension(ext): + raise ValueError(f"{step} requires {ext} extension") n = unit_ids.size pair_mask = np.triu(np.arange(n)) > 0 @@ -248,33 +233,38 @@ def get_potential_auto_merge( # STEP : remove units with too few spikes if step == "num_spikes": num_spikes = sorting.count_num_spikes_per_unit(outputs="array") - to_remove = num_spikes < min_spikes + to_remove = num_spikes < num_spikes_kwargs["min_spikes"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["num_spikes"] = to_remove # STEP : remove units with too small SNR elif step == "snr": qm_ext = sorting_analyzer.get_extension("quality_metrics") if qm_ext is None: - sorting_analyzer.compute("noise_levels") - sorting_analyzer.compute("quality_metrics", metric_names=["snr"]) + sorting_analyzer.compute(["noise_levels"], **job_kwargs) + sorting_analyzer.compute("quality_metrics", metric_names=["snr"], **job_kwargs) qm_ext = sorting_analyzer.get_extension("quality_metrics") snrs = qm_ext.get_data()["snr"].values - to_remove = snrs < min_snr + to_remove = snrs < snr_kwargs["min_snr"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["snr"] = to_remove # STEP : remove contaminated auto corr elif step == "remove_contaminated": contaminations, nb_violations = compute_refrac_period_violations( - sorting_analyzer, refractory_period_ms=refractory_period_ms, censored_period_ms=censored_period_ms + sorting_analyzer, + refractory_period_ms=remove_contaminated_kwargs["refractory_period_ms"], + censored_period_ms=remove_contaminated_kwargs["censored_period_ms"], ) nb_violations = np.array(list(nb_violations.values())) contaminations = np.array(list(contaminations.values())) - to_remove = contaminations > contamination_thresh + to_remove = contaminations > remove_contaminated_kwargs["contamination_thresh"] pair_mask[to_remove, :] = False pair_mask[:, to_remove] = False + outs["remove_contaminated"] = to_remove # STEP : unit positions are estimated roughly with channel elif step == "unit_locations" in steps: @@ -282,21 +272,23 @@ def get_potential_auto_merge( unit_locations = location_ext.get_data()[:, :2] unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean") - pair_mask = pair_mask & (unit_distances <= max_distance_um) + pair_mask = pair_mask & (unit_distances <= unit_locations_kwargs["max_distance_um"]) outs["unit_distances"] = unit_distances # STEP : potential auto merge by correlogram elif step == "correlogram" in steps: correlograms_ext = sorting_analyzer.get_extension("correlograms") correlograms, bins = correlograms_ext.get_data() - mask = (bins[:-1] >= -censor_correlograms_ms) & (bins[:-1] < censor_correlograms_ms) + censor_ms = correlogram_kwargs["censor_correlograms_ms"] + sigma_smooth_ms = correlogram_kwargs["sigma_smooth_ms"] + mask = (bins[:-1] >= -censor_ms) & (bins[:-1] < censor_ms) correlograms[:, :, mask] = 0 correlograms_smoothed = smooth_correlogram(correlograms, bins, sigma_smooth_ms=sigma_smooth_ms) # find correlogram window for each units win_sizes = np.zeros(n, dtype=int) for unit_ind in range(n): auto_corr = correlograms_smoothed[unit_ind, unit_ind, :] - thresh = np.max(auto_corr) * adaptative_window_thresh + thresh = np.max(auto_corr) * correlogram_kwargs["adaptative_window_thresh"] win_size = get_unit_adaptive_window(auto_corr, thresh) win_sizes[unit_ind] = win_size correlogram_diff = compute_correlogram_diff( @@ -306,7 +298,7 @@ def get_potential_auto_merge( pair_mask=pair_mask, ) # print(correlogram_diff) - pair_mask = pair_mask & (correlogram_diff < corr_diff_thresh) + pair_mask = pair_mask & (correlogram_diff < correlogram_kwargs["corr_diff_thresh"]) outs["correlograms"] = correlograms outs["bins"] = bins outs["correlograms_smoothed"] = correlograms_smoothed @@ -318,18 +310,17 @@ def get_potential_auto_merge( template_similarity_ext = sorting_analyzer.get_extension("template_similarity") templates_similarity = template_similarity_ext.get_data() templates_diff = 1 - templates_similarity - pair_mask = pair_mask & (templates_diff < template_diff_thresh) + pair_mask = pair_mask & (templates_diff < template_similarity_kwargs["template_diff_thresh"]) outs["templates_diff"] = templates_diff # STEP : check the vicinity of the spikes elif step == "knn" in steps: - if knn_kwargs is None: - knn_kwargs = dict() - pair_mask = get_pairs_via_nntree(sorting_analyzer, k_nn, pair_mask, **knn_kwargs) + pair_mask = get_pairs_via_nntree(sorting_analyzer, **knn_kwargs, pair_mask=pair_mask) # STEP : check how the rates overlap in times elif step == "presence_distance" in steps: - presence_distance_kwargs = presence_distance_kwargs or dict() + presence_distance_kwargs = presence_distance_kwargs.copy() + presence_distance_thresh = presence_distance_kwargs.pop("presence_distance_thresh") num_samples = [ sorting_analyzer.get_num_samples(segment_index) for segment_index in range(sorting.get_num_segments()) ] @@ -341,11 +332,14 @@ def get_potential_auto_merge( # STEP : check if the cross contamination is significant elif step == "cross_contamination" in steps: - refractory = (censored_period_ms, refractory_period_ms) + refractory = ( + cross_contamination_kwargs["censored_period_ms"], + cross_contamination_kwargs["refractory_period_ms"], + ) CC, p_values = compute_cross_contaminations( - sorting_analyzer, pair_mask, cc_thresh, refractory, contaminations + sorting_analyzer, pair_mask, cross_contamination_kwargs["cc_thresh"], refractory, contaminations ) - pair_mask = pair_mask & (p_values > p_value) + pair_mask = pair_mask & (p_values > cross_contamination_kwargs["p_value"]) outs["cross_contaminations"] = CC, p_values # STEP : validate the potential merges with CC increase the contamination quality metrics @@ -354,9 +348,9 @@ def get_potential_auto_merge( sorting_analyzer, pair_mask, contaminations, - firing_contamination_balance, - refractory_period_ms, - censored_period_ms, + quality_score_kwargs["firing_contamination_balance"], + quality_score_kwargs["refractory_period_ms"], + quality_score_kwargs["censored_period_ms"], ) outs["pairs_decreased_score"] = pairs_decreased_score @@ -364,9 +358,6 @@ def get_potential_auto_merge( ind1, ind2 = np.nonzero(pair_mask) potential_merges = list(zip(unit_ids[ind1], unit_ids[ind2])) - # some methods return identities ie (1,1) which we can cleanup first. - potential_merges = [(ids[0], ids[1]) for ids in potential_merges if ids[0] != ids[1]] - if resolve_graph: potential_merges = resolve_merging_graph(sorting, potential_merges) @@ -376,6 +367,284 @@ def get_potential_auto_merge( return potential_merges +def get_potential_auto_merge( + sorting_analyzer: SortingAnalyzer, + preset: str | None = "similarity_correlograms", + resolve_graph: bool = False, + min_spikes: int = 100, + min_snr: float = 2, + max_distance_um: float = 150.0, + corr_diff_thresh: float = 0.16, + template_diff_thresh: float = 0.25, + contamination_thresh: float = 0.2, + presence_distance_thresh: float = 100, + p_value: float = 0.2, + cc_thresh: float = 0.1, + censored_period_ms: float = 0.3, + refractory_period_ms: float = 1.0, + sigma_smooth_ms: float = 0.6, + adaptative_window_thresh: float = 0.5, + censor_correlograms_ms: float = 0.15, + firing_contamination_balance: float = 2.5, + k_nn: int = 10, + knn_kwargs: dict | None = None, + presence_distance_kwargs: dict | None = None, + extra_outputs: bool = False, + steps: list[str] | None = None, +) -> list[tuple[int | str, int | str]] | Tuple[tuple[int | str, int | str], dict]: + """ + Algorithm to find and check potential merges between units. + + The merges are proposed based on a series of steps with different criteria: + + * "num_spikes": enough spikes are found in each unit for computing the correlogram (`min_spikes`) + * "snr": the SNR of the units is above a threshold (`min_snr`) + * "remove_contaminated": each unit is not contaminated (by checking auto-correlogram - `contamination_thresh`) + * "unit_locations": estimated unit locations are close enough (`max_distance_um`) + * "correlogram": the cross-correlograms of the two units are similar to each auto-corrleogram (`corr_diff_thresh`) + * "template_similarity": the templates of the two units are similar (`template_diff_thresh`) + * "presence_distance": the presence of the units is complementary in time (`presence_distance_thresh`) + * "cross_contamination": the cross-contamination is not significant (`cc_thresh` and `p_value`) + * "knn": the two units are close in the feature space + * "quality_score": the unit "quality score" is increased after the merge + + The "quality score" factors in the increase in firing rate (**f**) due to the merge and a possible increase in + contamination (**C**), wheighted by a factor **k** (`firing_contamination_balance`). + + .. math:: + + Q = f(1 - (k + 1)C) + + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + preset : "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" | None, default: "similarity_correlograms" + The preset to use for the auto-merge. Presets combine different steps into a recipe and focus on: + + * | "similarity_correlograms": mainly focused on template similarity and correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "correlogram", "quality_score" + * | "x_contaminations": similar to "similarity_correlograms", but checks for cross-contamination instead of correlograms. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "cross_contamination", "quality_score" + * | "temporal_splits": focused on finding temporal splits using presence distance. + | It uses the following steps: "num_spikes", "remove_contaminated", "unit_locations", + | "template_similarity", "presence_distance", "quality_score" + * | "feature_neighbors": focused on finding unit pairs whose spikes are close in the feature space using kNN. + | It uses the following steps: "num_spikes", "snr", "remove_contaminated", "unit_locations", + | "knn", "quality_score" + + If `preset` is None, you can specify the steps manually with the `steps` parameter. + resolve_graph : bool, default: False + If True, the function resolves the potential unit pairs to be merged into multiple-unit merges. + min_spikes : int, default: 100 + Minimum number of spikes for each unit to consider a potential merge. + Enough spikes are needed to estimate the correlogram + min_snr : float, default 2 + Minimum Signal to Noise ratio for templates to be considered while merging + max_distance_um : float, default: 150 + Maximum distance between units for considering a merge + corr_diff_thresh : float, default: 0.16 + The threshold on the "correlogram distance metric" for considering a merge. + It needs to be between 0 and 1 + template_diff_thresh : float, default: 0.25 + The threshold on the "template distance metric" for considering a merge. + It needs to be between 0 and 1 + contamination_thresh : float, default: 0.2 + Threshold for not taking in account a unit when it is too contaminated. + presence_distance_thresh : float, default: 100 + Parameter to control how present two units should be simultaneously. + p_value : float, default: 0.2 + The p-value threshold for the cross-contamination test. + cc_thresh : float, default: 0.1 + The threshold on the cross-contamination for considering a merge. + censored_period_ms : float, default: 0.3 + Used to compute the refractory period violations aka "contamination". + refractory_period_ms : float, default: 1 + Used to compute the refractory period violations aka "contamination". + sigma_smooth_ms : float, default: 0.6 + Parameters to smooth the correlogram estimation. + adaptative_window_thresh : float, default: 0.5 + Parameter to detect the window size in correlogram estimation. + censor_correlograms_ms : float, default: 0.15 + The period to censor on the auto and cross-correlograms. + firing_contamination_balance : float, default: 2.5 + Parameter to control the balance between firing rate and contamination in computing unit "quality score". + k_nn : int, default 5 + The number of neighbors to consider for every spike in the recording. + knn_kwargs : dict, default None + The dict of extra params to be passed to knn. + extra_outputs : bool, default: False + If True, an additional dictionary (`outs`) with processed data is returned. + steps : None or list of str, default: None + Which steps to run, if no preset is used. + Pontential steps : "num_spikes", "snr", "remove_contaminated", "unit_locations", "correlogram", + "template_similarity", "presence_distance", "cross_contamination", "knn", "quality_score" + Please check steps explanations above! + presence_distance_kwargs : None|dict, default: None + A dictionary of kwargs to be passed to compute_presence_distance(). + + Returns + ------- + potential_merges: + A list of tuples of 2 elements (if `resolve_graph`if false) or 2+ elements (if `resolve_graph` is true). + List of pairs that could be merged. + outs: + Returned only when extra_outputs=True + A dictionary that contains data for debugging and plotting. + + References + ---------- + This function is inspired and built upon similar functions from Lussac [Llobet]_, + done by Aurelien Wyngaard and Victor Llobet. + https://github.com/BarbourLab/lussac/blob/v1.0.0/postprocessing/merge_units.py + """ + presence_distance_kwargs = presence_distance_kwargs or dict() + knn_kwargs = knn_kwargs or dict() + return auto_merges( + sorting_analyzer, + preset, + resolve_graph, + num_spikes_kwargs={"min_spikes": min_spikes}, + snr_kwargs={"min_snr": min_snr}, + remove_contaminated_kwargs={ + "contamination_thresh": contamination_thresh, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + unit_locations_kwargs={"max_distance_um": max_distance_um}, + correlogram_kwargs={ + "corr_diff_thresh": corr_diff_thresh, + "censor_correlograms_ms": censor_correlograms_ms, + "sigma_smooth_ms": sigma_smooth_ms, + "adaptative_window_thresh": adaptative_window_thresh, + }, + template_similarity_kwargs={"template_diff_thresh": template_diff_thresh}, + presence_distance_kwargs={"presence_distance_thresh": presence_distance_thresh, **presence_distance_kwargs}, + knn_kwargs={"k_nn": k_nn, **knn_kwargs}, + cross_contamination_kwargs={ + "cc_thresh": cc_thresh, + "p_value": p_value, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + quality_score_kwargs={ + "firing_contamination_balance": firing_contamination_balance, + "refractory_period_ms": refractory_period_ms, + "censored_period_ms": censored_period_ms, + }, + compute_needed_extensions=False, + extra_outputs=extra_outputs, + steps=steps, + ) + + +def iterative_merges( + sorting_analyzer, + presets, + params=None, + merging_kwargs={"merging_mode": "soft", "sparsity_overlap": 0.5, "censor_ms": 3}, + compute_needed_extensions=True, + verbose=False, + extra_outputs=False, + **job_kwargs, +): + """ + Wrapper to conveniently be able to launch several presets for auto_merges in a row, as a list. Merges + are applied sequentially, one preset at a time, and extensions are not recomputed thanks to the merging units + + Parameters + ---------- + sorting_analyzer : SortingAnalyzer + The SortingAnalyzer + presets : list of presets for the auto_merges() functions. Presets can be in + "similarity_correlograms" | "x_contaminations" | "temporal_splits" | "feature_neighbors" + (see auto_merge for more details) + params : list of params that should be given to all presets. Should have the same length as presets + merging_kwargs : dict, the paramaters that should be used while merging units after each preset + compute_needed_extensions : bool, default True + During the preset, boolean to specify is extensions needed by the steps should be recomputed, + or used as they are if already present in the sorting_analyzer + extra_outputs : bool, default: False + If True, additional list of merges applied at every preset, and dictionary (`outs`) with processed data are returned. + + Returns + ------- + sorting_analyzer: + The new sorting analyzer where all the merges from all the presets have been applied + + merges, outs: + Returned only when extra_outputs=True + A list with all the merges performed at every steps, and dictionaries that contains data for debugging and plotting. + """ + + if params is None: + params = [{}] * len(presets) + + assert len(presets) == len(params) + n_units = max(sorting_analyzer.unit_ids) + 1 + + if compute_needed_extensions: + sorting_analyzer = sorting_analyzer.copy() + + if extra_outputs: + all_merges = [] + all_outs = [] + + for i in range(len(presets)): + if extra_outputs: + merges, outs = auto_merges( + sorting_analyzer, + preset=presets[i], + resolve_graph=True, + compute_needed_extensions=bool(compute_needed_extensions * (i == 0)), + extra_outputs=extra_outputs, + force_copy=False, + **params[i], + **job_kwargs, + ) + + all_merges += [merges] + all_outs += [outs] + else: + merges = auto_merges( + sorting_analyzer, + preset=presets[i], + resolve_graph=True, + compute_needed_extensions=compute_needed_extensions * (i == 0), + extra_outputs=extra_outputs, + force_copy=False, + **params[i], + **job_kwargs, + ) + + if verbose: + n_merges = int(np.sum([len(i) for i in merges])) + print(f"{n_merges} merges have been made during pass", presets[i]) + + sorting_analyzer = sorting_analyzer.merge_units(merges, **merging_kwargs, **job_kwargs) + + if extra_outputs: + + final_merges = {} + for merge in all_merges: + for count, m in enumerate(merge): + new_list = m + for k in m: + if k in final_merges: + new_list.remove(k) + new_list += final_merges[k] + final_merges[count + n_units] = new_list + if len(final_merges.keys()) > 0: + n_units = max(final_merges.keys()) + 1 + + return sorting_analyzer, list(final_merges.values()), all_outs + else: + return sorting_analyzer + + def get_pairs_via_nntree(sorting_analyzer, k_nn=5, pair_mask=None, **knn_kwargs): sorting = sorting_analyzer.sorting diff --git a/src/spikeinterface/curation/tests/test_auto_merge.py b/src/spikeinterface/curation/tests/test_auto_merge.py index 33fd06d27a..2dffd685a5 100644 --- a/src/spikeinterface/curation/tests/test_auto_merge.py +++ b/src/spikeinterface/curation/tests/test_auto_merge.py @@ -14,7 +14,6 @@ ) def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): - print(sorting_analyzer_for_curation) sorting = sorting_analyzer_for_curation.sorting recording = sorting_analyzer_for_curation.recording num_unit_splited = 1 @@ -72,7 +71,7 @@ def test_get_auto_merge_list(sorting_analyzer_for_curation, preset): with pytest.raises(ValueError): potential_merges = get_potential_auto_merge(sorting_analyzer, preset=preset) potential_merges = get_potential_auto_merge( - sorting_analyzer, preset=preset, steps=["min_spikes", "min_snr", "remove_contaminated", "unit_positions"] + sorting_analyzer, preset=preset, steps=["num_spikes", "snr", "remove_contaminated", "unit_locations"] ) # DEBUG diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 9d28340352..ea576a5842 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -598,3 +598,104 @@ def get_traces( def get_num_samples(self) -> int: return self.num_samples + + +def split_sorting_by_times( + sorting_analyzer, splitting_probability=0.5, partial_split_prob=0.95, unit_ids=None, min_snr=None, seed=None +): + sa = sorting_analyzer + sorting = sa.sorting + rng = np.random.RandomState(seed) + + sorting_split = sorting.select_units(sorting.unit_ids) + split_units = [] + original_units = [] + nb_splits = int(splitting_probability * len(sorting.unit_ids)) + if unit_ids is None: + select_from = sorting.unit_ids + if min_snr is not None: + if sa.get_extension("noise_levels") is None: + sa.compute("noise_levels") + if sa.get_extension("quality_metrics") is None: + sa.compute("quality_metrics", metric_names=["snr"]) + + snr = sa.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] + + to_split_ids = rng.choice(select_from, nb_splits, replace=False) + else: + to_split_ids = unit_ids + + import spikeinterface.curation as scur + + for unit in to_split_ids: + num_spikes = len(sorting_split.get_unit_spike_train(unit)) + indices = np.zeros(num_spikes, dtype=int) + indices[: num_spikes // 2] = (rng.rand(num_spikes // 2) < partial_split_prob).astype(int) + indices[num_spikes // 2 :] = (rng.rand(num_spikes - num_spikes // 2) < 1 - partial_split_prob).astype(int) + sorting_split = scur.split_unit_sorting( + sorting_split, split_unit_id=unit, indices_list=indices, properties_policy="remove" + ) + split_units.append(sorting_split.unit_ids[-2:]) + original_units.append(unit) + return sorting_split, split_units + + +def split_sorting_by_amplitudes(sorting_analyzer, splitting_probability=0.5, unit_ids=None, min_snr=None, seed=None): + """ + Fonction used to split a sorting based on the amplitudes of the units. This + might be used for benchmarking meta merging step (see components) + """ + + sa = sorting_analyzer + if sa.get_extension("spike_amplitudes") is None: + sa.compute("spike_amplitudes") + + rng = np.random.RandomState(seed) + + from spikeinterface.core.numpyextractors import NumpySorting + from spikeinterface.core.template_tools import get_template_extremum_channel + + extremum_channel_inds = get_template_extremum_channel(sa, outputs="index") + spikes = sa.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + new_spikes = spikes.copy() + amplitudes = sa.get_extension("spike_amplitudes").get_data() + nb_splits = int(splitting_probability * len(sa.sorting.unit_ids)) + + if unit_ids is None: + select_from = sa.sorting.unit_ids + if min_snr is not None: + if sa.get_extension("noise_levels") is None: + sa.compute("noise_levels") + if sa.get_extension("quality_metrics") is None: + sa.compute("quality_metrics", metric_names=["snr"]) + + snr = sa.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] + to_split_ids = rng.choice(select_from, nb_splits, replace=False) + else: + to_split_ids = unit_ids + + max_index = np.max(spikes["unit_index"]) + new_unit_ids = list(sa.sorting.unit_ids.copy()) + splitted_pairs = [] + for unit_id in to_split_ids: + ind_mask = spikes["unit_index"] == sa.sorting.id_to_index(unit_id) + + m = amplitudes[ind_mask].mean() + s = amplitudes[ind_mask].std() + thresh = m + 0.2 * s + + amplitude_mask = amplitudes > thresh + mask = ind_mask & amplitude_mask + new_spikes["unit_index"][mask] = max_index + 1 + + amplitude_mask = (amplitudes > m) * (amplitudes < thresh) + mask = ind_mask & amplitude_mask + new_spikes["unit_index"][mask] = (max_index + 1) * rng.rand(np.sum(mask)) > 0.5 + max_index += 1 + new_unit_ids += [max(new_unit_ids) + 1] + splitted_pairs += [(unit_id, new_unit_ids[-1])] + + new_sorting = NumpySorting(new_spikes, sampling_frequency=sa.sampling_frequency, unit_ids=new_unit_ids) + return new_sorting, splitted_pairs diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index c3b3099535..a964514acf 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -14,10 +14,6 @@ from spikeinterface.sortingcomponents.tools import cache_preprocessing from spikeinterface.core.basesorting import minimum_spike_dtype from spikeinterface.core.sparsity import compute_sparsity -from spikeinterface.core.sortinganalyzer import create_sorting_analyzer -from spikeinterface.curation.auto_merge import get_potential_auto_merge -from spikeinterface.core.analyzer_extension_core import ComputeTemplates -from spikeinterface.core.sparsity import ChannelSparsity class Spykingcircus2Sorter(ComponentsBasedSorter): @@ -36,15 +32,8 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): "seed": 42, }, "apply_motion_correction": True, - "motion_correction": {"preset": "nonrigid_fast_and_accurate"}, - "merging": { - "similarity_kwargs": {"method": "cosine", "support": "union", "max_lag_ms": 0.2}, - "correlograms_kwargs": {}, - "auto_merge": { - "min_spikes": 10, - "corr_diff_thresh": 0.25, - }, - }, + "motion_correction": {"preset": "dredge_fast"}, + "merging": {"method": "lussac"}, "clustering": {"legacy": True}, "matching": {"method": "wobble"}, "apply_preprocessing": True, @@ -105,13 +94,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.clustering import find_cluster_from_peaks from spikeinterface.sortingcomponents.matching import find_spikes_from_templates + from spikeinterface.sortingcomponents.merging import merge_spikes from spikeinterface.sortingcomponents.tools import remove_empty_templates from spikeinterface.sortingcomponents.tools import get_prototype_spike, check_probe_for_drift_correction - from spikeinterface.sortingcomponents.tools import get_prototype_spike + from spikeinterface.core.globals import set_global_job_kwargs, get_global_job_kwargs - job_kwargs = params["job_kwargs"] + job_kwargs_before = get_global_job_kwargs().copy() + job_kwargs = params["job_kwargs"].copy() job_kwargs = fix_job_kwargs(job_kwargs) - job_kwargs.update({"progress_bar": verbose}) + set_global_job_kwargs(**job_kwargs) recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) @@ -142,7 +133,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("Motion correction activated (probe geometry compatible)") motion_folder = sorter_output_folder / "motion" - params["motion_correction"].update({"folder": motion_folder}) + params["motion_correction"].update({"folder": motion_folder, "overwrite": True}) recording_f = correct_motion(recording_f, **params["motion_correction"]) else: motion_folder = None @@ -218,6 +209,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): clustering_params["ms_before"] = exclude_sweep_ms clustering_params["ms_after"] = exclude_sweep_ms clustering_params["tmp_folder"] = sorter_output_folder / "clustering" + clustering_params["verbose"] = verbose legacy = clustering_params.get("legacy", True) @@ -310,17 +302,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): shutil.rmtree(sorting_folder) merging_params = params["merging"].copy() + merging_method = merging_params.get("method", None) - if len(merging_params) > 0: - if params["motion_correction"] and motion_folder is not None: - from spikeinterface.preprocessing.motion import load_motion_info - - motion_info = load_motion_info(motion_folder) - motion = motion_info["motion"] - max_motion = max( - np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) - ) - merging_params["max_distance_um"] = max(50, 2 * max_motion) + if merging_method is not None: # peak_sign = params['detection'].get('peak_sign', 'neg') # best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign) @@ -336,7 +320,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): sorting.save(folder=curation_folder) # np.save(fitting_folder / "amplitudes", guessed_amplitudes) - sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params) + sorting = merge_spikes(recording_w, sorting, templates=templates, verbose=verbose, **merging_params) if verbose: print(f"Final merging, keeping {len(sorting.unit_ids)} units") @@ -353,44 +337,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): shutil.rmtree(folder_to_delete) sorting = sorting.save(folder=sorting_folder) + set_global_job_kwargs(**job_kwargs_before) return sorting - - -def create_sorting_analyzer_with_templates(sorting, recording, templates, remove_empty=True): - sparsity = templates.sparsity - templates_array = templates.get_dense_templates().copy() - - if remove_empty: - non_empty_unit_ids = sorting.get_non_empty_unit_ids() - non_empty_sorting = sorting.remove_empty_units() - non_empty_unit_indices = sorting.ids_to_indices(non_empty_unit_ids) - templates_array = templates_array[non_empty_unit_indices] - sparsity_mask = sparsity.mask[non_empty_unit_indices, :] - sparsity = ChannelSparsity(sparsity_mask, non_empty_unit_ids, sparsity.channel_ids) - else: - non_empty_sorting = sorting - - sa = create_sorting_analyzer(non_empty_sorting, recording, format="memory", sparsity=sparsity) - sa.extensions["templates"] = ComputeTemplates(sa) - sa.extensions["templates"].params = {"ms_before": templates.ms_before, "ms_after": templates.ms_after} - sa.extensions["templates"].data["average"] = templates_array - return sa - - -def final_cleaning_circus(recording, sorting, templates, **merging_kwargs): - - from spikeinterface.core.sorting_tools import apply_merges_to_sorting - - sa = create_sorting_analyzer_with_templates(sorting, recording, templates) - - sa.compute("unit_locations", method="monopolar_triangulation") - similarity_kwargs = merging_kwargs.pop("similarity_kwargs", {}) - sa.compute("template_similarity", **similarity_kwargs) - correlograms_kwargs = merging_kwargs.pop("correlograms_kwargs", {}) - sa.compute("correlograms", **correlograms_kwargs) - auto_merge_kwargs = merging_kwargs.pop("auto_merge", {}) - merges = get_potential_auto_merge(sa, resolve_graph=True, **auto_merge_kwargs) - sorting = apply_merges_to_sorting(sa.sorting, merges) - - return sorting diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_merging.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_merging.py new file mode 100644 index 0000000000..8c472dc306 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_merging.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +from spikeinterface.sortingcomponents.merging import merge_spikes +from spikeinterface.comparison import compare_sorter_to_ground_truth +from spikeinterface.widgets import ( + plot_agreement_matrix, + plot_unit_templates, + plot_amplitudes, + plot_crosscorrelograms, +) + +import numpy as np +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy + + +class MergingBenchmark(Benchmark): + + def __init__(self, recording, splitted_sorting, params, gt_sorting, splitted_cells=None): + self.recording = recording + self.splitted_sorting = splitted_sorting + self.method = params["method"] + self.gt_sorting = gt_sorting + self.splitted_cells = splitted_cells + self.method_kwargs = params["method_kwargs"] + self.result = {} + + def run(self, **job_kwargs): + self.result["sorting"], self.result["merges"], self.result["outs"] = merge_spikes( + self.recording, + self.splitted_sorting, + method=self.method, + verbose=True, + extra_outputs=True, + method_kwargs=self.method_kwargs, + ) + + def compute_result(self, **result_params): + sorting = self.result["sorting"] + comp = compare_sorter_to_ground_truth(self.gt_sorting, sorting, exhaustive_gt=True) + self.result["gt_comparison"] = comp + + _run_key_saved = [("sorting", "sorting"), ("merges", "pickle"), ("outs", "pickle")] + _result_key_saved = [("gt_comparison", "pickle")] + + +class MergingStudy(BenchmarkStudy): + + benchmark_class = MergingBenchmark + + def create_benchmark(self, key): + dataset_key = self.cases[key]["dataset"] + recording, gt_sorting = self.datasets[dataset_key] + params = self.cases[key]["params"] + init_kwargs = self.cases[key]["init_kwargs"] + benchmark = MergingBenchmark(recording, gt_sorting, params, **init_kwargs) + return benchmark + + def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): + import pandas as pd + + if case_keys is None: + case_keys = list(self.cases.keys()) + + if isinstance(case_keys[0], str): + index = pd.Index(case_keys, name=self.levels) + else: + index = pd.MultiIndex.from_tuples(case_keys, names=self.levels) + + columns = ["num_gt", "num_sorter", "num_well_detected"] + comp = self.get_result(case_keys[0])["gt_comparison"] + if comp.exhaustive_gt: + columns.extend(["num_false_positive", "num_redundant", "num_overmerged", "num_bad"]) + count_units = pd.DataFrame(index=index, columns=columns, dtype=int) + + for key in case_keys: + comp = self.get_result(key)["gt_comparison"] + assert comp is not None, "You need to do study.run_comparisons() first" + + gt_sorting = comp.sorting1 + sorting = comp.sorting2 + + count_units.loc[key, "num_gt"] = len(gt_sorting.get_unit_ids()) + count_units.loc[key, "num_sorter"] = len(sorting.get_unit_ids()) + count_units.loc[key, "num_well_detected"] = comp.count_well_detected_units(well_detected_score) + + if comp.exhaustive_gt: + count_units.loc[key, "num_redundant"] = comp.count_redundant_units(redundant_score) + count_units.loc[key, "num_overmerged"] = comp.count_overmerged_units(overmerged_score) + count_units.loc[key, "num_false_positive"] = comp.count_false_positive_units(redundant_score) + count_units.loc[key, "num_bad"] = comp.count_bad_units() + + return count_units + + def plot_agreements(self, case_keys=None, figsize=(15, 15)): + import matplotlib.pyplot as plt + + if case_keys is None: + case_keys = list(self.cases.keys()) + + fig, axs = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) + for count, key in enumerate(case_keys): + ax = axs[0, count] + ax.set_title(self.cases[key]["label"]) + plot_agreement_matrix(self.get_result(key)["gt_comparison"], ax=ax) + + return fig + + def plot_unit_counts(self, case_keys=None, figsize=None, **extra_kwargs): + from spikeinterface.widgets.widget_list import plot_study_unit_counts + + plot_study_unit_counts(self, case_keys, figsize=figsize, **extra_kwargs) + + def get_splitted_pairs(self, case_key): + return self.benchmarks[case_key].splitted_cells + + def get_splitted_pairs_index(self, case_key, pair): + for count, i in enumerate(self.benchmarks[case_key].splitted_cells): + if i == pair: + return count + + def plot_splitted_amplitudes(self, case_key, pair_index=0, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + if analyzer.get_extension("spike_amplitudes") is None: + analyzer.compute(["spike_amplitudes"]) + plot_amplitudes(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index], backend=backend) + + def plot_splitted_correlograms(self, case_key, pair_index=0, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + if analyzer.get_extension("correlograms") is None: + analyzer.compute(["correlograms"]) + if analyzer.get_extension("template_similarity") is None: + analyzer.compute(["template_similarity"]) + plot_crosscorrelograms(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index]) + + def plot_splitted_templates(self, case_key, pair_index=0, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + if analyzer.get_extension("spike_amplitudes") is None: + analyzer.compute(["spike_amplitudes"]) + plot_unit_templates(analyzer, unit_ids=self.get_splitted_pairs(case_key)[pair_index], backend=backend) + + def plot_potential_merges(self, case_key, min_snr=None, backend="ipywidgets"): + analyzer = self.get_sorting_analyzer(case_key) + mylist = self.get_splitted_pairs(case_key) + + if analyzer.get_extension("spike_amplitudes") is None: + analyzer.compute(["spike_amplitudes"]) + if analyzer.get_extension("correlograms") is None: + analyzer.compute(["correlograms"]) + + if min_snr is not None: + select_from = analyzer.sorting.unit_ids + if analyzer.get_extension("noise_levels") is None: + analyzer.compute("noise_levels") + if analyzer.get_extension("quality_metrics") is None: + analyzer.compute("quality_metrics", metric_names=["snr"]) + + snr = analyzer.get_extension("quality_metrics").get_data()["snr"].values + select_from = select_from[snr > min_snr] + mylist_selection = [] + for i in mylist: + if (i[0] in select_from) or (i[1] in select_from): + mylist_selection += [i] + mylist = mylist_selection + + from spikeinterface.widgets import plot_potential_merges + + plot_potential_merges(analyzer, mylist, backend=backend) diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_merging.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_merging.py new file mode 100644 index 0000000000..d3c6e37539 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_merging.py @@ -0,0 +1,73 @@ +import pytest +from pathlib import Path +import numpy as np + +import shutil + +from spikeinterface.sortingcomponents.benchmark.tests.common_benchmark_testing import make_dataset +from spikeinterface.sortingcomponents.benchmark.benchmark_merging import MergingStudy +from spikeinterface.generation.drift_tools import split_sorting_by_amplitudes, split_sorting_by_times + + +@pytest.mark.skip() +def test_benchmark_merging(create_cache_folder): + cache_folder = create_cache_folder + job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") + + recording, gt_sorting, gt_analyzer = make_dataset() + + # create study + study_folder = cache_folder / "study_clustering" + # datasets = {"toy": (recording, gt_sorting)} + datasets = {"toy": gt_analyzer} + + gt_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) + + splitted_sorting = {} + splitted_sorting["times"] = split_sorting_by_times(gt_analyzer) + splitted_sorting["amplitudes"] = split_sorting_by_amplitudes(gt_analyzer) + + cases = {} + for splits in ["times", "amplitudes"]: + for method in ["circus", "lussac"]: + cases[(method, splits)] = { + "label": f"{method}", + "dataset": "toy", + "init_kwargs": {"gt_sorting": gt_sorting, "splitted_cells": splitted_sorting[splits][1]}, + "params": {"method": method, "splitted_sorting": splitted_sorting[splits][0], "method_kwargs": {}}, + } + + if study_folder.exists(): + shutil.rmtree(study_folder) + study = MergingStudy.create(study_folder, datasets=datasets, cases=cases) + print(study) + + # this study needs analyzer + # study.create_sorting_analyzer_gt(**job_kwargs) + study.compute_metrics() + + study = MergingStudy(study_folder) + + # run and result + study.run(**job_kwargs) + study.compute_results() + + # load study to check persistency + study = MergingStudy(study_folder) + print(study) + + # plots + # study.plot_performances_vs_snr() + study.plot_agreements() + study.plot_unit_counts() + # study.plot_error_metrics() + # study.plot_metrics_vs_snr() + # study.plot_run_times() + # study.plot_metrics_vs_snr("cosine") + # study.homogeneity_score(ignore_noise=False) + # import matplotlib.pyplot as plt + # plt.show() + + +if __name__ == "__main__": + test_benchmark_merging() diff --git a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py index 08a1384333..65b2308996 100644 --- a/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/clustering_tools.py @@ -8,7 +8,6 @@ import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap -from spikeinterface.core import NumpySorting def _split_waveforms( diff --git a/src/spikeinterface/sortingcomponents/merging/__init__.py b/src/spikeinterface/sortingcomponents/merging/__init__.py new file mode 100644 index 0000000000..5c1b5498d7 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/merging/__init__.py @@ -0,0 +1,3 @@ +from .method_list import merging_methods + +from .main import merge_spikes diff --git a/src/spikeinterface/sortingcomponents/merging/circus.py b/src/spikeinterface/sortingcomponents/merging/circus.py new file mode 100644 index 0000000000..7866a82fe0 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/merging/circus.py @@ -0,0 +1,51 @@ +from __future__ import annotations +import numpy as np + +from .main import BaseMergingEngine +from spikeinterface.curation.auto_merge import iterative_merges + + +class CircusMerging(BaseMergingEngine): + """ + Meta merging inspired from the Lussac metric + """ + + default_params = { + "compute_needed_extensions": True, + "merging_kwargs": {"merging_mode": "soft", "sparsity_overlap": 0, "censor_ms": 3}, + "similarity_correlograms_kwargs": { + "unit_locations_kwargs": {"max_distance_um": 50, "unit_locations": {"method": "monopolar_triangulation"}}, + "template_similarity_kwargs": { + "template_diff_thresh": 0.25, + "template_similarity": {"method": "l2", "max_lag_ms": 0.1}, + }, + }, + "temporal_splits_kwargs": None, + } + + def __init__(self, sorting_analyzer, kwargs): + self.params = self.default_params.copy() + self.params.update(**kwargs) + self.analyzer = sorting_analyzer + + def run(self, extra_outputs=False, verbose=False, **job_kwargs): + presets = ["similarity_correlograms", "temporal_splits"] + similarity_kwargs = self.params["similarity_correlograms_kwargs"] or dict() + temporal_kwargs = self.params["temporal_splits_kwargs"] or dict() + params = [similarity_kwargs, temporal_kwargs] + + result = iterative_merges( + self.analyzer, + presets=presets, + params=params, + verbose=verbose, + extra_outputs=extra_outputs, + compute_needed_extensions=self.params["compute_needed_extensions"], + merging_kwargs=self.params["merging_kwargs"], + **job_kwargs, + ) + + if extra_outputs: + return result[0].sorting, result[1], result[2] + else: + return result.sorting diff --git a/src/spikeinterface/sortingcomponents/merging/lussac.py b/src/spikeinterface/sortingcomponents/merging/lussac.py new file mode 100644 index 0000000000..86471ecaed --- /dev/null +++ b/src/spikeinterface/sortingcomponents/merging/lussac.py @@ -0,0 +1,51 @@ +from __future__ import annotations +import numpy as np +import copy +from .main import BaseMergingEngine +from spikeinterface.curation.auto_merge import iterative_merges + + +class LussacMerging(BaseMergingEngine): + """ + Meta merging inspired from the Lussac metric + """ + + default_params = { + "compute_needed_extensions": True, + "merging_kwargs": {"merging_mode": "soft", "sparsity_overlap": 0, "censor_ms": 3}, + "template_diff_thresh": np.arange(0.05, 0.5, 0.05), + "x_contaminations_kwargs": { + "unit_locations_kwargs": {"max_distance_um": 50, "unit_locations": {"method": "monopolar_triangulation"}}, + "template_similarity_kwargs": {"template_similarity": {"method": "l2", "max_lag_ms": 0.1}}, + }, + } + + def __init__(self, sorting_analyzer, kwargs): + self.params = self.default_params.copy() + self.params.update(**kwargs) + self.analyzer = sorting_analyzer + self.iterations = self.params["template_diff_thresh"] + + def run(self, extra_outputs=False, verbose=False, **job_kwargs): + presets = ["x_contaminations"] * len(self.iterations) + params = [] + for thresh in self.iterations: + local_param = copy.deepcopy(self.params["x_contaminations_kwargs"]) + local_param["template_similarity_kwargs"].update({"template_diff_thresh": thresh}) + params += [local_param] + + result = iterative_merges( + self.analyzer, + presets=presets, + params=params, + verbose=verbose, + extra_outputs=extra_outputs, + compute_needed_extensions=self.params["compute_needed_extensions"], + merging_kwargs=self.params["merging_kwargs"], + **job_kwargs, + ) + + if extra_outputs: + return result[0].sorting, result[1], result[2] + else: + return result.sorting diff --git a/src/spikeinterface/sortingcomponents/merging/main.py b/src/spikeinterface/sortingcomponents/merging/main.py new file mode 100644 index 0000000000..4c58175df8 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/merging/main.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from threadpoolctl import threadpool_limits +import numpy as np +from spikeinterface.core.sortinganalyzer import create_sorting_analyzer +from spikeinterface.core.sparsity import ChannelSparsity +from spikeinterface.core.analyzer_extension_core import ComputeTemplates + + +def create_sorting_analyzer_with_templates(sorting, recording, templates, remove_empty=True): + sparsity = templates.sparsity + templates_array = templates.get_dense_templates().copy() + + if remove_empty: + non_empty_unit_ids = sorting.get_non_empty_unit_ids() + non_empty_sorting = sorting.remove_empty_units() + non_empty_unit_indices = sorting.ids_to_indices(non_empty_unit_ids) + templates_array = templates_array[non_empty_unit_indices] + sparsity_mask = sparsity.mask[non_empty_unit_indices, :] + sparsity = ChannelSparsity(sparsity_mask, non_empty_unit_ids, sparsity.channel_ids) + else: + non_empty_sorting = sorting + + sa = create_sorting_analyzer(non_empty_sorting, recording, format="memory", sparsity=sparsity) + sa.extensions["templates"] = ComputeTemplates(sa) + sa.extensions["templates"].params = {"ms_before": templates.ms_before, "ms_after": templates.ms_after} + sa.extensions["templates"].data["average"] = templates_array + return sa + + +def merge_spikes( + recording, + sorting, + method="circus", + templates=None, + remove_empty=True, + method_kwargs={}, + extra_outputs=False, + verbose=False, + **job_kwargs, +): + """Find spike from a recording from given templates. + + Parameters + ---------- + recording: RecordingExtractor + The recording extractor object + sorting: Sorting + The NumpySorting object + method: "circus" + Which method to use for merging spikes + method_kwargs: dict, optional + Keyword arguments for the chosen method + extra_outputs: bool + If True then method_kwargs is also returned + + Returns + ------- + new_sorting: NumpySorting + Sorting found after merging + method_kwargs: + Optionaly returns for debug purpose. + + """ + from .method_list import merging_methods + + assert method in merging_methods, f"The 'method' {method} is not valid. Use a method from {merging_methods}" + + method_class = merging_methods[method] + + if templates is None: + if remove_empty: + non_empty_sorting = sorting.remove_empty_units() + sorting_analyzer = create_sorting_analyzer(non_empty_sorting, recording) + else: + sorting_analyzer = create_sorting_analyzer_with_templates(sorting, recording, templates, remove_empty) + + method_instance = method_class(sorting_analyzer, method_kwargs) + + return method_instance.run(extra_outputs=extra_outputs, verbose=verbose, **job_kwargs) + + +# generic class for template engine +class BaseMergingEngine: + default_params = {} + + def __init__(self, sorting_analyzer, kwargs): + """This function runs before loops""" + # need to be implemented in subclass + raise NotImplementedError + + def run(self, **job_kwargs): + # need to be implemented in subclass + raise NotImplementedError diff --git a/src/spikeinterface/sortingcomponents/merging/method_list.py b/src/spikeinterface/sortingcomponents/merging/method_list.py new file mode 100644 index 0000000000..db1bb116e3 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/merging/method_list.py @@ -0,0 +1,5 @@ +from __future__ import annotations +from .circus import CircusMerging +from .lussac import LussacMerging + +merging_methods = {"circus": CircusMerging, "lussac": LussacMerging}