Skip to content

Commit

Permalink
With additional windowing.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 24, 2024
1 parent e01bb3a commit e77d06b
Show file tree
Hide file tree
Showing 11 changed files with 150 additions and 83 deletions.
Binary file modified debugging/histogram1.npy
Binary file not shown.
Binary file modified debugging/histogram2.npy
Binary file not shown.
Binary file modified debugging/histogram3.npy
Binary file not shown.
Binary file modified debugging/peak_locs_1.npy
Binary file not shown.
Binary file modified debugging/peak_locs_2.npy
Binary file not shown.
Binary file modified debugging/peak_locs_3.npy
Binary file not shown.
Binary file modified debugging/peaks_1.npy
Binary file not shown.
Binary file modified debugging/peaks_2.npy
Binary file not shown.
Binary file modified debugging/peaks_3.npy
Binary file not shown.
4 changes: 2 additions & 2 deletions debugging/playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
# --------------------------------------------------------------------------------------

recordings_list, _ = generate_session_displacement_recordings(
num_units=120,
num_units=20,
recording_durations=[400, 400, 400],
recording_shifts=((0, 0), (0, -300), (0, 200)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient
recording_shifts=((0, 0), (0, -200), (0, 100)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient
non_rigid_gradient=None, # 0.1,
seed=2, # 52
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,60 +942,90 @@ def cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=100, plo
return best_displacements


def cross_correlate_with_scaled_fixed(x_orig, new_positions, fixed_windows, signal1_blanked, signal2_blanked, thr, round_, plot):
def _interp(signal, current_coords, orig_coords):
interp_f = scipy.interpolate.interp1d(
current_coords, signal, fill_value=0.0, bounds_error=False, kind=INTERP
)
return interp_f(orig_coords)

# Be extremely careful, if the fixed_indices is not correct, everything will be messed
# up and it will be very hard to detect!
# the windowing will just adjust blanked_mask and fixed_windows...
# For windowing, it would be easier just to pass the sinals directly and get displacements back
# 1) project 2) window 3) correlate and add shifts
def cross_correlate_with_scaled_fixed(x_orig, new_positions, blanked_mask, fixed_windows, histogram_array_blanked, i, j, thr, round_, plot):
"""
"""
best_correlation = 0
best_positions = np.zeros_like(signal1_blanked)
best_positions = np.zeros(histogram_array_blanked.shape[1])

for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1
histogram_array_blanked = histogram_array_blanked.copy()
for i_ in range(histogram_array_blanked.shape[0]):
histogram_array_blanked[i_, fixed_windows[i_, :]] = 0

nonzero = np.where(signal1_blanked > 0)[0]
if not np.any(nonzero):
continue
# We are comparing i to j
signal2 = _interp(histogram_array_blanked[j, :], new_positions[j, :], x_orig)

midpoint = nonzero[0] + np.ptp(nonzero) / 2
x_scale = (x_orig - midpoint) * scale + midpoint
# just need to do this to find the scaling midpoint...
signal1_for_midpoint = _interp(histogram_array_blanked[i, :], new_positions[i, :], x_orig)
signal1_for_midpoint[blanked_mask] = 0
nonzero = np.where(signal1_for_midpoint > 0)[0]
if not np.any(nonzero):
return new_positions[i, :] # no change
midpoint = nonzero[0] + np.ptp(nonzero) / 2

for sh in np.arange(-thr, thr): # TODO: we are off by one here
best_s1 = None

x_shift = x_scale - sh
for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1

x_shift_ = x_orig.copy() # TODO
x_shift_[~fixed_windows] = x_shift[~fixed_windows]
x_shift = x_shift_
x_scale = (new_positions[i, :] - midpoint) * scale + midpoint

interp_f = scipy.interpolate.interp1d(
x_shift, signal1_blanked, fill_value=0.0, bounds_error=False, kind=INTERP
)
shift_signal1_blanked = interp_f(x_orig)
for sh in np.arange(-thr, thr): # TODO: we are off by one here (note we go 2x thr here!)

putative_new_x = x_scale - sh

# x_shift_ = new_positions[i, :].copy() # TODO
# x_shift_[~fixed_windows[i, :]] = putative_new_x[~fixed_windows[i, :]]
# putative_new_x = x_shift_

shift_signal1 = _interp(histogram_array_blanked[i, :], putative_new_x, x_orig)

from scipy.ndimage import gaussian_filter

s1 = shift_signal1.copy()
s2 = signal2.copy()
s1[blanked_mask] = 0
s2[blanked_mask] = 0

corr_value = _correlate(
gaussian_filter(shift_signal1_blanked, 0.5), # TODO: need to adapt to kinetics of the data
gaussian_filter(signal2_blanked, 0.5)
s1, # shift_signal1[~blanked_mask], # gaussian_filter(histogram_array_blanked[i, :], 0.5), # TODO: need to adapt to kinetics of the data
s2, # signal2[~blanked_mask], # gaussian_filter(signal2_blanked, 0.5)
)

corr_value *= 1 - np.abs(sh - 0) / thr
percent_diff = np.exp(-(np.abs(1 - np.sum(s1) / np.sum(histogram_array_blanked[i, :]))) ** 2 / 1.2 ** 2) ** 12
corr_value *= percent_diff # heavily penalise interpolation errors

# corr_value *= 1 - np.abs(sh - 0) / thr

if np.isnan(corr_value) or corr_value < 0:
corr_value = 0

if corr_value > best_correlation:
best_positions = x_shift
best_positions = putative_new_x
best_correlation = corr_value
best_s1 = s1

# plt.plot(shift_signal1_blanked)
# plt.plot(signal2_blanked)
# plt.title(corr_value)
# plt.draw()
# plt.pause(0.1)
# plt.clf()

new_positions = new_positions + (best_positions - x_orig)
if round_ > 0:
plt.plot(s1)
plt.plot(s2)
plt.title(corr_value)
plt.draw()
plt.pause(0.1)
plt.clf()

return new_positions, best_correlation
new_pos = new_positions[i, :].copy()
new_pos[~fixed_windows[i, :]] = best_positions[~fixed_windows[i, :]]
return new_pos


def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, orig_blank_histograms, interp_blanked_histograms, thr, round_, plot):
Expand Down Expand Up @@ -1092,8 +1122,8 @@ def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, orig_bla

def get_threshold_array(num_bins, windows):
num_points = len(windows)
max = num_bins
min = windows[0].size // 2
max = num_bins // 2
min = windows[0].size // 5

k = -np.log(min / (max - min)) / (num_points - 1)
x_values = np.arange(num_points)
Expand All @@ -1118,26 +1148,11 @@ def get_shifts_union(histogram_array, windows, plot=True):

loss = 0

for round in range(len(windows)):

# print("ROUND", round)

# thr = all_thr[round]
blanked_mask = np.zeros(histogram_array.shape[1]).astype(bool)

# shift_matrix = np.zeros((histogram_array.shape[0], histogram_array.shape[0], histogram_array.shape[1]))
# correlations = np.zeros((histogram_array.shape[0], histogram_array.shape[0]))

# histogram_array_blanked_interp = np.zeros_like(histogram_array_blanked)
# for i in range(histogram_array_blanked.shape[0]):
# interpf = scipy.interpolate.interp1d(
# new_positions[i, :], histogram_array_blanked[i, :], fill_value=0.0, bounds_error=False,
# kind=INTERP
# )
# histogram_array_blanked_interp[i, :] = interpf(x_orig)
for round in range(len(windows)):

# print("BEFORE")
# plt.plot(histogram_array_blanked_interp.T)
# plt.show()
thr = all_thr[round]

# find contigious window ids
diffs = np.diff(windows_to_run)
Expand All @@ -1146,28 +1161,74 @@ def get_shifts_union(histogram_array, windows, plot=True):

for block in all_blocks:

print("BLOCK", block)

window_indexes = []
block_bools = np.ones(histogram_array.shape[1]).astype(bool)
for block_idx in block:
block_bools[windows[block_idx]] = False
window_indexes.append(windows[block_idx])
window_indexes = np.hstack(window_indexes)

if round< 100: # TODO: maybe some function of num windows?
shift_matrix = np.zeros((histogram_array.shape[0], histogram_array.shape[0], histogram_array.shape[1]))


y = np.zeros_like(histogram_array_blanked)
for i in range(histogram_array_blanked.shape[0]):
interpf = scipy.interpolate.interp1d(
new_positions[i, :], histogram_array_blanked[i, :], fill_value=0.0,
bounds_error=False,
kind=INTERP
)
y[i, :] = interpf(x_orig)

print("BEFORE")
plt.plot(y.T)
plt.show()


if round == 0: # TODO: maybe some function of num windows?
for i in range(histogram_array.shape[0]):
for j in range(histogram_array.shape[0]):

fixed_windows_round = np.logical_or(fixed_windows[i, :], block_bools)
fixed_windows_orig = np.ones_like(histogram_array).astype(bool)

histogram_array_blanked_interp_i = histogram_array_blanked_interp[i, :].copy()
histogram_array_blanked_interp_i[fixed_windows_round] = 0
# TODO: DIRECT COPY!!!
if window_indexes[0] == 0:
window_min = np.min([new_positions[i, :], new_positions[j, :]]) - 1
else:
window_min = window_indexes[0]

histogram_array_blanked_interp_j = histogram_array_blanked_interp[j, :].copy()
histogram_array_blanked_interp_j[np.logical_or(fixed_windows[j, :], block_bools)] = 0
if window_indexes[-1] == x_orig[-1]:
window_max = np.max([new_positions[i, :], new_positions[j, :]]) + 1
else:
window_max = window_indexes[-1]

shift_matrix[i, j, :], correlations[i, j] = cross_correlate_with_scaled_fixed(
x_orig, new_positions[i, :], fixed_windows_round, histogram_array_blanked_interp_i, histogram_array_blanked_interp_j, thr=thr, round_=round, plot=plot # , plot=False, round=round
)
fixed_indices = np.where(np.logical_and(new_positions[i, :] >= window_min,
new_positions[i, :] <= window_max))
fixed_windows_orig[i, fixed_indices] = False # TODO: CAREUFULLY CHECK MAPPING

fixed_indices = np.where(np.logical_and(new_positions[j, :] >= window_min,
new_positions[j, :] <= window_max))
fixed_windows_orig[j, fixed_indices] = False # TODO: CAREUFULLY CHECK MAPPING
# DIRECT COPY END

# from the window, and from the block
fixed_windows_round = np.logical_or(fixed_windows, fixed_windows_orig) # this is in orig space, different for all.

blanked_mask_round = np.logical_or(blanked_mask, block_bools) # this is interp space, same for all

shift_matrix[i, j, :] = cross_correlate_with_scaled_fixed(
x_orig, new_positions, blanked_mask, fixed_windows, histogram_array_blanked, i, j, thr=thr, round_=round, plot=plot
)

this_round_new_positions = np.mean(shift_matrix, axis=1) # TODO: FIX! TODO: these are not displacements






else:
# Not bad for evne no blanking!
fixed_windows_round = block_bools #np.logical_or(fixed_windows, block_bools)
Expand Down Expand Up @@ -1218,21 +1279,24 @@ def get_shifts_union(histogram_array, windows, plot=True):

this_round_new_positions = cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogram_array_blanked_new, histogram_array_blanked_interp_new, thr, round, plot=True)



histogram_array_interp = np.zeros_like(histogram_array_blanked)
for i in range(histogram_array_blanked.shape[0]):
interpf = scipy.interpolate.interp1d(
this_round_new_positions[i, :], histogram_array_blanked[i, :], fill_value=0.0, bounds_error=False,
kind=INTERP
)
histogram_array_interp[i, :] = interpf(x_orig)
histogram_array_interp[i, blanked_mask] = 0


print("INTERPED")
plt.plot(histogram_array_interp.T)
plt.show()

window_corrs = np.empty(len(windows)) # okay need to increase but shouldn't fail for one window
for i, idx in enumerate(windows):
window_corrs[i] = np.sum(np.triu(np.cov(histogram_array_interp[:, idx]), k=1)) # det doesn't work very well, too small
window_corrs[i] = np.sum(np.triu(np.corrcoef(histogram_array_interp[:, idx]), k=1)) # det doesn't work very well, too small
# plt.plot(histogram_array_interp[:, idx].T)
# plt.show()
window_corrs[np.isnan(window_corrs)] = 0
window_corrs = np.abs(window_corrs)

Expand All @@ -1253,29 +1317,32 @@ def get_shifts_union(histogram_array, windows, plot=True):
window_max = windows[max_window][-1]

fixed_indices = np.where(np.logical_and(this_round_new_positions[i, :] >= window_min, this_round_new_positions[i, :] <= window_max))
fixed_windows[i, fixed_indices] = True
# this is in original space, the new_positions are also in original space (x -> new_positions)
histogram_array_blanked[i, fixed_indices] = 0 # this is in interpolated space
window_corrs[max_window] = 0
windows_to_run = np.delete(windows_to_run, np.where(windows_to_run == max_window)[0])
fixed_windows[i, fixed_indices] = True # TODO: CAREUFULLY CHECK MAPPING

blanked_mask[windows[max_window]] = True
window_corrs[max_window] = 0
windows_to_run = np.delete(windows_to_run, np.where(windows_to_run == max_window)[0])

# if round == 1 or not np.any(window_corrs > 0.1): # TODO: definately keep a running track of the xcorr and quit when it gets worse or doesn't improve. See how this example does across the rounds
# break

final = np.zeros_like(histogram_array_blanked)
for i in range(histogram_array_blanked.shape[0]):
interpf = scipy.interpolate.interp1d(
this_round_new_positions[i], histogram_array[i, :], fill_value=0.0, bounds_error=False, kind=INTERP
)
final[i, :] = interpf(x_orig)
# final = np.zeros_like(histogram_array_blanked)
# for i in range(histogram_array_blanked.shape[0]):
# interpf = scipy.interpolate.interp1d(
# this_round_new_positions[i], histogram_array[i, :], fill_value=0.0, bounds_error=False, kind=INTERP
# )
# final[i, :] = interpf(x_orig)

loss_ = 0 # okay need to increase but shouldn't fail for one window
loss_ += np.sum(
np.triu(np.cov(histogram_array_interp), k=1)
)
# loss_ = 0 # okay need to increase but shouldn't fail for one window
# loss_ += np.sum(
# np.triu(np.cov(histogram_array_interp), k=1)
# )

new_positions = this_round_new_positions # TODO

if not np.any(window_corrs > 0.01): # TODO: KEY <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< HADNLE THIS FIRST
break

# if round == 1:
# break
# if round == 0:
Expand All @@ -1287,9 +1354,9 @@ def get_shifts_union(histogram_array, windows, plot=True):
# new_positions = this_round_new_positions # TODO
# loss = loss_

print("FINAL")
plt.plot(final.T)
plt.show()
# print("FINAL")
# plt.plot(final.T)
# plt.show()

# going to have to check the improvement in fit for every round and
# if the round does not add much to the loss, then don't make the
Expand Down

0 comments on commit e77d06b

Please sign in to comment.