Skip to content

Commit

Permalink
Update how to and fix displacement_vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Jun 28, 2024
1 parent 5186f74 commit 8b16006
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 79 deletions.
179 changes: 103 additions & 76 deletions examples/how_to/benchmark_with_hybrid_recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: ipynb,py:percent
# formats: ipynb,py
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %%
# # Benchmark spike sorting with hybrid recordings
#
# This example shows how to use the SpikeInterface hybrid recordings framework to benchmark spike sorting results.
Expand All @@ -38,84 +37,67 @@
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# -

# %matplotlib inline

si.set_global_job_kwargs(n_jobs=16)

# To make this notebook self-contained, we will simulate a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator.
# For this notebook, we will use a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator.

# +
generate_displacement_vector_kwargs = {
"motion_list": [
{
"drift_mode": "zigzag",
"non_rigid_gradient": None,
"t_start_drift": 30.0,
"t_end_drift": 210,
"period_s": 60,
}
],
"drift_start_um": [0, 60],
"drift_stop_um": [0, -60],
}

# this generates a "static" and "drifting" recording version
static_recording, drifting_recording, sorting = sgen.generate_drifting_recording(
probe_name="Neuropixel-384", # ,'Neuropixel-384',
seed=23,
duration=240,
num_units=100,
generate_displacement_vector_kwargs=generate_displacement_vector_kwargs,
)
workdir = Path("/ssd980/working/hybrid/steinmetz_imposed_motion")
workdir.mkdir(exist_ok=True)

# we sort the channels by depth, to match the hybrid templates
drifting_recording = spre.depth_order(drifting_recording)
# -
recording_np1_imposed = se.read_spikeglx("/hdd1/data/spikeglx/nick-steinmetz/dataset1/p1_g0_t0/")
recording_preproc = spre.highpass_filter(recording_np1_imposed)
recording_preproc = spre.common_reference(recording_preproc)

# To visualize the drift, we can estimate the motion and plot it:

# to correct for drift, we need a float dtype
recording_preproc = spre.astype(recording_preproc, "float")
_, motion_info = spre.correct_motion(
drifting_recording, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True
recording_preproc, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True
)


ax = sw.plot_drift_map(
ax = sw.plot_drift_raster_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=drifting_recording,
recording=recording_preproc,
cmap="Greys_r",
scatter_decimate=10,
depth_lim=(-10, 3000)
)

# ## Retrieve templates from database

# +
templates_info = sgen.fetch_templates_database_info()

print(len(templates_info))

templates_info.head()
print(f"Number of templates in database: {len(templates_info)}")
print(f"Template database columns: {templates_info.columns}")
# -

available_brain_areas = np.unique(templates_info.brain_area)
print(f"Available brain areas: {available_brain_areas}")

# let's perform a query: templates from brain region VISp5 and at the "top" of the probe
# Let's perform a query: templates from visual brain regions and at the "top" of the probe

target_area = ["VISa5", "VISa6a", "VISp5", "VISp6a", "VISrl6b"]
minimum_depth = 1500
templates_selected_info = templates_info.query(f"brain_area in {target_area} and depth_along_probe > {minimum_depth}")
len(templates_selected_info)

# We can now retrieve the selected templates as a `Templates` object
#
# We can now retrieve the selected templates as a `Templates` object:

templates_selected = sgen.query_templates_from_database(templates_selected_info, verbose=True)
print(templates_selected)

# While we selected templates from a target aread and at certain depths, we can see that the template amplitudes are quite large. This will make spike sorting easy... we can further manipulate the `Templates` by rescaling, relocating, or further selections with the `sgen.scale_template_to_range`, `sgen.relocate_templates`, and `sgen.select_templates` functions.
#
# In our case, let's rescale the amplitudes between 50 and 150 $\mu$V and relocate them throughout the entire depth of the probe.
# In our case, let's rescale the amplitudes between 50 and 150 $\mu$V and relocate them towards the bottom half of the probe, where the activity looks interesting!

# +
min_amplitude = 50
max_amplitude = 150
templates_scaled = sgen.scale_template_to_range(
Expand All @@ -124,13 +106,14 @@
max_amplitude=max_amplitude
)

min_displacement = 200
min_displacement = 1000
max_displacement = 4000
templates_relocated = sgen.relocate_templates(
templates=templates_scaled,
min_displacement=min_displacement,
max_displacement=max_displacement
)
# -

# Let's plot the selected templates:

Expand All @@ -140,73 +123,99 @@
w.figure.subplots_adjust(wspace=0.5, hspace=0.7)

# ## Constructing hybrid recordings
#
# We can construct now hybrid recordings with the selected templates.
#
# We will do this in two ways to show how important it is to account for drifts when injecting hybrid spikes.
#
# - For the first recording we will not pass the estimated motion (`recording_hybrid_ignore_drift`).
# - For the second recording, we will pass and account for the estimated motion (`recording_hybrid_with_drift`).

recording_hybrid_no_drift, sorting_hybrid = sgen.generate_hybrid_recording(
recording=drifting_recording, templates=templates_relocated, seed=2308
recording_hybrid_ignore_drift, sorting_hybrid = sgen.generate_hybrid_recording(
recording=recording_preproc, templates=templates_relocated, seed=2308
)
recording_hybrid_no_drift
recording_hybrid_ignore_drift

# Note that the `generate_hybrid_recording` is warning us that we might want to account for drift!

recording_hybrid, sorting_hybrid = sgen.generate_hybrid_recording(
recording=drifting_recording,
# by passing the `sorting_hybrid` object, we make sure that injected spikes are the same
# this will take a bit more time because it's interpolating the templates to account for drifts
recording_hybrid_with_drift, sorting_hybrid = sgen.generate_hybrid_recording(
recording=recording_preproc,
templates=templates_relocated,
motion=motion_info["motion"],
sorting=sorting_hybrid,
seed=2308,
)
recording_hybrid
recording_hybrid_with_drift

# We can use the `SortingAnalyzer` to estimate spike locations and plot them:

# +
# construct analyzers and compute spike locations
analyzer_hybrid = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid)
analyzer_hybrid.compute(["random_spikes", "templates"])
analyzer_hybrid.compute("spike_locations", method="grid_convolution")
analyzer_hybrid_ignore_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_ignore_drift)
analyzer_hybrid_ignore_drift.compute(["random_spikes", "templates"])
analyzer_hybrid_ignore_drift.compute("spike_locations", method="grid_convolution")

# construct hybrid analyzer for spike locations
analyzer_hybrid_no_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_no_drift)
analyzer_hybrid_no_drift.compute(["random_spikes", "templates"])
analyzer_hybrid_no_drift.compute("spike_locations", method="grid_convolution")
analyzer_hybrid_with_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_with_drift)
analyzer_hybrid_with_drift.compute(["random_spikes", "templates"])
analyzer_hybrid_with_drift.compute("spike_locations", method="grid_convolution")
# -

# Let's plot the added hybrid spikes using the drift maps:

fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
_ = sw.plot_drift_map(
fig, axs = plt.subplots(ncols=2, figsize=(10, 7), sharex=True, sharey=True)
_ = sw.plot_drift_raster_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=drifting_recording,
recording=recording_preproc,
cmap="Greys_r",
scatter_decimate=10,
ax=axs[0],
)
_ = sw.plot_drift_map(analyzer=analyzer_hybrid_no_drift, color_amplitude=False, color="r", ax=axs[0])
_ = sw.plot_drift_map(
_ = sw.plot_drift_raster_map(
sorting_analyzer=analyzer_hybrid_ignore_drift,
color_amplitude=False,
color="r",
scatter_decimate=10,
ax=axs[0]
)
_ = sw.plot_drift_raster_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=drifting_recording,
recording=recording_preproc,
cmap="Greys_r",
scatter_decimate=10,
ax=axs[1],
)
_ = sw.plot_drift_map(analyzer=analyzer_hybrid, color_amplitude=False, color="b", ax=axs[1])
axs[0].set_title("Ignoring drift")
axs[1].set_title("Accounting for drift")
_ = sw.plot_drift_raster_map(
sorting_analyzer=analyzer_hybrid_with_drift,
color_amplitude=False,
color="b",
scatter_decimate=10,
ax=axs[1]
)
axs[0].set_title("Hybrid spikes\nIgnoring drift")
axs[1].set_title("Hybrid spikes\nAccounting for drift")
axs[0].set_xlim(1000, 1500)
axs[0].set_ylim(500, 2500)

# We can see that clearly following drift is essential in order to properly blend the hybrid spikes into the recording!

# ## Ground-truth study
#
# In this section we will use the hybrid recording to benchmark a few spike sorters:
#
# - `Kilosort2.5`
# - `Kilosort3`
# - `Kilosort4`
# - `Spyking-CIRCUS 2`

# +
# import shutil
# shutil.rmtree(study_folder)
# -

workdir = Path("/ssd980/working/hybrid/drift")
workdir.mkdir(exist_ok=True)

# to speed up computations, let's first dump the recording to binary
recording_hybrid_bin = recording_hybrid.save(folder=workdir / "hybrid_bin", overwrite=True)
recording_hybrid_bin = recording_hybrid_with_drift.save(
folder=workdir / "hybrid_bin",
overwrite=True
)

# +
datasets = {
Expand Down Expand Up @@ -249,10 +258,28 @@
else:
gtstudy = sc.GroundTruthStudy.create(study_folder=study_folder, datasets=datasets, cases=cases)

# run the spike sorting jobs
gtstudy.run_sorters(verbose=True, keep=False)

# run the comparisons
gtstudy.run_comparisons(exhaustive_gt=False)

w_run_times = sw.plot_study_run_times(gtstudy)
# ## Plot performances

w_perf = sw.plot_study_performances(gtstudy, figsize=(12, 7))
w_perf.axes[0, 0].legend(loc=4)

# And the winner of the hybrid study challenge is...`SpyKING-CIRCUS 2` 🎉🎉🎉
#
# In this example, we showed how to:
#
# - Access and fetch templates from the SpikeInterface template database
# - Manipulate templates (scaling/relocating)
# - Construct hybrid recordings accounting for drifts
# - Use the `GroundTruthStudy` to benchmark different sorters
#
# The hybrid framework can be extended to target multiple recordings from different brain regions and species and creating recordings of increasing complexity to challenge the existing sorters!
#
# In addition, hybrid studies can also be used to fine-tune spike sorting parameters on specific datasets.
#
# **Are you ready to try it on your data?**
6 changes: 4 additions & 2 deletions src/spikeinterface/generation/hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,11 +531,13 @@ def generate_hybrid_recording(
displacement_vector = np.zeros((len(temporal_bins_segment), 2, num_units))
for unit_index in range(num_units):
motion_for_unit = motion.get_displacement_at_time_and_depth(
times=temporal_bins_segment,
times_s=temporal_bins_segment,
locations_um=unit_locations[unit_index],
segment_index=segment_index,
grid=True,
)
displacement_vector[:, motion.dim, unit_index] = motion_for_unit
displacement_vector[:, motion.dim, unit_index] = motion_for_unit[motion.dim, :]
displacement_vectors.append(displacement_vector)
# since displacement is estimated by interpolation for each unit, the unit factor is an eye
displacement_unit_factor = np.eye(num_units)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde
else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1.
segment_index: int, default: None
The index of the segment to evaluate. If None, and there is only one segment, then that segment is used.
grid : bool
grid : bool, default: False
If grid=False, the default, then times_s and locations_um should have the same one-dimensional
shape, and the returned displacement[i] is the displacement at time times_s[i] and location
locations_um[i].
Expand Down

0 comments on commit 8b16006

Please sign in to comment.