Skip to content

Commit

Permalink
Tidying up, begin fixing alignment alg.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 17, 2024
1 parent e91e9d4 commit 4ce7035
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 106 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
from signal import signal

from toolz import first
from torch.onnx.symbolic_opset11 import chunk

from spikeinterface import BaseRecording
import numpy as np
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram
Expand Down Expand Up @@ -62,9 +57,7 @@ def get_activity_histogram(
peak_locations,
weight_with_amplitude=False,
direction="y",
bin_s=(
bin_s if bin_s is not None else recording.get_duration(segment_index=0)
), # TODO: doube cehck is this already scaling?
bin_s=(bin_s if bin_s is not None else recording.get_duration(segment_index=0)),
bin_um=None,
hist_margin_um=None,
spatial_bin_edges=spatial_bin_edges,
Expand Down Expand Up @@ -95,15 +88,20 @@ def get_bin_centers(bin_edges):

def estimate_chunk_size(scaled_activity_histogram):
"""
Get an estimate of chunk size such that
the 80th percentile of the firing rate will be
estimated within 10% 90% of the time,
Estimate a chunk size based on the firing rate. Intuitively, we
want longer chunk size to better estimate low firing rates. The
estimation computes a summary of the the firing rates for the session
by taking the value 25% of the max of the activity histogram.
Then, the chunk size that will accurately estimate this firing rate
within 90% accuracy, 90% of the time based on assumption of Poisson
firing (based on CLT) is computed.
I think a better way is to take the peaks above half width and find the min.
Or just to take the 50th percentile...? NO. Because all peaks might be similar heights
Parameters
----------
corrected based on assumption
of Poisson firing (based on CLT).
scaled_activity_histogram: np.ndarray
The activity histogram scaled to firing rate in Hz.
TODO
----
Expand Down Expand Up @@ -162,7 +160,7 @@ def get_chunked_hist_supremum(chunked_session_histograms):

min_hist = np.min(chunked_session_histograms, axis=0)

scaled_range = (max_hist - min_hist) / max_hist # TODO: no idea if this is a good idea or not
scaled_range = (max_hist - min_hist) / (max_hist + 1e-12)

return max_hist, scaled_range

Expand Down Expand Up @@ -201,28 +199,31 @@ def get_chunked_hist_eigenvector(chunked_session_histograms):
"""
TODO: a little messy with the 2D stuff. Will probably deprecate anyway.
"""
if chunked_session_histograms.shape[0] == 1: # TODO: handle elsewhere
if chunked_session_histograms.shape[0] == 1:
return chunked_session_histograms.squeeze(), None

is_2d = chunked_session_histograms.ndim == 3

if is_2d:
num_hist, num_spat_bin, num_amp_bin = chunked_histograms.shape
num_hist, num_spat_bin, num_amp_bin = chunked_session_histograms.shape
chunked_session_histograms = np.reshape(chunked_session_histograms, (num_hist, num_spat_bin * num_amp_bin))

A = chunked_session_histograms
S = (1 / A.shape[0]) * A.T @ A

U, S, Vh = np.linalg.svd(S) # TODO: this is already symmetric PSD so use eig
L, U = np.linalg.eigh(S)

first_eigenvector = U[:, 0] * np.sqrt(S[0])
first_eigenvector = np.abs(first_eigenvector) # sometimes the eigenvector can be negative
first_eigenvector = U[:, -1] * np.sqrt(L[-1])
first_eigenvector = np.abs(first_eigenvector) # sometimes the eigenvector is negative

# Project all vectors (histograms) onto the principal component,
# then take the standard deviation in each dimension (over bins)
v1 = first_eigenvector[:, np.newaxis]
reconstruct = (A @ v1) @ v1.T
v1_std = np.std(np.sqrt(reconstruct), axis=0, ddof=0) # TODO: double check sqrt works out
projection_onto_v1 = (A @ v1 @ v1.T) / (v1.T @ v1)

if is_2d:
v1_std = np.std(projection_onto_v1, axis=0)

if is_2d: # TODO: double check this
first_eigenvector = np.reshape(first_eigenvector, (num_spat_bin, num_amp_bin))
v1_std = np.reshape(v1_std, (num_spat_bin, num_amp_bin))

Expand Down Expand Up @@ -423,7 +424,9 @@ def compute_histogram_crosscorrelation(
windowed_histogram_i = session_histogram_list[i, :] * window
windowed_histogram_j = session_histogram_list[j, :] * window

xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full")
xcorr = np.correlate(
windowed_histogram_i, windowed_histogram_j, mode="full"
) # TODO: add weight option.

if num_shifts_block:
window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block)
Expand All @@ -435,6 +438,14 @@ def compute_histogram_crosscorrelation(

# Smooth the cross-correlations across the bins
if smoothing_sigma_bin:
breakpoint()
import matplotlib.pyplot as plt

plt.plot(xcorr_matrix[0, :])
X = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)
plt.plot(X[0, :])
plt.show()

xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)

# Smooth the cross-correlations across the windows
Expand Down Expand Up @@ -495,3 +506,79 @@ def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray:
cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]

return cut_padded_array


def akima_interpolate_nonrigid_shifts(
non_rigid_shifts: np.ndarray,
non_rigid_window_centers: np.ndarray,
spatial_bin_centers: np.ndarray,
):
"""
Perform Akima spline interpolation on a set of non-rigid shifts.
The non-rigid shifts are per segment of the probe, each segment
containing a number of channels. Interpolating these non-rigid
shifts to the spatial bin centers gives a more accurate shift
per channel.
Parameters
----------
non_rigid_shifts : np.ndarray
non_rigid_window_centers : np.ndarray
spatial_bin_centers : np.ndarray
Returns
-------
interp_nonrigid_shifts : np.ndarray
An array (length num_spatial_bins) of shifts
interpolated from the non-rigid shifts.
TODO
----
requires scipy 14
"""
from scipy.interpolate import Akima1DInterpolator

x = non_rigid_window_centers
xs = spatial_bin_centers

num_sessions = non_rigid_shifts.shape[0]
num_bins = spatial_bin_centers.shape[0]

interp_nonrigid_shifts = np.zeros((num_sessions, num_bins))
for ses_idx in range(num_sessions):

y = non_rigid_shifts[ses_idx]
y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs)
interp_nonrigid_shifts[ses_idx, :] = y_new

return interp_nonrigid_shifts


def get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray):
"""
Given a matrix of displacements between all sessions, find the
shifts (one per session) to bring the sessions into alignment.
Parameters
----------
alignment_order : "to_middle" or "to_session_X" where
"N" is the number of the session to align to.
session_offsets_matrix : np.ndarray
The num_sessions x num_sessions symmetric matrix
of displacements between all sessions, generated by
`_compute_session_alignment()`.
Returns
-------
optimal_shift_indices : np.ndarray
A 1 x num_sessions array of shifts to apply to
each session in order to bring all sessions into
alignment.
"""
if alignment_order == "to_middle":
optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0)
else:
ses_idx = int(alignment_order.split("_")[-1]) - 1
optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :]

return optimal_shift_indices
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def align_sessions(
interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs)

# Ensure list lengths match and all channel locations are the same across recordings.
_check_align_sesssions_inputs(
_check_align_sessions_inputs(
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs
)

Expand Down Expand Up @@ -894,11 +894,11 @@ def _compute_session_alignment(
nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
)
non_rigid_shifts = _get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)

# Akima interpolate the nonrigid bins if required.
if akima_interp_nonrigid:
interp_nonrigid_shifts = _akima_interpolate_nonrigid_shifts(
interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts(
non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers
)
shifts = rigid_shifts + interp_nonrigid_shifts
Expand Down Expand Up @@ -944,83 +944,9 @@ def _estimate_rigid_alignment(
rigid_window,
**compute_alignment_kwargs,
)
optimal_shift_indices = _get_shifts_from_session_matrix(alignment_order, rigid_session_offsets_matrix)

return optimal_shift_indices


def _akima_interpolate_nonrigid_shifts(
non_rigid_shifts: np.ndarray,
non_rigid_window_centers: np.ndarray,
spatial_bin_centers: np.ndarray,
):
"""
Perform Akima spline interpolation on a set of non-rigid shifts.
The non-rigid shifts are per segment of the probe, each segment
containing a number of channels. Interpolating these non-rigid
shifts to the spatial bin centers gives a more accurate shift
per channel.
Parameters
----------
non_rigid_shifts : np.ndarray
non_rigid_window_centers : np.ndarray
spatial_bin_centers : np.ndarray
Returns
-------
interp_nonrigid_shifts : np.ndarray
An array (length num_spatial_bins) of shifts
interpolated from the non-rigid shifts.
TODO
----
requires scipy 14
"""
from scipy.interpolate import Akima1DInterpolator

x = non_rigid_window_centers
xs = spatial_bin_centers

num_sessions = non_rigid_shifts.shape[0]
num_bins = spatial_bin_centers.shape[0]

interp_nonrigid_shifts = np.zeros((num_sessions, num_bins))
for ses_idx in range(num_sessions):

y = non_rigid_shifts[ses_idx]
y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs)
interp_nonrigid_shifts[ses_idx, :] = y_new

return interp_nonrigid_shifts


def _get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray):
"""
Given a matrix of displacements between all sessions, find the
shifts (one per session) to bring the sessions into alignment.
Parameters
----------
alignment_order : "to_middle" or "to_session_X" where
"N" is the number of the session to align to.
session_offsets_matrix : np.ndarray
The num_sessions x num_sessions symmetric matrix
of displacements between all sessions, generated by
`_compute_session_alignment()`.
Returns
-------
optimal_shift_indices : np.ndarray
A 1 x num_sessions array of shifts to apply to
each session in order to bring all sessions into
alignment.
"""
if alignment_order == "to_middle":
optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0)
else:
ses_idx = int(alignment_order.split("_")[-1]) - 1
optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :]
optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix(
alignment_order, rigid_session_offsets_matrix
)

return optimal_shift_indices

Expand All @@ -1030,7 +956,7 @@ def _get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix
# -----------------------------------------------------------------------------


def _check_align_sesssions_inputs(
def _check_align_sessions_inputs(
recordings_list: list[BaseRecording],
peaks_list: list[np.ndarray],
peak_locations_list: list[np.ndarray],
Expand Down

0 comments on commit 4ce7035

Please sign in to comment.