From 5186f74a1bfdafc30749a9a2ec1f99c5a28ea4d0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 28 Jun 2024 11:09:05 +0200 Subject: [PATCH] Sam+Charlie's suggestions: estimate one motion vector per unit (plus other suggestions) --- src/spikeinterface/generation/drift_tools.py | 26 ++++----- src/spikeinterface/generation/hybrid_tools.py | 56 +++++++++++++------ src/spikeinterface/generation/noise_tools.py | 15 +++-- .../generation/tests/test_hybrid_tools.py | 6 +- .../sortingcomponents/motion_utils.py | 4 +- 5 files changed, 61 insertions(+), 46 deletions(-) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 52861ae44a..1f410f4330 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -118,21 +118,15 @@ class DriftingTemplates(Templates): This is the same strategy used by MEArec. """ - def __init__(self, **kwargs): - has_templates_moved = "templates_array_moved" in kwargs - has_displacements = "displacements" in kwargs - precomputed = has_templates_moved or has_displacements - if precomputed and not (has_templates_moved and has_displacements): - raise ValueError( - "Please pass both template_array_moved and displacements to DriftingTemplates " - "if you are using precomputed displaceed templates." - ) - templates_array_moved = kwargs.pop("templates_array_moved", None) - displacements = kwargs.pop("displacements", None) - - Templates.__init__(self, **kwargs) + def __init__(self, templates_array_moved=None, displacements=None, **static_kwargs): + Templates.__init__(self, **static_kwargs) assert self.probe is not None, "DriftingTemplates need a Probe in the init" - + if templates_array_moved is not None: + if displacements is None: + raise ValueError( + "Please pass both template_array_moved and displacements to DriftingTemplates " + "if you are using precomputed displaced templates." + ) self.templates_array_moved = templates_array_moved self.displacements = displacements @@ -192,11 +186,11 @@ def from_precomputed_templates( templates_static = templates_array_moved[templates_array_moved.shape[0] // 2] return cls( templates_array=templates_static, - templates_array_moved=templates_array_moved, - displacements=displacements, sampling_frequency=sampling_frequency, nbefore=nbefore, probe=probe, + templates_array_moved=templates_array_moved, + displacements=displacements, ) def move_one_template(self, unit_index, displacement, **interpolation_kwargs): diff --git a/src/spikeinterface/generation/hybrid_tools.py b/src/spikeinterface/generation/hybrid_tools.py index a8427ca06d..6a0c702001 100644 --- a/src/spikeinterface/generation/hybrid_tools.py +++ b/src/spikeinterface/generation/hybrid_tools.py @@ -321,6 +321,7 @@ def generate_hybrid_recording( motion: Motion | None = None, templates_in_uV: bool = True, unit_locations: np.ndarray | None = None, + drift_step_um: float = 1.0, upsample_factor: int | None = None, upsample_vector: np.ndarray | None = None, amplitude_std: float = 0.05, @@ -361,7 +362,9 @@ def generate_hybrid_recording( Cut out in ms after spike peak. unit_locations : np.array, default: None The locations at which the templates should be injected. If not provided, generated (see - generate_unit_location_kwargs) + generate_unit_location_kwargs). + drift_step_um : float, default: 1.0 + The step in um to use for the drifting templates. upsample_factor : None or int, default: None A upsampling factor used only when templates are not provided. upsample_vector : np.array or None @@ -512,31 +515,46 @@ def generate_hybrid_recording( stop = np.array([0, np.max(motion_array_concat)]) elif dim == 2: raise NotImplementedError("3D motion not implemented yet") - displacements = make_linear_displacement(start, stop, num_step=int((stop - start)[dim])) + num_step = int((stop - start)[dim] / drift_step_um) + displacements = make_linear_displacement(start, stop, num_step=num_step) # use templates_, because templates_array might have been scaled drifting_templates = DriftingTemplates.from_static_templates(templates_) drifting_templates.precompute_displacements(displacements) - # calculate displacement vectors for each segment + # calculate displacement vectors for each segment and unit + # for each unit, we interpolate the motion at its location displacement_sampling_frequency = 1.0 / np.diff(motion.temporal_bins_s[0])[0] displacement_vectors = [] - spatial_bins_um = motion.spatial_bins_um for segment_index in range(motion.num_segments): temporal_bins_segment = motion.temporal_bins_s[segment_index] - displacement_segment = motion.displacement[segment_index] - displacement_vector = np.zeros((len(temporal_bins_segment), 2, len(spatial_bins_um))) - - for count, i in enumerate(spatial_bins_um): - local_motion = displacement_segment[:, count] - displacement_vector[:, motion.dim, count] = local_motion - displacement_vectors.append(displacement_vector) - - # calculate displacement unit factor - displacement_unit_factor = np.zeros((num_units, len(spatial_bins_um))) - for count in range(num_units): - a = 1 / np.abs((unit_locations[count, motion.dim] - spatial_bins_um)) - displacement_unit_factor[count] = a / a.sum() + displacement_vector = np.zeros((len(temporal_bins_segment), 2, num_units)) + for unit_index in range(num_units): + motion_for_unit = motion.get_displacement_at_time_and_depth( + times=temporal_bins_segment, + locations_um=unit_locations[unit_index], + segment_index=segment_index, + ) + displacement_vector[:, motion.dim, unit_index] = motion_for_unit + # since displacement is estimated by interpolation for each unit, the unit factor is an eye + displacement_unit_factor = np.eye(num_units) + + # spatial_bins_um = motion.spatial_bins_um + # for segment_index in range(motion.num_segments): + # temporal_bins_segment = motion.temporal_bins_s[segment_index] + # displacement_segment = motion.displacement[segment_index] + # displacement_vector = np.zeros((len(temporal_bins_segment), 2, len(spatial_bins_um))) + + # for count, i in enumerate(spatial_bins_um): + # local_motion = displacement_segment[:, count] + # displacement_vector[:, motion.dim, count] = local_motion + # displacement_vectors.append(displacement_vector) + + # # calculate displacement unit factor + # displacement_unit_factor = np.zeros((num_units, len(spatial_bins_um))) + # for count in range(num_units): + # a = 1 / np.abs((unit_locations[count, motion.dim] - spatial_bins_um)) + # displacement_unit_factor[count] = a / a.sum() hybrid_recording = InjectDriftingTemplatesRecording( sorting=sorting, @@ -550,6 +568,10 @@ def generate_hybrid_recording( ) else: + warnings.warn( + "No Motion is provided! Please check that your recording is drift-free, otherwise the hybrid recording " + "will have stationary units over a drifting recording..." + ) hybrid_recording = InjectTemplatesRecording( sorting, templates_array, diff --git a/src/spikeinterface/generation/noise_tools.py b/src/spikeinterface/generation/noise_tools.py index 48555b3062..11f30e352f 100644 --- a/src/spikeinterface/generation/noise_tools.py +++ b/src/spikeinterface/generation/noise_tools.py @@ -10,25 +10,24 @@ def generate_noise( Parameters ---------- - probe: Probe + probe : Probe A probe object. - sampling_frequency: float + sampling_frequency : float Sampling frequency - durations: list of float + durations : list of float Durations - dtype: np.dtype + dtype : np.dtype Dtype - noise_levels: float | np.array | tuple + noise_levels : float | np.array | tuple If scalar same noises on all channels. If array then per channels noise level. If tuple, then this represent the range. - - seed: None | int + seed : None | int The seed for random generator. Returns ------- - noise: NoiseGeneratorRecording + noise : NoiseGeneratorRecording A lazy noise generator recording. """ diff --git a/src/spikeinterface/generation/tests/test_hybrid_tools.py b/src/spikeinterface/generation/tests/test_hybrid_tools.py index 7939d0b70e..d31a0ec81d 100644 --- a/src/spikeinterface/generation/tests/test_hybrid_tools.py +++ b/src/spikeinterface/generation/tests/test_hybrid_tools.py @@ -34,11 +34,9 @@ def test_generate_hybrid_with_sorting(): assert sorting_hybrid.get_num_units() == len(hybrid.templates) -def test_generate_hybrid_motion(create_cache_folder): - cache_folder = create_cache_folder +def test_generate_hybrid_motion(): rec, _ = generate_ground_truth_recording(sampling_frequency=20000, durations=[10], seed=0) - correct_motion(rec, folder=cache_folder / "motion") - motion_info = load_motion_info(cache_folder / "motion") + _, motion_info = correct_motion(rec, output_motion_info=True) motion = motion_info["motion"] hybrid, sorting_hybrid = generate_hybrid_recording(rec, motion=motion, seed=0) assert rec.get_num_channels() == hybrid.get_num_channels() diff --git a/src/spikeinterface/sortingcomponents/motion_utils.py b/src/spikeinterface/sortingcomponents/motion_utils.py index cb8241d310..39991f4e52 100644 --- a/src/spikeinterface/sortingcomponents/motion_utils.py +++ b/src/spikeinterface/sortingcomponents/motion_utils.py @@ -90,10 +90,12 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde Parameters ---------- times_s: np.array + The time points at which to evaluate the displacement. locations_um: np.array Either this is a one-dimensional array (a vector of positions along self.dimension), or else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. - segment_index: int, optional + segment_index: int, default: None + The index of the segment to evaluate. If None, and there is only one segment, then that segment is used. grid : bool If grid=False, the default, then times_s and locations_um should have the same one-dimensional shape, and the returned displacement[i] is the displacement at time times_s[i] and location