From 40429b4c325b904835d8b2507e4e9667454df2be Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Mar 2024 11:26:25 +0100 Subject: [PATCH 1/4] Extended zarr compression --- src/spikeinterface/core/base.py | 17 ++++- .../core/tests/test_zarrextractors.py | 65 ++++++++++++++----- src/spikeinterface/core/zarrextractors.py | 30 +++++++-- 3 files changed, 89 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 80341811b9..d25f1bf97b 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -967,9 +967,24 @@ def save_to_zarr( For cloud storage locations, this should not be None (in case of default values, use an empty dict) channel_chunk_size: int or None, default: None Channels per chunk (only for BaseRecording) + compressor: numcodecs.Codec or None, default: None + Global compressor. If None, Blosc-zstd level 5 is used. + filters: list[numcodecs.Codec] or None, default: None + Global filters for zarr (global) + compressor_by_dataset: dict or None, default: None + Optional compressor per dataset.: + - traces + - times + If None, the global compressor is used + filters_by_dataset: dict or None, default: None + Optional filters per dataset: + - traces + - times + If None, the global filters are used verbose: bool, default: True If True, the output is verbose - **save_kwargs: Keyword arguments for saving to zarr + auto_cast_uint: bool, default: True + If True, unsigned integers are cast to signed integers to avoid issues with zarr (only for BaseRecording) Returns ------- diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index 72247cb42a..2fc1f42ec5 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -1,39 +1,72 @@ import pytest from pathlib import Path -import shutil - import zarr from spikeinterface.core import ( ZarrRecordingExtractor, ZarrSortingExtractor, + generate_recording, generate_sorting, load_extractor, ) -from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group +from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor + + +def test_zarr_compression_options(tmp_path): + from numcodecs import Blosc, Delta, FixedScaleOffset + + recording = generate_recording(durations=[2]) + recording.set_times(recording.get_times() + 100) + + # store in root standard normal way + # default compressor + defaut_compressor = get_default_zarr_compressor() + + # other compressor + other_compressor1 = Blosc(cname="zlib", clevel=3, shuffle=Blosc.NOSHUFFLE) + other_compressor2 = Blosc(cname="blosclz", clevel=8, shuffle=Blosc.AUTOSHUFFLE) + + # timestamps compressors / filters + default_filters = None + other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype())] + other_filters2 = [Delta(dtype="float64")] + + # default + ZarrRecordingExtractor.write_recording(recording, tmp_path / "rec_default.zarr") + rec_default = ZarrRecordingExtractor(tmp_path / "rec_default.zarr") + assert rec_default._root["traces_seg0"].compressor == defaut_compressor + assert rec_default._root["traces_seg0"].filters == default_filters + assert rec_default._root["times_seg0"].compressor == defaut_compressor + assert rec_default._root["times_seg0"].filters == default_filters -if hasattr(pytest, "global_test_folder"): - cache_folder = pytest.global_test_folder / "core" -else: - cache_folder = Path("cache_folder") / "core" + # now with other compressor + ZarrRecordingExtractor.write_recording( + recording, + tmp_path / "rec_other.zarr", + compressor=defaut_compressor, + filters=default_filters, + compressor_by_dataset={"traces": other_compressor1, "times": other_compressor2}, + filters_by_dataset={"traces": other_filters1, "times": other_filters2}, + ) + rec_other = ZarrRecordingExtractor(tmp_path / "rec_other.zarr") + assert rec_other._root["traces_seg0"].compressor == other_compressor1 + assert rec_other._root["traces_seg0"].filters == other_filters1 + assert rec_other._root["times_seg0"].compressor == other_compressor2 + assert rec_other._root["times_seg0"].filters == other_filters2 -def test_ZarrSortingExtractor(): +def test_ZarrSortingExtractor(tmp_path): np_sorting = generate_sorting() # store in root standard normal way - folder = cache_folder / "zarr_sorting" - if folder.is_dir(): - shutil.rmtree(folder) + folder = tmp_path / "zarr_sorting" ZarrSortingExtractor.write_sorting(np_sorting, folder) sorting = ZarrSortingExtractor(folder) sorting = load_extractor(sorting.to_dict()) # store the sorting in a sub group (for instance SortingResult) - folder = cache_folder / "zarr_sorting_sub_group" - if folder.is_dir(): - shutil.rmtree(folder) + folder = tmp_path / "zarr_sorting_sub_group" zarr_root = zarr.open(folder, mode="w") zarr_sorting_group = zarr_root.create_group("sorting") add_sorting_to_zarr_group(sorting, zarr_sorting_group) @@ -43,4 +76,6 @@ def test_ZarrSortingExtractor(): if __name__ == "__main__": - test_ZarrSortingExtractor() + tmp_path = Path("tmp") + test_zarr_compression_options(tmp_path) + test_ZarrSortingExtractor(tmp_path) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index a8a23b5863..47e2ea2849 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -366,7 +366,9 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G # Recording -def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hierarchy.Group, **kwargs): +def add_recording_to_zarr_group( + recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, **kwargs +): zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) if recording.check_if_json_serializable(): @@ -380,15 +382,25 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] - zarr_kwargs["dtype"] = kwargs.get("dtype", None) or recording.get_dtype() - if "compressor" not in zarr_kwargs: - zarr_kwargs["compressor"] = get_default_zarr_compressor() + dtype = zarr_kwargs.get("dtype", None) or recording.get_dtype() + channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) + global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor()) + compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {}) + global_filters = zarr_kwargs.pop("filters", None) + filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {}) + compressor_traces = compressor_by_dataset.get("traces", global_compressor) + filters_traces = filters_by_dataset.get("traces", global_filters) add_traces_to_zarr( recording=recording, zarr_group=zarr_group, dataset_paths=dataset_paths, - **zarr_kwargs, + compressor=compressor_traces, + filters=filters_traces, + dtype=dtype, + channel_chunk_size=channel_chunk_size, + auto_cast_uint=auto_cast_uint, + verbose=verbose, **job_kwargs, ) @@ -402,12 +414,16 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera for segment_index, rs in enumerate(recording._recording_segments): d = rs.get_times_kwargs() time_vector = d["time_vector"] + + compressor_times = compressor_by_dataset.get("times", global_compressor) + filters_times = filters_by_dataset.get("times", global_filters) + if time_vector is not None: _ = zarr_group.create_dataset( name=f"times_seg{segment_index}", data=time_vector, - filters=zarr_kwargs.get("filters", None), - compressor=zarr_kwargs["compressor"], + filters=filters_times, + compressor=compressor_times, ) elif d["t_start"] is not None: t_starts[segment_index] = d["t_start"] From 05b5fdaa9a496c23d82045dd4a76db18f0438e2a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Mar 2024 11:54:48 +0100 Subject: [PATCH 2/4] Port #2643 --- src/spikeinterface/core/zarrextractors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 47e2ea2849..4a0c5f8eef 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -255,7 +255,7 @@ def read_zarr( The loaded extractor """ # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. - # for the futur SortingResult we will have this 2 fields!!! + # for the futur SortingAnalyzer we will have this 2 fields!!! root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options) @@ -367,7 +367,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G # Recording def add_recording_to_zarr_group( - recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, **kwargs + recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, dtype=None, **kwargs ): zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) @@ -382,7 +382,7 @@ def add_recording_to_zarr_group( zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] - dtype = zarr_kwargs.get("dtype", None) or recording.get_dtype() + dtype = recording.get_dtype() if dtype is None else dtype channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor()) compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {}) From 27d1676abf5ce7b037792303008b7b40715f7c9b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Mar 2024 12:00:00 +0100 Subject: [PATCH 3/4] Add release notes --- doc/releases/0.100.4.rst | 10 ++++++++++ doc/whatisnew.rst | 13 ++++++++++--- pyproject.toml | 2 +- 3 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 doc/releases/0.100.4.rst diff --git a/doc/releases/0.100.4.rst b/doc/releases/0.100.4.rst new file mode 100644 index 0000000000..bea358053d --- /dev/null +++ b/doc/releases/0.100.4.rst @@ -0,0 +1,10 @@ +.. _release0.100.4: + +SpikeInterface 0.100.4 release notes +------------------------------------ + +29th March 2024 + +Minor release with improved compression capability for Zarr + +* Extend zarr compression options (#2643) diff --git a/doc/whatisnew.rst b/doc/whatisnew.rst index 3063db51f5..3c9f2b44c7 100644 --- a/doc/whatisnew.rst +++ b/doc/whatisnew.rst @@ -8,6 +8,7 @@ Release notes .. toctree:: :maxdepth: 1 + releases/0.100.4.rst releases/0.100.3.rst releases/0.100.2.rst releases/0.100.1.rst @@ -37,20 +38,26 @@ Release notes releases/0.9.1.rst +Version 0.100.4 +=============== + +* Minor release with extended compression capability for Zarr + + Version 0.100.3 -============== +=============== * Minor release with bug fixes for Zarr compressor and NWB in container Version 0.100.2 -============== +=============== * Minor release with fix for running Kilosort4 with GPU support in container Version 0.100.1 -============== +=============== * Minor release with some bug fixes and Kilosort4 support diff --git a/pyproject.toml b/pyproject.toml index 2a2a072fc8..1c9bc56ac7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "spikeinterface" -version = "0.100.3" +version = "0.100.4" authors = [ { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, { name="Samuel Garcia", email="sam.garcia.die@gmail.com" }, From 7867b8947c6360d5c3cdc330397050d842891197 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 29 Mar 2024 12:30:42 +0100 Subject: [PATCH 4/4] Remove extra . --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d25f1bf97b..ff9f841e8e 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -972,7 +972,7 @@ def save_to_zarr( filters: list[numcodecs.Codec] or None, default: None Global filters for zarr (global) compressor_by_dataset: dict or None, default: None - Optional compressor per dataset.: + Optional compressor per dataset: - traces - times If None, the global compressor is used