diff --git a/debugging/playing.py b/debugging/playing.py index 793cb3ef29..8a58035c63 100644 --- a/debugging/playing.py +++ b/debugging/playing.py @@ -9,6 +9,8 @@ import matplotlib.pyplot as plt import spikeinterface.full as si +import numpy as np + si.set_global_job_kwargs(n_jobs=10) @@ -19,8 +21,6 @@ # Load / generate some recordings # -------------------------------------------------------------------------------------- - - recordings_list, _ = generate_session_displacement_recordings( num_units=20, recording_durations=[400, 400, 400], @@ -48,21 +48,23 @@ # There is a function 'session_alignment.align_sessions_after_motion_correction() # you can use instead of the below. - peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( - recordings_list, - detect_kwargs={"method": "locally_exclusive"}, - localize_peaks_kwargs={"method": "grid_convolution"}, - ) - if False: + peaks_list, peak_locations_list = session_alignment.compute_peaks_locations_for_session_alignment( + recordings_list, + detect_kwargs={"method": "locally_exclusive"}, + localize_peaks_kwargs={"method": "grid_convolution"}, + ) + np.save("peaks_1.npy", peaks_list[0]) np.save("peaks_2.npy", peaks_list[1]) + np.save("peaks_3.npy", peaks_list[2]) np.save("peak_locs_1.npy", peak_locations_list[0]) np.save("peak_locs_2.npy", peak_locations_list[1]) + np.save("peak_locs_3.npy", peak_locations_list[2]) - if False: - peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy")] - peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy")] + # if False: + peaks_list = [np.load("peaks_1.npy"), np.load("peaks_2.npy"), np.load("peaks_3.npy")] + peak_locations_list = [np.load("peak_locs_1.npy"), np.load("peak_locs_2.npy"), np.load("peak_locs_3.npy")] # -------------------------------------------------------------------------------------- # Do the estimation @@ -73,14 +75,20 @@ # See `session_alignment.py` for docs on these settings. non_rigid_window_kwargs = session_alignment.get_non_rigid_window_kwargs() - non_rigid_window_kwargs["rigid"] = False - # non_rigid_window_kwargs["win_shape"] = "rect" - # non_rigid_window_kwargs["win_step_um"] = 25 + non_rigid_window_kwargs["rigid_mode"] = "nonrigid" + non_rigid_window_kwargs["win_shape"] = "rect" + non_rigid_window_kwargs["win_step_um"] = 100.0 + non_rigid_window_kwargs["win_scale_um"] = 200.0 estimate_histogram_kwargs = session_alignment.get_estimate_histogram_kwargs() - estimate_histogram_kwargs["method"] = "chunked_mean" - estimate_histogram_kwargs["histogram_type"] = "activity_1d" - estimate_histogram_kwargs["bin_um"] = 5 + estimate_histogram_kwargs["method"] = "chunked_median" + estimate_histogram_kwargs["histogram_type"] = "activity_1d" # TODO: investigate this case thoroughly + estimate_histogram_kwargs["bin_um"] = 2 + estimate_histogram_kwargs["log_scale"] = True + estimate_histogram_kwargs["weight_with_amplitude"] = False + + compute_alignment_kwargs = session_alignment.get_compute_alignment_kwargs() + compute_alignment_kwargs["num_shifts_block"] = 300 corrected_recordings_list, extra_info = session_alignment.align_sessions( recordings_list, diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py index bac8a38b0e..a3ecac3db7 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py @@ -21,6 +21,7 @@ def get_activity_histogram( bin_s: float | None, depth_smooth_um: float | None, scale_to_hz: bool = False, + weight_with_amplitude: bool = False, ): """ Generate a 2D activity histogram for the session. Wraps the underlying @@ -57,7 +58,7 @@ def get_activity_histogram( recording, peaks, peak_locations, - weight_with_amplitude=False, + weight_with_amplitude=weight_with_amplitude, direction="y", bin_s=(bin_s if bin_s is not None else recording.get_duration(segment_index=0)), bin_um=None, diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py index ac305c839d..0f29d1f508 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/plotting_session_alignment.py @@ -311,9 +311,14 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): elif isinstance(spatial_bin_centers, np.ndarray): spatial_bin_centers = [spatial_bin_centers] * num_histograms - # TOOD: For 2D histogram, will need to subplot and just plot the individual histograms... + if dp.session_histogram_list[0].ndim == 2: + histogram_list = [np.sum(hist_, axis=1) for hist_ in dp.session_histogram_list] + print("2D histogram passed, will be summed across first (i.e. amplitude) axis.") + else: + histogram_list = dp.session_histogram_list + for i in range(num_histograms): - self.ax.plot(spatial_bin_centers[i], dp.session_histogram_list[i], color=colors[i], linewidth=linewidths[i]) + self.ax.plot(spatial_bin_centers[i], histogram_list[i], color=colors[i], linewidth=linewidths[i]) if legend is not None: self.ax.legend(legend) diff --git a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py index a583c5d672..750e4843d9 100644 --- a/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py +++ b/src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py @@ -47,6 +47,7 @@ def get_estimate_histogram_kwargs() -> dict: "log_scale": False, "depth_smooth_um": None, "histogram_type": "activity_1d", + "weight_with_amplitude": True, } @@ -66,7 +67,7 @@ def get_compute_alignment_kwargs() -> dict: windows along the probe depth. See `get_spatial_windows`. """ return { - "num_shifts_block": 100, # TODO: estimate this properly, make take as some factor of the window width? Also check if it is 2x the block xcorr in motion correction + "num_shifts_block": 50, # TODO: estimate this properly, make take as some factor of the window width? Also check if it is 2x the block xcorr in motion correction "interpolate": False, "interp_factor": 10, "kriging_sigma": 1, @@ -88,7 +89,7 @@ def get_non_rigid_window_kwargs(): defined in the function signature. """ return { - "rigid": True, + "rigid_mode": "rigid", # "rigid", "rigid_nonrigid", "nonrigid" "win_shape": "gaussian", "win_step_um": 50, "win_scale_um": 50, @@ -314,7 +315,7 @@ def align_sessions_after_motion_correction( # shifts together. if ( "non_rigid_window_kwargs" in align_sessions_kwargs - and align_sessions_kwargs["non_rigid_window_kwargs"]["rigid"] is False + and "nonrigid" in align_sessions_kwargs["non_rigid_window_kwargs"]["rigid_mode"] ): if motion_window_kwargs["rigid"] is False: @@ -392,6 +393,7 @@ def _compute_session_histograms( chunked_bin_size_s: float | "estimate", depth_smooth_um: float, log_scale: bool, + weight_with_amplitude: bool, ) -> tuple[list[np.ndarray], list[np.ndarray], np.ndarray, np.ndarray, list[dict]]: """ Compute a 1d activity histogram for the session. As @@ -455,6 +457,7 @@ def _compute_session_histograms( log_scale, chunked_bin_size_s, depth_smooth_um, + weight_with_amplitude, ) temporal_bin_centers_list.append(temporal_bin_centers) session_histogram_list.append(session_hist) @@ -479,6 +482,7 @@ def _get_single_session_activity_histogram( log_scale: bool, chunked_bin_size_s: float | "estimate", depth_smooth_um: float, + weight_with_amplitude: bool, ) -> tuple[np.ndarray, np.ndarray, dict]: """ Compute an activity histogram for a single session. @@ -534,6 +538,7 @@ def _get_single_session_activity_histogram( bin_s=None, depth_smooth_um=None, scale_to_hz=False, + weight_with_amplitude=weight_with_amplitude ) # It is important that the passed histogram is scaled to firing rate in Hz @@ -824,12 +829,69 @@ def _correct_session_displacement( estimate_histogram_kwargs["log_scale"], estimate_histogram_kwargs["chunked_bin_size_s"], estimate_histogram_kwargs["depth_smooth_um"], + estimate_histogram_kwargs["weight_with_amplitude"], ) corrected_session_histogram_list.append(session_hist) return corrected_peak_locations_list, corrected_session_histogram_list +def get_shifts(signal1, signal2, windows): + + import matplotlib.pyplot as plt + + signa11_blanked = signal1.copy() + signal2_blanked = signal2.copy() + + if (first_idx := windows[0][0]) != 0: + print("first idx", first_idx) + signa11_blanked[:first_idx] = 0 + signal2_blanked[:first_idx] = 0 + + if (last_idx := windows[-1][-1]) != signal1.size - 1: #double check + print("last idx", last_idx) + signa11_blanked[last_idx:] = 0 + signal2_blanked[last_idx:] = 0 + + segment_shifts = np.empty(windows.shape[0]) + cum_shifts = [] + + for round in range(windows.shape[0]): + + xcorr = np.correlate(signa11_blanked, signal2_blanked, mode="full") + + if np.max(xcorr) < 0.01: + shift = 0 + else: + shift = np.argmax(xcorr) - xcorr.size // 2 + cum_shifts.append(shift) + print(shift) + + # shift the signal1, or use indexing + signa11_blanked = alignment_utils.shift_array_fill_zeros(signa11_blanked, shift) + + # plt.plot(signa11_blanked) + # plt.plot(signal2_blanked) + # plt.show() + + window_corrs = np.empty(windows.shape[0]) + for i, idx in enumerate(windows): + + window_corrs[i] = np.correlate(signa11_blanked[idx] - np.mean(signa11_blanked[idx]), signal2_blanked[idx] - np.mean(signal2_blanked[idx])) + + max_window = np.argmax(window_corrs) + + segment_shifts[max_window] = np.sum(cum_shifts) + + # print(segment_shifts[max_window]) + + # TODO: this is interacting with the shift to make spikes! + signa11_blanked[windows[max_window]] = 0 + signal2_blanked[windows[max_window]] = 0 + + return segment_shifts + + def _compute_session_alignment( session_histogram_list: list[np.ndarray], contact_depths: np.ndarray, @@ -868,6 +930,9 @@ def _compute_session_alignment( akima_interp_nonrigid = compute_alignment_kwargs.pop("akima_interp_nonrigid") + rigid_mode = non_rigid_window_kwargs.pop("rigid_mode") # TODO: carefully check all popped kwargs + non_rigid_window_kwargs["rigid"] = rigid_mode == "rigid" + non_rigid_windows, non_rigid_window_centers = get_spatial_windows( contact_depths, spatial_bin_centers, **non_rigid_window_kwargs ) @@ -878,7 +943,7 @@ def _compute_session_alignment( compute_alignment_kwargs, ) - if non_rigid_window_kwargs["rigid"]: + if rigid_mode == "rigid": return rigid_shifts, non_rigid_windows, non_rigid_window_centers # For non-rigid, first shift the histograms according to the rigid shift @@ -888,7 +953,7 @@ def _compute_session_alignment( # for non-rigid, it makes sense to start without rigid alignment shifted_histograms = session_histogram_array.copy() - if False: + if rigid_mode == "rigid_nonrigid": # TOOD: add to docs shifted_histograms = np.zeros_like(session_histogram_array) for ses_idx, orig_histogram in enumerate(session_histogram_array): @@ -897,10 +962,45 @@ def _compute_session_alignment( ) shifted_histograms[ses_idx, :] = shifted_histogram + + nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0])) + + windows = [] + for i in range(non_rigid_windows.shape[0]): + idxs = np.arange(non_rigid_windows.shape[1])[non_rigid_windows[i, :].astype(bool)] + windows.append(idxs) + # TODO: check assumptions these are always the same size + + windows = np.vstack(windows) + +# import matplotlib.pyplot as plt +# plt.plot(non_rigid_windows.T) +# plt.show() + + windows1 = windows[::2, :] + windows2 = windows[1::2, :] + + nonrigid_session_offsets_matrix = np.empty((shifted_histograms.shape[0], shifted_histograms.shape[0], non_rigid_windows.shape[0])) + + for i in range(shifted_histograms.shape[0]): + for j in range(shifted_histograms.shape[0]): + + shifts1 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows1) + shifts2 = get_shifts(shifted_histograms[i, :], shifted_histograms[j, :], windows2) + shifts = np.empty(shifts1.size + shifts2.size) + # breakpoint() + shifts[::2] = shifts1 + shifts[1::2] = (shifts1[:-1] + shifts1[1:]) / 2# np.shifts2 + # breakpoint() + nonrigid_session_offsets_matrix[i, j, :] = shifts + + # TODO: there are gaps in between rect, rect seems weird, they are non-overlapping :S + + # breakpoint() # Then compute the nonrigid shifts - nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( - shifted_histograms, non_rigid_windows, **compute_alignment_kwargs - ) + # nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation( + # shifted_histograms, non_rigid_windows, **compute_alignment_kwargs + # ) non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix) # Akima interpolate the nonrigid bins if required. @@ -911,8 +1011,12 @@ def _compute_session_alignment( shifts = interp_nonrigid_shifts # rigid_shifts + interp_nonrigid_shifts non_rigid_window_centers = spatial_bin_centers else: + # TODO: so check + add a test, the interpolator will handle this? shifts = non_rigid_shifts # rigid_shifts + non_rigid_shifts + if rigid_mode == "rigid_nonrigid": + shifts += rigid_shifts + return shifts, non_rigid_windows, non_rigid_window_centers