Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 20, 2024
1 parent cf16e06 commit 70db8e2
Showing 1 changed file with 28 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

INTERP = "linear"


def get_estimate_histogram_kwargs() -> dict:
"""
A dictionary controlling how the histogram for each session is
Expand Down Expand Up @@ -857,16 +858,15 @@ def cross_correlate(sig1, sig2, thr=None):

return shift


def _correlate(signal1, signal2):

corr_value = (
np.corrcoef(signal1,
signal2)[0, 1]
)
corr_value = np.corrcoef(signal1, signal2)[0, 1]
if False:
corr_value = np.correlate(signal1 - np.mean(signal1), signal2 - np.mean(signal2)) / signal1.size
return corr_value


def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True, round=0):
""" """
best_correlation = 0
Expand All @@ -885,11 +885,11 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo
midpoint = nonzero[0] + np.ptp(nonzero) / 2
x_scale = (x - midpoint) * scale + midpoint

# interp_f = scipy.interpolate.interp1d(
# x_scale, signa11_blanked, fill_value=0.0, bounds_error=False
# ) # TODO: try cubic etc... or Kriging
# interp_f = scipy.interpolate.interp1d(
# x_scale, signa11_blanked, fill_value=0.0, bounds_error=False
# ) # TODO: try cubic etc... or Kriging

# scaled_func = interp_f(x)
# scaled_func = interp_f(x)

for sh in np.arange(-thr, thr): # TODO: we are off by one here

Expand All @@ -904,10 +904,7 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo

from scipy.ndimage import gaussian_filter

corr_value = _correlate(
gaussian_filter(shift_signal1_blanked, 1.5),
gaussian_filter(signal2_blanked, 1.5)
)
corr_value = _correlate(gaussian_filter(shift_signal1_blanked, 1.5), gaussian_filter(signal2_blanked, 1.5))

if np.isnan(corr_value) or corr_value < 0:
corr_value = 0
Expand All @@ -916,15 +913,15 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo
best_displacements = x_shift
best_correlation = corr_value

if False and plot and round == 1 and (corr_value > 0.3): # and plot and np.abs(sh) < 25:
if False and plot and round == 1 and (corr_value > 0.3): # and plot and np.abs(sh) < 25:
print("3")
plt.plot(shift_signal1_blanked)
plt.plot(signal2_blanked)
plt.title(corr_value)
plt.show()
# plt.draw()
# plt.pause(0.1)
# plt.clf()
# plt.draw()
# plt.pause(0.1)
# plt.clf()
if plot:
print("DONE)")
plt.plot(signa11_blanked)
Expand Down Expand Up @@ -974,15 +971,15 @@ def get_shifts(signal1, signal2, windows, plot=True):
x_values = np.arange(num_points)
all_thr = (max - min) * np.exp(-1.2 * x_values) + min

# all_thr = np.linspace(max, min, len(windows))
# all_thr = np.linspace(max, min, len(windows))

for round in range(num_points):

thr = all_thr[round] # TODO: optimise this somehow? go back and forth?
# if round < 2:
# thr = np.where(best_displacements == 0)[0].size // 2
# else:
# thr = windows[0].size // 5
# if round < 2:
# thr = np.where(best_displacements == 0)[0].size // 2
# else:
# thr = windows[0].size // 5

print(f"ROUND: {round}, THR: {thr}")
displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=thr, plot=plot, round=round)
Expand All @@ -998,7 +995,9 @@ def get_shifts(signal1, signal2, windows, plot=True):

window_corrs[np.isnan(window_corrs)] = 0
if np.any(window_corrs):
max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...
max_window = np.argmax(
np.abs(window_corrs)
) # TODO: cutoff! TODO: note sure about the abs, very weird edge case...

if False:
small_shift = cross_correlate(
Expand All @@ -1014,7 +1013,6 @@ def get_shifts(signal1, signal2, windows, plot=True):

best_displacements[windows[max_window]] = displacements[windows[max_window]]


x = displacements

signa11_blanked[windows[max_window]] = 0
Expand All @@ -1026,7 +1024,9 @@ def get_shifts(signal1, signal2, windows, plot=True):
plt.plot(signal2)
plt.show()

interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False, kind=INTERP)
interpf = scipy.interpolate.interp1d(
best_displacements, signal1, fill_value=0.0, bounds_error=False, kind=INTERP
)
final = interpf(x_orig)
plt.plot(final)
plt.plot(signal2)
Expand Down Expand Up @@ -1134,7 +1134,7 @@ def _compute_session_alignment(
for i in range(shifted_histograms.shape[0]):
for j in range(shifted_histograms.shape[0]):

plot_ = False # i == 0 and j == 1
plot_ = False # i == 0 and j == 1
print("I", i)
print("J", j)

Expand All @@ -1155,7 +1155,9 @@ def _compute_session_alignment(
# nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
# shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
# )
non_rigid_shifts = nonrigid_session_offsets_matrix[0, :, :] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
non_rigid_shifts = nonrigid_session_offsets_matrix[
0, :, :
] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
non_rigid_window_centers = spatial_bin_centers
shifts = non_rigid_shifts

Expand Down

0 comments on commit 70db8e2

Please sign in to comment.