diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py index 2a1755b8cf..710ce67a14 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py @@ -120,7 +120,10 @@ def estimate_chunk_size(scaled_activity_histogram): t = lambda_hat_s / (e / confidence_z) ** 2 - print(f"estimated t: {t} for lambda {lambda_hat_s}") + print( + f"Chunked histogram window size of: {t}s estimated " + f"for firing rate (25% of histogram peak) of {lambda_hat_s}" + ) return 10 @@ -227,8 +230,7 @@ def get_chunked_hist_eigenvector(chunked_session_histograms): def get_chunked_gaussian_process_regression(chunked_session_histogram): - """ - """ + """ """ # TODO: this is currently a placeholder implementation where the # mean and variance over repeated samples is taken to run quickly. # It would be better to use sparse version with repeated measures @@ -255,7 +257,7 @@ def get_chunked_gaussian_process_regression(chunked_session_histogram): bias_mean = False if bias_mean: - #this is cool, bias the estimation towards the peak + # this is cool, bias the estimation towards the peak Y = Y + np.mean(Y, axis=0) - np.percentile(Y, 5, axis=0) # TODO: avoid copy, also fix dims in case of square # normalise X and set lengthscale to 1 bin @@ -272,15 +274,19 @@ def get_chunked_gaussian_process_regression(chunked_session_histogram): y_mean = np.mean(Y, axis=0) y_var = np.std(Y, axis=0) - Y_mean_scaled = (y_mean - mu_ystar) /std_ystar # standardise the normal way - Y_var_scaled = (1/std_ystar**2) * y_var # this is a variance so need to scale to the square (TODO: see overleaf notes) + Y_mean_scaled = (y_mean - mu_ystar) / std_ystar # standardise the normal way + Y_var_scaled = ( + 1 / std_ystar**2 + ) * y_var # this is a variance so need to scale to the square (TODO: see overleaf notes) kernel = GPy.kern.RBF(input_dim=1, lengthscale=lengthscale, variance=np.mean(Y_var_scaled)) # TODO: check this output_index2 = np.arange(num_bins) - Y_metadata2 = {'output_index': output_index2} + Y_metadata2 = {"output_index": output_index2} - likelihood = GPy.likelihoods.HeteroscedasticGaussian(Y_metadata2, variance=Y_var_scaled) # one variance per y, but should be repeated for the same x + likelihood = GPy.likelihoods.HeteroscedasticGaussian( + Y_metadata2, variance=Y_var_scaled + ) # one variance per y, but should be repeated for the same x gp = GPy.models.GPRegression(X_scaled.reshape(-1, 1), Y_mean_scaled.reshape(-1, 1), kernel, Y_metadata2) diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index fc596d645e..330c485433 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -1,8 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING -from torch.onnx.symbolic_opset11 import chunk - if TYPE_CHECKING: from spikeinterface.core.baserecording import BaseRecording @@ -15,8 +13,6 @@ from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node import copy -import matplotlib.pyplot as plt - def get_estimate_histogram_kwargs() -> dict: """ @@ -37,9 +33,9 @@ def get_estimate_histogram_kwargs() -> dict: "chunked_poisson". Determines the summary statistic used over the histograms computed across a session. See `alignment_utils.py for details on each method. - "chunked_bin_size_s" : The length in seconds to chunk the recording - for estimating the chunked histograms. If set to "estimate", - the size is estimated from firing frequencies. + "chunked_bin_size_s" : The length in seconds (float) to chunk the recording + for estimating the chunked histograms. Can be set to "estimate" (str), + and the size is estimated from firing frequencies. "log_scale" : if `True`, histograms are log transformed. "depth_smooth_um" : if `None`, no smoothing is applied. See `make_2d_motion_histogram`. @@ -50,6 +46,7 @@ def get_estimate_histogram_kwargs() -> dict: "chunked_bin_size_s": "estimate", "log_scale": False, "depth_smooth_um": None, + "histogram_type": "activity_1d", } @@ -83,7 +80,12 @@ def get_compute_alignment_kwargs() -> dict: def get_non_rigid_window_kwargs(): """ - TODO: merge with motion correction + see get_spatial_windows() for parameters. + + TODO + ---- + merge with motion correction kwargs which are + defined in the function signature. """ return { "rigid": True, @@ -97,7 +99,8 @@ def get_non_rigid_window_kwargs(): def get_interpolate_motion_kwargs(): """ - Settings to pass to `InterpolateMotionRecording" + Settings to pass to `InterpolateMotionRecording`, + see that class for parameter descriptions. """ return {"border_mode": "remove_channels", "spatial_interpolation_method": "kriging", "sigma_um": 20.0, "p": 2} @@ -147,16 +150,12 @@ def align_sessions( "to_middle" will align all sessions to the mean position. Alternatively, "to_session_N" where "N" is a session number will align to the Nth session. - rigid : bool - If `True`, estimated displacement is rigid. If `False`, nonrigid - estimation is performed by performing rigid alignment on overlapping - subsets of the probes "y" dimension. + non_rigid_window_kwargs : dict + see `get_non_rigid_window_kwargs` estimate_histogram_kwargs : dict see `get_estimate_histogram_kwargs()` compute_alignment_kwargs : dict see `get_compute_alignment_kwargs()` - non_rigid_window_kwargs : dict - see `get_non_rigid_window_kwargs` interpolate_motion_kwargs : dict see `get_interpolate_motion_kwargs()` @@ -168,26 +167,51 @@ def align_sessions( an InterpolateMotionRecording` recording, the corrected output recording will be a copy of the input recording with the additional displacement correction added. - `motion_objects_list : list[Motion] - List of motion objects associated with each corrected - recording. In the case where the `recording` was an - `InterpolateMotionRecording`, no motion object is created - and the entry in `motion_objects_list` will be `None`. - TODO extra_outputs_dict : dict - A dictionary of outputs, including variables generated - during the displacement estiamtion and correction. - Also, includes an "corrected" field including - a list of corrected `peak_locations` and activity - histogram generated after correction. + Dictionary of features used in the alignment estimation and correction. + + shifts_array : np.ndarray + A (num_sessions x num_rigid_windows) array of shifts. + session_histogram_list : list[np.ndarray] + A list of histograms (one per session) used for the alignment. + spatial_bin_centers : np.ndarray + The spatial bin centers, shared between all recordings. + temporal_bin_centers_list : list[np.ndarray] + List of temporal bin centers. As alignment is based on a single + histogram per session, this contains only 1 value per recording, + which is the mid-timepoint of the recording. + non_rigid_window_centers : np.ndarray + Window centers of the probe segments used for non-rigid alignment. + If rigid alignment is performed, this is a single value (mid-probe). + non_rigid_windows : np.ndarray + A (num nonrigid windows, num spatial_bin_centers) binary array used to mask + the probe segments for non-rigid alignment. If rigid alignment is performed, + this a vector of ones with length (spatial_bin_centers,) + histogram_info_list :list[dict] + see `_get_single_session_activity_histogram()` for details. + motion_objects_list : + List of motion objects containing the shifts and spatial and temporal + bins for each recording. Note this contains only displacement + associated with the inter-session alignment, and so will differ from + the motion on corrected recording objects if the recording is + already an `InterpolateMotionRecording` object containing + within-session motion correction. + corrected : dict + Dictionary containing corrected-histogram + information. + corrected_peak_locations_list : + Displacement-corrected `peak_locations`. + corrected_session_histogram_list : + Corrected activity histogram (computed from the corrected peak locations). """ non_rigid_window_kwargs = copy.deepcopy(non_rigid_window_kwargs) estimate_histogram_kwargs = copy.deepcopy(estimate_histogram_kwargs) compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs) - _check_align_sesssions_inpus( + # Ensure list lengths match and all channel locations are the same across recordings. + _check_align_sesssions_inputs( recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs ) @@ -199,7 +223,7 @@ def align_sessions( print("Aligning the activity histograms across sessions...") - contact_depths = recordings_list[0].get_channel_locations()[:, 1] # "y" dim. + contact_depths = recordings_list[0].get_channel_locations()[:, 1] shifts_array, non_rigid_windows, non_rigid_window_centers = _compute_session_alignment( session_histogram_list, @@ -220,7 +244,12 @@ def align_sessions( print("Creating corrected peak locations and histograms...") corrected_peak_locations_list, corrected_session_histogram_list = _correct_session_displacement( - corrected_recordings_list, peaks_list, peak_locations_list, spatial_bin_edges, estimate_histogram_kwargs + corrected_recordings_list, + peaks_list, + peak_locations_list, + motion_objects_list, + spatial_bin_edges, + estimate_histogram_kwargs, ) extra_outputs_dict = { @@ -242,12 +271,12 @@ def align_sessions( def align_sessions_after_motion_correction( recordings_list: list[BaseRecording], motion_info_list: list[dict], align_sessions_kwargs: dict | None -) -> tuple[list[BaseRecording], list[Motion], dict]: +) -> tuple[list[BaseRecording], dict]: """ Convenience function to run `align_sessions` to correct for inter-session displacement from the outputs of motion correction. - The estimated displacement will be added to the existing recording. + The estimated displacement will be added directly to the recording. Parameters ---------- @@ -256,8 +285,6 @@ def align_sessions_after_motion_correction( motion_info_list : list[dict] A list of `motion_info` objects, as output from `correct_motion`. Each entry should correspond to a recording in `recording_list`. - rigid: bool - align_sessions_kwargs : dict A dictionary of keyword arguments passed to `align_sessions`. @@ -270,7 +297,7 @@ def align_sessions_after_motion_correction( motion_kwargs_list = [info["parameters"]["estimate_motion_kwargs"] for info in motion_info_list] if not all(kwargs == motion_kwargs_list[0] for kwargs in motion_kwargs_list): raise ValueError( - "The motion correct settings used on the `recordings_list`" "must be identical for all recordings" + "The motion correct settings used on the `recordings_list` must be identical for all recordings" ) motion_window_kwargs = copy.deepcopy(motion_kwargs_list[0]) @@ -281,7 +308,7 @@ def align_sessions_after_motion_correction( align_sessions_kwargs = get_compute_alignment_kwargs() # If motion correction was nonrigid, we must use the same settings for - # inter-session alignment or we will not be able to add the nonrigid + # inter-session alignment, or we will not be able to add the nonrigid # shifts together. if ( "non_rigid_window_kwargs" in align_sessions_kwargs @@ -360,10 +387,10 @@ def _compute_session_histograms( histogram_type, # TODO think up better names bin_um: float, method: str, - chunked_bin_size_s: str, - depth_smooth_um: str, - log_scale: str, -) -> tuple[list[np.ndarray], list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]: + chunked_bin_size_s: float | "estimate", + depth_smooth_um: float, + log_scale: bool, +) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]: """ Compute a 1d activity histogram for the session. As sessions may be long, the approach taken is to chunk @@ -403,7 +430,8 @@ def _compute_session_histograms( (e.g. chunked histograms). One per session. See `_get_single_session_activity_histogram()` for details. """ - # Get spatial windows and estimate the session histograms + # Get spatial windows (shared across all histograms) + # and estimate the session histograms temporal_bin_centers_list = [] spatial_bin_centers, spatial_bin_edges, _ = get_spatial_bins( @@ -447,7 +475,7 @@ def _get_single_session_activity_histogram( spatial_bin_edges: np.ndarray, method: str, log_scale: bool, - chunked_bin_size_s: float, + chunked_bin_size_s: float | "estimate", depth_smooth_um: float, ) -> tuple[np.ndarray, np.ndarray, dict]: """ @@ -457,7 +485,7 @@ def _get_single_session_activity_histogram( Note if `chunked_bin_size_is` is set to `"estimate"` the histogram for the entire session is first created to get a good - estimate of the firing rates. TODO: this is probably overkill. + estimate of the firing rates. The firing rates are used to use a time segment size that will allow a good estimation of the firing rate. @@ -511,7 +539,7 @@ def _get_single_session_activity_histogram( chunked_bin_size_s = alignment_utils.estimate_chunk_size(scaled_hist) chunked_bin_size_s = np.min([chunked_bin_size_s, recording.get_duration()]) - if histogram_type == "1Dy": # TODO: tidy this up + if histogram_type == "activity_1d": chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_activity_histogram( recording, @@ -524,12 +552,12 @@ def _get_single_session_activity_histogram( scale_to_hz=True, ) - elif histogram_type in ["2Dy_amplitude", "2Dy_x"]: + elif histogram_type in ["activity_2d", "locations_2d"]: - if histogram_type == "2Dy_amplitude": + if histogram_type == "activity_2d": from spikeinterface.sortingcomponents.motion.motion_utils import make_3d_motion_histograms - chunked_histograms, chunked_temporal_bin_edges, _ = make_3d_motion_histograms( # TODO: compute centers + chunked_histograms, chunked_temporal_bin_edges, _ = make_3d_motion_histograms( recording, peaks, peak_locations, @@ -543,41 +571,41 @@ def _get_single_session_activity_histogram( ) else: - chunked_histogram, chunked_temporal_bin_edges = _get_peak_positions_as_histogram( - recording, peak_locations - ) # TODO: could add smoothing + chunked_histograms, chunked_temporal_bin_edges = _get_peak_positions_as_histogram( + recording, spatial_bin_edges, chunked_bin_size_s, peaks, peak_locations + ) chunked_temporal_bin_centers = alignment_utils.get_bin_centers(chunked_temporal_bin_edges) if method == "chunked_mean": - session_histogram, variation = alignment_utils.get_chunked_hist_mean(chunked_histograms) + session_histogram, hist_variability = alignment_utils.get_chunked_hist_mean(chunked_histograms) elif method == "chunked_median": - session_histogram, variation = alignment_utils.get_chunked_hist_median(chunked_histograms) + session_histogram, hist_variability = alignment_utils.get_chunked_hist_median(chunked_histograms) elif method == "chunked_supremum": - session_histogram, variation = alignment_utils.get_chunked_hist_supremum(chunked_histograms) + session_histogram, hist_variability = alignment_utils.get_chunked_hist_supremum(chunked_histograms) elif method == "chunked_poisson": - session_histogram, variation = alignment_utils.get_chunked_hist_poisson_estimate(chunked_histograms) + session_histogram, hist_variability = alignment_utils.get_chunked_hist_poisson_estimate(chunked_histograms) elif method == "first_eigenvector": - session_histogram, variation = alignment_utils.get_chunked_hist_eigenvector(chunked_histograms) + session_histogram, hist_variability = alignment_utils.get_chunked_hist_eigenvector(chunked_histograms) elif method == "chunked_gp": # TODO: better name - session_histogram, variation, gp_model = alignment_utils.get_chunked_gaussian_process_regression( + session_histogram, hist_variability, gp_model = alignment_utils.get_chunked_gaussian_process_regression( chunked_histograms ) - # each b in independent, I think this is fine irrespective of method used - session_variation = np.mean(variation) # think about meaning of this 1d vs. 2D + # Take the average variability across bins as a summary measure. + session_mean_variability = np.mean(hist_variability) histogram_info = { "chunked_histograms": chunked_histograms, "chunked_temporal_bin_centers": chunked_temporal_bin_centers, - "session_variation": session_variation, + "session_mean_variability": session_mean_variability, "chunked_bin_size_s": chunked_bin_size_s, - "session_histogram_variation": variation, + "session_histogram_variation": hist_variability, } if method == "chunked_gp": @@ -586,12 +614,10 @@ def _get_single_session_activity_histogram( return session_histogram, temporal_bin_centers, histogram_info -def _get_peak_positions_as_histogram( - recording, peak_locations -): +def _get_peak_positions_as_histogram(recording, spatial_bin_edges, chunked_bin_size_s, peaks, peak_locations): """ This is just a temp function to see how it goes... - + # TODO: could add smoothing """ min_x = np.min(peak_locations["x"]) @@ -614,7 +640,7 @@ def _get_peak_positions_as_histogram( chunked_histograms, _ = np.histogramdd(arr, (chunked_temporal_bin_edges, spatial_bin_edges, x_bins)) - return chunked_histogram, chunked_temporal_bin_edges + return chunked_histograms, chunked_temporal_bin_edges def _create_motion_recordings( @@ -653,18 +679,18 @@ def _create_motion_recordings( session_shift = shifts_array[ses_idx][np.newaxis, :] + motion = Motion([session_shift], [temporal_bin_centers_list[ses_idx]], non_rigid_window_centers, direction="y") + motion_objects_list.append(motion) + if isinstance(recording, InterpolateMotionRecording): - corrected_recording = _add_displacement_to_interpolate_recording( - recording, session_shift, non_rigid_window_centers + print( + "Recording is already an `InterpolateMotionRecording. " "Adding shifts directly the recording object." ) - motion_objects_list.append(None) + + corrected_recording = _add_displacement_to_interpolate_recording(recording, motion) else: - motion = Motion( - [session_shift], [temporal_bin_centers_list[ses_idx]], non_rigid_window_centers, direction="y" - ) corrected_recording = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs) - motion_objects_list.append(motion) corrected_recordings_list.append(corrected_recording) @@ -672,9 +698,8 @@ def _create_motion_recordings( def _add_displacement_to_interpolate_recording( - recording: BaseRecording, - shifts_to_add: np.ndarray, - new_non_rigid_window_centers: np.ndarray, + original_recording: BaseRecording, + session_displacement_motion: Motion, ): """ This function adds a shift to an InterpolateMotionRecording. @@ -707,7 +732,10 @@ def _add_displacement_to_interpolate_recording( # Everything is done in place, so keep a short variable # name reference to the new recordings `motion` object # and update it.okay - corrected_recording = copy.deepcopy(recording) + corrected_recording = copy.deepcopy(original_recording) + + shifts_to_add = session_displacement_motion.displacement[0] + new_non_rigid_window_centers = session_displacement_motion.spatial_bins_um motion_ref = corrected_recording._recording_segments[0].motion recording_bins = motion_ref.displacement[0].shape[1] @@ -744,6 +772,7 @@ def _correct_session_displacement( recordings_list: list[BaseRecording], peaks_list: list[np.ndarray], peak_locations_list: list[np.ndarray], + motion_objects_list: list[Motion], spatial_bin_edges: np.ndarray, estimate_histogram_kwargs: dict, ): @@ -765,18 +794,19 @@ def _correct_session_displacement( corrected_session_histogram_list : list[np.ndarray] A list of histograms calculated from the corrected peaks (one per session). """ - # Correct the peak locations corrected_peak_locations_list = [] - for recording, peaks, peak_locations in zip(recordings_list, peaks_list, peak_locations_list): + for recording, peaks, peak_locations, motion in zip( + recordings_list, peaks_list, peak_locations_list, motion_objects_list + ): + # Note this `motion` is not necessarily the same as the motion on the recording. If the recording + # is an `InterpolateMotionRecording`, it will contain correction for both motion and inter-session displacement. + # Here we want to correct only the motion associated with inter-session displacement. corrected_peak_locs = correct_motion_on_peaks( peaks, peak_locations, - recording._recording_segments[ - 0 - ].motion, # TODO: this is wrong, if the previous recording was a motion correction - # then this will add the original motion correct and the new one. We need to pass just the new shifts + motion, recording, ) corrected_peak_locations_list.append(corrected_peak_locs) @@ -915,6 +945,7 @@ def _estimate_rigid_alignment( **compute_alignment_kwargs, ) optimal_shift_indices = _get_shifts_from_session_matrix(alignment_order, rigid_session_offsets_matrix) + return optimal_shift_indices @@ -999,7 +1030,7 @@ def _get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix # ----------------------------------------------------------------------------- -def _check_align_sesssions_inpus( +def _check_align_sesssions_inputs( recordings_list: list[BaseRecording], peaks_list: list[np.ndarray], peak_locations_list: list[np.ndarray], @@ -1020,7 +1051,7 @@ def _check_align_sesssions_inpus( if not all(rec.get_num_segments() == 1 for rec in recordings_list): raise ValueError( - "Multi-segment recordings not supported. All recordings " "in `recordings_list` but have only 1 segment." + "Multi-segment recordings not supported. All recordings in `recordings_list` but have only 1 segment." ) channel_locs = [rec.get_channel_locations() for rec in recordings_list] @@ -1031,17 +1062,16 @@ def _check_align_sesssions_inpus( "performed using the same probe." ) - accepted_hist_methods = ["entire_session", "chunked_mean", "chunked_median", "chunked_supremum", "chunked_poisson"] - method = estimate_histogram_kwargs["method"] - if method not in [ + accepted_hist_methods = [ "entire_session", "chunked_mean", "chunked_median", "chunked_supremum", - "chunked_poisson", "first_eigenvector", "chunked_gp", - ]: + ] + method = estimate_histogram_kwargs["method"] + if method not in accepted_hist_methods: raise ValueError(f"`method` option must be one of: {accepted_hist_methods}") if alignment_order != "to_middle": @@ -1049,14 +1079,14 @@ def _check_align_sesssions_inpus( split_name = alignment_order.split("_") if not "_".join(split_name[:2]) == "to_session": raise ValueError( - "`alignment_order` must take the form 'to_sesion_X'" "where X is the session number to align to." + "`alignment_order` must take the form 'to_session_X' where X is the session number to align to." ) ses_num = int(split_name[-1]) if ses_num > num_sessions: raise ValueError( - f"`alignment_order` session {ses_num} is larger than" f"the number of sessions in `recordings_list`." + f"`alignment_order` session {ses_num} is larger than the number of sessions in `recordings_list`." ) if ses_num == 0: - raise ValueError("`alignment_order` required the session number, " "not session index.") + raise ValueError("`alignment_order` required the session number, not session index.")