Skip to content

Commit

Permalink
Continue! working on tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jan 17, 2025
1 parent 9da4cbf commit 01adece
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,17 @@ def get_chunked_hist_median(chunked_session_histograms):

# TODO: a good test here is to give zero shift for even and off numbered hist and check the output is zero!
def compute_histogram_crosscorrelation(
session_histogram_list: list[np.ndarray],
session_histogram_list: np.ndarray,
non_rigid_windows: np.ndarray,
num_shifts: int,
interpolate: bool,
interp_factor: int,
kriging_sigma: float,
kriging_p: float,
kriging_d: float,
smoothing_sigma_bin: float,
smoothing_sigma_window: float,
):
smoothing_sigma_bin: None | float,
smoothing_sigma_window: None | float,
) -> tuple[np.ndarray, np.ndarray]:
"""
Given a list of session activity histograms, cross-correlate
all histograms returning the peak correlation shift (in indices)
Expand All @@ -185,7 +185,8 @@ def compute_histogram_crosscorrelation(
Parameters
----------
session_histogram_list : list[np.ndarray]
session_histogram_list : list[np.ndarray] TODO: change name!!
(num_sessions, num_bins) array of session activity histograms.
non_rigid_windows : np.ndarray
A (num windows x num_bins) binary of weights by which to window
the activity histogram for non-rigid-registration. For example, if
Expand Down Expand Up @@ -258,23 +259,21 @@ def compute_histogram_crosscorrelation(
"""
import matplotlib.pyplot as plt

num_sessions = len(session_histogram_list)
num_sessions = session_histogram_list.shape[0]
num_bins = session_histogram_list.shape[1] # all hists are same length
num_windows = non_rigid_windows.shape[0]

shift_matrix = np.zeros((num_sessions, num_sessions, num_windows))

center_bin = np.floor((num_bins * 2 - 1) / 2).astype(int)

# Create the (num windows, num_bins) matrix for this pair of sessions
num_iter = num_bins * 2 - 1 if not num_shifts else num_shifts * 2
shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1)

for i in range(num_sessions):
for j in range(i, num_sessions):

# Create the (num windows, num_bins) matrix for this pair of sessions
num_iter = (
num_bins * 2 - 1
if not num_shifts
else num_shifts * 2 # num_shift_block with iterative alignment is 2x, the same, make note!
)
xcorr_matrix = np.zeros((non_rigid_windows.shape[0], num_iter))

# For each window, window the session histograms (`window` is binary)
Expand All @@ -292,12 +291,12 @@ def compute_histogram_crosscorrelation(
window_i = windowed_histogram_i - np.mean(windowed_histogram_i, axis=1)[:, np.newaxis]
window_j = windowed_histogram_j - np.mean(windowed_histogram_j, axis=1)[:, np.newaxis]

xcorr = np.zeros(num_iter)
for idx, shift in enumerate(range(-num_iter // 2, num_iter // 2)):
xcorr = np.zeros(num_iter + 1)

for idx, shift in enumerate(shifts_array):
shifted_i = shift_array_fill_zeros(window_i, shift)

xcorr[idx] = np.correlate(shifted_i.flatten(), window_j.flatten())

else:
# For a 1D histogram, compute the full cross-correlation and
# window the desired shifts ( this is faster than manual looping).
Expand All @@ -315,11 +314,6 @@ def compute_histogram_crosscorrelation(

xcorr_matrix[win_idx, :] = xcorr

if num_shifts:
shift_center_bin = num_shifts
else:
shift_center_bin = center_bin

# Smooth the cross-correlations across the bins
if smoothing_sigma_bin:
xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)
Expand All @@ -328,36 +322,67 @@ def compute_histogram_crosscorrelation(
if num_windows > 1 and smoothing_sigma_window:
xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_window, axes=0)

shifts_array = np.arange(-(num_iter // 2), num_iter // 2 + 1) # TODO: double check
# Upsample the cross-correlation
if interpolate:
shifts = np.arange(xcorr_matrix.shape[1])
shifts_upsampled = np.linspace(shifts[0], shifts[-1], shifts.size * interp_factor)

# shifts = np.arange(xcorr_matrix.shape[1])
shifts_upsampled = np.linspace(shifts_array[0], shifts_array[-1], shifts_array.size * interp_factor)

K = kriging_kernel(
np.c_[np.ones_like(shifts), shifts],
np.c_[np.ones_like(shifts_array), shifts_array],
np.c_[np.ones_like(shifts_upsampled), shifts_upsampled],
kriging_sigma,
kriging_p,
kriging_d,
)
xcorr_matrix = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])

xcorr_peak = np.argmax(xcorr_matrix, axis=1) / interp_factor
# breakpoint()

xcorr_matrix_old = np.matmul(xcorr_matrix, K, axes=[(-2, -1), (-2, -1), (-2, -1)])
xcorr_matrix_ = np.zeros(
(xcorr_matrix.shape[0], shifts_upsampled.size)
) # TODO: check in nonlinear case
for i_ in range(xcorr_matrix.shape[0]):
xcorr_matrix_[i_, :] = np.matmul(xcorr_matrix[i_, :], K)

# breakpoint()

plt.plot(shifts_array, xcorr_matrix.T)
plt.show
plt.plot(shifts_upsampled, xcorr_matrix_.T)
plt.show()

xcorr_matrix = xcorr_matrix_

# plt.plot(xcorr_matrix.T)
# plt.plot(xcorr_matrix_old.T)
# plt.show()
#

xcorr_peak = np.argmax(xcorr_matrix, axis=1)
shift = shifts_upsampled[xcorr_peak]

# breakpoint()

else:
xcorr_peak = np.argmax(xcorr_matrix, axis=1)
shift = shifts_array[xcorr_peak]

# Caclulate and save the shift for session i to j
shift = xcorr_peak - shift_center_bin
# x=i;y=j
# breakpoint()
shift_matrix[i, j, :] = shift

breakpoint()

# As xcorr shifts are symmetric, the shift matrix is skew symmetric, so fill
# the (empty) lower triangular with the negative (already computed) upper triangular to save computation
for k in range(shift_matrix.shape[2]):
lower_i, lower_j = np.tril_indices_from(shift_matrix[:, :, k], k=-1)
upper_i, upper_j = np.triu_indices_from(shift_matrix[:, :, k], k=1)
shift_matrix[lower_i, lower_j, k] = shift_matrix[upper_i, upper_j, k] * -1

return shift_matrix
return shift_matrix, xcorr_matrix


def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_estimate_histogram_kwargs() -> dict:
"bin_um": 2,
"method": "chunked_mean",
"chunked_bin_size_s": "estimate",
"log_scale": False,
"log_scale": True,
"depth_smooth_um": None,
"histogram_type": "activity_1d",
"weight_with_amplitude": False,
Expand Down Expand Up @@ -881,8 +881,6 @@ def _compute_session_alignment(
return rigid_shifts, non_rigid_windows, non_rigid_window_centers

# For non-rigid, first shift the histograms according to the rigid shift
shifted_histograms = session_histogram_array.copy()

shifted_histograms = np.zeros_like(session_histogram_array)
for ses_idx, orig_histogram in enumerate(session_histogram_array):

Expand All @@ -892,7 +890,7 @@ def _compute_session_alignment(
shifted_histograms[ses_idx, :] = shifted_histogram

# Then compute the nonrigid shifts
nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
nonrigid_session_offsets_matrix, _ = alignment_utils.compute_histogram_crosscorrelation(
shifted_histograms, non_rigid_windows, num_shifts=num_shifts_block, **compute_alignment_kwargs
)
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
Expand Down Expand Up @@ -940,7 +938,7 @@ def _estimate_rigid_alignment(

rigid_window = np.ones(session_histogram_array.shape[1])[np.newaxis, :]

rigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
rigid_session_offsets_matrix, _ = alignment_utils.compute_histogram_crosscorrelation(
session_histogram_array,
rigid_window,
num_shifts=num_shifts,
Expand Down
113 changes: 82 additions & 31 deletions src/spikeinterface/preprocessing/tests/test_inter_session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def test_recording_1(self):
############################################################################

# TEST 1D AND 2D HERE
@pytest.mark.parametrize("histogram_type", ["activity_2d"]) # "activity_1d"
# TODO: test shift blocks...
@pytest.mark.parametrize("histogram_type", ["activity_1d", "activity_2d"]) # "activity_1d" "activity_2d"
def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_type):
""" """
recordings_list, _, peaks_list, peak_locations_list = test_recording_1
Expand All @@ -57,8 +58,9 @@ def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_t
compute_alignment_kwargs["smoothing_sigma_window"] = None

estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs()
estimate_histogram_kwargs["bin_um"] = 0.5
estimate_histogram_kwargs["bin_um"] = 2
estimate_histogram_kwargs["histogram_type"] = histogram_type
estimate_histogram_kwargs["log_scale"] = True

for mode, expected in zip(
["to_session_1", "to_session_2", "to_session_3", "to_middle"],
Expand All @@ -78,32 +80,7 @@ def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_t
estimate_histogram_kwargs=estimate_histogram_kwargs,
)

# assert np.allclose(expected, extra_info["shifts_array"].squeeze(), rtol=0, atol=1.5)

from spikeinterface.widgets import plot_session_alignment, plot_activity_histogram_2d
import matplotlib.pyplot as plt

corr_peaks_list, corr_peak_loc_list = session_alignment.compute_peaks_locations_for_session_alignment(
corrected_recordings_list,
detect_kwargs={"method": "locally_exclusive"},
localize_peaks_kwargs={"method": "grid_convolution"},
)

plot_session_alignment(
corrected_recordings_list,
corr_peaks_list,
corr_peak_loc_list,
extra_info["spatial_bin_centers"],
**extra_info["corrected"],
)
plt.show()

plot_activity_histogram_2d(
extra_info["session_histogram_list"],
extra_info["spatial_bin_centers"],
extra_info["corrected"]["session_histogram_list"],
)
plt.show()
assert np.allclose(expected, extra_info["shifts_array"].squeeze(), rtol=0, atol=0.02)

# plot_session_alignment
# recordings_list: list[BaseRecording],
Expand All @@ -124,8 +101,8 @@ def test_align_sessions_finds_correct_shifts(self, test_recording_1, histogram_t

rows, cols = np.triu_indices(len(new_histograms), k=1)
assert np.all(
np.abs(np.corrcoef(new_histograms)[rows, cols])
- np.abs(np.corrcoef(extra_info["session_histogram_list"])[rows, cols])
np.abs(np.corrcoef([hist.flatten() for hist in new_histograms])[rows, cols])
- np.abs(np.corrcoef([hist.flatten() for hist in extra_info["session_histogram_list"]])[rows, cols])
>= 0
)

Expand Down Expand Up @@ -766,7 +743,81 @@ def estimate_chunk_size(self):
def test_akima_interpolate_nonrigid_shifts(self):
pass

def test_compute_histogram_crosscorrelation(self):
# TODO:
@pytest.mark.parametrize("shifts", [3]) # -2 #test and off and even shift
def test_compute_histogram_crosscorrelation(self, shifts):

even_hist = np.array([0, 0, 1, 1, 0, 1, 0, 1])
odd_hist = np.array([1, 0, 1, 1, 1, 0])

even_hist_shift = alignment_utils.shift_array_fill_zeros(even_hist, shifts)
odd_hist_shift = alignment_utils.shift_array_fill_zeros(odd_hist, shifts)

session_histogram_list = np.vstack([even_hist, even_hist_shift])

# Ut oh, is interpolate broken?
interpolate = True # or False
interp_factor = 50

shifts_matrix, xcorr_matrix_unsmoothed = alignment_utils.compute_histogram_crosscorrelation(
session_histogram_list,
non_rigid_windows=np.ones((1, even_hist.size)), # TODO: test non rigid!
num_shifts=None, # TODO: test num shifts!
interpolate=interpolate,
interp_factor=interp_factor,
kriging_sigma=0.5,
kriging_p=2,
kriging_d=2,
smoothing_sigma_bin=None,
smoothing_sigma_window=None,
)
breakpoint()
assert alignment_utils.get_shifts_from_session_matrix("to_session_1", shifts_matrix)[-1] == -shifts

num_shifts = even_hist.size * 2 - 1
if interpolate:
assert xcorr_matrix_unsmoothed.shape[1] == num_shifts * interp_factor
else:
assert xcorr_matrix_unsmoothed.shape[1] == num_shifts

shifts_matrix_smoothed_bin, xcorr_matrix_smoothed_bin = alignment_utils.compute_histogram_crosscorrelation(
session_histogram_list,
non_rigid_windows=np.ones((1, even_hist.size)), # TODO: test non rigid!
num_shifts=None, # TODO: test num shifts!
interpolate=interpolate,
interp_factor=interp_factor,
kriging_sigma=1,
kriging_p=1,
kriging_d=1,
smoothing_sigma_bin=0.5,
smoothing_sigma_window=None,
)

shifts_matrix_smoothed_window, xcorr_matrix_smoothed_window = (
alignment_utils.compute_histogram_crosscorrelation(
session_histogram_list,
non_rigid_windows=np.ones((1, even_hist.size)), # TODO: test non rigid!
num_shifts=None, # TODO: test num shifts!
interpolate=interpolate,
interp_factor=interp_factor,
kriging_sigma=1,
kriging_p=1,
kriging_d=1,
smoothing_sigma_bin=None,
smoothing_sigma_window=0.5,
)
)

# make a histogram (odd and even length)
# shift it (odd and even shift)
# check smoothing across bins and time
# check interpolate
# thats it!

def test_compute_histogram_crosscorrelation_gaussian_filter_kwargs(self): ## TODO: incorporate these above
pass

def test_compute_histogram_crosscorrelation_kriging_kwargs(self):
pass

###########################################################################
Expand Down

0 comments on commit 01adece

Please sign in to comment.