diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index 6e567eeee3..8b1512c2bd 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -18,6 +18,7 @@ INTERP = "linear" + def get_estimate_histogram_kwargs() -> dict: """ A dictionary controlling how the histogram for each session is @@ -857,16 +858,15 @@ def cross_correlate(sig1, sig2, thr=None): return shift + def _correlate(signal1, signal2): - corr_value = ( - np.corrcoef(signal1, - signal2)[0, 1] - ) + corr_value = np.corrcoef(signal1, signal2)[0, 1] if False: corr_value = np.correlate(signal1 - np.mean(signal1), signal2 - np.mean(signal2)) / signal1.size return corr_value + def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plot=True, round=0): """ """ best_correlation = 0 @@ -885,11 +885,11 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo midpoint = nonzero[0] + np.ptp(nonzero) / 2 x_scale = (x - midpoint) * scale + midpoint - # interp_f = scipy.interpolate.interp1d( - # x_scale, signa11_blanked, fill_value=0.0, bounds_error=False - # ) # TODO: try cubic etc... or Kriging + # interp_f = scipy.interpolate.interp1d( + # x_scale, signa11_blanked, fill_value=0.0, bounds_error=False + # ) # TODO: try cubic etc... or Kriging - # scaled_func = interp_f(x) + # scaled_func = interp_f(x) for sh in np.arange(-thr, thr): # TODO: we are off by one here @@ -904,10 +904,7 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo from scipy.ndimage import gaussian_filter - corr_value = _correlate( - gaussian_filter(shift_signal1_blanked, 1.5), - gaussian_filter(signal2_blanked, 1.5) - ) + corr_value = _correlate(gaussian_filter(shift_signal1_blanked, 1.5), gaussian_filter(signal2_blanked, 1.5)) if np.isnan(corr_value) or corr_value < 0: corr_value = 0 @@ -916,15 +913,15 @@ def cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=100, plo best_displacements = x_shift best_correlation = corr_value - if False and plot and round == 1 and (corr_value > 0.3): # and plot and np.abs(sh) < 25: + if False and plot and round == 1 and (corr_value > 0.3): # and plot and np.abs(sh) < 25: print("3") plt.plot(shift_signal1_blanked) plt.plot(signal2_blanked) plt.title(corr_value) plt.show() - # plt.draw() - # plt.pause(0.1) - # plt.clf() + # plt.draw() + # plt.pause(0.1) + # plt.clf() if plot: print("DONE)") plt.plot(signa11_blanked) @@ -974,15 +971,15 @@ def get_shifts(signal1, signal2, windows, plot=True): x_values = np.arange(num_points) all_thr = (max - min) * np.exp(-1.2 * x_values) + min -# all_thr = np.linspace(max, min, len(windows)) + # all_thr = np.linspace(max, min, len(windows)) for round in range(num_points): thr = all_thr[round] # TODO: optimise this somehow? go back and forth? - # if round < 2: - # thr = np.where(best_displacements == 0)[0].size // 2 - # else: - # thr = windows[0].size // 5 + # if round < 2: + # thr = np.where(best_displacements == 0)[0].size // 2 + # else: + # thr = windows[0].size // 5 print(f"ROUND: {round}, THR: {thr}") displacements = cross_correlate_with_scale(x, signa11_blanked, signal2_blanked, thr=thr, plot=plot, round=round) @@ -998,7 +995,9 @@ def get_shifts(signal1, signal2, windows, plot=True): window_corrs[np.isnan(window_corrs)] = 0 if np.any(window_corrs): - max_window = np.argmax(np.abs(window_corrs)) # TODO: cutoff! TODO: note sure about the abs, very weird edge case... + max_window = np.argmax( + np.abs(window_corrs) + ) # TODO: cutoff! TODO: note sure about the abs, very weird edge case... if False: small_shift = cross_correlate( @@ -1014,7 +1013,6 @@ def get_shifts(signal1, signal2, windows, plot=True): best_displacements[windows[max_window]] = displacements[windows[max_window]] - x = displacements signa11_blanked[windows[max_window]] = 0 @@ -1026,7 +1024,9 @@ def get_shifts(signal1, signal2, windows, plot=True): plt.plot(signal2) plt.show() - interpf = scipy.interpolate.interp1d(best_displacements, signal1, fill_value=0.0, bounds_error=False, kind=INTERP) + interpf = scipy.interpolate.interp1d( + best_displacements, signal1, fill_value=0.0, bounds_error=False, kind=INTERP + ) final = interpf(x_orig) plt.plot(final) plt.plot(signal2) @@ -1134,7 +1134,7 @@ def _compute_session_alignment( for i in range(shifted_histograms.shape[0]): for j in range(shifted_histograms.shape[0]): - plot_ = False # i == 0 and j == 1 + plot_ = False # i == 0 and j == 1 print("I", i) print("J", j) @@ -1155,7 +1155,9 @@ def _compute_session_alignment( # nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( # shifted_histograms, non_rigid_windows, **compute_alignment_kwargs # ) - non_rigid_shifts = nonrigid_session_offsets_matrix[0, :, :] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) + non_rigid_shifts = nonrigid_session_offsets_matrix[ + 0, :, : + ] # alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) non_rigid_window_centers = spatial_bin_centers shifts = non_rigid_shifts