diff --git a/debugging/debugging_session_displacement.py b/debugging/debugging_session_displacement.py index 74140c21c7..8b79f1f1a6 100644 --- a/debugging/debugging_session_displacement.py +++ b/debugging/debugging_session_displacement.py @@ -50,9 +50,9 @@ "method": "by_amplitude_and_firing_rate", "scalings": scale_, }, - generate_sorting_kwargs=dict(firing_rates=(149, 150), refractory_period_ms=4.0), + generate_sorting_kwargs=dict(firing_rates=(0, 200), refractory_period_ms=4.0), generate_templates_kwargs=dict(unit_params=default_unit_params_range, ms_before=1.5, ms_after=3), - seed=44, + seed=None, generate_unit_locations_kwargs=dict( margin_um=0.0, # if this is say 20, then units go off the edge of the probe and are such low amplitude they are not picked up. minimum_z=5.0, diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py index 17e5028584..0c17bc062e 100644 --- a/src/spikeinterface/generation/session_displacement_generator.py +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -12,10 +12,15 @@ ) import numpy as np from spikeinterface.generation.noise_tools import generate_noise -from spikeinterface.core.generate import setup_inject_templates_recording +from spikeinterface.core.generate import setup_inject_templates_recording, _ensure_firing_rates from spikeinterface.core import InjectTemplatesRecording +# TODO: add note on what is fixed / not fixed across sessions +# TODO: tests are failing because of mutable default arguments. +# will need to fix this before proceeding. + + def generate_session_displacement_recordings( num_units=250, recording_durations=(10, 10, 10), @@ -87,7 +92,8 @@ def generate_session_displacement_recordings( "scalings" - a list of numpy arrays, one for each recording, with each entry an array of length num_units holding the unit scalings. e.g. for 3 recordings, 2 units: ((1, 1), (1, 1), (0.5, 0.5)). - + generate_sorting_kwargs : dict + Only `firing_rates` and `refractory_period_ms` are expected if passed. All other parameters are used as in from `generate_drifting_recording()`. Returns @@ -103,6 +109,19 @@ def generate_session_displacement_recordings( "templates_array_moved" : list[np.array] A list (length num records) of (num_units, num_samples, num_channels) arrays of templates that have been shifted. + + + Notes + ----- + It is important to consider what unit properties are maintained + across the session. Here, all `generate_template_kwargs` are fixed + across sessions, to be sure the unit properties do not change. + The firing rates passed to `generate_sorting` for each unit are + also fixed across sessions. When a seed is set, the exact spike times + will also be fixed across recordings. otherwise, when seed is `None` + the actual spike times will be different across recordings, although + all other unit properties will be maintained (except any location + shifting and template scaling applied). """ _check_generate_session_displacement_arguments( num_units, recording_durations, recording_shifts, recording_amplitude_scalings @@ -120,13 +139,18 @@ def generate_session_displacement_recordings( ) # Fix generate template kwargs, so they are the same for every created recording. + # Also fix unit firing rates across recordings. generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) + fixed_firing_rates = _ensure_firing_rates(generate_sorting_kwargs["firing_rates"], num_units, seed) + generate_sorting_kwargs["firing_rates"] = fixed_firing_rates + # Start looping over parameters, creating recordings shifted # and scaled as required extra_outputs_dict = { "unit_locations": [], "templates_array_moved": [], + "firing_rates": [], } output_recordings = [] output_sortings = [] @@ -173,9 +197,16 @@ def generate_session_displacement_recordings( **generate_templates_kwargs, ) + # TODO: these first amplitdues don't change per loop, but are usually not + # needed... if recording_amplitude_scalings is not None: + + first_rec_templates = ( + templates_array_moved if rec_idx == 0 else extra_outputs_dict["templates_array_moved"][0] + ) + _amplitude_scale_templates_in_place( - templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx + first_rec_templates, templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx ) # Bring it all together in a `InjectTemplatesRecording` and @@ -203,6 +234,7 @@ def generate_session_displacement_recordings( output_sortings.append(sorting) extra_outputs_dict["unit_locations"].append(unit_locations_moved) extra_outputs_dict["templates_array_moved"].append(templates_array_moved) + extra_outputs_dict["firing_rates"].append(sorting_extra_outputs["firing_rates"][0]) if extra_outputs: return output_recordings, output_sortings, extra_outputs_dict @@ -255,7 +287,9 @@ def _get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_ return displacement_vector, displacement_unit_factor -def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx): +def _amplitude_scale_templates_in_place( + first_rec_templates, moved_templates, recording_amplitude_scalings, sorting_extra_outputs, rec_idx +): """ Scale a set of templates given a set of scaling values. The scaling values can be applied in the order passed, or instead in order of @@ -264,9 +298,13 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca Parameters ---------- - templates_array : np.array - A (num_units, num_samples, num_channels) array of - template waveforms for all units. + first_rec_templates : np.array + The (num_units, num_samples, num_channels) templates array from the + first recording. Scaling by amplitude scales based on the amplitudes in + the first session. + moved_templates : np.array + A (num_units, num_samples, num_channels) array moved templates to the + current recording, that will be scaled. recording_amplitude_scalings : dict see `generate_session_displacement_recordings()`. sorting_extra_outputs : dict @@ -294,12 +332,12 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca firing_rates_hz = sorting_extra_outputs["firing_rates"][0] if method == "by_amplitude_and_firing_rate": - neg_ampl = np.min(np.min(templates_array, axis=2), axis=1) + neg_ampl = np.min(np.min(first_rec_templates, axis=2), axis=1) + assert np.all(neg_ampl < 0), "assumes all amplitudes are negative here." score = firing_rates_hz * neg_ampl else: score = firing_rates_hz - assert np.all(score < 0), "assumes all amplitudes are negative here." order_idx = np.argsort(score) ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis] @@ -310,7 +348,7 @@ def _amplitude_scale_templates_in_place(templates_array, recording_amplitude_sca else: raise ValueError("`recording_amplitude_scalings['method']` not recognised.") - templates_array *= ordered_rec_scalings + moved_templates *= ordered_rec_scalings def _check_generate_session_displacement_arguments( diff --git a/src/spikeinterface/generation/tests/test_session_displacement_generator.py b/src/spikeinterface/generation/tests/test_session_displacement_generator.py index 645b1a6cce..d00a071a1c 100644 --- a/src/spikeinterface/generation/tests/test_session_displacement_generator.py +++ b/src/spikeinterface/generation/tests/test_session_displacement_generator.py @@ -12,7 +12,7 @@ class TestSessionDisplacementGenerator: """ This class tests the `generate_session_displacement_recordings` that returns a recordings / sorting in which the units are shifted - across sessions. This is acheived by shifting the unit locations + across sessions. This is achieved by shifting the unit locations in both (x, y) on the generated templates that are used in `InjectTemplatesRecording()`. """ @@ -136,7 +136,7 @@ def test_recordings_length(self, options): for rec, expected_rec_length in zip(output_recordings, options["kwargs"]["recording_durations"]): assert rec.get_total_duration() == expected_rec_length - def test_spike_times_across_recordings(self, options): + def test_spike_times_and_firing_rates_across_recordings(self, options): """ Check the randomisation of spike times across recordings. When a seed is set, this is passed to `generate_sorting` @@ -146,14 +146,17 @@ def test_spike_times_across_recordings(self, options): """ options["kwargs"]["recording_durations"] = (10,) * options["num_recs"] - output_sortings_same = generate_session_displacement_recordings(**options["kwargs"])[1] + output_sortings_same, extra_outputs_same = generate_session_displacement_recordings(**options["kwargs"])[1:3] options["kwargs"]["seed"] = None - output_sortings_different = generate_session_displacement_recordings(**options["kwargs"])[1] + output_sortings_different, extra_outputs_different = generate_session_displacement_recordings( + **options["kwargs"] + )[1:3] for unit_idx in range(options["kwargs"]["num_units"]): for rec_idx in range(1, options["num_recs"]): + # Exact spike times are not preserved when seed is None assert np.array_equal( output_sortings_same[0].get_unit_spike_train(unit_idx), output_sortings_same[rec_idx].get_unit_spike_train(unit_idx), @@ -162,6 +165,15 @@ def test_spike_times_across_recordings(self, options): output_sortings_different[0].get_unit_spike_train(unit_idx), output_sortings_different[rec_idx].get_unit_spike_train(unit_idx), ) + # Firing rates should always be preserved. + assert np.array_equal( + extra_outputs_same["firing_rates"][0][unit_idx], + extra_outputs_same["firing_rates"][rec_idx][unit_idx], + ) + assert np.array_equal( + extra_outputs_different["firing_rates"][0][unit_idx], + extra_outputs_different["firing_rates"][rec_idx][unit_idx], + ) @pytest.mark.parametrize("dim_idx", [0, 1]) def test_x_y_shift_non_rigid(self, options, dim_idx): @@ -271,32 +283,70 @@ def test_displacement_with_peak_detection(self, options): assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"]) def test_amplitude_scalings(self, options): - + """ + Test that the templates are scaled by the passed scaling factors + in the specified order. The order can be in the passed order, + in the order of highest-to-lowest firing unit, or in the order + of (amplitude * firing_rate) (highest to lowest unit). + """ + # Setup arguments to create an unshifted set of recordings + # where the templates are to be scaled with `true_scalings` options["kwargs"]["recording_durations"] = (10, 10) options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0)) options["kwargs"]["num_units"] == 5, + true_scalings = np.array([0.1, 0.2, 0.3, 0.4, 0.5]) + recording_amplitude_scalings = { "method": "by_passed_order", - "scalings": (np.ones(5), np.array([0.1, 0.2, 0.3, 0.4, 0.5])), + "scalings": (np.ones(5), true_scalings), } _, output_sortings, extra_outputs = generate_session_displacement_recordings( **options["kwargs"], recording_amplitude_scalings=recording_amplitude_scalings, ) - breakpoint() - first, second = extra_outputs["templates_array_moved"] # TODO: own function - first_min = np.min(np.min(first, axis=2), axis=1) - second_min = np.min(np.min(second, axis=2), axis=1) - scales = second_min / first_min - assert np.allclose(scales, shifts) + # Check that the unit templates are scaled in the order + # the scalings were passed. + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings) + + # Now run, again applying the scalings in the order of + # unit firing rates (highest to lowest). + firing_rates = np.array([5, 4, 3, 2, 1]) + generate_sorting_kwargs = dict(firing_rates=firing_rates, refractory_period_ms=4.0) + recording_amplitude_scalings["method"] = "by_firing_rate" + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[np.argsort(firing_rates)]) - # TODO: scale based on recording output - # check scaled by amplitude. + # Finally, run again applying the scalings in the order of + # unit amplitude * firing_rate + recording_amplitude_scalings["method"] = "by_amplitude_and_firing_rate" # TODO: method -> order + amplitudes = np.min(np.min(extra_outputs["templates_array_moved"][0], axis=2), axis=1) + firing_rate_by_amplitude = np.argsort(amplitudes * firing_rates) - breakpoint() + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + generate_sorting_kwargs=generate_sorting_kwargs, + ) + + test_scalings = self._calculate_scalings_from_output(extra_outputs) + assert np.allclose(test_scalings, true_scalings[firing_rate_by_amplitude]) + + def _calculate_scalings_from_output(self, extra_outputs): + first, second = extra_outputs["templates_array_moved"] + first_min = np.min(np.min(first, axis=2), axis=1) + second_min = np.min(np.min(second, axis=2), axis=1) + test_scalings = second_min / first_min + return test_scalings def test_metadata(self, options): """ @@ -339,7 +389,7 @@ def test_same_as_generate_ground_truth_recording(self): generate_probe_kwargs = None generate_unit_locations_kwargs = dict() generate_templates_kwargs = dict(ms_before=1.5, ms_after=3) - generate_sorting_kwargs = dict() + generate_sorting_kwargs = dict(firing_rates=1) generate_noise_kwargs = dict() seed = 42