diff --git a/debugging/histogram1.npy b/debugging/histogram1.npy index fd25be6c66..85fae37143 100644 Binary files a/debugging/histogram1.npy and b/debugging/histogram1.npy differ diff --git a/debugging/histogram2.npy b/debugging/histogram2.npy index dc68746a30..af6fa99ca8 100644 Binary files a/debugging/histogram2.npy and b/debugging/histogram2.npy differ diff --git a/debugging/histogram3.npy b/debugging/histogram3.npy index 99c2577603..8d003239b1 100644 Binary files a/debugging/histogram3.npy and b/debugging/histogram3.npy differ diff --git a/debugging/peak_locs_1.npy b/debugging/peak_locs_1.npy index dd54121e7e..67d2aaa0fd 100644 Binary files a/debugging/peak_locs_1.npy and b/debugging/peak_locs_1.npy differ diff --git a/debugging/peak_locs_2.npy b/debugging/peak_locs_2.npy index fb40287531..622cdbe65a 100644 Binary files a/debugging/peak_locs_2.npy and b/debugging/peak_locs_2.npy differ diff --git a/debugging/peak_locs_3.npy b/debugging/peak_locs_3.npy index d33aed1e93..9ff848da0f 100644 Binary files a/debugging/peak_locs_3.npy and b/debugging/peak_locs_3.npy differ diff --git a/debugging/peaks_1.npy b/debugging/peaks_1.npy index f14f3ac390..36e777c303 100644 Binary files a/debugging/peaks_1.npy and b/debugging/peaks_1.npy differ diff --git a/debugging/peaks_2.npy b/debugging/peaks_2.npy index 7e788dbb6c..d314d20050 100644 Binary files a/debugging/peaks_2.npy and b/debugging/peaks_2.npy differ diff --git a/debugging/peaks_3.npy b/debugging/peaks_3.npy index 12bae3e8c6..426cc94d4a 100644 Binary files a/debugging/peaks_3.npy and b/debugging/peaks_3.npy differ diff --git a/debugging/playing.py b/debugging/playing.py index 98545a1e73..cdec4d6481 100644 --- a/debugging/playing.py +++ b/debugging/playing.py @@ -30,9 +30,9 @@ # -------------------------------------------------------------------------------------- recordings_list, _ = generate_session_displacement_recordings( - num_units=120, + num_units=20, recording_durations=[400, 400, 400], - recording_shifts=((0, 0), (0, -300), (0, 200)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient + recording_shifts=((0, 0), (0, -200), (0, 100)), # TODO: can see how well this is recaptured by comparing the displacements to the known displacement + gradient non_rigid_gradient=None, # 0.1, seed=2, # 52 ) diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index fa0f088dda..9fd7f31395 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -942,60 +942,90 @@ def cross_correlate_with_scale(x, signal1_blanked, signal2_blanked, thr=100, plo return best_displacements -def cross_correlate_with_scaled_fixed(x_orig, new_positions, fixed_windows, signal1_blanked, signal2_blanked, thr, round_, plot): +def _interp(signal, current_coords, orig_coords): + interp_f = scipy.interpolate.interp1d( + current_coords, signal, fill_value=0.0, bounds_error=False, kind=INTERP + ) + return interp_f(orig_coords) + +# Be extremely careful, if the fixed_indices is not correct, everything will be messed +# up and it will be very hard to detect! +# the windowing will just adjust blanked_mask and fixed_windows... +# For windowing, it would be easier just to pass the sinals directly and get displacements back +# 1) project 2) window 3) correlate and add shifts +def cross_correlate_with_scaled_fixed(x_orig, new_positions, blanked_mask, fixed_windows, histogram_array_blanked, i, j, thr, round_, plot): """ """ best_correlation = 0 - best_positions = np.zeros_like(signal1_blanked) + best_positions = np.zeros(histogram_array_blanked.shape[1]) - for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1 + histogram_array_blanked = histogram_array_blanked.copy() + for i_ in range(histogram_array_blanked.shape[0]): + histogram_array_blanked[i_, fixed_windows[i_, :]] = 0 - nonzero = np.where(signal1_blanked > 0)[0] - if not np.any(nonzero): - continue + # We are comparing i to j + signal2 = _interp(histogram_array_blanked[j, :], new_positions[j, :], x_orig) - midpoint = nonzero[0] + np.ptp(nonzero) / 2 - x_scale = (x_orig - midpoint) * scale + midpoint + # just need to do this to find the scaling midpoint... + signal1_for_midpoint = _interp(histogram_array_blanked[i, :], new_positions[i, :], x_orig) + signal1_for_midpoint[blanked_mask] = 0 + nonzero = np.where(signal1_for_midpoint > 0)[0] + if not np.any(nonzero): + return new_positions[i, :] # no change + midpoint = nonzero[0] + np.ptp(nonzero) / 2 - for sh in np.arange(-thr, thr): # TODO: we are off by one here + best_s1 = None - x_shift = x_scale - sh + for scale in np.r_[np.linspace(0.85, 1, 10), np.linspace(1, 1.15, 10)]: # TODO: double 1 - x_shift_ = x_orig.copy() # TODO - x_shift_[~fixed_windows] = x_shift[~fixed_windows] - x_shift = x_shift_ + x_scale = (new_positions[i, :] - midpoint) * scale + midpoint - interp_f = scipy.interpolate.interp1d( - x_shift, signal1_blanked, fill_value=0.0, bounds_error=False, kind=INTERP - ) - shift_signal1_blanked = interp_f(x_orig) + for sh in np.arange(-thr, thr): # TODO: we are off by one here (note we go 2x thr here!) + + putative_new_x = x_scale - sh + + # x_shift_ = new_positions[i, :].copy() # TODO + # x_shift_[~fixed_windows[i, :]] = putative_new_x[~fixed_windows[i, :]] + # putative_new_x = x_shift_ + + shift_signal1 = _interp(histogram_array_blanked[i, :], putative_new_x, x_orig) from scipy.ndimage import gaussian_filter + s1 = shift_signal1.copy() + s2 = signal2.copy() + s1[blanked_mask] = 0 + s2[blanked_mask] = 0 + corr_value = _correlate( - gaussian_filter(shift_signal1_blanked, 0.5), # TODO: need to adapt to kinetics of the data - gaussian_filter(signal2_blanked, 0.5) + s1, # shift_signal1[~blanked_mask], # gaussian_filter(histogram_array_blanked[i, :], 0.5), # TODO: need to adapt to kinetics of the data + s2, # signal2[~blanked_mask], # gaussian_filter(signal2_blanked, 0.5) ) - corr_value *= 1 - np.abs(sh - 0) / thr + percent_diff = np.exp(-(np.abs(1 - np.sum(s1) / np.sum(histogram_array_blanked[i, :]))) ** 2 / 1.2 ** 2) ** 12 + corr_value *= percent_diff # heavily penalise interpolation errors + + # corr_value *= 1 - np.abs(sh - 0) / thr if np.isnan(corr_value) or corr_value < 0: corr_value = 0 if corr_value > best_correlation: - best_positions = x_shift + best_positions = putative_new_x best_correlation = corr_value + best_s1 = s1 - # plt.plot(shift_signal1_blanked) - # plt.plot(signal2_blanked) - # plt.title(corr_value) - # plt.draw() - # plt.pause(0.1) - # plt.clf() - - new_positions = new_positions + (best_positions - x_orig) + if round_ > 0: + plt.plot(s1) + plt.plot(s2) + plt.title(corr_value) + plt.draw() + plt.pause(0.1) + plt.clf() - return new_positions, best_correlation + new_pos = new_positions[i, :].copy() + new_pos[~fixed_windows[i, :]] = best_positions[~fixed_windows[i, :]] + return new_pos def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, orig_blank_histograms, interp_blanked_histograms, thr, round_, plot): @@ -1092,8 +1122,8 @@ def cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, orig_bla def get_threshold_array(num_bins, windows): num_points = len(windows) - max = num_bins - min = windows[0].size // 2 + max = num_bins // 2 + min = windows[0].size // 5 k = -np.log(min / (max - min)) / (num_points - 1) x_values = np.arange(num_points) @@ -1118,26 +1148,11 @@ def get_shifts_union(histogram_array, windows, plot=True): loss = 0 - for round in range(len(windows)): - -# print("ROUND", round) - -# thr = all_thr[round] + blanked_mask = np.zeros(histogram_array.shape[1]).astype(bool) -# shift_matrix = np.zeros((histogram_array.shape[0], histogram_array.shape[0], histogram_array.shape[1])) -# correlations = np.zeros((histogram_array.shape[0], histogram_array.shape[0])) - -# histogram_array_blanked_interp = np.zeros_like(histogram_array_blanked) -# for i in range(histogram_array_blanked.shape[0]): -# interpf = scipy.interpolate.interp1d( -# new_positions[i, :], histogram_array_blanked[i, :], fill_value=0.0, bounds_error=False, -# kind=INTERP -# ) -# histogram_array_blanked_interp[i, :] = interpf(x_orig) + for round in range(len(windows)): -# print("BEFORE") -# plt.plot(histogram_array_blanked_interp.T) -# plt.show() + thr = all_thr[round] # find contigious window ids diffs = np.diff(windows_to_run) @@ -1146,28 +1161,74 @@ def get_shifts_union(histogram_array, windows, plot=True): for block in all_blocks: + print("BLOCK", block) + + window_indexes = [] block_bools = np.ones(histogram_array.shape[1]).astype(bool) for block_idx in block: block_bools[windows[block_idx]] = False + window_indexes.append(windows[block_idx]) + window_indexes = np.hstack(window_indexes) + + if round< 100: # TODO: maybe some function of num windows? + shift_matrix = np.zeros((histogram_array.shape[0], histogram_array.shape[0], histogram_array.shape[1])) + + + y = np.zeros_like(histogram_array_blanked) + for i in range(histogram_array_blanked.shape[0]): + interpf = scipy.interpolate.interp1d( + new_positions[i, :], histogram_array_blanked[i, :], fill_value=0.0, + bounds_error=False, + kind=INTERP + ) + y[i, :] = interpf(x_orig) + + print("BEFORE") + plt.plot(y.T) + plt.show() + - if round == 0: # TODO: maybe some function of num windows? for i in range(histogram_array.shape[0]): for j in range(histogram_array.shape[0]): - fixed_windows_round = np.logical_or(fixed_windows[i, :], block_bools) + fixed_windows_orig = np.ones_like(histogram_array).astype(bool) - histogram_array_blanked_interp_i = histogram_array_blanked_interp[i, :].copy() - histogram_array_blanked_interp_i[fixed_windows_round] = 0 + # TODO: DIRECT COPY!!! + if window_indexes[0] == 0: + window_min = np.min([new_positions[i, :], new_positions[j, :]]) - 1 + else: + window_min = window_indexes[0] - histogram_array_blanked_interp_j = histogram_array_blanked_interp[j, :].copy() - histogram_array_blanked_interp_j[np.logical_or(fixed_windows[j, :], block_bools)] = 0 + if window_indexes[-1] == x_orig[-1]: + window_max = np.max([new_positions[i, :], new_positions[j, :]]) + 1 + else: + window_max = window_indexes[-1] - shift_matrix[i, j, :], correlations[i, j] = cross_correlate_with_scaled_fixed( - x_orig, new_positions[i, :], fixed_windows_round, histogram_array_blanked_interp_i, histogram_array_blanked_interp_j, thr=thr, round_=round, plot=plot # , plot=False, round=round - ) + fixed_indices = np.where(np.logical_and(new_positions[i, :] >= window_min, + new_positions[i, :] <= window_max)) + fixed_windows_orig[i, fixed_indices] = False # TODO: CAREUFULLY CHECK MAPPING + + fixed_indices = np.where(np.logical_and(new_positions[j, :] >= window_min, + new_positions[j, :] <= window_max)) + fixed_windows_orig[j, fixed_indices] = False # TODO: CAREUFULLY CHECK MAPPING + # DIRECT COPY END + + # from the window, and from the block + fixed_windows_round = np.logical_or(fixed_windows, fixed_windows_orig) # this is in orig space, different for all. + blanked_mask_round = np.logical_or(blanked_mask, block_bools) # this is interp space, same for all + + shift_matrix[i, j, :] = cross_correlate_with_scaled_fixed( + x_orig, new_positions, blanked_mask, fixed_windows, histogram_array_blanked, i, j, thr=thr, round_=round, plot=plot + ) this_round_new_positions = np.mean(shift_matrix, axis=1) # TODO: FIX! TODO: these are not displacements + + + + + + else: # Not bad for evne no blanking! fixed_windows_round = block_bools #np.logical_or(fixed_windows, block_bools) @@ -1218,6 +1279,8 @@ def get_shifts_union(histogram_array, windows, plot=True): this_round_new_positions = cross_correlate_combined_loss(x_orig, new_positions, fixed_windows, histogram_array_blanked_new, histogram_array_blanked_interp_new, thr, round, plot=True) + + histogram_array_interp = np.zeros_like(histogram_array_blanked) for i in range(histogram_array_blanked.shape[0]): interpf = scipy.interpolate.interp1d( @@ -1225,14 +1288,15 @@ def get_shifts_union(histogram_array, windows, plot=True): kind=INTERP ) histogram_array_interp[i, :] = interpf(x_orig) + histogram_array_interp[i, blanked_mask] = 0 + - print("INTERPED") - plt.plot(histogram_array_interp.T) - plt.show() window_corrs = np.empty(len(windows)) # okay need to increase but shouldn't fail for one window for i, idx in enumerate(windows): - window_corrs[i] = np.sum(np.triu(np.cov(histogram_array_interp[:, idx]), k=1)) # det doesn't work very well, too small + window_corrs[i] = np.sum(np.triu(np.corrcoef(histogram_array_interp[:, idx]), k=1)) # det doesn't work very well, too small + # plt.plot(histogram_array_interp[:, idx].T) + # plt.show() window_corrs[np.isnan(window_corrs)] = 0 window_corrs = np.abs(window_corrs) @@ -1253,29 +1317,32 @@ def get_shifts_union(histogram_array, windows, plot=True): window_max = windows[max_window][-1] fixed_indices = np.where(np.logical_and(this_round_new_positions[i, :] >= window_min, this_round_new_positions[i, :] <= window_max)) - fixed_windows[i, fixed_indices] = True - # this is in original space, the new_positions are also in original space (x -> new_positions) - histogram_array_blanked[i, fixed_indices] = 0 # this is in interpolated space - window_corrs[max_window] = 0 - windows_to_run = np.delete(windows_to_run, np.where(windows_to_run == max_window)[0]) + fixed_windows[i, fixed_indices] = True # TODO: CAREUFULLY CHECK MAPPING + + blanked_mask[windows[max_window]] = True + window_corrs[max_window] = 0 + windows_to_run = np.delete(windows_to_run, np.where(windows_to_run == max_window)[0]) # if round == 1 or not np.any(window_corrs > 0.1): # TODO: definately keep a running track of the xcorr and quit when it gets worse or doesn't improve. See how this example does across the rounds # break - final = np.zeros_like(histogram_array_blanked) - for i in range(histogram_array_blanked.shape[0]): - interpf = scipy.interpolate.interp1d( - this_round_new_positions[i], histogram_array[i, :], fill_value=0.0, bounds_error=False, kind=INTERP - ) - final[i, :] = interpf(x_orig) + # final = np.zeros_like(histogram_array_blanked) + # for i in range(histogram_array_blanked.shape[0]): + # interpf = scipy.interpolate.interp1d( + # this_round_new_positions[i], histogram_array[i, :], fill_value=0.0, bounds_error=False, kind=INTERP + # ) + # final[i, :] = interpf(x_orig) - loss_ = 0 # okay need to increase but shouldn't fail for one window - loss_ += np.sum( - np.triu(np.cov(histogram_array_interp), k=1) - ) + # loss_ = 0 # okay need to increase but shouldn't fail for one window + # loss_ += np.sum( + # np.triu(np.cov(histogram_array_interp), k=1) + # ) new_positions = this_round_new_positions # TODO + if not np.any(window_corrs > 0.01): # TODO: KEY <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< HADNLE THIS FIRST + break + # if round == 1: # break # if round == 0: @@ -1287,9 +1354,9 @@ def get_shifts_union(histogram_array, windows, plot=True): # new_positions = this_round_new_positions # TODO # loss = loss_ - print("FINAL") - plt.plot(final.T) - plt.show() +# print("FINAL") + # plt.plot(final.T) + # plt.show() # going to have to check the improvement in fit for every round and # if the round does not add much to the loss, then don't make the