From 532ea488d7bd8bd22058d64e74a8c6c0027fca79 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 26 Jun 2024 22:55:48 +0200 Subject: [PATCH] start porting dredge_ap --- .../sortingcomponents/motion/decentralized.py | 18 +- .../sortingcomponents/motion/dredge.py | 451 +++++++++++++++++- .../motion/motion_estimation.py | 4 +- .../sortingcomponents/motion/motion_utils.py | 16 + 4 files changed, 466 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py index 3b8f19cc3e..9a76b94d20 100644 --- a/src/spikeinterface/sortingcomponents/motion/decentralized.py +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -168,23 +168,9 @@ def run( bin_duration_s=bin_duration_s, spatial_bin_edges=spatial_bin_edges, weight_with_amplitude=weight_with_amplitude, + depth_smooth_um=histogram_depth_smooth_um, + time_smooth_s=histogram_time_smooth_s, ) - import scipy.signal - - if histogram_depth_smooth_um is not None: - bins = np.arange(motion_histogram.shape[1]) * bin_um - bins = bins - np.mean(bins) - smooth_kernel = np.exp(-(bins**2) / (2 * histogram_depth_smooth_um**2)) - smooth_kernel /= np.sum(smooth_kernel) - - motion_histogram = scipy.signal.fftconvolve(motion_histogram, smooth_kernel[None, :], mode="same", axes=1) - - if histogram_time_smooth_s is not None: - bins = np.arange(motion_histogram.shape[0]) * bin_duration_s - bins = bins - np.mean(bins) - smooth_kernel = np.exp(-(bins**2) / (2 * histogram_time_smooth_s**2)) - smooth_kernel /= np.sum(smooth_kernel) - motion_histogram = scipy.signal.fftconvolve(motion_histogram, smooth_kernel[:, None], mode="same", axes=0) if extra is not None: extra["motion_histogram"] = motion_histogram diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index 70dfc86107..60d93f355c 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -14,17 +14,318 @@ but here the original functions from Charlie, Julien and Erdem have been ported for an easier maintenance instead of making DREDge a dependency of spikeinterface. + +Some renaming has been done. Small details has been added. +But this code is very similar to the original code. +2 classes has been added : DredgeApRegistration and DredgeLfpRegistration +but the original function dredge_ap() and dredge_online_lfp() can be used directly. + """ import warnings from tqdm.auto import trange import numpy as np +import gc from .motion_utils import Motion, get_windows, get_window_domains, scipy_conv1d -# TODO add direction +# to discuss +# margin +# xcorr new function +# which dataset band usefull for ? +# dredge_ap +# use patient 2 + +# todo use gaussian_filter1d in historgam 2d +# put smotthing inside the histogram function +# put the log for weight inhitstogram + + +# simple class wrapper to be compliant with estimate_motion +class DredgeApRegistration: + """ + + """ + name = "dredge_ap" + need_peak_location = True + params_doc = """ + + """ + @classmethod + def run( + cls, + recording, + peaks, + peak_locations, + direction, + rigid, + win_shape, + win_step_um, + win_scale_um, + win_margin_um, + verbose, + progress_bar, + extra, + **method_kwargs, + ): + dim = ["x", "y", "z"].index(direction) + peak_amplitudes = peaks["amplitude"] + peak_depths = peak_locations[direction] + peak_times = recording.sample_index_to_time(peaks["sample_index"]) + + + outs = dredge_ap( + peak_amplitudes, + peak_depths, + peak_times, + direction=direction, + rigid=rigid, + win_shape=win_shape, + win_step_um=win_step_um, + win_scale_um=win_scale_um, + win_margin_um=win_margin_um, + extra_outputs=(extra is not None), + progress_bar=progress_bar, + **method_kwargs, + ) + + if extra is not None: + motion, extra_ = outs + extra.update(extra_) + else: + motion = outs + return motion + +# @TODO : Charlie I started very small refactoring, I let you continue +def dredge_ap( + amps, + depths_um, + times_s, + direction="y", + rigid=False, + # nonrigid window construction arguments + win_shape="gaussian", + win_step_um=400, + win_scale_um=450, + win_margin_um=None, + bin_um=1.0, + bin_s=1.0, + max_disp_um=None, + time_horizon_s=1000, + mincorr=0.1, + # weights arguments + do_window_weights=True, + weights_threshold_low=0.2, + weights_threshold_high=0.2, + mincorr_percentile=None, + mincorr_percentile_nneighbs=None, + # raster arguments + amp_scale_fn=None, + post_transform=np.log1p, + gaussian_smoothing_sigma_um=1, + gaussian_smoothing_sigma_s=1, + avg_in_bin=False, + count_masked_correlation=False, + count_bins=401, + count_bin_min=2, + # low-level keyword args + thomas_kw=None, + xcorr_kw=None, + # misc + device=None, + progress_bar=True, + extra_outputs=False, + precomputed_D_C_maxdisp=None, +): + """Estimate motion from spikes + + Spikes located at depths specified in `depths` along the probe, occurring at times in + seconds specified in `times` with amplitudes `amps` are used to create a 2d image of + the spiking activity. This image is cross-correlated with itself to produce a displacement + matrix (or several, one for each nonrigid window). This matrix is used to solve for a + motion estimate. + + Arguments + --------- + amps : np.array of shape (n_spikes,) + depths: np.array of shape (n_spikes,) + times : np.array of shape (n_spikes,) + The amplitudes, depths (microns) and times (seconds) of input + spike events. + direction : "x" | "y", default "y" + Dimension on which the motion is estimated. "y" is depth along the probe. + rigid : bool, default=False + If True, ignore the nonrigid window args (win_shape, win_step_um, win_scale_um, + win_margin_um) and do rigid registration (equivalent to one flat window, which + is how it's implemented). + win_shape : str, default="gaussian" + Nonrigid window shape + win_step_um : float + Spacing between nonrigid window centers in microns + win_scale_um : float + Controls the width of nonrigid windows centers + win_margin_um : float + Distance of nonrigid windows centers from the probe boundary (-1000 means there will + be no window center within 1000um of the edge of the probe) + bin_um: float + bin_s : float + The size of the bins along depth in microns and along time in seconds. + The returned object's .displacement array will respect these bins. + Increasing these can lead to more stable estimates and faster runtimes + at the cost of spatial and/or temporal resolution. + max_disp_um : float + Maximum possible displacement in microns. If you can guess a number which is larger + than the largest displacement possible in your recording across a span of `time_horizon_s` + seconds, setting this value to that number can stabilize the result and speed up + the algorithm (since it can do less cross-correlating). + By default, this is set to win-scale_um / 4, or 112.5 microns. Which can be a bit + large! + time_horizon_s : float + "Time horizon" parameter, in seconds. Time bins separated by more seconds than this + will not be cross-correlated. So, if your data has nonstationarities or changes which + could lead to bad cross-correlations at some timescale, it can help to input that + value here. If this is too small, it can make the motion estimation unstable. + mincorr : float, between 0 and 1 + Correlation threshold. Pairs of frames whose maximal cross correlation value is smaller + than this threshold will be ignored when solving for the global displacement estimate. + thomas_kw, xcorr_kw, raster_kw, weights_kw + These dictionaries allow setting parameters for fine control over the registration + device : str or torch.device + What torch device to run on? E.g., "cpu" or "cuda" or "cuda:1". + + Returns + ------- + motion_est : a motion_util.MotionEstimate object + This has a .displacement attribute which is the displacement estimate in a + (num_nonrigid_blocks, num_time_bins) array. It also has properties describing + the time and spatial bins, and methods for getting the displacement at a particular + time and depth. See the documentation of these classes in motion_util.py. + extra : dict + This has extra info about what happened during registration, including the nonrigid + windows if one wants to visualize them. Set `extra_outputs` to also save displacement + and correlation matrices. + """ + thomas_kw = thomas_kw if thomas_kw is not None else {} + xcorr_kw = xcorr_kw if xcorr_kw is not None else {} + if time_horizon_s: + xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s) + raster_kw = dict( + amp_scale_fn=amp_scale_fn, + post_transform=post_transform, + gaussian_smoothing_sigma_um=gaussian_smoothing_sigma_um, + gaussian_smoothing_sigma_s=gaussian_smoothing_sigma_s, + bin_s=bin_s, + bin_um=bin_um, + avg_in_bin=avg_in_bin, + return_counts=count_masked_correlation, + count_bins=count_bins, + count_bin_min=count_bin_min, + ) + weights_kw = dict( + mincorr=mincorr, + time_horizon_s=time_horizon_s, + do_window_weights=do_window_weights, + weights_threshold_low=weights_threshold_low, + weights_threshold_high=weights_threshold_high, + ) + + # this will store return values other than the MotionEstimate + extra = {} + + # TODO charlie switch this to make_2d_motion_histogram after having putting more option + raster_res = spike_raster( + amps, + depths_um, + times_s, + **raster_kw, + ) + if count_masked_correlation: + raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res + else: + raster, spatial_bin_edges_um, time_bin_edges_s = raster_res + windows, window_centers = get_windows( + # pseudo geom to fool spikeinterface + np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um], + win_step_um, + win_scale_um, + spatial_bin_edges=spatial_bin_edges_um, + margin_um=-win_scale_um / 2 if win_margin_um is None else win_margin_um, + win_shape=win_shape, + zero_threshold=1e-5, + rigid=rigid, + ) + if extra_outputs and count_masked_correlation: + extra["counts"] = counts + + # cross-correlate to get D and C + if precomputed_D_C_maxdisp is None: + Ds, Cs, max_disp_um = xcorr_windows( + raster, + windows, + spatial_bin_edges_um, + win_scale_um, + rigid=rigid, + bin_um=bin_um, + max_disp_um=max_disp_um, + progress_bar=progress_bar, + device=device, + masks=(counts > 0) if count_masked_correlation else None, + **xcorr_kw, + ) + else: + Ds, Cs, max_disp_um = precomputed_D_C_maxdisp + + # turn Cs into weights + Us, wextra = weight_correlation_matrix( + Ds, + Cs, + windows, + raster, + spatial_bin_edges_um, + time_bin_edges_s, + raster_kw, + lambda_t=thomas_kw.get("lambda_t", DEFAULT_LAMBDA_T), + eps=thomas_kw.get("eps", DEFAULT_EPS), + progress_bar=progress_bar, + in_place=not extra_outputs, + **weights_kw, + ) + extra.update({k: wextra[k] for k in wextra if k not in ("S", "U")}) + if extra_outputs: + extra.update({k: wextra[k] for k in wextra if k in ("S", "U")}) + del wextra + if extra_outputs: + extra["D"] = Ds + extra["C"] = Cs + del Cs + + # @charlie : is this needed ? + gc.collect() + + # solve for P + # now we can do our tridiag solve + displacement, textra = thomas_solve(Ds, Us, progress_bar=progress_bar, **thomas_kw) + if extra_outputs: + extra.update(textra) + del textra + + if extra_outputs: + extra["windows"] = windows + extra["window_centers"] = window_centers + extra["max_disp_um"] = max_disp_um + + time_bin_centers = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) + spatial_bin_centers_um = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) + + motion = Motion([displacement], [time_bin_centers], spatial_bin_centers_um, direction=direction) + + if extra_outputs: + return motion, extra + else: + return motion + # simple class wrapper to be compliant with estimate_motion class DredgeLfpRegistration: @@ -72,10 +373,9 @@ def run( if extra is not None: motion, extra_ = outs extra.update(extra_) - else: motion = outs - + return motion @@ -117,6 +417,8 @@ def dredge_online_lfp( be the target resolution of the registration, so definitely use SpikeInterface to resample your recording to, say, 250Hz (or a value you like) rather than estimating motion at the original frequency (which may be high). + direction : "x" | "y", default "y" + Dimension on which the motion is estimated. "y" is depth along the probe. rigid : boolean, optional If True, window-related arguments are ignored and we do rigid registration win_shape, win_step_um, win_scale_um, win_margin_um : float @@ -310,7 +612,7 @@ def dredge_online_lfp( t0, t1 = t1, t2 traces0 = traces1 - motion = Motion([P_online.T], [lfp_recording.get_times(0)], window_centers, direction="y") + motion = Motion([P_online.T], [lfp_recording.get_times(0)], window_centers, direction=direction) if extra_outputs: return motion, extra @@ -941,4 +1243,143 @@ def normxcorr1d( corr /= npx.sqrt(var_x) corr /= npx.sqrt(var_template) - return corr \ No newline at end of file + return corr + + +def get_weights( + Ds, + Ss, + Sigma0inv_t, + windows, + raster, + dbe, + tbe, + raster_kw, + weights_threshold_low=0.0, + weights_threshold_high=np.inf, + progress_bar=False, +): + """Compute per-time-bin weighting for each nonrigid window""" + # determine window-weighted raster "heat" in each nonrigid window + # as a function of time + assert windows.shape[1] == dbe.size - 1 + weights = [] + p_inds = [] + for b in range((len(Ds))): + ilow, ihigh = np.flatnonzero(windows[b])[[0, -1]] + ihigh += 1 + window_sliced = windows[b, ilow:ihigh] + weights.append(window_sliced @ raster[ilow:ihigh]) + weights_orig = np.array(weights) + + scale_fn = raster_kw["post_transform"] or raster_kw["amp_scale_fn"] + if isinstance(weights_threshold_low, tuple): + nspikes_threshold_low, amp_threshold_low = weights_threshold_low + unif = np.full_like(windows[0], 1 / len(windows[0])) + weights_threshold_low = ( + scale_fn(amp_threshold_low) + * windows + @ (nspikes_threshold_low * unif) + ) + weights_threshold_low = weights_threshold_low[:, None] + if isinstance(weights_threshold_high, tuple): + nspikes_threshold_high, amp_threshold_high = weights_threshold_high + unif = np.full_like(windows[0], 1 / len(windows[0])) + weights_threshold_high = ( + scale_fn(amp_threshold_high) + * windows + @ (nspikes_threshold_high * unif) + ) + weights_threshold_high = weights_threshold_high[:, None] + weights_thresh = weights_orig.copy() + weights_thresh[weights_orig < weights_threshold_low] = 0 + weights_thresh[weights_orig > weights_threshold_high] = np.inf + + return weights, weights_thresh, p_inds + +def weight_correlation_matrix( + Ds, + Cs, + windows, + raster, + depth_bin_edges, + time_bin_edges, + raster_kw, + mincorr=0.0, + mincorr_percentile=None, + mincorr_percentile_nneighbs=20, + max_dt_s=None, + lambda_t=DEFAULT_LAMBDA_T, + eps=DEFAULT_EPS, + do_window_weights=True, + weights_threshold_low=0.0, + weights_threshold_high=np.inf, + progress_bar=True, + in_place=False, +): + """Transform the correlation matrix into the weights used in optimization.""" + extra = {} + + Ds = np.asarray(Ds) + Cs = np.asarray(Cs) + if Ds.ndim == 2: + Ds = Ds[None] + Cs = Cs[None] + B, T, T_ = Ds.shape + assert T == T_ + assert Ds.shape == Cs.shape + extra = {} + + Ss, mincorr = threshold_correlation_matrix( + Cs, + mincorr=mincorr, + mincorr_percentile=mincorr_percentile, + mincorr_percentile_nneighbs=mincorr_percentile_nneighbs, + max_dt_s=max_dt_s, + bin_s=time_bin_edges[1] - time_bin_edges[0], + T=T, + in_place=in_place, + ) + extra["S"] = Ss + extra["mincorr"] = mincorr + + if not do_window_weights: + return Ss, extra + + # get weights + L_t = lambda_t * laplacian(T, eps=max(1e-5, eps)) + weights_orig, weights_thresh, Pind = get_weights( + Ds, + Ss, + L_t, + windows, + raster, + depth_bin_edges, + time_bin_edges, + raster_kw, + weights_threshold_low=weights_threshold_low, + weights_threshold_high=weights_threshold_high, + progress_bar=progress_bar, + ) + extra["weights_orig"] = weights_orig + extra["weights_thresh"] = weights_thresh + extra["Pind"] = Pind + + # update noise model. we deliberately divide by zero and inf here. + Us = Ss if in_place else np.zeros_like(Ss) + with np.errstate(divide="ignore"): + # low mem impl of U = abs(1/(1/weights_thresh+1/weights_thresh'+1/S)) + np.reciprocal(Ss, out=Us) + invW = 1.0 / weights_thresh + Us += invW[:, :, None] + Us += invW[:, None, :] + np.reciprocal(Us, out=Us) + # handles possible -0s that cause issues elsewhere + np.abs(Us, out=Us) + # more readable equivalent: + # for b in range(B): + # invWbtt = invW[b, :, None] + invW[b, None, :] + # Us[b] = np.abs(1.0 / (invWbtt + 1.0 / Ss[b])) + extra["U"] = Us + + return Us, extra diff --git a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py index 7804238024..b6fa344def 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_estimation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_estimation.py @@ -11,7 +11,7 @@ from .motion_utils import Motion, get_windows, get_spatial_bin_edges from .decentralized import DecentralizedRegistration from .iterative_template import IterativeTemplateRegistration -from .dredge import DredgeLfpRegistration +from .dredge import DredgeLfpRegistration, DredgeApRegistration def estimate_motion( @@ -162,7 +162,7 @@ def estimate_motion( return motion -_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration, DredgeLfpRegistration] +_methods_list = [DecentralizedRegistration, IterativeTemplateRegistration, DredgeLfpRegistration, DredgeApRegistration] estimate_motion_methods = {m.name: m for m in _methods_list} method_doc = make_multi_method_doc(_methods_list) estimate_motion.__doc__ = estimate_motion.__doc__.format(method_doc=method_doc) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index d9c0bedb24..44bf3eb3ad 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -376,6 +376,8 @@ def make_2d_motion_histogram( bin_um=2.0, hist_margin_um=50, spatial_bin_edges=None, + depth_smooth_um=None, + time_smooth_s=None, ): """ Generate 2d motion histogram in depth and time. @@ -401,6 +403,12 @@ def make_2d_motion_histogram( Ignored if spatial_bin_edges is given. spatial_bin_edges : np.array, default: None The pre-computed spatial bin edges + depth_smooth_um: None or float + Optional gaussian smoother on histogram on depth axis. + This is given as the sigma of the gaussian in micrometers. + time_smooth_s: None or float + Optional gaussian smoother on histogram on time axis. + This is given as the sigma of the gaussian in seconds. Returns ------- @@ -435,6 +443,14 @@ def make_2d_motion_histogram( bin_counts[bin_counts == 0] = 1 motion_histogram = motion_histogram / bin_counts + from scipy.ndimage import gaussian_filter1d + + if depth_smooth_um is not None: + motion_histogram = gaussian_filter1d(motion_histogram, depth_smooth_um / bin_um, axis=1, mode="constant") + + if time_smooth_s is not None: + motion_histogram = gaussian_filter1d(motion_histogram, time_smooth_s / bin_duration_s, axis=0, mode="constant") + return motion_histogram, temporal_bin_edges, spatial_bin_edges