Skip to content

Commit

Permalink
Add input checks.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jul 24, 2024
1 parent faa7e87 commit 3337037
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 25 deletions.
6 changes: 3 additions & 3 deletions debugging/debugging_session_displacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
rec_list, _ = generate_session_displacement_recordings(
non_rigid_gradient=None, # 0.05, TODO: note this will set nonlinearity to both x and y (the same)
num_units=5,
rec_durations=(25, 25, 25), # TODO: checks on inputs
rec_shifts=(
recording_durations=(25, 25, 25), # TODO: checks on inputs
recording_shifts=(
(0, 0),
(0, 0),
(0, 0),
),
amplitude_scaling_options= {
recording_amplitude_scalings= {
"method": "by_amplitude_and_firing_rate",
"scalings": scale_,
},
Expand Down
71 changes: 49 additions & 22 deletions src/spikeinterface/generation/session_displacement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# TODO: test metadata
# TOOD: test new amplitude scalings
# TODO:
# TODO: test correct unit_locations are on the sortings (part of metadata)


def generate_session_displacement_recordings(
Expand Down Expand Up @@ -53,7 +53,7 @@ def generate_session_displacement_recordings(
seed=None,
):
""" """
_check_generate_session_displacement_inputs(
_check_generate_session_displacement_arguments(
num_units, recording_durations, recording_shifts, recording_amplitude_scalings
)

Expand All @@ -68,17 +68,18 @@ def generate_session_displacement_recordings(
**generate_unit_locations_kwargs,
)

# Fix generate template kwargs so they
# are the same for every created recording.
# Fix generate template kwargs, so they are the same for every created recording.
generate_templates_kwargs = fix_generate_templates_kwargs(generate_templates_kwargs, num_units, seed)

output_recordings = []
output_sortings = []

# Start looping over parameters, creating recordings shifted
# and scaled as required
extra_outputs_dict = {
"unit_locations": [],
"template_array_moved": [],
}
output_recordings = []
output_sortings = []

for rec_idx, (shift, duration) in enumerate(zip(recording_shifts, recording_durations)):

displacement_vector, displacement_unit_factor = get_inter_session_displacements(
Expand Down Expand Up @@ -179,11 +180,13 @@ def get_inter_session_displacements(shift, non_rigid_gradient, num_units, unit_l

def amplitude_scale_templates_in_place(templates_array, recording_amplitude_scalings, sorting_extra_outputs, rec_idx):
""" """
if recording_amplitude_scalings["method"] in ["by_amplitude_and_firing_rate", "by_firing_rate"]:
method = recording_amplitude_scalings["method"]

if method in ["by_amplitude_and_firing_rate", "by_firing_rate"]:

firing_rates_hz = sorting_extra_outputs["firing_rates"][0]

if recording_amplitude_scalings["method"] == "by_amplitude_and_firing_rate":
if method == "by_amplitude_and_firing_rate":
neg_ampl = np.min(np.min(templates_array, axis=2), axis=1)
score = firing_rates_hz * neg_ampl
else:
Expand All @@ -193,27 +196,51 @@ def amplitude_scale_templates_in_place(templates_array, recording_amplitude_scal
order_idx = np.argsort(score)
ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][order_idx, np.newaxis, np.newaxis]

elif recording_amplitude_scalings["method"] == "by_passed_order":
elif method == "by_passed_order":

ordered_rec_scalings = recording_amplitude_scalings["scalings"][rec_idx][:, np.newaxis, np.newaxis]

else:
raise ValueError(
"`recording_amplitude_scalings` 'method' entry must be "
"'by_amplitude_and_firing_rate', 'by_firing_rate' or "
"'by_passed_order'."
)
raise ValueError("`recording_amplitude_scalings['method']` not recognised.")

templates_array *= ordered_rec_scalings


def _check_generate_session_displacement_inputs(
def _check_generate_session_displacement_arguments(
num_units, recording_durations, recording_shifts, recording_amplitude_scalings
):
breakpoint()
"""
Function to check the input arguments related to recording
shift and scale parameters are the correct size.
"""
expected_num_recs = len(recording_durations)

if len(recording_shifts) != expected_num_recs:
raise ValueError(
"`recording_shifts` and `recording_durations` must be "
"the same length, the number of recordings to generate."
)

shifts_are_2d = [len(shift) == 2 for shift in recording_shifts]
if not all(shifts_are_2d):
raise ValueError("Each record entry for `recording_shifts` must have " "two elements, the x and y shift.")

if recording_amplitude_scalings is not None:

# TODO: a lot of input checks
# # assert len is the same
# amplitude_scalings = get_unit_amplitude_scalings(
# templates_moved_array, recording_amplitude_scalings
# )
keys = recording_amplitude_scalings.keys()
if not "method" in keys or not "scalings" in keys:
raise ValueError("`recording_amplitude_scalings` must be a dict " "with keys `method` and `scalings`.")

allowed_methods = ["by_passed_value", "by_amplitude_and_firing_rate", "by_firing_rate"]
if not recording_amplitude_scalings["method"] in allowed_methods:
raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}")

rec_scalings = recording_amplitude_scalings["scalings"]
if not len(rec_scalings) == expected_num_recs:
raise ValueError("`recording_amplitude_scalings` 'scalings' " "must have one array per recording.")

if not all([len(scale) == num_units for scale in rec_scalings]):
raise ValueError(
"The entry for each recording in `recording_amplitude_scalings` "
"must have the same length as the number of units."
)

0 comments on commit 3337037

Please sign in to comment.