From cab6646e7edacf091ec1396744a1e122ae8499cf Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Thu, 27 Jun 2024 12:18:31 +0200 Subject: [PATCH] wip dredge_ap --- .../sortingcomponents/motion/decentralized.py | 4 +- .../sortingcomponents/motion/dredge.py | 139 ++++++++++++------ .../sortingcomponents/motion/motion_utils.py | 14 +- 3 files changed, 102 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/sortingcomponents/motion/decentralized.py b/src/spikeinterface/sortingcomponents/motion/decentralized.py index 9a76b94d20..baf0fae1f2 100644 --- a/src/spikeinterface/sortingcomponents/motion/decentralized.py +++ b/src/spikeinterface/sortingcomponents/motion/decentralized.py @@ -145,7 +145,7 @@ def run( conv_engine = "torch" if HAVE_TORCH else "numpy" dim = ["x", "y", "z"].index(direction) - contact_pos = recording.get_channel_locations()[:, dim] + contact_depth = recording.get_channel_locations()[:, dim] # spatial histogram bins spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um) @@ -153,7 +153,7 @@ def run( # get spatial windows non_rigid_windows, non_rigid_window_centers = get_windows( - rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape + rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape ) # make 2D histogram raster diff --git a/src/spikeinterface/sortingcomponents/motion/dredge.py b/src/spikeinterface/sortingcomponents/motion/dredge.py index 817f34f88f..779b236d77 100644 --- a/src/spikeinterface/sortingcomponents/motion/dredge.py +++ b/src/spikeinterface/sortingcomponents/motion/dredge.py @@ -28,7 +28,7 @@ import gc -from .motion_utils import Motion, get_windows, get_window_domains, scipy_conv1d +from .motion_utils import Motion, get_windows, get_window_domains, scipy_conv1d, make_2d_motion_histogram # to discuss @@ -70,16 +70,12 @@ def run( extra, **method_kwargs, ): - dim = ["x", "y", "z"].index(direction) - peak_amplitudes = peaks["amplitude"] - peak_depths = peak_locations[direction] - peak_times = recording.sample_index_to_time(peaks["sample_index"]) outs = dredge_ap( - peak_amplitudes, - peak_depths, - peak_times, + recording, + peaks, + peak_locations, direction=direction, rigid=rigid, win_shape=win_shape, @@ -100,9 +96,9 @@ def run( # @TODO : Charlie I started very small refactoring, I let you continue def dredge_ap( - amps, - depths_um, - times_s, + recording, + peaks, + peak_locations, direction="y", rigid=False, # nonrigid window construction arguments @@ -113,7 +109,7 @@ def dredge_ap( bin_um=1.0, bin_s=1.0, max_disp_um=None, - time_horizon_s=1000, + time_horizon_s=1000., mincorr=0.1, # weights arguments do_window_weights=True, @@ -149,6 +145,8 @@ def dredge_ap( Arguments --------- + recording : Recording + The recording amps : np.array of shape (n_spikes,) depths: np.array of shape (n_spikes,) times : np.array of shape (n_spikes,) @@ -207,10 +205,22 @@ def dredge_ap( windows if one wants to visualize them. Set `extra_outputs` to also save displacement and correlation matrices. """ + + dim = ["x", "y", "z"].index(direction) + # @charlie: I removed amps/depths_um/times_s from the signature + # preaks and peak_locations are more SI compatible + # the way to get then + amps = peak_amplitudes = peaks["amplitude"] + depths_um = peak_depths = peak_locations[direction] + times_s = peak_times = recording.sample_index_to_time(peaks["sample_index"]) + + + thomas_kw = thomas_kw if thomas_kw is not None else {} xcorr_kw = xcorr_kw if xcorr_kw is not None else {} if time_horizon_s: xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s) + raster_kw = dict( amp_scale_fn=amp_scale_fn, post_transform=post_transform, @@ -234,33 +244,70 @@ def dredge_ap( # this will store return values other than the MotionEstimate extra = {} - # TODO charlie switch this to make_2d_motion_histogram after having putting more option - raster_res = spike_raster( - amps, - depths_um, - times_s, - **raster_kw, + + + # TODO charlie I switch this to make_2d_motion_histogram + # but we need to add all options from the original spike_raster() + + # raster_res = spike_raster( + # amps, + # depths_um, + # times_s, + # **raster_kw, + # ) + # if count_masked_correlation: + # raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res + # else: + # raster, spatial_bin_edges_um, time_bin_edges_s = raster_res + + motion_histogram, time_bin_edges_s, spatial_bin_edges_um = make_2d_motion_histogram( + recording, + peaks, + peak_locations, + weight_with_amplitude=False, + direction="y", + bin_duration_s=1.0, + bin_um=2.0, + hist_margin_um=50, + spatial_bin_edges=None, + depth_smooth_um=None, + time_smooth_s=None, ) - if count_masked_correlation: - raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res - else: - raster, spatial_bin_edges_um, time_bin_edges_s = raster_res + raster = motion_histogram.T + + + # TODO @charlie check that we are doing the same thing + # windows, window_centers = get_windows( + # np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um], + # win_step_um, + # win_scale_um, + # spatial_bin_edges=spatial_bin_edges_um, + # margin_um=-win_scale_um / 2 if win_margin_um is None else win_margin_um, + # win_shape=win_shape, + # zero_threshold=1e-5, + # rigid=rigid, + # ) + + dim = ["x", "y", "z"].index(direction) + contact_depth = recording.get_channel_locations()[:, dim] + spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) - # TODO Sam I did not yet change parameters here with the get_windows API windows, window_centers = get_windows( - # pseudo geom to fool spikeinterface - np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um], + rigid, + contact_depth, + spatial_bin_centers, + win_margin_um, win_step_um, win_scale_um, - spatial_bin_edges=spatial_bin_edges_um, - margin_um=-win_scale_um / 2 if win_margin_um is None else win_margin_um, - win_shape=win_shape, + win_shape, zero_threshold=1e-5, - rigid=rigid, ) - if extra_outputs and count_masked_correlation: - extra["counts"] = counts + # TODO charlie : put back the count + # if extra_outputs and count_masked_correlation: + # extra["counts"] = counts + + # cross-correlate to get D and C if precomputed_D_C_maxdisp is None: Ds, Cs, max_disp_um = xcorr_windows( @@ -273,7 +320,8 @@ def dredge_ap( max_disp_um=max_disp_um, progress_bar=progress_bar, device=device, - masks=(counts > 0) if count_masked_correlation else None, + # TODO charlie : put back the count for the mask + # masks=(counts > 0) if count_masked_correlation else None, **xcorr_kw, ) else: @@ -319,9 +367,7 @@ def dredge_ap( extra["max_disp_um"] = max_disp_um time_bin_centers = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1]) - spatial_bin_centers_um = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1]) - - motion = Motion([displacement], [time_bin_centers], spatial_bin_centers_um, direction=direction) + motion = Motion([displacement.T], [time_bin_centers], window_centers, direction=direction) if extra_outputs: return motion, extra @@ -464,7 +510,7 @@ def dredge_online_lfp( """ dim = ["x", "y", "z"].index(direction) # contact pos is the only on the direction - contact_pos = lfp_recording.get_channel_locations()[:, dim] + contact_depth = lfp_recording.get_channel_locations()[:, dim] fs = lfp_recording.get_sampling_frequency() @@ -477,7 +523,7 @@ def dredge_online_lfp( thomas_kw = thomas_kw if thomas_kw is not None else {} full_xcorr_kw = dict( rigid=rigid, - bin_um=np.median(np.diff(contact_pos)), + bin_um=np.median(np.diff(contact_depth)), max_disp_um=max_disp_um, progress_bar=False, device=device, @@ -494,22 +540,22 @@ def dredge_online_lfp( # here we check that contact positons are unique on the direction - if contact_pos.size != np.unique(contact_pos).size: + if contact_depth.size != np.unique(contact_depth).size: raise ValueError( f"estimate motion with 'dredge_lfp' need channel_positions to be unique in the direction='{direction}'" ) - if np.any(np.diff(contact_pos) < 0): + if np.any(np.diff(contact_depth) < 0): raise ValueError( f"estimate motion with 'dredge_lfp' need channel_positions to be ordered direction='{direction}'" "please use spikeinterface.preprocessing.depth_order(recording)" ) # Important detail : in LFP bin center are contact position in the direction - spatial_bin_centers = contact_pos + spatial_bin_centers = contact_depth windows, window_centers = get_windows( rigid=rigid, - contact_pos=contact_pos, + contact_depth=contact_depth, spatial_bin_centers=spatial_bin_centers, win_margin_um=win_margin_um, win_step_um=win_step_um, @@ -529,7 +575,7 @@ def dredge_online_lfp( t0, t1 = 0, T_chunk traces0 = lfp_recording.get_traces(start_frame=t0, end_frame=t1) Ds0, Cs0, max_disp_um = xcorr_windows( - traces0.T, windows, contact_pos, win_scale_um, **full_xcorr_kw + traces0.T, windows, contact_depth, win_scale_um, **full_xcorr_kw ) full_xcorr_kw["max_disp_um"] = max_disp_um Ss0, mincorr0 = threshold_correlation_matrix( @@ -568,7 +614,7 @@ def dredge_online_lfp( Ds10, Cs10, _ = xcorr_windows( traces1.T, windows, - contact_pos, + contact_depth, win_scale_um, raster_b=traces0.T, **full_xcorr_kw, @@ -576,7 +622,7 @@ def dredge_online_lfp( # cross-correlation in current chunk Ds1, Cs1, _ = xcorr_windows( - traces1.T, windows, contact_pos, win_scale_um, **full_xcorr_kw + traces1.T, windows, contact_depth, win_scale_um, **full_xcorr_kw ) Ss1, mincorr1 = threshold_correlation_matrix( Cs1, @@ -954,6 +1000,7 @@ def xcorr_windows( slices = get_window_domains(windows) B, D = windows.shape D_, T0 = raster_a.shape + assert D == D_ # torch versions on device @@ -1310,7 +1357,7 @@ def weight_correlation_matrix( mincorr=0.0, mincorr_percentile=None, mincorr_percentile_nneighbs=20, - max_dt_s=None, + time_horizon_s=None, lambda_t=DEFAULT_LAMBDA_T, eps=DEFAULT_EPS, do_window_weights=True, @@ -1337,7 +1384,7 @@ def weight_correlation_matrix( mincorr=mincorr, mincorr_percentile=mincorr_percentile, mincorr_percentile_nneighbs=mincorr_percentile_nneighbs, - max_dt_s=max_dt_s, + time_horizon_s=time_horizon_s, bin_s=time_bin_edges[1] - time_bin_edges[0], T=T, in_place=in_place, diff --git a/src/spikeinterface/sortingcomponents/motion/motion_utils.py b/src/spikeinterface/sortingcomponents/motion/motion_utils.py index 44bf3eb3ad..3fb0f8505a 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_utils.py @@ -232,7 +232,7 @@ def copy(self): -def get_windows(rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape, +def get_windows(rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape, zero_threshold=None): """ Generate spatial windows (taper) for non-rigid motion. @@ -243,7 +243,7 @@ def get_windows(rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step ---------- rigid : bool If True, returns a single rectangular window - contact_pos : np.ndarray + contact_depth : np.ndarray Position of electrodes of the corection direction shape=(num_channels, ) spatial_bin_centers : np.array The pre-computed spatial bin centers @@ -283,8 +283,8 @@ def get_windows(rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step f"get_windows(): spatial windows are probably not overlaping because {win_scale_um=} and {win_step_um=}" ) - min_ = np.min(contact_pos) - win_margin_um - max_ = np.max(contact_pos) + win_margin_um + min_ = np.min(contact_depth) - win_margin_um + max_ = np.max(contact_depth) + win_margin_um num_windows = int((max_ - min_) // win_step_um) border = ((max_ - min_) % win_step_um) / 2 window_centers = np.arange(num_windows + 1) * win_step_um + min_ + border @@ -356,10 +356,10 @@ def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um): # contact along one axis probe = recording.get_probe() dim = ["x", "y", "z"].index(direction) - contact_pos = probe.contact_positions[:, dim] + contact_depth = probe.contact_positions[:, dim] - min_ = np.min(contact_pos) - hist_margin_um - max_ = np.max(contact_pos) + hist_margin_um + min_ = np.min(contact_depth) - hist_margin_um + max_ = np.max(contact_depth) + hist_margin_um spatial_bins = np.arange(min_, max_ + bin_um, bin_um) return spatial_bins