diff --git a/debugging/_test_session_alignment.py b/debugging/_test_session_alignment.py index e46bd5bfd5..4de6360210 100644 --- a/debugging/_test_session_alignment.py +++ b/debugging/_test_session_alignment.py @@ -241,7 +241,7 @@ "method": "chunked_supremum", # TODO: double check scaling "chunked_bin_size_s": "estimate", "log_scale": True, - "smooth_um": 10, + "depth_smooth_um": 10, } compute_alignment_kwargs = { "num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs. diff --git a/debugging/alignment_utils.py b/debugging/alignment_utils.py index 6ce787f68b..05b068994a 100644 --- a/debugging/alignment_utils.py +++ b/debugging/alignment_utils.py @@ -20,7 +20,7 @@ def get_activity_histogram( - recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s, smooth_um + recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s, depth_smooth_um ): """ TODO: assumes 1-segment recording @@ -35,7 +35,7 @@ def get_activity_histogram( bin_um=None, hist_margin_um=None, spatial_bin_edges=spatial_bin_edges, - depth_smooth_um=smooth_um, + depth_depth_smooth_um=depth_smooth_um, ) assert np.array_equal(generated_spatial_bin_edges, spatial_bin_edges), "TODO: remove soon after testing" diff --git a/debugging/playing_inter-session-alignment.py b/debugging/playing_inter-session-alignment.py index 236d05e3dd..d6ad5cb14e 100644 --- a/debugging/playing_inter-session-alignment.py +++ b/debugging/playing_inter-session-alignment.py @@ -31,7 +31,7 @@ "method": "chunked_median", # TODO: double check scaling "chunked_bin_size_s": "estimate", "log_scale": False, - "smooth_um": 5, + "depth_smooth_um": 5, } compute_alignment_kwargs = { "num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs. diff --git a/debugging/session_alignment.py b/debugging/session_alignment.py index 9490622d6f..795df8fa43 100644 --- a/debugging/session_alignment.py +++ b/debugging/session_alignment.py @@ -1,3 +1,8 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from spikeinterface.core.baserecording import BaseRecording + import numpy as np from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording from spikeinterface.sortingcomponents.motion.motion_utils import get_spatial_windows, Motion, get_spatial_bins @@ -9,21 +14,64 @@ import copy -# TODO: need to plot out the entire call tree and -# make sure it is optimal as its quite complex but -# there is quite a lot of weird stuff here. +# 1) add docstrings and type hints +# 2) add print statements to the entry function +# 3) look into the DREDGE stuff +# TODO: 1) think of a method to choose some reasonable defaults for bin size, nonrigid smoothing. +# TODO: 2) different alignment procedure -_estimate_histogram_kwargs = { + +def get_estimate_histogram_kwargs() -> dict: + """ + A dictionary controlling how the histogram for each session is + computed. The session histograms are estimated by chunking + the recording into time segments and computing histograms + for each chunk, then performing some summary statistic over + the chunked histograms. + + Returns + ------- + A dictionary with entries: + + "bin_um" : number of spatial histogram bins. As the estimated peak + locations are continuous (i.e. real numbers) this is not constrained + by the number of channels. + "method" : may be "chunked_mean", "chunked_median", "chunked_supremum", + "chunked_poisson". Determines the summary statistic used over + the histograms computed across a session. See `alignment_utils.py + for details on each method. + "chunked_bin_size_s" : The length in seconds to chunk the recording + for estimating the chunked histograms. + "log_scale" : if `True`, histograms are log transformed. + "depth_smooth_um" : if `None`, no smoothing is applied. See + `make_2d_motion_histogram`. + """ + return { "bin_um": 2, "method": "chunked_mean", "chunked_bin_size_s": "estimate", "log_scale": False, - "smooth_um": None, + "depth_smooth_um": None, } -_compute_alignment_kwargs = { +def get_compute_alignment_kwargs() -> dict: + """ + A dictionary with settings controlling how inter-session + alignment is estimated and computed given a set of + session activity histograms. + + All keys except for "non_rigid_window_kwargs" determine + how alignment is estimated, based on the kilosort ("kilosort_like" + in spikeinterface) motion correction method. See + `iterative_template_registration` for details. + + "non_rigid_window_kwargs" : if nonrigid alignment + is performed, this determines the nature of the + windows along the probe depth. See `get_spatial_windows`. + """ + return { "num_shifts_block": 5, "interpolate": False, "interp_factor": 10, @@ -43,11 +91,15 @@ } -_interpolate_motion_kwargs = { - "border_mode": "remove_channels", - "spatial_interpolation_method": "kriging", - "sigma_um": 20.0, - "p": 2 +def get_interpolate_motion_kwargs(): + """ + Settings to pass to `InterpolateMotionRecording" + """ + return { + "border_mode": "remove_channels", + "spatial_interpolation_method": "kriging", + "sigma_um": 20.0, + "p": 2 } # ----------------------------------------------------------------------------- @@ -57,18 +109,66 @@ # TODO: add some print statements for progress def align_sessions( - recordings_list, - peaks_list, - peak_locations_list, - alignment_order="to_middle", - rigid=True, - estimate_histogram_kwargs=_estimate_histogram_kwargs, - compute_alignment_kwargs=_compute_alignment_kwargs, - interpolate_motion_kwargs=_interpolate_motion_kwargs, -): + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + alignment_order: str = "to_middle", + rigid: bool = True, + estimate_histogram_kwargs: dict = get_estimate_histogram_kwargs(), + compute_alignment_kwargs: dict = get_compute_alignment_kwargs(), + interpolate_motion_kwargs: dict = get_interpolate_motion_kwargs(), +) -> tuple[list[BaseRecording], ]: """ - print what happens with the update! it is automaticlally added - to existing motion if it exists. + Estimate probe displacement across recording sessions and + return interpolated, displacement-corrected recording. Displacement + is only estimated along the "y" dimension. + + This assumes peaks and peak locations have already been computed. + See `compute_peaks_locations_for_session_alignment` for generating + `peaks_list` and `peak_locations_list` from a `recordings_list`. + + If a recording in `recordings_list` is already an `InterpolateMotionRecording`, + the displacement will be added to the existing shifts to avoid duplicate + interpolations. Note the returned, corrected recording is a copy + (recordings in `recording_list` are not edited in-place). + + Parameters + ---------- + + recordings_list : list[BaseRecording] + A list of recordings to be aligned. + peaks_list : list[np.ndarray] + A list of peaks detected from the recordings in `recordings_list`, + as returned from the `detect_peaks` function. Each entry in + `peaks_list` should be from the corresponding entry in `recordings_list`. + peak_locations_list : list[np.ndarray] + A list of peak locations, as computed by `localize_peaks`. Each entry + in `peak_locations_list` should be matched to the corresponding entry + in `peaks_list` and `recordings_list`. + alignment_order : str + "to_middle" will align all sessions to the mean position. + Alternatively, "to_session_N" where "N" is a session number + will align to the Nth session. + rigid : bool + If `True`, estimated displacement is rigid. If `False`, nonrigid + estimation is performed by performing rigid alignment on overlapping + subsets of the probes "y" dimension. + estimate_histogram_kwargs : dict + see `get_estimate_histogram_kwargs()` + compute_alignment_kwargs : dict + see `get_compute_alignment_kwargs()` + interpolate_motion_kwargs : dict + see `get_interpolate_motion_kwargs()` + + Returns + ------- + `corrected_recordings_list : list[BaseRecording] + + `motion_objects_list : list[TODO] + + extra_outputs_dict : dict + + """ estimate_histogram_kwargs = copy.deepcopy(estimate_histogram_kwargs) compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) @@ -142,7 +242,7 @@ def align_sessions_after_motion_correction( ) -def compute_peaks_for_session_alignment( +def compute_peaks_locations_for_session_alignment( recording_list, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs ): """ @@ -172,7 +272,7 @@ def _compute_session_histograms( bin_um, method, chunked_bin_size_s, - smooth_um, + depth_smooth_um, log_scale, ): """ @@ -191,7 +291,7 @@ def _compute_session_histograms( session_hist, temporal_bin_centers, histogram_info = _get_single_session_activity_histogram( recording, peaks, peak_locations, spatial_bin_edges, - method, log_scale, chunked_bin_size_s, smooth_um, + method, log_scale, chunked_bin_size_s, depth_smooth_um, ) temporal_bin_centers_list.append(temporal_bin_centers) session_histogram_list.append(session_hist) @@ -201,7 +301,7 @@ def _compute_session_histograms( def _get_single_session_activity_histogram( - recording, peaks, peak_locations, spatial_bin_edges, method, log_scale, chunked_bin_size_s, smooth_um + recording, peaks, peak_locations, spatial_bin_edges, method, log_scale, chunked_bin_size_s, depth_smooth_um ): """ """ @@ -214,7 +314,7 @@ def _get_single_session_activity_histogram( one_bin_histogram, _, _ = alignment_utils.get_activity_histogram( recording, peaks, peak_locations, spatial_bin_edges, - log_scale=False, bin_s=None, smooth_um=smooth_um + log_scale=False, bin_s=None, depth_smooth_um=depth_smooth_um ) if method == "entire_session": if log_scale: @@ -231,7 +331,7 @@ def _get_single_session_activity_histogram( ) chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_activity_histogram( - recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s=chunked_bin_size_s, smooth_um=smooth_um, + recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s=chunked_bin_size_s, depth_smooth_um=depth_smooth_um, ) session_std = np.sum(np.std(chunked_histograms, axis=0)) / chunked_histograms.shape[1] @@ -373,7 +473,7 @@ def _correct_session_displacement( estimate_histogram_kwargs["method"], estimate_histogram_kwargs["log_scale"], estimate_histogram_kwargs["chunked_bin_size_s"], - estimate_histogram_kwargs["smooth_um"], + estimate_histogram_kwargs["depth_smooth_um"], ) corrected_session_histogram_list.append(session_hist)