Skip to content

Commit

Permalink
Add shift_units_outside_probe.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Aug 28, 2024
1 parent c471c21 commit 05a460a
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 7 deletions.
124 changes: 117 additions & 7 deletions src/spikeinterface/generation/session_displacement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def generate_session_displacement_recordings(
recording_shifts=((0, 0), (0, 25), (0, 50)),
non_rigid_gradient=None,
recording_amplitude_scalings=None,
shift_units_outside_probe=False,
sampling_frequency=30000.0,
probe_name="Neuropixel-128",
generate_probe_kwargs=None,
Expand Down Expand Up @@ -87,8 +88,16 @@ def generate_session_displacement_recordings(
"scalings" - a list of numpy arrays, one for each recording, with
each entry an array of length num_units holding the unit scalings.
e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
shift_units_outside_probe : bool
By default (`False`), when units are shifted across sessions, new units are
not introduced into the recording (e.g. the region in which units
have been shifted out of is left at baseline level). In reality,
when the probe shifts new units from outside the original recorded
region are shifted into the recording. When `True`, new units
are shifted into the generated recording.
generate_sorting_kwargs : dict
Only `firing_rates` and `refractory_period_ms` are expected if passed.
All other parameters are used as in from `generate_drifting_recording()`.
Returns
Expand All @@ -105,7 +114,6 @@ def generate_session_displacement_recordings(
A list (length num records) of (num_units, num_samples, num_channels)
arrays of templates that have been shifted.
Notes
-----
It is important to consider what unit properties are maintained
Expand Down Expand Up @@ -141,12 +149,28 @@ def generate_session_displacement_recordings(

# Fix generate template kwargs, so they are the same for every created recording.
# Also fix unit firing rates across recordings.
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)
fixed_generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)

fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed)
generate_sorting_kwargs["firing_rates"] = fixed_firing_rates
fixed_generate_sorting_kwargs = copy.deepcopy(generate_sorting_kwargs)
fixed_generate_sorting_kwargs["firing_rates"] = fixed_firing_rates

if shift_units_outside_probe:
num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs = (
_update_kwargs_for_extended_units(
num_units,
channel_locations,
unit_locations,
generate_unit_locations_kwargs,
generate_templates_kwargs,
generate_sorting_kwargs,
fixed_generate_templates_kwargs,
fixed_generate_sorting_kwargs,
seed,
)
)

# Start looping over parameters, creating recordings shifted
# Start looping over parameters, creating recordings shifted
# and scaled as required
extra_outputs_dict = {
"unit_locations": [],
Expand Down Expand Up @@ -174,7 +198,7 @@ def generate_session_displacement_recordings(
num_units=num_units,
sampling_frequency=sampling_frequency,
durations=[duration],
**generate_sorting_kwargs,
**fixed_generate_sorting_kwargs,
extra_outputs=True,
seed=seed,
)
Expand All @@ -195,7 +219,7 @@ def generate_session_displacement_recordings(
unit_locations_moved,
sampling_frequency=sampling_frequency,
seed=seed,
**generate_templates_kwargs,
**fixed_generate_templates_kwargs,
)

if recording_amplitude_scalings is not None:
Expand All @@ -210,7 +234,7 @@ def generate_session_displacement_recordings(

# Bring it all together in a `InjectTemplatesRecording` and
# propagate all relevant metadata to the recording.
ms_before = generate_templates_kwargs["ms_before"]
ms_before = fixed_generate_templates_kwargs["ms_before"]
nbefore = int(sampling_frequency * ms_before / 1000.0)

recording = InjectTemplatesRecording(
Expand Down Expand Up @@ -388,3 +412,89 @@ def _check_generate_session_displacement_arguments(
"The entry for each recording in `recording_amplitude_scalings` "
"must have the same length as the number of units."
)


def _update_kwargs_for_extended_units(
num_units,
channel_locations,
unit_locations,
generate_unit_locations_kwargs,
generate_templates_kwargs,
generate_sorting_kwargs,
fixed_generate_templates_kwargs,
fixed_generate_sorting_kwargs,
seed,
):
"""
In a real world situation, if the probe moves up / down
not only will previously recorded units be shifted, but
new units will be introduced into the recording.
This function extends the default num units, unit locations,
and template / sorting kwargs to extend the unit of units
one probe's height (y dimension) above and below the probe.
Now, when the unit locations are shifted, new units will be
introduced into the recording (from below or above).
It is important that the unit kwargs for the units are kept the
same across runs when seeded (i.e. whether `shift_units_outside_probe`
is `True` or `False`). To acheive this, the fixed unit kwargs
are extended with new units located above and below these fixed
units. The seeds are shifted slightly, so the introduced
units do not duplicate the existing units.
"""
seed_top = seed + 1 if seed is not None else None
seed_bottom = seed - 1 if seed is not None else None

# Set unit locations above and below the probe and extend
# the `unit_locations` array.
channel_locations_extend_top = channel_locations.copy()
channel_locations_extend_top[:, 1] -= np.max(channel_locations[:, 1])

extend_top_locations = generate_unit_locations(
num_units,
channel_locations_extend_top,
seed=seed_top,
**generate_unit_locations_kwargs,
)

channel_locations_extend_bottom = channel_locations.copy()
channel_locations_extend_bottom[:, 1] += np.max(channel_locations[:, 1])

extend_bottom_locations = generate_unit_locations(
num_units,
channel_locations_extend_bottom,
seed=seed_bottom,
**generate_unit_locations_kwargs,
)

unit_locations = np.r_[extend_bottom_locations, unit_locations, extend_top_locations]

# For the new units located above and below the probe, generate a set of
# firing rates and template kwargs.

# Extend the template kwargs
template_kwargs_top = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_top)
template_kwargs_bottom = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed_bottom)

for key in fixed_generate_templates_kwargs["unit_params"].keys():
fixed_generate_templates_kwargs["unit_params"][key] = np.r_[
template_kwargs_top["unit_params"][key],
fixed_generate_templates_kwargs["unit_params"][key],
template_kwargs_bottom["unit_params"][key],
]

# Extend the firing rates
firing_rates_top = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_top)
firing_rates_bottom = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed_bottom)

fixed_generate_sorting_kwargs["firing_rates"] = np.r_[
firing_rates_top, fixed_generate_sorting_kwargs["firing_rates"], firing_rates_bottom
]

# Update the number of units (3x as a
# new set above and below the existing units)
num_units *= 3

return num_units, unit_locations, fixed_generate_templates_kwargs, fixed_generate_sorting_kwargs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,48 @@ def test_metadata(self, options):
)
assert output_sortings[i].name == "InterSessionDisplacementSorting"

def test_shift_units_outside_probe(self, options):
"""
When `shift_units_outside_probe` is `True`, a new set of
units above and below the probe (y dimension) are created,
such that they may be shifted into the recording.
Here, check that these new units are created when `shift_units_outside_probe`
is on and that the kwargs for the central set of units match those
as when `shift_units_outside_probe` is `False`.
"""
num_sessions = len(options["kwargs"]["recording_durations"])
_, _, baseline_outputs = generate_session_displacement_recordings(
**options["kwargs"],
)

_, _, outside_probe_outputs = generate_session_displacement_recordings(
**options["kwargs"], shift_units_outside_probe=True
)

num_units = options["kwargs"]["num_units"]
num_extended_units = num_units * 3

for ses_idx in range(num_sessions):

# There are 3x the number of units when new units are created
# (one new set above, and one new set below the probe).
for key in ["unit_locations", "templates_array_moved", "firing_rates"]:
assert outside_probe_outputs[key][ses_idx].shape[0] == num_extended_units

assert np.array_equal(
baseline_outputs[key][ses_idx], outside_probe_outputs[key][ses_idx][num_units:-num_units]
)

# The kwargs of the units in the central positions should be identical
# to those when `shift_units_outside_probe` is `False`.
lower_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][-num_units:][:, 1]
upper_unit_pos = outside_probe_outputs["unit_locations"][ses_idx][:num_units][:, 1]
middle_unit_pos = baseline_outputs["unit_locations"][ses_idx][:, 1]

assert np.min(upper_unit_pos) > np.max(middle_unit_pos)
assert np.max(lower_unit_pos) < np.min(middle_unit_pos)

def test_same_as_generate_ground_truth_recording(self):
"""
It is expected that inter-session displacement randomly
Expand Down

0 comments on commit 05a460a

Please sign in to comment.