Skip to content

Commit

Permalink
add tests for select units for zarr
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Jun 20, 2024
1 parent c504fc6 commit 861857f
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,16 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset):
sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None
)

sorting_analyzer.compute(["random_spikes", "templates"])
sorting_analyzer = load_sorting_analyzer(folder, format="auto")
_check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path)

# test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041
# this bug requires that we have an info.json file so we calculate templates above
select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1])
assert len(select_units_sorting_analyer.unit_ids) == 1

folder = tmp_path / "test_SortingAnalyzer_binary_folder"
if folder.exists():
shutil.rmtree(folder)
Expand Down Expand Up @@ -97,9 +104,15 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None
)
sorting_analyzer.compute(["random_spikes", "templates"])
sorting_analyzer = load_sorting_analyzer(folder, format="auto")
_check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path)

# test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041
# this bug requires that we have an info.json file so we calculate templates above
select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1])
assert len(select_units_sorting_analyer.unit_ids) == 1

folder = tmp_path / "test_SortingAnalyzer_zarr.zarr"
if folder.exists():
shutil.rmtree(folder)
Expand Down Expand Up @@ -312,7 +325,7 @@ def test_extensions_sorting():

if __name__ == "__main__":
tmp_path = Path("test_SortingAnalyzer")
dataset = _get_dataset()
dataset = get_dataset()
test_SortingAnalyzer_memory(tmp_path, dataset)
test_SortingAnalyzer_binary_folder(tmp_path, dataset)
test_SortingAnalyzer_zarr(tmp_path, dataset)
Expand Down

0 comments on commit 861857f

Please sign in to comment.