Skip to content

Commit

Permalink
Add documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jul 24, 2024
1 parent 3337037 commit 9798c75
Showing 1 changed file with 125 additions and 12 deletions.
137 changes: 125 additions & 12 deletions src/spikeinterface/generation/session_displacement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,63 @@ def generate_session_displacement_recordings(
extra_outputs=False,
seed=None,
):
""" """
"""
Generate a set of recordings simulating probe drift across recording
sessions.
Rigid drift can be added in the (x, y) direction in `recording_shifts`.
These drifts can be made non-rigid (scaled dependent on the unit location)
with the `non_rigid_gradient` parameter. Amplitude of units can be scaled
(e.g. template signal removed by scaling with zero) by specifying scaling
factors in `recording_amplitude_scalings`.
Parameters
----------
num_units : int
The number of units in the generated recordings.
recording_durations : list
An array of length (num_recordings,) specifying the
duration that each created recording should be.
recording_shifts : list
An array of length (num_recordings,) in which each element
is a 2-element array specifying the (x, y) shift for the recording.
Typically, the first recording will have shift (0, 0) so all further
recordings are shifted relative to it. e.g. to create two recordings,
the second shifted by 50 um in the x-direction and 250 um in the y
direction : ((0, 0), (50, 250)).
non_rigid_gradient : float
Factor which sets the level of non-rigidty in the displacement.
See `calculate_displacement_unit_factor` for details.
recording_amplitude_scalings : dict
A dict with keys:
"method" - order by which to apply the scalings.
"by_passed_order" - scalings are applied to the unit templates
in order passed
"by_firing_rate" - scalings are applied to the units in order of
maximum to minimum firing rate
"by_amplitude_and_firing_rate" - scalings are applied to the units
in order of amplitude * firing_rate (maximum to minimum)
"scalings" - a list of numpy arrays, one for each recording, with
each entry an array of length num_units holding the unit scalings.
e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
All other parameters are used as in from `generate_drifting_recording()`.
Returns
-------
output_recordings : list
A list of recordings with units shifted (i.e. replicated probe shift).
output_sortings : list
A list of corresponding sorting objects.
extra_outputs_dict (options) : dict
When `extra_outputs` is `True`, a dict containing variables used
in the generation process.
"unit_locations" : A list (length num records) of shifted unit locations
"templates_array_moved" : list[np.array]
A list (length num records) of (num_units, num_samples, num_channels)
arrays of templates that have been shifted.
"""
_check_generate_session_displacement_arguments(
num_units, recording_durations, recording_shifts, recording_amplitude_scalings
)
Expand Down Expand Up @@ -82,7 +138,7 @@ def generate_session_displacement_recordings(

for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)):

displacement_vector, displacement_unit_factor = get_inter_session_displacements(
displacement_vector, displacement_unit_factor = _get_inter_session_displacements(
shift,
non_rigid_gradient,
num_units,
Expand Down Expand Up @@ -114,7 +170,7 @@ def generate_session_displacement_recordings(
)

# Generate the (possibly shifted, scaled) unit templates
templates_moved_array = generate_templates(
template_array_moved = generate_templates(
channel_locations,
unit_locations_moved,
sampling_frequency=sampling_frequency,
Expand All @@ -124,8 +180,8 @@ def generate_session_displacement_recordings(

if recording_amplitude_scalings is not None:

templates_moved_array = amplitude_scale_templates_in_place(
templates_moved_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
template_array_moved = _amplitude_scale_templates_in_place(
template_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
)

# Bring it all together in a `InjectTemplatesRecording` and
Expand All @@ -135,7 +191,7 @@ def generate_session_displacement_recordings(

recording = InjectTemplatesRecording(
sorting=sorting,
templates=templates_moved_array,
templates=template_array_moved,
nbefore=nbefore,
amplitude_factor=None,
parent_recording=noise,
Expand All @@ -152,19 +208,46 @@ def generate_session_displacement_recordings(
output_recordings.append(recording)
output_sortings.append(sorting)
extra_outputs_dict["unit_locations"].append(unit_locations_moved)
extra_outputs_dict["template_array_moved"].append(templates_moved_array)
extra_outputs_dict["template_array_moved"].append(template_array_moved)

if extra_outputs:
return output_recordings, output_sortings, extra_outputs_dict
else:
return output_recordings, output_sortings


def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations):
""" """
def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_locations):
"""
Get the formatted `displacement_vector` and `displacement_unit_factor`
used to shift the `unit_locations`..
Parameters
---------
shift : np.array | list | tuple
A 2-element array with the shift in the (x, y) direction.
non_rigid_gradient : float
Factor which sets the level of non-rigidty in the displacement.
See `calculate_displacement_unit_factor` for details.
num_units : int
Number of units
unit_locations : np.array
(num_units, 3) array of unit locations (x, y, z).
Returns
-------
displacement_vector : np.array
A (:, 2) array of (x, y) of displacements
to add to (i.e. move) unit_locations.
e.g. array([[1, 2]])
displacement_unit_factor : np.array
A (num_units, :) array of scaling values to apply to the
displacement vector in order to add nonrigid shift to
the displacement. Note the same scaling is applied to the
x and y dimension.
"""
displacement_vector = np.atleast_2d(shift)

if non_rigid_gradient is None or shift == (0, 0):
if non_rigid_gradient is None or (shift[0] == 0 and shift[1] == 0):
displacement_unit_factor = np.ones((num_units, 1))
else:
displacement_unit_factor = calculate_displacement_unit_factor(
Expand All @@ -178,8 +261,38 @@ def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_l
return displacement_vector, displacement_unit_factor


def amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx):
""" """
def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx):
"""
Scale a set of templates given a set of scaling values. The scaling
values can be applied in the order passed, or instead in order of
the unit firing range (max to min) or unit amplitude * firing rate (max to min).
This will chang the `templates_array` in place.
Parameters
----------
templates_array : np.array
A (num_units, num_samples, num_channels) array of
template waveforms for all units.
recording_amplitude_scalings : dict
see `generate_session_displacement_recordings()`.
sorting_extra_outputs : dict
Extra output of `generate_sorting` holding the firing frequency of all units.
The unit order is assumed to match the templates.
rec_idx : int
The index of the recording for which the templates are being scaled.
Notes
-----
This method is used in the context of inter-session displacement. Often,
units may drop out of the recording across sessions. This simulates this by
directly scaling the template (e.g. if scaling by 0, the template is completely
dropped out). The provided scalings can be applied in the order passed, or
in the order of unit firing rate or firing rate * amplitude. The idea is
that it may be desired to remove to downscale the most activate neurons
that contribute most significantly to activity histograms. Similarly,
if amplitude is included in activity histograms the amplitude may
also want to be considered when ordering the units to downscale.
"""
method = recording_amplitude_scalings["method"]

if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]:
Expand Down

0 comments on commit 9798c75

Please sign in to comment.