From a5ed0b36fae0e33f6b6f310807e90258fed2962a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 6 Sep 2024 17:46:18 +0200 Subject: [PATCH 1/5] Ensure sorting analyzer in zarr are consolidated --- src/spikeinterface/core/sortinganalyzer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 817c453a97..0831391469 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -76,6 +76,9 @@ def create_sorting_analyzer( return_scaled : bool, default: True All extensions that play with traces will use this global return_scaled : "waveforms", "noise_levels", "templates". This prevent return_scaled being differents from different extensions and having wrong snr for instance. + overwrite: bool, default: False + If True, overwrite the folder if it already exists. + Returns ------- @@ -563,11 +566,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at recording_info = zarr_root.create_group("extensions") + zarr.consolidate_metadata(zarr_root.store) + @classmethod def load_from_zarr(cls, folder, recording=None, storage_options=None): import zarr - zarr_root = zarr.open(str(folder), mode="r", storage_options=storage_options) + zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options) # load internal sorting in memory sorting = NumpySorting.from_sorting( @@ -2002,7 +2007,7 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") elif self.format == "zarr": - + import zarr import numcodecs extension_group = self._get_zarr_extension_group(mode="r+") @@ -2036,6 +2041,8 @@ def _save_data(self, **kwargs): except: raise Exception(f"Could not save {ext_data_name} as extension data") extension_group[ext_data_name].attrs["object"] = True + # we need to re-consolidate + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) def _reset_extension_folder(self): """ @@ -2051,7 +2058,7 @@ def _reset_extension_folder(self): import zarr zarr_root = zarr.open(self.folder, mode="r+") - extension_group = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) def reset(self): """ From c12ceb6a60af3ae4625101e9db6332205b71db0c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 9 Sep 2024 17:43:48 +0200 Subject: [PATCH 2/5] Use open_consolidated when possible --- src/spikeinterface/core/sortinganalyzer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0831391469..3abf0e9b5e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -488,7 +488,11 @@ def _get_zarr_root(self, mode="r+"): if is_path_remote(str(self.folder)): mode = "r" - zarr_root = zarr.open(self.folder, mode=mode, storage_options=self.storage_options) + # we open_consolidated only if we are in read mode + if mode in ("r+", "a"): + zarr_root = zarr.open(str(self.folder), mode=mode, storage_options=self.storage_options) + else: + zarr_root = zarr.open_consolidated(self.folder, mode=mode, storage_options=self.storage_options) return zarr_root @classmethod @@ -2057,8 +2061,9 @@ def _reset_extension_folder(self): elif self.format == "zarr": import zarr - zarr_root = zarr.open(self.folder, mode="r+") + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) + zarr.consolidate_metadata(zarr_root.store) def reset(self): """ @@ -2074,7 +2079,7 @@ def set_params(self, save=True, **params): Set parameters for the extension and make it persistent in json. """ - # this ensure data is also deleted and corresponf to params + # this ensure data is also deleted and corresponds to params # this also ensure the group is created self._reset_extension_folder() From 8b5de9d65abb3152c6ad30ecb9599b27f17a903e Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 11 Sep 2024 09:52:33 +0200 Subject: [PATCH 3/5] propagate storage option --- src/spikeinterface/core/sortinganalyzer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 3abf0e9b5e..519d741bc1 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -622,6 +622,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, + storage_options=storage_options ) return sorting_analyzer From c56d625d744e7474658d0acca1284f6eacbbe200 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 07:54:25 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 519d741bc1..daa693d667 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -622,7 +622,7 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None): format="zarr", sparsity=sparsity, return_scaled=return_scaled, - storage_options=storage_options + storage_options=storage_options, ) return sorting_analyzer From ff07ac603bb8552ef40bd1f43962f201a72df17c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 11 Sep 2024 11:50:22 +0200 Subject: [PATCH 5/5] Fix run_info and consolidate metadata --- src/spikeinterface/core/sortinganalyzer.py | 35 +++++++++++++++---- .../tests/test_analyzer_extension_core.py | 5 +++ .../core/tests/test_sortinganalyzer.py | 3 ++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 4f5be665f4..424fab7c5e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1472,7 +1472,7 @@ def delete_extension(self, extension_name) -> None: if self.format != "memory" and self.has_extension(extension_name): # need a reload to reset the folder ext = self.load_extension(extension_name) - ext.reset() + ext.delete() # remove from dict self.extensions.pop(extension_name, None) @@ -2014,19 +2014,17 @@ def run(self, save=True, **kwargs): # NB: this call to _save_params() also resets the folder or zarr group self._save_params() self._save_importing_provenance() - self._save_run_info() t_start = perf_counter() self._run(**kwargs) t_end = perf_counter() self.run_info["runtime_s"] = t_end - t_start + self.run_info["run_completed"] = True if save and not self.sorting_analyzer.is_read_only(): + self._save_run_info() self._save_data(**kwargs) - self.run_info["run_completed"] = True - self._save_run_info() - def save(self, **kwargs): self._save_params() self._save_importing_provenance() @@ -2126,6 +2124,32 @@ def _reset_extension_folder(self): _ = zarr_root["extensions"].create_group(self.extension_name, overwrite=True) zarr.consolidate_metadata(zarr_root.store) + def _delete_extension_folder(self): + """ + Delete the extension in a folder (binary or zarr). + """ + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + if extension_folder.is_dir(): + shutil.rmtree(extension_folder) + + elif self.format == "zarr": + import zarr + + zarr_root = self.sorting_analyzer._get_zarr_root(mode="r+") + if self.extension_name in zarr_root["extensions"]: + del zarr_root["extensions"][self.extension_name] + zarr.consolidate_metadata(zarr_root.store) + + def delete(self): + """ + Delete the extension from the folder or zarr and from the dict. + """ + self._delete_extension_folder() + self.params = None + self.run_info = self._default_run_info_dict() + self.data = dict() + def reset(self): """ Reset the waveform extension. @@ -2154,7 +2178,6 @@ def set_params(self, save=True, **params): if save: self._save_params() self._save_importing_provenance() - self._save_run_info() def _save_params(self): params_to_save = self.params.copy() diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index b4d96a3391..626899ab6e 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -79,15 +79,20 @@ def _check_result_extension(sorting_analyzer, extension_name, cache_folder): ) def test_ComputeRandomSpikes(format, sparse, create_cache_folder): cache_folder = create_cache_folder + print("Creating analyzer") sorting_analyzer = get_sorting_analyzer(cache_folder, format=format, sparse=sparse) + print("Computing random spikes") ext = sorting_analyzer.compute("random_spikes", max_spikes_per_unit=10, seed=2205) indices = ext.data["random_spikes_indices"] assert indices.size == 10 * sorting_analyzer.sorting.unit_ids.size + print("Checking results") _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) + print("Delering extension") sorting_analyzer.delete_extension("random_spikes") + print("Re-computing random spikes") ext = sorting_analyzer.compute("random_spikes", method="all") indices = ext.data["random_spikes_indices"] assert indices.size == len(sorting_analyzer.sorting.to_spike_vector()) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 3f45487f4c..77b8f2c5bf 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -126,6 +126,8 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): def test_load_without_runtime_info(tmp_path, dataset): + import zarr + recording, sorting = dataset folder = tmp_path / "test_SortingAnalyzer_run_info" @@ -153,6 +155,7 @@ def test_load_without_runtime_info(tmp_path, dataset): root = sorting_analyzer._get_zarr_root(mode="r+") for ext in extensions: del root["extensions"][ext].attrs["run_info"] + zarr.consolidate_metadata(root.store) # should raise a warning for missing run_info with pytest.warns(UserWarning): sorting_analyzer = load_sorting_analyzer(folder, format="auto")