From 37014bbcb41344c1f7be1aeeb999a7ad75311e53 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 27 Jun 2024 14:15:48 +0200 Subject: [PATCH] more refactoring and parameters change --- .../benchmark/benchmark_motion_estimation.py | 4 +- .../benchmark/benchmark_tools.py | 2 +- .../sortingcomponents/motion/decentralized.py | 23 ++++++---- .../sortingcomponents/motion/dredge.py | 31 +++++++------ .../motion/iterative_template.py | 16 +++++-- .../motion/motion_cleaner.py | 2 +- .../motion/motion_estimation.py | 35 ++++----------- .../sortingcomponents/motion/motion_utils.py | 43 ++++++++++++++----- 8 files changed, 90 insertions(+), 66 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py index 55ef21de9d..ec7e1e24a8 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_motion_estimation.py @@ -9,13 +9,13 @@ from spikeinterface.core import get_noise_levels from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis -from spikeinterface.sortingcomponents.motion_estimation import estimate_motion +from spikeinterface.sortingcomponents.motion import estimate_motion from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_selection import select_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks from spikeinterface.widgets import plot_probe_map -from spikeinterface.sortingcomponents.motion_utils import Motion +from spikeinterface.sortingcomponents.motion import Motion # import MEArec as mr diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py index e9f128993d..7dc3fad280 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_tools.py @@ -443,7 +443,7 @@ def load_folder(cls, folder): result[k] = load_extractor(folder / k) elif format == "Motion": - from spikeinterface.sortingcomponents.motion_utils import Motion + from spikeinterface.sortingcomponents.motion import Motion result[k] = Motion.load(folder / k) elif format == "zarr_templates": diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py index baf0fae1f2..5815f77a48 100644 --- a/src/spikeinterface/sortingcomponents/motion/decentralized.py +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -10,7 +10,7 @@ except ImportError: HAVE_TORCH = False -from .motion_utils import Motion, get_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d class DecentralizedRegistration: @@ -113,11 +113,11 @@ def run( verbose, progress_bar, extra, - bin_um=10.0, - hist_margin_um=0.0, - bin_duration_s=2.0, - histogram_depth_smooth_um=None, - histogram_time_smooth_s=None, + bin_um=1.0, + hist_margin_um=20.0, + bin_duration_s=1.0, + histogram_depth_smooth_um=1., + histogram_time_smooth_s=1., pairwise_displacement_method="conv", max_displacement_um=100.0, weight_scale="linear", @@ -152,8 +152,15 @@ def run( spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) # get spatial windows - non_rigid_windows, non_rigid_window_centers = get_windows( - rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( + contact_depth, + spatial_bin_centers, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + zero_threshold=None ) # make 2D histogram raster diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index 779b236d77..f2a46bdfcb 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -28,7 +28,7 @@ import gc -from .motion_utils import Motion, get_windows, get_window_domains, scipy_conv1d, make_2d_motion_histogram +from .motion_utils import Motion, get_spatial_windows, get_window_domains, scipy_conv1d, make_2d_motion_histogram # to discuss @@ -42,6 +42,8 @@ # put smotthing inside the histogram function # put the log for weight inhitstogram +# TODO maybe change everywhere bin_duration_s to bin_s + # simple class wrapper to be compliant with estimate_motion class DredgeApRegistration: @@ -221,6 +223,9 @@ def dredge_ap( if time_horizon_s: xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s) + #TODO @charlie I think this is a bad to have the dict which is transported to every function + # this should be used only in histogram function but not in weight_correlation_matrix() + # only important kwargs should be explicitly reported raster_kw = dict( amp_scale_fn=amp_scale_fn, post_transform=post_transform, @@ -277,7 +282,7 @@ def dredge_ap( # TODO @charlie check that we are doing the same thing - # windows, window_centers = get_windows( + # windows, window_centers = get_spatial_windows( # np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um], # win_step_um, # win_scale_um, @@ -292,16 +297,18 @@ def dredge_ap( contact_depth = recording.get_channel_locations()[:, dim] spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) - windows, window_centers = get_windows( - rigid, + windows, window_centers = get_spatial_windows( contact_depth, spatial_bin_centers, - win_margin_um, - win_step_um, - win_scale_um, - win_shape, - zero_threshold=1e-5, - ) + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + zero_threshold=1e-5 + ) + + # TODO charlie : put back the count # if extra_outputs and count_masked_correlation: @@ -553,10 +560,10 @@ def dredge_online_lfp( # Important detail : in LFP bin center are contact position in the direction spatial_bin_centers = contact_depth - windows, window_centers = get_windows( - rigid=rigid, + windows, window_centers = get_spatial_windows( contact_depth=contact_depth, spatial_bin_centers=spatial_bin_centers, + rigid=rigid, win_margin_um=win_margin_um, win_step_um=win_step_um, win_scale_um=win_scale_um, diff --git a/src/spikeinterface/sortingcomponents/motion/iterative_template.py b/src/spikeinterface/sortingcomponents/motion/iterative_template.py index ab6877adc3..e7c978e865 100644 --- a/src/spikeinterface/sortingcomponents/motion/iterative_template.py +++ b/src/spikeinterface/sortingcomponents/motion/iterative_template.py @@ -1,6 +1,6 @@ import numpy as np -from .motion_utils import Motion, get_windows, get_spatial_bin_edges, make_3d_motion_histograms +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_3d_motion_histograms class IterativeTemplateRegistration: @@ -78,18 +78,26 @@ def run( ): dim = ["x", "y", "z"].index(direction) - contact_pos = recording.get_channel_locations()[:, dim] + contact_depth = recording.get_channel_locations()[:, dim] # spatial histogram bins spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) # get spatial windows - non_rigid_windows, non_rigid_window_centers = get_windows( - rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( + contact_depth=contact_depth, + spatial_bin_centers=spatial_bin_centers, + rigid=rigid, + win_margin_um=win_margin_um, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_shape=win_shape, + zero_threshold=None, ) + # make a 3D histogram motion_histograms, temporal_hist_bin_edges, spatial_hist_bin_edges = make_3d_motion_histograms( recording, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py index 401210e079..de2c7df4cc 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_cleaner.py @@ -1,6 +1,6 @@ import numpy as np - +# TODO this need a full rewrite with motion object def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=30, sigma_smooth_s=None): """ diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index b6fa344def..aab3a1b491 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -8,7 +8,7 @@ from spikeinterface.sortingcomponents.tools import make_multi_method_doc -from .motion_utils import Motion, get_windows, get_spatial_bin_edges +from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges from .decentralized import DecentralizedRegistration from .iterative_template import IterativeTemplateRegistration from .dredge import DredgeLfpRegistration, DredgeApRegistration @@ -19,13 +19,11 @@ def estimate_motion( peaks=None, peak_locations=None, direction="y", - # bin_um=10.0, - # hist_margin_um=0.0, rigid=False, win_shape="gaussian", win_step_um=50.0, win_scale_um=150.0, - win_margin_um=0., + win_margin_um=None, method="decentralized", extra_outputs=False, progress_bar=False, @@ -36,7 +34,8 @@ def estimate_motion( """ Estimate motion for given peaks and after their localization. - Note that the way you detect peak locations (center of mass/monopolar triangulation) have an impact on the result. + Note that the way you detect peak locations (center of mass/monopolar triangulation) + have an impact on the result. Parameters ---------- @@ -70,14 +69,14 @@ def estimate_motion( The depth domain will be broken up into windows with shape controlled by win_shape, spaced by win_step_um at a margin of win_margin_um from the boundary, and with width controlled by win_scale_um. + When win_margin_um is None the margin is automatically set to -win_scale_um/2. + See get_spatial_windows. win_step_um : float, default: 50 See win_shape win_scale_um : float, default: 150 See win_shape - win_margin_um : float, default: 0. - See win_shape - - + win_margin_um : None | float, default: None + See win_shape extra_outputs: bool, default: False If True then return an extra dict that contains variables to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement) @@ -114,25 +113,7 @@ def estimate_motion( else: extra = None - # contact positions - # probe = recording.get_probe() - # dim = ["x", "y", "z"].index(direction) - # contact_pos = probe.contact_positions[:, dim] - - # # spatial histogram bins - # spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) - # spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1]) - - # # get spatial windows - # non_rigid_windows, non_rigid_window_centers = get_windows( - # rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape - # ) - - # if extra_outputs: - # extra["non_rigid_windows"] = non_rigid_windows - # run method - motion = method_class.run( recording, peaks, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 3fb0f8505a..f0442996a6 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -232,29 +232,45 @@ def copy(self): -def get_windows(rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape, - zero_threshold=None): +def get_spatial_windows( + contact_depth, + spatial_bin_centers, + rigid=False, + win_shape="gaussian", + win_step_um=50.0, + win_scale_um=150.0, + win_margin_um=None, + zero_threshold=None + ): """ Generate spatial windows (taper) for non-rigid motion. For rigid motion, this is equivalent to have one unique rectangular window that covers the entire probe. The windowing can be gaussian or rectangular. + Windows are centered between the min/max of contact_depth. + We can ensure window to not be to close from border with win_margin_um. + Parameters ---------- - rigid : bool - If True, returns a single rectangular window contact_depth : np.ndarray Position of electrodes of the corection direction shape=(num_channels, ) spatial_bin_centers : np.array The pre-computed spatial bin centers - win_margin_um : float - The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows.= + rigid : bool, default False + If True, returns a single rectangular window + win_shape : str, default "gaussian" + Shape of the window + "gaussian" | "rect" | "triangle" win_step_um : float The steps at which windows are defined - win_scale_um : float - Sigma of gaussian window (if win_shape is gaussian) - win_shape : float - "gaussian" | "rect" + win_scale_um : float, default 150. + Sigma of gaussian window if win_shape is gaussian + Width of the rectangle if win_shape is rect + win_margin_um : None | float, default None + The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows. + When None, then the margin is set to -win_scale_um./2 + zero_threshold: None | float + Lower value for thresholding to set zeros. Returns ------- @@ -280,9 +296,14 @@ def get_windows(rigid, contact_depth, spatial_bin_centers, win_margin_um, win_st else: if win_scale_um <= win_step_um/5.: warnings.warn( - f"get_windows(): spatial windows are probably not overlaping because {win_scale_um=} and {win_step_um=}" + f"get_spatial_windows(): spatial windows are probably not overlaping because {win_scale_um=} and {win_step_um=}" ) + # @charlie: I am pretty sure this is the best option + if win_margin_um is None: + # this ensure that first/last windows do not overflow outside the probe + win_margin_um = -win_scale_um / 2. + min_ = np.min(contact_depth) - win_margin_um max_ = np.max(contact_depth) + win_margin_um num_windows = int((max_ - min_) // win_step_um)