Skip to content

Commit

Permalink
Begin typing and documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Sep 12, 2024
1 parent 1724ba6 commit 8c4030a
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 33 deletions.
2 changes: 1 addition & 1 deletion debugging/_test_session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@
"method": "chunked_supremum", # TODO: double check scaling
"chunked_bin_size_s": "estimate",
"log_scale": True,
"smooth_um": 10,
"depth_smooth_um": 10,
}
compute_alignment_kwargs = {
"num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs.
Expand Down
4 changes: 2 additions & 2 deletions debugging/alignment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def get_activity_histogram(
recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s, smooth_um
recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s, depth_smooth_um
):
"""
TODO: assumes 1-segment recording
Expand All @@ -35,7 +35,7 @@ def get_activity_histogram(
bin_um=None,
hist_margin_um=None,
spatial_bin_edges=spatial_bin_edges,
depth_smooth_um=smooth_um,
depth_depth_smooth_um=depth_smooth_um,
)
assert np.array_equal(generated_spatial_bin_edges, spatial_bin_edges), "TODO: remove soon after testing"

Expand Down
2 changes: 1 addition & 1 deletion debugging/playing_inter-session-alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"method": "chunked_median", # TODO: double check scaling
"chunked_bin_size_s": "estimate",
"log_scale": False,
"smooth_um": 5,
"depth_smooth_um": 5,
}
compute_alignment_kwargs = {
"num_shifts_block": None, # TODO: can be in um so comaprable with window kwargs.
Expand Down
158 changes: 129 additions & 29 deletions debugging/session_alignment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from spikeinterface.core.baserecording import BaseRecording

import numpy as np
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
from spikeinterface.sortingcomponents.motion.motion_utils import get_spatial_windows, Motion, get_spatial_bins
Expand All @@ -9,21 +14,64 @@
import copy


# TODO: need to plot out the entire call tree and
# make sure it is optimal as its quite complex but
# there is quite a lot of weird stuff here.
# 1) add docstrings and type hints
# 2) add print statements to the entry function
# 3) look into the DREDGE stuff

# TODO: 1) think of a method to choose some reasonable defaults for bin size, nonrigid smoothing.
# TODO: 2) different alignment procedure

_estimate_histogram_kwargs = {

def get_estimate_histogram_kwargs() -> dict:
"""
A dictionary controlling how the histogram for each session is
computed. The session histograms are estimated by chunking
the recording into time segments and computing histograms
for each chunk, then performing some summary statistic over
the chunked histograms.
Returns
-------
A dictionary with entries:
"bin_um" : number of spatial histogram bins. As the estimated peak
locations are continuous (i.e. real numbers) this is not constrained
by the number of channels.
"method" : may be "chunked_mean", "chunked_median", "chunked_supremum",
"chunked_poisson". Determines the summary statistic used over
the histograms computed across a session. See `alignment_utils.py
for details on each method.
"chunked_bin_size_s" : The length in seconds to chunk the recording
for estimating the chunked histograms.
"log_scale" : if `True`, histograms are log transformed.
"depth_smooth_um" : if `None`, no smoothing is applied. See
`make_2d_motion_histogram`.
"""
return {
"bin_um": 2,
"method": "chunked_mean",
"chunked_bin_size_s": "estimate",
"log_scale": False,
"smooth_um": None,
"depth_smooth_um": None,
}


_compute_alignment_kwargs = {
def get_compute_alignment_kwargs() -> dict:
"""
A dictionary with settings controlling how inter-session
alignment is estimated and computed given a set of
session activity histograms.
All keys except for "non_rigid_window_kwargs" determine
how alignment is estimated, based on the kilosort ("kilosort_like"
in spikeinterface) motion correction method. See
`iterative_template_registration` for details.
"non_rigid_window_kwargs" : if nonrigid alignment
is performed, this determines the nature of the
windows along the probe depth. See `get_spatial_windows`.
"""
return {
"num_shifts_block": 5,
"interpolate": False,
"interp_factor": 10,
Expand All @@ -43,11 +91,15 @@
}


_interpolate_motion_kwargs = {
"border_mode": "remove_channels",
"spatial_interpolation_method": "kriging",
"sigma_um": 20.0,
"p": 2
def get_interpolate_motion_kwargs():
"""
Settings to pass to `InterpolateMotionRecording"
"""
return {
"border_mode": "remove_channels",
"spatial_interpolation_method": "kriging",
"sigma_um": 20.0,
"p": 2
}

# -----------------------------------------------------------------------------
Expand All @@ -57,18 +109,66 @@
# TODO: add some print statements for progress

def align_sessions(
recordings_list,
peaks_list,
peak_locations_list,
alignment_order="to_middle",
rigid=True,
estimate_histogram_kwargs=_estimate_histogram_kwargs,
compute_alignment_kwargs=_compute_alignment_kwargs,
interpolate_motion_kwargs=_interpolate_motion_kwargs,
):
recordings_list: list[BaseRecording],
peaks_list: list[np.ndarray],
peak_locations_list: list[np.ndarray],
alignment_order: str = "to_middle",
rigid: bool = True,
estimate_histogram_kwargs: dict = get_estimate_histogram_kwargs(),
compute_alignment_kwargs: dict = get_compute_alignment_kwargs(),
interpolate_motion_kwargs: dict = get_interpolate_motion_kwargs(),
) -> tuple[list[BaseRecording], ]:
"""
print what happens with the update! it is automaticlally added
to existing motion if it exists.
Estimate probe displacement across recording sessions and
return interpolated, displacement-corrected recording. Displacement
is only estimated along the "y" dimension.
This assumes peaks and peak locations have already been computed.
See `compute_peaks_locations_for_session_alignment` for generating
`peaks_list` and `peak_locations_list` from a `recordings_list`.
If a recording in `recordings_list` is already an `InterpolateMotionRecording`,
the displacement will be added to the existing shifts to avoid duplicate
interpolations. Note the returned, corrected recording is a copy
(recordings in `recording_list` are not edited in-place).
Parameters
----------
recordings_list : list[BaseRecording]
A list of recordings to be aligned.
peaks_list : list[np.ndarray]
A list of peaks detected from the recordings in `recordings_list`,
as returned from the `detect_peaks` function. Each entry in
`peaks_list` should be from the corresponding entry in `recordings_list`.
peak_locations_list : list[np.ndarray]
A list of peak locations, as computed by `localize_peaks`. Each entry
in `peak_locations_list` should be matched to the corresponding entry
in `peaks_list` and `recordings_list`.
alignment_order : str
"to_middle" will align all sessions to the mean position.
Alternatively, "to_session_N" where "N" is a session number
will align to the Nth session.
rigid : bool
If `True`, estimated displacement is rigid. If `False`, nonrigid
estimation is performed by performing rigid alignment on overlapping
subsets of the probes "y" dimension.
estimate_histogram_kwargs : dict
see `get_estimate_histogram_kwargs()`
compute_alignment_kwargs : dict
see `get_compute_alignment_kwargs()`
interpolate_motion_kwargs : dict
see `get_interpolate_motion_kwargs()`
Returns
-------
`corrected_recordings_list : list[BaseRecording]
`motion_objects_list : list[TODO]
extra_outputs_dict : dict
"""
estimate_histogram_kwargs = copy.deepcopy(estimate_histogram_kwargs)
compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs)
Expand Down Expand Up @@ -142,7 +242,7 @@ def align_sessions_after_motion_correction(
)


def compute_peaks_for_session_alignment(
def compute_peaks_locations_for_session_alignment(
recording_list, gather_mode, detect_kwargs, localize_peaks_kwargs, job_kwargs
):
"""
Expand Down Expand Up @@ -172,7 +272,7 @@ def _compute_session_histograms(
bin_um,
method,
chunked_bin_size_s,
smooth_um,
depth_smooth_um,
log_scale,
):
"""
Expand All @@ -191,7 +291,7 @@ def _compute_session_histograms(

session_hist, temporal_bin_centers, histogram_info = _get_single_session_activity_histogram(
recording, peaks, peak_locations, spatial_bin_edges,
method, log_scale, chunked_bin_size_s, smooth_um,
method, log_scale, chunked_bin_size_s, depth_smooth_um,
)
temporal_bin_centers_list.append(temporal_bin_centers)
session_histogram_list.append(session_hist)
Expand All @@ -201,7 +301,7 @@ def _compute_session_histograms(


def _get_single_session_activity_histogram(
recording, peaks, peak_locations, spatial_bin_edges, method, log_scale, chunked_bin_size_s, smooth_um
recording, peaks, peak_locations, spatial_bin_edges, method, log_scale, chunked_bin_size_s, depth_smooth_um
):
"""
"""
Expand All @@ -214,7 +314,7 @@ def _get_single_session_activity_histogram(

one_bin_histogram, _, _ = alignment_utils.get_activity_histogram(
recording, peaks, peak_locations, spatial_bin_edges,
log_scale=False, bin_s=None, smooth_um=smooth_um
log_scale=False, bin_s=None, depth_smooth_um=depth_smooth_um
)
if method == "entire_session":
if log_scale:
Expand All @@ -231,7 +331,7 @@ def _get_single_session_activity_histogram(
)

chunked_histograms, chunked_temporal_bin_centers, _ = alignment_utils.get_activity_histogram(
recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s=chunked_bin_size_s, smooth_um=smooth_um,
recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s=chunked_bin_size_s, depth_smooth_um=depth_smooth_um,
)
session_std = np.sum(np.std(chunked_histograms, axis=0)) / chunked_histograms.shape[1]

Expand Down Expand Up @@ -373,7 +473,7 @@ def _correct_session_displacement(
estimate_histogram_kwargs["method"],
estimate_histogram_kwargs["log_scale"],
estimate_histogram_kwargs["chunked_bin_size_s"],
estimate_histogram_kwargs["smooth_um"],
estimate_histogram_kwargs["depth_smooth_um"],
)
corrected_session_histogram_list.append(session_hist)

Expand Down

0 comments on commit 8c4030a

Please sign in to comment.