Skip to content

Commit

Permalink
wip dredge_ap
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Jun 27, 2024
1 parent c2f5289 commit cab6646
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 55 deletions.
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/motion/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def run(
conv_engine = "torch" if HAVE_TORCH else "numpy"

dim = ["x", "y", "z"].index(direction)
contact_pos = recording.get_channel_locations()[:, dim]
contact_depth = recording.get_channel_locations()[:, dim]

# spatial histogram bins
spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um)
spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1])

# get spatial windows
non_rigid_windows, non_rigid_window_centers = get_windows(
rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape
rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape
)

# make 2D histogram raster
Expand Down
139 changes: 93 additions & 46 deletions src/spikeinterface/sortingcomponents/motion/dredge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import gc

from .motion_utils import Motion, get_windows, get_window_domains, scipy_conv1d
from .motion_utils import Motion, get_windows, get_window_domains, scipy_conv1d, make_2d_motion_histogram


# to discuss
Expand Down Expand Up @@ -70,16 +70,12 @@ def run(
extra,
**method_kwargs,
):
dim = ["x", "y", "z"].index(direction)
peak_amplitudes = peaks["amplitude"]
peak_depths = peak_locations[direction]
peak_times = recording.sample_index_to_time(peaks["sample_index"])


outs = dredge_ap(
peak_amplitudes,
peak_depths,
peak_times,
recording,
peaks,
peak_locations,
direction=direction,
rigid=rigid,
win_shape=win_shape,
Expand All @@ -100,9 +96,9 @@ def run(

# @TODO : Charlie I started very small refactoring, I let you continue
def dredge_ap(
amps,
depths_um,
times_s,
recording,
peaks,
peak_locations,
direction="y",
rigid=False,
# nonrigid window construction arguments
Expand All @@ -113,7 +109,7 @@ def dredge_ap(
bin_um=1.0,
bin_s=1.0,
max_disp_um=None,
time_horizon_s=1000,
time_horizon_s=1000.,
mincorr=0.1,
# weights arguments
do_window_weights=True,
Expand Down Expand Up @@ -149,6 +145,8 @@ def dredge_ap(
Arguments
---------
recording : Recording
The recording
amps : np.array of shape (n_spikes,)
depths: np.array of shape (n_spikes,)
times : np.array of shape (n_spikes,)
Expand Down Expand Up @@ -207,10 +205,22 @@ def dredge_ap(
windows if one wants to visualize them. Set `extra_outputs` to also save displacement
and correlation matrices.
"""

dim = ["x", "y", "z"].index(direction)
# @charlie: I removed amps/depths_um/times_s from the signature
# preaks and peak_locations are more SI compatible
# the way to get then
amps = peak_amplitudes = peaks["amplitude"]
depths_um = peak_depths = peak_locations[direction]
times_s = peak_times = recording.sample_index_to_time(peaks["sample_index"])



thomas_kw = thomas_kw if thomas_kw is not None else {}
xcorr_kw = xcorr_kw if xcorr_kw is not None else {}
if time_horizon_s:
xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s)

raster_kw = dict(
amp_scale_fn=amp_scale_fn,
post_transform=post_transform,
Expand All @@ -234,33 +244,70 @@ def dredge_ap(
# this will store return values other than the MotionEstimate
extra = {}

# TODO charlie switch this to make_2d_motion_histogram after having putting more option
raster_res = spike_raster(
amps,
depths_um,
times_s,
**raster_kw,


# TODO charlie I switch this to make_2d_motion_histogram
# but we need to add all options from the original spike_raster()

# raster_res = spike_raster(
# amps,
# depths_um,
# times_s,
# **raster_kw,
# )
# if count_masked_correlation:
# raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res
# else:
# raster, spatial_bin_edges_um, time_bin_edges_s = raster_res

motion_histogram, time_bin_edges_s, spatial_bin_edges_um = make_2d_motion_histogram(
recording,
peaks,
peak_locations,
weight_with_amplitude=False,
direction="y",
bin_duration_s=1.0,
bin_um=2.0,
hist_margin_um=50,
spatial_bin_edges=None,
depth_smooth_um=None,
time_smooth_s=None,
)
if count_masked_correlation:
raster, spatial_bin_edges_um, time_bin_edges_s, counts = raster_res
else:
raster, spatial_bin_edges_um, time_bin_edges_s = raster_res
raster = motion_histogram.T


# TODO @charlie check that we are doing the same thing
# windows, window_centers = get_windows(
# np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um],
# win_step_um,
# win_scale_um,
# spatial_bin_edges=spatial_bin_edges_um,
# margin_um=-win_scale_um / 2 if win_margin_um is None else win_margin_um,
# win_shape=win_shape,
# zero_threshold=1e-5,
# rigid=rigid,
# )

dim = ["x", "y", "z"].index(direction)
contact_depth = recording.get_channel_locations()[:, dim]
spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1])

# TODO Sam I did not yet change parameters here with the get_windows API
windows, window_centers = get_windows(
# pseudo geom to fool spikeinterface
np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um],
rigid,
contact_depth,
spatial_bin_centers,
win_margin_um,
win_step_um,
win_scale_um,
spatial_bin_edges=spatial_bin_edges_um,
margin_um=-win_scale_um / 2 if win_margin_um is None else win_margin_um,
win_shape=win_shape,
win_shape,
zero_threshold=1e-5,
rigid=rigid,
)
if extra_outputs and count_masked_correlation:
extra["counts"] = counts

# TODO charlie : put back the count
# if extra_outputs and count_masked_correlation:
# extra["counts"] = counts


# cross-correlate to get D and C
if precomputed_D_C_maxdisp is None:
Ds, Cs, max_disp_um = xcorr_windows(
Expand All @@ -273,7 +320,8 @@ def dredge_ap(
max_disp_um=max_disp_um,
progress_bar=progress_bar,
device=device,
masks=(counts > 0) if count_masked_correlation else None,
# TODO charlie : put back the count for the mask
# masks=(counts > 0) if count_masked_correlation else None,
**xcorr_kw,
)
else:
Expand Down Expand Up @@ -319,9 +367,7 @@ def dredge_ap(
extra["max_disp_um"] = max_disp_um

time_bin_centers = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])
spatial_bin_centers_um = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1])

motion = Motion([displacement], [time_bin_centers], spatial_bin_centers_um, direction=direction)
motion = Motion([displacement.T], [time_bin_centers], window_centers, direction=direction)

if extra_outputs:
return motion, extra
Expand Down Expand Up @@ -464,7 +510,7 @@ def dredge_online_lfp(
"""
dim = ["x", "y", "z"].index(direction)
# contact pos is the only on the direction
contact_pos = lfp_recording.get_channel_locations()[:, dim]
contact_depth = lfp_recording.get_channel_locations()[:, dim]


fs = lfp_recording.get_sampling_frequency()
Expand All @@ -477,7 +523,7 @@ def dredge_online_lfp(
thomas_kw = thomas_kw if thomas_kw is not None else {}
full_xcorr_kw = dict(
rigid=rigid,
bin_um=np.median(np.diff(contact_pos)),
bin_um=np.median(np.diff(contact_depth)),
max_disp_um=max_disp_um,
progress_bar=False,
device=device,
Expand All @@ -494,22 +540,22 @@ def dredge_online_lfp(


# here we check that contact positons are unique on the direction
if contact_pos.size != np.unique(contact_pos).size:
if contact_depth.size != np.unique(contact_depth).size:
raise ValueError(
f"estimate motion with 'dredge_lfp' need channel_positions to be unique in the direction='{direction}'"
)
if np.any(np.diff(contact_pos) < 0):
if np.any(np.diff(contact_depth) < 0):
raise ValueError(
f"estimate motion with 'dredge_lfp' need channel_positions to be ordered direction='{direction}'"
"please use spikeinterface.preprocessing.depth_order(recording)"
)

# Important detail : in LFP bin center are contact position in the direction
spatial_bin_centers = contact_pos
spatial_bin_centers = contact_depth

windows, window_centers = get_windows(
rigid=rigid,
contact_pos=contact_pos,
contact_depth=contact_depth,
spatial_bin_centers=spatial_bin_centers,
win_margin_um=win_margin_um,
win_step_um=win_step_um,
Expand All @@ -529,7 +575,7 @@ def dredge_online_lfp(
t0, t1 = 0, T_chunk
traces0 = lfp_recording.get_traces(start_frame=t0, end_frame=t1)
Ds0, Cs0, max_disp_um = xcorr_windows(
traces0.T, windows, contact_pos, win_scale_um, **full_xcorr_kw
traces0.T, windows, contact_depth, win_scale_um, **full_xcorr_kw
)
full_xcorr_kw["max_disp_um"] = max_disp_um
Ss0, mincorr0 = threshold_correlation_matrix(
Expand Down Expand Up @@ -568,15 +614,15 @@ def dredge_online_lfp(
Ds10, Cs10, _ = xcorr_windows(
traces1.T,
windows,
contact_pos,
contact_depth,
win_scale_um,
raster_b=traces0.T,
**full_xcorr_kw,
)

# cross-correlation in current chunk
Ds1, Cs1, _ = xcorr_windows(
traces1.T, windows, contact_pos, win_scale_um, **full_xcorr_kw
traces1.T, windows, contact_depth, win_scale_um, **full_xcorr_kw
)
Ss1, mincorr1 = threshold_correlation_matrix(
Cs1,
Expand Down Expand Up @@ -954,6 +1000,7 @@ def xcorr_windows(
slices = get_window_domains(windows)
B, D = windows.shape
D_, T0 = raster_a.shape

assert D == D_

# torch versions on device
Expand Down Expand Up @@ -1310,7 +1357,7 @@ def weight_correlation_matrix(
mincorr=0.0,
mincorr_percentile=None,
mincorr_percentile_nneighbs=20,
max_dt_s=None,
time_horizon_s=None,
lambda_t=DEFAULT_LAMBDA_T,
eps=DEFAULT_EPS,
do_window_weights=True,
Expand All @@ -1337,7 +1384,7 @@ def weight_correlation_matrix(
mincorr=mincorr,
mincorr_percentile=mincorr_percentile,
mincorr_percentile_nneighbs=mincorr_percentile_nneighbs,
max_dt_s=max_dt_s,
time_horizon_s=time_horizon_s,
bin_s=time_bin_edges[1] - time_bin_edges[0],
T=T,
in_place=in_place,
Expand Down
14 changes: 7 additions & 7 deletions src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def copy(self):



def get_windows(rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape,
def get_windows(rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape,
zero_threshold=None):
"""
Generate spatial windows (taper) for non-rigid motion.
Expand All @@ -243,7 +243,7 @@ def get_windows(rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step
----------
rigid : bool
If True, returns a single rectangular window
contact_pos : np.ndarray
contact_depth : np.ndarray
Position of electrodes of the corection direction shape=(num_channels, )
spatial_bin_centers : np.array
The pre-computed spatial bin centers
Expand Down Expand Up @@ -283,8 +283,8 @@ def get_windows(rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step
f"get_windows(): spatial windows are probably not overlaping because {win_scale_um=} and {win_step_um=}"
)

min_ = np.min(contact_pos) - win_margin_um
max_ = np.max(contact_pos) + win_margin_um
min_ = np.min(contact_depth) - win_margin_um
max_ = np.max(contact_depth) + win_margin_um
num_windows = int((max_ - min_) // win_step_um)
border = ((max_ - min_) % win_step_um) / 2
window_centers = np.arange(num_windows + 1) * win_step_um + min_ + border
Expand Down Expand Up @@ -356,10 +356,10 @@ def get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um):
# contact along one axis
probe = recording.get_probe()
dim = ["x", "y", "z"].index(direction)
contact_pos = probe.contact_positions[:, dim]
contact_depth = probe.contact_positions[:, dim]

min_ = np.min(contact_pos) - hist_margin_um
max_ = np.max(contact_pos) + hist_margin_um
min_ = np.min(contact_depth) - hist_margin_um
max_ = np.max(contact_depth) + hist_margin_um
spatial_bins = np.arange(min_, max_ + bin_um, bin_um)

return spatial_bins
Expand Down

0 comments on commit cab6646

Please sign in to comment.