Skip to content

Commit

Permalink
Finalise and tidy up tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jul 25, 2024
1 parent 5f012a2 commit 4dc065f
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 28 deletions.
4 changes: 2 additions & 2 deletions debugging/debugging_session_displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@
"method": "by_amplitude_and_firing_rate",
"scalings": scale_,
},
generate_sorting_kwargs=dict(firing_rates=(149, 150), refractory_period_ms=4.0),
generate_sorting_kwargs=dict(firing_rates=(0, 200), refractory_period_ms=4.0),
generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3),
seed=44,
seed=None,
generate_unit_locations_kwargs=dict(
margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up.
minimum_z=5.0,
Expand Down
58 changes: 48 additions & 10 deletions src/spikeinterface/generation/session_displacement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@
)
import numpy as np
from spikeinterface.generation.noise_tools import generate_noise
from spikeinterface.core.generate import setup_inject_templates_recording
from spikeinterface.core.generate import setup_inject_templates_recording, _ensure_firing_rates
from spikeinterface.core import InjectTemplatesRecording


# TODO: add note on what is fixed / not fixed across sessions
# TODO: tests are failing because of mutable default arguments.
# will need to fix this before proceeding.


def generate_session_displacement_recordings(
num_units=250,
recording_durations=(10, 10, 10),
Expand Down Expand Up @@ -87,7 +92,8 @@ def generate_session_displacement_recordings(
"scalings" - a list of numpy arrays, one for each recording, with
each entry an array of length num_units holding the unit scalings.
e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)).
generate_sorting_kwargs : dict
Only `firing_rates` and `refractory_period_ms` are expected if passed.
All other parameters are used as in from `generate_drifting_recording()`.
Returns
Expand All @@ -103,6 +109,19 @@ def generate_session_displacement_recordings(
"templates_array_moved" : list[np.array]
A list (length num records) of (num_units, num_samples, num_channels)
arrays of templates that have been shifted.
Notes
-----
It is important to consider what unit properties are maintained
across the session. Here, all `generate_template_kwargs` are fixed
across sessions, to be sure the unit properties do not change.
The firing rates passed to `generate_sorting` for each unit are
also fixed across sessions. When a seed is set, the exact spike times
will also be fixed across recordings. otherwise, when seed is `None`
the actual spike times will be different across recordings, although
all other unit properties will be maintained (except any location
shifting and template scaling applied).
"""
_check_generate_session_displacement_arguments(
num_units, recording_durations, recording_shifts, recording_amplitude_scalings
Expand All @@ -120,13 +139,18 @@ def generate_session_displacement_recordings(
)

# Fix generate template kwargs, so they are the same for every created recording.
# Also fix unit firing rates across recordings.
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)

fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed)
generate_sorting_kwargs["firing_rates"] = fixed_firing_rates

# Start looping over parameters, creating recordings shifted
# and scaled as required
extra_outputs_dict = {
"unit_locations": [],
"templates_array_moved": [],
"firing_rates": [],
}
output_recordings = []
output_sortings = []
Expand Down Expand Up @@ -173,9 +197,16 @@ def generate_session_displacement_recordings(
**generate_templates_kwargs,
)

# TODO: these first amplitdues don't change per loop, but are usually not
# needed...
if recording_amplitude_scalings is not None:

first_rec_templates = (
templates_array_moved if rec_idx == 0 else extra_outputs_dict["templates_array_moved"][0]
)

_amplitude_scale_templates_in_place(
templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
first_rec_templates, templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
)

# Bring it all together in a `InjectTemplatesRecording` and
Expand Down Expand Up @@ -203,6 +234,7 @@ def generate_session_displacement_recordings(
output_sortings.append(sorting)
extra_outputs_dict["unit_locations"].append(unit_locations_moved)
extra_outputs_dict["templates_array_moved"].append(templates_array_moved)
extra_outputs_dict["firing_rates"].append(sorting_extra_outputs["firing_rates"][0])

if extra_outputs:
return output_recordings, output_sortings, extra_outputs_dict
Expand Down Expand Up @@ -255,7 +287,9 @@ def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_
return displacement_vector, displacement_unit_factor


def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx):
def _amplitude_scale_templates_in_place(
first_rec_templates, moved_templates, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
):
"""
Scale a set of templates given a set of scaling values. The scaling
values can be applied in the order passed, or instead in order of
Expand All @@ -264,9 +298,13 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca
Parameters
----------
templates_array : np.array
A (num_units, num_samples, num_channels) array of
template waveforms for all units.
first_rec_templates : np.array
The (num_units, num_samples, num_channels) templates array from the
first recording. Scaling by amplitude scales based on the amplitudes in
the first session.
moved_templates : np.array
A (num_units, num_samples, num_channels) array moved templates to the
current recording, that will be scaled.
recording_amplitude_scalings : dict
see `generate_session_displacement_recordings()`.
sorting_extra_outputs : dict
Expand Down Expand Up @@ -294,12 +332,12 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca
firing_rates_hz = sorting_extra_outputs["firing_rates"][0]

if method == "by_amplitude_and_firing_rate":
neg_ampl = np.min(np.min(templates_array, axis=2), axis=1)
neg_ampl = np.min(np.min(first_rec_templates, axis=2), axis=1)
assert np.all(neg_ampl < 0), "assumes all amplitudes are negative here."
score = firing_rates_hz * neg_ampl
else:
score = firing_rates_hz

assert np.all(score < 0), "assumes all amplitudes are negative here."
order_idx = np.argsort(score)
ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis]

Expand All @@ -310,7 +348,7 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca
else:
raise ValueError("`recording_amplitude_scalings['method']` not recognised.")

templates_array *= ordered_rec_scalings
moved_templates *= ordered_rec_scalings


def _check_generate_session_displacement_arguments(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestSessionDisplacementGenerator:
"""
This class tests the `generate_session_displacement_recordings` that
returns a recordings / sorting in which the units are shifted
across sessions. This is acheived by shifting the unit locations
across sessions. This is achieved by shifting the unit locations
in both (x, y) on the generated templates that are used in
`InjectTemplatesRecording()`.
"""
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_recordings_length(self, options):
for rec, expected_rec_length in zip(output_recordings, options["kwargs"]["recording_durations"]):
assert rec.get_total_duration() == expected_rec_length

def test_spike_times_across_recordings(self, options):
def test_spike_times_and_firing_rates_across_recordings(self, options):
"""
Check the randomisation of spike times across recordings.
When a seed is set, this is passed to `generate_sorting`
Expand All @@ -146,14 +146,17 @@ def test_spike_times_across_recordings(self, options):
"""
options["kwargs"]["recording_durations"] = (10,) * options["num_recs"]

output_sortings_same = generate_session_displacement_recordings(**options["kwargs"])[1]
output_sortings_same, extra_outputs_same = generate_session_displacement_recordings(**options["kwargs"])[1:3]

options["kwargs"]["seed"] = None
output_sortings_different = generate_session_displacement_recordings(**options["kwargs"])[1]
output_sortings_different, extra_outputs_different = generate_session_displacement_recordings(
**options["kwargs"]
)[1:3]

for unit_idx in range(options["kwargs"]["num_units"]):
for rec_idx in range(1, options["num_recs"]):

# Exact spike times are not preserved when seed is None
assert np.array_equal(
output_sortings_same[0].get_unit_spike_train(unit_idx),
output_sortings_same[rec_idx].get_unit_spike_train(unit_idx),
Expand All @@ -162,6 +165,15 @@ def test_spike_times_across_recordings(self, options):
output_sortings_different[0].get_unit_spike_train(unit_idx),
output_sortings_different[rec_idx].get_unit_spike_train(unit_idx),
)
# Firing rates should always be preserved.
assert np.array_equal(
extra_outputs_same["firing_rates"][0][unit_idx],
extra_outputs_same["firing_rates"][rec_idx][unit_idx],
)
assert np.array_equal(
extra_outputs_different["firing_rates"][0][unit_idx],
extra_outputs_different["firing_rates"][rec_idx][unit_idx],
)

@pytest.mark.parametrize("dim_idx", [0, 1])
def test_x_y_shift_non_rigid(self, options, dim_idx):
Expand Down Expand Up @@ -271,32 +283,70 @@ def test_displacement_with_peak_detection(self, options):
assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"])

def test_amplitude_scalings(self, options):

"""
Test that the templates are scaled by the passed scaling factors
in the specified order. The order can be in the passed order,
in the order of highest-to-lowest firing unit, or in the order
of (amplitude * firing_rate) (highest to lowest unit).
"""
# Setup arguments to create an unshifted set of recordings
# where the templates are to be scaled with `true_scalings`
options["kwargs"]["recording_durations"] = (10, 10)
options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0))
options["kwargs"]["num_units"] == 5,

true_scalings = np.array([0.1, 0.2, 0.3, 0.4, 0.5])

recording_amplitude_scalings = {
"method": "by_passed_order",
"scalings": (np.ones(5), np.array([0.1, 0.2, 0.3, 0.4, 0.5])),
"scalings": (np.ones(5), true_scalings),
}

_, output_sortings, extra_outputs = generate_session_displacement_recordings(
**options["kwargs"],
recording_amplitude_scalings=recording_amplitude_scalings,
)
breakpoint()
first, second = extra_outputs["templates_array_moved"] # TODO: own function
first_min = np.min(np.min(first, axis=2), axis=1)
second_min = np.min(np.min(second, axis=2), axis=1)
scales = second_min / first_min

assert np.allclose(scales, shifts)
# Check that the unit templates are scaled in the order
# the scalings were passed.
test_scalings = self._calculate_scalings_from_output(extra_outputs)
assert np.allclose(test_scalings, true_scalings)

# Now run, again applying the scalings in the order of
# unit firing rates (highest to lowest).
firing_rates = np.array([5, 4, 3, 2, 1])
generate_sorting_kwargs = dict(firing_rates=firing_rates, refractory_period_ms=4.0)
recording_amplitude_scalings["method"] = "by_firing_rate"
_, output_sortings, extra_outputs = generate_session_displacement_recordings(
**options["kwargs"],
recording_amplitude_scalings=recording_amplitude_scalings,
generate_sorting_kwargs=generate_sorting_kwargs,
)

test_scalings = self._calculate_scalings_from_output(extra_outputs)
assert np.allclose(test_scalings, true_scalings[np.argsort(firing_rates)])

# TODO: scale based on recording output
# check scaled by amplitude.
# Finally, run again applying the scalings in the order of
# unit amplitude * firing_rate
recording_amplitude_scalings["method"] = "by_amplitude_and_firing_rate" # TODO: method -> order
amplitudes = np.min(np.min(extra_outputs["templates_array_moved"][0], axis=2), axis=1)
firing_rate_by_amplitude = np.argsort(amplitudes * firing_rates)

breakpoint()
_, output_sortings, extra_outputs = generate_session_displacement_recordings(
**options["kwargs"],
recording_amplitude_scalings=recording_amplitude_scalings,
generate_sorting_kwargs=generate_sorting_kwargs,
)

test_scalings = self._calculate_scalings_from_output(extra_outputs)
assert np.allclose(test_scalings, true_scalings[firing_rate_by_amplitude])

def _calculate_scalings_from_output(self, extra_outputs):
first, second = extra_outputs["templates_array_moved"]
first_min = np.min(np.min(first, axis=2), axis=1)
second_min = np.min(np.min(second, axis=2), axis=1)
test_scalings = second_min / first_min
return test_scalings

def test_metadata(self, options):
"""
Expand Down Expand Up @@ -339,7 +389,7 @@ def test_same_as_generate_ground_truth_recording(self):
generate_probe_kwargs = None
generate_unit_locations_kwargs = dict()
generate_templates_kwargs = dict(ms_before=1.5, ms_after=3)
generate_sorting_kwargs = dict()
generate_sorting_kwargs = dict(firing_rates=1)
generate_noise_kwargs = dict()
seed = 42

Expand Down

0 comments on commit 4dc065f

Please sign in to comment.