diff --git a/debugging/playing.py b/debugging/playing.py index dba45ad5b9..69a60998ad 100644 --- a/debugging/playing.py +++ b/debugging/playing.py @@ -23,15 +23,14 @@ recordings_list, _ = generate_session_displacement_recordings( num_units=20, - recording_durations=[400, 400, 400], - recording_shifts=((0, 0), (0, 200), (0, -125)), - non_rigid_gradient=0.005, - seed=52, + recording_durations=[1000, 1000, 1000], + recording_shifts=((0, 0), (0, -300), (0, 450)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient + non_rigid_gradient=0.2, + seed=54, # 52 ) if False: import numpy as np - recordings_list = [ si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-38-28_VR1.zarr\M25_D18_2024-11-05_12-38-28_VR1.zarr"), si.read_zarr(r"C:\Users\Joe\Downloads\M25_D18_2024-11-05_12-08-47_OF1.zarr\M25_D18_2024-11-05_12-08-47_OF1.zarr"), @@ -63,7 +62,7 @@ np.save("peak_locs_3.npy", peak_locations_list[2]) # if False: - peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy"), np.load("peaks_3.npy")] + peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy") , np.load("peaks_3.npy")] peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy"), np.load("peak_locs_3.npy")] # -------------------------------------------------------------------------------------- @@ -83,9 +82,9 @@ estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() estimate_histogram_kwargs["method"] = "chunked_median" estimate_histogram_kwargs["histogram_type"] = "activity_1d" # TODO: investigate this case thoroughly - estimate_histogram_kwargs["bin_um"] = 0.5 + estimate_histogram_kwargs["bin_um"] = 5 estimate_histogram_kwargs["log_scale"] = True - estimate_histogram_kwargs["weight_with_amplitude"] = False + estimate_histogram_kwargs["weight_with_amplitude"] = True compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs() compute_alignment_kwargs["num_shifts_block"] = 300 @@ -94,10 +93,12 @@ recordings_list, peaks_list, peak_locations_list, - alignment_order="to_session_2", # "to_session_X" or "to_middle" + alignment_order="to_session_1", # "to_session_X" or "to_middle" non_rigid_window_kwargs=non_rigid_window_kwargs, estimate_histogram_kwargs=estimate_histogram_kwargs, ) + si.plot_traces(recordings_list[0], mode="line", time_range=(0, 1)) + plt.show() # TODO: nonlinear is not working well 'to middle', investigate # TODO: also finalise the estimation of bin number of nonrigid. diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index 949a740bf0..6e567eeee3 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -14,7 +14,9 @@ from spikeinterface.preprocessing.motion import run_peak_detection_pipeline_node import copy import scipy +import matplotlib.pyplot as plt +INTERP = "linear" def get_estimate_histogram_kwargs() -> dict: """ @@ -855,8 +857,17 @@ def cross_correlate(sig1, sig2, thr=None): return shift +def _correlate(signal1, signal2): -def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True): + corr_value = ( + np.corrcoef(signal1, + signal2)[0, 1] + ) + if False: + corr_value = np.correlate(signal1 - np.mean(signal1), signal2 - np.mean(signal2)) / signal1.size + return corr_value + +def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True, round=0): """ """ best_correlation = 0 best_displacements = np.zeros_like(signa11_blanked) @@ -865,7 +876,7 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo xcorr = [] - for scale in np.linspace(0.85, 1.15, 10): + for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1 nonzero = np.where(signa11_blanked > 0)[0] if not np.any(nonzero): @@ -874,75 +885,63 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo midpoint = nonzero[0] + np.ptp(nonzero) / 2 x_scale = (x - midpoint) * scale + midpoint - interp_f = scipy.interpolate.interp1d( - x_scale, signa11_blanked, fill_value=0.0, bounds_error=False - ) # TODO: try cubic etc... or Kriging - - scaled_func = interp_f(x) + # interp_f = scipy.interpolate.interp1d( + # x_scale, signa11_blanked, fill_value=0.0, bounds_error=False + # ) # TODO: try cubic etc... or Kriging - # plt.plot(signa11_blanked) - # plt.plot(scaled_func) - # plt.show() - - # breakpoint() + # scaled_func = interp_f(x) for sh in np.arange(-thr, thr): # TODO: we are off by one here - shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh) + # shift_signal1_blanked = alignment_utils.shift_array_fill_zeros(scaled_func, sh) - x_shift = x_scale - sh # TODO: rename + x_shift = x_scale - sh - # is this pull back? - # interp_f = scipy.interpolate.interp1d(xs, shift_signal1_blanked, fill_value=0.0, bounds_error=False) # TODO: try cubic etc... or Kriging + interp_f = scipy.interpolate.interp1d( + x_shift, signa11_blanked, fill_value=0.0, bounds_error=False, kind=INTERP + ) + shift_signal1_blanked = interp_f(x) - # scaled_func = interp_f(x_shift) + from scipy.ndimage import gaussian_filter - corr_value = ( - np.correlate( - shift_signal1_blanked - np.mean(shift_signal1_blanked), - signal2_blanked - np.mean(signal2_blanked), - ) - / signa11_blanked.size + corr_value = _correlate( + gaussian_filter(shift_signal1_blanked, 1.5), + gaussian_filter(signal2_blanked, 1.5) ) + if np.isnan(corr_value) or corr_value < 0: + corr_value = 0 + if corr_value > best_correlation: best_displacements = x_shift best_correlation = corr_value - if False and np.abs(sh) == 1: - print(corr_value) - - plt.plot(shift_signal1_blanked) - plt.plot(signal2_blanked) - plt.show() - # plt.draw() # Draw the updated figure - # plt.pause(0.1) # Pause for 0.5 seconds before updating - # plt.clf() - - # breakpoint() - - # xcorr.append(np.max(np.r_[xcorr_scale])) - - if False: - xcorr = np.r_[xcorr] - # shift = np.argmax(xcorr) - thr - - print("MAX", np.max(xcorr)) - - if np.max(xcorr) < 0.0001: - shift = 0 - else: - shift = np.argmax(xcorr) - thr + if False and plot and round == 1 and (corr_value > 0.3): # and plot and np.abs(sh) < 25: + print("3") + plt.plot(shift_signal1_blanked) + plt.plot(signal2_blanked) + plt.title(corr_value) + plt.show() + # plt.draw() + # plt.pause(0.1) + # plt.clf() + if plot: + print("DONE)") + plt.plot(signa11_blanked) + plt.plot(signal2_blanked) + plt.show() - print("output shift", shift) + interp_f = scipy.interpolate.interp1d( + best_displacements, signa11_blanked, fill_value=0.0, bounds_error=False, kind=INTERP + ) + final = interp_f(x) + plt.plot(final) + plt.plot(signal2_blanked) + plt.show() return best_displacements -# plt.plot(signal1) -# plt.plot(signal2) - - def get_shifts(signal1, signal2, windows, plot=True): import matplotlib.pyplot as plt @@ -967,72 +966,71 @@ def get_shifts(signal1, signal2, windows, plot=True): x = np.arange(signa11_blanked.size) x_orig = x.copy() - for round in range(len(windows)): + num_points = len(windows) + max = signal1.size // 2 + min = windows[0].size // 5 - # if round == 0: - # shift = cross_correlate(signa11_blanked, signal2_blanked, thr=100) # for first rigid, do larger! - # else: - displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=200, plot=False) + k = -np.log(min / (max - min)) / (num_points - 1) + x_values = np.arange(num_points) + all_thr = (max - min) * np.exp(-1.2 * x_values) + min - # breakpoint() +# all_thr = np.linspace(max, min, len(windows)) - interpf = scipy.interpolate.interp1d( - displacements, signa11_blanked, fill_value=0.0, bounds_error=False - ) # TODO: move away from this indexing sceheme - signa11_blanked = interpf(x) + for round in range(num_points): - # cum_shifts.append(shift) - # print("shift", shift) + thr = all_thr[round] # TODO: optimise this somehow? go back and forth? + # if round < 2: + # thr = np.where(best_displacements == 0)[0].size // 2 + # else: + # thr = windows[0].size // 5 - # shift the signal1, or use indexing + print(f"ROUND: {round}, THR: {thr}") + displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=thr, plot=plot, round=round) - # signa11_blanked = shift_array_fill_zeros(signa11_blanked, shift) # INTERP HERE, KRIGING. but will accumulate interpolation errors... - - # if plot: - # print("round", round) - # plt.plot(signa11_blanked) - # plt.plot(signal2_blanked) - # plt.show() + interpf = scipy.interpolate.interp1d( + displacements, signa11_blanked, fill_value=0.0, bounds_error=False, kind=INTERP + ) # TODO: move away from this indexing sceheme + signa11_blanked = interpf(x) window_corrs = np.empty(len(windows)) for i, idx in enumerate(windows): - window_corrs[i] = ( - np.correlate( - signa11_blanked[idx] - np.mean(signa11_blanked[idx]), - signal2_blanked[idx] - np.mean(signal2_blanked[idx]), + window_corrs[i] = _correlate(signa11_blanked[idx], signal2_blanked[idx]) + + window_corrs[np.isnan(window_corrs)] = 0 + if np.any(window_corrs): + max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case... + + if False: + small_shift = cross_correlate( + signa11_blanked[windows[max_window]], + signal2_blanked[windows[max_window]], + thr=windows[max_window].size // 2, ) - / signa11_blanked[idx].size - ) + signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, small_shift) + segment_shifts[max_window] = np.sum(cum_shifts) + small_shift - max_window = np.argmax(window_corrs) # TODO: cutoff! + if plot: + breakpoint() - if False: - small_shift = cross_correlate( - signa11_blanked[windows[max_window]], - signal2_blanked[windows[max_window]], - thr=windows[max_window].size // 2, - ) - signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, small_shift) - segment_shifts[max_window] = np.sum(cum_shifts) + small_shift + best_displacements[windows[max_window]] = displacements[windows[max_window]] - best_displacements[windows[max_window]] = displacements[windows[max_window]] x = displacements signa11_blanked[windows[max_window]] = 0 signal2_blanked[windows[max_window]] = 0 - # TODO: need to carry over displacements! + if plot: + print("FINAL") + plt.plot(signal1) + plt.plot(signal2) + plt.show() - print(best_displacements) - interpf = scipy.interpolate.interp1d( - best_displacements, signal1, fill_value=0.0, bounds_error=False - ) # TODO: move away from this indexing sceheme - final = interpf(x_orig) - - # plt.plot(final) - # plt.plot(signal2) - # plt.show() + interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False, kind=INTERP) + final = interpf(x_orig) + plt.plot(final) + plt.plot(signal2) + plt.show() return np.floor(best_displacements - x_orig) @@ -1136,7 +1134,11 @@ def _compute_session_alignment( for i in range(shifted_histograms.shape[0]): for j in range(shifted_histograms.shape[0]): - shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1, plot=True) + plot_ = False # i == 0 and j == 1 + print("I", i) + print("J", j) + + shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1, plot=plot_) # shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2) # shifts = np.empty(shifts1.size + shifts1.size - 1) @@ -1153,9 +1155,7 @@ def _compute_session_alignment( # nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( # shifted_histograms, non_rigid_windows, **compute_alignment_kwargs # ) - non_rigid_shifts = nonrigid_session_offsets_matrix[ - 2, :, : - ] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) + non_rigid_shifts = nonrigid_session_offsets_matrix[0, :, :] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) non_rigid_window_centers = spatial_bin_centers shifts = non_rigid_shifts