Skip to content

Commit

Permalink
Merge pull request #3347 from jonahpearl/analyzer_extension_exit_status
Browse files Browse the repository at this point in the history
Analyzer extension exit status
  • Loading branch information
samuelgarcia authored Sep 11, 2024
2 parents 096d91a + cc21f06 commit 469187a
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 7 deletions.
90 changes: 83 additions & 7 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import warnings
import importlib
from time import perf_counter

import numpy as np

Expand Down Expand Up @@ -1336,6 +1337,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job

job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys())

t_start = perf_counter()
results = run_node_pipeline(
self.recording,
all_nodes,
Expand All @@ -1345,10 +1347,15 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job
squeeze_output=False,
verbose=verbose,
)
t_end = perf_counter()
# for pipeline node extensions we can only track the runtime of the run_node_pipeline
runtime_s = t_end - t_start

for r, result in enumerate(results):
extension_name, variable_name = result_routage[r]
extension_instances[extension_name].data[variable_name] = result
extension_instances[extension_name].run_info["runtime_s"] = runtime_s
extension_instances[extension_name].run_info["run_completed"] = True

for extension_name, extension_instance in extension_instances.items():
self.extensions[extension_name] = extension_instance
Expand Down Expand Up @@ -1738,8 +1745,12 @@ def __init__(self, sorting_analyzer):
self._sorting_analyzer = weakref.ref(sorting_analyzer)

self.params = None
self.run_info = self._default_run_info_dict()
self.data = dict()

def _default_run_info_dict(self):
return dict(run_completed=False, runtime_s=None)

#######
# This 3 methods must be implemented in the subclass!!!
# See DummyAnalyzerExtension in test_sortinganalyzer.py as a simple example
Expand Down Expand Up @@ -1851,11 +1862,42 @@ def _get_zarr_extension_group(self, mode="r+"):
def load(cls, sorting_analyzer):
ext = cls(sorting_analyzer)
ext.load_params()
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
ext.load_run_info()
if ext.run_info is not None:
if ext.run_info["run_completed"]:
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
return ext
else:
# this is for back-compatibility of old analyzers
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
return ext
# If extension run not completed, or data has gone missing,
# return None to indicate that the extension should be (re)computed.
return None

def load_run_info(self):
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
if run_info_file.is_file():
with open(str(run_info_file), "r") as f:
run_info = json.load(f)
else:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
run_info = None

return ext
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r")
run_info = extension_group.attrs.get("run_info", None)
if run_info is None:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
self.run_info = run_info

def load_params(self):
if self.format == "binary_folder":
Expand All @@ -1873,12 +1915,17 @@ def load_params(self):
self.params = params

def load_data(self):
ext_data = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
for ext_data_file in extension_folder.iterdir():
# patch for https://github.com/SpikeInterface/spikeinterface/issues/3041
# maybe add a check for version number from the info.json during loading only
if ext_data_file.name == "params.json" or ext_data_file.name == "info.json":
if (
ext_data_file.name == "params.json"
or ext_data_file.name == "info.json"
or ext_data_file.name == "run_info.json"
):
continue
ext_data_name = ext_data_file.stem
if ext_data_file.suffix == ".json":
Expand Down Expand Up @@ -1919,6 +1966,9 @@ def load_data(self):
ext_data = np.array(ext_data_)
self.data[ext_data_name] = ext_data

if len(self.data) == 0:
warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.")

def copy(self, new_sorting_analyzer, unit_ids=None):
# alessio : please note that this also replace the old select_units!!!
new_extension = self.__class__(new_sorting_analyzer)
Expand All @@ -1927,6 +1977,7 @@ def copy(self, new_sorting_analyzer, unit_ids=None):
new_extension.data = self.data
else:
new_extension.data = self._select_extension_data(unit_ids)
new_extension.run_info = self.run_info.copy()
new_extension.save()
return new_extension

Expand All @@ -1944,24 +1995,33 @@ def merge(
new_extension.data = self._merge_extension_data(
merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs
)
new_extension.run_info = self.run_info.copy()
new_extension.save()
return new_extension

def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
# this also reset the folder or zarr group
# NB: this call to _save_params() also resets the folder or zarr group
self._save_params()
self._save_importing_provenance()
self._save_run_info()

t_start = perf_counter()
self._run(**kwargs)
t_end = perf_counter()
self.run_info["runtime_s"] = t_end - t_start

if save and not self.sorting_analyzer.is_read_only():
self._save_data(**kwargs)

self.run_info["run_completed"] = True
self._save_run_info()

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)
self._save_run_info()

def _save_data(self, **kwargs):
if self.format == "memory":
Expand Down Expand Up @@ -2060,6 +2120,7 @@ def reset(self):
"""
self._reset_extension_folder()
self.params = None
self.run_info = self._default_run_info_dict()
self.data = dict()

def set_params(self, save=True, **params):
Expand All @@ -2080,6 +2141,7 @@ def set_params(self, save=True, **params):
if save:
self._save_params()
self._save_importing_provenance()
self._save_run_info()

def _save_params(self):
params_to_save = self.params.copy()
Expand Down Expand Up @@ -2117,14 +2179,28 @@ def _save_importing_provenance(self):
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["info"] = info

def _save_run_info(self):
run_info = self.run_info.copy()

if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8")
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["run_info"] = run_info

def get_pipeline_nodes(self):
assert (
self.use_nodepipeline
), "AnalyzerExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True"
return self._get_pipeline_nodes()

def get_data(self, *args, **kwargs):
assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data"
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
assert len(self.data) > 0, "Extension has been run but no data found."
return self._get_data(*args, **kwargs)


Expand Down
33 changes: 33 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,39 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
)


def test_load_without_runtime_info(tmp_path, dataset):
recording, sorting = dataset

folder = tmp_path / "test_SortingAnalyzer_run_info"

extensions = ["random_spikes", "templates"]
# binary_folder
sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None
)
sorting_analyzer.compute(extensions)
# remove run_info.json to mimic a previous version of spikeinterface
for ext in extensions:
(folder / "extensions" / ext / "run_info.json").unlink()
# should raise a warning for missing run_info
with pytest.warns(UserWarning):
sorting_analyzer = load_sorting_analyzer(folder, format="auto")

# zarr
folder = tmp_path / "test_SortingAnalyzer_run_info.zarr"
sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None
)
sorting_analyzer.compute(extensions)
# remove run_info from attrs to mimic a previous version of spikeinterface
root = sorting_analyzer._get_zarr_root(mode="r+")
for ext in extensions:
del root["extensions"][ext].attrs["run_info"]
# should raise a warning for missing run_info
with pytest.warns(UserWarning):
sorting_analyzer = load_sorting_analyzer(folder, format="auto")


def test_SortingAnalyzer_tmp_recording(dataset):
recording, sorting = dataset
recording_cached = recording.save(mode="memory")
Expand Down

0 comments on commit 469187a

Please sign in to comment.