Skip to content

Commit

Permalink
Update doc handle drift + better preset (#3232)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia authored Jul 19, 2024
1 parent 9e84a62 commit f4505e5
Show file tree
Hide file tree
Showing 35 changed files with 1,610 additions and 270 deletions.
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ spikeinterface.preprocessing
.. autofunction:: common_reference
.. autofunction:: correct_lsb
.. autofunction:: correct_motion
.. autofunction:: get_motion_presets
.. autofunction:: get_motion_parameters_preset
.. autofunction:: depth_order
.. autofunction:: detect_bad_channels
.. autofunction:: directional_derivative
Expand Down
1,474 changes: 1,353 additions & 121 deletions doc/how_to/handle_drift.rst

Large diffs are not rendered by default.

Binary file removed doc/how_to/handle_drift_files/handle_drift_13_0.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_13_1.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_13_2.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_15_0.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_15_1.png
Binary file not shown.
Binary file removed doc/how_to/handle_drift_files/handle_drift_15_2.png
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/how_to/handle_drift_files/handle_drift_17_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
85 changes: 52 additions & 33 deletions examples/how_to/handle_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: py,ipynb
# formats: py:light,ipynb
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.6
# jupytext_version: 1.16.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand Down Expand Up @@ -55,6 +55,8 @@

import spikeinterface.full as si

from spikeinterface.preprocessing import get_motion_parameters_preset, get_motion_presets

# -

base_folder = Path("/mnt/data/sam/DataSpikeSorting/imposed_motion_nick")
Expand All @@ -70,6 +72,7 @@


def preprocess_chain(rec):
rec = rec.astype('float32')
rec = si.bandpass_filter(rec, freq_min=300.0, freq_max=6000.0)
rec = si.common_reference(rec, reference="global", operator="median")
return rec
Expand All @@ -79,33 +82,46 @@ def preprocess_chain(rec):

job_kwargs = dict(n_jobs=40, chunk_duration="1s", progress_bar=True)

# ### Run motion correction with one function!
#
#
# Correcting for drift is easy! You just need to run a single function.
# We will try this function with 3 presets.
# We will try this function with some presets.
#
# Internally a preset is a dictionary of dictionaries containing all parameters for every steps.
#
# Here we also save the motion correction results into a folder to be able to load them later.

# internally, we can explore a preset like this
# every parameter can be overwritten at runtime
from spikeinterface.preprocessing.motion import motion_options_preset
# ### preset and parameters
#
# Motion correction has some steps and eevry step can be controlled by a method and related parameters.
#
# A preset is a nested dict that contains theses methods/parameters.

preset_keys = get_motion_presets()
preset_keys

one_preset_params = get_motion_parameters_preset("kilosort_like")
one_preset_params

# ### Run motion correction with one function!
#
# Correcting for drift is easy! You just need to run a single function.
# We will try this function with some presets.
#
# Here we also save the motion correction results into a folder to be able to load them later.

motion_options_preset["kilosort_like"]
# lets try theses presets
some_presets = ("rigid_fast", "kilosort_like", "nonrigid_accurate", "nonrigid_fast_and_accurate", "dredge", "dredge_fast")

# lets try theses 3 presets
some_presets = ("rigid_fast", "kilosort_like", "nonrigid_accurate")
# some_presets = ('kilosort_like', )

# compute motion with 3 presets
# compute motion with theses presets
for preset in some_presets:
print("Computing with", preset)
folder = base_folder / "motion_folder_dataset1" / preset
if folder.exists():
shutil.rmtree(folder)
recording_corrected, motion_info = si.correct_motion(
rec, preset=preset, folder=folder, output_motion_info=True, **job_kwargs
recording_corrected, motion, motion_info = si.correct_motion(
rec, preset=preset, folder=folder, output_motion=True, output_motion_info=True, **job_kwargs
)

# ### Plot the results
Expand All @@ -127,11 +143,20 @@ def preprocess_chain(rec):
# The motion vector is computed for different depths.
# The corrected peak locations are flatter than the rigid case.
# The motion vector map is still be a bit noisy at some depths (e.g around 1000um).
# * The preset **nonrigid_accurate** seems to give the best results on this recording.
# The motion vector seems less noisy globally, but it is not "perfect" (see at the top of the probe 3200um to 3800um).
# Also note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion:
# * The preset **dredge** is offcial DREDge re-implementation in spikeinterface.
# It give the best result : very fast and smooth motion estimation. Very few noise.
# This method also capture very well the non rigid motion gradient along the probe.
# The best method on the market at the moement.
# An enormous thanks to the dream team : Charlie Windolf, Julien Boussard, Erdem Varol, Liam Paninski.
# Note that in the first part of the recording before the imposed motion (0-600s) we clearly have a non-rigid motion:
# the upper part of the probe (2000-3000um) experience some drifts, but the lower part (0-1000um) is relatively stable.
# The method defined by this preset is able to capture this.
# * The preset **nonrigid_accurate** this is the ancestor of "dredge" before it was published.
# It seems to give the good results on this recording but with bit more noise.
# * The preset **dredge_fast** similar than dredge but faster (using grid_convolution).
# * The preset **nonrigid_fast_and_accurate** a variant of nonrigid_accurate but faster (using grid_convolution).
#
#

for preset in some_presets:
# load
Expand All @@ -140,8 +165,8 @@ def preprocess_chain(rec):

# and plot
fig = plt.figure(figsize=(14, 8))
si.plot_motion(
motion_info,
si.plot_motion_info(
motion_info, rec,
figure=fig,
depth_lim=(400, 600),
color_amplitude=True,
Expand Down Expand Up @@ -173,6 +198,8 @@ def preprocess_chain(rec):
folder = base_folder / "motion_folder_dataset1" / preset
motion_info = si.load_motion_info(folder)

motion = motion_info["motion"]

fig, axs = plt.subplots(ncols=2, figsize=(12, 8), sharey=True)

ax = axs[0]
Expand All @@ -190,24 +217,16 @@ def preprocess_chain(rec):

color_kargs = dict(alpha=0.2, s=2, c=c)

loc = motion_info["peak_locations"]
peak_locations = motion_info["peak_locations"]
# color='black',
ax.scatter(loc["x"][mask][sl], loc["y"][mask][sl], **color_kargs)

loc2 = correct_motion_on_peaks(
motion_info["peaks"],
motion_info["peak_locations"],
rec.sampling_frequency,
motion_info["motion"],
motion_info["temporal_bins"],
motion_info["spatial_bins"],
direction="y",
)
ax.scatter(peak_locations["x"][mask][sl], peak_locations["y"][mask][sl], **color_kargs)

peak_locations2 = correct_motion_on_peaks(peaks, peak_locations, motion,rec)

ax = axs[1]
si.plot_probe_map(rec, ax=ax)
# color='black',
ax.scatter(loc2["x"][mask][sl], loc2["y"][mask][sl], **color_kargs)
ax.scatter(peak_locations2["x"][mask][sl], peak_locations2["y"][mask][sl], **color_kargs)

ax.set_ylim(400, 600)
fig.suptitle(f"{preset=}")
Expand All @@ -228,7 +247,7 @@ def preprocess_chain(rec):
keys = run_times[0].keys()

bottom = np.zeros(len(run_times))
fig, ax = plt.subplots()
fig, ax = plt.subplots(figsize=(14, 6))
for k in keys:
rtimes = np.array([rt[k] for rt in run_times])
if np.any(rtimes > 0.0):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ test_extractors = [

test_preprocessing = [
"ibllib>=2.36.0", # for IBL
"torch",
]


Expand Down
11 changes: 8 additions & 3 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,14 @@ def make_linear_displacement(start, stop, num_step=10):
displacements : np.array
The displacements with shape (num_step, 2)
"""
displacements = (stop[np.newaxis, :] - start[np.newaxis, :]) / (num_step - 1) * np.arange(num_step)[
:, np.newaxis
] + start[np.newaxis, :]
if num_step < 1:
raise ValueError("make_linear_displacement needs num_step > 0")
if num_step == 1:
displacements = ((start + stop) / 2)[np.newaxis, :]
else:
displacements = (stop[np.newaxis, :] - start[np.newaxis, :]) / (num_step - 1) * np.arange(num_step)[
:, np.newaxis
] + start[np.newaxis, :]
return displacements


Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/generation/hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ def generate_hybrid_recording(
elif dim == 2:
raise NotImplementedError("3D motion not implemented yet")
num_step = int((stop - start)[dim] / drift_step_um)
num_step = max(1, num_step)
displacements = make_linear_displacement(start, stop, num_step=num_step)

# use templates_, because templates_array might have been scaled
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/generation/tests/test_hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_estimate_templates(create_cache_folder):


if __name__ == "__main__":
test_generate_hybrid_no_motion()
# test_generate_hybrid_no_motion()
test_generate_hybrid_motion()
test_estimate_templates()
test_generate_hybrid_with_sorting()
# test_estimate_templates()
# test_generate_hybrid_with_sorting()
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .preprocessinglist import *

from .motion import correct_motion, load_motion_info, save_motion_info
from .motion import correct_motion, load_motion_info, save_motion_info, get_motion_parameters_preset, get_motion_presets

from .preprocessing_tools import get_spatial_interpolation_kernel
from .detect_bad_channels import detect_bad_channels
Expand Down
Loading

0 comments on commit f4505e5

Please sign in to comment.