Skip to content

Commit

Permalink
Start finalising tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Jul 24, 2024
1 parent 9798c75 commit 5f012a2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 22 deletions.
21 changes: 8 additions & 13 deletions src/spikeinterface/generation/session_displacement_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
from spikeinterface.core import InjectTemplatesRecording


# TODO: test metadata
# TOOD: test new amplitude scalings
# TODO: test correct unit_locations are on the sortings (part of metadata)


def generate_session_displacement_recordings(
num_units=250,
recording_durations=(10, 10, 10),
Expand Down Expand Up @@ -131,7 +126,7 @@ def generate_session_displacement_recordings(
# and scaled as required
extra_outputs_dict = {
"unit_locations": [],
"template_array_moved": [],
"templates_array_moved": [],
}
output_recordings = []
output_sortings = []
Expand Down Expand Up @@ -170,7 +165,7 @@ def generate_session_displacement_recordings(
)

# Generate the (possibly shifted, scaled) unit templates
template_array_moved = generate_templates(
templates_array_moved = generate_templates(
channel_locations,
unit_locations_moved,
sampling_frequency=sampling_frequency,
Expand All @@ -179,9 +174,8 @@ def generate_session_displacement_recordings(
)

if recording_amplitude_scalings is not None:

template_array_moved = _amplitude_scale_templates_in_place(
template_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
_amplitude_scale_templates_in_place(
templates_array_moved, recording_amplitude_scalings, sorting_extra_outputs, rec_idx
)

# Bring it all together in a `InjectTemplatesRecording` and
Expand All @@ -191,7 +185,7 @@ def generate_session_displacement_recordings(

recording = InjectTemplatesRecording(
sorting=sorting,
templates=template_array_moved,
templates=templates_array_moved,
nbefore=nbefore,
amplitude_factor=None,
parent_recording=noise,
Expand All @@ -208,7 +202,7 @@ def generate_session_displacement_recordings(
output_recordings.append(recording)
output_sortings.append(sorting)
extra_outputs_dict["unit_locations"].append(unit_locations_moved)
extra_outputs_dict["template_array_moved"].append(template_array_moved)
extra_outputs_dict["templates_array_moved"].append(templates_array_moved)

if extra_outputs:
return output_recordings, output_sortings, extra_outputs_dict
Expand Down Expand Up @@ -344,12 +338,13 @@ def _check_generate_session_displacement_arguments(
if not "method" in keys or not "scalings" in keys:
raise ValueError("`recording_amplitude_scalings` must be a dict " "with keys `method` and `scalings`.")

allowed_methods = ["by_passed_value", "by_amplitude_and_firing_rate", "by_firing_rate"]
allowed_methods = ["by_passed_order", "by_amplitude_and_firing_rate", "by_firing_rate"]
if not recording_amplitude_scalings["method"] in allowed_methods:
raise ValueError(f"`recording_amplitude_scalings` must be one of {allowed_methods}")

rec_scalings = recording_amplitude_scalings["scalings"]
if not len(rec_scalings) == expected_num_recs:
breakpoint()
raise ValueError("`recording_amplitude_scalings` 'scalings' " "must have one array per recording.")

if not all([len(scale) == num_units for scale in rec_scalings]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

# TODO: test templates_array_moved are the same with
# no shift, with both seed and no seed

# rescale units per session


class TestSessionDisplacementGenerator:
"""
Expand Down Expand Up @@ -97,14 +92,14 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
for unit_idx in range(num_units):

start_pos = self._get_peak_chan_loc_in_um(
extra_outputs["template_array_moved"][0][unit_idx],
extra_outputs["templates_array_moved"][0][unit_idx],
options["y_bin_um"],
)

for rec_idx in range(1, options["num_recs"]):

new_pos = self._get_peak_chan_loc_in_um(
extra_outputs["template_array_moved"][rec_idx][unit_idx], options["y_bin_um"]
extra_outputs["templates_array_moved"][rec_idx][unit_idx], options["y_bin_um"]
)

y_shift = recording_shifts[rec_idx][1]
Expand All @@ -120,7 +115,7 @@ def test_x_y_rigid_shifts_are_properly_set(self, options):
for rec_idx in range(options["num_recs"]):
assert np.array_equal(
output_recordings[rec_idx].templates,
extra_outputs["template_array_moved"][rec_idx],
extra_outputs["templates_array_moved"][rec_idx],
)

def _get_peak_chan_loc_in_um(self, template_array, y_bin_um):
Expand Down Expand Up @@ -275,6 +270,56 @@ def test_displacement_with_peak_detection(self, options):

assert np.isclose(new_pos, first_pos + y_shift, rtol=0, atol=options["y_bin_um"])

def test_amplitude_scalings(self, options):

options["kwargs"]["recording_durations"] = (10, 10)
options["kwargs"]["recording_shifts"] = ((0, 0), (0, 0))
options["kwargs"]["num_units"] == 5,

recording_amplitude_scalings = {
"method": "by_passed_order",
"scalings": (np.ones(5), np.array([0.1, 0.2, 0.3, 0.4, 0.5])),
}

_, output_sortings, extra_outputs = generate_session_displacement_recordings(
**options["kwargs"],
recording_amplitude_scalings=recording_amplitude_scalings,
)
breakpoint()
first, second = extra_outputs["templates_array_moved"] # TODO: own function
first_min = np.min(np.min(first, axis=2), axis=1)
second_min = np.min(np.min(second, axis=2), axis=1)
scales = second_min / first_min

assert np.allclose(scales, shifts)

# TODO: scale based on recording output
# check scaled by amplitude.

breakpoint()

def test_metadata(self, options):
"""
Check that metadata required to be set of generated recordings is present
on all output recordings.
"""
output_recordings, output_sortings, extra_outputs = generate_session_displacement_recordings(
**options["kwargs"], generate_noise_kwargs=dict(noise_levels=(1.0, 2.0), spatial_decay=1.0)
)
num_chans = output_recordings[0].get_num_channels()

for i in range(len(output_recordings)):
assert output_recordings[i].name == "InterSessionDisplacementRecording"
assert output_recordings[i]._annotations["is_filtered"] is True
assert output_recordings[i].has_probe()
assert np.array_equal(output_recordings[i].get_channel_gains(), np.ones(num_chans))
assert np.array_equal(output_recordings[i].get_channel_offsets(), np.zeros(num_chans))

assert np.array_equal(
output_sortings[i].get_property("gt_unit_locations"), extra_outputs["unit_locations"][i]
)
assert output_sortings[i].name == "InterSessionDisplacementSorting"

def test_same_as_generate_ground_truth_recording(self):
"""
It is expected that inter-session displacement randomly
Expand Down Expand Up @@ -302,7 +347,7 @@ def test_same_as_generate_ground_truth_recording(self):
no_shift_recording, _ = generate_session_displacement_recordings(
num_units=num_units,
recording_durations=[duration],
recording_shifts=((0, 0)),
recording_shifts=((0, 0),),
sampling_frequency=sampling_frequency,
probe_name=probe_name,
generate_probe_kwargs=generate_probe_kwargs,
Expand Down

0 comments on commit 5f012a2

Please sign in to comment.