Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pandas save load and convert dtypes #3412

Merged
merged 7 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ preprocessing = [
full = [
"h5py",
"pandas",
"xarray",
"scipy",
"scikit-learn",
"networkx",
Expand Down Expand Up @@ -148,7 +147,6 @@ test = [
"pytest-dependency",
"pytest-cov",

"xarray",
"huggingface_hub",

# preprocessing
Expand Down Expand Up @@ -193,7 +191,6 @@ docs = [
"pandas", # in the modules gallery comparison tutorial
"hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous
"numba", # For many postprocessing functions
"xarray", # For use of SortingAnalyzer zarr format
"networkx",
# Download data
"pooch>=1.8.2",
Expand Down
59 changes: 42 additions & 17 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import warnings
import importlib
from packaging.version import parse
from time import perf_counter

import numpy as np
Expand Down Expand Up @@ -579,6 +580,20 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None):

zarr_root = zarr.open_consolidated(str(folder), mode="r", storage_options=storage_options)

si_info = zarr_root.attrs["spikeinterface_info"]
if parse(si_info["version"]) < parse("0.101.1"):
# v0.101.0 did not have a consolidate metadata step after computing extensions.
# Here we try to consolidate the metadata and throw a warning if it fails.
try:
zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options)
zarr.consolidate_metadata(zarr_root_a.store)
except Exception as e:
warnings.warn(
"The zarr store was not properly consolidated prior to v0.101.1. "
"This may lead to unexpected behavior in loading extensions. "
"Please consider re-generating the SortingAnalyzer object."
)

# load internal sorting in memory
sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
Expand Down Expand Up @@ -1970,12 +1985,14 @@ def load_data(self):
if "dict" in ext_data_.attrs:
ext_data = ext_data_[0]
elif "dataframe" in ext_data_.attrs:
import xarray
import pandas as pd

ext_data = xarray.open_zarr(
ext_data_.store, group=f"{extension_group.name}/{ext_data_name}"
).to_pandas()
ext_data.index.rename("", inplace=True)
index = ext_data_["index"]
ext_data = pd.DataFrame(index=index)
for col in ext_data_.keys():
if col != "index":
ext_data.loc[:, col] = ext_data_[col][:]
ext_data = ext_data.convert_dtypes()
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
Expand Down Expand Up @@ -2031,12 +2048,21 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
self._save_run_info()
self._save_data(**kwargs)
if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)
self._save_run_info()
self._save_data(**kwargs)

if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _save_data(self, **kwargs):
if self.format == "memory":
Expand Down Expand Up @@ -2096,12 +2122,12 @@ def _save_data(self, **kwargs):
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=extension_group.store,
group=f"{extension_group.name}/{ext_data_name}",
mode="a",
)
extension_group[ext_data_name].attrs["dataframe"] = True
df_group = extension_group.create_group(ext_data_name)
# first we save the index
df_group.create_dataset(name="index", data=ext_data.index.to_numpy())
for col in ext_data.columns:
df_group.create_dataset(name=col, data=ext_data[col].to_numpy())
df_group.attrs["dataframe"] = True
else:
# any object
try:
Expand All @@ -2111,8 +2137,6 @@ def _save_data(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
# we need to re-consolidate
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _reset_extension_folder(self):
"""
Expand Down Expand Up @@ -2240,9 +2264,10 @@ def get_pipeline_nodes(self):
return self._get_pipeline_nodes()

def get_data(self, *args, **kwargs):
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
if self.run_info is not None:
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
assert len(self.data) > 0, "Extension has been run but no data found."
return self._get_data(*args, **kwargs)

Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
value = np.nan
template_metrics.at[index, metric_name] = value

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
template_metrics = template_metrics.convert_dtypes()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I think that's great!

return template_metrics

def _run(self, verbose=False):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from spikeinterface.core import generate_ground_truth_recording
from spikeinterface.core import create_sorting_analyzer
from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer
from spikeinterface.core import estimate_sparsity


Expand Down Expand Up @@ -116,6 +116,8 @@ def _check_one(self, sorting_analyzer, extension_class, params):
with the passed parameters, and check the output is not empty, the extension
exists and `select_units()` method works.
"""
import pandas as pd

if extension_class.need_job_kwargs:
job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True)
else:
Expand All @@ -138,6 +140,26 @@ def _check_one(self, sorting_analyzer, extension_class, params):
merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0)
assert len(merged.unit_ids) == num_units_after_merge

# test roundtrip
if sorting_analyzer.format in ("binary_folder", "zarr"):
sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder)
ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name)
for ext_data_name, ext_data_loaded in ext_loaded.data.items():
if isinstance(ext_data_loaded, np.ndarray):
assert np.array_equal(ext.data[ext_data_name], ext_data_loaded)
elif isinstance(ext_data_loaded, pd.DataFrame):
# skip nan values
for col in ext_data_loaded.columns:
np.testing.assert_array_almost_equal(
ext.data[ext_data_name][col].dropna().to_numpy(),
ext_data_loaded[col].dropna().to_numpy(),
decimal=5,
)
elif isinstance(ext_data_loaded, dict):
assert ext.data[ext_data_name] == ext_data_loaded
else:
continue

def run_extension_tests(self, extension_class, params):
"""
Convenience function to perform all checks on the extension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
"""
Compute quality metrics.
"""
import pandas as pd

qm_params = self.params["qm_params"]
# sparsity = self.params["sparsity"]
Expand All @@ -163,8 +164,6 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
non_empty_unit_ids = unit_ids
empty_unit_ids = []

import pandas as pd

metrics = pd.DataFrame(index=unit_ids)

# simple metrics not based on PCs
Expand Down Expand Up @@ -216,6 +215,9 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
if len(empty_unit_ids) > 0:
metrics.loc[empty_unit_ids] = np.nan

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
metrics = metrics.convert_dtypes()
return metrics

def _run(self, verbose=False, **job_kwargs):
Expand Down
Loading