From 9798c75d7490633e13f34324b31ce5837edd775e Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 24 Jul 2024 21:25:43 +0100 Subject: [PATCH] Add documentation. --- .../session_displacement_generator.py | 137 ++++++++++++++++-- 1 file changed, 125 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py index 9bb30b994a..6cdd22aa2f 100644 --- a/src/spikeinterface/generation/session_displacement_generator.py +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -52,7 +52,63 @@ def generate_session_displacement_recordings( extra_outputs=False, seed=None, ): - """ """ + """ + Generate a set of recordings simulating probe drift across recording + sessions. + + Rigid drift can be added in the (x, y) direction in `recording_shifts`. + These drifts can be made non-rigid (scaled dependent on the unit location) + with the `non_rigid_gradient` parameter. Amplitude of units can be scaled + (e.g. template signal removed by scaling with zero) by specifying scaling + factors in `recording_amplitude_scalings`. + + Parameters + ---------- + + num_units : int + The number of units in the generated recordings. + recording_durations : list + An array of length (num_recordings,) specifying the + duration that each created recording should be. + recording_shifts : list + An array of length (num_recordings,) in which each element + is a 2-element array specifying the (x, y) shift for the recording. + Typically, the first recording will have shift (0, 0) so all further + recordings are shifted relative to it. e.g. to create two recordings, + the second shifted by 50 um in the x-direction and 250 um in the y + direction : ((0, 0), (50, 250)). + non_rigid_gradient : float + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + recording_amplitude_scalings : dict + A dict with keys: + "method" - order by which to apply the scalings. + "by_passed_order" - scalings are applied to the unit templates + in order passed + "by_firing_rate" - scalings are applied to the units in order of + maximum to minimum firing rate + "by_amplitude_and_firing_rate" - scalings are applied to the units + in order of amplitude * firing_rate (maximum to minimum) + "scalings" - a list of numpy arrays, one for each recording, with + each entry an array of length num_units holding the unit scalings. + e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)). + + All other parameters are used as in from `generate_drifting_recording()`. + + Returns + ------- + output_recordings : list + A list of recordings with units shifted (i.e. replicated probe shift). + output_sortings : list + A list of corresponding sorting objects. + extra_outputs_dict (options) : dict + When `extra_outputs` is `True`, a dict containing variables used + in the generation process. + "unit_locations" : A list (length num records) of shifted unit locations + "templates_array_moved" : list[np.array] + A list (length num records) of (num_units, num_samples, num_channels) + arrays of templates that have been shifted. + """ _check_generate_session_displacement_arguments( num_units, recording_durations, recording_shifts, recording_amplitude_scalings ) @@ -82,7 +138,7 @@ def generate_session_displacement_recordings( for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)): - displacement_vector, displacement_unit_factor = get_inter_session_displacements( + displacement_vector, displacement_unit_factor = _get_inter_session_displacements( shift, non_rigid_gradient, num_units, @@ -114,7 +170,7 @@ def generate_session_displacement_recordings( ) # Generate the (possibly shifted, scaled) unit templates - templates_moved_array = generate_templates( + template_array_moved = generate_templates( channel_locations, unit_locations_moved, sampling_frequency=sampling_frequency, @@ -124,8 +180,8 @@ def generate_session_displacement_recordings( if recording_amplitude_scalings is not None: - templates_moved_array = amplitude_scale_templates_in_place( - templates_moved_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx + template_array_moved = _amplitude_scale_templates_in_place( + template_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx ) # Bring it all together in a `InjectTemplatesRecording` and @@ -135,7 +191,7 @@ def generate_session_displacement_recordings( recording = InjectTemplatesRecording( sorting=sorting, - templates=templates_moved_array, + templates=template_array_moved, nbefore=nbefore, amplitude_factor=None, parent_recording=noise, @@ -152,7 +208,7 @@ def generate_session_displacement_recordings( output_recordings.append(recording) output_sortings.append(sorting) extra_outputs_dict["unit_locations"].append(unit_locations_moved) - extra_outputs_dict["template_array_moved"].append(templates_moved_array) + extra_outputs_dict["template_array_moved"].append(template_array_moved) if extra_outputs: return output_recordings, output_sortings, extra_outputs_dict @@ -160,11 +216,38 @@ def generate_session_displacement_recordings( return output_recordings, output_sortings -def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations): - """ """ +def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations): + """ + Get the formatted `displacement_vector` and `displacement_unit_factor` + used to shift the `unit_locations`.. + + Parameters + --------- + shift : np.array | list | tuple + A 2-element array with the shift in the (x, y) direction. + non_rigid_gradient : float + Factor which sets the level of non-rigidty in the displacement. + See `calculate_displacement_unit_factor` for details. + num_units : int + Number of units + unit_locations : np.array + (num_units, 3) array of unit locations (x, y, z). + + Returns + ------- + displacement_vector : np.array + A (:, 2) array of (x, y) of displacements + to add to (i.e. move) unit_locations. + e.g. array([[1, 2]]) + displacement_unit_factor : np.array + A (num_units, :) array of scaling values to apply to the + displacement vector in order to add nonrigid shift to + the displacement. Note the same scaling is applied to the + x and y dimension. + """ displacement_vector = np.atleast_2d(shift) - if non_rigid_gradient is None or shift == (0, 0): + if non_rigid_gradient is None or (shift[0] == 0 and shift[1] == 0): displacement_unit_factor = np.ones((num_units, 1)) else: displacement_unit_factor = calculate_displacement_unit_factor( @@ -178,8 +261,38 @@ def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_l return displacement_vector, displacement_unit_factor -def amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx): - """ """ +def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx): + """ + Scale a set of templates given a set of scaling values. The scaling + values can be applied in the order passed, or instead in order of + the unit firing range (max to min) or unit amplitude * firing rate (max to min). + This will chang the `templates_array` in place. + + Parameters + ---------- + templates_array : np.array + A (num_units, num_samples, num_channels) array of + template waveforms for all units. + recording_amplitude_scalings : dict + see `generate_session_displacement_recordings()`. + sorting_extra_outputs : dict + Extra output of `generate_sorting` holding the firing frequency of all units. + The unit order is assumed to match the templates. + rec_idx : int + The index of the recording for which the templates are being scaled. + + Notes + ----- + This method is used in the context of inter-session displacement. Often, + units may drop out of the recording across sessions. This simulates this by + directly scaling the template (e.g. if scaling by 0, the template is completely + dropped out). The provided scalings can be applied in the order passed, or + in the order of unit firing rate or firing rate * amplitude. The idea is + that it may be desired to remove to downscale the most activate neurons + that contribute most significantly to activity histograms. Similarly, + if amplitude is included in activity histograms the amplitude may + also want to be considered when ordering the units to downscale. + """ method = recording_amplitude_scalings["method"] if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]: