diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 817c453a97..49a31738e3 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -11,6 +11,7 @@ import shutil import warnings import importlib +from time import perf_counter import numpy as np @@ -1336,6 +1337,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys()) + t_start = perf_counter() results = run_node_pipeline( self.recording, all_nodes, @@ -1345,10 +1347,15 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job squeeze_output=False, verbose=verbose, ) + t_end = perf_counter() + # for pipeline node extensions we can only track the runtime of the run_node_pipeline + runtime_s = t_end - t_start for r, result in enumerate(results): extension_name, variable_name = result_routage[r] extension_instances[extension_name].data[variable_name] = result + extension_instances[extension_name].run_info["runtime_s"] = runtime_s + extension_instances[extension_name].run_info["run_completed"] = True for extension_name, extension_instance in extension_instances.items(): self.extensions[extension_name] = extension_instance @@ -1738,8 +1745,12 @@ def __init__(self, sorting_analyzer): self._sorting_analyzer = weakref.ref(sorting_analyzer) self.params = None + self.run_info = self._default_run_info_dict() self.data = dict() + def _default_run_info_dict(self): + return dict(run_completed=False, runtime_s=None) + ####### # This 3 methods must be implemented in the subclass!!! # See DummyAnalyzerExtension in test_sortinganalyzer.py as a simple example @@ -1851,11 +1862,42 @@ def _get_zarr_extension_group(self, mode="r+"): def load(cls, sorting_analyzer): ext = cls(sorting_analyzer) ext.load_params() - ext.load_data() - if cls.need_backward_compatibility_on_load: - ext._handle_backward_compatibility_on_load() + ext.load_run_info() + if ext.run_info is not None: + if ext.run_info["run_completed"]: + ext.load_data() + if cls.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + if len(ext.data) > 0: + return ext + else: + # this is for back-compatibility of old analyzers + ext.load_data() + if cls.need_backward_compatibility_on_load: + ext._handle_backward_compatibility_on_load() + if len(ext.data) > 0: + return ext + # If extension run not completed, or data has gone missing, + # return None to indicate that the extension should be (re)computed. + return None + + def load_run_info(self): + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + if run_info_file.is_file(): + with open(str(run_info_file), "r") as f: + run_info = json.load(f) + else: + warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") + run_info = None - return ext + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r") + run_info = extension_group.attrs.get("run_info", None) + if run_info is None: + warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.") + self.run_info = run_info def load_params(self): if self.format == "binary_folder": @@ -1873,12 +1915,17 @@ def load_params(self): self.params = params def load_data(self): + ext_data = None if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() for ext_data_file in extension_folder.iterdir(): # patch for https://github.com/SpikeInterface/spikeinterface/issues/3041 # maybe add a check for version number from the info.json during loading only - if ext_data_file.name == "params.json" or ext_data_file.name == "info.json": + if ( + ext_data_file.name == "params.json" + or ext_data_file.name == "info.json" + or ext_data_file.name == "run_info.json" + ): continue ext_data_name = ext_data_file.stem if ext_data_file.suffix == ".json": @@ -1919,6 +1966,9 @@ def load_data(self): ext_data = np.array(ext_data_) self.data[ext_data_name] = ext_data + if len(self.data) == 0: + warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.") + def copy(self, new_sorting_analyzer, unit_ids=None): # alessio : please note that this also replace the old select_units!!! new_extension = self.__class__(new_sorting_analyzer) @@ -1927,6 +1977,7 @@ def copy(self, new_sorting_analyzer, unit_ids=None): new_extension.data = self.data else: new_extension.data = self._select_extension_data(unit_ids) + new_extension.run_info = self.run_info.copy() new_extension.save() return new_extension @@ -1944,24 +1995,33 @@ def merge( new_extension.data = self._merge_extension_data( merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs ) + new_extension.run_info = self.run_info.copy() new_extension.save() return new_extension def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): - # this also reset the folder or zarr group + # NB: this call to _save_params() also resets the folder or zarr group self._save_params() self._save_importing_provenance() + self._save_run_info() + t_start = perf_counter() self._run(**kwargs) + t_end = perf_counter() + self.run_info["runtime_s"] = t_end - t_start if save and not self.sorting_analyzer.is_read_only(): self._save_data(**kwargs) + self.run_info["run_completed"] = True + self._save_run_info() + def save(self, **kwargs): self._save_params() self._save_importing_provenance() self._save_data(**kwargs) + self._save_run_info() def _save_data(self, **kwargs): if self.format == "memory": @@ -2060,6 +2120,7 @@ def reset(self): """ self._reset_extension_folder() self.params = None + self.run_info = self._default_run_info_dict() self.data = dict() def set_params(self, save=True, **params): @@ -2080,6 +2141,7 @@ def set_params(self, save=True, **params): if save: self._save_params() self._save_importing_provenance() + self._save_run_info() def _save_params(self): params_to_save = self.params.copy() @@ -2117,6 +2179,17 @@ def _save_importing_provenance(self): extension_group = self._get_zarr_extension_group(mode="r+") extension_group.attrs["info"] = info + def _save_run_info(self): + run_info = self.run_info.copy() + + if self.format == "binary_folder": + extension_folder = self._get_binary_extension_folder() + run_info_file = extension_folder / "run_info.json" + run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8") + elif self.format == "zarr": + extension_group = self._get_zarr_extension_group(mode="r+") + extension_group.attrs["run_info"] = run_info + def get_pipeline_nodes(self): assert ( self.use_nodepipeline @@ -2124,7 +2197,10 @@ def get_pipeline_nodes(self): return self._get_pipeline_nodes() def get_data(self, *args, **kwargs): - assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data" + assert self.run_info[ + "run_completed" + ], f"You must run the extension {self.extension_name} before retrieving data" + assert len(self.data) > 0, "Extension has been run but no data found." return self._get_data(*args, **kwargs) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index 689073d6bf..3f45487f4c 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -125,6 +125,39 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): ) +def test_load_without_runtime_info(tmp_path, dataset): + recording, sorting = dataset + + folder = tmp_path / "test_SortingAnalyzer_run_info" + + extensions = ["random_spikes", "templates"] + # binary_folder + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None + ) + sorting_analyzer.compute(extensions) + # remove run_info.json to mimic a previous version of spikeinterface + for ext in extensions: + (folder / "extensions" / ext / "run_info.json").unlink() + # should raise a warning for missing run_info + with pytest.warns(UserWarning): + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + + # zarr + folder = tmp_path / "test_SortingAnalyzer_run_info.zarr" + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + ) + sorting_analyzer.compute(extensions) + # remove run_info from attrs to mimic a previous version of spikeinterface + root = sorting_analyzer._get_zarr_root(mode="r+") + for ext in extensions: + del root["extensions"][ext].attrs["run_info"] + # should raise a warning for missing run_info + with pytest.warns(UserWarning): + sorting_analyzer = load_sorting_analyzer(folder, format="auto") + + def test_SortingAnalyzer_tmp_recording(dataset): recording, sorting = dataset recording_cached = recording.save(mode="memory")