diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py index 6cdd22aa2f..17e5028584 100644 --- a/src/spikeinterface/generation/session_displacement_generator.py +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -16,11 +16,6 @@ from spikeinterface.core import InjectTemplatesRecording -# TODO: test metadata -# TOOD: test new amplitude scalings -# TODO: test correct unit_locations are on the sortings (part of metadata) - - def generate_session_displacement_recordings( num_units=250, recording_durations=(10, 10, 10), @@ -131,7 +126,7 @@ def generate_session_displacement_recordings( # and scaled as required extra_outputs_dict = { "unit_locations": [], - "template_array_moved": [], + "templates_array_moved": [], } output_recordings = [] output_sortings = [] @@ -170,7 +165,7 @@ def generate_session_displacement_recordings( ) # Generate the (possibly shifted, scaled) unit templates - template_array_moved = generate_templates( + templates_array_moved = generate_templates( channel_locations, unit_locations_moved, sampling_frequency=sampling_frequency, @@ -179,9 +174,8 @@ def generate_session_displacement_recordings( ) if recording_amplitude_scalings is not None: - - template_array_moved = _amplitude_scale_templates_in_place( - template_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx + _amplitude_scale_templates_in_place( + templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx ) # Bring it all together in a `InjectTemplatesRecording` and @@ -191,7 +185,7 @@ def generate_session_displacement_recordings( recording = InjectTemplatesRecording( sorting=sorting, - templates=template_array_moved, + templates=templates_array_moved, nbefore=nbefore, amplitude_factor=None, parent_recording=noise, @@ -208,7 +202,7 @@ def generate_session_displacement_recordings( output_recordings.append(recording) output_sortings.append(sorting) extra_outputs_dict["unit_locations"].append(unit_locations_moved) - extra_outputs_dict["template_array_moved"].append(template_array_moved) + extra_outputs_dict["templates_array_moved"].append(templates_array_moved) if extra_outputs: return output_recordings, output_sortings, extra_outputs_dict @@ -344,12 +338,13 @@ def _check_generate_session_displacement_arguments( if not "method" in keys or not "scalings" in keys: raise ValueError("`recording_amplitude_scalings` must be a dict " "with keys `method` and `scalings`.") - allowed_methods = ["by_passed_value", "by_amplitude_and_firing_rate", "by_firing_rate"] + allowed_methods = ["by_passed_order", "by_amplitude_and_firing_rate", "by_firing_rate"] if not recording_amplitude_scalings["method"] in allowed_methods: raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}") rec_scalings = recording_amplitude_scalings["scalings"] if not len(rec_scalings) == expected_num_recs: + breakpoint() raise ValueError("`recording_amplitude_scalings` 'scalings' " "must have one array per recording.") if not all([len(scale) == num_units for scale in rec_scalings]): diff --git a/src/spikeinterface/generation/tests/test_session_displacement_generator.py b/src/spikeinterface/generation/tests/test_session_displacement_generator.py index 8bf689df78..645b1a6cce 100644 --- a/src/spikeinterface/generation/tests/test_session_displacement_generator.py +++ b/src/spikeinterface/generation/tests/test_session_displacement_generator.py @@ -7,11 +7,6 @@ from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks -# TODO: test templates_array_moved are the same with -# no shift, with both seed and no seed - -# rescale units per session - class TestSessionDisplacementGenerator: """ @@ -97,14 +92,14 @@ def test_x_y_rigid_shifts_are_properly_set(self, options): for unit_idx in range(num_units): start_pos = self._get_peak_chan_loc_in_um( - extra_outputs["template_array_moved"][0][unit_idx], + extra_outputs["templates_array_moved"][0][unit_idx], options["y_bin_um"], ) for rec_idx in range(1, options["num_recs"]): new_pos = self._get_peak_chan_loc_in_um( - extra_outputs["template_array_moved"][rec_idx][unit_idx], options["y_bin_um"] + extra_outputs["templates_array_moved"][rec_idx][unit_idx], options["y_bin_um"] ) y_shift = recording_shifts[rec_idx][1] @@ -120,7 +115,7 @@ def test_x_y_rigid_shifts_are_properly_set(self, options): for rec_idx in range(options["num_recs"]): assert np.array_equal( output_recordings[rec_idx].templates, - extra_outputs["template_array_moved"][rec_idx], + extra_outputs["templates_array_moved"][rec_idx], ) def _get_peak_chan_loc_in_um(self, template_array, y_bin_um): @@ -275,6 +270,56 @@ def test_displacement_with_peak_detection(self, options): assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"]) + def test_amplitude_scalings(self, options): + + options["kwargs"]["recording_durations"] = (10, 10) + options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0)) + options["kwargs"]["num_units"] == 5, + + recording_amplitude_scalings = { + "method": "by_passed_order", + "scalings": (np.ones(5), np.array([0.1, 0.2, 0.3, 0.4, 0.5])), + } + + _, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], + recording_amplitude_scalings=recording_amplitude_scalings, + ) + breakpoint() + first, second = extra_outputs["templates_array_moved"] # TODO: own function + first_min = np.min(np.min(first, axis=2), axis=1) + second_min = np.min(np.min(second, axis=2), axis=1) + scales = second_min / first_min + + assert np.allclose(scales, shifts) + + # TODO: scale based on recording output + # check scaled by amplitude. + + breakpoint() + + def test_metadata(self, options): + """ + Check that metadata required to be set of generated recordings is present + on all output recordings. + """ + output_recordings, output_sortings, extra_outputs = generate_session_displacement_recordings( + **options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0) + ) + num_chans = output_recordings[0].get_num_channels() + + for i in range(len(output_recordings)): + assert output_recordings[i].name == "InterSessionDisplacementRecording" + assert output_recordings[i]._annotations["is_filtered"] is True + assert output_recordings[i].has_probe() + assert np.array_equal(output_recordings[i].get_channel_gains(), np.ones(num_chans)) + assert np.array_equal(output_recordings[i].get_channel_offsets(), np.zeros(num_chans)) + + assert np.array_equal( + output_sortings[i].get_property("gt_unit_locations"), extra_outputs["unit_locations"][i] + ) + assert output_sortings[i].name == "InterSessionDisplacementSorting" + def test_same_as_generate_ground_truth_recording(self): """ It is expected that inter-session displacement randomly @@ -302,7 +347,7 @@ def test_same_as_generate_ground_truth_recording(self): no_shift_recording, _ = generate_session_displacement_recordings( num_units=num_units, recording_durations=[duration], - recording_shifts=((0, 0)), + recording_shifts=((0, 0),), sampling_frequency=sampling_frequency, probe_name=probe_name, generate_probe_kwargs=generate_probe_kwargs,