Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dredge lfp and dredge ap #3062

Merged
merged 37 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
400b0b3
initial copy/paste from dredge
samuelgarcia Jun 19, 2024
ddf094b
Move motion files into new motion subfolder
samuelgarcia Jun 21, 2024
a6424f2
Reorganize dredge function with already existing
samuelgarcia Jun 21, 2024
cabbbeb
dredge test file
samuelgarcia Jun 21, 2024
00bef34
fix
samuelgarcia Jun 21, 2024
62c75f8
Put spatial_bin_centers in get_windows()
samuelgarcia Jun 21, 2024
13c81ac
small fix
samuelgarcia Jun 21, 2024
1ac451c
motion_estimation() : refactoring party.
samuelgarcia Jun 26, 2024
532ea48
start porting dredge_ap
samuelgarcia Jun 26, 2024
c2f5289
important comments
samuelgarcia Jun 26, 2024
cab6646
wip dredge_ap
samuelgarcia Jun 27, 2024
37014bb
more refactoring and parameters change
samuelgarcia Jun 27, 2024
e2e9bff
import fix
samuelgarcia Jun 27, 2024
20a0e90
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Jun 28, 2024
ea01729
fixing dredge_ap details
samuelgarcia Jun 28, 2024
75decf6
still rafactor and fix estimate_motion() for dredge_ap
samuelgarcia Jul 3, 2024
ae18211
move doc for dredge classes
samuelgarcia Jul 3, 2024
21072ff
Merge with main
samuelgarcia Jul 4, 2024
64f3177
wip
samuelgarcia Jul 5, 2024
41e6eda
dredge_lfp doc
samuelgarcia Jul 8, 2024
5c250e1
dredge_lfp doc
samuelgarcia Jul 8, 2024
840e9c1
dredge_lfp doc
samuelgarcia Jul 8, 2024
8ef01d9
Add tutorial motion.
samuelgarcia Jul 8, 2024
35ab6b0
Merge branch 'main' into dredge_lfp
samuelgarcia Jul 11, 2024
18b7cfe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2024
de7df4b
Fix tests
alejoe91 Jul 12, 2024
261b021
Cleanup how to
alejoe91 Jul 12, 2024
41d3903
Handle negative/0 windows, contact_depth->contact_depths, add verbosi…
alejoe91 Jul 12, 2024
27a7f12
typo
alejoe91 Jul 12, 2024
96789ba
Merge branch 'main' into dredge_lfp
alejoe91 Jul 12, 2024
997cdba
Use numpy.broadcast_to for conv_engine numpy in normxcorr1d
alejoe91 Jul 12, 2024
b3c954b
Merge branch 'dredge_lfp' of github.com:samuelgarcia/spikeinterface i…
alejoe91 Jul 12, 2024
e8f4e76
formatting
alejoe91 Jul 12, 2024
7568260
Fix tdc2 and propagate to API
alejoe91 Jul 12, 2024
e318088
raise error if num_windows<1 and fix motion test
alejoe91 Jul 12, 2024
4c4e70f
Fix final(?) torch import
alejoe91 Jul 12, 2024
abb6a7c
Fix generation tests
alejoe91 Jul 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions doc/how_to/drift_with_lfp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
Estimate drift using the LFP traces
===================================

Drift is a well known issue for long shank probes. Some datasets, especially from primates and humans,
can experience very fast motion due to breathing and heart beats. In these cases, the standard motion
estimation methods that use detected spikes as a basis for motion inference will fail, because there
are not enough spikes to "follow" such fast drifts.

Charlie Windolf and colleagues from the Paninski Lab at Columbia have developed a method to estimate
the motion using the LFP signal: **DREDge**. (more details about the method in the paper
`DREDge: robust motion correction for high-density extracellular recordings across species <https://doi.org/10.1101/2023.10.24.563768>`_).

This method is particularly suited for the open dataset recorded at Massachusetts General Hospital by Angelique Paulk and colleagues in humans (more details in the [paper](https://doi.org/10.1038/s41593-021-00997-0)). The dataset can be dowloaed from [datadryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.d2547d840) and it contains recordings on human patients with a Neuropixels probe, some of which with very high and fast motion on the probe, which prevents accurate spike sorting without a proper and adequate motion correction

The **DREDge** method has two options: **dredge_lfp** and **dredge_ap**, which have both been ported inside `SpikeInterface`.

Here we will demonstrate the **dredge_lfp** method to estimate the fast and high drift on this recording.

For each patient, the dataset contains two streams:

* a highpass "action potential" (AP), sampled at 30kHz
* a lowpass "local field" (LF) sampled at 2.5kHz

For this demonstration, we will use the LF stream.

.. code:: ipython3

%matplotlib inline
%load_ext autoreload
%autoreload 2

.. code:: ipython3

from pathlib import Path
import matplotlib.pyplot as plt

import spikeinterface.full as si
from spikeinterface.sortingcomponents.motion import estimate_motion

.. code:: ipython3

# the dataset has been locally downloaded
base_folder = Path("/mnt/data/sam/DataSpikeSorting/")
np_data_drift = base_folder / 'human_neuropixel/Pt02/'

Read the spikeglx file
~~~~~~~~~~~~~~~~~~~~~~

.. code:: ipython3

raw_rec = si.read_spikeglx(np_data_drift)
print(raw_rec)


.. parsed-literal::

SpikeGLXRecordingExtractor: 384 channels - 2.5kHz - 1 segments - 2,183,292 samples
873.32s (14.56 minutes) - int16 dtype - 1.56 GiB


Preprocessing
~~~~~~~~~~~~~

Contrary to the **dredge_ap** approach, which needs detected peaks and peak locations, the **dredge_lfp**
method is estimating the motion directly on traces.
Importantly, the method requires some additional pre-processing steps:
* ``bandpass_filter``: to "focus" the signal on a particular band
* ``phase_shift``: to compensate for the sampling misalignement
* ``resample``: to further reduce the sampling fequency of the signal and speed up the computation. The sampling frequency of the estimated motion will be the same as the resampling frequency. Here we choose 250Hz, which corresponds to a sampling interval of 4ms.
* ``directional_derivative``: this optional step applies a second order derivative in the spatial dimension to enhance edges on the traces.
This is not a general rules and need to be tested case by case.
* ``average_across_direction``: Neuropixels 1.0 probes have two contacts per depth. This steps averages them to obtain a unique virtual signal along the probe depth ("y" in ``spikeinterface``).

After appying this preprocessing chain, the motion can be estimated almost by eyes ont the traces plotted with the map mode.

.. code:: ipython3

lfprec = si.bandpass_filter(
raw_rec,
freq_min=0.5,
freq_max=250,

margin_ms=1500.,
filter_order=3,
dtype="float32",
add_reflect_padding=True,
)
lfprec = si.phase_shift(lfprec)
lfprec = si.resample(lfprec, resample_rate=250, margin_ms=1000)

lfprec = si.directional_derivative(lfprec, order=2, edge_order=1)
lfprec = si.average_across_direction(lfprec)

print(lfprec)


.. parsed-literal::

AverageAcrossDirectionRecording: 192 channels - 0.2kHz - 1 segments - 218,329 samples
873.32s (14.56 minutes) - float32 dtype - 159.91 MiB


.. code:: ipython3

%matplotlib inline
si.plot_traces(lfprec, backend="matplotlib", mode="map", clim=(-0.05, 0.05), time_range=(400, 420))



.. image:: drift_with_lfp_files/drift_with_lfp_8_1.png


Run the method
~~~~~~~~~~~~~~

``estimate_motion()`` is the generic function to estimate motion with multiple
methods in ``spikeinterface``.

This function returns a ``Motion`` object and we can notice that the interval is exactly
the same as downsampled signal.

Here we use ``rigid=True``, which means that we have one unqiue signal to
describe the motion across the entire probe depth.

.. code:: ipython3

motion = estimate_motion(lfprec, method='dredge_lfp', rigid=True, progress_bar=True)
motion


.. parsed-literal::

Online chunks [10.0s each]: 0%| | 0/87 [00:00<?, ?it/s]


.. parsed-literal::

Motion rigid - interval 0.004s - 1 segments



Plot the drift
~~~~~~~~~~~~~~

When plotting the drift, we can notice a very fast drift which corresponds to the heart rate.
The slower oscillations can be attributed to the breathing signal.

We can appreciate how the estimated motion signal matches the processed LFP traces plotted above.

.. code:: ipython3

fig, ax = plt.subplots()
si.plot_motion(motion, mode='line', ax=ax)
ax.set_xlim(400, 420)
ax.set_ylim(800, 1300)


.. parsed-literal::

(800.0, 1300.0)


.. image:: drift_with_lfp_files/drift_with_lfp_12_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions doc/how_to/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
process_by_channel_group
load_your_data_into_sorting
benchmark_with_hybrid_recordings
drift_with_lfp
2 changes: 1 addition & 1 deletion doc/modules/sortingcomponents.rst
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ Here is an example with non-rigid motion estimation:
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
motion, temporal_bins, spatial_bins,
extra_check = estimate_motion(recording=recording, peaks=peaks, peak_locations=peak_locations,
direction='y', bin_duration_s=10., bin_um=10., margin_um=0.,
direction='y', bin_s=10., bin_um=10., margin_um=0.,
method='decentralized_registration',
rigid=False, win_shape='gaussian', win_step_um=50., win_sigma_um=150.,
progress_bar=True, verbose=True)
Expand Down
112 changes: 112 additions & 0 deletions examples/how_to/drift_with_lfp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: ipynb,py
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# # Estimate drift using the LFP traces
#
# Drift is a well known issue for long shank probes. Some datasets, especially from primates and humans, can experience very fast motion due to breathing and heart beats. In these cases, the standard motion estimation methods that use detected spikes as a basis for motion inference will fail, because there are not enough spikes to "follow" such fast drifts.
#
# Charlie Windolf and colleagues from the Paninski Lab at Columbia have developed a method to estimate the motion using the LFP signal: **DREDge**. (more details about the method in the paper [DREDge: robust motion correction for high-density extracellular recordings across species](https://doi.org/10.1101/2023.10.24.563768)).
#
# This method is particularly suited for the open dataset recorded at Massachusetts General Hospital by Angelique Paulk and colleagues in humans (more details in the [paper](https://doi.org/10.1038/s41593-021-00997-0)). The dataset can be dowloaed from [datadryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.d2547d840) and it contains recordings on human patients with a Neuropixels probe, some of which with very high and fast motion on the probe, which prevents accurate spike sorting without a proper and adequate motion correction
#
# The **DREDge** method has two options: **dredge_lfp** and **dredge_ap**, which have both been ported inside `SpikeInterface`.
#
# Here we will demonstrate the **dredge_lfp** method to estimate the fast and high drift on this recording.
#
# For each patient, the dataset contains two streams:
#
# * a highpass "action potential" (AP), sampled at 30kHz
# * a lowpass "local field" (LF) sampled at 2.5kHz
#
# For this demonstration, we will use the LF stream.

# %matplotlib inline
# %load_ext autoreload
# %autoreload 2

# +
from pathlib import Path
import matplotlib.pyplot as plt

import spikeinterface.full as si
from spikeinterface.sortingcomponents.motion import estimate_motion
# -

# the dataset has been downloaded locally
base_folder = Path("/mnt/data/sam/DataSpikeSorting/")
np_data_drift = base_folder / 'human_neuropixel" / "Pt02"

# ### Read the spikeglx file

raw_rec = si.read_spikeglx(np_data_drift)
print(raw_rec)

# ### Preprocessing
#
# Contrary to the **dredge_ap** approach, which needs detected peaks and peak locations, the **dredge_lfp** method is estimating the motion directly on traces.
# Importantly, the method requires some additional pre-processing steps:
# * `bandpass_filter`: to "focus" the signal on a particular band
# * `phase_shift`: to compensate for the sampling misalignement
# * `resample`: to further reduce the sampling fequency of the signal and speed up the computation. The sampling frequency of the estimated motion will be the same as the resampling frequency. Here we choose 250Hz, which corresponds to a sampling interval of 4ms.
# * `directional_derivative`: this optional step applies a second order derivative in the spatial dimension to enhance edges on the traces.
# This is not a general rules and need to be tested case by case.
# * `average_across_direction`: Neuropixels 1.0 probes have two contacts per depth. This steps averages them to obtain a unique virtual signal along the probe depth ("y" in `spikeinterface`).
#
# After appying this preprocessing chain, the motion can be estimated almost by eyes ont the traces plotted with the map mode.

# +
lfprec = si.bandpass_filter(
raw_rec,
freq_min=0.5,
freq_max=250,
margin_ms=1500.,
filter_order=3,
dtype="float32",
add_reflect_padding=True,
)
lfprec = si.phase_shift(lfprec)
lfprec = si.resample(lfprec, resample_rate=250, margin_ms=1000)

lfprec = si.directional_derivative(lfprec, order=2, edge_order=1)
lfprec = si.average_across_direction(lfprec)

print(lfprec)
# -

# %matplotlib inline
si.plot_traces(lfprec, backend="matplotlib", mode="map", clim=(-0.05, 0.05), time_range=(400, 420))

# ### Run the method
#
# `estimate_motion()` is the generic function to estimate motion with multiple methods in `spikeinterface`.
#
# This function returns a `Motion` object and we can notice that the interval is exactly the same as downsampled signal.
#
# Here we use `rigid=True`, which means that we have one unqiue signal to describe the motion across the entire probe depth.

motion = estimate_motion(lfprec, method='dredge_lfp', rigid=True, progress_bar=True)
motion

# ### Plot the drift
#
# When plotting the drift, we can notice a very fast drift which corresponds to the heart rate. The slower oscillations can be attributed to the breathing signal.
#
# We can appreciate how the estimated motion signal matches the processed LFP traces plotted above.

fig, ax = plt.subplots()
si.plot_motion(motion, mode='line', ax=ax)
ax.set_xlim(400, 420)
ax.set_ylim(800, 1300)
103 changes: 103 additions & 0 deletions examples/tutorials/sortingcomponents/plot_1_estimate_motion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Motion estimation
=================

SpikeInterface offers a very flexible framework to handle drift as a
preprocessing step. If you want to know more, please read the
:ref:`motion_correction` section of the documentation.

Here a short example with a simulated drifting recording.

"""

# %%
import matplotlib.pyplot as plt


from spikeinterface.generation import generate_drifting_recording
from spikeinterface.preprocessing import correct_motion
from spikeinterface.widgets import plot_motion, plot_motion_info, plot_probe_map

# %%
# First, let's simulate a drifting recording using the
# :code:`spikeinterface.generation module`.
#
# Here the simulated recording has a small zigzag motion along the 'y' axis of the probe.

static_recording, drifting_recording, sorting = generate_drifting_recording(
num_units=200,
duration=300.,
probe_name='Neuropixel-128',
generate_displacement_vector_kwargs=dict(
displacement_sampling_frequency=5.0,
drift_start_um=[0, 20],
drift_stop_um=[0, -20],
drift_step_um=1,
motion_list=[
dict(
drift_mode="zigzag",
non_rigid_gradient=None,
t_start_drift=60.0,
t_end_drift=None,
period_s=200,
),
],
),
seed=2205,
)

plot_probe_map(drifting_recording)

# %%
# Here we will use the high level function :code:`correct_motion()`
#
# Internally, this function is doing all steps of the motion detection:
# 1. **activity profile** : detect peaks and localize them along time and depth
# 2. **motion inference**: estimate the drift motion
# 3. **motion interpolation**: interpolate traces using the estimated motion
#
# All steps have an use several methods with many parameters. This is why we can use
# 'preset' which combine methods and related parameters.
#
# This function can take a while peak detection and localization is a slow process
# that need to go trought the entire traces

recording_corrected, motion, motion_info = correct_motion(
drifting_recording, preset="nonrigid_fast_and_accurate",
output_motion=True, output_motion_info=True,
n_jobs=-1, progress_bar=True,
)

# %%
# The function return a recording 'corrected'
#
# A new recording is return, this recording will interpolate motion corrected traces
# when calling get_traces()

print(recording_corrected)

# %%
# Optionally the function also return the `Motion` object itself
#

print(motion)

# %%
# This motion can be plotted, in our case the motion has been estimated as non-rigid
# so we can use the use the `mode='map'` to check the motion across depth.
#

plot_motion(motion, mode='line')
plot_motion(motion, mode='map')


# %%
# The dict `motion_info` can be used for more plotting.
# Here we can appreciate of the two top axes the raster of peaks depth vs times before and
# after correction.

fig = plt.figure()
plot_motion_info(motion_info, drifting_recording, amplitude_cmap="inferno", color_amplitude=True, figure=fig)
fig.axes[0].set_ylim(520, 620)
plt.show()
# %%
Loading
Loading