Skip to content

Commit

Permalink
Add tests for unique names in channel slice and unit selection (#2258)
Browse files Browse the repository at this point in the history
* add tests for unique names in channel slice and unique selection

* add tests for unique names in channel slice and unique selection

* better names

* better names

* fix test once again

---------

Co-authored-by: Alessio Buccino <[email protected]>
  • Loading branch information
h-mayorquin and alejoe91 authored Nov 24, 2023
1 parent d1203ea commit a4e201c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 22 deletions.
9 changes: 9 additions & 0 deletions src/spikeinterface/core/tests/test_channelslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import probeinterface

from spikeinterface.core import ChannelSliceRecording, BinaryRecordingExtractor
from spikeinterface.core.generate import generate_recording


def test_ChannelSliceRecording():
Expand Down Expand Up @@ -73,5 +74,13 @@ def test_ChannelSliceRecording():
assert np.all(traces3[:, 1] == 2)


def test_failure_with_non_unique_channel_ids():
durations = [1.0]
seed = 10
rec = generate_recording(num_channels=4, durations=durations, set_probe=False, seed=seed)
with pytest.raises(AssertionError):
rec_sliced = ChannelSliceRecording(rec, channel_ids=[0, 1], renamed_channel_ids=[0, 0])


if __name__ == "__main__":
test_ChannelSliceRecording()
33 changes: 11 additions & 22 deletions src/spikeinterface/core/tests/test_unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,16 @@

from spikeinterface.core import UnitsSelectionSorting

from spikeinterface.core import NpzSortingExtractor, load_extractor
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core.generate import generate_sorting

from spikeinterface.core import create_sorting_npz


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"


def test_unitsselectionsorting():
num_seg = 2
file_path = cache_folder / "test_BaseSorting.npz"

create_sorting_npz(num_seg, file_path)

sorting = NpzSortingExtractor(file_path)
print(sorting)
print(sorting.unit_ids)
def test_basic_functions():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2])
print(sorting2.unit_ids)
assert np.array_equal(sorting2.unit_ids, [0, 2])

sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"])
print(sorting3.unit_ids)
assert np.array_equal(sorting3.unit_ids, ["a", "b"])

assert np.array_equal(
Expand All @@ -49,5 +31,12 @@ def test_unitsselectionsorting():
)


def test_failure_with_non_unique_unit_ids():
seed = 10
sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed)
with pytest.raises(AssertionError):
sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"])


if __name__ == "__main__":
test_unitsselectionsorting()
test_basic_functions()

0 comments on commit a4e201c

Please sign in to comment.