Skip to content

Commit

Permalink
Merge pull request #2915 from cwindolf/motion_object
Browse files Browse the repository at this point in the history
Internal motion API, aka `Motion` object
  • Loading branch information
alejoe91 authored Jun 19, 2024
2 parents f550802 + 1c6dcf4 commit a233697
Show file tree
Hide file tree
Showing 19 changed files with 965 additions and 445 deletions.
22 changes: 9 additions & 13 deletions doc/modules/motion_correction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,19 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte
max_distance_um=150.0, **job_kwargs)
# Step 2: motion inference
motion, temporal_bins, spatial_bins = estimate_motion(recording=rec,
peaks=peaks,
peak_locations=peak_locations,
method="decentralized",
direction="y",
bin_duration_s=2.0,
bin_um=5.0,
win_step_um=50.0,
win_sigma_um=150.0)
motion = estimate_motion(recording=rec,
peaks=peaks,
peak_locations=peak_locations,
method="decentralized",
direction="y",
bin_duration_s=2.0,
bin_um=5.0,
win_step_um=50.0,
win_sigma_um=150.0)
# Step 3: motion interpolation
# this step is lazy
rec_corrected = interpolate_motion(recording=rec, motion=motion,
temporal_bins=temporal_bins,
spatial_bins=spatial_bins,
border_mode="remove_channels",
spatial_interpolation_method="kriging",
sigma_um=30.)
Expand Down Expand Up @@ -220,8 +218,6 @@ different preprocessing chains: one for motion correction and one for spike sort
rec_corrected2 = interpolate_motion(
recording=rec2,
motion=motion_info['motion'],
temporal_bins=motion_info['temporal_bins'],
spatial_bins=motion_info['spatial_bins'],
**motion_info['parameters']['interpolate_motion_kwargs'])
sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2)
Expand Down
37 changes: 16 additions & 21 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import time
from pathlib import Path

import numpy as np
import json
from pathlib import Path
import time

from spikeinterface.core import get_noise_levels, fix_job_kwargs
from spikeinterface.core.job_tools import _shared_job_kwargs_doc
from spikeinterface.core.core_tools import SIJsonEncoder
from spikeinterface.core.job_tools import _shared_job_kwargs_doc

motion_options_preset = {
# This preset should be the most acccurate
Expand Down Expand Up @@ -68,7 +68,7 @@
weight_with_amplitude=False,
),
"interpolate_motion_kwargs": dict(
direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
),
},
"nonrigid_fast_and_accurate": {
Expand Down Expand Up @@ -127,7 +127,7 @@
weight_with_amplitude=False,
),
"interpolate_motion_kwargs": dict(
direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
),
},
# This preset is a super fast rigid estimation with center of mass
Expand All @@ -152,7 +152,7 @@
rigid=True,
),
"interpolate_motion_kwargs": dict(
direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
),
},
# This preset try to mimic kilosort2.5 motion estimator
Expand Down Expand Up @@ -186,7 +186,7 @@
win_shape="rect",
),
"interpolate_motion_kwargs": dict(
direction=1, border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2
),
},
# empty preset
Expand Down Expand Up @@ -276,9 +276,8 @@ def correct_motion(
recording_corrected : Recording
The motion corrected recording
motion_info : dict
Optional output if `output_motion_info=True`
Optional output if `output_motion_info=True`. The key "motion" holds the Motion object.
"""

# local import are important because "sortingcomponents" is not important by default
from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods
from spikeinterface.sortingcomponents.peak_selection import select_peaks
Expand Down Expand Up @@ -377,30 +376,22 @@ def correct_motion(
np.save(folder / "peak_locations.npy", peak_locations)

t0 = time.perf_counter()
motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs)
motion = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs)
t1 = time.perf_counter()
run_times["estimate_motion"] = t1 - t0

recording_corrected = InterpolateMotionRecording(
recording, motion, temporal_bins, spatial_bins, **interpolate_motion_kwargs
)
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)

if folder is not None:
(folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8")

np.save(folder / "temporal_bins.npy", temporal_bins)
np.save(folder / "motion.npy", motion)
if spatial_bins is not None:
np.save(folder / "spatial_bins.npy", spatial_bins)
motion.save(folder / "motion")

if output_motion_info:
motion_info = dict(
parameters=parameters,
run_times=run_times,
peaks=peaks,
peak_locations=peak_locations,
temporal_bins=temporal_bins,
spatial_bins=spatial_bins,
motion=motion,
)
return recording_corrected, motion_info
Expand All @@ -419,6 +410,8 @@ def correct_motion(


def load_motion_info(folder):
from spikeinterface.sortingcomponents.motion_utils import Motion

folder = Path(folder)

motion_info = {}
Expand All @@ -429,11 +422,13 @@ def load_motion_info(folder):
with open(folder / "run_times.json") as f:
motion_info["run_times"] = json.load(f)

array_names = ("peaks", "peak_locations", "temporal_bins", "spatial_bins", "motion")
array_names = ("peaks", "peak_locations")
for name in array_names:
if (folder / f"{name}.npy").exists():
motion_info[name] = np.load(folder / f"{name}.npy")
else:
motion_info[name] = None

motion_info["motion"] = Motion.load(folder / "motion")

return motion_info
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/preprocessing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def get_spatial_interpolation_kernel(

elif method == "idw":
distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean")
interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64")
interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype)
for c in range(target_location.shape[0]):
ind_sorted = np.argsort(distances[:, c])
chan_closest = ind_sorted[:num_closest]
Expand All @@ -97,7 +97,7 @@ def get_spatial_interpolation_kernel(

elif method == "nearest":
distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean")
interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64")
interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype)
for c in range(target_location.shape[0]):
ind_closest = np.argmin(distances[:, c])
interpolation_kernel[ind_closest, c] = 1.0
Expand Down
14 changes: 6 additions & 8 deletions src/spikeinterface/preprocessing/tests/test_motion.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import pytest
from pathlib import Path

import shutil
from pathlib import Path

import numpy as np
import pytest
from spikeinterface.core import generate_recording

from spikeinterface.preprocessing import correct_motion, load_motion_info

import numpy as np


def test_estimate_and_correct_motion(create_cache_folder):
cache_folder = create_cache_folder
Expand All @@ -18,6 +15,7 @@ def test_estimate_and_correct_motion(create_cache_folder):
folder = cache_folder / "estimate_and_correct_motion"
if folder.exists():
shutil.rmtree(folder)

rec_corrected = correct_motion(rec, folder=folder)
print(rec_corrected)

Expand All @@ -26,5 +24,5 @@ def test_estimate_and_correct_motion(create_cache_folder):


if __name__ == "__main__":
print(correct_motion.__doc__)
# test_estimate_and_correct_motion()
# print(correct_motion.__doc__)
test_estimate_and_correct_motion()
6 changes: 5 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.preprocessing.motion import load_motion_info

motion_info = load_motion_info(motion_folder)
merging_params["maximum_distance_um"] = max(50, 2 * np.abs(motion_info["motion"]).max())
motion = motion_info["motion"]
max_motion = max(
np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement))
)
merging_params["maximum_distance_um"] = max(50, 2 * max_motion)

# peak_sign = params['detection'].get('peak_sign', 'neg')
# best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign)
Expand Down
Loading

0 comments on commit a233697

Please sign in to comment.