From 7221004cfbfc0361879ec5fa59c5f929bf68709c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 27 Sep 2024 09:36:49 +0200 Subject: [PATCH 01/10] Expose zarr_kwargs at the analyzer level to zarr dataset options --- src/spikeinterface/core/sortinganalyzer.py | 55 +++++++++++------ .../core/tests/test_sortinganalyzer.py | 61 ++++++++++++++++--- src/spikeinterface/core/zarrextractors.py | 3 +- 3 files changed, 90 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 4961db8524..5ffdc85e50 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -219,6 +219,9 @@ def __init__( # this is used to store temporary recording self._temporary_recording = None + # for zarr format, we store the kwargs to create zarr datasets (e.g., compression) + self._zarr_kwargs = {} + # extensions are not loaded at init self.extensions = dict() @@ -500,7 +503,7 @@ def _get_zarr_root(self, mode="r+"): return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): + def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, **zarr_kwargs): # used by create and save_as import zarr import numcodecs @@ -531,7 +534,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) else: warnings.warn( - "SortingAnalyzer with zarr : the Recording is not json serializable, the recording link will be lost for future load" + "SortingAnalyzer with zarr : the Recording is not json serializable, " + "the recording link will be lost for future load" ) # sorting provenance @@ -569,7 +573,6 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at # Alessio : we need to find a way to propagate compressor for all steps. # kwargs = dict(compressor=...) - zarr_kwargs = dict() add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **zarr_kwargs) recording_info = zarr_root.create_group("extensions") @@ -645,6 +648,18 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): return sorting_analyzer + def set_zarr_kwargs(self, **zarr_kwargs): + """ + Set the zarr kwargs for the zarr datasets. This can be used to specify custom compressors or filters. + Note that currently the zarr kwargs will be used for all zarr datasets. + + Parameters + ---------- + zarr_kwargs : keyword arguments + The zarr kwargs to set. + """ + self._zarr_kwargs = zarr_kwargs + def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True): """ Sets a temporary recording object. This function can be useful to temporarily set @@ -683,7 +698,7 @@ def _save_or_select_or_merge( sparsity_overlap=0.75, verbose=False, new_unit_ids=None, - **job_kwargs, + **kwargs, ) -> "SortingAnalyzer": """ Internal method used by both `save_as()`, `copy()`, `select_units()`, and `merge_units()`. @@ -712,8 +727,8 @@ def _save_or_select_or_merge( The new unit ids for merged units. Required if `merge_unit_groups` is not None. verbose : bool, default: False If True, output is verbose. - job_kwargs : dict - Keyword arguments for parallelization. + kwargs : keyword arguments + Keyword arguments including job_kwargs and zarr_kwargs. Returns ------- @@ -727,6 +742,8 @@ def _save_or_select_or_merge( else: recording = None + zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) + if self.sparsity is not None and unit_ids is None and merge_unit_groups is None: sparsity = self.sparsity elif self.sparsity is not None and unit_ids is not None and merge_unit_groups is None: @@ -807,10 +824,11 @@ def _save_or_select_or_merge( assert folder is not None, "For format='zarr' folder must be provided" folder = clean_zarr_folder_name(folder) SortingAnalyzer.create_zarr( - folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes + folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes, **zarr_kwargs ) new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) new_sorting_analyzer.folder = folder + new_sorting_analyzer._zarr_kwargs = zarr_kwargs else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") @@ -848,7 +866,7 @@ def _save_or_select_or_merge( return new_sorting_analyzer - def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": + def save_as(self, format="memory", folder=None, **zarr_kwargs) -> "SortingAnalyzer": """ Save SortingAnalyzer object into another format. Uselful for memory to zarr or memory to binary. @@ -863,10 +881,11 @@ def save_as(self, format="memory", folder=None) -> "SortingAnalyzer": The output folder if `format` is "zarr" or "binary_folder" format : "memory" | "binary_folder" | "zarr", default: "memory" The new backend format to use + zarr_kwargs : keyword arguments for zarr format """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder) + return self._save_or_select_or_merge(format=format, folder=folder, **zarr_kwargs) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -2051,24 +2070,24 @@ def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): self._save_run_info() - self._save_data(**kwargs) + self._save_data() if self.format == "zarr": import zarr zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) - def save(self, **kwargs): + def save(self): self._save_params() self._save_importing_provenance() self._save_run_info() - self._save_data(**kwargs) + self._save_data() if self.format == "zarr": import zarr zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) - def _save_data(self, **kwargs): + def _save_data(self): if self.format == "memory": return @@ -2107,14 +2126,14 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - import zarr import numcodecs + zarr_kwargs = self.sorting_analyzer._zarr_kwargs extension_group = self._get_zarr_extension_group(mode="r+") - compressor = kwargs.get("compressor", None) - if compressor is None: - compressor = get_default_zarr_compressor() + # if compression is not externally given, we use the default + if "compressor" not in zarr_kwargs: + zarr_kwargs["compressor"] = get_default_zarr_compressor() for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: @@ -2124,7 +2143,7 @@ def _save_data(self, **kwargs): name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() ) elif isinstance(ext_data, np.ndarray): - extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor) + extension_group.create_dataset(name=ext_data_name, data=ext_data, **zarr_kwargs) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): df_group = extension_group.create_group(ext_data_name) # first we save the index diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 5c7e267cc6..53e28fe083 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -10,6 +10,7 @@ load_sorting_analyzer, get_available_analyzer_extensions, get_default_analyzer_extension_params, + get_default_zarr_compressor, ) from spikeinterface.core.sortinganalyzer import ( register_result_extension, @@ -99,16 +100,25 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" - if folder.exists(): - shutil.rmtree(folder) + default_compressor = get_default_zarr_compressor() sorting_analyzer = create_sorting_analyzer( - sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None, overwrite=True ) sorting_analyzer.compute(["random_spikes", "templates"]) sorting_analyzer = load_sorting_analyzer(folder, format="auto") _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) + # check that compression is applied + assert ( + sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressor.codec_id + == default_compressor.codec_id + ) + assert ( + sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id + == default_compressor.codec_id + ) + # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 # this bug requires that we have an info.json file so we calculate templates above select_units_sorting_analyer = sorting_analyzer.select_units(unit_ids=[1]) @@ -117,11 +127,44 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): assert len(remove_units_sorting_analyer.unit_ids) == len(sorting_analyzer.unit_ids) - 1 assert 1 not in remove_units_sorting_analyer.unit_ids - folder = tmp_path / "test_SortingAnalyzer_zarr.zarr" - if folder.exists(): - shutil.rmtree(folder) - sorting_analyzer = create_sorting_analyzer( - sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None, return_scaled=False + # test no compression + sorting_analyzer_no_compression = create_sorting_analyzer( + sorting, + recording, + format="zarr", + folder=folder, + sparse=False, + sparsity=None, + return_scaled=False, + overwrite=True, + ) + sorting_analyzer_no_compression.set_zarr_kwargs(compressor=None) + sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) + assert ( + sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressor + is None + ) + assert sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressor is None + + # test a different compressor + from numcodecs import LZMA + + lzma_compressor = LZMA() + folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" + sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( + format="zarr", folder=folder, compressor=lzma_compressor + ) + assert ( + sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressor.codec_id + == LZMA.codec_id + ) + assert ( + sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id + == LZMA.codec_id ) @@ -326,7 +369,7 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): else: folder = None sorting_analyzer5 = sorting_analyzer.merge_units( - merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, mode="hard" + merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, merging_mode="hard" ) # test compute with extension-specific params diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 17f1ac08b3..355553428e 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -329,8 +329,7 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G zarr_group.attrs["num_segments"] = int(num_segments) zarr_group.create_dataset(name="unit_ids", data=sorting.unit_ids, compressor=None) - if "compressor" not in kwargs: - compressor = get_default_zarr_compressor() + compressor = kwargs.get("compressor", get_default_zarr_compressor()) # save sub fields spikes_group = zarr_group.create_group(name="spikes") From 23413b388c97730cb6208341d042864a1995dcf9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 27 Sep 2024 09:55:32 +0200 Subject: [PATCH 02/10] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 5ffdc85e50..f7a8485502 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -828,7 +828,7 @@ def _save_or_select_or_merge( ) new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) new_sorting_analyzer.folder = folder - new_sorting_analyzer._zarr_kwargs = zarr_kwargs + new_sorting_analyzer.set_zarr_kwargs(zarr_kwargs) else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") From fa97fd45689b29d5050652718938ae856132ff91 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 27 Sep 2024 09:59:53 +0200 Subject: [PATCH 03/10] Fix tests --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index f7a8485502..16945008ae 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -828,7 +828,7 @@ def _save_or_select_or_merge( ) new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) new_sorting_analyzer.folder = folder - new_sorting_analyzer.set_zarr_kwargs(zarr_kwargs) + new_sorting_analyzer.set_zarr_kwargs(**zarr_kwargs) else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") From a5d8c1db11182b31a3a98d2fe7cc41fe2ee9ca03 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 28 Sep 2024 16:54:39 +0200 Subject: [PATCH 04/10] Improve IBL recording extractor with PID --- .../extractors/iblextractors.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 5dd549347d..317ea21cce 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -105,6 +105,8 @@ def get_stream_names(eid: str, cache_folder: Optional[Union[Path, str]] = None, An instance of the ONE API to use for data loading. If not provided, a default instance is created using the default parameters. If you need to use a specific instance, you can create it using the ONE API and pass it here. + stream_type : "ap" | "lf" | None, default: None + The stream type to load, required when pid is provided and stream_name is not. Returns ------- @@ -140,6 +142,7 @@ def __init__( remove_cached: bool = True, stream: bool = True, one: "one.api.OneAlyx" = None, + stream_type: str | None = None, ): try: from brainbox.io.one import SpikeSortingLoader @@ -154,20 +157,24 @@ def __init__( one = IblRecordingExtractor._get_default_one(cache_folder=cache_folder) if pid is not None: + assert stream_type is not None, "When providing a PID, you must also provide a stream type." eid, _ = one.pid2eid(pid) - - stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) - if len(stream_names) > 1: - assert ( - stream_name is not None - ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." - assert stream_name in stream_names, ( - f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " - f"Please choose one of {stream_names}." - ) + pids, probes = one.eid2pid(eid) + pname = probes[pids.index(pid)] + stream_name = f"{pname}.{stream_type}" else: - stream_name = stream_names[0] - pname, stream_type = stream_name.split(".") + stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) + if len(stream_names) > 1: + assert ( + stream_name is not None + ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." + assert stream_name in stream_names, ( + f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " + f"Please choose one of {stream_names}." + ) + else: + stream_name = stream_names[0] + pname, stream_type = stream_name.split(".") self.ssl = SpikeSortingLoader(one=one, eid=eid, pid=pid, pname=pname) if pid is None: From 04ebe5ed6360aacca491a39e01002840c4af70fb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 30 Sep 2024 18:29:52 +0200 Subject: [PATCH 05/10] Use more general backend_kwargs --- src/spikeinterface/core/sortinganalyzer.py | 68 +++++++++++++------ .../core/tests/test_sortinganalyzer.py | 4 +- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 16945008ae..4ffdc8d95a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -219,8 +219,9 @@ def __init__( # this is used to store temporary recording self._temporary_recording = None - # for zarr format, we store the kwargs to create zarr datasets (e.g., compression) - self._zarr_kwargs = {} + # backend-specific kwargs for different formats, which can be used to + # set some parameters for saving (e.g., compression) + self._backend_kwargs = {"binary_folder": {}, "zarr": {}} # extensions are not loaded at init self.extensions = dict() @@ -352,7 +353,9 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut return sorting_analyzer @classmethod - def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): + def create_binary_folder( + cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, **binary_format_kwargs + ): # used by create and save_as assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" @@ -571,8 +574,6 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at # write sorting copy from .zarrextractors import add_sorting_to_zarr_group - # Alessio : we need to find a way to propagate compressor for all steps. - # kwargs = dict(compressor=...) add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **zarr_kwargs) recording_info = zarr_root.create_group("extensions") @@ -648,17 +649,27 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): return sorting_analyzer - def set_zarr_kwargs(self, **zarr_kwargs): + @property + def backend_kwargs(self): + """ + Returns the backend kwargs for the analyzer. + """ + return self._backend_kwargs.copy() + + @backend_kwargs.setter + def backend_kwargs(self, backend_kwargs): """ - Set the zarr kwargs for the zarr datasets. This can be used to specify custom compressors or filters. - Note that currently the zarr kwargs will be used for all zarr datasets. + Sets the backend kwargs for the analyzer. If the backend kwargs are not set, the default backend kwargs are used. Parameters ---------- - zarr_kwargs : keyword arguments + backend_kwargs : keyword arguments The zarr kwargs to set. """ - self._zarr_kwargs = zarr_kwargs + for key in backend_kwargs: + if key not in ("zarr", "binary_folder"): + raise ValueError(f"Unknown backend key: {key}. Available keys are 'zarr' and 'binary_folder'.") + self._backend_kwargs[key] = backend_kwargs[key] def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True): """ @@ -698,7 +709,8 @@ def _save_or_select_or_merge( sparsity_overlap=0.75, verbose=False, new_unit_ids=None, - **kwargs, + backend_kwargs=None, + **job_kwargs, ) -> "SortingAnalyzer": """ Internal method used by both `save_as()`, `copy()`, `select_units()`, and `merge_units()`. @@ -727,8 +739,10 @@ def _save_or_select_or_merge( The new unit ids for merged units. Required if `merge_unit_groups` is not None. verbose : bool, default: False If True, output is verbose. - kwargs : keyword arguments - Keyword arguments including job_kwargs and zarr_kwargs. + backend_kwargs : dict | None, default: None + Keyword arguments for the backend specified by format. + job_kwargs : keyword arguments + Keyword arguments for the job parallelization. Returns ------- @@ -742,8 +756,6 @@ def _save_or_select_or_merge( else: recording = None - zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) - if self.sparsity is not None and unit_ids is None and merge_unit_groups is None: sparsity = self.sparsity elif self.sparsity is not None and unit_ids is not None and merge_unit_groups is None: @@ -804,6 +816,8 @@ def _save_or_select_or_merge( # TODO: sam/pierre would create a curation field / curation.json with the applied merges. # What do you think? + backend_kwargs = {} if backend_kwargs is None else backend_kwargs + if format == "memory": # This make a copy of actual SortingAnalyzer new_sorting_analyzer = SortingAnalyzer.create_memory( @@ -814,8 +828,15 @@ def _save_or_select_or_merge( # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" folder = Path(folder) + binary_format_kwargs = backend_kwargs SortingAnalyzer.create_binary_folder( - folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes + folder, + sorting_provenance, + recording, + sparsity, + self.return_scaled, + self.rec_attributes, + **binary_format_kwargs, ) new_sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) new_sorting_analyzer.folder = folder @@ -823,15 +844,18 @@ def _save_or_select_or_merge( elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" folder = clean_zarr_folder_name(folder) + zarr_kwargs = backend_kwargs SortingAnalyzer.create_zarr( folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes, **zarr_kwargs ) new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) new_sorting_analyzer.folder = folder - new_sorting_analyzer.set_zarr_kwargs(**zarr_kwargs) else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") + if format != "memory": + new_sorting_analyzer.backend_kwargs = {format: backend_kwargs} + # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing sorted_extensions = _sort_extensions_by_dependency(self.extensions) @@ -866,7 +890,7 @@ def _save_or_select_or_merge( return new_sorting_analyzer - def save_as(self, format="memory", folder=None, **zarr_kwargs) -> "SortingAnalyzer": + def save_as(self, format="memory", folder=None, backend_kwargs=None) -> "SortingAnalyzer": """ Save SortingAnalyzer object into another format. Uselful for memory to zarr or memory to binary. @@ -881,11 +905,13 @@ def save_as(self, format="memory", folder=None, **zarr_kwargs) -> "SortingAnalyz The output folder if `format` is "zarr" or "binary_folder" format : "memory" | "binary_folder" | "zarr", default: "memory" The new backend format to use - zarr_kwargs : keyword arguments for zarr format + backend_kwargs : dict | None, default: None + Backend-specific kwargs for the specified format, which can be used to set some parameters for saving. + For example, if `format` is "zarr", one can set the compressor for the zarr datasets with `backend_kwargs={"compressor": some_compressor}`. """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, **zarr_kwargs) + return self._save_or_select_or_merge(format=format, folder=folder, backend_kwargs=backend_kwargs) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -2128,7 +2154,7 @@ def _save_data(self): elif self.format == "zarr": import numcodecs - zarr_kwargs = self.sorting_analyzer._zarr_kwargs + zarr_kwargs = self.sorting_analyzer.backend_kwargs.get("zarr", {}) extension_group = self._get_zarr_extension_group(mode="r+") # if compression is not externally given, we use the default diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 53e28fe083..f2aa7f459d 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -138,7 +138,7 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): return_scaled=False, overwrite=True, ) - sorting_analyzer_no_compression.set_zarr_kwargs(compressor=None) + sorting_analyzer_no_compression.backend_kwargs = {"zarr": dict(compressor=None)} sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) assert ( sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ @@ -154,7 +154,7 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): lzma_compressor = LZMA() folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( - format="zarr", folder=folder, compressor=lzma_compressor + format="zarr", folder=folder, backend_kwargs=dict(compressor=lzma_compressor) ) assert ( sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ From 022f924f1b4a7527e6c5a4b8d0ef4a68bf0e6a6c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 1 Oct 2024 11:05:01 +0200 Subject: [PATCH 06/10] Use backend_options for storage/saving_options --- src/spikeinterface/core/sortinganalyzer.py | 185 ++++++++++-------- .../core/tests/test_sortinganalyzer.py | 5 +- 2 files changed, 107 insertions(+), 83 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 4ffdc8d95a..10c5d8d475 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from copy import copy from packaging.version import parse from time import perf_counter @@ -45,6 +46,7 @@ def create_sorting_analyzer( sparsity=None, return_scaled=True, overwrite=False, + backend_options=None, **sparsity_kwargs, ) -> "SortingAnalyzer": """ @@ -80,7 +82,12 @@ def create_sorting_analyzer( This prevent return_scaled being differents from different extensions and having wrong snr for instance. overwrite: bool, default: False If True, overwrite the folder if it already exists. - + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) + sparsity_kwargs : keyword arguments Returns ------- @@ -144,13 +151,19 @@ def create_sorting_analyzer( return_scaled = False sorting_analyzer = SortingAnalyzer.create( - sorting, recording, format=format, folder=folder, sparsity=sparsity, return_scaled=return_scaled + sorting, + recording, + format=format, + folder=folder, + sparsity=sparsity, + return_scaled=return_scaled, + backend_options=backend_options, ) return sorting_analyzer -def load_sorting_analyzer(folder, load_extensions=True, format="auto", storage_options=None) -> "SortingAnalyzer": +def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer": """ Load a SortingAnalyzer object from disk. @@ -172,7 +185,7 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", storage_o The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, storage_options=storage_options) + return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, backend_options=backend_options) class SortingAnalyzer: @@ -205,7 +218,7 @@ def __init__( format=None, sparsity=None, return_scaled=True, - storage_options=None, + backend_options=None, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -215,13 +228,17 @@ def __init__( self.format = format self.sparsity = sparsity self.return_scaled = return_scaled - self.storage_options = storage_options + # this is used to store temporary recording self._temporary_recording = None # backend-specific kwargs for different formats, which can be used to # set some parameters for saving (e.g., compression) - self._backend_kwargs = {"binary_folder": {}, "zarr": {}} + # + # - storage_options: dict | None (fsspec storage options) + # - saving_options: dict | None + # (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) + self._backend_options = {} if backend_options is None else backend_options # extensions are not loaded at init self.extensions = dict() @@ -257,6 +274,7 @@ def create( folder=None, sparsity=None, return_scaled=True, + backend_options=None, ): # some checks if sorting.sampling_frequency != recording.sampling_frequency: @@ -281,22 +299,34 @@ def create( if format == "memory": sorting_analyzer = cls.create_memory(sorting, recording, sparsity, return_scaled, rec_attributes=None) elif format == "binary_folder": - cls.create_binary_folder(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) - sorting_analyzer = cls.load_from_binary_folder(folder, recording=recording) - sorting_analyzer.folder = Path(folder) + sorting_analyzer = cls.create_binary_folder( + folder, + sorting, + recording, + sparsity, + return_scaled, + rec_attributes=None, + backend_options=backend_options, + ) elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" folder = clean_zarr_folder_name(folder) - cls.create_zarr(folder, sorting, recording, sparsity, return_scaled, rec_attributes=None) - sorting_analyzer = cls.load_from_zarr(folder, recording=recording) - sorting_analyzer.folder = Path(folder) + sorting_analyzer = cls.create_zarr( + folder, + sorting, + recording, + sparsity, + return_scaled, + rec_attributes=None, + backend_options=backend_options, + ) else: raise ValueError("SortingAnalyzer.create: wrong format") return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto", storage_options=None): + def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None): """ Load folder or zarr. The recording can be given if the recording location has changed. @@ -310,10 +340,12 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", stora format = "binary_folder" if format == "binary_folder": - sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) + sorting_analyzer = SortingAnalyzer.load_from_binary_folder( + folder, recording=recording, backend_options=backend_options + ) elif format == "zarr": sorting_analyzer = SortingAnalyzer.load_from_zarr( - folder, recording=recording, storage_options=storage_options + folder, recording=recording, backend_options=backend_options ) if is_path_remote(str(folder)): @@ -353,9 +385,7 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut return sorting_analyzer @classmethod - def create_binary_folder( - cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, **binary_format_kwargs - ): + def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, backend_options): # used by create and save_as assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" @@ -417,8 +447,10 @@ def create_binary_folder( with open(settings_file, mode="w") as f: json.dump(check_json(settings), f, indent=4) + return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) + @classmethod - def load_from_binary_folder(cls, folder, recording=None): + def load_from_binary_folder(cls, folder, recording=None, backend_options=None): folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" @@ -489,34 +521,42 @@ def load_from_binary_folder(cls, folder, recording=None): format="binary_folder", sparsity=sparsity, return_scaled=return_scaled, + backend_options=backend_options, ) + sorting_analyzer.folder = folder return sorting_analyzer def _get_zarr_root(self, mode="r+"): import zarr - if is_path_remote(str(self.folder)): - mode = "r" + # if is_path_remote(str(self.folder)): + # mode = "r" + storage_options = self._backend_options.get("storage_options", {}) # we open_consolidated only if we are in read mode if mode in ("r+", "a"): - zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options) + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=storage_options) else: - zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options) + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=storage_options) return zarr_root @classmethod - def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, **zarr_kwargs): + def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes, backend_options): # used by create and save_as import zarr import numcodecs + from .zarrextractors import add_sorting_to_zarr_group folder = clean_zarr_folder_name(folder) if folder.is_dir(): raise ValueError(f"Folder already exists {folder}") - zarr_root = zarr.open(folder, mode="w") + backend_options = {} if backend_options is None else backend_options + storage_options = backend_options.get("storage_options", {}) + saving_options = backend_options.get("saving_options", {}) + + zarr_root = zarr.open(folder, mode="w", storage_options=storage_options) info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) @@ -569,21 +609,23 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) if sparsity is not None: - zarr_root.create_dataset("sparsity_mask", data=sparsity.mask) - - # write sorting copy - from .zarrextractors import add_sorting_to_zarr_group + zarr_root.create_dataset("sparsity_mask", data=sparsity.mask, **saving_options) - add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **zarr_kwargs) + add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **saving_options) recording_info = zarr_root.create_group("extensions") zarr.consolidate_metadata(zarr_root.store) + return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) + @classmethod - def load_from_zarr(cls, folder, recording=None, storage_options=None): + def load_from_zarr(cls, folder, recording=None, backend_options=None): import zarr + backend_options = {} if backend_options is None else backend_options + storage_options = backend_options.get("storage_options", {}) + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) si_info = zarr_root.attrs["spikeinterface_info"] @@ -644,33 +686,12 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, - storage_options=storage_options, + backend_options=backend_options, ) + sorting_analyzer.folder = folder return sorting_analyzer - @property - def backend_kwargs(self): - """ - Returns the backend kwargs for the analyzer. - """ - return self._backend_kwargs.copy() - - @backend_kwargs.setter - def backend_kwargs(self, backend_kwargs): - """ - Sets the backend kwargs for the analyzer. If the backend kwargs are not set, the default backend kwargs are used. - - Parameters - ---------- - backend_kwargs : keyword arguments - The zarr kwargs to set. - """ - for key in backend_kwargs: - if key not in ("zarr", "binary_folder"): - raise ValueError(f"Unknown backend key: {key}. Available keys are 'zarr' and 'binary_folder'.") - self._backend_kwargs[key] = backend_kwargs[key] - def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool = True): """ Sets a temporary recording object. This function can be useful to temporarily set @@ -709,7 +730,7 @@ def _save_or_select_or_merge( sparsity_overlap=0.75, verbose=False, new_unit_ids=None, - backend_kwargs=None, + backend_options=None, **job_kwargs, ) -> "SortingAnalyzer": """ @@ -739,8 +760,11 @@ def _save_or_select_or_merge( The new unit ids for merged units. Required if `merge_unit_groups` is not None. verbose : bool, default: False If True, output is verbose. - backend_kwargs : dict | None, default: None - Keyword arguments for the backend specified by format. + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) job_kwargs : keyword arguments Keyword arguments for the job parallelization. @@ -816,7 +840,7 @@ def _save_or_select_or_merge( # TODO: sam/pierre would create a curation field / curation.json with the applied merges. # What do you think? - backend_kwargs = {} if backend_kwargs is None else backend_kwargs + backend_options = {} if backend_options is None else backend_options if format == "memory": # This make a copy of actual SortingAnalyzer @@ -828,34 +852,31 @@ def _save_or_select_or_merge( # create a new folder assert folder is not None, "For format='binary_folder' folder must be provided" folder = Path(folder) - binary_format_kwargs = backend_kwargs - SortingAnalyzer.create_binary_folder( + new_sorting_analyzer = SortingAnalyzer.create_binary_folder( folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes, - **binary_format_kwargs, + backend_options=backend_options, ) - new_sorting_analyzer = SortingAnalyzer.load_from_binary_folder(folder, recording=recording) - new_sorting_analyzer.folder = folder elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" folder = clean_zarr_folder_name(folder) - zarr_kwargs = backend_kwargs - SortingAnalyzer.create_zarr( - folder, sorting_provenance, recording, sparsity, self.return_scaled, self.rec_attributes, **zarr_kwargs + new_sorting_analyzer = SortingAnalyzer.create_zarr( + folder, + sorting_provenance, + recording, + sparsity, + self.return_scaled, + self.rec_attributes, + backend_options=backend_options, ) - new_sorting_analyzer = SortingAnalyzer.load_from_zarr(folder, recording=recording) - new_sorting_analyzer.folder = folder else: raise ValueError(f"SortingAnalyzer.save: unsupported format: {format}") - if format != "memory": - new_sorting_analyzer.backend_kwargs = {format: backend_kwargs} - # make a copy of extensions # note that the copy of extension handle itself the slicing of units when necessary and also the saveing sorted_extensions = _sort_extensions_by_dependency(self.extensions) @@ -890,7 +911,7 @@ def _save_or_select_or_merge( return new_sorting_analyzer - def save_as(self, format="memory", folder=None, backend_kwargs=None) -> "SortingAnalyzer": + def save_as(self, format="memory", folder=None, backend_options=None) -> "SortingAnalyzer": """ Save SortingAnalyzer object into another format. Uselful for memory to zarr or memory to binary. @@ -905,13 +926,15 @@ def save_as(self, format="memory", folder=None, backend_kwargs=None) -> "Sorting The output folder if `format` is "zarr" or "binary_folder" format : "memory" | "binary_folder" | "zarr", default: "memory" The new backend format to use - backend_kwargs : dict | None, default: None - Backend-specific kwargs for the specified format, which can be used to set some parameters for saving. - For example, if `format` is "zarr", one can set the compressor for the zarr datasets with `backend_kwargs={"compressor": some_compressor}`. + backend_options : dict | None, default: None + Keyword arguments for the backend specified by format. It can contain the: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets, + e.g. compression/filters for zarr) """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, backend_kwargs=backend_kwargs) + return self._save_or_select_or_merge(format=format, folder=folder, backend_options=backend_options) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -2154,12 +2177,12 @@ def _save_data(self): elif self.format == "zarr": import numcodecs - zarr_kwargs = self.sorting_analyzer.backend_kwargs.get("zarr", {}) + saving_options = self.sorting_analyzer._backend_options.get("saving_options", {}) extension_group = self._get_zarr_extension_group(mode="r+") # if compression is not externally given, we use the default - if "compressor" not in zarr_kwargs: - zarr_kwargs["compressor"] = get_default_zarr_compressor() + if "compressor" not in saving_options: + saving_options["compressor"] = get_default_zarr_compressor() for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: @@ -2169,7 +2192,7 @@ def _save_data(self): name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.JSON() ) elif isinstance(ext_data, np.ndarray): - extension_group.create_dataset(name=ext_data_name, data=ext_data, **zarr_kwargs) + extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): df_group = extension_group.create_group(ext_data_name) # first we save the index diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index f2aa7f459d..35ab18b5f2 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -137,8 +137,9 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): sparsity=None, return_scaled=False, overwrite=True, + backend_options={"saving_options": {"compressor": None}}, ) - sorting_analyzer_no_compression.backend_kwargs = {"zarr": dict(compressor=None)} + print(sorting_analyzer_no_compression._backend_options) sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) assert ( sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ @@ -154,7 +155,7 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): lzma_compressor = LZMA() folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( - format="zarr", folder=folder, backend_kwargs=dict(compressor=lzma_compressor) + format="zarr", folder=folder, backend_options={"saving_options": {"compressor": lzma_compressor}} ) assert ( sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ From 605b7b40e8d2d51e144222ef710dcc5aa5cc8852 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 1 Oct 2024 12:44:40 +0200 Subject: [PATCH 07/10] Fix saving analyzer directly to remote storage --- src/spikeinterface/core/sortinganalyzer.py | 64 ++++++++++++++-------- 1 file changed, 41 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0ff028bd42..a50c391798 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -124,12 +124,14 @@ def create_sorting_analyzer( """ if format != "memory": if format == "zarr": - folder = clean_zarr_folder_name(folder) - if Path(folder).is_dir(): - if not overwrite: - raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") - else: - shutil.rmtree(folder) + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) + if not is_path_remote(folder): + if Path(folder).is_dir(): + if not overwrite: + raise ValueError(f"Folder already exists {folder}! Use overwrite=True to overwrite it.") + else: + shutil.rmtree(folder) # handle sparsity if sparsity is not None: @@ -249,6 +251,9 @@ def __repr__(self) -> str: nchan = self.get_num_channels() nunits = self.get_num_units() txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" + if self.format != "memory": + if is_path_remote(str(self.folder)): + txt += f" (remote)" if self.is_sparse(): txt += " - sparse" if self.has_recording(): @@ -311,7 +316,8 @@ def create( ) elif format == "zarr": assert folder is not None, "For format='zarr' folder must be provided" - folder = clean_zarr_folder_name(folder) + if not is_path_remote(folder): + folder = clean_zarr_folder_name(folder) sorting_analyzer = cls.create_zarr( folder, sorting, @@ -349,12 +355,7 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe folder, recording=recording, backend_options=backend_options ) - if is_path_remote(str(folder)): - sorting_analyzer.folder = folder - # in this case we only load extensions when needed - else: - sorting_analyzer.folder = Path(folder) - + if not is_path_remote(str(folder)): if load_extensions: sorting_analyzer.load_all_saved_extension() @@ -537,12 +538,16 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): def _get_zarr_root(self, mode="r+"): import zarr - # if is_path_remote(str(self.folder)): - # mode = "r" + assert mode in ("r+", "a", "r"), "mode must be 'r+', 'a' or 'r'" + storage_options = self._backend_options.get("storage_options", {}) # we open_consolidated only if we are in read mode if mode in ("r+", "a"): - zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=storage_options) + try: + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=storage_options) + except Exception as e: + # this could happen in remote mode, and it's a way to check if the folder is still there + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=storage_options) else: zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=storage_options) return zarr_root @@ -554,10 +559,14 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at import numcodecs from .zarrextractors import add_sorting_to_zarr_group - folder = clean_zarr_folder_name(folder) - - if folder.is_dir(): - raise ValueError(f"Folder already exists {folder}") + if is_path_remote(folder): + remote = True + else: + remote = False + if not remote: + folder = clean_zarr_folder_name(folder) + if folder.is_dir(): + raise ValueError(f"Folder already exists {folder}") backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) @@ -572,8 +581,9 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at zarr_root.attrs["settings"] = check_json(settings) # the recording + relative_to = folder if not remote else None if recording is not None: - rec_dict = recording.to_dict(relative_to=folder, recursive=True) + rec_dict = recording.to_dict(relative_to=relative_to, recursive=True) if recording.check_serializability("json"): # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) zarr_rec = np.array([check_json(rec_dict)], dtype=object) @@ -589,7 +599,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at warnings.warn("Recording not provided, instntiating SortingAnalyzer in recordingless mode.") # sorting provenance - sort_dict = sorting.to_dict(relative_to=folder, recursive=True) + sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True) if sorting.check_serializability("json"): zarr_sort = np.array([check_json(sort_dict)], dtype=object) zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) @@ -1106,7 +1116,15 @@ def copy(self): def is_read_only(self) -> bool: if self.format == "memory": return False - return not os.access(self.folder, os.W_OK) + elif self.format == "binary_folder": + return not os.access(self.folder, os.W_OK) + else: + if not is_path_remote(str(self.folder)): + return not os.access(self.folder, os.W_OK) + else: + # in this case we don't know if the file is read only so an error + # will be raised if we try to save/append + return False ## map attribute and property zone From cffb2c9415501028740ea7ee75f9308e4f824198 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 1 Oct 2024 15:02:31 +0200 Subject: [PATCH 08/10] Only reset extension when save is False --- src/spikeinterface/core/sortinganalyzer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a50c391798..bb3e8d5564 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -260,7 +260,9 @@ def __repr__(self) -> str: txt += " - has recording" if self.has_temporary_recording(): txt += " - has temporary recording" - ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) + ext_txt = f"Loaded {len(self.extensions)} extensions" + if len(self.extensions) > 0: + ext_txt += f": {', '.join(self.extensions.keys())}" txt += "\n" + ext_txt return txt @@ -2297,7 +2299,8 @@ def set_params(self, save=True, **params): """ # this ensure data is also deleted and corresponds to params # this also ensure the group is created - self._reset_extension_folder() + if save: + self._reset_extension_folder() params = self._set_params(**params) self.params = params From e564f8b8229572d049c8107ad9d9d358c6c96724 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 2 Oct 2024 17:05:42 +0200 Subject: [PATCH 09/10] Fall back to anon=True for zarr extractors and analyzers in case backend/storage options is not provided --- src/spikeinterface/core/sortinganalyzer.py | 40 ++++++++++++++++------ src/spikeinterface/core/zarrextractors.py | 32 ++++++++++++++--- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1f404c755d..14b4f73eaf 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -65,18 +65,18 @@ def create_sorting_analyzer( recording : Recording The recording object folder : str or Path or None, default: None - The folder where waveforms are cached + The folder where analyzer is cached format : "memory | "binary_folder" | "zarr", default: "memory" - The mode to store waveforms. If "folder", waveforms are stored on disk in the specified folder. + The mode to store analyzer. If "folder", the analyzer is stored on disk in the specified folder. The "folder" argument must be specified in case of mode "folder". - If "memory" is used, the waveforms are stored in RAM. Use this option carefully! + If "memory" is used, the analyzer is stored in RAM. Use this option carefully! sparse : bool, default: True If True, then a sparsity mask is computed using the `estimate_sparsity()` function using a few spikes to get an estimate of dense templates to create a ChannelSparsity object. Then, the sparsity will be propagated to all ResultExtention that handle sparsity (like wavforms, pca, ...) You can control `estimate_sparsity()` : all extra arguments are propagated to it (included job_kwargs) sparsity : ChannelSparsity or None, default: None - The sparsity used to compute waveforms. If this is given, `sparse` is ignored. + The sparsity used to compute exensions. If this is given, `sparse` is ignored. return_scaled : bool, default: True All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates". This prevent return_scaled being differents from different extensions and having wrong snr for instance. @@ -98,7 +98,7 @@ def create_sorting_analyzer( -------- >>> import spikeinterface as si - >>> # Extract dense waveforms and save to disk with binary_folder format. + >>> # Create dense analyzer and save to disk with binary_folder format. >>> sorting_analyzer = si.create_sorting_analyzer(sorting, recording, format="binary_folder", folder="/path/to_my/result") >>> # Can be reload @@ -172,14 +172,19 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o Parameters ---------- folder : str or Path - The folder / zarr folder where the waveform extractor is stored + The folder / zarr folder where the analyzer is stored. If the folder is a remote path stored in the cloud, + the backend_options can be used to specify credentials. If the remote path is not accessible, + and backend_options is not provided, the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. load_extensions : bool, default: True Load all extensions or not. format : "auto" | "binary_folder" | "zarr" The format of the folder. - storage_options : dict | None, default: None - The storage options to specify credentials to remote zarr bucket. - For open buckets, it doesn't need to be specified. + backend_options : dict | None, default: None + The backend options for the backend. + The dictionary can contain the following keys: + - storage_options: dict | None (fsspec storage options) + - saving_options: dict | None (additional saving options for creating and saving datasets) Returns ------- @@ -187,7 +192,20 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, backend_options=backend_options) + if is_path_remote(folder) and backend_options is None: + try: + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) + except Exception as e: + backend_options = dict(storage_options=dict(anon=True)) + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) + else: + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options + ) class SortingAnalyzer: @@ -2286,7 +2304,7 @@ def delete(self): def reset(self): """ - Reset the waveform extension. + Reset the extension. Delete the sub folder and create a new empty one. """ self._reset_extension_folder() diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 355553428e..26cb3cc6fc 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -12,6 +12,7 @@ from .core_tools import define_function_from_class, check_json from .job_tools import split_job_kwargs from .recording_tools import determine_cast_unsigned +from .core_tools import is_path_remote class ZarrRecordingExtractor(BaseRecording): @@ -21,7 +22,11 @@ class ZarrRecordingExtractor(BaseRecording): Parameters ---------- folder_path : str or Path - Path to the zarr root folder + Path to the zarr root folder. This can be a local path or a remote path (s3:// or gcs://). + If the path is a remote path, the storage_options can be provided to specify credentials. + If the remote path is not accessible and backend_options is not provided, + the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. storage_options : dict or None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. @@ -35,7 +40,14 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + if is_path_remote(str(folder_path)) and storage_options is None: + try: + self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + except Exception as e: + storage_options = {"anon": True} + self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + else: + self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) sampling_frequency = self._root.attrs.get("sampling_frequency", None) num_segments = self._root.attrs.get("num_segments", None) @@ -150,7 +162,11 @@ class ZarrSortingExtractor(BaseSorting): Parameters ---------- folder_path : str or Path - Path to the zarr root file + Path to the zarr root file. This can be a local path or a remote path (s3:// or gcs://). + If the path is a remote path, the storage_options can be provided to specify credentials. + If the remote path is not accessible and backend_options is not provided, + the function will try to load the object in anonymous mode (anon=True), + which enables to load data from open buckets. storage_options : dict or None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. zarr_group : str or None, default: None @@ -165,7 +181,15 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - zarr_root = self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + if is_path_remote(str(folder_path)) and storage_options is None: + try: + zarr_root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + except Exception as e: + storage_options = {"anon": True} + zarr_root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + else: + zarr_root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + if zarr_group is None: self._root = zarr_root else: From 580703f5e8382aeca58cc5d9ec4e300cbfc6f3e3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 2 Oct 2024 17:50:23 +0200 Subject: [PATCH 10/10] Protect against uninitialized chunks and add anonymous zarr open --- src/spikeinterface/core/zarrextractors.py | 42 +++++++++++++---------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 26cb3cc6fc..ff552dfb54 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -15,6 +15,18 @@ from .core_tools import is_path_remote +def anononymous_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: dict | None = None): + if is_path_remote(str(folder_path)) and storage_options is None: + try: + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + except Exception as e: + storage_options = {"anon": True} + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + else: + root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + return root + + class ZarrRecordingExtractor(BaseRecording): """ RecordingExtractor for a zarr format @@ -40,14 +52,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - if is_path_remote(str(folder_path)) and storage_options is None: - try: - self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) - except Exception as e: - storage_options = {"anon": True} - self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) - else: - self._root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + self._root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) sampling_frequency = self._root.attrs.get("sampling_frequency", None) num_segments = self._root.attrs.get("num_segments", None) @@ -93,7 +98,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) nbytes_segment = self._root[trace_name].nbytes nbytes_stored_segment = self._root[trace_name].nbytes_stored - cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment + if nbytes_stored_segment > 0: + cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment + else: + cr_by_segment[segment_index] = np.nan total_nbytes += nbytes_segment total_nbytes_stored += nbytes_stored_segment @@ -117,7 +125,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None) if annotations is not None: self.annotate(**annotations) # annotate compression ratios - cr = total_nbytes / total_nbytes_stored + if total_nbytes_stored > 0: + cr = total_nbytes / total_nbytes_stored + else: + cr = np.nan self.annotate(compression_ratio=cr, compression_ratio_segments=cr_by_segment) self._kwargs = {"folder_path": folder_path_kwarg, "storage_options": storage_options} @@ -181,14 +192,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) - if is_path_remote(str(folder_path)) and storage_options is None: - try: - zarr_root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) - except Exception as e: - storage_options = {"anon": True} - zarr_root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) - else: - zarr_root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + zarr_root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) if zarr_group is None: self._root = zarr_root @@ -267,7 +271,7 @@ def read_zarr( """ # TODO @alessio : we should have something more explicit in our zarr format to tell which object it is. # for the futur SortingAnalyzer we will have this 2 fields!!! - root = zarr.open(str(folder_path), mode="r", storage_options=storage_options) + root = anononymous_zarr_open(folder_path, mode="r", storage_options=storage_options) if "channel_ids" in root.keys(): return read_zarr_recording(folder_path, storage_options=storage_options) elif "unit_ids" in root.keys():