From 5c351a4d25ac3fb992e30db3eeea3dcdf4cd3955 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Wed, 28 Aug 2024 11:37:05 -0400 Subject: [PATCH 01/15] save run_info to check extensions for completion --- src/spikeinterface/core/sortinganalyzer.py | 90 +++++++++++++++++++--- 1 file changed, 79 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fa4547d272..5eaebb3189 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -18,6 +18,8 @@ import spikeinterface +from zarr.errors import ArrayNotFoundError + from .baserecording import BaseRecording from .basesorting import BaseSorting @@ -1719,8 +1721,12 @@ def __init__(self, sorting_analyzer): self._sorting_analyzer = weakref.ref(sorting_analyzer) self.params = None + self.run_info = self._default_run_info_dict() self.data = dict() + def _default_run_info_dict(self): + return dict(run_completed=False, data_loadable=False, runtime_s=None) + ####### # This 3 methods must be implemented in the subclass!!! # See DummyAnalyzerExtension in test_sortinganalyzer.py as a simple example @@ -1832,11 +1838,35 @@ def _get_zarr_extension_group(self, mode="r+"): def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() - ext.load_data() - if cls.need_backward_compatibility_on_load: - ext._handle_backward_compatibility_on_load() + ext.load_run_info() + if ext.run_info["run_completed"] and ext.run_info["data_loadable"]: + ext.load_data() + if cls.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + return ext + elif ext.run_info["run_completed"] and not ext.run_info["data_loadable"]: + warnings.warn( + f"Extension {cls.extension_name} has been computed but the data is not loadable. " + "The extension should be re-computed." + ) + return ext + else: + return None + + def load_run_info(self): + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + assert run_info_file.is_file(), f"No run_info file in extension {self.extension_name} folder" + with open(str(run_info_file), "r") as f: + run_info = json.load(f) + + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r") + assert "run_info" in extension_group.attrs, f"No run_info file in extension {self.extension_name} folder" + run_info = extension_group.attrs["run_info"] - return ext + self.run_info = run_info def load_params(self): if self.format == "binary_folder": @@ -1853,13 +1883,15 @@ def load_params(self): self.params = params - def load_data(self): + def load_data(self, keep=True): + ext_data = None + if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() for ext_data_file in extension_folder.iterdir(): # patch for https://github.com/SpikeInterface/spikeinterface/issues/3041 # maybe add a check for version number from the info.json during loading only - if ext_data_file.name == "params.json" or ext_data_file.name == "info.json": + if ext_data_file.name == "params.json" or ext_data_file.name == "info.json" or ext_data_file.name == "run_info.json": continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": @@ -1878,7 +1910,6 @@ def load_data(self): ext_data = pickle.load(ext_data_file.open("rb")) else: continue - self.data[ext_data_name] = ext_data elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") @@ -1898,12 +1929,29 @@ def load_data(self): else: # this load in memmory ext_data = np.array(ext_data_) - self.data[ext_data_name] = ext_data + + if ext_data is None: + warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") + + if keep: + self.data[ext_data_name] = ext_data + + def _check_data_loadable(self): + try: + self.load_data(keep=False) + return True + except ( + ValueError, IOError, EOFError, KeyError, UnicodeDecodeError, + json.JSONDecodeError, pickle.UnpicklingError, pd.errors.ParserError, + ArrayNotFoundError + ): + return False def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() + new_extension.run_info = self.run_info.copy() # TODO: does copy() assume both extensions have been run? if unit_ids is None: new_extension.data = self.data else: @@ -1922,6 +1970,7 @@ def merge( ): new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() + new_extension.run_info = self.run_info.copy() # TODO: does merge() assume both extensions have been run? new_extension.data = self._merge_extension_data( merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs ) @@ -1930,19 +1979,26 @@ def merge( def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): - # this also reset the folder or zarr group - self._save_params() + # NB: this call to _save_params() also resets the folder or zarr group + self._save_params() self._save_importing_provenance() + self._save_run_info() self._run(**kwargs) - + if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) + self.run_info["data_loadable"] = self._check_data_loadable() # maybe overkill? + + self.run_info["run_completed"] = True + self._save_run_info() def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) + self.run_info["data_loadable"] = self._check_data_loadable() + self._save_run_info() def _save_data(self, **kwargs): if self.format == "memory": @@ -2041,6 +2097,7 @@ def reset(self): """ self._reset_extension_folder() self.params = None + self.run_info = self._default_run_info_dict() self.data = dict() def set_params(self, save=True, **params): @@ -2098,6 +2155,17 @@ def _save_importing_provenance(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["info"] = info + def _save_run_info(self): + run_info = self.run_info.copy() + + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8") + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r+") + extension_group.attrs["run_info"] = run_info + def get_pipeline_nodes(self): assert ( self.use_nodepipeline From 725b208afba299191419f0ef5bfd9eb379036730 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Wed, 28 Aug 2024 12:55:58 -0400 Subject: [PATCH 02/15] save run time --- src/spikeinterface/core/sortinganalyzer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 5eaebb3189..03a3653c52 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from time import time import numpy as np @@ -1984,12 +1985,15 @@ def run(self, save=True, **kwargs): self._save_importing_provenance() self._save_run_info() + start = time() self._run(**kwargs) - + end = time() + if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) self.run_info["data_loadable"] = self._check_data_loadable() # maybe overkill? - + + self.run_info["runtime_s"] = np.round(end - start, 1) self.run_info["run_completed"] = True self._save_run_info() From d95a754c0924cfcb2f55307a052a1382ecf8ea92 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Wed, 28 Aug 2024 12:56:13 -0400 Subject: [PATCH 03/15] bug fixes for pipeline extensions --- src/spikeinterface/core/sortinganalyzer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 03a3653c52..7bdf2dc516 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2002,6 +2002,8 @@ def save(self, **kwargs): self._save_importing_provenance() self._save_data(**kwargs) self.run_info["data_loadable"] = self._check_data_loadable() + if self.run_info["data_loadable"]: + self.run_info["run_completed"] = True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead self._save_run_info() def _save_data(self, **kwargs): @@ -2122,6 +2124,7 @@ def set_params(self, save=True, **params): if save: self._save_params() self._save_importing_provenance() + self._save_run_info() def _save_params(self): params_to_save = self.params.copy() From 1699413fbb16872f0c3a40561e874a650d87648b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:27:44 +0000 Subject: [PATCH 04/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 34 +++++++++++++++------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 7bdf2dc516..3ca4617551 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1853,7 +1853,7 @@ def load(cls, sorting_analyzer): return ext else: return None - + def load_run_info(self): if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() @@ -1861,7 +1861,7 @@ def load_run_info(self): assert run_info_file.is_file(), f"No run_info file in extension {self.extension_name} folder" with open(str(run_info_file), "r") as f: run_info = json.load(f) - + elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") assert "run_info" in extension_group.attrs, f"No run_info file in extension {self.extension_name} folder" @@ -1892,7 +1892,11 @@ def load_data(self, keep=True): for ext_data_file in extension_folder.iterdir(): # patch for https://github.com/SpikeInterface/spikeinterface/issues/3041 # maybe add a check for version number from the info.json during loading only - if ext_data_file.name == "params.json" or ext_data_file.name == "info.json" or ext_data_file.name == "run_info.json": + if ( + ext_data_file.name == "params.json" + or ext_data_file.name == "info.json" + or ext_data_file.name == "run_info.json" + ): continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": @@ -1930,10 +1934,10 @@ def load_data(self, keep=True): else: # this load in memmory ext_data = np.array(ext_data_) - + if ext_data is None: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") - + if keep: self.data[ext_data_name] = ext_data @@ -1942,10 +1946,16 @@ def _check_data_loadable(self): self.load_data(keep=False) return True except ( - ValueError, IOError, EOFError, KeyError, UnicodeDecodeError, - json.JSONDecodeError, pickle.UnpicklingError, pd.errors.ParserError, - ArrayNotFoundError - ): + ValueError, + IOError, + EOFError, + KeyError, + UnicodeDecodeError, + json.JSONDecodeError, + pickle.UnpicklingError, + pd.errors.ParserError, + ArrayNotFoundError, + ): return False def copy(self, new_sorting_analyzer, unit_ids=None): @@ -1981,7 +1991,7 @@ def merge( def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # NB: this call to _save_params() also resets the folder or zarr group - self._save_params() + self._save_params() self._save_importing_provenance() self._save_run_info() @@ -2003,7 +2013,9 @@ def save(self, **kwargs): self._save_data(**kwargs) self.run_info["data_loadable"] = self._check_data_loadable() if self.run_info["data_loadable"]: - self.run_info["run_completed"] = True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead + self.run_info["run_completed"] = ( + True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead + ) self._save_run_info() def _save_data(self, **kwargs): From 869b01a44d758a35869aa91b1009270307850304 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:25:58 -0400 Subject: [PATCH 05/15] switch to perf counter --- src/spikeinterface/core/sortinganalyzer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 61c5515425..642f760241 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2013,9 +2013,9 @@ def run(self, save=True, **kwargs): self._save_importing_provenance() self._save_run_info() - start = time() + start = perf_counter() self._run(**kwargs) - end = time() + end = perf_counter() if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) From 554f6e3765767fe7a045e910027a2606d61ad0c4 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:26:56 -0400 Subject: [PATCH 06/15] use perf counter --- src/spikeinterface/core/sortinganalyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 642f760241..9a54f0a627 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,7 +11,7 @@ import shutil import warnings import importlib -from time import time +from time import perf_counter import numpy as np From be023edc10bd26afef2fd26aa3847eadcc602cf9 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:27:16 -0400 Subject: [PATCH 07/15] remove data_loadable and _check_data_loadable --- src/spikeinterface/core/sortinganalyzer.py | 43 ++++------------------ 1 file changed, 8 insertions(+), 35 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9a54f0a627..251ce3249a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1744,7 +1744,7 @@ def __init__(self, sorting_analyzer): self.data = dict() def _default_run_info_dict(self): - return dict(run_completed=False, data_loadable=False, runtime_s=None) + return dict(run_completed=False, runtime_s=None) ####### # This 3 methods must be implemented in the subclass!!! @@ -1858,17 +1858,11 @@ def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() ext.load_run_info() - if ext.run_info["run_completed"] and ext.run_info["data_loadable"]: + if ext.run_info["run_completed"]: ext.load_data() if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() return ext - elif ext.run_info["run_completed"] and not ext.run_info["data_loadable"]: - warnings.warn( - f"Extension {cls.extension_name} has been computed but the data is not loadable. " - "The extension should be re-computed." - ) - return ext else: return None @@ -1902,7 +1896,7 @@ def load_params(self): self.params = params - def load_data(self, keep=True): + def load_data(self): ext_data = None if self.format == "binary_folder": @@ -1956,25 +1950,7 @@ def load_data(self, keep=True): if ext_data is None: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") - if keep: - self.data[ext_data_name] = ext_data - - def _check_data_loadable(self): - try: - self.load_data(keep=False) - return True - except ( - ValueError, - IOError, - EOFError, - KeyError, - UnicodeDecodeError, - json.JSONDecodeError, - pickle.UnpicklingError, - pd.errors.ParserError, - ArrayNotFoundError, - ): - return False + self.data[ext_data_name] = ext_data def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! @@ -2016,12 +1992,11 @@ def run(self, save=True, **kwargs): start = perf_counter() self._run(**kwargs) end = perf_counter() + self.run_info["runtime_s"] = np.round(end - start, 1) if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) - self.run_info["data_loadable"] = self._check_data_loadable() # maybe overkill? - self.run_info["runtime_s"] = np.round(end - start, 1) self.run_info["run_completed"] = True self._save_run_info() @@ -2029,11 +2004,9 @@ def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) - self.run_info["data_loadable"] = self._check_data_loadable() - if self.run_info["data_loadable"]: - self.run_info["run_completed"] = ( - True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to check here (or at least somewhere) instead - ) + self.run_info["run_completed"] = ( + True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to change run_completed here (or somewhere, at least) + ) self._save_run_info() def _save_data(self, **kwargs): From 943e398b3200be284d26752e338907f46cde22c6 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 09:51:15 -0400 Subject: [PATCH 08/15] edge case where data file is deleted --- src/spikeinterface/core/sortinganalyzer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 251ce3249a..33ab79c45e 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1947,11 +1947,11 @@ def load_data(self): # this load in memmory ext_data = np.array(ext_data_) - if ext_data is None: + if ext_data is not None: + self.data[ext_data_name] = ext_data + else: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") - self.data[ext_data_name] = ext_data - def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_analyzer) @@ -2183,7 +2183,8 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data" + assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" + assert len(self.data) > 0, "Extension has been run but no data found — data file might be missing. Re-compute the extension." return self._get_data(*args, **kwargs) From f9d7c0491cfe5cb6ac75dea9e632043fb11fa71c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:53:14 +0000 Subject: [PATCH 09/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 33ab79c45e..1bded4e6a8 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1948,7 +1948,7 @@ def load_data(self): ext_data = np.array(ext_data_) if ext_data is not None: - self.data[ext_data_name] = ext_data + self.data[ext_data_name] = ext_data else: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") @@ -2183,8 +2183,12 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" - assert len(self.data) > 0, "Extension has been run but no data found — data file might be missing. Re-compute the extension." + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" + assert ( + len(self.data) > 0 + ), "Extension has been run but no data found — data file might be missing. Re-compute the extension." return self._get_data(*args, **kwargs) From 8e9acfc66bd22b2232914274c12319f43e197559 Mon Sep 17 00:00:00 2001 From: Jonah Pearl Date: Thu, 29 Aug 2024 10:18:23 -0400 Subject: [PATCH 10/15] always return None if extension data is missing --- src/spikeinterface/core/sortinganalyzer.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1bded4e6a8..9b2144ddcf 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1862,9 +1862,12 @@ def load(cls, sorting_analyzer): ext.load_data() if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() - return ext - else: - return None + if len(ext.data) > 0: + return ext + + # If extension run not completed, or data has gone missing, + # return None to indicate that the extension should be (re)computed. + return None def load_run_info(self): if self.format == "binary_folder": @@ -2183,12 +2186,8 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info[ - "run_completed" - ], f"You must run the extension {self.extension_name} before retrieving data" - assert ( - len(self.data) > 0 - ), "Extension has been run but no data found — data file might be missing. Re-compute the extension." + assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" + assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) From f1a3d97d8daf55607f9ac0da909b03e4297fce9d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:21:14 +0000 Subject: [PATCH 11/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/sortinganalyzer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 9b2144ddcf..017760c002 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2186,7 +2186,9 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert self.run_info["run_completed"], f"You must run the extension {self.extension_name} before retrieving data" + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) From dd53372d58506f045ce5b1742498a9fc1c8a20a8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 19:16:22 +0200 Subject: [PATCH 12/15] Fix error in load data, ensure backward compatibility, and add tests --- src/spikeinterface/core/sortinganalyzer.py | 51 ++++++++++++------- .../core/tests/test_sortinganalyzer.py | 33 ++++++++++++ 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b0c6162b05..06879d09a2 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -19,8 +19,6 @@ import spikeinterface -from zarr.errors import ArrayNotFoundError - from .baserecording import BaseRecording from .basesorting import BaseSorting @@ -1339,6 +1337,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys()) + t_start = perf_counter() results = run_node_pipeline( self.recording, all_nodes, @@ -1348,10 +1347,14 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job squeeze_output=False, verbose=verbose, ) + t_end = perf_counter() + # for pipeline node extensions we can only track the runtime of the run_node_pipeline + runtime_s = np.round(t_end - t_start, 1) for r, result in enumerate(results): extension_name, variable_name = result_routage[r] extension_instances[extension_name].data[variable_name] = result + extension_instances[extension_name].run_info["runtime_s"] = runtime_s for extension_name, extension_instance in extension_instances.items(): self.extensions[extension_name] = extension_instance @@ -1859,13 +1862,20 @@ def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() ext.load_run_info() - if ext.run_info["run_completed"]: + if ext.run_info is not None: + if ext.run_info["run_completed"]: + ext.load_data() + if cls.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + if len(ext.data) > 0: + return ext + else: + # this is for back-compatibility of old analyzers ext.load_data() if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() if len(ext.data) > 0: return ext - # If extension run not completed, or data has gone missing, # return None to indicate that the extension should be (re)computed. return None @@ -1874,15 +1884,18 @@ def load_run_info(self): if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() run_info_file = extension_folder / "run_info.json" - assert run_info_file.is_file(), f"No run_info file in extension {self.extension_name} folder" - with open(str(run_info_file), "r") as f: - run_info = json.load(f) + if run_info_file.is_file(): + with open(str(run_info_file), "r") as f: + run_info = json.load(f) + else: + warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") + run_info = None elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") - assert "run_info" in extension_group.attrs, f"No run_info file in extension {self.extension_name} folder" - run_info = extension_group.attrs["run_info"] - + run_info = extension_group.attrs.get("run_info", None) + if run_info is None: + warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") self.run_info = run_info def load_params(self): @@ -1902,7 +1915,6 @@ def load_params(self): def load_data(self): ext_data = None - if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() for ext_data_file in extension_folder.iterdir(): @@ -1922,6 +1934,7 @@ def load_data(self): # and have a link to the old buffer on windows then it fails # ext_data = np.load(ext_data_file, mmap_mode="r") # so we go back to full loading + print(f"{ext_data_file} is numpy!") ext_data = np.load(ext_data_file) elif ext_data_file.suffix == ".csv": import pandas as pd @@ -1931,6 +1944,7 @@ def load_data(self): ext_data = pickle.load(ext_data_file.open("rb")) else: continue + self.data[ext_data_name] = ext_data elif self.format == "zarr": extension_group = self._get_zarr_extension_group(mode="r") @@ -1950,21 +1964,20 @@ def load_data(self): else: # this load in memmory ext_data = np.array(ext_data_) + self.data[ext_data_name] = ext_data - if ext_data is not None: - self.data[ext_data_name] = ext_data - else: + if len(self.data) == 0: warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() - new_extension.run_info = self.run_info.copy() # TODO: does copy() assume both extensions have been run? if unit_ids is None: new_extension.data = self.data else: new_extension.data = self._select_extension_data(unit_ids) + new_extension.run_info = self.run_info.copy() new_extension.save() return new_extension @@ -1979,10 +1992,10 @@ def merge( ): new_extension = self.__class__(new_sorting_analyzer) new_extension.params = self.params.copy() - new_extension.run_info = self.run_info.copy() # TODO: does merge() assume both extensions have been run? new_extension.data = self._merge_extension_data( merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs ) + new_extension.run_info = self.run_info.copy() new_extension.save() return new_extension @@ -1993,10 +2006,10 @@ def run(self, save=True, **kwargs): self._save_importing_provenance() self._save_run_info() - start = perf_counter() + t_start = perf_counter() self._run(**kwargs) - end = perf_counter() - self.run_info["runtime_s"] = np.round(end - start, 1) + t_end = perf_counter() + self.run_info["runtime_s"] = np.round(t_end - t_start, 1) if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 689073d6bf..3f45487f4c 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -125,6 +125,39 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): ) +def test_load_without_runtime_info(tmp_path, dataset): + recording, sorting = dataset + + folder = tmp_path / "test_SortingAnalyzer_run_info" + + extensions = ["random_spikes", "templates"] + # binary_folder + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None + ) + sorting_analyzer.compute(extensions) + # remove run_info.json to mimic a previous version of spikeinterface + for ext in extensions: + (folder / "extensions" / ext / "run_info.json").unlink() + # should raise a warning for missing run_info + with pytest.warns(UserWarning): + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + + # zarr + folder = tmp_path / "test_SortingAnalyzer_run_info.zarr" + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + ) + sorting_analyzer.compute(extensions) + # remove run_info from attrs to mimic a previous version of spikeinterface + root = sorting_analyzer._get_zarr_root(mode="r+") + for ext in extensions: + del root["extensions"][ext].attrs["run_info"] + # should raise a warning for missing run_info + with pytest.warns(UserWarning): + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + + def test_SortingAnalyzer_tmp_recording(dataset): recording, sorting = dataset recording_cached = recording.save(mode="memory") From 8240ee9281186ebb47a0cbdf7317d05b6d684a65 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 4 Sep 2024 19:17:57 +0200 Subject: [PATCH 13/15] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 06879d09a2..95c1a290ba 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2021,9 +2021,7 @@ def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) - self.run_info["run_completed"] = ( - True # extensions that go through compute_several_extensions() and then run_node_pipeline() never have ext_instance.run() called, so need to change run_completed here (or somewhere, at least) - ) + self.run_info["run_completed"] = True self._save_run_info() def _save_data(self, **kwargs): From 5769eff8f6697eac93780ef4ded75c969b22f0ae Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 12:18:57 +0200 Subject: [PATCH 14/15] Update src/spikeinterface/core/sortinganalyzer.py --- src/spikeinterface/core/sortinganalyzer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 95c1a290ba..970cd150fc 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1934,7 +1934,6 @@ def load_data(self): # and have a link to the old buffer on windows then it fails # ext_data = np.load(ext_data_file, mmap_mode="r") # so we go back to full loading - print(f"{ext_data_file} is numpy!") ext_data = np.load(ext_data_file) elif ext_data_file.suffix == ".csv": import pandas as pd From cc21f06492f3d84c476814e0b09f19b2ca84ceeb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 10 Sep 2024 12:22:41 +0200 Subject: [PATCH 15/15] Apply suggestions from code review --- src/spikeinterface/core/sortinganalyzer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 970cd150fc..57c9e5f37c 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1349,12 +1349,13 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job ) t_end = perf_counter() # for pipeline node extensions we can only track the runtime of the run_node_pipeline - runtime_s = np.round(t_end - t_start, 1) + runtime_s = t_end - t_start for r, result in enumerate(results): extension_name, variable_name = result_routage[r] extension_instances[extension_name].data[variable_name] = result extension_instances[extension_name].run_info["runtime_s"] = runtime_s + extension_instances[extension_name].run_info["run_completed"] = True for extension_name, extension_instance in extension_instances.items(): self.extensions[extension_name] = extension_instance @@ -2008,7 +2009,7 @@ def run(self, save=True, **kwargs): t_start = perf_counter() self._run(**kwargs) t_end = perf_counter() - self.run_info["runtime_s"] = np.round(t_end - t_start, 1) + self.run_info["runtime_s"] = t_end - t_start if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) @@ -2020,7 +2021,6 @@ def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) - self.run_info["run_completed"] = True self._save_run_info() def _save_data(self, **kwargs):