From 4ce7035c44c07c11bb4a66d811bb436567e2a34b Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 16 Dec 2024 20:46:15 +0000 Subject: [PATCH] Tidying up, begin fixing alignment alg. --- .../alignment_utils.py | 137 ++++++++++++++---- .../session_alignment.py | 88 +---------- 2 files changed, 119 insertions(+), 106 deletions(-) diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py index 710ce67a14..69f196e6fa 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py @@ -1,8 +1,3 @@ -from signal import signal - -from toolz import first -from torch.onnx.symbolic_opset11 import chunk - from spikeinterface import BaseRecording import numpy as np from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram @@ -62,9 +57,7 @@ def get_activity_histogram( peak_locations, weight_with_amplitude=False, direction="y", - bin_s=( - bin_s if bin_s is not None else recording.get_duration(segment_index=0) - ), # TODO: doube cehck is this already scaling? + bin_s=(bin_s if bin_s is not None else recording.get_duration(segment_index=0)), bin_um=None, hist_margin_um=None, spatial_bin_edges=spatial_bin_edges, @@ -95,15 +88,20 @@ def get_bin_centers(bin_edges): def estimate_chunk_size(scaled_activity_histogram): """ - Get an estimate of chunk size such that - the 80th percentile of the firing rate will be - estimated within 10% 90% of the time, + Estimate a chunk size based on the firing rate. Intuitively, we + want longer chunk size to better estimate low firing rates. The + estimation computes a summary of the the firing rates for the session + by taking the value 25% of the max of the activity histogram. + + Then, the chunk size that will accurately estimate this firing rate + within 90% accuracy, 90% of the time based on assumption of Poisson + firing (based on CLT) is computed. - I think a better way is to take the peaks above half width and find the min. - Or just to take the 50th percentile...? NO. Because all peaks might be similar heights + Parameters + ---------- - corrected based on assumption - of Poisson firing (based on CLT). + scaled_activity_histogram: np.ndarray + The activity histogram scaled to firing rate in Hz. TODO ---- @@ -162,7 +160,7 @@ def get_chunked_hist_supremum(chunked_session_histograms): min_hist = np.min(chunked_session_histograms, axis=0) - scaled_range = (max_hist - min_hist) / max_hist # TODO: no idea if this is a good idea or not + scaled_range = (max_hist - min_hist) / (max_hist + 1e-12) return max_hist, scaled_range @@ -201,28 +199,31 @@ def get_chunked_hist_eigenvector(chunked_session_histograms): """ TODO: a little messy with the 2D stuff. Will probably deprecate anyway. """ - if chunked_session_histograms.shape[0] == 1: # TODO: handle elsewhere + if chunked_session_histograms.shape[0] == 1: return chunked_session_histograms.squeeze(), None is_2d = chunked_session_histograms.ndim == 3 if is_2d: - num_hist, num_spat_bin, num_amp_bin = chunked_histograms.shape + num_hist, num_spat_bin, num_amp_bin = chunked_session_histograms.shape chunked_session_histograms = np.reshape(chunked_session_histograms, (num_hist, num_spat_bin * num_amp_bin)) A = chunked_session_histograms S = (1 / A.shape[0]) * A.T @ A - U, S, Vh = np.linalg.svd(S) # TODO: this is already symmetric PSD so use eig + L, U = np.linalg.eigh(S) - first_eigenvector = U[:, 0] * np.sqrt(S[0]) - first_eigenvector = np.abs(first_eigenvector) # sometimes the eigenvector can be negative + first_eigenvector = U[:, -1] * np.sqrt(L[-1]) + first_eigenvector = np.abs(first_eigenvector) # sometimes the eigenvector is negative + # Project all vectors (histograms) onto the principal component, + # then take the standard deviation in each dimension (over bins) v1 = first_eigenvector[:, np.newaxis] - reconstruct = (A @ v1) @ v1.T - v1_std = np.std(np.sqrt(reconstruct), axis=0, ddof=0) # TODO: double check sqrt works out + projection_onto_v1 = (A @ v1 @ v1.T) / (v1.T @ v1) - if is_2d: + v1_std = np.std(projection_onto_v1, axis=0) + + if is_2d: # TODO: double check this first_eigenvector = np.reshape(first_eigenvector, (num_spat_bin, num_amp_bin)) v1_std = np.reshape(v1_std, (num_spat_bin, num_amp_bin)) @@ -423,7 +424,9 @@ def compute_histogram_crosscorrelation( windowed_histogram_i = session_histogram_list[i, :] * window windowed_histogram_j = session_histogram_list[j, :] * window - xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full") + xcorr = np.correlate( + windowed_histogram_i, windowed_histogram_j, mode="full" + ) # TODO: add weight option. if num_shifts_block: window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block) @@ -435,6 +438,14 @@ def compute_histogram_crosscorrelation( # Smooth the cross-correlations across the bins if smoothing_sigma_bin: + breakpoint() + import matplotlib.pyplot as plt + + plt.plot(xcorr_matrix[0, :]) + X = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1) + plt.plot(X[0, :]) + plt.show() + xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1) # Smooth the cross-correlations across the windows @@ -495,3 +506,79 @@ def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray: cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] return cut_padded_array + + +def akima_interpolate_nonrigid_shifts( + non_rigid_shifts: np.ndarray, + non_rigid_window_centers: np.ndarray, + spatial_bin_centers: np.ndarray, +): + """ + Perform Akima spline interpolation on a set of non-rigid shifts. + The non-rigid shifts are per segment of the probe, each segment + containing a number of channels. Interpolating these non-rigid + shifts to the spatial bin centers gives a more accurate shift + per channel. + + Parameters + ---------- + non_rigid_shifts : np.ndarray + non_rigid_window_centers : np.ndarray + spatial_bin_centers : np.ndarray + + Returns + ------- + interp_nonrigid_shifts : np.ndarray + An array (length num_spatial_bins) of shifts + interpolated from the non-rigid shifts. + + TODO + ---- + requires scipy 14 + """ + from scipy.interpolate import Akima1DInterpolator + + x = non_rigid_window_centers + xs = spatial_bin_centers + + num_sessions = non_rigid_shifts.shape[0] + num_bins = spatial_bin_centers.shape[0] + + interp_nonrigid_shifts = np.zeros((num_sessions, num_bins)) + for ses_idx in range(num_sessions): + + y = non_rigid_shifts[ses_idx] + y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs) + interp_nonrigid_shifts[ses_idx, :] = y_new + + return interp_nonrigid_shifts + + +def get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray): + """ + Given a matrix of displacements between all sessions, find the + shifts (one per session) to bring the sessions into alignment. + + Parameters + ---------- + alignment_order : "to_middle" or "to_session_X" where + "N" is the number of the session to align to. + session_offsets_matrix : np.ndarray + The num_sessions x num_sessions symmetric matrix + of displacements between all sessions, generated by + `_compute_session_alignment()`. + + Returns + ------- + optimal_shift_indices : np.ndarray + A 1 x num_sessions array of shifts to apply to + each session in order to bring all sessions into + alignment. + """ + if alignment_order == "to_middle": + optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0) + else: + ses_idx = int(alignment_order.split("_")[-1]) - 1 + optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :] + + return optimal_shift_indices diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index 330c485433..3ee7eac8d1 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -211,7 +211,7 @@ def align_sessions( interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs) # Ensure list lengths match and all channel locations are the same across recordings. - _check_align_sesssions_inputs( + _check_align_sessions_inputs( recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs ) @@ -894,11 +894,11 @@ def _compute_session_alignment( nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( shifted_histograms, non_rigid_windows, **compute_alignment_kwargs ) - non_rigid_shifts = _get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) + non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) # Akima interpolate the nonrigid bins if required. if akima_interp_nonrigid: - interp_nonrigid_shifts = _akima_interpolate_nonrigid_shifts( + interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts( non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers ) shifts = rigid_shifts + interp_nonrigid_shifts @@ -944,83 +944,9 @@ def _estimate_rigid_alignment( rigid_window, **compute_alignment_kwargs, ) - optimal_shift_indices = _get_shifts_from_session_matrix(alignment_order, rigid_session_offsets_matrix) - - return optimal_shift_indices - - -def _akima_interpolate_nonrigid_shifts( - non_rigid_shifts: np.ndarray, - non_rigid_window_centers: np.ndarray, - spatial_bin_centers: np.ndarray, -): - """ - Perform Akima spline interpolation on a set of non-rigid shifts. - The non-rigid shifts are per segment of the probe, each segment - containing a number of channels. Interpolating these non-rigid - shifts to the spatial bin centers gives a more accurate shift - per channel. - - Parameters - ---------- - non_rigid_shifts : np.ndarray - non_rigid_window_centers : np.ndarray - spatial_bin_centers : np.ndarray - - Returns - ------- - interp_nonrigid_shifts : np.ndarray - An array (length num_spatial_bins) of shifts - interpolated from the non-rigid shifts. - - TODO - ---- - requires scipy 14 - """ - from scipy.interpolate import Akima1DInterpolator - - x = non_rigid_window_centers - xs = spatial_bin_centers - - num_sessions = non_rigid_shifts.shape[0] - num_bins = spatial_bin_centers.shape[0] - - interp_nonrigid_shifts = np.zeros((num_sessions, num_bins)) - for ses_idx in range(num_sessions): - - y = non_rigid_shifts[ses_idx] - y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs) - interp_nonrigid_shifts[ses_idx, :] = y_new - - return interp_nonrigid_shifts - - -def _get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray): - """ - Given a matrix of displacements between all sessions, find the - shifts (one per session) to bring the sessions into alignment. - - Parameters - ---------- - alignment_order : "to_middle" or "to_session_X" where - "N" is the number of the session to align to. - session_offsets_matrix : np.ndarray - The num_sessions x num_sessions symmetric matrix - of displacements between all sessions, generated by - `_compute_session_alignment()`. - - Returns - ------- - optimal_shift_indices : np.ndarray - A 1 x num_sessions array of shifts to apply to - each session in order to bring all sessions into - alignment. - """ - if alignment_order == "to_middle": - optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0) - else: - ses_idx = int(alignment_order.split("_")[-1]) - 1 - optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :] + optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix( + alignment_order, rigid_session_offsets_matrix + ) return optimal_shift_indices @@ -1030,7 +956,7 @@ def _get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix # ----------------------------------------------------------------------------- -def _check_align_sesssions_inputs( +def _check_align_sessions_inputs( recordings_list: list[BaseRecording], peaks_list: list[np.ndarray], peak_locations_list: list[np.ndarray],