Skip to content

Commit

Permalink
more refactoring and parameters change
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Jun 27, 2024
1 parent cab6646 commit 37014bb
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

from spikeinterface.core import get_noise_levels
from spikeinterface.sortingcomponents.benchmark.benchmark_tools import Benchmark, BenchmarkStudy, _simpleaxis
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.sortingcomponents.motion import estimate_motion
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.widgets import plot_probe_map

from spikeinterface.sortingcomponents.motion_utils import Motion
from spikeinterface.sortingcomponents.motion import Motion

# import MEArec as mr

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def load_folder(cls, folder):

result[k] = load_extractor(folder / k)
elif format == "Motion":
from spikeinterface.sortingcomponents.motion_utils import Motion
from spikeinterface.sortingcomponents.motion import Motion

result[k] = Motion.load(folder / k)
elif format == "zarr_templates":
Expand Down
23 changes: 15 additions & 8 deletions src/spikeinterface/sortingcomponents/motion/decentralized.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
except ImportError:
HAVE_TORCH = False

from .motion_utils import Motion, get_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d
from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_2d_motion_histogram, scipy_conv1d


class DecentralizedRegistration:
Expand Down Expand Up @@ -113,11 +113,11 @@ def run(
verbose,
progress_bar,
extra,
bin_um=10.0,
hist_margin_um=0.0,
bin_duration_s=2.0,
histogram_depth_smooth_um=None,
histogram_time_smooth_s=None,
bin_um=1.0,
hist_margin_um=20.0,
bin_duration_s=1.0,
histogram_depth_smooth_um=1.,
histogram_time_smooth_s=1.,
pairwise_displacement_method="conv",
max_displacement_um=100.0,
weight_scale="linear",
Expand Down Expand Up @@ -152,8 +152,15 @@ def run(
spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1])

# get spatial windows
non_rigid_windows, non_rigid_window_centers = get_windows(
rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape
non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
contact_depth,
spatial_bin_centers,
rigid=rigid,
win_shape=win_shape,
win_step_um=win_step_um,
win_scale_um=win_scale_um,
win_margin_um=win_margin_um,
zero_threshold=None
)

# make 2D histogram raster
Expand Down
31 changes: 19 additions & 12 deletions src/spikeinterface/sortingcomponents/motion/dredge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import gc

from .motion_utils import Motion, get_windows, get_window_domains, scipy_conv1d, make_2d_motion_histogram
from .motion_utils import Motion, get_spatial_windows, get_window_domains, scipy_conv1d, make_2d_motion_histogram


# to discuss
Expand All @@ -42,6 +42,8 @@
# put smotthing inside the histogram function
# put the log for weight inhitstogram

# TODO maybe change everywhere bin_duration_s to bin_s


# simple class wrapper to be compliant with estimate_motion
class DredgeApRegistration:
Expand Down Expand Up @@ -221,6 +223,9 @@ def dredge_ap(
if time_horizon_s:
xcorr_kw["max_dt_bins"] = np.ceil(time_horizon_s / bin_s)

#TODO @charlie I think this is a bad to have the dict which is transported to every function
# this should be used only in histogram function but not in weight_correlation_matrix()
# only important kwargs should be explicitly reported
raster_kw = dict(
amp_scale_fn=amp_scale_fn,
post_transform=post_transform,
Expand Down Expand Up @@ -277,7 +282,7 @@ def dredge_ap(


# TODO @charlie check that we are doing the same thing
# windows, window_centers = get_windows(
# windows, window_centers = get_spatial_windows(
# np.c_[np.zeros_like(spatial_bin_edges_um), spatial_bin_edges_um],
# win_step_um,
# win_scale_um,
Expand All @@ -292,16 +297,18 @@ def dredge_ap(
contact_depth = recording.get_channel_locations()[:, dim]
spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1])

windows, window_centers = get_windows(
rigid,
windows, window_centers = get_spatial_windows(
contact_depth,
spatial_bin_centers,
win_margin_um,
win_step_um,
win_scale_um,
win_shape,
zero_threshold=1e-5,
)
rigid=rigid,
win_shape=win_shape,
win_step_um=win_step_um,
win_scale_um=win_scale_um,
win_margin_um=win_margin_um,
zero_threshold=1e-5
)



# TODO charlie : put back the count
# if extra_outputs and count_masked_correlation:
Expand Down Expand Up @@ -553,10 +560,10 @@ def dredge_online_lfp(
# Important detail : in LFP bin center are contact position in the direction
spatial_bin_centers = contact_depth

windows, window_centers = get_windows(
rigid=rigid,
windows, window_centers = get_spatial_windows(
contact_depth=contact_depth,
spatial_bin_centers=spatial_bin_centers,
rigid=rigid,
win_margin_um=win_margin_um,
win_step_um=win_step_um,
win_scale_um=win_scale_um,
Expand Down
16 changes: 12 additions & 4 deletions src/spikeinterface/sortingcomponents/motion/iterative_template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from .motion_utils import Motion, get_windows, get_spatial_bin_edges, make_3d_motion_histograms
from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges, make_3d_motion_histograms


class IterativeTemplateRegistration:
Expand Down Expand Up @@ -78,18 +78,26 @@ def run(
):

dim = ["x", "y", "z"].index(direction)
contact_pos = recording.get_channel_locations()[:, dim]
contact_depth = recording.get_channel_locations()[:, dim]

# spatial histogram bins
spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um)
spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1])

# get spatial windows
non_rigid_windows, non_rigid_window_centers = get_windows(
rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape
non_rigid_windows, non_rigid_window_centers = get_spatial_windows(
contact_depth=contact_depth,
spatial_bin_centers=spatial_bin_centers,
rigid=rigid,
win_margin_um=win_margin_um,
win_step_um=win_step_um,
win_scale_um=win_scale_um,
win_shape=win_shape,
zero_threshold=None,
)



# make a 3D histogram
motion_histograms, temporal_hist_bin_edges, spatial_hist_bin_edges = make_3d_motion_histograms(
recording,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np


# TODO this need a full rewrite with motion object

def clean_motion_vector(motion, temporal_bins, bin_duration_s, speed_threshold=30, sigma_smooth_s=None):
"""
Expand Down
35 changes: 8 additions & 27 deletions src/spikeinterface/sortingcomponents/motion/motion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from spikeinterface.sortingcomponents.tools import make_multi_method_doc


from .motion_utils import Motion, get_windows, get_spatial_bin_edges
from .motion_utils import Motion, get_spatial_windows, get_spatial_bin_edges
from .decentralized import DecentralizedRegistration
from .iterative_template import IterativeTemplateRegistration
from .dredge import DredgeLfpRegistration, DredgeApRegistration
Expand All @@ -19,13 +19,11 @@ def estimate_motion(
peaks=None,
peak_locations=None,
direction="y",
# bin_um=10.0,
# hist_margin_um=0.0,
rigid=False,
win_shape="gaussian",
win_step_um=50.0,
win_scale_um=150.0,
win_margin_um=0.,
win_margin_um=None,
method="decentralized",
extra_outputs=False,
progress_bar=False,
Expand All @@ -36,7 +34,8 @@ def estimate_motion(
"""
Estimate motion for given peaks and after their localization.
Note that the way you detect peak locations (center of mass/monopolar triangulation) have an impact on the result.
Note that the way you detect peak locations (center of mass/monopolar triangulation)
have an impact on the result.
Parameters
----------
Expand Down Expand Up @@ -70,14 +69,14 @@ def estimate_motion(
The depth domain will be broken up into windows with shape controlled by win_shape,
spaced by win_step_um at a margin of win_margin_um from the boundary, and with
width controlled by win_scale_um.
When win_margin_um is None the margin is automatically set to -win_scale_um/2.
See get_spatial_windows.
win_step_um : float, default: 50
See win_shape
win_scale_um : float, default: 150
See win_shape
win_margin_um : float, default: 0.
See win_shape
win_margin_um : None | float, default: None
See win_shape
extra_outputs: bool, default: False
If True then return an extra dict that contains variables
to check intermediate steps (motion_histogram, non_rigid_windows, pairwise_displacement)
Expand Down Expand Up @@ -114,25 +113,7 @@ def estimate_motion(
else:
extra = None

# contact positions
# probe = recording.get_probe()
# dim = ["x", "y", "z"].index(direction)
# contact_pos = probe.contact_positions[:, dim]

# # spatial histogram bins
# spatial_bin_edges = get_spatial_bin_edges(recording, direction, hist_margin_um, bin_um)
# spatial_bin_centers = 0.5 * (spatial_bin_edges[1:] + spatial_bin_edges[:-1])

# # get spatial windows
# non_rigid_windows, non_rigid_window_centers = get_windows(
# rigid, contact_pos, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape
# )

# if extra_outputs:
# extra["non_rigid_windows"] = non_rigid_windows

# run method

motion = method_class.run(
recording,
peaks,
Expand Down
43 changes: 32 additions & 11 deletions src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,29 +232,45 @@ def copy(self):



def get_windows(rigid, contact_depth, spatial_bin_centers, win_margin_um, win_step_um, win_scale_um, win_shape,
zero_threshold=None):
def get_spatial_windows(
contact_depth,
spatial_bin_centers,
rigid=False,
win_shape="gaussian",
win_step_um=50.0,
win_scale_um=150.0,
win_margin_um=None,
zero_threshold=None
):
"""
Generate spatial windows (taper) for non-rigid motion.
For rigid motion, this is equivalent to have one unique rectangular window that covers the entire probe.
The windowing can be gaussian or rectangular.
Windows are centered between the min/max of contact_depth.
We can ensure window to not be to close from border with win_margin_um.
Parameters
----------
rigid : bool
If True, returns a single rectangular window
contact_depth : np.ndarray
Position of electrodes of the corection direction shape=(num_channels, )
spatial_bin_centers : np.array
The pre-computed spatial bin centers
win_margin_um : float
The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows.=
rigid : bool, default False
If True, returns a single rectangular window
win_shape : str, default "gaussian"
Shape of the window
"gaussian" | "rect" | "triangle"
win_step_um : float
The steps at which windows are defined
win_scale_um : float
Sigma of gaussian window (if win_shape is gaussian)
win_shape : float
"gaussian" | "rect"
win_scale_um : float, default 150.
Sigma of gaussian window if win_shape is gaussian
Width of the rectangle if win_shape is rect
win_margin_um : None | float, default None
The margin to extend (if positive) or shrink (if negative) the probe dimension to compute windows.
When None, then the margin is set to -win_scale_um./2
zero_threshold: None | float
Lower value for thresholding to set zeros.
Returns
-------
Expand All @@ -280,9 +296,14 @@ def get_windows(rigid, contact_depth, spatial_bin_centers, win_margin_um, win_st
else:
if win_scale_um <= win_step_um/5.:
warnings.warn(
f"get_windows(): spatial windows are probably not overlaping because {win_scale_um=} and {win_step_um=}"
f"get_spatial_windows(): spatial windows are probably not overlaping because {win_scale_um=} and {win_step_um=}"
)

# @charlie: I am pretty sure this is the best option
if win_margin_um is None:
# this ensure that first/last windows do not overflow outside the probe
win_margin_um = -win_scale_um / 2.

min_ = np.min(contact_depth) - win_margin_um
max_ = np.max(contact_depth) + win_margin_um
num_windows = int((max_ - min_) // win_step_um)
Expand Down

0 comments on commit 37014bb

Please sign in to comment.