From 95dbece896ef8a481f7efb25e39d70a89acafbca Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 3 Jun 2024 14:20:17 -0700 Subject: [PATCH 01/21] Inheritance of data coordinates --- xarray/core/datatree.py | 313 ++++++++++++++-------------------- xarray/tests/test_datatree.py | 77 ++++++++- 2 files changed, 205 insertions(+), 185 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 5737cdcb686..cb55ff89af5 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import itertools from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping from html import escape @@ -15,6 +14,7 @@ ) from xarray.core import utils +from xarray.core.alignment import align from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, DataVariables @@ -91,14 +91,35 @@ def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: return ds -def _check_for_name_collisions( - children: Iterable[str], variables: Iterable[Hashable] +def _collect_data_and_coord_variables( + data: Dataset, +) -> tuple[dict[Hashable, Variable], dict[Hashable, Variable]]: + data_variables = {} + coord_variables = {} + for k, v in data.variables.items(): + if k in data._coord_names: + coord_variables[k] = v + else: + data_variables[k] = v + return data_variables, coord_variables + + +def _check_alignment( + node_ds: Dataset, + parent_ds: Dataset | None, + children: Mapping[Hashable, DataTree], ) -> None: - colliding_names = set(children).intersection(set(variables)) - if colliding_names: - raise KeyError( - f"Some names would collide between variables and children: {list(colliding_names)}" - ) + if parent_ds is not None: + try: + align(node_ds, parent_ds, join="exact") + except ValueError as e: + raise ValueError( + "inconsistent alignment between node and parent datasets:\n" + f"{node_ds}\nvs\n{parent_ds}" + ) from e + + for child in children.values(): + _check_alignment(child.ds, node_ds, child.children) class DatasetView(Dataset): @@ -116,7 +137,7 @@ class DatasetView(Dataset): __slots__ = ( "_attrs", - "_cache", + "_cache", # is this used? "_coord_names", "_dims", "_encoding", @@ -335,27 +356,35 @@ class DataTree( _name: str | None _parent: DataTree | None _children: dict[str, DataTree] - _attrs: dict[Hashable, Any] | None - _cache: dict[str, Any] + _cache: dict[str, Any] # is this used? + _node_data_variables: dict[Hashable, Variable] + _node_coord_variables: dict[Hashable, Variable] + _node_dims: dict[Hashable, int] + _node_indexes: dict[Hashable, Index] + _variables: dict[Hashable, Variable] _coord_names: set[Hashable] _dims: dict[Hashable, int] + _indexes: dict[Hashable, Index] + _attrs: dict[Hashable, Any] | None _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None - _indexes: dict[Hashable, Index] - _variables: dict[Hashable, Variable] __slots__ = ( "_name", "_parent", "_children", - "_attrs", "_cache", + "_node_data_variables", + "_node_coord_variables", + "_node_dims", + "_node_indexes", + "_variables", "_coord_names", "_dims", + "_indexes", + "_attrs", "_encoding", "_close", - "_indexes", - "_variables", ) def __init__( @@ -368,14 +397,15 @@ def __init__( """ Create a single node of a DataTree. - The node may optionally contain data in the form of data and coordinate variables, stored in the same way as - data is stored in an xarray.Dataset. + The node may optionally contain data in the form of data and coordinate + variables, stored in the same way as data is stored in an + xarray.Dataset. Parameters ---------- data : Dataset, DataArray, or None, optional - Data to store under the .ds attribute of this node. DataArrays will be promoted to Datasets. - Default is None. + Data to store under the .ds attribute of this node. DataArrays will + be promoted to Datasets. Default is None. parent : DataTree, optional Parent node to this node. Default is None. children : Mapping[str, DataTree], optional @@ -391,31 +421,63 @@ def __init__( -------- DataTree.from_dict """ - - # validate input if children is None: children = {} - ds = _coerce_to_dataset(data) - _check_for_name_collisions(children, ds.variables) super().__init__(name=name) - # set data attributes - self._replace( - inplace=True, - variables=ds._variables, - coord_names=ds._coord_names, - dims=ds._dims, - indexes=ds._indexes, - attrs=ds._attrs, - encoding=ds._encoding, - ) - self._close = ds._close + # set tree attributes + self._children = {} + self._parent = None + self._set_node_data(data) - # set tree attributes (must happen after variables set to avoid initialization errors) - self.children = children + # finalize tree attributes + self.children = children # must set first self.parent = parent + def _set_node_data(self, data: Dataset | DataArray | None) -> None: + ds = _coerce_to_dataset(data) + data_vars, coord_vars = _collect_data_and_coord_variables(ds) + + # set node data attributes + self._data_variables = data_vars + self._node_coord_variables = coord_vars + self._node_dims = ds._dims + self._node_indexes = ds._indexes + self._attrs = ds._attrs + self._encoding = ds._encoding + self._close = ds._close + + # setup inherited node attributes (finalized by _post_attach) + self._variables = dict(data_vars) + self._variables.update(coord_vars) + self._coord_names = set(coord_vars) + self._dims = dict(ds._dims) + self._indexes = dict(ds._indexes) + + def _pre_attach(self: DataTree, parent: DataTree) -> None: + super()._pre_attach(parent) + if self.name in parent.ds.variables: + raise KeyError( + f"parent {parent.name} already contains a variable named {self.name}" + ) + _check_alignment(self.ds, parent.ds, self.children) + + def _post_attach_recursively(self: DataTree, parent: DataTree) -> None: + for k in parent._coord_names: + if k not in self._variables: + self._variables[k] = parent._variables[k] + self._coord_names.add(k) + self._dims.update(parent._dims) + self._indexes.update(parent._indexes) + + for child in self._children.values(): + child._post_attach_recursively(self) + + def _post_attach(self: DataTree, parent: DataTree) -> None: + super()._post_attach(parent) + self._post_attach_recursively(parent) + @property def parent(self: DataTree) -> DataTree | None: """Parent of this node.""" @@ -442,33 +504,8 @@ def ds(self) -> DatasetView: @ds.setter def ds(self, data: Dataset | DataArray | None = None) -> None: - # Known mypy issue for setters with different type to property: - # https://github.com/python/mypy/issues/3004 ds = _coerce_to_dataset(data) - - _check_for_name_collisions(self.children, ds.variables) - - self._replace( - inplace=True, - variables=ds._variables, - coord_names=ds._coord_names, - dims=ds._dims, - indexes=ds._indexes, - attrs=ds._attrs, - encoding=ds._encoding, - ) - self._close = ds._close - - def _pre_attach(self: DataTree, parent: DataTree) -> None: - """ - Method which superclass calls before setting parent, here used to prevent having two - children with duplicate names (or a data variable with the same name as a child). - """ - super()._pre_attach(parent) - if self.name in list(parent.ds.variables): - raise KeyError( - f"parent {parent.name} already contains a data variable named {self.name}" - ) + self._replace_node(ds) def to_dataset(self) -> Dataset: """ @@ -478,6 +515,7 @@ def to_dataset(self) -> Dataset: -------- DataTree.ds """ + # TODO: copy these container objects? return Dataset._construct_direct( self._variables, self._coord_names, @@ -644,122 +682,31 @@ def _repr_html_(self): return f"
{escape(repr(self))}
" return datatree_repr_html(self) - @classmethod - def _construct_direct( - cls, - variables: dict[Any, Variable], - coord_names: set[Hashable], - dims: dict[Any, int] | None = None, - attrs: dict | None = None, - indexes: dict[Any, Index] | None = None, - encoding: dict | None = None, - name: str | None = None, - parent: DataTree | None = None, - children: dict[str, DataTree] | None = None, - close: Callable[[], None] | None = None, - ) -> DataTree: - """Shortcut around __init__ for internal use when we want to skip costly validation.""" - - # data attributes - if dims is None: - dims = calculate_dimensions(variables) - if indexes is None: - indexes = {} - if children is None: - children = dict() - - obj: DataTree = object.__new__(cls) - obj._variables = variables - obj._coord_names = coord_names - obj._dims = dims - obj._indexes = indexes - obj._attrs = attrs - obj._close = close - obj._encoding = encoding + def _replace_node( + self: DataTree, + data: Dataset | Default = _default, + children: Mapping[str, DataTree] | Default = _default, + ) -> None: + if data is _default: + data = self.ds + if children is _default: + children = self.children - # tree attributes - obj._name = name - obj._children = children - obj._parent = parent + for child_name, child in children.items(): + if child_name in data.variables: + raise ValueError(f"node already contains a variable named {child_name}") - return obj + parent_ds = self.parent.ds if self.parent is not None else None + _check_alignment(data, parent_ds, children) - def _replace( - self: DataTree, - variables: dict[Hashable, Variable] | None = None, - coord_names: set[Hashable] | None = None, - dims: dict[Any, int] | None = None, - attrs: dict[Hashable, Any] | None | Default = _default, - indexes: dict[Hashable, Index] | None = None, - encoding: dict | None | Default = _default, - name: str | None | Default = _default, - parent: DataTree | None | Default = _default, - children: dict[str, DataTree] | None = None, - inplace: bool = False, - ) -> DataTree: - """ - Fastpath constructor for internal use. + self._children = children + self._set_node_data(data) - Returns an object with optionally replaced attributes. - - Explicitly passed arguments are *not* copied when placed on the new - datatree. It is up to the caller to ensure that they have the right type - and are not used elsewhere. - """ - # TODO Adding new children inplace using this method will cause bugs. - # You will end up with an inconsistency between the name of the child node and the key the child is stored under. - # Use ._set() instead for now - if inplace: - if variables is not None: - self._variables = variables - if coord_names is not None: - self._coord_names = coord_names - if dims is not None: - self._dims = dims - if attrs is not _default: - self._attrs = attrs - if indexes is not None: - self._indexes = indexes - if encoding is not _default: - self._encoding = encoding - if name is not _default: - self._name = name - if parent is not _default: - self._parent = parent - if children is not None: - self._children = children - obj = self + if self.parent is not None: + self._post_attach(self.parent) else: - if variables is None: - variables = self._variables.copy() - if coord_names is None: - coord_names = self._coord_names.copy() - if dims is None: - dims = self._dims.copy() - if attrs is _default: - attrs = copy.copy(self._attrs) - if indexes is None: - indexes = self._indexes.copy() - if encoding is _default: - encoding = copy.copy(self._encoding) - if name is _default: - name = self._name # no need to copy str objects or None - if parent is _default: - parent = copy.copy(self._parent) - if children is _default: - children = copy.copy(self._children) - obj = self._construct_direct( - variables, - coord_names, - dims, - attrs, - indexes, - encoding, - name, - parent, - children, - ) - return obj + for child in children.values(): + child._post_attach_recursively(self) def copy( self: DataTree, @@ -811,9 +758,8 @@ def _copy_node( deep: bool = False, ) -> DataTree: """Copy just one node of a tree""" - new_node: DataTree = DataTree() - new_node.name = self.name - new_node.ds = self.to_dataset().copy(deep=deep) # type: ignore[assignment] + data = self.ds.copy(deep=deep) + new_node = DataTree(data, name=self.name) return new_node def __copy__(self: DataTree) -> DataTree: @@ -961,11 +907,12 @@ def update( raise TypeError(f"Type {type(v)} cannot be assigned to a DataTree") vars_merge_result = dataset_update_method(self.to_dataset(), new_variables) + data = Dataset._construct_direct(**vars_merge_result._asdict()) + # TODO are there any subtleties with preserving order of children like this? merged_children = {**self.children, **new_children} - self._replace( - inplace=True, children=merged_children, **vars_merge_result._asdict() - ) + + self._replace_node(data, children=merged_children) def assign( self, items: Mapping[Any, Any] | None = None, **items_kwargs: Any @@ -1040,10 +987,12 @@ def drop_nodes( if extra: raise KeyError(f"Cannot drop all nodes - nodes {extra} not present") + result = self.copy() children_to_keep = { - name: child for name, child in self.children.items() if name not in names + name: child for name, child in result.children.items() if name not in names } - return self._replace(children=children_to_keep) + result._replace_node(children=children_to_keep) + return result @classmethod def from_dict( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 58fec20d4c6..8bbe1dc91c5 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -152,19 +152,19 @@ def test_is_hollow(self): class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self): dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) - with pytest.raises(KeyError, match="already contains a data variable named a"): + with pytest.raises(KeyError, match="already contains a variable named a"): DataTree(name="a", data=None, parent=dt) def test_assign_when_already_child_with_variables_name(self): dt: DataTree = DataTree(data=None) DataTree(name="a", data=None, parent=dt) - with pytest.raises(KeyError, match="names would collide"): + with pytest.raises(ValueError, match="node already contains a variable"): dt.ds = xr.Dataset({"a": 0}) # type: ignore[assignment] dt.ds = xr.Dataset() # type: ignore[assignment] new_ds = dt.to_dataset().assign(a=xr.DataArray(0)) - with pytest.raises(KeyError, match="names would collide"): + with pytest.raises(ValueError, match="node already contains a variable"): dt.ds = new_ds # type: ignore[assignment] @@ -623,6 +623,77 @@ def test_operation_with_attrs_but_no_data(self): dt.sel(dim_0=0) +class TestInheritance: + def test_inherited_dims(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset({"d": (("x",), [1, 2])}), + "/b": xr.Dataset({"e": (("y",), [3])}), + "/c": xr.Dataset({"f": (("y",), [3, 4, 5])}), + } + ) + assert dt.sizes == {"x": 2} + assert dt.b.sizes == {"x": 2, "y": 1} + assert dt.c.sizes == {"x": 2, "y": 3} + + def test_inherited_coords_index(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset({"d": (("x",), [1, 2])}, coords={"x": [2, 3]}), + "/b": xr.Dataset({"e": (("y",), [3])}), + } + ) + assert "x" in dt["/b"].indexes + assert "x" in dt["/b"].coords + xr.testing.assert_identical(dt["/x"], dt["/b/x"]) + + def test_inherited_coords_override(self): + dt = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": 1, "y": 2}), + "/b": xr.Dataset(coords={"x": 4, "z": 3}), + } + ) + assert dt.coords.keys() == {"x", "y"} + root_coords = {"x": 1, "y": 2} + sub_coords = {"x": 4, "y": 2, "z": 3} + xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords)) + xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords)) + assert dt["/b"].coords.keys() == {"x", "y", "z"} + xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords)) + xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) + xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) + + def test_inconsistent_dims(self): + with pytest.raises( + ValueError, match="inconsistent alignment between node and parent datasets" + ): + DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1, 2])}), + "/b": xr.Dataset({"c": (("x",), [3])}), + } + ) + + dt = DataTree() + dt["/a"] = xr.DataArray([1, 2], dims=["x"]) + with pytest.raises( + ValueError, match="cannot reindex or align along dimension 'x'" + ): + dt["/b/c"] = xr.DataArray([3], dims=["x"]) + + def test_inconsistent_indexes(self): + with pytest.raises( + ValueError, match="inconsistent alignment between node and parent datasets" + ): + DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1])}, coords={"x": [1]}), + "/b": xr.Dataset({"c": (("x",), [2])}, coords={"x": [2]}), + } + ) + + class TestRestructuring: def test_drop_nodes(self): sue = DataTree.from_dict({"Mary": None, "Kate": None, "Ashley": None}) From abf0574f2c65854eb8d492aa76c9044fb9c67ad2 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 12 Jun 2024 17:11:49 -0600 Subject: [PATCH 02/21] Simplify __init__ --- xarray/core/datatree.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index cb55ff89af5..c66441b03b5 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -449,8 +449,7 @@ def _set_node_data(self, data: Dataset | DataArray | None) -> None: self._close = ds._close # setup inherited node attributes (finalized by _post_attach) - self._variables = dict(data_vars) - self._variables.update(coord_vars) + self._variables = data_vars | coord_vars self._coord_names = set(coord_vars) self._dims = dict(ds._dims) self._indexes = dict(ds._indexes) From 54364d0bec04861c39d0887173e17cad26cf63df Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 19 Jun 2024 09:46:07 -0700 Subject: [PATCH 03/21] Include path name in alignment errors --- xarray/core/datatree.py | 13 ++++++++----- xarray/tests/test_datatree.py | 4 ++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 05bad3ca3ab..9bb044caf1d 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -107,6 +107,7 @@ def _collect_data_and_coord_variables( def _check_alignment( + path: str, node_ds: Dataset, parent_ds: Dataset | None, children: Mapping[Hashable, DataTree], @@ -116,12 +117,12 @@ def _check_alignment( align(node_ds, parent_ds, join="exact") except ValueError as e: raise ValueError( - "inconsistent alignment between node and parent datasets:\n" - f"{node_ds}\nvs\n{parent_ds}" + f"group {path!r} is not aligned with its parent:\n" + f"Group: {node_ds}\nvs\nParent: {parent_ds}" ) from e for child in children.values(): - _check_alignment(child.ds, node_ds, child.children) + _check_alignment(child.path, child.ds, node_ds, child.children) class DatasetView(Dataset): @@ -462,7 +463,9 @@ def _pre_attach(self: DataTree, parent: DataTree) -> None: raise KeyError( f"parent {parent.name} already contains a variable named {self.name}" ) - _check_alignment(self.ds, parent.ds, self.children) + name = self.name if self.name is not None else "" + path = parent.path.rstrip("/") + "/" + name + _check_alignment(path, self.ds, parent.ds, self.children) def _post_attach_recursively(self: DataTree, parent: DataTree) -> None: for k in parent._coord_names: @@ -698,7 +701,7 @@ def _replace_node( raise ValueError(f"node already contains a variable named {child_name}") parent_ds = self.parent.ds if self.parent is not None else None - _check_alignment(data, parent_ds, children) + _check_alignment(self.path, data, parent_ds, children) self._children = children self._set_node_data(data) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 8bbe1dc91c5..f3c125e0d51 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -666,7 +666,7 @@ def test_inherited_coords_override(self): def test_inconsistent_dims(self): with pytest.raises( - ValueError, match="inconsistent alignment between node and parent datasets" + ValueError, match="group '/b' is not aligned with its parent" ): DataTree.from_dict( { @@ -684,7 +684,7 @@ def test_inconsistent_dims(self): def test_inconsistent_indexes(self): with pytest.raises( - ValueError, match="inconsistent alignment between node and parent datasets" + ValueError, match="group '/b' is not aligned with its parent" ): DataTree.from_dict( { From f72fcd5b9fd058b71e496ae1e741dc55295cadde Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 19 Jun 2024 18:32:04 -0700 Subject: [PATCH 04/21] Fix some mypy errors --- xarray/core/datatree.py | 8 ++++---- xarray/tests/test_datatree.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 9bb044caf1d..3fc7c616a55 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -110,7 +110,7 @@ def _check_alignment( path: str, node_ds: Dataset, parent_ds: Dataset | None, - children: Mapping[Hashable, DataTree], + children: Mapping[str, DataTree], ) -> None: if parent_ds is not None: try: @@ -689,12 +689,12 @@ def _repr_html_(self): def _replace_node( self: DataTree, data: Dataset | Default = _default, - children: Mapping[str, DataTree] | Default = _default, + children: dict[str, DataTree] | Default = _default, ) -> None: if data is _default: data = self.ds if children is _default: - children = self.children + children = self._children for child_name, child in children.items(): if child_name in data.variables: @@ -763,7 +763,7 @@ def _copy_node( ) -> DataTree: """Copy just one node of a tree""" data = self.ds.copy(deep=deep) - new_node = DataTree(data, name=self.name) + new_node: DataTree = DataTree(data, name=self.name) return new_node def __copy__(self: DataTree) -> DataTree: diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index f3c125e0d51..ecf1996f8be 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -675,7 +675,7 @@ def test_inconsistent_dims(self): } ) - dt = DataTree() + dt: DataTree = DataTree() dt["/a"] = xr.DataArray([1, 2], dims=["x"]) with pytest.raises( ValueError, match="cannot reindex or align along dimension 'x'" From c4722e05dba80ee424eb37058e7d31e39153790a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 20 Jun 2024 10:17:45 -0700 Subject: [PATCH 05/21] mypy fix --- xarray/core/treenode.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 6f51e1ffa38..f3935c4d2f8 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -567,6 +567,9 @@ def same_tree(self, other: Tree) -> bool: return self.root is other.root +AnyNamedNode = TypeVar("AnyNamedNode", bound="NamedNode") + + class NamedNode(TreeNode, Generic[Tree]): """ A TreeNode which knows its own name. @@ -606,7 +609,7 @@ def __repr__(self, level=0): def __str__(self) -> str: return f"NamedNode('{self.name}')" if self.name else "NamedNode()" - def _post_attach(self: NamedNode, parent: NamedNode) -> None: + def _post_attach(self: AnyNamedNode, parent: AnyNamedNode) -> None: """Ensures child has name attribute corresponding to key under which it has been stored.""" key = next(k for k, v in parent.children.items() if v is self) self.name = key From c08e810019554c7e2556cb760f7789c9800674d6 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 20 Jun 2024 10:37:16 -0700 Subject: [PATCH 06/21] simplify DataTree data model --- xarray/core/datatree.py | 54 +++++++++-------------------------------- 1 file changed, 12 insertions(+), 42 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 3fc7c616a55..00bd5b03f2b 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -83,7 +83,7 @@ def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: if isinstance(data, DataArray): ds = data.to_dataset() elif isinstance(data, Dataset): - ds = data + ds = data.copy(deep=False) elif data is None: ds = Dataset() else: @@ -93,19 +93,6 @@ def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: return ds -def _collect_data_and_coord_variables( - data: Dataset, -) -> tuple[dict[Hashable, Variable], dict[Hashable, Variable]]: - data_variables = {} - coord_variables = {} - for k, v in data.variables.items(): - if k in data._coord_names: - coord_variables[k] = v - else: - data_variables[k] = v - return data_variables, coord_variables - - def _check_alignment( path: str, node_ds: Dataset, @@ -140,7 +127,7 @@ class DatasetView(Dataset): __slots__ = ( "_attrs", - "_cache", # is this used? + "_cache", # used by _CachedAccessor "_coord_names", "_dims", "_encoding", @@ -359,11 +346,7 @@ class DataTree( _name: str | None _parent: DataTree | None _children: dict[str, DataTree] - _cache: dict[str, Any] # is this used? - _node_data_variables: dict[Hashable, Variable] - _node_coord_variables: dict[Hashable, Variable] - _node_dims: dict[Hashable, int] - _node_indexes: dict[Hashable, Index] + _cache: dict[str, Any] # used by _CachedAccessor _variables: dict[Hashable, Variable] _coord_names: set[Hashable] _dims: dict[Hashable, int] @@ -376,11 +359,7 @@ class DataTree( "_name", "_parent", "_children", - "_cache", - "_node_data_variables", - "_node_coord_variables", - "_node_dims", - "_node_indexes", + "_cache", # used by _CachedAccessor "_variables", "_coord_names", "_dims", @@ -432,31 +411,22 @@ def __init__( # set tree attributes self._children = {} self._parent = None - self._set_node_data(data) + self._set_node_data(_coerce_to_dataset(data)) # finalize tree attributes self.children = children # must set first self.parent = parent - def _set_node_data(self, data: Dataset | DataArray | None) -> None: - ds = _coerce_to_dataset(data) - data_vars, coord_vars = _collect_data_and_coord_variables(ds) - - # set node data attributes - self._data_variables = data_vars - self._node_coord_variables = coord_vars - self._node_dims = ds._dims - self._node_indexes = ds._indexes - self._attrs = ds._attrs + def _set_node_data(self, ds: Dataset) -> None: + # these node data attributes are finalized by _post_attach + self._variables = ds._variables + self._coord_names = ds._coord_names + self._dims = ds._dims + self._indexes = ds._indexes self._encoding = ds._encoding + self._attrs = ds._attrs self._close = ds._close - # setup inherited node attributes (finalized by _post_attach) - self._variables = data_vars | coord_vars - self._coord_names = set(coord_vars) - self._dims = dict(ds._dims) - self._indexes = dict(ds._indexes) - def _pre_attach(self: DataTree, parent: DataTree) -> None: super()._pre_attach(parent) if self.name in parent.ds.variables: From 2351b9726dd214872acef6af149a27173d5218bb Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 20 Jun 2024 18:37:44 -0700 Subject: [PATCH 07/21] Add to_dataset(local=True) --- xarray/core/datatree.py | 44 ++++++++++++++++++++++++----------- xarray/tests/test_datatree.py | 14 +++++++++++ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 00bd5b03f2b..294fe1db0a4 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -347,6 +347,10 @@ class DataTree( _parent: DataTree | None _children: dict[str, DataTree] _cache: dict[str, Any] # used by _CachedAccessor + _local_variables: dict[Hashable, Variable] + _local_coord_names: set[Hashable] + _local_dims: dict[Hashable, int] + _local_indexes: dict[Hashable, Index] _variables: dict[Hashable, Variable] _coord_names: set[Hashable] _dims: dict[Hashable, int] @@ -360,6 +364,10 @@ class DataTree( "_parent", "_children", "_cache", # used by _CachedAccessor + "_local_variables", + "_local_coord_names", + "_local_dims", + "_local_indexes", "_variables", "_coord_names", "_dims", @@ -418,11 +426,16 @@ def __init__( self.parent = parent def _set_node_data(self, ds: Dataset) -> None: - # these node data attributes are finalized by _post_attach - self._variables = ds._variables - self._coord_names = ds._coord_names - self._dims = ds._dims - self._indexes = ds._indexes + # local data attributes for to_dataset(local=True) + self._local_variables = ds._variables + self._local_coord_names = ds._coord_names + self._local_dims = ds._dims + self._local_indexes = ds._indexes + # these data attributes with inheritance are finalized by _post_attach + self._variables = dict(ds._variables) + self._coord_names = set(ds._coord_names) + self._dims = dict(ds._dims) + self._indexes = dict(ds._indexes) self._encoding = ds._encoding self._attrs = ds._attrs self._close = ds._close @@ -481,22 +494,27 @@ def ds(self, data: Dataset | DataArray | None = None) -> None: ds = _coerce_to_dataset(data) self._replace_node(ds) - def to_dataset(self) -> Dataset: + def to_dataset(self, local: bool = False) -> Dataset: """ Return the data in this node as a new xarray.Dataset object. + Parameters + ---------- + local : bool, optional + If True, only include coordinates, indexes and dimensions defined + at the level of this DataTree node, excluding inherited coordinates. + See Also -------- DataTree.ds """ - # TODO: copy these container objects? return Dataset._construct_direct( - self._variables, - self._coord_names, - self._dims, - self._attrs, - self._indexes, - self._encoding, + dict(self._local_variables if local else self._variables), + set(self._local_coord_names if local else self._coord_names), + dict(self._local_dims if local else self._dims), + None if self._attrs is None else dict(self._attrs), + self._local_indexes if local else self._indexes, + None if self._encoding is None else dict(self._encoding), self._close, ) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index ecf1996f8be..207886bb7ab 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -149,6 +149,20 @@ def test_is_hollow(self): assert not eve.is_hollow +class TestToDataset: + def test_to_dataset(self): + base = xr.Dataset(coords={"a": 1}) + sub = xr.Dataset(coords={"b": 2}) + dt = DataTree.from_dict({"/": base, "/sub": sub}) + + assert_identical(dt.to_dataset(local=True), base) + assert_identical(dt["sub"].to_dataset(local=True), sub) + + sub2 = xr.Dataset(coords={"a": 1, "b": 2}) + assert_identical(dt.to_dataset(local=False), base) + assert_identical(dt["sub"].to_dataset(local=False), sub2) + + class TestVariablesChildrenNameCollisions: def test_parent_already_has_variable_with_childs_name(self): dt: DataTree = DataTree(data=xr.Dataset({"a": [0], "b": 1})) From 03fb99100560411335e8ca2813f76495938583b0 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 21 Jun 2024 14:07:03 -0700 Subject: [PATCH 08/21] Fix mypy failure in tests --- xarray/tests/test_datatree.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 207886bb7ab..946c4b4bd07 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1,3 +1,4 @@ +import typing from copy import copy, deepcopy from textwrap import dedent @@ -153,14 +154,15 @@ class TestToDataset: def test_to_dataset(self): base = xr.Dataset(coords={"a": 1}) sub = xr.Dataset(coords={"b": 2}) - dt = DataTree.from_dict({"/": base, "/sub": sub}) + tree = DataTree.from_dict({"/": base, "/sub": sub}) + subtree = typing.cast(DataTree, tree["sub"]) - assert_identical(dt.to_dataset(local=True), base) - assert_identical(dt["sub"].to_dataset(local=True), sub) + assert_identical(tree.to_dataset(local=True), base) + assert_identical(subtree.to_dataset(local=True), sub) - sub2 = xr.Dataset(coords={"a": 1, "b": 2}) - assert_identical(dt.to_dataset(local=False), base) - assert_identical(dt["sub"].to_dataset(local=False), sub2) + sub_and_base = xr.Dataset(coords={"a": 1, "b": 2}) + assert_identical(tree.to_dataset(local=False), base) + assert_identical(subtree.to_dataset(local=False), sub_and_base) class TestVariablesChildrenNameCollisions: From a611fb1fac38a326ea45f34ad1bbdaaf2976b5fe Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 24 Jun 2024 23:22:55 -0700 Subject: [PATCH 09/21] Fix to_zarr for inherited coords --- xarray/core/datatree_io.py | 2 +- xarray/tests/test_backends_datatree.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index 1473e624d9e..eeffa032d58 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -151,7 +151,7 @@ def _datatree_to_zarr( ) for node in dt.subtree: - ds = node.ds + ds = node.to_dataset(local=True) group_path = node.path if ds is None: _create_empty_zarr_group(store, group_path, mode) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 4e819eec0b5..84d3ff090f0 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -1,10 +1,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import pytest +import xarray as xr from xarray.backends.api import open_datatree +from xarray.core.datatree import DataTree from xarray.testing import assert_equal from xarray.tests import ( requires_h5netcdf, @@ -119,3 +121,18 @@ def test_to_zarr_default_write_mode(self, tmpdir, simple_datatree): # with default settings, to_zarr should not overwrite an existing dir with pytest.raises(zarr.errors.ContainsGroupError): simple_datatree.to_zarr(tmpdir) + + def test_to_zarr_inherited_coords(self, tmpdir): + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}), + "/sub": xr.Dataset({"b": (("x",), [5, 6])}), + } + ) + filepath = tmpdir / "test.zarr" + original_dt.to_zarr(filepath) + + roundtrip_dt = open_datatree(filepath, engine="zarr") + assert_equal(original_dt, roundtrip_dt) + subtree = cast(DataTree, roundtrip_dt["/sub"]) + assert "x" not in subtree.to_dataset(local=True).coords From cef6cfada0cf2ed9391bcfb0e684ce5a277d86fc Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 25 Jun 2024 16:14:43 -0700 Subject: [PATCH 10/21] Fix to_netcdf for heirarchical coords --- xarray/core/datatree_io.py | 2 +- xarray/tests/test_backends_datatree.py | 23 +++++++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index eeffa032d58..dbc80627cdd 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -85,7 +85,7 @@ def _datatree_to_netcdf( unlimited_dims = {} for node in dt.subtree: - ds = node.ds + ds = node.to_dataset(local=True) group_path = node.path if ds is None: _create_empty_netcdf_group(filepath, group_path, mode, engine) diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index 84d3ff090f0..b8eb22312e2 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -15,11 +15,11 @@ ) if TYPE_CHECKING: - from xarray.backends.api import T_NetcdfEngine + from xarray.core.datatree_io import T_DataTreeNetcdfEngine class DatatreeIOBase: - engine: T_NetcdfEngine | None = None + engine: T_DataTreeNetcdfEngine | None = None def test_to_netcdf(self, tmpdir, simple_datatree): filepath = tmpdir / "test.nc" @@ -29,6 +29,21 @@ def test_to_netcdf(self, tmpdir, simple_datatree): roundtrip_dt = open_datatree(filepath, engine=self.engine) assert_equal(original_dt, roundtrip_dt) + def test_to_netcdf_inherited_coords(self, tmpdir): + filepath = tmpdir / "test.nc" + original_dt = DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1, 2])}, coords={"x": [3, 4]}), + "/sub": xr.Dataset({"b": (("x",), [5, 6])}), + } + ) + original_dt.to_netcdf(filepath, engine=self.engine) + + roundtrip_dt = open_datatree(filepath, engine=self.engine) + assert_equal(original_dt, roundtrip_dt) + subtree = cast(DataTree, roundtrip_dt["/sub"]) + assert "x" not in subtree.to_dataset(local=True).coords + def test_netcdf_encoding(self, tmpdir, simple_datatree): filepath = tmpdir / "test.nc" original_dt = simple_datatree @@ -50,12 +65,12 @@ def test_netcdf_encoding(self, tmpdir, simple_datatree): @requires_netCDF4 class TestNetCDF4DatatreeIO(DatatreeIOBase): - engine: T_NetcdfEngine | None = "netcdf4" + engine: T_DataTreeNetcdfEngine | None = "netcdf4" @requires_h5netcdf class TestH5NetCDFDatatreeIO(DatatreeIOBase): - engine: T_NetcdfEngine | None = "h5netcdf" + engine: T_DataTreeNetcdfEngine | None = "h5netcdf" @requires_zarr From 51b2c4314fe9d7cdba21f1ccb68e1be89d61e4a7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 26 Jun 2024 08:11:34 -0700 Subject: [PATCH 11/21] Add ChainSet --- xarray/core/utils.py | 44 ++++++++++++++++++++++++++++++++++++++ xarray/tests/test_utils.py | 30 +++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 5cb52cbd25c..7d7c3c91e80 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -563,6 +563,50 @@ def __repr__(self) -> str: return f"{type(self).__name__}({list(self)!r})" +class ChainSet(MutableSet[T]): + """A chained set, based on the design of collections.ChainMap.""" + + sets: list[set[T]] + __slots__ = ("sets",) + + def __init__(self, *sets): + self.sets = list(sets) or [set()] + + # Required methods for MutableSet + + def __contains__(self, value: Hashable) -> bool: + return any(value in s for s in self.sets) + + def __iter__(self) -> Iterator[T]: + return iter(set().union(*self.sets)) + + def __len__(self) -> int: + return len(set().union(*self.sets)) + + def add(self, value: T) -> None: + self.sets[0].add(value) + + def discard(self, value: T) -> None: + for s in self.sets: + s.discard(value) + + # Additional methods + + def new_child(self, s: set[T] | None = None) -> ChainSet[T]: + """New ChainSet with a new set followed by all previous sets.""" + if s is None: + s = set() + return self.__class__(s, *self.sets) + + @property + def parents(self) -> ChainSet[T]: + """New ChainSet from sets[1:].""" + return self.__class__(*self.sets[1:]) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({", ".join(map(repr, self.sets))})' + + class NdimSizeLenMixin: """Mixin class that extends a class that defines a ``shape`` property to one that also defines ``ndim``, ``size`` and ``__len__``. diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 50061c774a8..1ce554902e0 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -7,7 +7,12 @@ import pytest from xarray.core import duck_array_ops, utils -from xarray.core.utils import either_dict_or_kwargs, infix_dims, iterate_nested +from xarray.core.utils import ( + ChainSet, + either_dict_or_kwargs, + infix_dims, + iterate_nested, +) from xarray.tests import assert_array_equal, requires_dask @@ -356,3 +361,26 @@ def f(): return utils.find_stack_level(test_mode=True) assert f() == 3 + + +def test_chain_set(): + chain_set = ChainSet({1, 2}, {2, 3}) + assert chain_set.sets == [{1, 2}, {2, 3}] + assert 1 in chain_set + assert 3 in chain_set + assert 4 not in chain_set + assert set(chain_set) == {1, 2, 3} + assert len(chain_set) == 3 + chain_set.add(0) + assert chain_set == {0, 1, 2, 3} + assert chain_set.sets[0] == {0, 1, 2} + assert chain_set.sets[1] == {2, 3} # unchanged + chain_set.discard(0) + assert chain_set == {1, 2, 3} + empty_child = chain_set.new_child() + assert empty_child.sets == [set(), {1, 2}, {2, 3}] + filled_child = chain_set.new_child({0}) + assert filled_child.sets == [{0}, {1, 2}, {2, 3}] + child_parent = filled_child.parents + assert child_parent == chain_set + assert repr(chain_set) == "ChainSet({1, 2}, {2, 3})" From dcb103b4dc19d46f81eb915e6baaac61e378e14b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 26 Jun 2024 20:08:16 -0700 Subject: [PATCH 12/21] Revise internal data model; remove ChainSet --- xarray/core/datatree.py | 218 +++++++++++++++++++++------------- xarray/core/treenode.py | 17 +-- xarray/core/utils.py | 44 ------- xarray/tests/test_datatree.py | 71 ++++++++++- xarray/tests/test_utils.py | 24 ---- 5 files changed, 211 insertions(+), 163 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 294fe1db0a4..017779137e1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1,6 +1,8 @@ from __future__ import annotations import itertools +import textwrap +from collections import ChainMap from collections.abc import Hashable, Iterable, Iterator, Mapping, MutableMapping from html import escape from typing import ( @@ -79,6 +81,19 @@ T_Path = Union[str, NodePath] +def _collect_data_and_coord_variables( + data: Dataset, +) -> tuple[dict[Hashable, Variable], dict[Hashable, Variable]]: + data_variables = {} + coord_variables = {} + for k, v in data.variables.items(): + if k in data._coord_names: + coord_variables[k] = v + else: + data_variables[k] = v + return data_variables, coord_variables + + def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: if isinstance(data, DataArray): ds = data.to_dataset() @@ -93,6 +108,23 @@ def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: return ds +def _join_path(root: str, name: str) -> str: + return root.rstrip("/") + "/" + name + + +def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset: + parent_coord_variables = {k: parent._variables[k] for k in parent._coord_names} + return Dataset._construct_direct( + variables=parent_coord_variables | ds._variables, + coord_names=parent._coord_names | ds._coord_names, + dims=parent._dims | ds._dims, + attrs=ds._attrs, + indexes=parent._indexes | ds._indexes, + encoding=ds._encoding, + close=ds._close, + ) + + def _check_alignment( path: str, node_ds: Dataset, @@ -103,13 +135,23 @@ def _check_alignment( try: align(node_ds, parent_ds, join="exact") except ValueError as e: + node_repr = textwrap.indent(repr(node_ds), prefix=" ") + parent_repr = textwrap.indent(repr(parent_ds), prefix=" ") raise ValueError( f"group {path!r} is not aligned with its parent:\n" - f"Group: {node_ds}\nvs\nParent: {parent_ds}" + f"Group:\n{node_repr}\nParent:\n{parent_repr}" ) from e - for child in children.values(): - _check_alignment(child.path, child.ds, node_ds, child.children) + if children: + if parent_ds is not None: + base_ds = _inherited_dataset(node_ds, parent_ds) + else: + base_ds = node_ds + + for child_name, child in children.items(): + child_path = _join_path(path, child_name) + child_ds = child.to_dataset(local=True) + _check_alignment(child_path, child_ds, base_ds, child.children) class DatasetView(Dataset): @@ -145,21 +187,25 @@ def __init__( raise AttributeError("DatasetView objects are not to be initialized directly") @classmethod - def _from_node( + def _from_dataset_state( cls, - wrapping_node: DataTree, + variables: dict[Any, Variable], + coord_names: set[Hashable], + dims: dict[Any, int], + attrs: dict | None, + indexes: dict[Any, Index], + encoding: dict | None, + close: Callable[[], None] | None, ) -> DatasetView: """Constructor, using dataset attributes from wrapping node""" - obj: DatasetView = object.__new__(cls) - obj._variables = wrapping_node._variables - obj._coord_names = wrapping_node._coord_names - obj._dims = wrapping_node._dims - obj._indexes = wrapping_node._indexes - obj._attrs = wrapping_node._attrs - obj._close = wrapping_node._close - obj._encoding = wrapping_node._encoding - + obj._variables = variables + obj._coord_names = coord_names + obj._dims = dims + obj._indexes = indexes + obj._attrs = attrs + obj._close = close + obj._encoding = encoding return obj def __setitem__(self, key, val) -> None: @@ -347,14 +393,10 @@ class DataTree( _parent: DataTree | None _children: dict[str, DataTree] _cache: dict[str, Any] # used by _CachedAccessor - _local_variables: dict[Hashable, Variable] - _local_coord_names: set[Hashable] - _local_dims: dict[Hashable, int] - _local_indexes: dict[Hashable, Index] - _variables: dict[Hashable, Variable] - _coord_names: set[Hashable] - _dims: dict[Hashable, int] - _indexes: dict[Hashable, Index] + _data_variables: dict[Hashable, Variable] + _coord_variables: ChainMap[Hashable, Variable] + _dims: ChainMap[Hashable, int] + _indexes: ChainMap[Hashable, Index] _attrs: dict[Hashable, Any] | None _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None @@ -364,12 +406,8 @@ class DataTree( "_parent", "_children", "_cache", # used by _CachedAccessor - "_local_variables", - "_local_coord_names", - "_local_dims", - "_local_indexes", - "_variables", - "_coord_names", + "_data_variables", + "_coord_variables", "_dims", "_indexes", "_attrs", @@ -419,51 +457,47 @@ def __init__( # set tree attributes self._children = {} self._parent = None + + # set data attributes + self._coord_variables: ChainMap[Hashable, Variable] = ChainMap() + self._dims = ChainMap() + self._indexes = ChainMap() self._set_node_data(_coerce_to_dataset(data)) # finalize tree attributes self.children = children # must set first self.parent = parent - def _set_node_data(self, ds: Dataset) -> None: - # local data attributes for to_dataset(local=True) - self._local_variables = ds._variables - self._local_coord_names = ds._coord_names - self._local_dims = ds._dims - self._local_indexes = ds._indexes - # these data attributes with inheritance are finalized by _post_attach - self._variables = dict(ds._variables) - self._coord_names = set(ds._coord_names) - self._dims = dict(ds._dims) - self._indexes = dict(ds._indexes) + def _set_node_data(self, ds: Dataset): + data_vars, coord_vars = _collect_data_and_coord_variables(ds) + self._data_variables = data_vars + self._coord_variables.maps[0] = coord_vars + self._dims.maps[0] = ds._dims + self._indexes.maps[0] = ds._indexes self._encoding = ds._encoding self._attrs = ds._attrs self._close = ds._close - def _pre_attach(self: DataTree, parent: DataTree) -> None: - super()._pre_attach(parent) - if self.name in parent.ds.variables: + def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._pre_attach(parent, name) + if name in parent.ds.variables: raise KeyError( - f"parent {parent.name} already contains a variable named {self.name}" + f"parent {parent.name} already contains a variable named {name}" ) - name = self.name if self.name is not None else "" - path = parent.path.rstrip("/") + "/" + name - _check_alignment(path, self.ds, parent.ds, self.children) - - def _post_attach_recursively(self: DataTree, parent: DataTree) -> None: - for k in parent._coord_names: - if k not in self._variables: - self._variables[k] = parent._variables[k] - self._coord_names.add(k) - self._dims.update(parent._dims) - self._indexes.update(parent._indexes) - + path = _join_path(parent.path, name) + node_ds = self.to_dataset(local=True) + _check_alignment(path, node_ds, parent.ds, self.children) + + def _add_parent_maps(self: DataTree, parent: DataTree) -> None: + self._coord_variables.maps.extend(parent._coord_variables.maps) + self._dims.maps.extend(parent._dims.maps) + self._indexes.maps.extend(parent._indexes.maps) for child in self._children.values(): - child._post_attach_recursively(self) + child._add_parent_maps(self) - def _post_attach(self: DataTree, parent: DataTree) -> None: - super()._post_attach(parent) - self._post_attach_recursively(parent) + def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._post_attach(parent, name) + self._add_parent_maps(parent) @property def parent(self: DataTree) -> DataTree | None: @@ -481,13 +515,22 @@ def ds(self) -> DatasetView: """ An immutable Dataset-like view onto the data in this node. - For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead. + For a mutable Dataset containing the same data as in this node, use + `.to_dataset()` instead. See Also -------- DataTree.to_dataset """ - return DatasetView._from_node(self) + return DatasetView._from_dataset_state( + variables=self._data_variables | self._coord_variables, + coord_names=set(self._coord_variables), + dims=dict(self._dims), # always includes inherited dimensions + attrs=self._attrs, + indexes=dict(self._indexes), + encoding=self._encoding, + close=self._close, + ) @ds.setter def ds(self, data: Dataset | DataArray | None = None) -> None: @@ -501,27 +544,30 @@ def to_dataset(self, local: bool = False) -> Dataset: Parameters ---------- local : bool, optional - If True, only include coordinates, indexes and dimensions defined - at the level of this DataTree node, excluding inherited coordinates. + If True, only include coordinates and indexes defined at the level + of this DataTree node, excluding inherited coordinates. See Also -------- DataTree.ds """ + coord_vars = self._coord_variables.maps[0] if local else self._coord_variables + variables = self._data_variables | coord_vars + dims = dict(self._dims.maps[0]) if local else calculate_dimensions(variables) return Dataset._construct_direct( - dict(self._local_variables if local else self._variables), - set(self._local_coord_names if local else self._coord_names), - dict(self._local_dims if local else self._dims), + variables, + set(coord_vars), + dims, None if self._attrs is None else dict(self._attrs), - self._local_indexes if local else self._indexes, + dict(self._indexes.maps[0] if local else self._indexes), None if self._encoding is None else dict(self._encoding), self._close, ) @property - def has_data(self): - """Whether or not there are any data variables in this node.""" - return len(self._variables) > 0 + def has_data(self) -> bool: + """Whether or not there are any variables in this node.""" + return bool(self._data_variables or self._coord_variables.maps[0]) @property def has_attrs(self) -> bool: @@ -546,7 +592,7 @@ def variables(self) -> Mapping[Hashable, Variable]: Dataset invariants. It contains all variable objects constituting this DataTree node, including both data variables and coordinates. """ - return Frozen(self._variables) + return Frozen(self._data_variables | self._coord_variables) @property def attrs(self) -> dict[Hashable, Any]: @@ -607,7 +653,7 @@ def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: def _item_sources(self) -> Iterable[Mapping[Any, Any]]: """Places to look-up items for key-completion""" yield self.data_vars - yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + yield HybridMappingProxy(keys=self._coord_variables, mapping=self.coords) # virtual coordinates yield HybridMappingProxy(keys=self.dims, mapping=self) @@ -649,10 +695,10 @@ def __contains__(self, key: object) -> bool: return key in self.variables or key in self.children def __bool__(self) -> bool: - return bool(self.ds.data_vars) or bool(self.children) + return bool(self._data_variables) or bool(self._children) def __iter__(self) -> Iterator[Hashable]: - return itertools.chain(self.ds.data_vars, self.children) + return itertools.chain(self._data_variables, self._children) def __array__(self, dtype=None, copy=None): raise TypeError( @@ -679,26 +725,30 @@ def _replace_node( data: Dataset | Default = _default, children: dict[str, DataTree] | Default = _default, ) -> None: - if data is _default: - data = self.ds + + ds = self.to_dataset(local=True) if data is _default else data + if children is _default: children = self._children - for child_name, child in children.items(): - if child_name in data.variables: + for child_name in children: + if child_name in ds.variables: raise ValueError(f"node already contains a variable named {child_name}") parent_ds = self.parent.ds if self.parent is not None else None - _check_alignment(self.path, data, parent_ds, children) + _check_alignment(self.path, ds, parent_ds, children) + + if data is not _default: + self._set_node_data(ds) self._children = children - self._set_node_data(data) if self.parent is not None: - self._post_attach(self.parent) + assert self.name is not None + self._post_attach(self.parent, self.name) else: for child in children.values(): - child._post_attach_recursively(self) + child._add_parent_maps(self) def copy( self: DataTree, @@ -1076,7 +1126,9 @@ def indexes(self) -> Indexes[pd.Index]: @property def xindexes(self) -> Indexes[Index]: """Mapping of xarray Index objects used for label based indexing.""" - return Indexes(self._indexes, {k: self._variables[k] for k in self._indexes}) + return Indexes( + self._indexes, {k: self._coord_variables[k] for k in self._indexes} + ) @property def coords(self) -> DatasetCoordinates: diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index f3935c4d2f8..3017b59d978 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -138,14 +138,14 @@ def _attach(self, parent: Tree | None, child_name: str | None = None) -> None: "To directly set parent, child needs a name, but child is unnamed" ) - self._pre_attach(parent) + self._pre_attach(parent, child_name) parentchildren = parent._children assert not any( child is self for child in parentchildren ), "Tree is corrupt." parentchildren[child_name] = self self._parent = parent - self._post_attach(parent) + self._post_attach(parent, child_name) else: self._parent = None @@ -415,11 +415,11 @@ def _post_detach(self: Tree, parent: Tree) -> None: """Method call after detaching from `parent`.""" pass - def _pre_attach(self: Tree, parent: Tree) -> None: + def _pre_attach(self: Tree, parent: Tree, name: str) -> None: """Method call before attaching to `parent`.""" pass - def _post_attach(self: Tree, parent: Tree) -> None: + def _post_attach(self: Tree, parent: Tree, name: str) -> None: """Method call after attaching to `parent`.""" pass @@ -609,10 +609,13 @@ def __repr__(self, level=0): def __str__(self) -> str: return f"NamedNode('{self.name}')" if self.name else "NamedNode()" - def _post_attach(self: AnyNamedNode, parent: AnyNamedNode) -> None: + def _get_name_in_parent(self: AnyNamedNode, parent: AnyNamedNode) -> str: + return next(k for k, v in parent.children.items() if v is self) + + def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None: """Ensures child has name attribute corresponding to key under which it has been stored.""" - key = next(k for k, v in parent.children.items() if v is self) - self.name = key + # self.name = self._get_name_in_parent(parent + self.name = name @property def path(self) -> str: diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 7d7c3c91e80..5cb52cbd25c 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -563,50 +563,6 @@ def __repr__(self) -> str: return f"{type(self).__name__}({list(self)!r})" -class ChainSet(MutableSet[T]): - """A chained set, based on the design of collections.ChainMap.""" - - sets: list[set[T]] - __slots__ = ("sets",) - - def __init__(self, *sets): - self.sets = list(sets) or [set()] - - # Required methods for MutableSet - - def __contains__(self, value: Hashable) -> bool: - return any(value in s for s in self.sets) - - def __iter__(self) -> Iterator[T]: - return iter(set().union(*self.sets)) - - def __len__(self) -> int: - return len(set().union(*self.sets)) - - def add(self, value: T) -> None: - self.sets[0].add(value) - - def discard(self, value: T) -> None: - for s in self.sets: - s.discard(value) - - # Additional methods - - def new_child(self, s: set[T] | None = None) -> ChainSet[T]: - """New ChainSet with a new set followed by all previous sets.""" - if s is None: - s = set() - return self.__class__(s, *self.sets) - - @property - def parents(self) -> ChainSet[T]: - """New ChainSet from sets[1:].""" - return self.__class__(*self.sets[1:]) - - def __repr__(self) -> str: - return f'{self.__class__.__name__}({", ".join(map(repr, self.sets))})' - - class NdimSizeLenMixin: """Mixin class that extends a class that defines a ``shape`` property to one that also defines ``ndim``, ``size`` and ``__len__``. diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 946c4b4bd07..b8db8272a33 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -577,7 +577,9 @@ def test_methods(self): def test_arithmetic(self, create_test_datatree): dt = create_test_datatree() - expected = create_test_datatree(modify=lambda ds: 10.0 * ds)["set1"] + expected = create_test_datatree(modify=lambda ds: 10.0 * ds)[ + "set1" + ].to_dataset() result = 10.0 * dt["set1"].ds assert result.identical(expected) @@ -651,6 +653,10 @@ def test_inherited_dims(self): assert dt.sizes == {"x": 2} assert dt.b.sizes == {"x": 2, "y": 1} assert dt.c.sizes == {"x": 2, "y": 3} + # .ds should also include inherit dims (used for alignment checking) + assert dt.b.ds.sizes == {"x": 2, "y": 1} + # .to_dataset() should not + assert dt.b.to_dataset().sizes == {"y": 1} def test_inherited_coords_index(self): dt = DataTree.from_dict( @@ -694,21 +700,76 @@ def test_inconsistent_dims(self): dt: DataTree = DataTree() dt["/a"] = xr.DataArray([1, 2], dims=["x"]) with pytest.raises( - ValueError, match="cannot reindex or align along dimension 'x'" + ValueError, match="group '/b' is not aligned with its parent" ): dt["/b/c"] = xr.DataArray([3], dims=["x"]) - def test_inconsistent_indexes(self): + b: DataTree = DataTree(data=xr.Dataset({"c": (("x",), [3])})) + with pytest.raises( + ValueError, match="group '/b' is not aligned with its parent" + ): + DataTree( + data=xr.Dataset({"a": (("x",), [1, 2])}), + children={"b": b}, + ) + + def test_inconsistent_child_indexes(self): + with pytest.raises( + ValueError, match="group '/b' is not aligned with its parent" + ): + DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": [1]}), + "/b": xr.Dataset(coords={"x": [2]}), + } + ) + + # TODO: figure out how to set coordinates only on a node via mutation + + b: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) with pytest.raises( ValueError, match="group '/b' is not aligned with its parent" + ): + DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) + + def test_inconsistent_grandchild_indexes(self): + with pytest.raises( + ValueError, match="group '/b/c' is not aligned with its parent" ): DataTree.from_dict( { - "/": xr.Dataset({"a": (("x",), [1])}, coords={"x": [1]}), - "/b": xr.Dataset({"c": (("x",), [2])}, coords={"x": [2]}), + "/": xr.Dataset(coords={"x": [1]}), + "/b/c": xr.Dataset(coords={"x": [2]}), } ) + # TODO: figure out how to set coordinates only on a node via mutation + + c: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) + b: DataTree = DataTree(children={"c": c}) + with pytest.raises( + ValueError, match="group '/b/c' is not aligned with its parent" + ): + DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) + + def test_inconsistent_grandchild_dims(self): + with pytest.raises( + ValueError, match="group '/b/c' is not aligned with its parent" + ): + DataTree.from_dict( + { + "/": xr.Dataset({"a": (("x",), [1, 2])}), + "/b/c": xr.Dataset({"d": (("x",), [3])}), + } + ) + + dt: DataTree = DataTree() + dt["/a"] = xr.DataArray([1, 2], dims=["x"]) + with pytest.raises( + ValueError, match="group '/b/c' is not aligned with its parent" + ): + dt["/b/c/d"] = xr.DataArray([3], dims=["x"]) + class TestRestructuring: def test_drop_nodes(self): diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 1ce554902e0..ecec3ca507b 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -8,7 +8,6 @@ from xarray.core import duck_array_ops, utils from xarray.core.utils import ( - ChainSet, either_dict_or_kwargs, infix_dims, iterate_nested, @@ -361,26 +360,3 @@ def f(): return utils.find_stack_level(test_mode=True) assert f() == 3 - - -def test_chain_set(): - chain_set = ChainSet({1, 2}, {2, 3}) - assert chain_set.sets == [{1, 2}, {2, 3}] - assert 1 in chain_set - assert 3 in chain_set - assert 4 not in chain_set - assert set(chain_set) == {1, 2, 3} - assert len(chain_set) == 3 - chain_set.add(0) - assert chain_set == {0, 1, 2, 3} - assert chain_set.sets[0] == {0, 1, 2} - assert chain_set.sets[1] == {2, 3} # unchanged - chain_set.discard(0) - assert chain_set == {1, 2, 3} - empty_child = chain_set.new_child() - assert empty_child.sets == [set(), {1, 2}, {2, 3}] - filled_child = chain_set.new_child({0}) - assert filled_child.sets == [{0}, {1, 2}, {2, 3}] - child_parent = filled_child.parents - assert child_parent == chain_set - assert repr(chain_set) == "ChainSet({1, 2}, {2, 3})" From 34c1fb2aee0834f6e04d58104308d26ccb9e3028 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 27 Jun 2024 16:57:48 -0700 Subject: [PATCH 13/21] add another way to construct inherited indexes --- xarray/tests/test_datatree.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index d329fd0dd4e..90c65d82c8c 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -791,7 +791,13 @@ def test_inconsistent_child_indexes(self): } ) - # TODO: figure out how to set coordinates only on a node via mutation + dt: DataTree = DataTree() + dt.ds = xr.Dataset(coords={"x": [1]}) # type: ignore + dt["/b"] = DataTree() + with pytest.raises( + ValueError, match="group '/b' is not aligned with its parent" + ): + dt["/b"].ds = xr.Dataset(coords={"x": [2]}) b: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) with pytest.raises( @@ -810,7 +816,13 @@ def test_inconsistent_grandchild_indexes(self): } ) - # TODO: figure out how to set coordinates only on a node via mutation + dt: DataTree = DataTree() + dt.ds = xr.Dataset(coords={"x": [1]}) # type: ignore + dt["/b/c"] = DataTree() + with pytest.raises( + ValueError, match="group '/b/c' is not aligned with its parent" + ): + dt["/b/c"].ds = xr.Dataset(coords={"x": [2]}) c: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) b: DataTree = DataTree(children={"c": c}) From 7767634762e89a6cb8c45ef4da04120ac3c0f60f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 27 Jun 2024 18:03:10 -0700 Subject: [PATCH 14/21] Finish refactoring error message --- xarray/core/datatree.py | 92 ++++++++++++---- xarray/core/datatree_io.py | 4 +- xarray/core/formatting.py | 2 +- xarray/tests/test_backends_datatree.py | 4 +- xarray/tests/test_datatree.py | 146 ++++++++++++++++++------- 5 files changed, 178 insertions(+), 70 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 82961c2d933..22719694fd7 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -113,9 +113,8 @@ def _join_path(root: str, name: str) -> str: def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset: - parent_coord_variables = {k: parent._variables[k] for k in parent._coord_names} return Dataset._construct_direct( - variables=parent_coord_variables | ds._variables, + variables=parent._variables | ds._variables, coord_names=parent._coord_names | ds._coord_names, dims=parent._dims | ds._dims, attrs=ds._attrs, @@ -125,6 +124,21 @@ def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset: ) +def _indented_without_header(text: str) -> str: + return textwrap.indent("\n".join(text.split("\n")[1:]), prefix=" ") + + +def _drop_data_vars_and_attrs_sections(text: str) -> str: + lines = text.split("\n") + outputs = [] + match = "Data variables:" + for line in lines: + if line[: len(match)] == match: + break + outputs.append(line) + return "\n".join(outputs) + + def _check_alignment( path: str, node_ds: Dataset, @@ -135,8 +149,10 @@ def _check_alignment( try: align(node_ds, parent_ds, join="exact") except ValueError as e: - node_repr = textwrap.indent(repr(node_ds), prefix=" ") - parent_repr = textwrap.indent(repr(parent_ds), prefix=" ") + node_repr = _indented_without_header(repr(node_ds)) + parent_repr = _indented_without_header( + _drop_data_vars_and_attrs_sections(repr(parent_ds)) + ) raise ValueError( f"group {path!r} is not aligned with its parent:\n" f"Group:\n{node_repr}\nParent:\n{parent_repr}" @@ -150,7 +166,7 @@ def _check_alignment( for child_name, child in children.items(): child_path = _join_path(path, child_name) - child_ds = child.to_dataset(local=True) + child_ds = child.to_dataset(inherited=False) _check_alignment(child_path, child_ds, base_ds, child.children) @@ -485,8 +501,9 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: f"parent {parent.name} already contains a variable named {name}" ) path = _join_path(parent.path, name) - node_ds = self.to_dataset(local=True) - _check_alignment(path, node_ds, parent.ds, self.children) + node_ds = self.to_dataset(inherited=False) + parent_ds = parent._to_dataset_view(rebuild_dims=False) + _check_alignment(path, node_ds, parent_ds, self.children) def _add_parent_maps(self: DataTree, parent: DataTree) -> None: self._coord_variables.maps.extend(parent._coord_variables.maps) @@ -510,6 +527,33 @@ def parent(self: DataTree, new_parent: DataTree) -> None: raise ValueError("Cannot set an unnamed node as a child of another node") self._set_parent(new_parent, self.name) + def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView: + variables = self._data_variables | self._coord_variables + if rebuild_dims: + dims = calculate_dimensions(variables) + else: + # Note: rebuild_dims=False can create technically invalid Dataset + # objects because it may not contain all dimensions on its direct + # member variables, e.g., consider: + # tree = DataTree.from_dict( + # { + # "/": xr.Dataset({"a": (("x",), [1, 2])}), # x has size 2 + # "/b/c": xr.Dataset({"d": (("x",), [3])}), # x has size1 + # } + # ) + # However, they are fine for internal use cases, for align() or + # building a repr(). + dims = dict(self._dims) + return DatasetView._from_dataset_state( + variables=variables, + coord_names=set(self._coord_variables), + dims=dims, + attrs=self._attrs, + indexes=dict(self._indexes), + encoding=self._encoding, + close=None, + ) + @property def ds(self) -> DatasetView: """ @@ -522,44 +566,40 @@ def ds(self) -> DatasetView: -------- DataTree.to_dataset """ - return DatasetView._from_dataset_state( - variables=self._data_variables | self._coord_variables, - coord_names=set(self._coord_variables), - dims=dict(self._dims), # always includes inherited dimensions - attrs=self._attrs, - indexes=dict(self._indexes), - encoding=self._encoding, - close=self._close, - ) + return self._to_dataset_view(rebuild_dims=True) @ds.setter def ds(self, data: Dataset | DataArray | None = None) -> None: ds = _coerce_to_dataset(data) self._replace_node(ds) - def to_dataset(self, local: bool = False) -> Dataset: + def to_dataset(self, inherited: bool = True) -> Dataset: """ Return the data in this node as a new xarray.Dataset object. Parameters ---------- - local : bool, optional - If True, only include coordinates and indexes defined at the level + inherited : bool, optional + If False, only include coordinates and indexes defined at the level of this DataTree node, excluding inherited coordinates. See Also -------- DataTree.ds """ - coord_vars = self._coord_variables.maps[0] if local else self._coord_variables + coord_vars = ( + self._coord_variables if inherited else self._coord_variables.maps[0] + ) variables = self._data_variables | coord_vars - dims = dict(self._dims.maps[0]) if local else calculate_dimensions(variables) + dims = ( + calculate_dimensions(variables) if inherited else dict(self._dims.maps[0]) + ) return Dataset._construct_direct( variables, set(coord_vars), dims, None if self._attrs is None else dict(self._attrs), - dict(self._indexes.maps[0] if local else self._indexes), + dict(self._indexes if inherited else self._indexes.maps[0]), None if self._encoding is None else dict(self._encoding), self._close, ) @@ -726,7 +766,7 @@ def _replace_node( children: dict[str, DataTree] | Default = _default, ) -> None: - ds = self.to_dataset(local=True) if data is _default else data + ds = self.to_dataset(inherited=False) if data is _default else data if children is _default: children = self._children @@ -735,7 +775,11 @@ def _replace_node( if child_name in ds.variables: raise ValueError(f"node already contains a variable named {child_name}") - parent_ds = self.parent.ds if self.parent is not None else None + parent_ds = ( + self.parent._to_dataset_view(rebuild_dims=False) + if self.parent is not None + else None + ) _check_alignment(self.path, ds, parent_ds, children) if data is not _default: diff --git a/xarray/core/datatree_io.py b/xarray/core/datatree_io.py index dbc80627cdd..36665a0d153 100644 --- a/xarray/core/datatree_io.py +++ b/xarray/core/datatree_io.py @@ -85,7 +85,7 @@ def _datatree_to_netcdf( unlimited_dims = {} for node in dt.subtree: - ds = node.to_dataset(local=True) + ds = node.to_dataset(inherited=False) group_path = node.path if ds is None: _create_empty_netcdf_group(filepath, group_path, mode, engine) @@ -151,7 +151,7 @@ def _datatree_to_zarr( ) for node in dt.subtree: - ds = node.to_dataset(local=True) + ds = node.to_dataset(inherited=False) group_path = node.path if ds is None: _create_empty_zarr_group(store, group_path, mode) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index c15df34b5b1..0a1ee49ca26 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -1024,7 +1024,7 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat): def _single_node_repr(node: DataTree) -> str: """Information about this node, not including its relationships to other nodes.""" if node.has_data or node.has_attrs: - ds_info = "\n" + repr(node.ds) + ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False)) else: ds_info = "" return f"Group: {node.path}{ds_info}" diff --git a/xarray/tests/test_backends_datatree.py b/xarray/tests/test_backends_datatree.py index b8eb22312e2..b4c4f481359 100644 --- a/xarray/tests/test_backends_datatree.py +++ b/xarray/tests/test_backends_datatree.py @@ -42,7 +42,7 @@ def test_to_netcdf_inherited_coords(self, tmpdir): roundtrip_dt = open_datatree(filepath, engine=self.engine) assert_equal(original_dt, roundtrip_dt) subtree = cast(DataTree, roundtrip_dt["/sub"]) - assert "x" not in subtree.to_dataset(local=True).coords + assert "x" not in subtree.to_dataset(inherited=False).coords def test_netcdf_encoding(self, tmpdir, simple_datatree): filepath = tmpdir / "test.nc" @@ -150,4 +150,4 @@ def test_to_zarr_inherited_coords(self, tmpdir): roundtrip_dt = open_datatree(filepath, engine="zarr") assert_equal(original_dt, roundtrip_dt) subtree = cast(DataTree, roundtrip_dt["/sub"]) - assert "x" not in subtree.to_dataset(local=True).coords + assert "x" not in subtree.to_dataset(inherited=False).coords diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 90c65d82c8c..766c9aac251 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1,3 +1,4 @@ +import re import typing from copy import copy, deepcopy from textwrap import dedent @@ -157,12 +158,12 @@ def test_to_dataset(self): tree = DataTree.from_dict({"/": base, "/sub": sub}) subtree = typing.cast(DataTree, tree["sub"]) - assert_identical(tree.to_dataset(local=True), base) - assert_identical(subtree.to_dataset(local=True), sub) + assert_identical(tree.to_dataset(inherited=False), base) + assert_identical(subtree.to_dataset(inherited=False), sub) sub_and_base = xr.Dataset(coords={"a": 1, "b": 2}) - assert_identical(tree.to_dataset(local=False), base) - assert_identical(subtree.to_dataset(local=False), sub_and_base) + assert_identical(tree.to_dataset(inherited=True), base) + assert_identical(subtree.to_dataset(inherited=True), sub_and_base) class TestVariablesChildrenNameCollisions: @@ -718,12 +719,13 @@ def test_inherited_dims(self): } ) assert dt.sizes == {"x": 2} + # nodes should include inherited dimensions assert dt.b.sizes == {"x": 2, "y": 1} assert dt.c.sizes == {"x": 2, "y": 3} - # .ds should also include inherit dims (used for alignment checking) - assert dt.b.ds.sizes == {"x": 2, "y": 1} - # .to_dataset() should not - assert dt.b.to_dataset().sizes == {"y": 1} + # dataset objects created from nodes should not + assert dt.b.ds.sizes == {"y": 1} + assert dt.b.to_dataset(inherited=True).sizes == {"y": 1} + assert dt.b.to_dataset(inherited=False).sizes == {"y": 1} def test_inherited_coords_index(self): dt = DataTree.from_dict( @@ -754,9 +756,27 @@ def test_inherited_coords_override(self): xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) def test_inconsistent_dims(self): - with pytest.raises( - ValueError, match="group '/b' is not aligned with its parent" - ): + expected_msg = ( + "^" + + re.escape( + dedent( + """ + group '/b' is not aligned with its parent: + Group: + Dimensions: (x: 1) + Dimensions without coordinates: x + Data variables: + c (x) int64 8B 3 + Parent: + Dimensions: (x: 2) + Dimensions without coordinates: x + """ + ).strip() + ) + + "$" + ) + + with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset({"a": (("x",), [1, 2])}), @@ -766,24 +786,40 @@ def test_inconsistent_dims(self): dt: DataTree = DataTree() dt["/a"] = xr.DataArray([1, 2], dims=["x"]) - with pytest.raises( - ValueError, match="group '/b' is not aligned with its parent" - ): + with pytest.raises(ValueError, match=expected_msg): dt["/b/c"] = xr.DataArray([3], dims=["x"]) b: DataTree = DataTree(data=xr.Dataset({"c": (("x",), [3])})) - with pytest.raises( - ValueError, match="group '/b' is not aligned with its parent" - ): + with pytest.raises(ValueError, match=expected_msg): DataTree( data=xr.Dataset({"a": (("x",), [1, 2])}), children={"b": b}, ) def test_inconsistent_child_indexes(self): - with pytest.raises( - ValueError, match="group '/b' is not aligned with its parent" - ): + expected_msg = ( + "^" + + re.escape( + dedent( + """ + group '/b' is not aligned with its parent: + Group: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 2 + Data variables: + *empty* + Parent: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 1 + """ + ).strip() + ) + + "$" + ) + + with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1]}), @@ -794,21 +830,37 @@ def test_inconsistent_child_indexes(self): dt: DataTree = DataTree() dt.ds = xr.Dataset(coords={"x": [1]}) # type: ignore dt["/b"] = DataTree() - with pytest.raises( - ValueError, match="group '/b' is not aligned with its parent" - ): + with pytest.raises(ValueError, match=expected_msg): dt["/b"].ds = xr.Dataset(coords={"x": [2]}) b: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) - with pytest.raises( - ValueError, match="group '/b' is not aligned with its parent" - ): + with pytest.raises(ValueError, match=expected_msg): DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) def test_inconsistent_grandchild_indexes(self): - with pytest.raises( - ValueError, match="group '/b/c' is not aligned with its parent" - ): + expected_msg = ( + "^" + + re.escape( + dedent( + """ + group '/b/c' is not aligned with its parent: + Group: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 2 + Data variables: + *empty* + Parent: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 1 + """ + ).strip() + ) + + "$" + ) + + with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset(coords={"x": [1]}), @@ -819,22 +871,36 @@ def test_inconsistent_grandchild_indexes(self): dt: DataTree = DataTree() dt.ds = xr.Dataset(coords={"x": [1]}) # type: ignore dt["/b/c"] = DataTree() - with pytest.raises( - ValueError, match="group '/b/c' is not aligned with its parent" - ): + with pytest.raises(ValueError, match=expected_msg): dt["/b/c"].ds = xr.Dataset(coords={"x": [2]}) c: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) b: DataTree = DataTree(children={"c": c}) - with pytest.raises( - ValueError, match="group '/b/c' is not aligned with its parent" - ): + with pytest.raises(ValueError, match=expected_msg): DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) def test_inconsistent_grandchild_dims(self): - with pytest.raises( - ValueError, match="group '/b/c' is not aligned with its parent" - ): + expected_msg = ( + "^" + + re.escape( + dedent( + """ + group '/b/c' is not aligned with its parent: + Group: + Dimensions: (x: 1) + Dimensions without coordinates: x + Data variables: + d (x) int64 8B 3 + Parent: + Dimensions: (x: 2) + Dimensions without coordinates: x + """ + ).strip() + ) + + "$" + ) + + with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { "/": xr.Dataset({"a": (("x",), [1, 2])}), @@ -844,9 +910,7 @@ def test_inconsistent_grandchild_dims(self): dt: DataTree = DataTree() dt["/a"] = xr.DataArray([1, 2], dims=["x"]) - with pytest.raises( - ValueError, match="group '/b/c' is not aligned with its parent" - ): + with pytest.raises(ValueError, match=expected_msg): dt["/b/c/d"] = xr.DataArray([3], dims=["x"]) From c282e62346eb4873268d514c4976e937d1a5b6d9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 27 Jun 2024 18:07:33 -0700 Subject: [PATCH 15/21] include inherited dimensions in HTML repr, too --- xarray/core/formatting_html.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 9bf5befbe3f..24b290031eb 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: def datatree_node_repr(group_title: str, dt: DataTree) -> str: header_components = [f"
{escape(group_title)}
"] - ds = dt.ds + ds = dt._to_dataset_view(rebuild_dims=False) sections = [ children_section(dt.children), From 6595e2fa1047b5ac41498549783f5e33f33c9d07 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 30 Jun 2024 15:45:30 -0700 Subject: [PATCH 16/21] Construct ChainMap objects on demand. --- xarray/core/datatree.py | 61 +++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 22719694fd7..59d126d4ebd 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -410,9 +410,9 @@ class DataTree( _children: dict[str, DataTree] _cache: dict[str, Any] # used by _CachedAccessor _data_variables: dict[Hashable, Variable] - _coord_variables: ChainMap[Hashable, Variable] - _dims: ChainMap[Hashable, int] - _indexes: ChainMap[Hashable, Index] + _node_coord_variables: dict[Hashable, Variable] + _node_dims: dict[Hashable, int] + _node_indexes: dict[Hashable, Index] _attrs: dict[Hashable, Any] | None _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None @@ -423,9 +423,9 @@ class DataTree( "_children", "_cache", # used by _CachedAccessor "_data_variables", - "_coord_variables", - "_dims", - "_indexes", + "_node_coord_variables", + "_node_dims", + "_node_indexes", "_attrs", "_encoding", "_close", @@ -475,9 +475,6 @@ def __init__( self._parent = None # set data attributes - self._coord_variables: ChainMap[Hashable, Variable] = ChainMap() - self._dims = ChainMap() - self._indexes = ChainMap() self._set_node_data(_coerce_to_dataset(data)) # finalize tree attributes @@ -487,9 +484,9 @@ def __init__( def _set_node_data(self, ds: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(ds) self._data_variables = data_vars - self._coord_variables.maps[0] = coord_vars - self._dims.maps[0] = ds._dims - self._indexes.maps[0] = ds._indexes + self._node_coord_variables = coord_vars + self._node_dims = ds._dims + self._node_indexes = ds._indexes self._encoding = ds._encoding self._attrs = ds._attrs self._close = ds._close @@ -505,16 +502,19 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: parent_ds = parent._to_dataset_view(rebuild_dims=False) _check_alignment(path, node_ds, parent_ds, self.children) - def _add_parent_maps(self: DataTree, parent: DataTree) -> None: - self._coord_variables.maps.extend(parent._coord_variables.maps) - self._dims.maps.extend(parent._dims.maps) - self._indexes.maps.extend(parent._indexes.maps) - for child in self._children.values(): - child._add_parent_maps(self) + @property + def _coord_variables(self) -> ChainMap[Hashable, Variable]: + return ChainMap( + self._node_coord_variables, *(p._node_coord_variables for p in self.parents) + ) - def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: - super()._post_attach(parent, name) - self._add_parent_maps(parent) + @property + def _dims(self) -> ChainMap[Hashable, int]: + return ChainMap(self._node_dims, *(p._node_dims for p in self.parents)) + + @property + def _indexes(self) -> ChainMap[Hashable, Index]: + return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents)) @property def parent(self: DataTree) -> DataTree | None: @@ -587,19 +587,15 @@ def to_dataset(self, inherited: bool = True) -> Dataset: -------- DataTree.ds """ - coord_vars = ( - self._coord_variables if inherited else self._coord_variables.maps[0] - ) + coord_vars = self._coord_variables if inherited else self._node_coord_variables variables = self._data_variables | coord_vars - dims = ( - calculate_dimensions(variables) if inherited else dict(self._dims.maps[0]) - ) + dims = calculate_dimensions(variables) if inherited else dict(self._node_dims) return Dataset._construct_direct( variables, set(coord_vars), dims, None if self._attrs is None else dict(self._attrs), - dict(self._indexes if inherited else self._indexes.maps[0]), + dict(self._indexes if inherited else self._node_indexes), None if self._encoding is None else dict(self._encoding), self._close, ) @@ -607,7 +603,7 @@ def to_dataset(self, inherited: bool = True) -> Dataset: @property def has_data(self) -> bool: """Whether or not there are any variables in this node.""" - return bool(self._data_variables or self._coord_variables.maps[0]) + return bool(self._data_variables or self._node_coord_variables) @property def has_attrs(self) -> bool: @@ -787,13 +783,6 @@ def _replace_node( self._children = children - if self.parent is not None: - assert self.name is not None - self._post_attach(self.parent, self.name) - else: - for child in children.values(): - child._add_parent_maps(self) - def copy( self: DataTree, deep: bool = False, From 88e5f4a0f0ebcbcbf88f493e5daf999b4271546b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 30 Jun 2024 16:01:34 -0700 Subject: [PATCH 17/21] slightly better error message with mis-aligned data trees --- xarray/core/datatree.py | 4 +- xarray/tests/test_datatree.py | 133 +++++++++++++++------------------- 2 files changed, 59 insertions(+), 78 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 59d126d4ebd..b541d8d1e4e 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -154,8 +154,8 @@ def _check_alignment( _drop_data_vars_and_attrs_sections(repr(parent_ds)) ) raise ValueError( - f"group {path!r} is not aligned with its parent:\n" - f"Group:\n{node_repr}\nParent:\n{parent_repr}" + f"group {path!r} is not aligned with its parents:\n" + f"Group:\n{node_repr}\nFrom parents:\n{parent_repr}" ) from e if children: diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 766c9aac251..c7825f485de 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -709,6 +709,11 @@ def test_repr(self): assert result == expected +def _exact_match(message: str) -> str: + return re.escape(dedent(message).strip()) + return "^" + re.escape(dedent(message.rstrip())) + "$" + + class TestInheritance: def test_inherited_dims(self): dt = DataTree.from_dict( @@ -756,24 +761,18 @@ def test_inherited_coords_override(self): xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords)) def test_inconsistent_dims(self): - expected_msg = ( - "^" - + re.escape( - dedent( - """ - group '/b' is not aligned with its parent: - Group: - Dimensions: (x: 1) - Dimensions without coordinates: x - Data variables: - c (x) int64 8B 3 - Parent: - Dimensions: (x: 2) - Dimensions without coordinates: x - """ - ).strip() - ) - + "$" + expected_msg = _exact_match( + """ + group '/b' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Dimensions without coordinates: x + Data variables: + c (x) int64 8B 3 + From parents: + Dimensions: (x: 2) + Dimensions without coordinates: x + """ ) with pytest.raises(ValueError, match=expected_msg): @@ -797,26 +796,20 @@ def test_inconsistent_dims(self): ) def test_inconsistent_child_indexes(self): - expected_msg = ( - "^" - + re.escape( - dedent( - """ - group '/b' is not aligned with its parent: - Group: - Dimensions: (x: 1) - Coordinates: - * x (x) int64 8B 2 - Data variables: - *empty* - Parent: - Dimensions: (x: 1) - Coordinates: - * x (x) int64 8B 1 - """ - ).strip() - ) - + "$" + expected_msg = _exact_match( + """ + group '/b' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 2 + Data variables: + *empty* + From parents: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 1 + """ ) with pytest.raises(ValueError, match=expected_msg): @@ -838,26 +831,20 @@ def test_inconsistent_child_indexes(self): DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) def test_inconsistent_grandchild_indexes(self): - expected_msg = ( - "^" - + re.escape( - dedent( - """ - group '/b/c' is not aligned with its parent: - Group: - Dimensions: (x: 1) - Coordinates: - * x (x) int64 8B 2 - Data variables: - *empty* - Parent: - Dimensions: (x: 1) - Coordinates: - * x (x) int64 8B 1 - """ - ).strip() - ) - + "$" + expected_msg = _exact_match( + """ + group '/b/c' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 2 + Data variables: + *empty* + From parents: + Dimensions: (x: 1) + Coordinates: + * x (x) int64 8B 1 + """ ) with pytest.raises(ValueError, match=expected_msg): @@ -880,24 +867,18 @@ def test_inconsistent_grandchild_indexes(self): DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) def test_inconsistent_grandchild_dims(self): - expected_msg = ( - "^" - + re.escape( - dedent( - """ - group '/b/c' is not aligned with its parent: - Group: - Dimensions: (x: 1) - Dimensions without coordinates: x - Data variables: - d (x) int64 8B 3 - Parent: - Dimensions: (x: 2) - Dimensions without coordinates: x - """ - ).strip() - ) - + "$" + expected_msg = _exact_match( + """ + group '/b/c' is not aligned with its parents: + Group: + Dimensions: (x: 1) + Dimensions without coordinates: x + Data variables: + d (x) int64 8B 3 + From parents: + Dimensions: (x: 2) + Dimensions without coordinates: x + """ ) with pytest.raises(ValueError, match=expected_msg): From 1f6c7b4dd7e4c780650120a9b969470dff3d8dc8 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 1 Jul 2024 08:55:11 -0700 Subject: [PATCH 18/21] mypy fix --- xarray/core/datatree.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index b541d8d1e4e..41bbb7158fa 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -528,7 +528,8 @@ def parent(self: DataTree, new_parent: DataTree) -> None: self._set_parent(new_parent, self.name) def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView: - variables = self._data_variables | self._coord_variables + variables = dict(self._data_variables) + variables |= self._coord_variables if rebuild_dims: dims = calculate_dimensions(variables) else: @@ -588,7 +589,8 @@ def to_dataset(self, inherited: bool = True) -> Dataset: DataTree.ds """ coord_vars = self._coord_variables if inherited else self._node_coord_variables - variables = self._data_variables | coord_vars + variables = dict(self._data_variables) + variables |= coord_vars dims = calculate_dimensions(variables) if inherited else dict(self._node_dims) return Dataset._construct_direct( variables, From 7bb5d869c687c341f004fc4dde521c87aecef8e2 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 1 Jul 2024 10:34:35 -0700 Subject: [PATCH 19/21] use float64 instead of float32 for windows --- xarray/tests/test_datatree.py | 56 +++++++++++++++++------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index c7825f485de..f2b58fa2489 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -768,7 +768,7 @@ def test_inconsistent_dims(self): Dimensions: (x: 1) Dimensions without coordinates: x Data variables: - c (x) int64 8B 3 + c (x) float64 8B 3.0 From parents: Dimensions: (x: 2) Dimensions without coordinates: x @@ -778,20 +778,20 @@ def test_inconsistent_dims(self): with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { - "/": xr.Dataset({"a": (("x",), [1, 2])}), - "/b": xr.Dataset({"c": (("x",), [3])}), + "/": xr.Dataset({"a": (("x",), [1.0, 2.0])}), + "/b": xr.Dataset({"c": (("x",), [3.0])}), } ) dt: DataTree = DataTree() - dt["/a"] = xr.DataArray([1, 2], dims=["x"]) + dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): - dt["/b/c"] = xr.DataArray([3], dims=["x"]) + dt["/b/c"] = xr.DataArray([3.0], dims=["x"]) - b: DataTree = DataTree(data=xr.Dataset({"c": (("x",), [3])})) + b: DataTree = DataTree(data=xr.Dataset({"c": (("x",), [3.0])})) with pytest.raises(ValueError, match=expected_msg): DataTree( - data=xr.Dataset({"a": (("x",), [1, 2])}), + data=xr.Dataset({"a": (("x",), [1.0, 2.0])}), children={"b": b}, ) @@ -802,33 +802,33 @@ def test_inconsistent_child_indexes(self): Group: Dimensions: (x: 1) Coordinates: - * x (x) int64 8B 2 + * x (x) float64 8B 2.0 Data variables: *empty* From parents: Dimensions: (x: 1) Coordinates: - * x (x) int64 8B 1 + * x (x) float64 8B 1.0 """ ) with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { - "/": xr.Dataset(coords={"x": [1]}), - "/b": xr.Dataset(coords={"x": [2]}), + "/": xr.Dataset(coords={"x": [1.0]}), + "/b": xr.Dataset(coords={"x": [2.0]}), } ) dt: DataTree = DataTree() - dt.ds = xr.Dataset(coords={"x": [1]}) # type: ignore + dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore dt["/b"] = DataTree() with pytest.raises(ValueError, match=expected_msg): - dt["/b"].ds = xr.Dataset(coords={"x": [2]}) + dt["/b"].ds = xr.Dataset(coords={"x": [2.0]}) - b: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) + b: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) with pytest.raises(ValueError, match=expected_msg): - DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) + DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) def test_inconsistent_grandchild_indexes(self): expected_msg = _exact_match( @@ -837,34 +837,34 @@ def test_inconsistent_grandchild_indexes(self): Group: Dimensions: (x: 1) Coordinates: - * x (x) int64 8B 2 + * x (x) float64 8B 2.0 Data variables: *empty* From parents: Dimensions: (x: 1) Coordinates: - * x (x) int64 8B 1 + * x (x) float64 8B 1.0 """ ) with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { - "/": xr.Dataset(coords={"x": [1]}), - "/b/c": xr.Dataset(coords={"x": [2]}), + "/": xr.Dataset(coords={"x": [1.0]}), + "/b/c": xr.Dataset(coords={"x": [2.0]}), } ) dt: DataTree = DataTree() - dt.ds = xr.Dataset(coords={"x": [1]}) # type: ignore + dt.ds = xr.Dataset(coords={"x": [1.0]}) # type: ignore dt["/b/c"] = DataTree() with pytest.raises(ValueError, match=expected_msg): - dt["/b/c"].ds = xr.Dataset(coords={"x": [2]}) + dt["/b/c"].ds = xr.Dataset(coords={"x": [2.0]}) - c: DataTree = DataTree(xr.Dataset(coords={"x": [2]})) + c: DataTree = DataTree(xr.Dataset(coords={"x": [2.0]})) b: DataTree = DataTree(children={"c": c}) with pytest.raises(ValueError, match=expected_msg): - DataTree(data=xr.Dataset(coords={"x": [1]}), children={"b": b}) + DataTree(data=xr.Dataset(coords={"x": [1.0]}), children={"b": b}) def test_inconsistent_grandchild_dims(self): expected_msg = _exact_match( @@ -874,7 +874,7 @@ def test_inconsistent_grandchild_dims(self): Dimensions: (x: 1) Dimensions without coordinates: x Data variables: - d (x) int64 8B 3 + d (x) float64 8B 3.0 From parents: Dimensions: (x: 2) Dimensions without coordinates: x @@ -884,15 +884,15 @@ def test_inconsistent_grandchild_dims(self): with pytest.raises(ValueError, match=expected_msg): DataTree.from_dict( { - "/": xr.Dataset({"a": (("x",), [1, 2])}), - "/b/c": xr.Dataset({"d": (("x",), [3])}), + "/": xr.Dataset({"a": (("x",), [1.0, 2.0])}), + "/b/c": xr.Dataset({"d": (("x",), [3.0])}), } ) dt: DataTree = DataTree() - dt["/a"] = xr.DataArray([1, 2], dims=["x"]) + dt["/a"] = xr.DataArray([1.0, 2.0], dims=["x"]) with pytest.raises(ValueError, match=expected_msg): - dt["/b/c/d"] = xr.DataArray([3], dims=["x"]) + dt["/b/c/d"] = xr.DataArray([3.0], dims=["x"]) class TestRestructuring: From 57e598e3c99971b1079d6e979957bdc5d99efd32 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 1 Jul 2024 15:39:59 -0700 Subject: [PATCH 20/21] clean-up per review --- xarray/core/datatree.py | 47 +++++++++++++-------------------------- xarray/core/formatting.py | 21 +++++++++++++++++ xarray/core/treenode.py | 4 ---- 3 files changed, 37 insertions(+), 35 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 41bbb7158fa..cdf28f3be4c 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -33,7 +33,7 @@ MappedDataWithCoords, ) from xarray.core.datatree_render import RenderDataTree -from xarray.core.formatting import datatree_repr +from xarray.core.formatting import datatree_repr, dims_and_coords_repr from xarray.core.formatting_html import ( datatree_repr as datatree_repr_html, ) @@ -109,7 +109,7 @@ def _coerce_to_dataset(data: Dataset | DataArray | None) -> Dataset: def _join_path(root: str, name: str) -> str: - return root.rstrip("/") + "/" + name + return str(NodePath(root) / name) def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset: @@ -124,19 +124,12 @@ def _inherited_dataset(ds: Dataset, parent: Dataset) -> Dataset: ) -def _indented_without_header(text: str) -> str: - return textwrap.indent("\n".join(text.split("\n")[1:]), prefix=" ") +def _without_header(text: str) -> str: + return "\n".join(text.split("\n")[1:]) -def _drop_data_vars_and_attrs_sections(text: str) -> str: - lines = text.split("\n") - outputs = [] - match = "Data variables:" - for line in lines: - if line[: len(match)] == match: - break - outputs.append(line) - return "\n".join(outputs) +def _indented(text: str) -> str: + return textwrap.indent(text, prefix=" ") def _check_alignment( @@ -149,10 +142,8 @@ def _check_alignment( try: align(node_ds, parent_ds, join="exact") except ValueError as e: - node_repr = _indented_without_header(repr(node_ds)) - parent_repr = _indented_without_header( - _drop_data_vars_and_attrs_sections(repr(parent_ds)) - ) + node_repr = _indented(_without_header(repr(node_ds))) + parent_repr = _indented(dims_and_coords_repr(parent_ds)) raise ValueError( f"group {path!r} is not aligned with its parents:\n" f"Group:\n{node_repr}\nFrom parents:\n{parent_repr}" @@ -165,7 +156,7 @@ def _check_alignment( base_ds = node_ds for child_name, child in children.items(): - child_path = _join_path(path, child_name) + child_path = str(NodePath(path) / child_name) child_ds = child.to_dataset(inherited=False) _check_alignment(child_path, child_ds, base_ds, child.children) @@ -203,7 +194,7 @@ def __init__( raise AttributeError("DatasetView objects are not to be initialized directly") @classmethod - def _from_dataset_state( + def _constructor( cls, variables: dict[Any, Variable], coord_names: set[Hashable], @@ -213,7 +204,9 @@ def _from_dataset_state( encoding: dict | None, close: Callable[[], None] | None, ) -> DatasetView: - """Constructor, using dataset attributes from wrapping node""" + """Private constructor, from Dataset attributes.""" + # We override Dataset._construct_direct below, so we need a new + # constructor for creating DatasetView objects. obj: DatasetView = object.__new__(cls) obj._variables = variables obj._coord_names = coord_names @@ -469,17 +462,9 @@ def __init__( children = {} super().__init__(name=name) - - # set tree attributes - self._children = {} - self._parent = None - - # set data attributes self._set_node_data(_coerce_to_dataset(data)) - - # finalize tree attributes - self.children = children # must set first self.parent = parent + self.children = children def _set_node_data(self, ds: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(ds) @@ -497,7 +482,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: raise KeyError( f"parent {parent.name} already contains a variable named {name}" ) - path = _join_path(parent.path, name) + path = str(NodePath(parent.path) / name) node_ds = self.to_dataset(inherited=False) parent_ds = parent._to_dataset_view(rebuild_dims=False) _check_alignment(path, node_ds, parent_ds, self.children) @@ -545,7 +530,7 @@ def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView: # However, they are fine for internal use cases, for align() or # building a repr(). dims = dict(self._dims) - return DatasetView._from_dataset_state( + return DatasetView._constructor( variables=variables, coord_names=set(self._coord_variables), dims=dims, diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 77fd0200525..6dca4eba8e8 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -748,6 +748,27 @@ def dataset_repr(ds): return "\n".join(summary) +def dims_and_coords_repr(ds) -> str: + """Partial Dataset repr for use inside DataTree inheritance errors.""" + summary = [] + + col_width = _calculate_col_width(ds.coords) + max_rows = OPTIONS["display_max_rows"] + + dims_start = pretty_print("Dimensions:", col_width) + dims_values = dim_summary_limited(ds, col_width=col_width + 1, max_rows=max_rows) + summary.append(f"{dims_start}({dims_values})") + + if ds.coords: + summary.append(coords_repr(ds.coords, col_width=col_width, max_rows=max_rows)) + + unindexed_dims_str = unindexed_dims_repr(ds.dims, ds.coords, max_rows=max_rows) + if unindexed_dims_str: + summary.append(unindexed_dims_str) + + return "\n".join(summary) + + def diff_dim_summary(a, b): if a.sizes != b.sizes: return f"Differing dimensions:\n ({dim_summary(a)}) != ({dim_summary(b)})" diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 3017b59d978..77e7ed23a51 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -609,12 +609,8 @@ def __repr__(self, level=0): def __str__(self) -> str: return f"NamedNode('{self.name}')" if self.name else "NamedNode()" - def _get_name_in_parent(self: AnyNamedNode, parent: AnyNamedNode) -> str: - return next(k for k, v in parent.children.items() if v is self) - def _post_attach(self: AnyNamedNode, parent: AnyNamedNode, name: str) -> None: """Ensures child has name attribute corresponding to key under which it has been stored.""" - # self.name = self._get_name_in_parent(parent self.name = name @property From 6cfd8ff7da5cbadb14709d9a70cf13419e076c29 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 2 Jul 2024 16:55:53 -0700 Subject: [PATCH 21/21] Add note about inheritance to .ds docs --- xarray/core/datatree.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index cdf28f3be4c..38f8f8cd495 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -545,6 +545,8 @@ def ds(self) -> DatasetView: """ An immutable Dataset-like view onto the data in this node. + Includes inherited coordinates and indexes from parent nodes. + For a mutable Dataset containing the same data as in this node, use `.to_dataset()` instead.