From 8b16006a1706909f7e92657f9510f984db0ea8cd Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 28 Jun 2024 15:30:40 +0200 Subject: [PATCH] Update how to and fix displacement_vectors --- .../benchmark_with_hybrid_recordings.py | 179 ++++++++++-------- src/spikeinterface/generation/hybrid_tools.py | 6 +- .../sortingcomponents/motion_utils.py | 2 +- 3 files changed, 108 insertions(+), 79 deletions(-) diff --git a/examples/how_to/benchmark_with_hybrid_recordings.py b/examples/how_to/benchmark_with_hybrid_recordings.py index 3c4efe629c..a7eb2c467a 100644 --- a/examples/how_to/benchmark_with_hybrid_recordings.py +++ b/examples/how_to/benchmark_with_hybrid_recordings.py @@ -2,11 +2,11 @@ # jupyter: # jupytext: # cell_metadata_filter: -all -# formats: ipynb,py:percent +# formats: ipynb,py # text_representation: # extension: .py -# format_name: percent -# format_version: '1.3' +# format_name: light +# format_version: '1.5' # jupytext_version: 1.16.2 # kernelspec: # display_name: Python 3 (ipykernel) @@ -14,7 +14,6 @@ # name: python3 # --- -# %% # # Benchmark spike sorting with hybrid recordings # # This example shows how to use the SpikeInterface hybrid recordings framework to benchmark spike sorting results. @@ -38,84 +37,67 @@ import numpy as np import matplotlib.pyplot as plt from pathlib import Path - # - # %matplotlib inline si.set_global_job_kwargs(n_jobs=16) -# To make this notebook self-contained, we will simulate a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator. +# For this notebook, we will use a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator. -# + -generate_displacement_vector_kwargs = { - "motion_list": [ - { - "drift_mode": "zigzag", - "non_rigid_gradient": None, - "t_start_drift": 30.0, - "t_end_drift": 210, - "period_s": 60, - } - ], - "drift_start_um": [0, 60], - "drift_stop_um": [0, -60], -} - -# this generates a "static" and "drifting" recording version -static_recording, drifting_recording, sorting = sgen.generate_drifting_recording( - probe_name="Neuropixel-384", # ,'Neuropixel-384', - seed=23, - duration=240, - num_units=100, - generate_displacement_vector_kwargs=generate_displacement_vector_kwargs, -) +workdir = Path("/ssd980/working/hybrid/steinmetz_imposed_motion") +workdir.mkdir(exist_ok=True) -# we sort the channels by depth, to match the hybrid templates -drifting_recording = spre.depth_order(drifting_recording) -# - +recording_np1_imposed = se.read_spikeglx("/hdd1/data/spikeglx/nick-steinmetz/dataset1/p1_g0_t0/") +recording_preproc = spre.highpass_filter(recording_np1_imposed) +recording_preproc = spre.common_reference(recording_preproc) # To visualize the drift, we can estimate the motion and plot it: +# to correct for drift, we need a float dtype +recording_preproc = spre.astype(recording_preproc, "float") _, motion_info = spre.correct_motion( - drifting_recording, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True + recording_preproc, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True ) - -ax = sw.plot_drift_map( +ax = sw.plot_drift_raster_map( peaks=motion_info["peaks"], peak_locations=motion_info["peak_locations"], - recording=drifting_recording, + recording=recording_preproc, cmap="Greys_r", + scatter_decimate=10, + depth_lim=(-10, 3000) ) # ## Retrieve templates from database +# + templates_info = sgen.fetch_templates_database_info() -print(len(templates_info)) - -templates_info.head() +print(f"Number of templates in database: {len(templates_info)}") +print(f"Template database columns: {templates_info.columns}") +# - available_brain_areas = np.unique(templates_info.brain_area) print(f"Available brain areas: {available_brain_areas}") -# let's perform a query: templates from brain region VISp5 and at the "top" of the probe +# Let's perform a query: templates from visual brain regions and at the "top" of the probe + target_area = ["VISa5", "VISa6a", "VISp5", "VISp6a", "VISrl6b"] minimum_depth = 1500 templates_selected_info = templates_info.query(f"brain_area in {target_area} and depth_along_probe > {minimum_depth}") len(templates_selected_info) -# We can now retrieve the selected templates as a `Templates` object -# +# We can now retrieve the selected templates as a `Templates` object: templates_selected = sgen.query_templates_from_database(templates_selected_info, verbose=True) print(templates_selected) # While we selected templates from a target aread and at certain depths, we can see that the template amplitudes are quite large. This will make spike sorting easy... we can further manipulate the `Templates` by rescaling, relocating, or further selections with the `sgen.scale_template_to_range`, `sgen.relocate_templates`, and `sgen.select_templates` functions. # -# In our case, let's rescale the amplitudes between 50 and 150 $\mu$V and relocate them throughout the entire depth of the probe. +# In our case, let's rescale the amplitudes between 50 and 150 $\mu$V and relocate them towards the bottom half of the probe, where the activity looks interesting! +# + min_amplitude = 50 max_amplitude = 150 templates_scaled = sgen.scale_template_to_range( @@ -124,13 +106,14 @@ max_amplitude=max_amplitude ) -min_displacement = 200 +min_displacement = 1000 max_displacement = 4000 templates_relocated = sgen.relocate_templates( templates=templates_scaled, min_displacement=min_displacement, max_displacement=max_displacement ) +# - # Let's plot the selected templates: @@ -140,73 +123,99 @@ w.figure.subplots_adjust(wspace=0.5, hspace=0.7) # ## Constructing hybrid recordings +# +# We can construct now hybrid recordings with the selected templates. +# +# We will do this in two ways to show how important it is to account for drifts when injecting hybrid spikes. +# +# - For the first recording we will not pass the estimated motion (`recording_hybrid_ignore_drift`). +# - For the second recording, we will pass and account for the estimated motion (`recording_hybrid_with_drift`). -recording_hybrid_no_drift, sorting_hybrid = sgen.generate_hybrid_recording( - recording=drifting_recording, templates=templates_relocated, seed=2308 +recording_hybrid_ignore_drift, sorting_hybrid = sgen.generate_hybrid_recording( + recording=recording_preproc, templates=templates_relocated, seed=2308 ) -recording_hybrid_no_drift +recording_hybrid_ignore_drift + +# Note that the `generate_hybrid_recording` is warning us that we might want to account for drift! -recording_hybrid, sorting_hybrid = sgen.generate_hybrid_recording( - recording=drifting_recording, +# by passing the `sorting_hybrid` object, we make sure that injected spikes are the same +# this will take a bit more time because it's interpolating the templates to account for drifts +recording_hybrid_with_drift, sorting_hybrid = sgen.generate_hybrid_recording( + recording=recording_preproc, templates=templates_relocated, motion=motion_info["motion"], sorting=sorting_hybrid, seed=2308, ) -recording_hybrid +recording_hybrid_with_drift +# We can use the `SortingAnalyzer` to estimate spike locations and plot them: + +# + # construct analyzers and compute spike locations -analyzer_hybrid = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid) -analyzer_hybrid.compute(["random_spikes", "templates"]) -analyzer_hybrid.compute("spike_locations", method="grid_convolution") +analyzer_hybrid_ignore_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_ignore_drift) +analyzer_hybrid_ignore_drift.compute(["random_spikes", "templates"]) +analyzer_hybrid_ignore_drift.compute("spike_locations", method="grid_convolution") -# construct hybrid analyzer for spike locations -analyzer_hybrid_no_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_no_drift) -analyzer_hybrid_no_drift.compute(["random_spikes", "templates"]) -analyzer_hybrid_no_drift.compute("spike_locations", method="grid_convolution") +analyzer_hybrid_with_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_with_drift) +analyzer_hybrid_with_drift.compute(["random_spikes", "templates"]) +analyzer_hybrid_with_drift.compute("spike_locations", method="grid_convolution") +# - # Let's plot the added hybrid spikes using the drift maps: -fig, axs = plt.subplots(ncols=2, figsize=(10, 7)) -_ = sw.plot_drift_map( +fig, axs = plt.subplots(ncols=2, figsize=(10, 7), sharex=True, sharey=True) +_ = sw.plot_drift_raster_map( peaks=motion_info["peaks"], peak_locations=motion_info["peak_locations"], - recording=drifting_recording, + recording=recording_preproc, cmap="Greys_r", + scatter_decimate=10, ax=axs[0], ) -_ = sw.plot_drift_map(analyzer=analyzer_hybrid_no_drift, color_amplitude=False, color="r", ax=axs[0]) -_ = sw.plot_drift_map( +_ = sw.plot_drift_raster_map( + sorting_analyzer=analyzer_hybrid_ignore_drift, + color_amplitude=False, + color="r", + scatter_decimate=10, + ax=axs[0] +) +_ = sw.plot_drift_raster_map( peaks=motion_info["peaks"], peak_locations=motion_info["peak_locations"], - recording=drifting_recording, + recording=recording_preproc, cmap="Greys_r", + scatter_decimate=10, ax=axs[1], ) -_ = sw.plot_drift_map(analyzer=analyzer_hybrid, color_amplitude=False, color="b", ax=axs[1]) -axs[0].set_title("Ignoring drift") -axs[1].set_title("Accounting for drift") +_ = sw.plot_drift_raster_map( + sorting_analyzer=analyzer_hybrid_with_drift, + color_amplitude=False, + color="b", + scatter_decimate=10, + ax=axs[1] +) +axs[0].set_title("Hybrid spikes\nIgnoring drift") +axs[1].set_title("Hybrid spikes\nAccounting for drift") +axs[0].set_xlim(1000, 1500) +axs[0].set_ylim(500, 2500) # We can see that clearly following drift is essential in order to properly blend the hybrid spikes into the recording! # ## Ground-truth study # # In this section we will use the hybrid recording to benchmark a few spike sorters: +# # - `Kilosort2.5` # - `Kilosort3` # - `Kilosort4` # - `Spyking-CIRCUS 2` -# + -# import shutil -# shutil.rmtree(study_folder) -# - - -workdir = Path("/ssd980/working/hybrid/drift") -workdir.mkdir(exist_ok=True) - # to speed up computations, let's first dump the recording to binary -recording_hybrid_bin = recording_hybrid.save(folder=workdir / "hybrid_bin", overwrite=True) +recording_hybrid_bin = recording_hybrid_with_drift.save( + folder=workdir / "hybrid_bin", + overwrite=True +) # + datasets = { @@ -249,10 +258,28 @@ else: gtstudy = sc.GroundTruthStudy.create(study_folder=study_folder, datasets=datasets, cases=cases) +# run the spike sorting jobs gtstudy.run_sorters(verbose=True, keep=False) +# run the comparisons gtstudy.run_comparisons(exhaustive_gt=False) -w_run_times = sw.plot_study_run_times(gtstudy) +# ## Plot performances + w_perf = sw.plot_study_performances(gtstudy, figsize=(12, 7)) w_perf.axes[0, 0].legend(loc=4) + +# And the winner of the hybrid study challenge is...`SpyKING-CIRCUS 2` 🎉🎉🎉 +# +# In this example, we showed how to: +# +# - Access and fetch templates from the SpikeInterface template database +# - Manipulate templates (scaling/relocating) +# - Construct hybrid recordings accounting for drifts +# - Use the `GroundTruthStudy` to benchmark different sorters +# +# The hybrid framework can be extended to target multiple recordings from different brain regions and species and creating recordings of increasing complexity to challenge the existing sorters! +# +# In addition, hybrid studies can also be used to fine-tune spike sorting parameters on specific datasets. +# +# **Are you ready to try it on your data?** diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index 6a0c702001..9dc0cf2310 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -531,11 +531,13 @@ def generate_hybrid_recording( displacement_vector = np.zeros((len(temporal_bins_segment), 2, num_units)) for unit_index in range(num_units): motion_for_unit = motion.get_displacement_at_time_and_depth( - times=temporal_bins_segment, + times_s=temporal_bins_segment, locations_um=unit_locations[unit_index], segment_index=segment_index, + grid=True, ) - displacement_vector[:, motion.dim, unit_index] = motion_for_unit + displacement_vector[:, motion.dim, unit_index] = motion_for_unit[motion.dim, :] + displacement_vectors.append(displacement_vector) # since displacement is estimated by interpolation for each unit, the unit factor is an eye displacement_unit_factor = np.eye(num_units) diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index 39991f4e52..a8de3f6d13 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -96,7 +96,7 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. segment_index: int, default: None The index of the segment to evaluate. If None, and there is only one segment, then that segment is used. - grid : bool + grid : bool, default: False If grid=False, the default, then times_s and locations_um should have the same one-dimensional shape, and the returned displacement[i] is the displacement at time times_s[i] and location locations_um[i].