Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hierarchical coordinates in DataTree #9063

Merged
merged 25 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
95dbece
Inheritance of data coordinates
shoyer Jun 3, 2024
abf0574
Simplify __init__
shoyer Jun 12, 2024
dfed099
Merge branch 'main' into datatree-hierarchy
shoyer Jun 19, 2024
54364d0
Include path name in alignment errors
shoyer Jun 19, 2024
f72fcd5
Fix some mypy errors
shoyer Jun 20, 2024
c4722e0
mypy fix
shoyer Jun 20, 2024
c08e810
simplify DataTree data model
shoyer Jun 20, 2024
2351b97
Add to_dataset(local=True)
shoyer Jun 21, 2024
03fb991
Fix mypy failure in tests
shoyer Jun 21, 2024
a611fb1
Fix to_zarr for inherited coords
shoyer Jun 25, 2024
cef6cfa
Fix to_netcdf for heirarchical coords
shoyer Jun 25, 2024
51b2c43
Add ChainSet
shoyer Jun 26, 2024
dcb103b
Revise internal data model; remove ChainSet
shoyer Jun 27, 2024
b221420
fix repr tests for inherited coords
shoyer Jun 27, 2024
34c1fb2
add another way to construct inherited indexes
shoyer Jun 27, 2024
7767634
Finish refactoring error message
shoyer Jun 28, 2024
5b038ba
Merge branch 'main' into datatree-hierarchy
shoyer Jun 28, 2024
c282e62
include inherited dimensions in HTML repr, too
shoyer Jun 28, 2024
6595e2f
Construct ChainMap objects on demand.
shoyer Jun 30, 2024
ebeec84
Merge branch 'main' into datatree-hierarchy
shoyer Jun 30, 2024
88e5f4a
slightly better error message with mis-aligned data trees
shoyer Jun 30, 2024
1f6c7b4
mypy fix
shoyer Jul 1, 2024
7bb5d86
use float64 instead of float32 for windows
shoyer Jul 1, 2024
57e598e
clean-up per review
shoyer Jul 1, 2024
6cfd8ff
Add note about inheritance to .ds docs
shoyer Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
441 changes: 226 additions & 215 deletions xarray/core/datatree.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions xarray/core/datatree_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _datatree_to_netcdf(
unlimited_dims = {}

for node in dt.subtree:
ds = node.ds
ds = node.to_dataset(inherited=False)
group_path = node.path
if ds is None:
_create_empty_netcdf_group(filepath, group_path, mode, engine)
Expand Down Expand Up @@ -151,7 +151,7 @@ def _datatree_to_zarr(
)

for node in dt.subtree:
ds = node.ds
ds = node.to_dataset(inherited=False)
group_path = node.path
if ds is None:
_create_empty_zarr_group(store, group_path, mode)
Expand Down
23 changes: 22 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,27 @@ def dataset_repr(ds):
return "\n".join(summary)


def dims_and_coords_repr(ds) -> str:
"""Partial Dataset repr for use inside DataTree inheritance errors."""
summary = []

col_width = _calculate_col_width(ds.coords)
max_rows = OPTIONS["display_max_rows"]

dims_start = pretty_print("Dimensions:", col_width)
dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows)
summary.append(f"{dims_start}({dims_values})")

if ds.coords:
summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows))

unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows)
if unindexed_dims_str:
summary.append(unindexed_dims_str)

return "\n".join(summary)


def diff_dim_summary(a, b):
if a.sizes != b.sizes:
return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})"
Expand Down Expand Up @@ -1030,7 +1051,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
if node.has_data or node.has_attrs:
ds_info = "\n" + repr(node.ds)
ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False))
else:
ds_info = ""
return f"Group: {node.path}{ds_info}"
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
def datatree_node_repr(group_title: str, dt: DataTree) -> str:
header_components = [f"<div class='xr-obj-type'>{escape(group_title)}</div>"]

ds = dt.ds
ds = dt._to_dataset_view(rebuild_dims=False)

sections = [
children_section(dt.children),
Expand Down
16 changes: 9 additions & 7 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def _attach(self, parent: Tree | None, child_name: str | None = None) -> None:
"To directly set parent, child needs a name, but child is unnamed"
)

self._pre_attach(parent)
self._pre_attach(parent, child_name)
parentchildren = parent._children
assert not any(
child is self for child in parentchildren
), "Tree is corrupt."
parentchildren[child_name] = self
self._parent = parent
self._post_attach(parent)
self._post_attach(parent, child_name)
else:
self._parent = None

Expand Down Expand Up @@ -415,11 +415,11 @@ def _post_detach(self: Tree, parent: Tree) -> None:
"""Method call after detaching from `parent`."""
pass

def _pre_attach(self: Tree, parent: Tree) -> None:
def _pre_attach(self: Tree, parent: Tree, name: str) -> None:
shoyer marked this conversation as resolved.
Show resolved Hide resolved
"""Method call before attaching to `parent`."""
pass

def _post_attach(self: Tree, parent: Tree) -> None:
def _post_attach(self: Tree, parent: Tree, name: str) -> None:
"""Method call after attaching to `parent`."""
pass

Expand Down Expand Up @@ -567,6 +567,9 @@ def same_tree(self, other: Tree) -> bool:
return self.root is other.root


AnyNamedNode = TypeVar("AnyNamedNode", bound="NamedNode")


class NamedNode(TreeNode, Generic[Tree]):
"""
A TreeNode which knows its own name.
Expand Down Expand Up @@ -606,10 +609,9 @@ def __repr__(self, level=0):
def __str__(self) -> str:
return f"NamedNode('{self.name}')" if self.name else "NamedNode()"

def _post_attach(self: NamedNode, parent: NamedNode) -> None:
def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None:
"""Ensures child has name attribute corresponding to key under which it has been stored."""
key = next(k for k, v in parent.children.items() if v is self)
self.name = key
self.name = name

@property
def path(self) -> str:
Expand Down
42 changes: 37 additions & 5 deletions xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import pytest

import xarray as xr
from xarray.backends.api import open_datatree
from xarray.core.datatree import DataTree
from xarray.testing import assert_equal
from xarray.tests import (
requires_h5netcdf,
Expand All @@ -13,11 +15,11 @@
)

if TYPE_CHECKING:
from xarray.backends.api import T_NetcdfEngine
from xarray.core.datatree_io import T_DataTreeNetcdfEngine


class DatatreeIOBase:
engine: T_NetcdfEngine | None = None
engine: T_DataTreeNetcdfEngine | None = None

def test_to_netcdf(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.nc"
Expand All @@ -27,6 +29,21 @@ def test_to_netcdf(self, tmpdir, simple_datatree):
roundtrip_dt = open_datatree(filepath, engine=self.engine)
assert_equal(original_dt, roundtrip_dt)

def test_to_netcdf_inherited_coords(self, tmpdir):
filepath = tmpdir / "test.nc"
original_dt = DataTree.from_dict(
{
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
}
)
original_dt.to_netcdf(filepath, engine=self.engine)

roundtrip_dt = open_datatree(filepath, engine=self.engine)
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherited=False).coords

def test_netcdf_encoding(self, tmpdir, simple_datatree):
filepath = tmpdir / "test.nc"
original_dt = simple_datatree
Expand All @@ -48,12 +65,12 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree):

@requires_netCDF4
class TestNetCDF4DatatreeIO(DatatreeIOBase):
engine: T_NetcdfEngine | None = "netcdf4"
engine: T_DataTreeNetcdfEngine | None = "netcdf4"


@requires_h5netcdf
class TestH5NetCDFDatatreeIO(DatatreeIOBase):
engine: T_NetcdfEngine | None = "h5netcdf"
engine: T_DataTreeNetcdfEngine | None = "h5netcdf"


@requires_zarr
Expand Down Expand Up @@ -119,3 +136,18 @@ def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree):
# with default settings, to_zarr should not overwrite an existing dir
with pytest.raises(zarr.errors.ContainsGroupError):
simple_datatree.to_zarr(tmpdir)

def test_to_zarr_inherited_coords(self, tmpdir):
original_dt = DataTree.from_dict(
{
"/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}),
"/sub": xr.Dataset({"b": (("x",), [5, 6])}),
}
)
filepath = tmpdir / "test.zarr"
original_dt.to_zarr(filepath)

roundtrip_dt = open_datatree(filepath, engine="zarr")
assert_equal(original_dt, roundtrip_dt)
subtree = cast(DataTree, roundtrip_dt["/sub"])
assert "x" not in subtree.to_dataset(inherited=False).coords
Loading
Loading