diff --git a/doc/modules/motion_correction.rst b/doc/modules/motion_correction.rst index 8be2456caa..af81cb42d1 100644 --- a/doc/modules/motion_correction.rst +++ b/doc/modules/motion_correction.rst @@ -163,21 +163,19 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte max_distance_um=150.0, **job_kwargs) # Step 2: motion inference - motion, temporal_bins, spatial_bins = estimate_motion(recording=rec, - peaks=peaks, - peak_locations=peak_locations, - method="decentralized", - direction="y", - bin_duration_s=2.0, - bin_um=5.0, - win_step_um=50.0, - win_sigma_um=150.0) + motion = estimate_motion(recording=rec, + peaks=peaks, + peak_locations=peak_locations, + method="decentralized", + direction="y", + bin_duration_s=2.0, + bin_um=5.0, + win_step_um=50.0, + win_sigma_um=150.0) # Step 3: motion interpolation # this step is lazy rec_corrected = interpolate_motion(recording=rec, motion=motion, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=30.) @@ -220,8 +218,6 @@ different preprocessing chains: one for motion correction and one for spike sort rec_corrected2 = interpolate_motion( recording=rec2, motion=motion_info['motion'], - temporal_bins=motion_info['temporal_bins'], - spatial_bins=motion_info['spatial_bins'], **motion_info['parameters']['interpolate_motion_kwargs']) sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2) diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index cefe4d4d7a..8023bd4367 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -1,14 +1,14 @@ from __future__ import annotations -import time -from pathlib import Path - import numpy as np import json +from pathlib import Path +import time from spikeinterface.core import get_noise_levels, fix_job_kwargs from spikeinterface.core.job_tools import _shared_job_kwargs_doc from spikeinterface.core.core_tools import SIJsonEncoder +from spikeinterface.core.job_tools import _shared_job_kwargs_doc motion_options_preset = { # This preset should be the most acccurate @@ -68,7 +68,7 @@ weight_with_amplitude=False, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, "nonrigid_fast_and_accurate": { @@ -127,7 +127,7 @@ weight_with_amplitude=False, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # This preset is a super fast rigid estimation with center of mass @@ -152,7 +152,7 @@ rigid=True, ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # This preset try to mimic kilosort2.5 motion estimator @@ -186,7 +186,7 @@ win_shape="rect", ), "interpolate_motion_kwargs": dict( - direction=1, border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 + border_mode="force_extrapolate", spatial_interpolation_method="kriging", sigma_um=20.0, p=2 ), }, # empty preset @@ -276,9 +276,8 @@ def correct_motion( recording_corrected : Recording The motion corrected recording motion_info : dict - Optional output if `output_motion_info=True` + Optional output if `output_motion_info=True`. The key "motion" holds the Motion object. """ - # local import are important because "sortingcomponents" is not important by default from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods from spikeinterface.sortingcomponents.peak_selection import select_peaks @@ -377,21 +376,15 @@ def correct_motion( np.save(folder / "peak_locations.npy", peak_locations) t0 = time.perf_counter() - motion, temporal_bins, spatial_bins = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) + motion = estimate_motion(recording, peaks, peak_locations, **estimate_motion_kwargs) t1 = time.perf_counter() run_times["estimate_motion"] = t1 - t0 - recording_corrected = InterpolateMotionRecording( - recording, motion, temporal_bins, spatial_bins, **interpolate_motion_kwargs - ) + recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) if folder is not None: (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - - np.save(folder / "temporal_bins.npy", temporal_bins) - np.save(folder / "motion.npy", motion) - if spatial_bins is not None: - np.save(folder / "spatial_bins.npy", spatial_bins) + motion.save(folder / "motion") if output_motion_info: motion_info = dict( @@ -399,8 +392,6 @@ def correct_motion( run_times=run_times, peaks=peaks, peak_locations=peak_locations, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, motion=motion, ) return recording_corrected, motion_info @@ -419,6 +410,8 @@ def correct_motion( def load_motion_info(folder): + from spikeinterface.sortingcomponents.motion_utils import Motion + folder = Path(folder) motion_info = {} @@ -429,11 +422,13 @@ def load_motion_info(folder): with open(folder / "run_times.json") as f: motion_info["run_times"] = json.load(f) - array_names = ("peaks", "peak_locations", "temporal_bins", "spatial_bins", "motion") + array_names = ("peaks", "peak_locations") for name in array_names: if (folder / f"{name}.npy").exists(): motion_info[name] = np.load(folder / f"{name}.npy") else: motion_info[name] = None + motion_info["motion"] = Motion.load(folder / "motion") + return motion_info diff --git a/src/spikeinterface/preprocessing/preprocessing_tools.py b/src/spikeinterface/preprocessing/preprocessing_tools.py index c0b80c349b..942478fd71 100644 --- a/src/spikeinterface/preprocessing/preprocessing_tools.py +++ b/src/spikeinterface/preprocessing/preprocessing_tools.py @@ -80,7 +80,7 @@ def get_spatial_interpolation_kernel( elif method == "idw": distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean") - interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64") + interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype) for c in range(target_location.shape[0]): ind_sorted = np.argsort(distances[:, c]) chan_closest = ind_sorted[:num_closest] @@ -97,7 +97,7 @@ def get_spatial_interpolation_kernel( elif method == "nearest": distances = scipy.spatial.distance.cdist(source_location, target_location, metric="euclidean") - interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype="float64") + interpolation_kernel = np.zeros((source_location.shape[0], target_location.shape[0]), dtype=dtype) for c in range(target_location.shape[0]): ind_closest = np.argmin(distances[:, c]) interpolation_kernel[ind_closest, c] = 1.0 diff --git a/src/spikeinterface/preprocessing/tests/test_motion.py b/src/spikeinterface/preprocessing/tests/test_motion.py index e79fda1ad8..a298b41d8f 100644 --- a/src/spikeinterface/preprocessing/tests/test_motion.py +++ b/src/spikeinterface/preprocessing/tests/test_motion.py @@ -1,14 +1,11 @@ -import pytest -from pathlib import Path - import shutil +from pathlib import Path +import numpy as np +import pytest from spikeinterface.core import generate_recording - from spikeinterface.preprocessing import correct_motion, load_motion_info -import numpy as np - def test_estimate_and_correct_motion(create_cache_folder): cache_folder = create_cache_folder @@ -18,6 +15,7 @@ def test_estimate_and_correct_motion(create_cache_folder): folder = cache_folder / "estimate_and_correct_motion" if folder.exists(): shutil.rmtree(folder) + rec_corrected = correct_motion(rec, folder=folder) print(rec_corrected) @@ -26,5 +24,5 @@ def test_estimate_and_correct_motion(create_cache_folder): if __name__ == "__main__": - print(correct_motion.__doc__) - # test_estimate_and_correct_motion() + # print(correct_motion.__doc__) + test_estimate_and_correct_motion() diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index effc04d898..b5df0f1059 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -313,7 +313,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): from spikeinterface.preprocessing.motion import load_motion_info motion_info = load_motion_info(motion_folder) - merging_params["maximum_distance_um"] = max(50, 2 * np.abs(motion_info["motion"]).max()) + motion = motion_info["motion"] + max_motion = max( + np.max(np.abs(motion.displacement[seg_index])) for seg_index in range(len(motion.displacement)) + ) + merging_params["maximum_distance_um"] = max(50, 2 * max_motion) # peak_sign = params['detection'].get('peak_sign', 'neg') # best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 3212f95e7f..55ef21de9d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -1,20 +1,22 @@ from __future__ import annotations -import time +import json from pathlib import Path +import pickle +import time import numpy as np from spikeinterface.core import get_noise_levels +from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis +from spikeinterface.sortingcomponents.motion_estimation import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion -from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis - - from spikeinterface.widgets import plot_probe_map +from spikeinterface.sortingcomponents.motion_utils import Motion + # import MEArec as mr # TODO : plot_peaks @@ -28,8 +30,8 @@ def get_gt_motion_from_unit_displacement( unit_displacements, displacement_sampling_frequency, unit_locations, - temporal_bins, - spatial_bins, + temporal_bins_s, + spatial_bins_um, direction_dim=1, ): import scipy.interpolate @@ -37,20 +39,24 @@ def get_gt_motion_from_unit_displacement( unit_displacements = unit_displacements[:, :, direction_dim] times = np.arange(unit_displacements.shape[0]) / displacement_sampling_frequency f = scipy.interpolate.interp1d(times, unit_displacements, axis=0) - unit_displacements = f(temporal_bins) + unit_displacements = f(temporal_bins_s.clip(times[0], times[-1])) # spatial interpolataion of units discplacement - if spatial_bins.shape[0] == 1: + if spatial_bins_um.shape[0] == 1: # rigid - gt_motion = np.mean(unit_displacements, axis=1)[:, None] + gt_displacement = np.mean(unit_displacements, axis=1)[:, None] else: # non rigid - gt_motion = np.zeros((temporal_bins.size, spatial_bins.size)) - for t in range(temporal_bins.shape[0]): + gt_displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) + for t in range(temporal_bins_s.shape[0]): f = scipy.interpolate.interp1d( unit_locations[:, direction_dim], unit_displacements[t, :], fill_value="extrapolate" ) - gt_motion[t, :] = f(spatial_bins) + gt_displacement[t, :] = f(spatial_bins_um) + + gt_motion = Motion( + gt_displacement, temporal_bins_s, spatial_bins_um, direction="xyz"[direction_dim], interpolation_method="linear" + ) return gt_motion @@ -92,9 +98,7 @@ def run(self, **job_kwargs): t2 = time.perf_counter() peak_locations = localize_peaks(self.recording, selected_peaks, **p["localize_kwargs"], **job_kwargs) t3 = time.perf_counter() - motion, temporal_bins, spatial_bins = estimate_motion( - self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"] - ) + motion = estimate_motion(self.recording, selected_peaks, peak_locations, **p["estimate_motion_kwargs"]) t4 = time.perf_counter() step_run_times = dict( @@ -106,43 +110,37 @@ def run(self, **job_kwargs): self.result["step_run_times"] = step_run_times self.result["raw_motion"] = motion - self.result["temporal_bins"] = temporal_bins - self.result["spatial_bins"] = spatial_bins def compute_result(self, **result_params): raw_motion = self.result["raw_motion"] - temporal_bins = self.result["temporal_bins"] - spatial_bins = self.result["spatial_bins"] gt_motion = get_gt_motion_from_unit_displacement( self.unit_displacements, self.displacement_sampling_frequency, self.unit_locations, - temporal_bins, - spatial_bins, + raw_motion.temporal_bins_s[0], + raw_motion.spatial_bins_um, direction_dim=self.direction_dim, ) # align globally gt_motion and motion to avoid offsets motion = raw_motion.copy() - motion += np.median(gt_motion - motion) + motion.displacement[0] += np.median(gt_motion.displacement[0] - motion.displacement[0]) self.result["gt_motion"] = gt_motion self.result["motion"] = motion _run_key_saved = [ - ("raw_motion", "npy"), - ("temporal_bins", "npy"), - ("spatial_bins", "npy"), + ("raw_motion", "Motion"), ("step_run_times", "pickle"), ] _result_key_saved = [ ( "gt_motion", - "npy", + "Motion", ), ( "motion", - "npy", + "Motion", ), ] @@ -189,20 +187,20 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p # dirft ax = ax1 = fig.add_subplot(gs[2:7]) ax1.sharey(ax0) - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] # for i in range(self.gt_unit_positions.shape[1]): - # ax.plot(temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") + # ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") - for i in range(gt_motion.shape[1]): - depth = spatial_bins[i] + for i in range(gt_motion.displacement[0].shape[1]): + depth = motion.spatial_bins_um[i] if gt_drift: - ax.plot(temporal_bins, gt_motion[:, i] + depth, color="green", lw=4) + ax.plot(motion.temporal_bins_s[0], gt_motion.displacement[0][:, i] + depth, color="green", lw=4) if tested_drift: - ax.plot(temporal_bins, motion[:, i] + depth, color="cyan", lw=2) + ax.plot(motion.temporal_bins_s[0], motion.displacement[0][:, i] + depth, color="cyan", lw=2) ax.set_xlabel("time (s)") _simpleaxis(ax) @@ -241,14 +239,14 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] fig = plt.figure(figsize=figsize) gs = fig.add_gridspec(2, 2) - errors = gt_motion - motion + errors = gt_motion.displacement[0] - motion.displacement[0] channel_positions = bench.recording.get_channel_locations() probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() @@ -259,7 +257,12 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): aspect="auto", interpolation="nearest", origin="lower", - extent=(temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1]), + extent=( + motion.temporal_bins_s[0][0], + motion.temporal_bins_s[0][-1], + motion.spatial_bins_um[0], + motion.spatial_bins_um[-1], + ), ) plt.colorbar(im, ax=ax, label="error") ax.set_ylabel("depth (um)") @@ -270,7 +273,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax = fig.add_subplot(gs[1, 0]) mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) - ax.plot(temporal_bins, mean_error) + ax.plot(motion.temporal_bins_s[0], mean_error) ax.set_xlabel("time (s)") ax.set_ylabel("error") _simpleaxis(ax) @@ -279,7 +282,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax = fig.add_subplot(gs[1, 1]) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - ax.plot(spatial_bins, depth_error) + ax.plot(motion.spatial_bins_um, depth_error) ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) ax.set_xlabel("depth (um)") @@ -289,6 +292,7 @@ def plot_errors(self, case_keys=None, figsize=None, lim=None): ax.set_ylim(0, lim) def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)): + import matplotlib.pyplot as plt if case_keys is None: case_keys = list(self.cases.keys()) @@ -304,17 +308,17 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) gt_motion = bench.result["gt_motion"] motion = bench.result["motion"] - temporal_bins = bench.result["temporal_bins"] - spatial_bins = bench.result["spatial_bins"] + # temporal_bins_s = bench.result["temporal_bins_s"] + # spatial_bins_um = bench.result["spatial_bins_um"] # c = colors[count] if colors is not None else None c = colors[key] - errors = gt_motion - motion + errors = gt_motion.displacement[0] - motion.displacement[0] mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) - axes[0].plot(temporal_bins, mean_error, lw=1, label=label, color=c) + axes[0].plot(motion.temporal_bins_s[0], mean_error, lw=1, label=label, color=c) parts = axes[1].violinplot(mean_error, [count], showmeans=True) if c is not None: for pc in parts["bodies"]: @@ -324,7 +328,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) if k != "bodies": # for line in parts[k]: parts[k].set_color(c) - axes[2].plot(spatial_bins, depth_error, label=label, color=c) + axes[2].plot(motion.spatial_bins_um, depth_error, label=label, color=c) ax0 = ax = axes[0] ax.set_xlabel("Time [s]") @@ -361,8 +365,8 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # "peaks", # "selected_peaks", # "motion", -# "temporal_bins", -# "spatial_bins", +# "temporal_bins_s", +# "spatial_bins_um", # "peak_locations", # "gt_motion", # ) @@ -438,7 +442,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # self.recording, self.selected_peaks, **self.localize_kwargs, **self.job_kwargs # ) # t3 = time.perf_counter() -# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.motion, self.temporal_bins_s, self.spatial_bins_um = estimate_motion( # self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs # ) @@ -463,7 +467,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # def run_estimate_motion(self): # # usefull to re run only the motion estimate with peak localization # t3 = time.perf_counter() -# self.motion, self.temporal_bins, self.spatial_bins = estimate_motion( +# self.motion, self.temporal_bins_s, self.spatial_bins_um = estimate_motion( # self.recording, self.selected_peaks, self.peak_locations, **self.estimate_motion_kwargs # ) # t4 = time.perf_counter() @@ -479,7 +483,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # self.save_to_folder() # def compute_gt_motion(self): -# self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins) +# self.gt_unit_positions, _ = mr.extract_units_drift_vector(self.mearec_filename, time_vector=self.temporal_bins_s) # template_locations = np.array(mr.load_recordings(self.mearec_filename).template_locations) # assert len(template_locations.shape) == 3 @@ -489,18 +493,18 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # unit_motions = self.gt_unit_positions - unit_mid_positions # # unit_positions = np.mean(self.gt_unit_positions, axis=0) -# if self.spatial_bins is None: +# if self.spatial_bins_um is None: # self.gt_motion = np.mean(unit_motions, axis=1)[:, None] # channel_positions = self.recording.get_channel_locations() # probe_y_min, probe_y_max = channel_positions[:, 1].min(), channel_positions[:, 1].max() # center = (probe_y_min + probe_y_max) // 2 -# self.spatial_bins = np.array([center]) +# self.spatial_bins_um = np.array([center]) # else: # # time, units # self.gt_motion = np.zeros_like(self.motion) # for t in range(self.gt_unit_positions.shape[0]): # f = scipy.interpolate.interp1d(unit_mid_positions, unit_motions[t, :], fill_value="extrapolate") -# self.gt_motion[t, :] = f(self.spatial_bins) +# self.gt_motion[t, :] = f(self.spatial_bins_um) # def plot_true_drift(self, scaling_probe=1.5, figsize=(15, 10), axes=None): # if axes is None: @@ -534,11 +538,11 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = axes[1] # for i in range(self.gt_unit_positions.shape[1]): -# ax.plot(self.temporal_bins, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") +# ax.plot(self.temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5") # for i in range(self.gt_motion.shape[1]): -# depth = self.spatial_bins[i] -# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=4) +# depth = self.spatial_bins_um[i] +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i] + depth, color="green", lw=4) # # ax.set_ylim(ymin, ymax) # ax.set_xlabel("time (s)") @@ -617,15 +621,15 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax.axhline(probe_y_max, color="k", ls="--", alpha=0.5) # if show_drift: -# if self.spatial_bins is None: +# if self.spatial_bins_um is None: # center = (probe_y_min + probe_y_max) // 2 -# ax.plot(self.temporal_bins, self.gt_motion[:, 0] + center, color="green", lw=1.5) -# ax.plot(self.temporal_bins, self.motion[:, 0] + center, color="orange", lw=1.5) +# ax.plot(self.temporal_bins_s, self.gt_motion[:, 0] + center, color="green", lw=1.5) +# ax.plot(self.temporal_bins_s, self.motion[:, 0] + center, color="orange", lw=1.5) # else: # for i in range(self.gt_motion.shape[1]): -# depth = self.spatial_bins[i] -# ax.plot(self.temporal_bins, self.gt_motion[:, i] + depth, color="green", lw=1.5) -# ax.plot(self.temporal_bins, self.motion[:, i] + depth, color="orange", lw=1.5) +# depth = self.spatial_bins_um[i] +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i] + depth, color="green", lw=1.5) +# ax.plot(self.temporal_bins_s, self.motion[:, i] + depth, color="orange", lw=1.5) # if show_histogram: # ax2 = fig.add_subplot(gs[3]) @@ -669,10 +673,9 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # peak_locations_corrected = correct_motion_on_peaks( # self.selected_peaks, # self.peak_locations, -# self.recording.sampling_frequency, # self.motion, -# self.temporal_bins, -# self.spatial_bins, +# self.temporal_bins_s, +# self.spatial_bins_um, # direction="y", # ) # if axes is None: @@ -734,18 +737,18 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # colors = plt.colormaps["jet"].resampled(n) # for i in range(0, n, step): # ax = axs[0] -# ax.plot(self.temporal_bins, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) +# ax.plot(self.temporal_bins_s, self.gt_motion[:, i], lw=1.5, ls="--", color=colors(i)) # ax.plot( -# self.temporal_bins, +# self.temporal_bins_s, # self.motion[:, i], # lw=1.5, # ls="-", # color=colors(i), -# label=f"{self.spatial_bins[i]:0.1f}", +# label=f"{self.spatial_bins_um[i]:0.1f}", # ) # ax = axs[1] -# ax.plot(self.temporal_bins, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) +# ax.plot(self.temporal_bins_s, self.motion[:, i] - self.gt_motion[:, i], lw=1.5, ls="-", color=colors(i)) # ax = axs[0] # ax.set_title(self.title) @@ -774,7 +777,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # aspect="auto", # interpolation="nearest", # origin="lower", -# extent=(self.temporal_bins[0], self.temporal_bins[-1], self.spatial_bins[0], self.spatial_bins[-1]), +# extent=(self.temporal_bins_s[0], self.temporal_bins_s[-1], self.spatial_bins_um[0], self.spatial_bins_um[-1]), # ) # plt.colorbar(im, ax=ax, label="error") # ax.set_ylabel("depth (um)") @@ -785,7 +788,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = fig.add_subplot(gs[1, 0]) # mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) -# ax.plot(self.temporal_bins, mean_error) +# ax.plot(self.temporal_bins_s, mean_error) # ax.set_xlabel("time (s)") # ax.set_ylabel("error") # _simpleaxis(ax) @@ -794,7 +797,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # ax = fig.add_subplot(gs[1, 1]) # depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) -# ax.plot(self.spatial_bins, depth_error) +# ax.plot(self.spatial_bins_um, depth_error) # ax.axvline(probe_y_min, color="k", ls="--", alpha=0.5) # ax.axvline(probe_y_max, color="k", ls="--", alpha=0.5) # ax.set_xlabel("depth (um)") @@ -816,7 +819,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # mean_error = np.sqrt(np.mean((errors) ** 2, axis=1)) # depth_error = np.sqrt(np.mean((errors) ** 2, axis=0)) -# axes[0].plot(benchmark.temporal_bins, mean_error, lw=1, label=benchmark.title, color=c) +# axes[0].plot(benchmark.temporal_bins_s, mean_error, lw=1, label=benchmark.title, color=c) # parts = axes[1].violinplot(mean_error, [count], showmeans=True) # if c is not None: # for pc in parts["bodies"]: @@ -826,7 +829,7 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # if k != "bodies": # # for line in parts[k]: # parts[k].set_color(c) -# axes[2].plot(benchmark.spatial_bins, depth_error, label=benchmark.title, color=c) +# axes[2].plot(benchmark.spatial_bins_um, depth_error, label=benchmark.title, color=c) # ax0 = ax = axes[0] # ax.set_xlabel("Time [s]") @@ -875,10 +878,10 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # interpolation="nearest", # origin="lower", # extent=( -# benchmark.temporal_bins[0], -# benchmark.temporal_bins[-1], -# benchmark.spatial_bins[0], -# benchmark.spatial_bins[-1], +# benchmark.temporal_bins_s[0], +# benchmark.temporal_bins_s[-1], +# benchmark.spatial_bins_um[0], +# benchmark.spatial_bins_um[-1], # ), # ) # fig.colorbar(im, ax=ax, label="error") @@ -896,11 +899,11 @@ def plot_summary_errors(self, case_keys=None, show_legend=True, figsize=(15, 5)) # def plot_motions_several_benchmarks(benchmarks): # fig, ax = plt.subplots(figsize=(15, 5)) -# ax.plot(list(benchmarks)[0].temporal_bins, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") +# ax.plot(list(benchmarks)[0].temporal_bins_s, list(benchmarks)[0].gt_motion[:, 0], lw=2, c="k", label="real motion") # for count, benchmark in enumerate(benchmarks): -# ax.plot(benchmark.temporal_bins, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) +# ax.plot(benchmark.temporal_bins_s, benchmark.motion.mean(1), lw=1, c=f"C{count}", label=benchmark.title) # ax.fill_between( -# benchmark.temporal_bins, +# benchmark.temporal_bins_s, # benchmark.motion.mean(1) - benchmark.motion.std(1), # benchmark.motion.mean(1) + benchmark.motion.std(1), # color=f"C{count}", diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py index a515424648..a6ff05fc55 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_interpolation.py @@ -44,9 +44,7 @@ def run(self, **job_kwargs): recording = self.drifting_recording elif self.params["recording_source"] == "corrected": correct_motion_kwargs = self.params["correct_motion_kwargs"] - recording = InterpolateMotionRecording( - self.drifting_recording, self.motion, self.temporal_bins, self.spatial_bins, **correct_motion_kwargs - ) + recording = InterpolateMotionRecording(self.drifting_recording, self.motion, **correct_motion_kwargs) else: raise ValueError("recording_source") diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index b2cf56eb9c..e9f128993d 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -406,6 +406,8 @@ def _save_keys(self, saved_keys, folder): pickle.dump(self.result[k], f) elif format == "sorting": self.result[k].save(folder=folder / k, format="numpy_folder", overwrite=True) + elif format == "Motion": + self.result[k].save(folder=folder / k) elif format == "zarr_templates": self.result[k].to_zarr(folder / k) elif format == "sorting_analyzer": @@ -440,6 +442,10 @@ def load_folder(cls, folder): from spikeinterface.core import load_extractor result[k] = load_extractor(folder / k) + elif format == "Motion": + from spikeinterface.sortingcomponents.motion_utils import Motion + + result[k] = Motion.load(folder / k) elif format == "zarr_templates": from spikeinterface.core.template import Templates diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py index 696531b221..526cc2e92f 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_estimation.py @@ -65,8 +65,10 @@ def test_benchmark_motion_estimaton(create_cache_folder): # plots study.plot_true_drift() + study.plot_drift() study.plot_errors() study.plot_summary_errors() + import matplotlib.pyplot as plt plt.show() diff --git a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py index 4b7264a9de..6d80d027f2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/tests/test_benchmark_motion_interpolation.py @@ -48,8 +48,11 @@ def test_benchmark_motion_interpolation(create_cache_folder): spatial_bins, direction_dim=1, ) + # print(gt_motion) + + # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # ax.imshow(gt_motion.T) + # ax.imshow(gt_motion.displacement[0].T) # plt.show() cases = {} @@ -130,6 +133,8 @@ def test_benchmark_motion_interpolation(create_cache_folder): study.plot_sorting_accuracy(mode="depth", mode_best_merge=False) study.plot_sorting_accuracy(mode="depth", mode_best_merge=True) + import matplotlib.pyplot as plt + plt.show() diff --git a/src/spikeinterface/sortingcomponents/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion_estimation.py index 9eb5415316..3134d68681 100644 --- a/src/spikeinterface/sortingcomponents/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion_estimation.py @@ -1,7 +1,11 @@ from __future__ import annotations -import numpy as np from tqdm.auto import tqdm, trange +import numpy as np + + +from .motion_utils import Motion +from .tools import make_multi_method_doc try: import torch @@ -11,8 +15,6 @@ except ImportError: HAVE_TORCH = False -from .tools import make_multi_method_doc - def estimate_motion( recording, @@ -55,7 +57,7 @@ def estimate_motion( **histogram section** direction: "x" | "y" | "z", default: "y" - Dimension on which the motion is estimated + Dimension on which the motion is estimated. "y" is depth along the probe. bin_duration_s: float, default: 10 Bin duration in second bin_um: float, default: 10 @@ -105,19 +107,8 @@ def estimate_motion( Returns ------- - motion: numpy array 2d - Motion estimate in um. - Shape (temporal bins, spatial bins) - motion.shape[0] = temporal_bins.shape[0] - motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) - If upsample_to_histogram_bin, motion.shape[1] corresponds to spatial - bins given by bin_um. - temporal_bins: numpy.array 1d - temporal bins (bin center) - spatial_bins: numpy.array 1d - Windows center. - spatial_bins.shape[0] == motion.shape[1] - If rigid then spatial_bins.shape[0] == 1 + motion: Motion object + The motion object. extra_check: dict Optional output if `output_extra_check=True` This dict contain histogram, pairwise_displacement usefull for ploting. @@ -148,7 +139,7 @@ def estimate_motion( # run method method_class = estimate_motion_methods[method] - motion, temporal_bins = method_class.run( + motion_array, temporal_bins = method_class.run( recording, peaks, peak_locations, @@ -164,27 +155,30 @@ def estimate_motion( ) # replace nan by zeros - motion[np.isnan(motion)] = 0 + np.nan_to_num(motion_array, copy=False) if post_clean: - motion = clean_motion_vector( - motion, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s + motion_array = clean_motion_vector( + motion_array, temporal_bins, bin_duration_s, speed_threshold=speed_threshold, sigma_smooth_s=sigma_smooth_s ) if upsample_to_histogram_bin is None: upsample_to_histogram_bin = not rigid if upsample_to_histogram_bin: - extra_check["motion"] = motion + extra_check["motion_array"] = motion_array extra_check["non_rigid_window_centers"] = non_rigid_window_centers non_rigid_windows = np.array(non_rigid_windows) non_rigid_windows /= non_rigid_windows.sum(axis=0, keepdims=True) non_rigid_window_centers = spatial_bin_edges[:-1] + bin_um / 2 - motion = motion @ non_rigid_windows + motion_array = motion_array @ non_rigid_windows + + # TODO handle multi segment + motion = Motion([motion_array], [temporal_bins], non_rigid_window_centers, direction=direction) if output_extra_check: - return motion, temporal_bins, non_rigid_window_centers, extra_check + return motion, extra_check else: - return motion, temporal_bins, non_rigid_window_centers + return motion class DecentralizedRegistration: @@ -342,7 +336,7 @@ def run( extra_check["spatial_hist_bin_edges"] = spatial_hist_bin_edges # temporal bins are bin center - temporal_bins = temporal_hist_bin_edges[:-1] + bin_duration_s // 2.0 + temporal_bins = 0.5 * (temporal_hist_bin_edges[1:] + temporal_hist_bin_edges[:-1]) motion = np.zeros((temporal_bins.size, len(non_rigid_windows)), dtype=np.float64) windows_iter = non_rigid_windows @@ -690,16 +684,15 @@ def make_2d_motion_histogram( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) arr = np.zeros((peaks.size, 2), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] if weight_with_amplitude: @@ -707,11 +700,11 @@ def make_2d_motion_histogram( else: weights = None - motion_histogram, edges = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges), weights=weights) + motion_histogram, edges = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges), weights=weights) # average amplitude in each bin if weight_with_amplitude: - bin_counts, _ = np.histogramdd(arr, bins=(sample_bin_edges, spatial_bin_edges)) + bin_counts, _ = np.histogramdd(arr, bins=(temporal_bin_edges, spatial_bin_edges)) bin_counts[bin_counts == 0] = 1 motion_histogram = motion_histogram / bin_counts @@ -766,11 +759,10 @@ def make_3d_motion_histograms( spatial_bin_edges 1d array with spatial bin edges """ - fs = recording.get_sampling_frequency() - num_samples = recording.get_num_samples(segment_index=0) - bin_sample_size = int(bin_duration_s * fs) - sample_bin_edges = np.arange(0, num_samples + bin_sample_size, bin_sample_size) - temporal_bin_edges = sample_bin_edges / fs + n_samples = recording.get_num_samples() + mint_s = recording.sample_index_to_time(0) + maxt_s = recording.sample_index_to_time(n_samples) + temporal_bin_edges = np.arange(mint_s, maxt_s + bin_duration_s, bin_duration_s) if spatial_bin_edges is None: spatial_bin_edges = get_spatial_bin_edges(recording, direction, margin_um, bin_um) @@ -785,14 +777,14 @@ def make_3d_motion_histograms( ) arr = np.zeros((peaks.size, 3), dtype="float64") - arr[:, 0] = peaks["sample_index"] + arr[:, 0] = recording.sample_index_to_time(peaks["sample_index"]) arr[:, 1] = peak_locations[direction] arr[:, 2] = abs_peaks_log_norm motion_histograms, edges = np.histogramdd( arr, bins=( - sample_bin_edges, + temporal_bin_edges, spatial_bin_edges, amplitude_bin_edges, ), @@ -825,7 +817,6 @@ def compute_pairwise_displacement( """ Compute pairwise displacement """ - from scipy import sparse from scipy import linalg assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'" diff --git a/src/spikeinterface/sortingcomponents/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion_interpolation.py index 5e3733b363..32bb7634e9 100644 --- a/src/spikeinterface/sortingcomponents/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion_interpolation.py @@ -1,40 +1,26 @@ from __future__ import annotations import numpy as np - - from spikeinterface.core.core_tools import define_function_from_class -from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment from spikeinterface.preprocessing import get_spatial_interpolation_kernel +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.preprocessing.filter import fix_dtype -def correct_motion_on_peaks( - peaks, - peak_locations, - sampling_frequency, - motion, - temporal_bins, - spatial_bins, - direction="y", -): +def correct_motion_on_peaks(peaks, peak_locations, motion, recording): """ Given the output of estimate_motion(), apply inverse motion on peak locations. Parameters ---------- - peaks: np.array + peaks : np.array peaks vector - peak_locations: np.array + peak_locations : np.array peaks location vector - sampling_frequency: np.array - sampling_frequency of the recording - motion: np.array 2D - motion.shape[0] equal temporal_bins.shape[0] - motion.shape[1] equal 1 when "rigid" motion equal temporal_bins.shape[0] when "non-rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: np.array - Bins for non-rigid motion. If spatial_bins.sahpe[0] == 1 then rigid motion is used. + motion : Motion + The motion object. + recording : Recording + The recording object. This is used to convert sample indices to times. Returns ------- @@ -42,21 +28,17 @@ def correct_motion_on_peaks( Motion-corrected peak locations """ corrected_peak_locations = peak_locations.copy() - import scipy.interpolate - - spike_times = peaks["sample_index"] / sampling_frequency - if spatial_bins.shape[0] == 1: - # rigid motion interpolation 1D - f = scipy.interpolate.interp1d(temporal_bins, motion[:, 0], bounds_error=False, fill_value="extrapolate") - shift = f(spike_times) - corrected_peak_locations[direction] -= shift - else: - # non rigid motion = interpolation 2D - f = scipy.interpolate.RegularGridInterpolator( - (temporal_bins, spatial_bins), motion, method="linear", bounds_error=False, fill_value=None + + for segment_index in range(motion.num_segments): + times_s = recording.sample_index_to_time(peaks["sample_index"], segment_index=segment_index) + i0, i1 = np.searchsorted(peaks["segment_index"], [segment_index, segment_index + 1]) + + spike_times = times_s[i0:i1] + spike_locs = peak_locations[motion.direction][i0:i1] + spike_displacement = motion.get_displacement_at_time_and_depth( + spike_times, spike_locs, segment_index=segment_index ) - shift = f(np.c_[spike_times, peak_locations[direction]]) - corrected_peak_locations[direction] -= shift + corrected_peak_locations[i0:i1][motion.direction] -= spike_displacement return corrected_peak_locations @@ -66,12 +48,12 @@ def interpolate_motion_on_traces( times, channel_locations, motion, - temporal_bins, - spatial_bins, - direction=1, + segment_index=None, channel_inds=None, + interpolation_time_bin_centers_s=None, spatial_interpolation_method="kriging", spatial_interpolation_kwargs={}, + dtype=None, ): """ Apply inverse motion with spatial interpolation on traces. @@ -82,20 +64,19 @@ def interpolate_motion_on_traces( ---------- traces : np.array Trace snippet (num_samples, num_channels) + times : np.array + Sample times in seconds for the frames of the traces snippet channel_location: np.array 2d Channel location with shape (n, 2) or (n, 3) - motion: np.array 2D - motion.shape[0] equal temporal_bins.shape[0] - motion.shape[1] equal 1 when "rigid" motion - equal temporal_bins.shape[0] when "none rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: None or np.array - Bins for non-rigid motion. If None, rigid motion is used - direction: int in (0, 1, 2) - Dimension of shift in channel_locations. + motion: Motion + The motion object. + segment_index: int or None + The segment index. channel_inds: None or list If not None, interpolate only a subset of channels. + interpolation_time_bin_centers_s : None or np.array + Manually specify the time bins which the interpolation happens + in for this segment. If None, these are the motion estimate's time bins. spatial_interpolation_method: "idw" | "kriging", default: "kriging" The spatial interpolation method used to interpolate the channel locations: * idw : Inverse Distance Weighing @@ -105,40 +86,62 @@ def interpolate_motion_on_traces( Returns ------- - channel_motions: np.array - Shift over time by channel - Shape (times.shape[0], channel_location.shape[0]) + traces_corrected: np.array + Motion-corrected trace snippet, (num_samples, num_channels) """ # assert HAVE_NUMBA assert times.shape[0] == traces.shape[0] + if dtype is None: + dtype = traces.dtype + if dtype.kind != "f": + raise ValueError(f"Can't interpolate_motion with dtype {dtype}.") + if traces.dtype != dtype: + traces = traces.astype(dtype) + + if segment_index is None: + if motion.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + if channel_inds is None: traces_corrected = np.zeros(traces.shape, dtype=traces.dtype) else: channel_inds = np.asarray(channel_inds) traces_corrected = np.zeros((traces.shape[0], channel_inds.size), dtype=traces.dtype) - # regroup times by closet temporal_bins - bin_inds = _get_closest_ind(temporal_bins, times) + total_num_chans = channel_locations.shape[0] - # inperpolation kernel will be the same per temporal bin - for bin_ind in np.unique(bin_inds): - # Step 1 : channel motion - if spatial_bins.shape[0] == 1: - # rigid motion : same motion for all channels - channel_motions = motion[bin_ind, 0] - else: - # non rigid : interpolation channel motion for this temporal bin - import scipy.interpolate + # -- determine the blocks of frames that will land in the same interpolation time bin + time_bins = interpolation_time_bin_centers_s + if time_bins is None: + time_bins = motion.temporal_bins_s[segment_index] + bin_s = time_bins[1] - time_bins[0] + bins_start = time_bins[0] - 0.5 * bin_s + # nearest bin center for each frame? + bin_inds = (times - bins_start) // bin_s + bin_inds = bin_inds.astype(int) + # the time bins may not cover the whole set of times in the recording, + # so we need to clip these indices to the valid range + np.clip(bin_inds, 0, time_bins.size, out=bin_inds) - f = scipy.interpolate.interp1d( - spatial_bins, motion[bin_ind, :], kind="linear", axis=0, bounds_error=False, fill_value="extrapolate" - ) - locs = channel_locations[:, direction] - channel_motions = f(locs) + # -- what are the possibilities here anyway? + bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1) + + # inperpolation kernel will be the same per temporal bin + interp_times = np.empty(total_num_chans) + current_start_index = 0 + for bin_ind in bins_here: + bin_time = time_bins[bin_ind] + interp_times.fill(bin_time) + channel_motions = motion.get_displacement_at_time_and_depth( + interp_times, + channel_locations[:, motion.dim], + segment_index=segment_index, + ) channel_locations_moved = channel_locations.copy() - channel_locations_moved[:, direction] += channel_motions - # channel_locations_moved[:, direction] -= channel_motions + channel_locations_moved[:, motion.dim] += channel_motions if channel_inds is not None: channel_locations_moved = channel_locations_moved[channel_inds] @@ -146,24 +149,35 @@ def interpolate_motion_on_traces( drift_kernel = get_spatial_interpolation_kernel( channel_locations, channel_locations_moved, - dtype="float32", + dtype=dtype, method=spatial_interpolation_method, **spatial_interpolation_kwargs, ) - i0 = np.searchsorted(bin_inds, bin_ind, side="left") - i1 = np.searchsorted(bin_inds, bin_ind, side="right") + # keep this for DEBUG + # import matplotlib.pyplot as plt + # fig, ax = plt.subplots() + # ax.matshow(drift_kernel) + # ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}") + # plt.show() + + # quickly find the end of this bin, which is also the start of the next + next_start_index = current_start_index + np.searchsorted( + bin_inds[current_start_index:], bin_ind + 1, side="left" + ) + in_bin = slice(current_start_index, next_start_index) # here we use a simple np.matmul even if dirft_kernel can be super sparse. # because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing # in ChunkRecordingExecutor) - traces_corrected[i0:i1] = traces[i0:i1] @ drift_kernel + np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin]) + current_start_index = next_start_index return traces_corrected # if HAVE_NUMBA: -# # @numba.jit(parallel=False) +# # @numba.jit(parallel=False) # @numba.jit(parallel=True) # def my_sparse_dot(data_in, data_out, sparse_chans, weights): # """ @@ -178,7 +192,7 @@ def interpolate_motion_on_traces( # num_samples = data_in.shape[0] # num_chan_out = data_out.shape[1] # num_sparse = sparse_chans.shape[1] -# # for sample_index in range(num_samples): +# # for sample_index in range(num_samples): # for sample_index in numba.prange(num_samples): # for out_chan in range(num_chan_out): # v = 0 @@ -205,24 +219,25 @@ def _get_closest_ind(array, values): class InterpolateMotionRecording(BasePreprocessor): """ - Recording that corrects motion on-the-fly given a motion vector estimation (rigid or non-rigid). - This internally applies a spatial interpolation on the original traces after reversing the motion. - `estimate_motion()` must be called before this to estimate the motion vector. + Interpolate the input recording's traces to correct for motion, according to the + motion estimate object `motion`. The interpolation is carried out "lazily" / on the fly + by applying a spatial interpolation on the original traces to estimate their values + at the positions of the probe's channels after being shifted inversely to the motion. + + To get a Motion object, use `interpolate_motion()`. + + By default, each frame is spatially interpolated by the motion at the nearest motion + estimation time bin -- in other words, the temporal resolution of the motion correction + is the same as the motion estimation's. However, this behavior can be changed by setting + `interpolation_time_bin_centers_s` or `interpolation_time_bin_size_s` below. In that case, + the motion estimate will be interpolated to match the interpolation time bins. Parameters ---------- recording: Recording The parent recording. - motion: np.array 2D - The motion signal obtained with `estimate_motion()` - motion.shape[0] must correspond to temporal_bins.shape[0] - motion.shape[1] is 1 when "rigid" motion and spatial_bins.shape[0] when "non-rigid" - temporal_bins: np.array - Temporal bins in second. - spatial_bins: None or np.array - Bins for non-rigid motion. If None, rigid motion is used - direction: 0 | 1 | 2, default: 1 - Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z) + motion: Motion + The motion object spatial_interpolation_method: "kriging" | "idw" | "nearest", default: "kriging" The spatial interpolation method used to interpolate the channel locations. See `spikeinterface.preprocessing.get_spatial_interpolation_kernel()` for more details. @@ -239,10 +254,22 @@ class InterpolateMotionRecording(BasePreprocessor): Number of closest channels used by "idw" method for interpolation. border_mode: "remove_channels" | "force_extrapolate" | "force_zeros", default: "remove_channels" Control how channels are handled on border: - * "remove_channels": remove channels on the border, the recording has less channels * "force_extrapolate": keep all channel and force extrapolation (can lead to strange signal) * "force_zeros": keep all channel but set zeros when outside (force_extrapolate=False) + interpolation_time_bin_centers_s: np.array or list of np.array, optional + Spatially interpolate each frame according to the displacement estimate at its closest + bin center in this array. If not supplied, this is set to the motion estimate's time bin + centers. If it's supplied, the motion estimate is interpolated to these bin centers. + If you have a multi-segment recording, pass a list of these, one per segment. + interpolation_time_bin_size_s: float, optional + Similar to the previous argument: interpolation_time_bin_centers_s will be constructed + by bins spaced by interpolation_time_bin_size_s. This is ignored if interpolation_time_bin_centers_s + is supplied. + dtype : str or np.dtype, optional + Interpolation needs to convert to a floating dtype. If dtype is supplied, that will be used. + If the input recording is already floating and dtype=None, then its dtype is used by default. + If the input recording is integer, then float32 is used by default. Returns ------- @@ -250,60 +277,52 @@ class InterpolateMotionRecording(BasePreprocessor): Recording after motion correction """ - name = "correct_motion" + name = "interpolate_motion" def __init__( self, recording, motion, - temporal_bins, - spatial_bins, - direction=1, border_mode="remove_channels", spatial_interpolation_method="kriging", sigma_um=20.0, p=1, num_closest=3, + interpolation_time_bin_centers_s=None, + interpolation_time_bin_size_s=None, + dtype=None, + **spatial_interpolation_kwargs, ): - assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" - - # force as arrays - temporal_bins = np.asarray(temporal_bins) - motion = np.asarray(motion) - spatial_bins = np.asarray(spatial_bins) + # assert recording.get_num_segments() == 1, "correct_motion() is only available for single-segment recordings" channel_locations = recording.get_channel_locations() - assert channel_locations.ndim >= direction, ( - f"'direction' {direction} not available. " f"Channel locations have {channel_locations.ndim} dimensions." + assert channel_locations.ndim >= motion.dim, ( + f"'direction' {motion.direction} not available. " + f"Channel locations have {channel_locations.ndim} dimensions." + ) + spatial_interpolation_kwargs = dict( + sigma_um=sigma_um, p=p, num_closest=num_closest, **spatial_interpolation_kwargs ) - spatial_interpolation_kwargs = dict(sigma_um=sigma_um, p=p, num_closest=num_closest) if border_mode == "remove_channels": - locs = channel_locations[:, direction] - l0, l1 = np.min(channel_locations[:, direction]), np.max(channel_locations[:, direction]) + locs = channel_locations[:, motion.dim] + l0, l1 = np.min(locs), np.max(locs) - # compute max and min motion (with interpolation) - # and check if channels are inside + # check if channels stay inside the probe extents for all segments channel_inside = np.ones(locs.shape[0], dtype="bool") - for operator in (np.max, np.min): - if spatial_bins.shape[0] == 1: - best_motions = operator(motion[:, 0]) - else: - # non rigid : interpolation channel motion for this temporal bin - import scipy.spatial - import scipy.interpolate - - f = scipy.interpolate.interp1d( - spatial_bins, - operator(motion[:, :], axis=0), - kind="linear", - axis=0, - bounds_error=False, - fill_value="extrapolate", - ) - best_motions = f(locs) - channel_inside &= ((locs + best_motions) >= l0) & ((locs + best_motions) <= l1) - - (channel_inds,) = np.nonzero(channel_inside) + for segment_index in range(recording.get_num_segments()): + # evaluate the positions of all channels over all time bins + channel_displacements = motion.get_displacement_at_time_and_depth( + times_s=motion.temporal_bins_s[segment_index], + locations_um=locs, + grid=True, + ) + channel_locations_moved = locs[:, None] + channel_displacements + # check if these remain inside of the probe + seg_inside = channel_locations_moved.clip(l0, l1) == channel_locations_moved + seg_inside = seg_inside.all(axis=1) + channel_inside &= seg_inside + + channel_inds = np.flatnonzero(channel_inside) channel_ids = recording.channel_ids[channel_inds] spatial_interpolation_kwargs["force_extrapolate"] = False elif border_mode == "force_extrapolate": @@ -317,7 +336,14 @@ def __init__( else: raise ValueError("Wrong border_mode") - BasePreprocessor.__init__(self, recording, channel_ids=channel_ids) + if dtype is None: + if recording.dtype.kind == "f": + dtype = recording.dtype + else: + raise ValueError(f"Can't interpolate traces of recording with non-floating dtype={recording.dtype=}.") + + dtype_ = fix_dtype(recording, dtype) + BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels": # change the wiring of the probe @@ -327,32 +353,48 @@ def __init__( contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") self.set_property("contact_vector", contact_vector) - for parent_segment in recording._recording_segments: + # handle manual interpolation_time_bin_centers_s + # the case where interpolation_time_bin_size_s is set is handled per-segment below + if interpolation_time_bin_centers_s is None: + if interpolation_time_bin_size_s is None: + interpolation_time_bin_centers_s = motion.temporal_bins_s + + for segment_index, parent_segment in enumerate(recording._recording_segments): + # finish the per-segment part of the time bin logic + if interpolation_time_bin_centers_s is None: + # in this case, interpolation_time_bin_size_s is set. + s_end = parent_segment.get_num_samples() + t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end])) + halfbin = interpolation_time_bin_size_s / 2.0 + segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s) + else: + segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index] + rec_segment = InterpolateMotionRecordingSegment( parent_segment, channel_locations, motion, - temporal_bins, - spatial_bins, - direction, spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, + segment_index, + segment_interpolation_time_bins_s, + dtype=dtype_, ) self.add_recording_segment(rec_segment) self._kwargs = dict( recording=recording, motion=motion, - temporal_bins=temporal_bins, - spatial_bins=spatial_bins, - direction=direction, border_mode=border_mode, spatial_interpolation_method=spatial_interpolation_method, sigma_um=sigma_um, p=p, num_closest=num_closest, + interpolation_time_bin_centers_s=interpolation_time_bin_centers_s, + dtype=dtype_.str, ) + self._kwargs.update(spatial_interpolation_kwargs) class InterpolateMotionRecordingSegment(BasePreprocessorSegment): @@ -361,61 +403,51 @@ def __init__( parent_recording_segment, channel_locations, motion, - temporal_bins, - spatial_bins, - direction, spatial_interpolation_method, spatial_interpolation_kwargs, channel_inds, + segment_index, + interpolation_time_bin_centers_s, + dtype="float32", ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.channel_locations = channel_locations - self.motion = motion - self.temporal_bins = temporal_bins - self.spatial_bins = spatial_bins - self.direction = direction self.spatial_interpolation_method = spatial_interpolation_method self.spatial_interpolation_kwargs = spatial_interpolation_kwargs self.channel_inds = channel_inds + self.segment_index = segment_index + self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s + self.dtype = dtype + self.motion = motion def get_traces(self, start_frame, end_frame, channel_indices): if self.time_vector is not None: - raise NotImplementedError( - "time_vector for InterpolateMotionRecording do not work because temporal_bins start from 0" - ) - # times = np.asarray(self.time_vector[start_frame:end_frame]) + raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.") if start_frame is None: start_frame = 0 if end_frame is None: end_frame = self.get_num_samples() - times = np.arange(end_frame - start_frame, dtype="float64") - times /= self.sampling_frequency - t0 = start_frame / self.sampling_frequency - # if self.t_start is not None: - # t0 = t0 + self.t_start - times += t0 - + times = self.parent_recording_segment.sample_index_to_time(np.arange(start_frame, end_frame)) traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices=slice(None)) - - trace2 = interpolate_motion_on_traces( + traces = traces.astype(self.dtype) + traces = interpolate_motion_on_traces( traces, times, self.channel_locations, self.motion, - self.temporal_bins, - self.spatial_bins, - direction=self.direction, + segment_index=self.segment_index, channel_inds=self.channel_inds, spatial_interpolation_method=self.spatial_interpolation_method, spatial_interpolation_kwargs=self.spatial_interpolation_kwargs, + interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s, ) if channel_indices is not None: - trace2 = trace2[:, channel_indices] + traces = traces[:, channel_indices] - return trace2 + return traces -interpolate_motion = define_function_from_class(source_class=InterpolateMotionRecording, name="correct_motion") +interpolate_motion = define_function_from_class(source_class=InterpolateMotionRecording, name="interpolate_motion") diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py new file mode 100644 index 0000000000..26d4b35b1a --- /dev/null +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -0,0 +1,230 @@ +import json +from pathlib import Path + +import numpy as np +import spikeinterface +from spikeinterface.core.core_tools import check_json + + +class Motion: + """ + Motion of the tissue relative the probe. + + Parameters + ---------- + displacement : numpy array 2d or list of + Motion estimate in um. + List is the number of segment. + For each semgent : + * shape (temporal bins, spatial bins) + * motion.shape[0] = temporal_bins.shape[0] + * motion.shape[1] = 1 (rigid) or spatial_bins.shape[1] (non rigid) + temporal_bins_s : numpy.array 1d or list of + temporal bins (bin center) + spatial_bins_um : numpy.array 1d + Windows center. + spatial_bins_um.shape[0] == displacement.shape[1] + If rigid then spatial_bins_um.shape[0] == 1 + direction : str, default: 'y' + Direction of the motion. + interpolation_method : str + How to determine the displacement between bin centers? See the docs + for scipy.interpolate.RegularGridInterpolator for options. + """ + + def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y", interpolation_method="linear"): + if isinstance(displacement, np.ndarray): + self.displacement = [displacement] + assert isinstance(temporal_bins_s, np.ndarray) + self.temporal_bins_s = [temporal_bins_s] + else: + assert isinstance(displacement, (list, tuple)) + self.displacement = displacement + self.temporal_bins_s = temporal_bins_s + + assert isinstance(spatial_bins_um, np.ndarray) + self.spatial_bins_um = spatial_bins_um + + self.num_segments = len(self.displacement) + self.interpolators = None + self.interpolation_method = interpolation_method + + self.direction = direction + self.dim = ["x", "y", "z"].index(direction) + self.check_properties() + + def check_properties(self): + assert all(d.ndim == 2 for d in self.displacement) + assert all(t.ndim == 1 for t in self.temporal_bins_s) + assert all(self.spatial_bins_um.shape == (d.shape[1],) for d in self.displacement) + + def __repr__(self): + nbins = self.spatial_bins_um.shape[0] + if nbins == 1: + rigid_txt = "rigid" + else: + rigid_txt = f"non-rigid - {nbins} spatial bins" + + interval_s = self.temporal_bins_s[0][1] - self.temporal_bins_s[0][0] + txt = f"Motion {rigid_txt} - interval {interval_s}s - {self.num_segments} segments" + return txt + + def make_interpolators(self): + from scipy.interpolate import RegularGridInterpolator + + self.interpolators = [ + RegularGridInterpolator( + (self.temporal_bins_s[j], self.spatial_bins_um), self.displacement[j], method=self.interpolation_method + ) + for j in range(self.num_segments) + ] + self.temporal_bounds = [(t[0], t[-1]) for t in self.temporal_bins_s] + self.spatial_bounds = (self.spatial_bins_um.min(), self.spatial_bins_um.max()) + + def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_index=None, grid=False): + """Evaluate the motion estimate at times and positions + + Evaluate the motion estimate, returning the (linearly interpolated) estimated displacement + at the given times and locations. + + Parameters + ---------- + times_s: np.array + locations_um: np.array + Either this is a one-dimensional array (a vector of positions along self.dimension), or + else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. + segment_index: int, optional + grid : bool + If grid=False, the default, then times_s and locations_um should have the same one-dimensional + shape, and the returned displacement[i] is the displacement at time times_s[i] and location + locations_um[i]. + If grid=True, times_s and locations_um determine a grid of positions to evaluate the displacement. + Then the returned displacement[i,j] is the displacement at depth locations_um[i] and time times_s[j]. + + Returns + ------- + displacement : np.array + A displacement per input location, of shape times_s.shape if grid=False and (locations_um.size, times_s.size) + if grid=True. + """ + if self.interpolators is None: + self.make_interpolators() + + if segment_index is None: + if self.num_segments == 1: + segment_index = 0 + else: + raise ValueError("Several segment need segment_index=") + + times_s = np.asarray(times_s) + locations_um = np.asarray(locations_um) + + if locations_um.ndim == 1: + locations_um = locations_um + elif locations_um.ndim == 2: + locations_um = locations_um[:, self.dim] + else: + assert False + + times_s = times_s.clip(*self.temporal_bounds[segment_index]) + locations_um = locations_um.clip(*self.spatial_bounds) + + if grid: + # construct a grid over which to evaluate the displacement + locations_um, times_s = np.meshgrid(locations_um, times_s, indexing="ij") + out_shape = times_s.shape + locations_um = locations_um.ravel() + times_s = times_s.ravel() + else: + # usual case: input is a point cloud + assert locations_um.shape == times_s.shape + assert times_s.ndim == 1 + out_shape = times_s.shape + + points = np.column_stack((times_s, locations_um)) + displacement = self.interpolators[segment_index](points) + # reshape to grid domain shape if necessary + displacement = displacement.reshape(out_shape) + + return displacement + + def to_dict(self): + return dict( + displacement=self.displacement, + temporal_bins_s=self.temporal_bins_s, + spatial_bins_um=self.spatial_bins_um, + interpolation_method=self.interpolation_method, + ) + + def save(self, folder): + folder = Path(folder) + folder.mkdir(exist_ok=False, parents=True) + + info_file = folder / f"spikeinterface_info.json" + info = dict( + version=spikeinterface.__version__, + dev_mode=spikeinterface.DEV_MODE, + object="Motion", + num_segments=self.num_segments, + direction=self.direction, + interpolation_method=self.interpolation_method, + ) + with open(info_file, mode="w") as f: + json.dump(check_json(info), f, indent=4) + + np.save(folder / "spatial_bins_um.npy", self.spatial_bins_um) + + for segment_index in range(self.num_segments): + np.save(folder / f"displacement_seg{segment_index}.npy", self.displacement[segment_index]) + np.save(folder / f"temporal_bins_s_seg{segment_index}.npy", self.temporal_bins_s[segment_index]) + + @classmethod + def load(cls, folder): + folder = Path(folder) + + info_file = folder / f"spikeinterface_info.json" + err_msg = f"Motion.load(folder): the folder {folder} does not contain a Motion object." + if not info_file.exists(): + raise IOError(err_msg) + + with open(info_file, "r") as f: + info = json.load(f) + if "object" not in info or info["object"] != "Motion": + raise IOError(err_msg) + + direction = info["direction"] + interpolation_method = info["interpolation_method"] + spatial_bins_um = np.load(folder / "spatial_bins_um.npy") + displacement = [] + temporal_bins_s = [] + for segment_index in range(info["num_segments"]): + displacement.append(np.load(folder / f"displacement_seg{segment_index}.npy")) + temporal_bins_s.append(np.load(folder / f"temporal_bins_s_seg{segment_index}.npy")) + + return cls( + displacement, + temporal_bins_s, + spatial_bins_um, + direction=direction, + interpolation_method=interpolation_method, + ) + + def __eq__(self, other): + for segment_index in range(self.num_segments): + if not np.allclose(self.displacement[segment_index], other.displacement[segment_index]): + return False + if not np.allclose(self.temporal_bins_s[segment_index], other.temporal_bins_s[segment_index]): + return False + + if not np.allclose(self.spatial_bins_um, other.spatial_bins_um): + return False + + return True + + def copy(self): + return Motion( + self.displacement.copy(), + self.temporal_bins_s.copy(), + self.spatial_bins_um.copy(), + interpolation_method=self.interpolation_method, + ) diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py index 597eee7a99..af62ba52ec 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_estimation.py @@ -1,17 +1,11 @@ -import pytest - -import shutil +from pathlib import Path import numpy as np - -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion - -from spikeinterface.sortingcomponents.motion_interpolation import InterpolateMotionRecording +import pytest from spikeinterface.core.node_pipeline import ExtractDenseWaveforms - +from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import LocalizeCenterOfMass - from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -159,34 +153,27 @@ def test_estimate_motion(setup_module): ) kwargs.update(cases_kwargs) - job_kwargs = dict(progress_bar=False) - - motion, temporal_bins, spatial_bins, extra_check = estimate_motion( - recording, peaks, peak_locations, **kwargs, **job_kwargs - ) - + motion, extra_check = estimate_motion(recording, peaks, peak_locations, **kwargs) motions[name] = motion - assert temporal_bins.shape[0] == motion.shape[0] - assert spatial_bins.shape[0] == motion.shape[1] - if cases_kwargs["rigid"]: - assert motion.shape[1] == 1 + assert motion.displacement[0].shape[1] == 1 else: - assert motion.shape[1] > 1 + assert motion.displacement[0].shape[1] > 1 - # Test saving to disk - corrected_rec = InterpolateMotionRecording( - recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" - ) - rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") - if rec_folder.exists(): - shutil.rmtree(rec_folder) - corrected_rec.save(folder=rec_folder) + # # Test saving to disk + # corrected_rec = InterpolateMotionRecording( + # recording, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate" + # ) + # rec_folder = cache_folder / (name.replace("/", "").replace(" ", "_") + "_recording") + # if rec_folder.exists(): + # shutil.rmtree(rec_folder) + # corrected_rec.save(folder=rec_folder) if DEBUG: fig, ax = plt.subplots() - ax.plot(temporal_bins, motion) + seg_index = 0 + ax.plot(motion.temporal_bins_s[0], motion.displacement[seg_index]) # motion_histogram = extra_check['motion_histogram'] # spatial_hist_bins = extra_check['spatial_hist_bin_edges'] @@ -206,33 +193,25 @@ def test_estimate_motion(setup_module): plt.show() # same params with differents engine should be the same - motion0, motion1 = motions["rigid / decentralized / torch"], motions["rigid / decentralized / numpy"] - assert (motion0 == motion1).all() + motion0 = motions["rigid / decentralized / torch"] + motion1 = motions["rigid / decentralized / numpy"] + assert motion0 == motion1 - motion0, motion1 = ( - motions["rigid / decentralized / torch / time_horizon_s"], - motions["rigid / decentralized / numpy / time_horizon_s"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["rigid / decentralized / numpy / time_horizon_s"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = motions["non-rigid / decentralized / torch"], motions["non-rigid / decentralized / numpy"] - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch"] + motion1 = motions["non-rigid / decentralized / numpy"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = ( - motions["non-rigid / decentralized / torch / time_horizon_s"], - motions["non-rigid / decentralized / numpy / time_horizon_s"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch / time_horizon_s"] + motion1 = motions["non-rigid / decentralized / numpy / time_horizon_s"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) - motion0, motion1 = ( - motions["non-rigid / decentralized / torch / spatial_prior"], - motions["non-rigid / decentralized / numpy / spatial_prior"], - ) - # TODO : later torch and numpy used to be the same - # assert np.testing.assert_almost_equal(motion0, motion1) + motion0 = motions["non-rigid / decentralized / torch / spatial_prior"] + motion1 = motions["non-rigid / decentralized / numpy / spatial_prior"] + np.testing.assert_array_almost_equal(motion0.displacement, motion1.displacement) if __name__ == "__main__": diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py index de22ee010d..cb26560272 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_interpolation.py @@ -1,30 +1,39 @@ -import numpy as np +from pathlib import Path +import numpy as np +import pytest +import spikeinterface.core as sc +from spikeinterface import download_dataset from spikeinterface.sortingcomponents.motion_interpolation import ( + InterpolateMotionRecording, correct_motion_on_peaks, + interpolate_motion, interpolate_motion_on_traces, - InterpolateMotionRecording, ) - +from spikeinterface.sortingcomponents.motion_utils import Motion from spikeinterface.sortingcomponents.tests.common import make_dataset def make_fake_motion(rec): - # make a fake motion vector + # make a fake motion object duration = rec.get_total_duration() locs = rec.get_channel_locations() temporal_bins = np.arange(0.5, duration - 0.49, 0.5) spatial_bins = np.arange(locs[:, 1].min(), locs[:, 1].max(), 100) - motion = np.zeros((temporal_bins.size, spatial_bins.size)) - motion[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] + displacement = np.zeros((temporal_bins.size, spatial_bins.size)) + displacement[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None] - return motion, temporal_bins, spatial_bins + motion = Motion([displacement], [temporal_bins], spatial_bins, direction="y") + + return motion def test_correct_motion_on_peaks(): rec, sorting = make_dataset() peaks = sorting.to_spike_vector() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + print(peaks.dtype) + motion = make_fake_motion(rec) + # print(motion) # fake locations peak_locations = np.zeros((peaks.size), dtype=[("x", "float32"), ("y", "float")]) @@ -32,26 +41,25 @@ def test_correct_motion_on_peaks(): corrected_peak_locations = correct_motion_on_peaks( peaks, peak_locations, - rec.sampling_frequency, motion, - temporal_bins, - spatial_bins, - direction="y", + rec, ) # print(corrected_peak_locations) assert np.any(corrected_peak_locations["y"] != 0) # import matplotlib.pyplot as plt # fig, ax = plt.subplots() - # ax.plot(times[peaks['sample_index']], corrected_peak_locations['y']) - # ax.plot(temporal_bins, motion[:, 1]) + # segment_index = 0 + # times = rec.get_times(segment_index=segment_index) + # ax.scatter(times[peaks['sample_index']], corrected_peak_locations['y']) + # ax.plot(motion.temporal_bins_s[segment_index], motion.displacement[segment_index][:, 1]) # plt.show() def test_interpolate_motion_on_traces(): rec, sorting = make_dataset() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + motion = make_fake_motion(rec) channel_locations = rec.get_channel_locations() @@ -64,28 +72,60 @@ def test_interpolate_motion_on_traces(): times, channel_locations, motion, - temporal_bins, - spatial_bins, - direction=1, channel_inds=None, spatial_interpolation_method=method, - spatial_interpolation_kwargs={}, + # spatial_interpolation_kwargs={}, + spatial_interpolation_kwargs={"force_extrapolate": True}, ) assert traces.shape == traces_corrected.shape assert traces.dtype == traces_corrected.dtype +def test_interpolation_simple(): + # a recording where a 1 moves at 1 chan per second. 30 chans 10 frames. + # there will be 9 chans of drift, so we add 9 chans of padding to the bottom + nt = nc0 = 10 # these need to be the same for this test + nc1 = nc0 + nc0 - 1 + traces = np.zeros((nt, nc1), dtype="float32") + traces[:, :nc0] = np.eye(nc0) + rec = sc.NumpyRecording(traces, sampling_frequency=1) + rec.set_dummy_probe_from_locations(np.c_[np.zeros(nc1), np.arange(nc1)]) + + true_motion = Motion(np.arange(nt)[:, None], 0.5 + np.arange(nt), np.zeros(1)) + rec_corrected = interpolate_motion(rec, true_motion, spatial_interpolation_method="nearest") + traces_corrected = rec_corrected.get_traces() + assert traces_corrected.shape == (nc0, nc0) + assert np.array_equal(traces_corrected[:, 0], np.ones(nt)) + assert np.array_equal(traces_corrected[:, 1:], np.zeros((nt, nc0 - 1))) + + # let's try a new version where we interpolate too slowly + rec_corrected = interpolate_motion( + rec, true_motion, spatial_interpolation_method="nearest", num_closest=2, interpolation_time_bin_size_s=2 + ) + traces_corrected = rec_corrected.get_traces() + assert traces_corrected.shape == (nc0, nc0) + # what happens with nearest here? + # well... due to rounding towards the nearest even number, the motion (which at + # these time bin centers is 0.5, 2.5, 4.5, ...) flips the signal's nearest + # neighbor back and forth between the first and second channels + assert np.all(traces_corrected[::2, 0] == 1) + assert np.all(traces_corrected[1::2, 0] == 0) + assert np.all(traces_corrected[1::2, 1] == 1) + assert np.all(traces_corrected[::2, 1] == 0) + assert np.all(traces_corrected[:, 2:] == 0) + + def test_InterpolateMotionRecording(): rec, sorting = make_dataset() - motion, temporal_bins, spatial_bins = make_fake_motion(rec) + motion = make_fake_motion(rec) - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="force_extrapolate") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate") assert rec2.channel_ids.size == 32 - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="force_zeros") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_zeros") assert rec2.channel_ids.size == 32 - rec2 = InterpolateMotionRecording(rec, motion, temporal_bins, spatial_bins, border_mode="remove_channels") + rec2 = InterpolateMotionRecording(rec, motion, border_mode="remove_channels") assert rec2.channel_ids.size == 24 for ch_id in (0, 1, 14, 15, 16, 17, 30, 31): assert ch_id not in rec2.channel_ids @@ -106,6 +146,7 @@ def test_InterpolateMotionRecording(): if __name__ == "__main__": - test_correct_motion_on_peaks() - test_interpolate_motion_on_traces() + # test_correct_motion_on_peaks() + # test_interpolate_motion_on_traces() + test_interpolation_simple() test_InterpolateMotionRecording() diff --git a/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py new file mode 100644 index 0000000000..0b67be39c0 --- /dev/null +++ b/src/spikeinterface/sortingcomponents/tests/test_motion_utils.py @@ -0,0 +1,86 @@ +import pickle +import shutil +from pathlib import Path + +import numpy as np +import pytest +from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.generation import make_one_displacement_vector + +if hasattr(pytest, "global_test_folder"): + cache_folder = pytest.global_test_folder / "sortingcomponents" +else: + cache_folder = Path("cache_folder") / "sortingcomponents" + + +def make_fake_motion(): + displacement_sampling_frequency = 5.0 + spatial_bins_um = np.array([100.0, 200.0, 300.0, 400.0]) + + displacement_vector = make_one_displacement_vector( + drift_mode="zigzag", + duration=50.0, + amplitude_factor=1.0, + displacement_sampling_frequency=displacement_sampling_frequency, + period_s=25.0, + ) + temporal_bins_s = np.arange(displacement_vector.size) / displacement_sampling_frequency + displacement = np.zeros((temporal_bins_s.size, spatial_bins_um.size)) + + n = spatial_bins_um.size + for i in range(n): + displacement[:, i] = displacement_vector * ((i + 1) / n) + + motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") + + return motion + + +def test_Motion(): + + temporal_bins_s = np.arange(0.0, 10.0, 1.0) + spatial_bins_um = np.array([100.0, 200.0]) + + displacement = np.zeros((temporal_bins_s.shape[0], spatial_bins_um.shape[0])) + displacement[:, :] = np.linspace(-20, 20, temporal_bins_s.shape[0])[:, np.newaxis] + + motion = Motion(displacement, temporal_bins_s, spatial_bins_um, direction="y") + assert motion.interpolators is None + + # serialize with pickle before interpolation fit + motion2 = pickle.loads(pickle.dumps(motion)) + assert motion2.interpolators is None + # serialize with pickle after interpolation fit + motion2.make_interpolators() + assert motion2.interpolators is not None + motion2 = pickle.loads(pickle.dumps(motion2)) + assert motion2.interpolators is not None + + # to/from dict + motion2 = Motion(**motion.to_dict()) + assert motion == motion2 + assert motion2.interpolators is None + + # do interpolate + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11], [120.0, 80.0, 150.0]) + # print(displacement) + assert displacement.shape[0] == 3 + # check clip + assert displacement[2] == 20.0 + + # interpolate grid + displacement = motion.get_displacement_at_time_and_depth([2, 4.4, 11, 15, 19], [150.0, 80.0], grid=True) + assert displacement.shape == (2, 5) + assert displacement[0, 2] == 20.0 + + # save/load to folder + folder = cache_folder / "motion_saved" + if folder.exists(): + shutil.rmtree(folder) + motion.save(folder) + motion2 = Motion.load(folder) + assert motion == motion2 + + +if __name__ == "__main__": + test_Motion() diff --git a/src/spikeinterface/widgets/motion.py b/src/spikeinterface/widgets/motion.py index 110555dd6a..fc0c91423d 100644 --- a/src/spikeinterface/widgets/motion.py +++ b/src/spikeinterface/widgets/motion.py @@ -7,14 +7,113 @@ class MotionWidget(BaseWidget): """ - Plot unit depths + Plot the Motion object + + Parameters + ---------- + motion : Motion + The motion object + segment_index : None | int + If Motion is multi segment, the must be not None + mode : "auto" | "line" | "map" + How to plot map or lines. + "auto" make it automatic if the number of depth is too high. + """ + + def __init__( + self, + motion, + segment_index=None, + mode="line", + motion_lim=None, + backend=None, + **backend_kwargs, + ): + if isinstance(motion, dict): + raise ValueError( + "The API has changed, plot_motion() used Motion object now, maybe you want plot_motion_info(motion_info)" + ) + + if segment_index is None: + if len(motion.displacement) == 1: + segment_index = 0 + else: + raise ValueError("plot motion : the Motion object is multi segment you must provide segment_index=XX") + + plot_data = dict( + motion=motion, + segment_index=segment_index, + mode=mode, + motion_lim=motion_lim, + ) + + BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs) + + def plot_matplotlib(self, data_plot, **backend_kwargs): + import matplotlib.pyplot as plt + from .utils_matplotlib import make_mpl_figure + from matplotlib.colors import Normalize + + dp = to_attr(data_plot) + + motion = data_plot["motion"] + segment_index = data_plot["segment_index"] + + assert backend_kwargs["axes"] is None + + self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) + + displacement = motion.displacement[dp.segment_index] + temporal_bins_s = motion.temporal_bins_s[dp.segment_index] + depth = motion.spatial_bins_um + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(displacement)) * 1.05 + else: + motion_lim = dp.motion_lim + + ax = self.ax + fig = self.figure + if dp.mode == "line": + ax.plot(temporal_bins_s, displacement, alpha=0.2, color="black") + ax.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") + ax.set_xlabel("Times [s]") + ax.set_ylabel("motion [um]") + elif dp.mode == "map": + im = ax.imshow( + displacement.T, + interpolation="nearest", + aspect="auto", + origin="lower", + extent=(temporal_bins_s[0], temporal_bins_s[-1], depth[0], depth[-1]), + cmap="PiYG", + ) + im.set_clim(-motion_lim, motion_lim) + + cbar = fig.colorbar(im) + cbar.ax.set_ylabel("motion [um]") + ax.set_xlabel("Times [s]") + ax.set_ylabel("Depth [um]") + + +class MotionInfoWidget(BaseWidget): + """ + Plot motion information from the motion_info dict returned by correct_motion(). + This plot: + * the motion iself + * the peak depth vs time before correction + * the peak depth vs time after correction Parameters ---------- motion_info : dict The motion info return by correct_motion() or load back with load_motion_info() + segment_index : int, default: None + The segment index to display. recording : RecordingExtractor, default: None The recording extractor object (only used to get "real" times) + segment_index : int, default: 0 + The segment index to display. sampling_frequency : float, default: None The sampling frequency (needed if recording is None) depth_lim : tuple or None, default: None @@ -36,6 +135,7 @@ class MotionWidget(BaseWidget): def __init__( self, motion_info, + segment_index=None, recording=None, depth_lim=None, motion_lim=None, @@ -47,11 +147,20 @@ def __init__( backend=None, **backend_kwargs, ): + + motion = motion_info["motion"] + if segment_index is None: + if len(motion.displacement) == 1: + segment_index = 0 + else: + raise ValueError("plot motion : teh Motion object is multi segment you must provide segmentindex=XX") + times = recording.get_times() if recording is not None else None plot_data = dict( sampling_frequency=motion_info["parameters"]["sampling_frequency"], times=times, + segment_index=segment_index, depth_lim=depth_lim, motion_lim=motion_lim, color_amplitude=color_amplitude, @@ -59,6 +168,7 @@ def __init__( amplitude_cmap=amplitude_cmap, amplitude_clim=amplitude_clim, amplitude_alpha=amplitude_alpha, + recording=recording, **motion_info, ) @@ -73,16 +183,29 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): dp = to_attr(data_plot) - assert backend_kwargs["axes"] is None - assert backend_kwargs["ax"] is None + assert backend_kwargs["axes"] is None, "axes argument is not allowed in MotionWidget" + assert backend_kwargs["ax"] is None, "ax argument is not allowed in MotionWidget" self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs) fig = self.figure fig.clear() - is_rigid = dp.motion.shape[1] == 1 + is_rigid = dp.motion.spatial_bins_um.shape[0] == 1 - gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.3) + motion = dp.motion + + displacement = motion.displacement[dp.segment_index] + temporal_bins_s = motion.temporal_bins_s[dp.segment_index] + spatial_bins_um = motion.spatial_bins_um + + if dp.motion_lim is None: + motion_lim = np.max(np.abs(displacement)) * 1.05 + else: + motion_lim = dp.motion_lim + + is_rigid = displacement.shape[1] == 1 + + gs = fig.add_gridspec(2, 2, wspace=0.3, hspace=0.5) ax0 = fig.add_subplot(gs[0, 0]) ax1 = fig.add_subplot(gs[0, 1]) ax2 = fig.add_subplot(gs[1, 0]) @@ -91,31 +214,23 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax1.sharex(ax0) ax1.sharey(ax0) - if dp.motion_lim is None: - motion_lim = np.max(np.abs(dp.motion)) * 1.05 - else: - motion_lim = dp.motion_lim - if dp.times is None: - temporal_bins_plot = dp.temporal_bins + # temporal_bins_plot = dp.temporal_bins x = dp.peaks["sample_index"] / dp.sampling_frequency else: # use real times and adjust temporal bins with t_start - temporal_bins_plot = dp.temporal_bins + dp.times[0] + # temporal_bins_plot = dp.temporal_bins + dp.times[0] x = dp.times[dp.peaks["sample_index"]] corrected_location = correct_motion_on_peaks( dp.peaks, dp.peak_locations, - dp.sampling_frequency, dp.motion, - dp.temporal_bins, - dp.spatial_bins, - direction="y", + dp.recording, ) - y = dp.peak_locations["y"] - y2 = corrected_location["y"] + y = dp.peak_locations[motion.direction] + y2 = corrected_location[motion.direction] if dp.scatter_decimate is not None: x = x[:: dp.scatter_decimate] y = y[:: dp.scatter_decimate] @@ -149,37 +264,38 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax0.set_ylim(*dp.depth_lim) ax0.set_title("Peak depth") ax0.set_xlabel("Times [s]") - ax0.set_ylabel("Depth [um]") + ax0.set_ylabel("Depth [$\\mu$m]") ax1.scatter(x, y2, s=1, **color_kwargs) ax1.set_xlabel("Times [s]") - ax1.set_ylabel("Depth [um]") + ax1.set_ylabel("Depth [$\\mu$m]") ax1.set_title("Corrected peak depth") - ax2.plot(temporal_bins_plot, dp.motion, alpha=0.2, color="black") - ax2.plot(temporal_bins_plot, np.mean(dp.motion, axis=1), color="C0") + ax2.plot(temporal_bins_s, displacement, alpha=0.2, color="black") + ax2.plot(temporal_bins_s, np.mean(displacement, axis=1), color="C0") ax2.set_ylim(-motion_lim, motion_lim) - ax2.set_ylabel("Motion [um]") + ax2.set_ylabel("Motion [$\\mu$m]") + ax2.set_xlabel("Times [s]") ax2.set_title("Motion vectors") axes = [ax0, ax1, ax2] if not is_rigid: im = ax3.imshow( - dp.motion.T, + displacement.T, aspect="auto", origin="lower", extent=( - temporal_bins_plot[0], - temporal_bins_plot[-1], - dp.spatial_bins[0], - dp.spatial_bins[-1], + temporal_bins_s[0], + temporal_bins_s[-1], + spatial_bins_um[0], + spatial_bins_um[-1], ), ) im.set_clim(-motion_lim, motion_lim) cbar = fig.colorbar(im) - cbar.ax.set_xlabel("motion [um]") + cbar.ax.set_ylabel("Motion [$\\mu$m]") ax3.set_xlabel("Times [s]") - ax3.set_ylabel("Depth [um]") + ax3.set_ylabel("Depth [$\\mu$m]") ax3.set_title("Motion vectors") axes.append(ax3) self.axes = np.array(axes) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index fdc937dc25..e841a1c93b 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -2,6 +2,8 @@ import pytest import os +import numpy as np + if __name__ != "__main__": try: import matplotlib @@ -578,6 +580,38 @@ def test_plot_multicomparison(self): _, axes = plt.subplots(len(mcmp.object_list), 1) sw.plot_multicomparison_agreement_by_sorter(mcmp, axes=axes) + def test_plot_motion(self): + from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + + motion = make_fake_motion() + + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_motion(motion, backend=backend, mode="line") + sw.plot_motion(motion, backend=backend, mode="map") + + def test_plot_motion_info(self): + from spikeinterface.sortingcomponents.tests.test_motion_utils import make_fake_motion + + motion = make_fake_motion() + rng = np.random.default_rng(seed=2205) + peak_locations = np.zeros(self.peaks.size, dtype=[("x", "float64"), ("y", "float64")]) + peak_locations["y"] = rng.uniform(motion.spatial_bins_um[0], motion.spatial_bins_um[-1], size=self.peaks.size) + + motion_info = dict( + motion=motion, + parameters=dict(sampling_frequency=30000.0), + run_times=dict(), + peaks=self.peaks, + peak_locations=peak_locations, + ) + + possible_backends = list(sw.MotionWidget.get_possible_backends()) + for backend in possible_backends: + if backend not in self.skip_backends: + sw.plot_motion_info(motion_info, recording=self.recording, backend=backend) + if __name__ == "__main__": # unittest.main() @@ -592,7 +626,7 @@ def test_plot_multicomparison(self): # mytest.test_plot_traces() # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_waveforms() - mytest.test_plot_spikes_on_traces() + # mytest.test_plot_spikes_on_traces() # mytest.test_plot_unit_depths() # mytest.test_plot_autocorrelograms() # mytest.test_plot_crosscorrelograms() @@ -612,6 +646,8 @@ def test_plot_multicomparison(self): # mytest.test_plot_peak_activity() # mytest.test_plot_multicomparison() # mytest.test_plot_sorting_summary() + # mytest.test_plot_motion() + mytest.test_plot_motion_info() plt.show() # TestWidgets.tearDownClass() diff --git a/src/spikeinterface/widgets/widget_list.py b/src/spikeinterface/widgets/widget_list.py index b65fe97a3c..d6df59b0f3 100644 --- a/src/spikeinterface/widgets/widget_list.py +++ b/src/spikeinterface/widgets/widget_list.py @@ -10,7 +10,7 @@ from .autocorrelograms import AutoCorrelogramsWidget from .crosscorrelograms import CrossCorrelogramsWidget from .isi_distribution import ISIDistributionWidget -from .motion import MotionWidget +from .motion import MotionWidget, MotionInfoWidget from .multicomparison import MultiCompGraphWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget from .peak_activity import PeakActivityMapWidget from .potential_merges import PotentialMergesWidget @@ -45,6 +45,7 @@ CrossCorrelogramsWidget, ISIDistributionWidget, MotionWidget, + MotionInfoWidget, MultiCompGlobalAgreementWidget, MultiCompAgreementBySorterWidget, MultiCompGraphWidget, @@ -117,6 +118,7 @@ plot_crosscorrelograms = CrossCorrelogramsWidget plot_isi_distribution = ISIDistributionWidget plot_motion = MotionWidget +plot_motion_info = MotionInfoWidget plot_multicomparison_agreement = MultiCompGlobalAgreementWidget plot_multicomparison_agreement_by_sorter = MultiCompAgreementBySorterWidget plot_multicomparison_graph = MultiCompGraphWidget