Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pandas save load and convert dtypes #3412

Merged
merged 7 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ preprocessing = [
full = [
"h5py",
"pandas",
"xarray",
"scipy",
"scikit-learn",
"networkx",
Expand Down Expand Up @@ -148,7 +147,6 @@ test = [
"pytest-dependency",
"pytest-cov",

"xarray",
"huggingface_hub",

# preprocessing
Expand Down Expand Up @@ -193,7 +191,6 @@ docs = [
"pandas", # in the modules gallery comparison tutorial
"hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous
"numba", # For many postprocessing functions
"xarray", # For use of SortingAnalyzer zarr format
"networkx",
# Download data
"pooch>=1.8.2",
Expand Down
44 changes: 27 additions & 17 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,12 +1970,14 @@ def load_data(self):
if "dict" in ext_data_.attrs:
ext_data = ext_data_[0]
elif "dataframe" in ext_data_.attrs:
import xarray
import pandas as pd

ext_data = xarray.open_zarr(
ext_data_.store, group=f"{extension_group.name}/{ext_data_name}"
).to_pandas()
ext_data.index.rename("", inplace=True)
index = ext_data_["index"]
ext_data = pd.DataFrame(index=index)
for col in ext_data_.keys():
if col != "index":
ext_data.loc[:, col] = ext_data_[col][:]
ext_data = ext_data.convert_dtypes()
elif "object" in ext_data_.attrs:
ext_data = ext_data_[0]
else:
Expand Down Expand Up @@ -2031,12 +2033,21 @@ def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
self._save_run_info()
self._save_data(**kwargs)
if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)
self._save_run_info()
self._save_data(**kwargs)

if self.format == "zarr":
import zarr

zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _save_data(self, **kwargs):
if self.format == "memory":
Expand Down Expand Up @@ -2096,12 +2107,12 @@ def _save_data(self, **kwargs):
elif isinstance(ext_data, np.ndarray):
extension_group.create_dataset(name=ext_data_name, data=ext_data, compressor=compressor)
elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame):
ext_data.to_xarray().to_zarr(
store=extension_group.store,
group=f"{extension_group.name}/{ext_data_name}",
mode="a",
)
extension_group[ext_data_name].attrs["dataframe"] = True
df_group = extension_group.create_group(ext_data_name)
# first we save the index
df_group.create_dataset(name="index", data=ext_data.index.to_numpy())
for col in ext_data.columns:
df_group.create_dataset(name=col, data=ext_data[col].to_numpy())
df_group.attrs["dataframe"] = True
else:
# any object
try:
Expand All @@ -2111,8 +2122,6 @@ def _save_data(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")
extension_group[ext_data_name].attrs["object"] = True
# we need to re-consolidate
zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store)

def _reset_extension_folder(self):
"""
Expand Down Expand Up @@ -2240,9 +2249,10 @@ def get_pipeline_nodes(self):
return self._get_pipeline_nodes()

def get_data(self, *args, **kwargs):
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
if self.run_info is not None:
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
assert len(self.data) > 0, "Extension has been run but no data found."
return self._get_data(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
value = np.nan
template_metrics.at[index, metric_name] = value
return template_metrics
return template_metrics.convert_dtypes()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a weak recommendation ,but maybe we put this on its own line with a comment. Just from reading this I have no clue why we need to do this and doing this in the return line is even more confusing. So something like

# see xx
template_metrics.convert_dtypes()
return template_metrics

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better?


def _run(self, verbose=False):
self.data["metrics"] = self._compute_metrics(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from spikeinterface.core import generate_ground_truth_recording
from spikeinterface.core import create_sorting_analyzer
from spikeinterface.core import create_sorting_analyzer, load_sorting_analyzer
from spikeinterface.core import estimate_sparsity


Expand Down Expand Up @@ -116,6 +116,8 @@ def _check_one(self, sorting_analyzer, extension_class, params):
with the passed parameters, and check the output is not empty, the extension
exists and `select_units()` method works.
"""
import pandas as pd

if extension_class.need_job_kwargs:
job_kwargs = dict(n_jobs=2, chunk_duration="1s", progress_bar=True)
else:
Expand All @@ -138,6 +140,26 @@ def _check_one(self, sorting_analyzer, extension_class, params):
merged = sorting_analyzer.merge_units(some_merges, format="memory", merging_mode="soft", sparsity_overlap=0.0)
assert len(merged.unit_ids) == num_units_after_merge

# test roundtrip
if sorting_analyzer.format in ("binary_folder", "zarr"):
sorting_analyzer_loaded = load_sorting_analyzer(sorting_analyzer.folder)
ext_loaded = sorting_analyzer_loaded.get_extension(extension_class.extension_name)
for ext_data_name, ext_data_loaded in ext_loaded.data.items():
if isinstance(ext_data_loaded, np.ndarray):
assert np.array_equal(ext.data[ext_data_name], ext_data_loaded)
elif isinstance(ext_data_loaded, pd.DataFrame):
# skip nan values
for col in ext_data_loaded.columns:
np.testing.assert_array_almost_equal(
ext.data[ext_data_name][col].dropna().to_numpy(),
ext_data_loaded[col].dropna().to_numpy(),
decimal=5,
)
elif isinstance(ext_data_loaded, dict):
assert ext.data[ext_data_name] == ext_data_loaded
else:
continue

def run_extension_tests(self, extension_class, params):
"""
Convenience function to perform all checks on the extension
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, **job
if len(empty_unit_ids) > 0:
metrics.loc[empty_unit_ids] = np.nan

return metrics
return metrics.convert_dtypes()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here. From the code it is not clear why we need to convert dtypes so I would refer to divide this into a convert step and then only return the converted. That way we can have a comment explaining why we need to convert.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment and convert step


def _run(self, verbose=False, **job_kwargs):
self.data["metrics"] = self._compute_metrics(
Expand Down
Loading