From 333703769f0c5b5e0b640b497033a193d268e376 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Wed, 24 Jul 2024 20:23:00 +0100 Subject: [PATCH] Add input checks. --- debugging/debugging_session_displacement.py | 6 +- .../session_displacement_generator.py | 71 +++++++++++++------ 2 files changed, 52 insertions(+), 25 deletions(-) diff --git a/debugging/debugging_session_displacement.py b/debugging/debugging_session_displacement.py index 79ba64cbc5..74140c21c7 100644 --- a/debugging/debugging_session_displacement.py +++ b/debugging/debugging_session_displacement.py @@ -40,13 +40,13 @@ rec_list, _ = generate_session_displacement_recordings( non_rigid_gradient=None, # 0.05, TODO: note this will set nonlinearity to both x and y (the same) num_units=5, - rec_durations=(25, 25, 25), # TODO: checks on inputs - rec_shifts=( + recording_durations=(25, 25, 25), # TODO: checks on inputs + recording_shifts=( (0, 0), (0, 0), (0, 0), ), - amplitude_scaling_options= { + recording_amplitude_scalings= { "method": "by_amplitude_and_firing_rate", "scalings": scale_, }, diff --git a/src/spikeinterface/generation/session_displacement_generator.py b/src/spikeinterface/generation/session_displacement_generator.py index 33203c0747..9bb30b994a 100644 --- a/src/spikeinterface/generation/session_displacement_generator.py +++ b/src/spikeinterface/generation/session_displacement_generator.py @@ -18,7 +18,7 @@ # TODO: test metadata # TOOD: test new amplitude scalings -# TODO: +# TODO: test correct unit_locations are on the sortings (part of metadata) def generate_session_displacement_recordings( @@ -53,7 +53,7 @@ def generate_session_displacement_recordings( seed=None, ): """ """ - _check_generate_session_displacement_inputs( + _check_generate_session_displacement_arguments( num_units, recording_durations, recording_shifts, recording_amplitude_scalings ) @@ -68,17 +68,18 @@ def generate_session_displacement_recordings( **generate_unit_locations_kwargs, ) - # Fix generate template kwargs so they - # are the same for every created recording. + # Fix generate template kwargs, so they are the same for every created recording. generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed) - output_recordings = [] - output_sortings = [] - + # Start looping over parameters, creating recordings shifted + # and scaled as required extra_outputs_dict = { "unit_locations": [], "template_array_moved": [], } + output_recordings = [] + output_sortings = [] + for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)): displacement_vector, displacement_unit_factor = get_inter_session_displacements( @@ -179,11 +180,13 @@ def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_l def amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx): """ """ - if recording_amplitude_scalings["method"] in ["by_amplitude_and_firing_rate", "by_firing_rate"]: + method = recording_amplitude_scalings["method"] + + if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]: firing_rates_hz = sorting_extra_outputs["firing_rates"][0] - if recording_amplitude_scalings["method"] == "by_amplitude_and_firing_rate": + if method == "by_amplitude_and_firing_rate": neg_ampl = np.min(np.min(templates_array, axis=2), axis=1) score = firing_rates_hz * neg_ampl else: @@ -193,27 +196,51 @@ def amplitude_scale_templates_in_place(templates_array, recording_amplitude_scal order_idx = np.argsort(score) ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis] - elif recording_amplitude_scalings["method"] == "by_passed_order": + elif method == "by_passed_order": ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][:, np.newaxis, np.newaxis] + else: - raise ValueError( - "`recording_amplitude_scalings` 'method' entry must be " - "'by_amplitude_and_firing_rate', 'by_firing_rate' or " - "'by_passed_order'." - ) + raise ValueError("`recording_amplitude_scalings['method']` not recognised.") templates_array *= ordered_rec_scalings -def _check_generate_session_displacement_inputs( +def _check_generate_session_displacement_arguments( num_units, recording_durations, recording_shifts, recording_amplitude_scalings ): - breakpoint() + """ + Function to check the input arguments related to recording + shift and scale parameters are the correct size. + """ + expected_num_recs = len(recording_durations) + + if len(recording_shifts) != expected_num_recs: + raise ValueError( + "`recording_shifts` and `recording_durations` must be " + "the same length, the number of recordings to generate." + ) + + shifts_are_2d = [len(shift) == 2 for shift in recording_shifts] + if not all(shifts_are_2d): + raise ValueError("Each record entry for `recording_shifts` must have " "two elements, the x and y shift.") + if recording_amplitude_scalings is not None: -# TODO: a lot of input checks -# # assert len is the same -# amplitude_scalings = get_unit_amplitude_scalings( -# templates_moved_array, recording_amplitude_scalings -# ) + keys = recording_amplitude_scalings.keys() + if not "method" in keys or not "scalings" in keys: + raise ValueError("`recording_amplitude_scalings` must be a dict " "with keys `method` and `scalings`.") + + allowed_methods = ["by_passed_value", "by_amplitude_and_firing_rate", "by_firing_rate"] + if not recording_amplitude_scalings["method"] in allowed_methods: + raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}") + + rec_scalings = recording_amplitude_scalings["scalings"] + if not len(rec_scalings) == expected_num_recs: + raise ValueError("`recording_amplitude_scalings` 'scalings' " "must have one array per recording.") + + if not all([len(scale) == num_units for scale in rec_scalings]): + raise ValueError( + "The entry for each recording in `recording_amplitude_scalings` " + "must have the same length as the number of units." + )