Skip to content

Commit

Permalink
playing with nonrigid.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 20, 2024
1 parent 45502ce commit cf16e06
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 111 deletions.
19 changes: 10 additions & 9 deletions debugging/playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@

recordings_list, _ = generate_session_displacement_recordings(
num_units=20,
recording_durations=[400, 400, 400],
recording_shifts=((0, 0), (0, 200), (0, -125)),
non_rigid_gradient=0.005,
seed=52,
recording_durations=[1000, 1000, 1000],
recording_shifts=((0, 0), (0, -300), (0, 450)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient
non_rigid_gradient=0.2,
seed=54, # 52
)
if False:
import numpy as np


recordings_list = [
si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-38-28_VR1.zarr\M25_D18_2024-11-05_12-38-28_VR1.zarr"),
si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-08-47_OF1.zarr\M25_D18_2024-11-05_12-08-47_OF1.zarr"),
Expand Down Expand Up @@ -63,7 +62,7 @@
np.save("peak_locs_3.npy", peak_locations_list[2])

# if False:
peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy"), np.load("peaks_3.npy")]
peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy") , np.load("peaks_3.npy")]
peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy"), np.load("peak_locs_3.npy")]

# --------------------------------------------------------------------------------------
Expand All @@ -83,9 +82,9 @@
estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs()
estimate_histogram_kwargs["method"] = "chunked_median"
estimate_histogram_kwargs["histogram_type"] = "activity_1d" # TODO: investigate this case thoroughly
estimate_histogram_kwargs["bin_um"] = 0.5
estimate_histogram_kwargs["bin_um"] = 5
estimate_histogram_kwargs["log_scale"] = True
estimate_histogram_kwargs["weight_with_amplitude"] = False
estimate_histogram_kwargs["weight_with_amplitude"] = True

compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs()
compute_alignment_kwargs["num_shifts_block"] = 300
Expand All @@ -94,10 +93,12 @@
recordings_list,
peaks_list,
peak_locations_list,
alignment_order="to_session_2", # "to_session_X" or "to_middle"
alignment_order="to_session_1", # "to_session_X" or "to_middle"
non_rigid_window_kwargs=non_rigid_window_kwargs,
estimate_histogram_kwargs=estimate_histogram_kwargs,
)
si.plot_traces(recordings_list[0], mode="line", time_range=(0, 1))
plt.show()

# TODO: nonlinear is not working well 'to middle', investigate
# TODO: also finalise the estimation of bin number of nonrigid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node
import copy
import scipy
import matplotlib.pyplot as plt

INTERP = "linear"

def get_estimate_histogram_kwargs() -> dict:
"""
Expand Down Expand Up @@ -855,8 +857,17 @@ def cross_correlate(sig1, sig2, thr=None):

return shift

def _correlate(signal1, signal2):

def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True):
corr_value = (
np.corrcoef(signal1,
signal2)[0, 1]
)
if False:
corr_value = np.correlate(signal1 - np.mean(signal1), signal2 - np.mean(signal2)) / signal1.size
return corr_value

def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True, round=0):
""" """
best_correlation = 0
best_displacements = np.zeros_like(signa11_blanked)
Expand All @@ -865,7 +876,7 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo

xcorr = []

for scale in np.linspace(0.85, 1.15, 10):
for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1

nonzero = np.where(signa11_blanked > 0)[0]
if not np.any(nonzero):
Expand All @@ -874,75 +885,63 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo
midpoint = nonzero[0] + np.ptp(nonzero) / 2
x_scale = (x - midpoint) * scale + midpoint

interp_f = scipy.interpolate.interp1d(
x_scale, signa11_blanked, fill_value=0.0, bounds_error=False
) # TODO: try cubic etc... or Kriging

scaled_func = interp_f(x)
# interp_f = scipy.interpolate.interp1d(
# x_scale, signa11_blanked, fill_value=0.0, bounds_error=False
# ) # TODO: try cubic etc... or Kriging

# plt.plot(signa11_blanked)
# plt.plot(scaled_func)
# plt.show()

# breakpoint()
# scaled_func = interp_f(x)

for sh in np.arange(-thr, thr): # TODO: we are off by one here

shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh)
# shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh)

x_shift = x_scale - sh # TODO: rename
x_shift = x_scale - sh

# is this pull back?
# interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging
interp_f = scipy.interpolate.interp1d(
x_shift, signa11_blanked, fill_value=0.0, bounds_error=False, kind=INTERP
)
shift_signal1_blanked = interp_f(x)

# scaled_func = interp_f(x_shift)
from scipy.ndimage import gaussian_filter

corr_value = (
np.correlate(
shift_signal1_blanked - np.mean(shift_signal1_blanked),
signal2_blanked - np.mean(signal2_blanked),
)
/ signa11_blanked.size
corr_value = _correlate(
gaussian_filter(shift_signal1_blanked, 1.5),
gaussian_filter(signal2_blanked, 1.5)
)

if np.isnan(corr_value) or corr_value < 0:
corr_value = 0

if corr_value > best_correlation:
best_displacements = x_shift
best_correlation = corr_value

if False and np.abs(sh) == 1:
print(corr_value)

plt.plot(shift_signal1_blanked)
plt.plot(signal2_blanked)
plt.show()
# plt.draw() # Draw the updated figure
# plt.pause(0.1) # Pause for 0.5 seconds before updating
# plt.clf()

# breakpoint()

# xcorr.append(np.max(np.r_[xcorr_scale]))

if False:
xcorr = np.r_[xcorr]
# shift = np.argmax(xcorr) - thr

print("MAX", np.max(xcorr))

if np.max(xcorr) < 0.0001:
shift = 0
else:
shift = np.argmax(xcorr) - thr
if False and plot and round == 1 and (corr_value > 0.3): # and plot and np.abs(sh) < 25:
print("3")
plt.plot(shift_signal1_blanked)
plt.plot(signal2_blanked)
plt.title(corr_value)
plt.show()
# plt.draw()
# plt.pause(0.1)
# plt.clf()
if plot:
print("DONE)")
plt.plot(signa11_blanked)
plt.plot(signal2_blanked)
plt.show()

print("output shift", shift)
interp_f = scipy.interpolate.interp1d(
best_displacements, signa11_blanked, fill_value=0.0, bounds_error=False, kind=INTERP
)
final = interp_f(x)
plt.plot(final)
plt.plot(signal2_blanked)
plt.show()

return best_displacements


# plt.plot(signal1)
# plt.plot(signal2)


def get_shifts(signal1, signal2, windows, plot=True):

import matplotlib.pyplot as plt
Expand All @@ -967,72 +966,71 @@ def get_shifts(signal1, signal2, windows, plot=True):
x = np.arange(signa11_blanked.size)
x_orig = x.copy()

for round in range(len(windows)):
num_points = len(windows)
max = signal1.size // 2
min = windows[0].size // 5

# if round == 0:
# shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger!
# else:
displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=200, plot=False)
k = -np.log(min / (max - min)) / (num_points - 1)
x_values = np.arange(num_points)
all_thr = (max - min) * np.exp(-1.2 * x_values) + min

# breakpoint()
# all_thr = np.linspace(max, min, len(windows))

interpf = scipy.interpolate.interp1d(
displacements, signa11_blanked, fill_value=0.0, bounds_error=False
) # TODO: move away from this indexing sceheme
signa11_blanked = interpf(x)
for round in range(num_points):

# cum_shifts.append(shift)
# print("shift", shift)
thr = all_thr[round] # TODO: optimise this somehow? go back and forth?
# if round < 2:
# thr = np.where(best_displacements == 0)[0].size // 2
# else:
# thr = windows[0].size // 5

# shift the signal1, or use indexing
print(f"ROUND: {round}, THR: {thr}")
displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=thr, plot=plot, round=round)

# signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) # INTERP HERE, KRIGING. but will accumulate interpolation errors...

# if plot:
# print("round", round)
# plt.plot(signa11_blanked)
# plt.plot(signal2_blanked)
# plt.show()
interpf = scipy.interpolate.interp1d(
displacements, signa11_blanked, fill_value=0.0, bounds_error=False, kind=INTERP
) # TODO: move away from this indexing sceheme
signa11_blanked = interpf(x)

window_corrs = np.empty(len(windows))
for i, idx in enumerate(windows):
window_corrs[i] = (
np.correlate(
signa11_blanked[idx] - np.mean(signa11_blanked[idx]),
signal2_blanked[idx] - np.mean(signal2_blanked[idx]),
window_corrs[i] = _correlate(signa11_blanked[idx], signal2_blanked[idx])

window_corrs[np.isnan(window_corrs)] = 0
if np.any(window_corrs):
max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...

if False:
small_shift = cross_correlate(
signa11_blanked[windows[max_window]],
signal2_blanked[windows[max_window]],
thr=windows[max_window].size // 2,
)
/ signa11_blanked[idx].size
)
signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, small_shift)
segment_shifts[max_window] = np.sum(cum_shifts) + small_shift

max_window = np.argmax(window_corrs) # TODO: cutoff!
if plot:
breakpoint()

if False:
small_shift = cross_correlate(
signa11_blanked[windows[max_window]],
signal2_blanked[windows[max_window]],
thr=windows[max_window].size // 2,
)
signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, small_shift)
segment_shifts[max_window] = np.sum(cum_shifts) + small_shift
best_displacements[windows[max_window]] = displacements[windows[max_window]]

best_displacements[windows[max_window]] = displacements[windows[max_window]]

x = displacements

signa11_blanked[windows[max_window]] = 0
signal2_blanked[windows[max_window]] = 0

# TODO: need to carry over displacements!
if plot:
print("FINAL")
plt.plot(signal1)
plt.plot(signal2)
plt.show()

print(best_displacements)
interpf = scipy.interpolate.interp1d(
best_displacements, signal1, fill_value=0.0, bounds_error=False
) # TODO: move away from this indexing sceheme
final = interpf(x_orig)

# plt.plot(final)
# plt.plot(signal2)
# plt.show()
interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False, kind=INTERP)
final = interpf(x_orig)
plt.plot(final)
plt.plot(signal2)
plt.show()

return np.floor(best_displacements - x_orig)

Expand Down Expand Up @@ -1136,7 +1134,11 @@ def _compute_session_alignment(
for i in range(shifted_histograms.shape[0]):
for j in range(shifted_histograms.shape[0]):

shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1, plot=True)
plot_ = False # i == 0 and j == 1
print("I", i)
print("J", j)

shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1, plot=plot_)

# shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2)
# shifts = np.empty(shifts1.size + shifts1.size - 1)
Expand All @@ -1153,9 +1155,7 @@ def _compute_session_alignment(
# nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
# shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
# )
non_rigid_shifts = nonrigid_session_offsets_matrix[
2, :, :
] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
non_rigid_shifts = nonrigid_session_offsets_matrix[0, :, :] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
non_rigid_window_centers = spatial_bin_centers
shifts = non_rigid_shifts

Expand Down

0 comments on commit cf16e06

Please sign in to comment.