diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index fa2f0944d1..6682252349 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -167,6 +167,8 @@ def remove_sorting(self, key): sorting_folder = self.folder / "sortings" / self.key_to_str(key) log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json" comparison_file = self.folder / "comparisons" / self.key_to_str(key) + self.sortings[key] = None + self.comparisons[key] = None if sorting_folder.exists(): shutil.rmtree(sorting_folder) for f in (log_file, comparison_file): @@ -381,6 +383,7 @@ def get_performance_by_unit(self, case_keys=None): perf_by_unit = pd.concat(perf_by_unit) perf_by_unit = perf_by_unit.set_index(self.levels) + perf_by_unit = perf_by_unit.sort_index() return perf_by_unit def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None): diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 78e9a82cf0..ee6cf5268d 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -255,7 +255,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin): return (local_peaks,) -def sorting_to_peaks(sorting, extremum_channel_inds, dtype): +def sorting_to_peaks(sorting, extremum_channel_inds, dtype=spike_peak_dtype): spikes = sorting.to_spike_vector() peaks = np.zeros(spikes.size, dtype=dtype) peaks["sample_index"] = spikes["sample_index"] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 332e63fafa..3ce20a0209 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -330,7 +330,8 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale json.dump(check_json(info), f, indent=4) # save a copy of the sorting - NumpyFolderSorting.write_sorting(sorting, folder / "sorting") + # NumpyFolderSorting.write_sorting(sorting, folder / "sorting") + sorting.save(folder=folder / "sorting") # save recording and sorting provenance if recording.check_serializability("json"): diff --git a/src/spikeinterface/generation/drifting_generator.py b/src/spikeinterface/generation/drifting_generator.py index 201bbcaf88..8a658cd97d 100644 --- a/src/spikeinterface/generation/drifting_generator.py +++ b/src/spikeinterface/generation/drifting_generator.py @@ -97,7 +97,7 @@ def make_one_displacement_vector( start_drift_index = int(t_start_drift * displacement_sampling_frequency) end_drift_index = int(t_end_drift * displacement_sampling_frequency) - num_samples = int(displacement_sampling_frequency * duration) + num_samples = int(np.ceil(displacement_sampling_frequency * duration)) displacement_vector = np.zeros(num_samples, dtype="float32") if drift_mode == "zigzag": @@ -286,6 +286,7 @@ def generate_drifting_recording( ), generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0), generate_noise_kwargs=dict(noise_levels=(12.0, 15.0), spatial_decay=25.0), + extra_outputs=False, seed=None, ): """ @@ -314,6 +315,8 @@ def generate_drifting_recording( Parameters given to generate_sorting(). generate_noise_kwargs: dict Parameters given to generate_noise(). + extra_outputs: bool, default False + Return optionaly a dict with more variables. seed: None ot int A unique seed for all steps. @@ -326,7 +329,14 @@ def generate_drifting_recording( sorting: Sorting The ground trith soring object. Same for both recordings. - + extra_infos: + If extra_outputs=True, then return also a dict that contain various information like: + * displacement_vectors + * displacement_sampling_frequency + * unit_locations + * displacement_unit_factor + * unit_displacements + This can be helpfull for motion benchmark. """ rng = np.random.default_rng(seed=seed) @@ -356,6 +366,14 @@ def generate_drifting_recording( generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs) ) + # unit_displacements is the sum of all discplacements (times, units, direction_x_y) + unit_displacements = np.zeros((displacement_vectors.shape[0], num_units, 2)) + for direction in (0, 1): + # x and y + for i in range(displacement_vectors.shape[2]): + m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :] + unit_displacements[:, :, direction] += m + # unit_params need to be fixed before the displacement steps generate_templates_kwargs = generate_templates_kwargs.copy() unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed) @@ -400,6 +418,8 @@ def generate_drifting_recording( seed=seed, ) + sorting.set_property("gt_unit_locations", unit_locations) + ## Important precompute displacement do not work on border and so do not work for tetrode # here we bypass the interpolation and regenrate templates at severals positions. ## drifting_templates.precompute_displacements(displacements_steps) @@ -437,4 +457,14 @@ def generate_drifting_recording( amplitude_factor=None, ) - return static_recording, drifting_recording, sorting + if extra_outputs: + extra_infos = dict( + displacement_vectors=displacement_vectors, + displacement_sampling_frequency=displacement_sampling_frequency, + unit_locations=unit_locations, + displacement_unit_factor=displacement_unit_factor, + unit_displacements=unit_displacements, + ) + return static_recording, drifting_recording, sorting, extra_infos + else: + return static_recording, drifting_recording, sorting diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index e7bb1027e3..c2b9f4cfc7 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -33,6 +33,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): _default_params = { "apply_preprocessing": True, + "apply_motion_correction": False, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, "waveforms": { "ms_before": 0.5, @@ -52,10 +53,12 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "ms_before": 2.0, "ms_after": 3.0, "max_spikes_per_unit": 400, + "sparsity_threshold": 2.0, # "peak_shift_ms": 0.2, }, # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, - "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, + # "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, + "matching": {"method": "wobble", "method_kwargs": {}}, "job_kwargs": {"n_jobs": -1}, "save_array": True, } @@ -102,6 +105,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.sortingcomponents.clustering.split import split_clusters from spikeinterface.sortingcomponents.clustering.merge import merge_clusters from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse + from spikeinterface.sortingcomponents.clustering.main import find_cluster_from_peaks + from spikeinterface.sortingcomponents.tools import remove_empty_templates from sklearn.decomposition import TruncatedSVD @@ -115,10 +120,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # preprocessing if params["apply_preprocessing"]: recording = bandpass_filter(recording_raw, **params["filtering"]) - # TODO what is the best about zscore>common_reference or the reverse recording = common_reference(recording) recording = zscore(recording, dtype="float32") - # recording = whiten(recording, dtype="float32") + recording = whiten(recording, dtype="float32") # used only if "folder" or "zarr" cache_folder = sorter_output_folder / "cache_preprocessing" @@ -148,152 +152,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print("We kept %d peaks for clustering" % len(peaks)) - ms_before = params["waveforms"]["ms_before"] - ms_after = params["waveforms"]["ms_after"] + clustering_kwargs = {} + clustering_kwargs["folder"] = sorter_output_folder + clustering_kwargs["waveforms"] = params["waveforms"].copy() + clustering_kwargs["clustering"] = params["clustering"].copy() - # SVD for time compression - few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) - few_wfs = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + labels_set, post_clean_label, extra_out = find_cluster_from_peaks( + recording, peaks, method="tdc_clustering", method_kwargs=clustering_kwargs, extra_outputs=True, **job_kwargs ) - - wfs = few_wfs[:, :, 0] - tsvd = TruncatedSVD(params["svd"]["n_components"]) - tsvd.fit(wfs) - - model_folder = sorter_output_folder / "tsvd_model" - - model_folder.mkdir(exist_ok=True) - with open(model_folder / "pca_model.pkl", "wb") as f: - pickle.dump(tsvd, f) - - model_params = { - "ms_before": ms_before, - "ms_after": ms_after, - "sampling_frequency": float(sampling_frequency), - } - with open(model_folder / "params.json", "w") as f: - json.dump(model_params, f) - - # features - - features_folder = sorter_output_folder / "features" - node0 = PeakRetriever(recording, peaks) - - radius_um = params["waveforms"]["radius_um"] - node1 = ExtractSparseWaveforms( - recording, - parents=[node0], - return_output=True, - ms_before=ms_before, - ms_after=ms_after, - radius_um=radius_um, - ) - - model_folder_path = sorter_output_folder / "tsvd_model" - - node2 = TemporalPCAProjection( - recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path - ) - - pipeline_nodes = [node0, node1, node2] - - output = run_node_pipeline( - recording, - pipeline_nodes, - job_kwargs, - gather_mode="npy", - gather_kwargs=dict(exist_ok=True), - folder=features_folder, - names=["sparse_wfs", "sparse_tsvd"], - ) - - # TODO make this generic in GatherNPY ??? - sparse_mask = node1.neighbours_mask - np.save(features_folder / "sparse_mask.npy", sparse_mask) - np.save(features_folder / "peaks.npy", peaks) - - # Clustering: channel index > split > merge - split_radius_um = params["clustering"]["split_radius_um"] - neighbours_mask = get_channel_distances(recording) < split_radius_um - - original_labels = peaks["channel_index"] - - min_cluster_size = 50 - - post_split_label, split_count = split_clusters( - original_labels, - recording, - features_folder, - method="local_feature_clustering", - method_kwargs=dict( - # clusterer="hdbscan", - clusterer="isocut5", - feature_name="sparse_tsvd", - # feature_name="sparse_wfs", - neighbours_mask=neighbours_mask, - waveforms_sparse_mask=sparse_mask, - min_size_split=min_cluster_size, - clusterer_kwargs={"min_cluster_size": min_cluster_size}, - n_pca_features=3, - ), - recursive=True, - recursive_depth=3, - returns_split_count=True, - **job_kwargs, - ) - - merge_radius_um = params["clustering"]["merge_radius_um"] - threshold_diff = params["clustering"]["threshold_diff"] - - post_merge_label, peak_shifts = merge_clusters( - peaks, - post_split_label, - recording, - features_folder, - radius_um=merge_radius_um, - # method="project_distribution", - # method_kwargs=dict( - # waveforms_sparse_mask=sparse_mask, - # feature_name="sparse_wfs", - # projection="centroid", - # criteria="distrib_overlap", - # threshold_overlap=0.3, - # min_cluster_size=min_cluster_size + 1, - # num_shift=5, - # ), - method="normalized_template_diff", - method_kwargs=dict( - waveforms_sparse_mask=sparse_mask, - threshold_diff=threshold_diff, - min_cluster_size=min_cluster_size + 1, - num_shift=5, - ), - **job_kwargs, - ) - - # sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r") - + peak_shifts = extra_out["peak_shifts"] new_peaks = peaks.copy() new_peaks["sample_index"] -= peak_shifts - # clean very small cluster before peeler - post_clean_label = post_merge_label.copy() - - minimum_cluster_size = 25 - labels_set, count = np.unique(post_clean_label, return_counts=True) - to_remove = labels_set[count < minimum_cluster_size] - mask = np.isin(post_clean_label, to_remove) - post_clean_label[mask] = -1 - - # final label sets - labels_set = np.unique(post_clean_label) - labels_set = labels_set[labels_set >= 0] - mask = post_clean_label >= 0 sorting_pre_peeler = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], - post_merge_label[mask], + post_clean_label[mask], sampling_frequency, unit_ids=labels_set, ) @@ -303,6 +177,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0) nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0) + sparsity_threshold = params["templates"]["sparsity_threshold"] templates_array = estimate_templates_with_accumulator( recording_w, sorting_pre_peeler.to_spike_vector(), @@ -320,8 +195,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) # TODO : try other methods for sparsity # sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.) - sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=1.0) + sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=sparsity_threshold) templates = templates_dense.to_sparse(sparsity) + templates = remove_empty_templates(templates) # snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum") # print(snrs) @@ -350,12 +226,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # ) # ) - if matching_method == "circus-omp-svd": - job_kwargs = job_kwargs.copy() - for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: - if value in job_kwargs: - job_kwargs.pop(value) - job_kwargs["chunk_duration"] = "100ms" + # if matching_method == "circus-omp-svd": + # job_kwargs = job_kwargs.copy() + # for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]: + # if value in job_kwargs: + # job_kwargs.pop(value) + # job_kwargs["chunk_duration"] = "100ms" spikes = find_spikes_from_templates( recording_w, method=matching_method, method_kwargs=matching_params, **job_kwargs @@ -366,9 +242,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(sorter_output_folder / "noise_levels.npy", noise_levels) np.save(sorter_output_folder / "all_peaks.npy", all_peaks) - np.save(sorter_output_folder / "post_split_label.npy", post_split_label) - np.save(sorter_output_folder / "split_count.npy", split_count) - np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label) + # np.save(sorter_output_folder / "post_split_label.npy", post_split_label) + # np.save(sorter_output_folder / "split_count.npy", split_count) + # np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label) np.save(sorter_output_folder / "spikes.npy", spikes) final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype) diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 8019f4e620..fa6b6986ba 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -176,11 +176,11 @@ def test_run_sorter_by_property(): setup_module() job_list = get_job_list() - test_run_sorter_jobs_loop(job_list) + # test_run_sorter_jobs_loop(job_list) # test_run_sorter_jobs_joblib(job_list) # test_run_sorter_jobs_processpoolexecutor(job_list) # test_run_sorter_jobs_multiprocessing(job_list) # test_run_sorter_jobs_dask(job_list) # test_run_sorter_jobs_slurm(job_list) - # test_run_sorter_by_property() + test_run_sorter_by_property() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 7c66fecb44..ebddd2bd58 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -56,8 +56,16 @@ def compute_result(self, **result_params): data = spikes[self.indices][~self.noise] # data["unit_index"] = self.result["peak_labels"][~self.noise] - positions = self.gt_sorting.get_property("gt_unit_locations") - self.result["sliced_gt_sorting"].set_property("gt_unit_locations", positions) + gt_unit_locations = self.gt_sorting.get_property("gt_unit_locations") + if gt_unit_locations is None: + print("'gt_unit_locations' is not a property of the sorting so compute it") + gt_analyzer = create_sorting_analyzer(self.gt_sorting, self.recording, format="memory", sparse=True) + gt_analyzer.compute(["random_spikes", "templates"]) + ext = gt_analyzer.compute("unit_locations", method="monopolar_triangulation") + gt_unit_locations = ext.get_data() + self.gt_sorting.set_property("gt_unit_locations", gt_unit_locations) + + self.result["sliced_gt_sorting"].set_property("gt_unit_locations", gt_unit_locations) self.result["clustering"] = NumpySorting.from_times_labels( data["sample_index"], self.result["peak_labels"][~self.noise], self.recording.sampling_frequency @@ -340,13 +348,13 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs to_plot += [scores.at[real, found]] axs[0, count].scatter(depth_matched, snr_matched, c=to_plot, label="matched") axs[0, count].scatter(depth_missed, snr_missed, c=np.zeros(len(snr_missed)), label="missed") - axs[0, count].set_xlabel("snr") - axs[0, count].set_ylabel(metric) + axs[0, count].set_xlabel("depth") + axs[0, count].set_ylabel("snr") label = self.cases[key]["label"] axs[0, count].set_title(label) - axs[0, count].legend() + # axs[0, count].legend() - def plot_unit_losses(self, before, after, metric="agreement", figsize=None): + def plot_unit_losses(self, case_before, case_after, metric="agreement", figsize=None): fig, axs = plt.subplots(ncols=1, nrows=3, figsize=figsize) @@ -354,16 +362,20 @@ def plot_unit_losses(self, before, after, metric="agreement", figsize=None): ax = axs[count] - label = self.cases[after]["label"] + # label = self.cases[case_after]["label"] - positions = self.get_result(before)["gt_comparison"].sorting1.get_property("gt_unit_locations") + # positions = self.get_result(case_before)["gt_comparison"].sorting1.get_property("gt_unit_locations") - analyzer = self.get_sorting_analyzer(before) + dataset_key = self.cases[case_before]["dataset"] + rec, gt_sorting1 = self.datasets[dataset_key] + positions = gt_sorting1.get_property("gt_unit_locations") + + analyzer = self.get_sorting_analyzer(case_before) metrics_before = analyzer.get_extension("quality_metrics").get_data() x = metrics_before["snr"].values - y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values - y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values + y_before = self.get_result(case_before)["gt_comparison"].get_performance()[k].values + y_after = self.get_result(case_after)["gt_comparison"].get_performance()[k].values if count < 2: ax.set_xticks([], []) elif count == 2: @@ -431,3 +443,71 @@ def plot_comparison_clustering( ax.set_yticks([]) plt.tight_layout(h_pad=0, w_pad=0) + + def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None): + if case_keys is None: + case_keys = list(self.cases.keys()) + + for count, key in enumerate(case_keys): + label = self.cases[key]["label"] + comp = self.get_result(key)["gt_comparison"] + + unit_index = np.flatnonzero(np.sum(comp.agreement_scores.values > overmerged_score, axis=0) > 1) + overmerged_ids = comp.sorting2.unit_ids[unit_index] + + n = min(len(overmerged_ids), max_units) + if n > 0: + fig, axs = plt.subplots(nrows=n, figsize=figsize) + for i, unit_id in enumerate(overmerged_ids[:n]): + gt_unit_indices = np.flatnonzero(comp.agreement_scores.loc[:, unit_id].values > overmerged_score) + gt_unit_ids = comp.sorting1.unit_ids[gt_unit_indices] + ax = axs[i] + ax.set_title(f"unit {unit_id} - GTids {gt_unit_ids}") + + analyzer = self.get_sorting_analyzer(key) + + wf_template = analyzer.get_extension("templates") + templates = wf_template.get_templates(unit_ids=gt_unit_ids) + if analyzer.sparsity is not None: + chan_mask = np.any(analyzer.sparsity.mask[gt_unit_indices, :], axis=0) + templates = templates[:, :, chan_mask] + ax.plot(templates.swapaxes(1, 2).reshape(templates.shape[0], -1).T) + ax.set_xticks([]) + + fig.suptitle(label) + else: + print(key, "no overmerged") + + def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None): + if case_keys is None: + case_keys = list(self.cases.keys()) + + for count, key in enumerate(case_keys): + label = self.cases[key]["label"] + comp = self.get_result(key)["gt_comparison"] + + gt_unit_indices = np.flatnonzero(np.sum(comp.agreement_scores.values > oversplit_score, axis=1) > 1) + oversplit_ids = comp.sorting1.unit_ids[gt_unit_indices] + + n = min(len(oversplit_ids), max_units) + if n > 0: + fig, axs = plt.subplots(nrows=n, figsize=figsize) + for i, unit_id in enumerate(oversplit_ids[:n]): + unit_indices = np.flatnonzero(comp.agreement_scores.loc[unit_id, :].values > oversplit_score) + unit_ids = comp.sorting2.unit_ids[unit_indices] + ax = axs[i] + ax.set_title(f"Gt unit {unit_id} - unit_ids: {unit_ids}") + + templates = self.get_result(key)["clustering_templates"] + + template_arrays = templates.get_dense_templates()[unit_indices, :, :] + if templates.sparsity is not None: + chan_mask = np.any(templates.sparsity.mask[gt_unit_indices, :], axis=0) + template_arrays = template_arrays[:, :, chan_mask] + + ax.plot(template_arrays.swapaxes(1, 2).reshape(template_arrays.shape[0], -1).T) + ax.set_xticks([]) + + fig.suptitle(label) + else: + print(key, "no over splited") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 12f0ff7a4a..30175288a3 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -28,43 +28,7 @@ # TODO : read from mearec -def get_unit_disclacement(displacement_vectors, displacement_unit_factor, direction_dim=1): - """ - Get final displacement vector unit per units. - - See drifting_tools for shapes. - - - Parameters - ---------- - - displacement_vectors: list of numpy array - The lenght of the list is the number of segment. - Per segment, the drift vector is a numpy array with shape (num_times, 2, num_motions) - num_motions is generally = 1 but can be > 1 in case of combining several drift vectors - displacement_unit_factor: numpy array or None, default: None - A array containing the factor per unit of the drift. - This is used to create non rigid with a factor gradient of depending on units position. - shape (num_units, num_motions) - If None then all unit have the same factor (1) and the drift is rigid. - - Returns - ------- - unit_displacements: numpy array - shape (num_times, num_units) - - - """ - num_units = displacement_unit_factor.shape[0] - unit_displacements = np.zeros((displacement_vectors.shape[0], num_units)) - for i in range(displacement_vectors.shape[2]): - m = displacement_vectors[:, direction_dim, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :] - unit_displacements[:, :] += m - - return unit_displacements - - -def get_gt_motion_from_unit_discplacement( +def get_gt_motion_from_unit_displacement( unit_displacements, displacement_sampling_frequency, unit_locations, @@ -73,6 +37,7 @@ def get_gt_motion_from_unit_discplacement( direction_dim=1, ): + unit_displacements = unit_displacements[:, :, direction_dim] times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) unit_displacements = f(temporal_bins) @@ -152,23 +117,14 @@ def compute_result(self, **result_params): temporal_bins = self.result["temporal_bins"] spatial_bins = self.result["spatial_bins"] - # time interpolatation of unit displacements - times = np.arange(self.unit_displacements.shape[0]) / self.displacement_sampling_frequency - f = scipy.interpolate.interp1d(times, self.unit_displacements, axis=0) - unit_displacements = f(temporal_bins) - - # spatial interpolataion of units discplacement - if spatial_bins.shape[0] == 1: - # rigid - gt_motion = np.mean(unit_displacements, axis=1)[:, None] - else: - # non rigid - gt_motion = np.zeros_like(raw_motion) - for t in range(temporal_bins.shape[0]): - f = scipy.interpolate.interp1d( - self.unit_locations[:, self.direction_dim], unit_displacements[t, :], fill_value="extrapolate" - ) - gt_motion[t, :] = f(spatial_bins) + gt_motion = get_gt_motion_from_unit_displacement( + self.unit_displacements, + self.displacement_sampling_frequency, + self.unit_locations, + temporal_bins, + spatial_bins, + direction_dim=self.direction_dim, + ) # align globally gt_motion and motion to avoid offsets motion = raw_motion.copy() @@ -207,6 +163,9 @@ def create_benchmark(self, key): return benchmark def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): + self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize) + + def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)): if case_keys is None: case_keys = list(self.cases.keys()) @@ -235,13 +194,18 @@ def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)): temporal_bins = bench.result["temporal_bins"] spatial_bins = bench.result["spatial_bins"] gt_motion = bench.result["gt_motion"] + motion = bench.result["motion"] # for i in range(self.gt_unit_positions.shape[1]): # ax.plot(temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") for i in range(gt_motion.shape[1]): depth = spatial_bins[i] - ax.plot(temporal_bins, gt_motion[:, i] + depth, color="green", lw=4) + if gt_drift: + ax.plot(temporal_bins, gt_motion[:, i] + depth, color="green", lw=4) + if tested_drift: + ax.plot(temporal_bins, motion[:, i] + depth, color="cyan", lw=2) + ax.set_xlabel("time (s)") _simpleaxis(ax) ax.set_yticks([]) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index 811673e525..5e2ded5ecc 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -10,6 +10,7 @@ import os +from spikeinterface.core import SortingAnalyzer from spikeinterface.core.core_tools import check_json from spikeinterface import load_extractor, split_job_kwargs, create_sorting_analyzer, load_sorting_analyzer @@ -38,6 +39,7 @@ class BenchmarkStudy: def __init__(self, study_folder): self.folder = Path(study_folder) self.datasets = {} + self.analyzers = {} self.cases = {} self.benchmarks = {} self.scan_folder() @@ -69,23 +71,54 @@ def create(cls, study_folder, datasets={}, cases={}, levels=None): study_folder = Path(study_folder) study_folder.mkdir(exist_ok=False, parents=True) - (study_folder / "datasets").mkdir() - (study_folder / "datasets" / "recordings").mkdir() - (study_folder / "datasets" / "gt_sortings").mkdir() + # (study_folder / "datasets").mkdir() + # (study_folder / "datasets" / "recordings").mkdir() + # (study_folder / "datasets" / "gt_sortings").mkdir() (study_folder / "run_logs").mkdir() - (study_folder / "metrics").mkdir() + # (study_folder / "metrics").mkdir() (study_folder / "results").mkdir() + (study_folder / "sorting_analyzer").mkdir() - for key, (rec, gt_sorting) in datasets.items(): + analyzers_path = {} + # for key, (rec, gt_sorting) in datasets.items(): + for key, data in datasets.items(): assert "/" not in key, "'/' cannot be in the key name!" assert "\\" not in key, "'\\' cannot be in the key name!" + local_analyzer_folder = study_folder / "sorting_analyzer" / key + + if isinstance(data, tuple): + # old case : rec + sorting + rec, gt_sorting = data + analyzer = create_sorting_analyzer( + gt_sorting, rec, sparse=True, format="binary_folder", folder=local_analyzer_folder + ) + analyzer.compute("random_spikes") + analyzer.compute("templates") + analyzer.compute("noise_levels") + else: + # new case : analzyer + assert isinstance(data, SortingAnalyzer) + analyzer = data + if data.format == "memory": + # then copy a local copy in the folder + analyzer = data.save_as(format="binary_folder", folder=local_analyzer_folder) + else: + analyzer = data + + rec, gt_sorting = analyzer.recording, analyzer.sorting + + analyzers_path[key] = str(analyzer.folder.resolve()) + # recordings are pickled - rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") + # rec.dump_to_pickle(study_folder / f"datasets/recordings/{key}.pickle") # sortings are pickled + saved as NumpyFolderSorting - gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") - gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") + # gt_sorting.dump_to_pickle(study_folder / f"datasets/gt_sortings/{key}.pickle") + # gt_sorting.save(format="numpy_folder", folder=study_folder / f"datasets/gt_sortings/{key}") + + # analyzer path (local or external) + (study_folder / "analyzers_path.json").write_text(json.dumps(analyzers_path, indent=4), encoding="utf8") info = {} info["levels"] = levels @@ -100,19 +133,29 @@ def create_benchmark(self): raise NotImplementedError def scan_folder(self): - if not (self.folder / "datasets").exists(): - raise ValueError(f"This is folder is not a GroundTruthStudy : {self.folder.absolute()}") + if not (self.folder / "sorting_analyzer").exists(): + raise ValueError(f"This is folder is not a BenchmarkStudy : {self.folder.absolute()}") with open(self.folder / "info.json", "r") as f: self.info = json.load(f) + with open(self.folder / "analyzers_path.json", "r") as f: + self.analyzers_path = json.load(f) + self.levels = self.info["levels"] - for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): - key = rec_file.stem - rec = load_extractor(rec_file) - gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) - self.datasets[key] = (rec, gt_sorting) + for key, folder in self.analyzers_path.items(): + analyzer = load_sorting_analyzer(folder) + self.analyzers[key] = analyzer + # the sorting is in memory here we take the saved one because comparisons need to pickle it later + sorting = load_extractor(analyzer.folder / "sorting") + self.datasets[key] = analyzer.recording, sorting + + # for rec_file in (self.folder / "datasets" / "recordings").glob("*.pickle"): + # key = rec_file.stem + # rec = load_extractor(rec_file) + # gt_sorting = load_extractor(self.folder / f"datasets" / "gt_sortings" / key) + # self.datasets[key] = (rec, gt_sorting) with open(self.folder / "cases.pickle", "rb") as f: self.cases = pickle.load(f) @@ -216,32 +259,34 @@ def compute_results(self, case_keys=None, verbose=False, **result_params): benchmark.save_result(self.folder / "results" / self.key_to_str(key)) def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, random_params={}, **job_kwargs): - if case_keys is None: - case_keys = self.cases.keys() - - base_folder = self.folder / "sorting_analyzer" - base_folder.mkdir(exist_ok=True) - - dataset_keys = [self.cases[key]["dataset"] for key in case_keys] - dataset_keys = set(dataset_keys) - for dataset_key in dataset_keys: - # the waveforms depend on the dataset key - folder = base_folder / self.key_to_str(dataset_key) - recording, gt_sorting = self.datasets[dataset_key] - sorting_analyzer = create_sorting_analyzer( - gt_sorting, recording, format="binary_folder", folder=folder, return_scaled=return_scaled - ) - sorting_analyzer.compute("random_spikes", **random_params) - sorting_analyzer.compute("templates", **job_kwargs) - sorting_analyzer.compute("noise_levels") + print("###### Study.create_sorting_analyzer_gt() is not used anymore!!!!!!") + # if case_keys is None: + # case_keys = self.cases.keys() + + # base_folder = self.folder / "sorting_analyzer" + # base_folder.mkdir(exist_ok=True) + + # dataset_keys = [self.cases[key]["dataset"] for key in case_keys] + # dataset_keys = set(dataset_keys) + # for dataset_key in dataset_keys: + # # the waveforms depend on the dataset key + # folder = base_folder / self.key_to_str(dataset_key) + # recording, gt_sorting = self.datasets[dataset_key] + # sorting_analyzer = create_sorting_analyzer( + # gt_sorting, recording, format="binary_folder", folder=folder, return_scaled=return_scaled + # ) + # sorting_analyzer.compute("random_spikes", **random_params) + # sorting_analyzer.compute("templates", **job_kwargs) + # sorting_analyzer.compute("noise_levels") def get_sorting_analyzer(self, case_key=None, dataset_key=None): if case_key is not None: dataset_key = self.cases[case_key]["dataset"] + return self.analyzers[dataset_key] - folder = self.folder / "sorting_analyzer" / self.key_to_str(dataset_key) - sorting_analyzer = load_sorting_analyzer(folder) - return sorting_analyzer + # folder = self.folder / "sorting_analyzer" / self.key_to_str(dataset_key) + # sorting_analyzer = load_sorting_analyzer(folder) + # return sorting_analyzer def get_templates(self, key, operator="average"): sorting_analyzer = self.get_sorting_analyzer(case_key=key) @@ -256,34 +301,47 @@ def compute_metrics(self, case_keys=None, metric_names=["snr", "firing_rate"], f for key in case_keys: dataset_key = self.cases[key]["dataset"] if dataset_key in done: - # some case can share the same waveform extractor + # some case can share the same analyzer continue done.append(dataset_key) - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if filename.exists(): - if force: - os.remove(filename) - else: - continue + # filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + # if filename.exists(): + # if force: + # os.remove(filename) + # else: + # continue sorting_analyzer = self.get_sorting_analyzer(key) - qm_ext = sorting_analyzer.compute("quality_metrics", metric_names=metric_names) + qm_ext = sorting_analyzer.get_extension("quality_metrics") + if qm_ext is None or force: + qm_ext = sorting_analyzer.compute("quality_metrics", metric_names=metric_names) + + # TODO remove this metics CSV file!!!! metrics = qm_ext.get_data() - metrics.to_csv(filename, sep="\t", index=True) + # metrics.to_csv(filename, sep="\t", index=True) def get_metrics(self, key): import pandas as pd dataset_key = self.cases[key]["dataset"] - filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" - if not filename.exists(): - return - metrics = pd.read_csv(filename, sep="\t", index_col=0) - dataset_key = self.cases[key]["dataset"] - recording, gt_sorting = self.datasets[dataset_key] - metrics.index = gt_sorting.unit_ids + analyzer = self.get_sorting_analyzer(key) + ext = analyzer.get_extension("quality_metrics") + if ext is None: + # TODO au to compute ???? + return None + + metrics = ext.get_data() return metrics + # filename = self.folder / "metrics" / f"{self.key_to_str(dataset_key)}.csv" + # if not filename.exists(): + # return + # metrics = pd.read_csv(filename, sep="\t", index_col=0) + # dataset_key = self.cases[key]["dataset"] + # recording, gt_sorting = self.datasets[dataset_key] + # metrics.index = gt_sorting.unit_ids + # return metrics + def get_units_snr(self, key): """ """ return self.get_metrics(key)["snr"] diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py index 1fa8959719..3401e36dd0 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/common_benchmark_testing.py @@ -12,18 +12,11 @@ from spikeinterface.core import ( generate_ground_truth_recording, - generate_templates, estimate_templates, Templates, - generate_sorting, - NoiseGeneratorRecording, + create_sorting_analyzer, ) -from spikeinterface.core.generate import generate_unit_locations -from spikeinterface.generation import DriftingTemplates, make_linear_displacement, InjectDriftingTemplatesRecording - - -from probeinterface import generate_multi_columns_probe - +from spikeinterface.generation import generate_drifting_recording ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -51,7 +44,14 @@ def make_dataset(): noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"), seed=2205, ) - return recording, gt_sorting + + gt_analyzer = create_sorting_analyzer(gt_sorting, recording, sparse=True, format="memory") + gt_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500) + # analyzer.compute("waveforms") + gt_analyzer.compute("templates") + gt_analyzer.compute("noise_levels") + + return recording, gt_sorting, gt_analyzer def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, return_scaled=False, **job_kwargs): @@ -83,153 +83,64 @@ def compute_gt_templates(recording, gt_sorting, ms_before=2.0, ms_after=3.0, ret def make_drifting_dataset(): - num_units = 15 - duration = 125.5 - sampling_frequency = 30000.0 - ms_before = 1.0 - ms_after = 3.0 - displacement_sampling_frequency = 5.0 - - probe = generate_multi_columns_probe( - num_columns=3, - num_contact_per_column=12, - xpitch=15, - ypitch=15, - contact_shapes="square", - contact_shape_params={"width": 10}, - ) - probe.set_device_channel_indices(np.arange(probe.contact_ids.size)) - - channel_locations = probe.contact_positions - - unit_locations = generate_unit_locations( - num_units, - channel_locations, - margin_um=20.0, - minimum_z=5.0, - maximum_z=40.0, - minimum_distance=20.0, - max_iteration=100, - distance_strict=False, - seed=None, - ) - - nbefore = int(sampling_frequency * ms_before / 1000.0) - - generate_kwargs = dict( - sampling_frequency=sampling_frequency, - ms_before=ms_before, - ms_after=ms_after, - seed=2205, - unit_params=dict( - alpha=(100.0, 500.0), - depolarization_ms=(0.09, 0.16), - repolarization_ms=np.ones(num_units) * 0.8, + static_recording, drifting_recording, sorting, extra_infos = generate_drifting_recording( + num_units=15, + duration=125.5, + sampling_frequency=30000.0, + probe_name=None, + generate_probe_kwargs=dict( + num_columns=3, + num_contact_per_column=12, + xpitch=15, + ypitch=15, + contact_shapes="square", + contact_shape_params={"width": 10}, ), - ) - templates_array = generate_templates(channel_locations, unit_locations, **generate_kwargs) - - templates = Templates( - templates_array=templates_array, - sampling_frequency=sampling_frequency, - nbefore=nbefore, - probe=probe, - ) - - drifting_templates = DriftingTemplates.from_static(templates) - channel_locations = probe.contact_positions - - start = np.array([0, -15.0]) - stop = np.array([0, 12]) - displacements = make_linear_displacement(start, stop, num_step=29) - - sorting = generate_sorting( - num_units=num_units, - sampling_frequency=sampling_frequency, - durations=[ - duration, - ], - firing_rates=25.0, - ) - sorting - - times = np.arange(0, duration, 1 / displacement_sampling_frequency) - times - - # 2 rythm - mid = (start + stop) / 2 - freq0 = 0.1 - displacement_vector0 = np.sin(2 * np.pi * freq0 * times)[:, np.newaxis] * (start - stop) + mid - # freq1 = 0.01 - # displacement_vector1 = 0.2 * np.sin(2 * np.pi * freq1 *times)[:, np.newaxis] * (start - stop) + mid - - # print() - - displacement_vectors = displacement_vector0[:, :, np.newaxis] - - # TODO gradient - num_motion = displacement_vectors.shape[2] - displacement_unit_factor = np.zeros((num_units, num_motion)) - displacement_unit_factor[:, 0] = 1 - - drifting_templates.precompute_displacements(displacements) - - direction = 1 - unit_displacements = np.zeros((displacement_vectors.shape[0], num_units)) - for i in range(displacement_vectors.shape[2]): - m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :] - unit_displacements[:, :] += m - - noise = NoiseGeneratorRecording( - num_channels=probe.contact_ids.size, - sampling_frequency=sampling_frequency, - durations=[duration], - noise_levels=1.0, - dtype="float32", - ) - - drifting_rec = InjectDriftingTemplatesRecording( - sorting=sorting, - parent_recording=noise, - drifting_templates=drifting_templates, - displacement_vectors=[displacement_vectors], - displacement_sampling_frequency=displacement_sampling_frequency, - displacement_unit_factor=displacement_unit_factor, - num_samples=[int(duration * sampling_frequency)], - amplitude_factor=None, + generate_unit_locations_kwargs=dict( + margin_um=20.0, + minimum_z=5.0, + maximum_z=40.0, + minimum_distance=18.0, + max_iteration=100, + distance_strict=False, + ), + generate_displacement_vector_kwargs=dict( + displacement_sampling_frequency=5.0, + drift_start_um=[0, 15], + drift_stop_um=[0, -15], + drift_step_um=1, + motion_list=[ + dict( + drift_mode="zigzag", + non_rigid_gradient=None, + t_start_drift=20.0, + t_end_drift=None, + period_s=50, + ), + ], + ), + generate_templates_kwargs=dict( + ms_before=1.5, + ms_after=3.0, + mode="ellipsoid", + unit_params=dict( + alpha=(150.0, 500.0), + spatial_decay=(10, 45), + ), + ), + generate_sorting_kwargs=dict(firing_rates=25.0, refractory_period_ms=4.0), + generate_noise_kwargs=dict(noise_levels=(12.0, 15.0), spatial_decay=25.0), + extra_outputs=True, + seed=None, ) - static_rec = InjectDriftingTemplatesRecording( + return dict( + drifting_rec=drifting_recording, + static_rec=static_recording, sorting=sorting, - parent_recording=noise, - drifting_templates=drifting_templates, - displacement_vectors=[displacement_vectors], - displacement_sampling_frequency=displacement_sampling_frequency, - displacement_unit_factor=np.zeros_like(displacement_unit_factor), - num_samples=[int(duration * sampling_frequency)], - amplitude_factor=None, - ) - - my_dict = _variable_from_namespace( - [ - drifting_rec, - static_rec, - sorting, - displacement_vectors, - displacement_sampling_frequency, - unit_locations, - displacement_unit_factor, - unit_displacements, - ], - locals(), + displacement_vectors=extra_infos["displacement_vectors"], + displacement_sampling_frequency=extra_infos["displacement_sampling_frequency"], + unit_locations=extra_infos["unit_locations"], + displacement_unit_factor=extra_infos["displacement_unit_factor"], + unit_displacements=extra_infos["unit_displacements"], ) - return my_dict - - -def _variable_from_namespace(objs, namespace): - d = dict() - for obj in objs: - for name in namespace: - if namespace[name] is obj: - d[name] = obj - return d diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py index d299d2492c..d9d07370cb 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_clustering.py @@ -17,28 +17,29 @@ def test_benchmark_clustering(): job_kwargs = dict(n_jobs=0.8, chunk_duration="1s") - recording, gt_sorting = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset() num_spikes = gt_sorting.to_spike_vector().size spike_indices = np.arange(0, num_spikes, 5) # create study study_folder = cache_folder / "study_clustering" - datasets = {"toy": (recording, gt_sorting)} + # datasets = {"toy": (recording, gt_sorting)} + datasets = {"toy": gt_analyzer} peaks = {} - for dataset in datasets.keys(): + for dataset, gt_analyzer in datasets.items(): - recording, gt_sorting = datasets[dataset] + # recording, gt_sorting = datasets[dataset] - sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) - sorting_analyzer.compute(["random_spikes", "templates"]) - extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index") - spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) + # sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False) + # sorting_analyzer.compute(["random_spikes", "templates"]) + extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index") + spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds) peaks[dataset] = spikes cases = {} - for method in ["random_projections", "circus"]: + for method in ["random_projections", "circus", "tdc_clustering"]: cases[method] = { "label": f"{method} on toy", "dataset": "toy", @@ -52,7 +53,7 @@ def test_benchmark_clustering(): print(study) # this study needs analyzer - study.create_sorting_analyzer_gt(**job_kwargs) + # study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() study = ClusteringStudy(study_folder) @@ -72,8 +73,7 @@ def test_benchmark_clustering(): study.plot_error_metrics() study.plot_metrics_vs_snr() study.plot_run_times() - # @pierre : This one has a bug - # study.plot_metrics_vs_snr('cosine') + study.plot_metrics_vs_snr("cosine") study.homogeneity_score(ignore_noise=False) plt.show() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py index 805f5d8327..1aae51c9ef 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_matching.py @@ -25,7 +25,7 @@ def test_benchmark_matching(): job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") - recording, gt_sorting = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset() # templates sparse gt_templates = compute_gt_templates( @@ -38,6 +38,8 @@ def test_benchmark_matching(): # create study study_folder = cache_folder / "study_matching" datasets = {"toy": (recording, gt_sorting)} + # datasets = {"toy": gt_analyzer} + cases = {} for engine in [ "wobble", @@ -54,7 +56,7 @@ def test_benchmark_matching(): print(study) # this study needs analyzer - study.create_sorting_analyzer_gt(**job_kwargs) + # study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() # run and result diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index 924b9ef385..06a3fa9140 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -17,8 +17,8 @@ from spikeinterface.sortingcomponents.benchmark.benchmark_motion_interpolation import MotionInterpolationStudy from spikeinterface.sortingcomponents.benchmark.benchmark_motion_estimation import ( - get_unit_disclacement, - get_gt_motion_from_unit_discplacement, + # get_unit_displacement, + get_gt_motion_from_unit_displacement, ) @@ -36,15 +36,16 @@ def test_benchmark_motion_interpolation(): duration = data["drifting_rec"].get_duration() channel_locations = data["drifting_rec"].get_channel_locations() - unit_displacements = get_unit_disclacement( - data["displacement_vectors"], data["displacement_unit_factor"], direction_dim=1 - ) + # unit_displacements = get_unit_displacement( + # data["displacement_vectors"], data["displacement_unit_factor"], direction_dim=1 + # ) + unit_displacements = data["unit_displacements"] bin_s = 1 temporal_bins = np.arange(0, duration, bin_s) spatial_bins = np.linspace(np.min(channel_locations[:, 1]), np.max(channel_locations[:, 1]), 10) - print(spatial_bins) - gt_motion = get_gt_motion_from_unit_discplacement( + # print(spatial_bins) + gt_motion = get_gt_motion_from_unit_displacement( unit_displacements, data["displacement_sampling_frequency"], data["unit_locations"], diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py index aebedd3b8c..fd09575193 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_detection.py @@ -19,7 +19,8 @@ def test_benchmark_peak_detection(): job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") - recording, gt_sorting = make_dataset() + # recording, gt_sorting = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset() # create study study_folder = cache_folder / "study_peak_detection" diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py index 6244055c51..fb3ecc61aa 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_peak_localization.py @@ -15,7 +15,8 @@ def test_benchmark_peak_localization(): job_kwargs = dict(n_jobs=0.8, chunk_duration="100ms") - recording, gt_sorting = make_dataset() + # recording, gt_sorting = make_dataset() + recording, gt_sorting, gt_analyzer = make_dataset() # create study study_folder = cache_folder / "study_peak_localization" diff --git a/src/spikeinterface/sortingcomponents/clustering/circus.py b/src/spikeinterface/sortingcomponents/clustering/circus.py index 38e6169ee8..4135bd4b6e 100644 --- a/src/spikeinterface/sortingcomponents/clustering/circus.py +++ b/src/spikeinterface/sortingcomponents/clustering/circus.py @@ -230,7 +230,7 @@ def main_function(cls, recording, peaks, params): ) if params["noise_levels"] is None: params["noise_levels"] = get_noise_levels(recording, return_scaled=False) - sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"]) + sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"]) templates = templates.to_sparse(sparsity) templates = remove_empty_templates(templates) diff --git a/src/spikeinterface/sortingcomponents/clustering/main.py b/src/spikeinterface/sortingcomponents/clustering/main.py index 4cb0db7db6..7381875557 100644 --- a/src/spikeinterface/sortingcomponents/clustering/main.py +++ b/src/spikeinterface/sortingcomponents/clustering/main.py @@ -26,7 +26,7 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, Returns ------- - labels: ndarray of int + labels_set: ndarray of int possible clusters list peak_labels: array of int peak_labels.shape[0] == peaks.shape[0] @@ -41,12 +41,15 @@ def find_cluster_from_peaks(recording, peaks, method="stupid", method_kwargs={}, params = method_class._default_params.copy() params.update(**method_kwargs) - labels, peak_labels = method_class.main_function(recording, peaks, params) + outputs = method_class.main_function(recording, peaks, params) if extra_outputs: - raise NotImplementedError - - return labels, peak_labels + return outputs + else: + if len(outputs) > 2: + outputs = outputs[:2] + labels_set, peak_labels = outputs + return labels_set, peak_labels find_cluster_from_peaks.__doc__ = find_cluster_from_peaks.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/sortingcomponents/clustering/method_list.py b/src/spikeinterface/sortingcomponents/clustering/method_list.py index f037ce7511..d763c516f9 100644 --- a/src/spikeinterface/sortingcomponents/clustering/method_list.py +++ b/src/spikeinterface/sortingcomponents/clustering/method_list.py @@ -9,6 +9,7 @@ from .position_and_features import PositionAndFeaturesClustering from .random_projections import RandomProjectionClustering from .circus import CircusClustering +from .tdc import TdcClustering clustering_methods = { "dummy": DummyClustering, @@ -20,4 +21,5 @@ "position_and_features": PositionAndFeaturesClustering, "random_projections": RandomProjectionClustering, "circus": CircusClustering, + "tdc_clustering": TdcClustering, } diff --git a/src/spikeinterface/sortingcomponents/clustering/tdc.py b/src/spikeinterface/sortingcomponents/clustering/tdc.py new file mode 100644 index 0000000000..c9a7da4536 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/clustering/tdc.py @@ -0,0 +1,228 @@ +from pathlib import Path + +import numpy as np +import json +import pickle +import random +import string +import shutil + +from spikeinterface.core import ( + get_noise_levels, + NumpySorting, + get_channel_distances, + estimate_templates_with_accumulator, + Templates, + compute_sparsity, + get_global_tmp_folder, +) + +from spikeinterface.sortingcomponents.matching import find_spikes_from_templates +from spikeinterface.core.node_pipeline import ( + run_node_pipeline, + ExtractDenseWaveforms, + ExtractSparseWaveforms, + PeakRetriever, +) + +from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel, cache_preprocessing +from spikeinterface.sortingcomponents.peak_detection import detect_peaks, DetectPeakLocallyExclusive +from spikeinterface.sortingcomponents.peak_selection import select_peaks +from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass, LocalizeGridConvolution +from spikeinterface.sortingcomponents.waveforms.temporal_pca import TemporalPCAProjection + +from spikeinterface.sortingcomponents.clustering.split import split_clusters +from spikeinterface.sortingcomponents.clustering.merge import merge_clusters +from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse + +from sklearn.decomposition import TruncatedSVD + +import hdbscan + + +class TdcClustering: + """ + Here the implementation of clustering used by tridesclous2 + """ + + _default_params = { + "folder": None, + "waveforms": { + "ms_before": 0.5, + "ms_after": 1.5, + "radius_um": 120.0, + }, + "svd": {"n_components": 6}, + "clustering": { + "split_radius_um": 40.0, + "merge_radius_um": 40.0, + "threshold_diff": 1.5, + }, + "job_kwargs": {}, + } + + @classmethod + def main_function(cls, recording, peaks, params): + job_kwargs = params["job_kwargs"] + + if params["folder"] is None: + randname = "".join(random.choices(string.ascii_uppercase + string.digits, k=6)) + clustering_folder = get_global_tmp_folder() / f"tdcclustering_{randname}" + clustering_folder.mkdir(parents=True, exist_ok=True) + need_folder_rm = True + else: + clustering_folder = Path(params["folder"]) + need_folder_rm = False + + sampling_frequency = recording.sampling_frequency + + ms_before = params["waveforms"]["ms_before"] + ms_after = params["waveforms"]["ms_after"] + + # SVD for time compression + few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000) + few_wfs = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) + + wfs = few_wfs[:, :, 0] + tsvd = TruncatedSVD(params["svd"]["n_components"]) + tsvd.fit(wfs) + + model_folder = clustering_folder / "tsvd_model" + + model_folder.mkdir(exist_ok=True) + with open(model_folder / "pca_model.pkl", "wb") as f: + pickle.dump(tsvd, f) + + model_params = { + "ms_before": ms_before, + "ms_after": ms_after, + "sampling_frequency": float(sampling_frequency), + } + with open(model_folder / "params.json", "w") as f: + json.dump(model_params, f) + + # features + + features_folder = clustering_folder / "features" + node0 = PeakRetriever(recording, peaks) + + radius_um = params["waveforms"]["radius_um"] + node1 = ExtractSparseWaveforms( + recording, + parents=[node0], + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=radius_um, + ) + + model_folder_path = clustering_folder / "tsvd_model" + + node2 = TemporalPCAProjection( + recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path + ) + + pipeline_nodes = [node0, node1, node2] + + run_node_pipeline( + recording, + pipeline_nodes, + job_kwargs, + gather_mode="npy", + gather_kwargs=dict(exist_ok=True), + folder=features_folder, + names=["sparse_wfs", "sparse_tsvd"], + ) + # TODO make this generic in GatherNPY ??? + sparse_mask = node1.neighbours_mask + np.save(features_folder / "sparse_mask.npy", sparse_mask) + np.save(features_folder / "peaks.npy", peaks) + + # to be able to delete feature folder + del pipeline_nodes, node0, node1, node2 + + # Clustering: channel index > split > merge + split_radius_um = params["clustering"]["split_radius_um"] + neighbours_mask = get_channel_distances(recording) < split_radius_um + + original_labels = peaks["channel_index"] + + min_cluster_size = 50 + # min_cluster_size = 10 + + post_split_label, split_count = split_clusters( + original_labels, + recording, + features_folder, + method="local_feature_clustering", + method_kwargs=dict( + clusterer="hdbscan", + # clusterer="isocut5", + feature_name="sparse_tsvd", + # feature_name="sparse_wfs", + neighbours_mask=neighbours_mask, + waveforms_sparse_mask=sparse_mask, + min_size_split=min_cluster_size, + clusterer_kwargs={"min_cluster_size": min_cluster_size}, + n_pca_features=3, + scale_n_pca_by_depth=True, + ), + recursive=True, + recursive_depth=3, + returns_split_count=True, + **job_kwargs, + ) + + merge_radius_um = params["clustering"]["merge_radius_um"] + threshold_diff = params["clustering"]["threshold_diff"] + + post_merge_label, peak_shifts = merge_clusters( + peaks, + post_split_label, + recording, + features_folder, + radius_um=merge_radius_um, + # method="project_distribution", + # method_kwargs=dict( + # waveforms_sparse_mask=sparse_mask, + # feature_name="sparse_wfs", + # projection="centroid", + # criteria="distrib_overlap", + # threshold_overlap=0.3, + # min_cluster_size=min_cluster_size + 1, + # num_shift=5, + # ), + method="normalized_template_diff", + method_kwargs=dict( + waveforms_sparse_mask=sparse_mask, + threshold_diff=threshold_diff, + min_cluster_size=min_cluster_size + 1, + num_shift=5, + ), + **job_kwargs, + ) + + # sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r") + + # new_peaks = peaks.copy() + # new_peaks["sample_index"] -= peak_shifts + + # clean very small cluster before peeler + post_clean_label = post_merge_label.copy() + + minimum_cluster_size = 25 + labels_set, count = np.unique(post_clean_label, return_counts=True) + to_remove = labels_set[count < minimum_cluster_size] + mask = np.isin(post_clean_label, to_remove) + post_clean_label[mask] = -1 + + labels_set = np.unique(post_clean_label) + labels_set = labels_set[labels_set >= 0] + + if need_folder_rm: + shutil.rmtree(clustering_folder) + + extra_out = {"peak_shifts": peak_shifts} + return labels_set, post_clean_label, extra_out diff --git a/src/spikeinterface/sortingcomponents/matching/wobble.py b/src/spikeinterface/sortingcomponents/matching/wobble.py index 8196df4dec..a8ce32dc43 100644 --- a/src/spikeinterface/sortingcomponents/matching/wobble.py +++ b/src/spikeinterface/sortingcomponents/matching/wobble.py @@ -926,7 +926,7 @@ def convolve_templates(compressed_templates, jitter_factor, approx_rank, jittere spatial_filters = spatial_overlapped[:approx_rank, visible_overlapped_channels].T spatially_filtered_template = np.matmul(visible_template, spatial_filters) scaled_filtered_template = spatially_filtered_template * singular_overlapped - for i in range(approx_rank): + for i in range(min(approx_rank, scaled_filtered_template.shape[1])): pconv[j, :] += np.convolve(scaled_filtered_template[:, i], temporal_overlapped[:, i], "full") pairwise_convolution.append(pconv) return pairwise_convolution diff --git a/src/spikeinterface/sortingcomponents/tests/test_clustering.py b/src/spikeinterface/sortingcomponents/tests/test_clustering.py index 3092becc94..427d120c5d 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_clustering.py +++ b/src/spikeinterface/sortingcomponents/tests/test_clustering.py @@ -77,6 +77,7 @@ def test_find_cluster_from_peaks(clustering_method, recording, peaks, peak_locat peaks = run_peaks(recording, job_kwargs) peak_locations = run_peak_locations(recording, peaks, job_kwargs) # method = "position_and_pca" - method = "circus" + # method = "circus" + method = "tdc_clustering" test_find_cluster_from_peaks(method, recording, peaks, peak_locations) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 66e5e87119..06dfd994f3 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -114,6 +114,10 @@ def cache_preprocessing(recording, mode="memory", memory_limit=0.5, delete_cache recording = recording.save_to_folder(**extra_kwargs) elif mode == "zarr": recording = recording.save_to_zarr(**extra_kwargs) + elif mode == "no-cache": + recording = recording + else: + raise ValueError(f"cache_preprocessing() wrong mode={mode}") return recording diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index 1cfbb81367..3512f31e84 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -190,13 +190,12 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax = self.axes.flatten()[count] for key in dp.case_keys: label = study.cases[key]["label"] - val = perfs.xs(key).loc[:, performance_name].values val = np.sort(val)[::-1] ax.plot(val, label=label) ax.set_title(performance_name) - if count == 0: - ax.legend(loc="upper right") + if count == len(dp.performance_names) - 1: + ax.legend(bbox_to_anchor=(0.05, 0.05), loc="lower left", framealpha=0.8) elif dp.mode == "snr": metric_name = dp.mode