Skip to content

Commit

Permalink
Merge pull request #2644 from alejoe91/prepare-0.100.4
Browse files Browse the repository at this point in the history
Prepare 0.100.4
  • Loading branch information
alejoe91 authored Mar 29, 2024
2 parents 22808ca + 7867b89 commit 7d0e1da
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 28 deletions.
10 changes: 10 additions & 0 deletions doc/releases/0.100.4.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
.. _release0.100.4:

SpikeInterface 0.100.4 release notes
------------------------------------

29th March 2024

Minor release with improved compression capability for Zarr

* Extend zarr compression options (#2643)
13 changes: 10 additions & 3 deletions doc/whatisnew.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Release notes
.. toctree::
:maxdepth: 1

releases/0.100.4.rst
releases/0.100.3.rst
releases/0.100.2.rst
releases/0.100.1.rst
Expand Down Expand Up @@ -37,20 +38,26 @@ Release notes
releases/0.9.1.rst


Version 0.100.4
===============

* Minor release with extended compression capability for Zarr


Version 0.100.3
==============
===============

* Minor release with bug fixes for Zarr compressor and NWB in container


Version 0.100.2
==============
===============

* Minor release with fix for running Kilosort4 with GPU support in container


Version 0.100.1
==============
===============

* Minor release with some bug fixes and Kilosort4 support

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "spikeinterface"
version = "0.100.3"
version = "0.100.4"
authors = [
{ name="Alessio Buccino", email="[email protected]" },
{ name="Samuel Garcia", email="[email protected]" },
Expand Down
17 changes: 16 additions & 1 deletion src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,24 @@ def save_to_zarr(
For cloud storage locations, this should not be None (in case of default values, use an empty dict)
channel_chunk_size: int or None, default: None
Channels per chunk (only for BaseRecording)
compressor: numcodecs.Codec or None, default: None
Global compressor. If None, Blosc-zstd level 5 is used.
filters: list[numcodecs.Codec] or None, default: None
Global filters for zarr (global)
compressor_by_dataset: dict or None, default: None
Optional compressor per dataset:
- traces
- times
If None, the global compressor is used
filters_by_dataset: dict or None, default: None
Optional filters per dataset:
- traces
- times
If None, the global filters are used
verbose: bool, default: True
If True, the output is verbose
**save_kwargs: Keyword arguments for saving to zarr
auto_cast_uint: bool, default: True
If True, unsigned integers are cast to signed integers to avoid issues with zarr (only for BaseRecording)
Returns
-------
Expand Down
65 changes: 50 additions & 15 deletions src/spikeinterface/core/tests/test_zarrextractors.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,72 @@
import pytest
from pathlib import Path

import shutil

import zarr

from spikeinterface.core import (
ZarrRecordingExtractor,
ZarrSortingExtractor,
generate_recording,
generate_sorting,
load_extractor,
)
from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group
from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor


def test_zarr_compression_options(tmp_path):
from numcodecs import Blosc, Delta, FixedScaleOffset

recording = generate_recording(durations=[2])
recording.set_times(recording.get_times() + 100)

# store in root standard normal way
# default compressor
defaut_compressor = get_default_zarr_compressor()

# other compressor
other_compressor1 = Blosc(cname="zlib", clevel=3, shuffle=Blosc.NOSHUFFLE)
other_compressor2 = Blosc(cname="blosclz", clevel=8, shuffle=Blosc.AUTOSHUFFLE)

# timestamps compressors / filters
default_filters = None
other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype())]
other_filters2 = [Delta(dtype="float64")]

# default
ZarrRecordingExtractor.write_recording(recording, tmp_path / "rec_default.zarr")
rec_default = ZarrRecordingExtractor(tmp_path / "rec_default.zarr")
assert rec_default._root["traces_seg0"].compressor == defaut_compressor
assert rec_default._root["traces_seg0"].filters == default_filters
assert rec_default._root["times_seg0"].compressor == defaut_compressor
assert rec_default._root["times_seg0"].filters == default_filters

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"
# now with other compressor
ZarrRecordingExtractor.write_recording(
recording,
tmp_path / "rec_other.zarr",
compressor=defaut_compressor,
filters=default_filters,
compressor_by_dataset={"traces": other_compressor1, "times": other_compressor2},
filters_by_dataset={"traces": other_filters1, "times": other_filters2},
)
rec_other = ZarrRecordingExtractor(tmp_path / "rec_other.zarr")
assert rec_other._root["traces_seg0"].compressor == other_compressor1
assert rec_other._root["traces_seg0"].filters == other_filters1
assert rec_other._root["times_seg0"].compressor == other_compressor2
assert rec_other._root["times_seg0"].filters == other_filters2


def test_ZarrSortingExtractor():
def test_ZarrSortingExtractor(tmp_path):
np_sorting = generate_sorting()

# store in root standard normal way
folder = cache_folder / "zarr_sorting"
if folder.is_dir():
shutil.rmtree(folder)
folder = tmp_path / "zarr_sorting"
ZarrSortingExtractor.write_sorting(np_sorting, folder)
sorting = ZarrSortingExtractor(folder)
sorting = load_extractor(sorting.to_dict())

# store the sorting in a sub group (for instance SortingResult)
folder = cache_folder / "zarr_sorting_sub_group"
if folder.is_dir():
shutil.rmtree(folder)
folder = tmp_path / "zarr_sorting_sub_group"
zarr_root = zarr.open(folder, mode="w")
zarr_sorting_group = zarr_root.create_group("sorting")
add_sorting_to_zarr_group(sorting, zarr_sorting_group)
Expand All @@ -43,4 +76,6 @@ def test_ZarrSortingExtractor():


if __name__ == "__main__":
test_ZarrSortingExtractor()
tmp_path = Path("tmp")
test_zarr_compression_options(tmp_path)
test_ZarrSortingExtractor(tmp_path)
32 changes: 24 additions & 8 deletions src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def read_zarr(
The loaded extractor
"""
# TODO @alessio : we should have something more explicit in our zarr format to tell which object it is.
# for the futur SortingResult we will have this 2 fields!!!
# for the futur SortingAnalyzer we will have this 2 fields!!!
root = zarr.open(str(folder_path), mode="r", storage_options=storage_options)
if "channel_ids" in root.keys():
return read_zarr_recording(folder_path, storage_options=storage_options)
Expand Down Expand Up @@ -366,7 +366,9 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G


# Recording
def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hierarchy.Group, **kwargs):
def add_recording_to_zarr_group(
recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, auto_cast_uint=True, dtype=None, **kwargs
):
zarr_kwargs, job_kwargs = split_job_kwargs(kwargs)

if recording.check_if_json_serializable():
Expand All @@ -380,15 +382,25 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera
zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None)
dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())]

zarr_kwargs["dtype"] = kwargs.get("dtype", None) or recording.get_dtype()
if "compressor" not in zarr_kwargs:
zarr_kwargs["compressor"] = get_default_zarr_compressor()
dtype = recording.get_dtype() if dtype is None else dtype
channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None)
global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor())
compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {})
global_filters = zarr_kwargs.pop("filters", None)
filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {})

compressor_traces = compressor_by_dataset.get("traces", global_compressor)
filters_traces = filters_by_dataset.get("traces", global_filters)
add_traces_to_zarr(
recording=recording,
zarr_group=zarr_group,
dataset_paths=dataset_paths,
**zarr_kwargs,
compressor=compressor_traces,
filters=filters_traces,
dtype=dtype,
channel_chunk_size=channel_chunk_size,
auto_cast_uint=auto_cast_uint,
verbose=verbose,
**job_kwargs,
)

Expand All @@ -402,12 +414,16 @@ def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.hiera
for segment_index, rs in enumerate(recording._recording_segments):
d = rs.get_times_kwargs()
time_vector = d["time_vector"]

compressor_times = compressor_by_dataset.get("times", global_compressor)
filters_times = filters_by_dataset.get("times", global_filters)

if time_vector is not None:
_ = zarr_group.create_dataset(
name=f"times_seg{segment_index}",
data=time_vector,
filters=zarr_kwargs.get("filters", None),
compressor=zarr_kwargs["compressor"],
filters=filters_times,
compressor=compressor_times,
)
elif d["t_start"] is not None:
t_starts[segment_index] = d["t_start"]
Expand Down

0 comments on commit 7d0e1da

Please sign in to comment.