Skip to content

Commit

Permalink
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
Browse files Browse the repository at this point in the history
…prepare_release
  • Loading branch information
alejoe91 committed Sep 16, 2024
2 parents f2d1f07 + a36615a commit 719f53f
Show file tree
Hide file tree
Showing 12 changed files with 489 additions and 169 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ preprocessing = [
full = [
"h5py",
"pandas",
"xarray",
"scipy",
"scikit-learn",
"networkx",
Expand Down Expand Up @@ -148,7 +147,6 @@ test = [
"pytest-dependency",
"pytest-cov",

"xarray",
"huggingface_hub",

# preprocessing
Expand Down Expand Up @@ -193,7 +191,6 @@ docs = [
"pandas", # in the modules gallery comparison tutorial
"hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous
"numba", # For many postprocessing functions
"xarray", # For use of SortingAnalyzer zarr format
"networkx",
# Download data
"pooch>=1.8.2",
Expand Down
59 changes: 42 additions & 17 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import warnings
import importlib
from packaging.version import parse
from time import perf_counter

import numpy as np
Expand Down Expand Up @@ -579,6 +580,20 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None):

zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options)

si_info = zarr_root.attrs["spikeinterface_info"]
if parse(si_info["version"]) < parse("0.101.1"):
# v0.101.0 did not have a consolidate metadata step after computing extensions.
# Here we try to consolidate the metadata and throw a warning if it fails.
try:
zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options)
zarr.consolidate_metadata(zarr_root_a.store)
except Exception as e:
warnings.warn(
"The zarr store was not properly consolidated prior to v0.101.1. "
"This may lead to unexpected behavior in loading extensions. "
"Please consider re-generating the SortingAnalyzer object."
)

# load internal sorting in memory
sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
Expand Down Expand Up @@ -1970,12 +1985,14 @@ def load_data(self):
if "dict" in ext_data_.attrs:
ext_data = ext_data_[0]
elif "dataframe" in ext_data_.attrs:
import xarray
import pandas as pd

ext_data = xarray.open_zarr(
ext_data_.store, group=f"{extension_group.name}/{ext_data_name}"
).to_pandas()
ext_data.index.rename("", inplace=True)
index = ext_data_["index"]
ext_data = pd.DataFrame(index=index)
for col in ext_data_.keys():
if col != "index":
ext_data.loc[:, col] = ext_data_[col][:]
ext_data = ext_data.convert_dtypes()
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
Expand Down Expand Up @@ -2031,12 +2048,21 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
self._save_run_info()
self._save_data(**kwargs)
if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)
self._save_run_info()
self._save_data(**kwargs)

if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _save_data(self, **kwargs):
if self.format == "memory":
Expand Down Expand Up @@ -2096,12 +2122,12 @@ def _save_data(self, **kwargs):
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=extension_group.store,
group=f"{extension_group.name}/{ext_data_name}",
mode="a",
)
extension_group[ext_data_name].attrs["dataframe"] = True
df_group = extension_group.create_group(ext_data_name)
# first we save the index
df_group.create_dataset(name="index", data=ext_data.index.to_numpy())
for col in ext_data.columns:
df_group.create_dataset(name=col, data=ext_data[col].to_numpy())
df_group.attrs["dataframe"] = True
else:
# any object
try:
Expand All @@ -2111,8 +2137,6 @@ def _save_data(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
# we need to re-consolidate
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _reset_extension_folder(self):
"""
Expand Down Expand Up @@ -2240,9 +2264,10 @@ def get_pipeline_nodes(self):
return self._get_pipeline_nodes()

def get_data(self, *args, **kwargs):
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
if self.run_info is not None:
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
assert len(self.data) > 0, "Extension has been run but no data found."
return self._get_data(*args, **kwargs)

Expand Down
Loading

0 comments on commit 719f53f

Please sign in to comment.