Skip to content

Commit

Permalink
Sam+Charlie's suggestions: estimate one motion vector per unit (plus …
Browse files Browse the repository at this point in the history
…other suggestions)
  • Loading branch information
alejoe91 committed Jun 28, 2024
1 parent 406ebcf commit 5186f74
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 46 deletions.
26 changes: 10 additions & 16 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,21 +118,15 @@ class DriftingTemplates(Templates):
This is the same strategy used by MEArec.
"""

def __init__(self, **kwargs):
has_templates_moved = "templates_array_moved" in kwargs
has_displacements = "displacements" in kwargs
precomputed = has_templates_moved or has_displacements
if precomputed and not (has_templates_moved and has_displacements):
raise ValueError(
"Please pass both template_array_moved and displacements to DriftingTemplates "
"if you are using precomputed displaceed templates."
)
templates_array_moved = kwargs.pop("templates_array_moved", None)
displacements = kwargs.pop("displacements", None)

Templates.__init__(self, **kwargs)
def __init__(self, templates_array_moved=None, displacements=None, **static_kwargs):
Templates.__init__(self, **static_kwargs)
assert self.probe is not None, "DriftingTemplates need a Probe in the init"

if templates_array_moved is not None:
if displacements is None:
raise ValueError(
"Please pass both template_array_moved and displacements to DriftingTemplates "
"if you are using precomputed displaced templates."
)
self.templates_array_moved = templates_array_moved
self.displacements = displacements

Expand Down Expand Up @@ -192,11 +186,11 @@ def from_precomputed_templates(
templates_static = templates_array_moved[templates_array_moved.shape[0] // 2]
return cls(
templates_array=templates_static,
templates_array_moved=templates_array_moved,
displacements=displacements,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
templates_array_moved=templates_array_moved,
displacements=displacements,
)

def move_one_template(self, unit_index, displacement, **interpolation_kwargs):
Expand Down
56 changes: 39 additions & 17 deletions src/spikeinterface/generation/hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ def generate_hybrid_recording(
motion: Motion | None = None,
templates_in_uV: bool = True,
unit_locations: np.ndarray | None = None,
drift_step_um: float = 1.0,
upsample_factor: int | None = None,
upsample_vector: np.ndarray | None = None,
amplitude_std: float = 0.05,
Expand Down Expand Up @@ -361,7 +362,9 @@ def generate_hybrid_recording(
Cut out in ms after spike peak.
unit_locations : np.array, default: None
The locations at which the templates should be injected. If not provided, generated (see
generate_unit_location_kwargs)
generate_unit_location_kwargs).
drift_step_um : float, default: 1.0
The step in um to use for the drifting templates.
upsample_factor : None or int, default: None
A upsampling factor used only when templates are not provided.
upsample_vector : np.array or None
Expand Down Expand Up @@ -512,31 +515,46 @@ def generate_hybrid_recording(
stop = np.array([0, np.max(motion_array_concat)])
elif dim == 2:
raise NotImplementedError("3D motion not implemented yet")
displacements = make_linear_displacement(start, stop, num_step=int((stop - start)[dim]))
num_step = int((stop - start)[dim] / drift_step_um)
displacements = make_linear_displacement(start, stop, num_step=num_step)

# use templates_, because templates_array might have been scaled
drifting_templates = DriftingTemplates.from_static_templates(templates_)
drifting_templates.precompute_displacements(displacements)

# calculate displacement vectors for each segment
# calculate displacement vectors for each segment and unit
# for each unit, we interpolate the motion at its location
displacement_sampling_frequency = 1.0 / np.diff(motion.temporal_bins_s[0])[0]
displacement_vectors = []
spatial_bins_um = motion.spatial_bins_um
for segment_index in range(motion.num_segments):
temporal_bins_segment = motion.temporal_bins_s[segment_index]
displacement_segment = motion.displacement[segment_index]
displacement_vector = np.zeros((len(temporal_bins_segment), 2, len(spatial_bins_um)))

for count, i in enumerate(spatial_bins_um):
local_motion = displacement_segment[:, count]
displacement_vector[:, motion.dim, count] = local_motion
displacement_vectors.append(displacement_vector)

# calculate displacement unit factor
displacement_unit_factor = np.zeros((num_units, len(spatial_bins_um)))
for count in range(num_units):
a = 1 / np.abs((unit_locations[count, motion.dim] - spatial_bins_um))
displacement_unit_factor[count] = a / a.sum()
displacement_vector = np.zeros((len(temporal_bins_segment), 2, num_units))
for unit_index in range(num_units):
motion_for_unit = motion.get_displacement_at_time_and_depth(
times=temporal_bins_segment,
locations_um=unit_locations[unit_index],
segment_index=segment_index,
)
displacement_vector[:, motion.dim, unit_index] = motion_for_unit
# since displacement is estimated by interpolation for each unit, the unit factor is an eye
displacement_unit_factor = np.eye(num_units)

# spatial_bins_um = motion.spatial_bins_um
# for segment_index in range(motion.num_segments):
# temporal_bins_segment = motion.temporal_bins_s[segment_index]
# displacement_segment = motion.displacement[segment_index]
# displacement_vector = np.zeros((len(temporal_bins_segment), 2, len(spatial_bins_um)))

# for count, i in enumerate(spatial_bins_um):
# local_motion = displacement_segment[:, count]
# displacement_vector[:, motion.dim, count] = local_motion
# displacement_vectors.append(displacement_vector)

# # calculate displacement unit factor
# displacement_unit_factor = np.zeros((num_units, len(spatial_bins_um)))
# for count in range(num_units):
# a = 1 / np.abs((unit_locations[count, motion.dim] - spatial_bins_um))
# displacement_unit_factor[count] = a / a.sum()

hybrid_recording = InjectDriftingTemplatesRecording(
sorting=sorting,
Expand All @@ -550,6 +568,10 @@ def generate_hybrid_recording(
)

else:
warnings.warn(
"No Motion is provided! Please check that your recording is drift-free, otherwise the hybrid recording "
"will have stationary units over a drifting recording..."
)
hybrid_recording = InjectTemplatesRecording(
sorting,
templates_array,
Expand Down
15 changes: 7 additions & 8 deletions src/spikeinterface/generation/noise_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,24 @@ def generate_noise(
Parameters
----------
probe: Probe
probe : Probe
A probe object.
sampling_frequency: float
sampling_frequency : float
Sampling frequency
durations: list of float
durations : list of float
Durations
dtype: np.dtype
dtype : np.dtype
Dtype
noise_levels: float | np.array | tuple
noise_levels : float | np.array | tuple
If scalar same noises on all channels.
If array then per channels noise level.
If tuple, then this represent the range.
seed: None | int
seed : None | int
The seed for random generator.
Returns
-------
noise: NoiseGeneratorRecording
noise : NoiseGeneratorRecording
A lazy noise generator recording.
"""

Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/generation/tests/test_hybrid_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@ def test_generate_hybrid_with_sorting():
assert sorting_hybrid.get_num_units() == len(hybrid.templates)


def test_generate_hybrid_motion(create_cache_folder):
cache_folder = create_cache_folder
def test_generate_hybrid_motion():
rec, _ = generate_ground_truth_recording(sampling_frequency=20000, durations=[10], seed=0)
correct_motion(rec, folder=cache_folder / "motion")
motion_info = load_motion_info(cache_folder / "motion")
_, motion_info = correct_motion(rec, output_motion_info=True)
motion = motion_info["motion"]
hybrid, sorting_hybrid = generate_hybrid_recording(rec, motion=motion, seed=0)
assert rec.get_num_channels() == hybrid.get_num_channels()
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/sortingcomponents/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde
Parameters
----------
times_s: np.array
The time points at which to evaluate the displacement.
locations_um: np.array
Either this is a one-dimensional array (a vector of positions along self.dimension), or
else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1.
segment_index: int, optional
segment_index: int, default: None
The index of the segment to evaluate. If None, and there is only one segment, then that segment is used.
grid : bool
If grid=False, the default, then times_s and locations_um should have the same one-dimensional
shape, and the returned displacement[i] is the displacement at time times_s[i] and location
Expand Down

0 comments on commit 5186f74

Please sign in to comment.