From 6595e2fa1047b5ac41498549783f5e33f33c9d07 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 30 Jun 2024 15:45:30 -0700 Subject: [PATCH] Construct ChainMap objects on demand. --- xarray/core/datatree.py | 61 +++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 36 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 22719694fd7..59d126d4ebd 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -410,9 +410,9 @@ class DataTree( _children: dict[str, DataTree] _cache: dict[str, Any] # used by _CachedAccessor _data_variables: dict[Hashable, Variable] - _coord_variables: ChainMap[Hashable, Variable] - _dims: ChainMap[Hashable, int] - _indexes: ChainMap[Hashable, Index] + _node_coord_variables: dict[Hashable, Variable] + _node_dims: dict[Hashable, int] + _node_indexes: dict[Hashable, Index] _attrs: dict[Hashable, Any] | None _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None @@ -423,9 +423,9 @@ class DataTree( "_children", "_cache", # used by _CachedAccessor "_data_variables", - "_coord_variables", - "_dims", - "_indexes", + "_node_coord_variables", + "_node_dims", + "_node_indexes", "_attrs", "_encoding", "_close", @@ -475,9 +475,6 @@ def __init__( self._parent = None # set data attributes - self._coord_variables: ChainMap[Hashable, Variable] = ChainMap() - self._dims = ChainMap() - self._indexes = ChainMap() self._set_node_data(_coerce_to_dataset(data)) # finalize tree attributes @@ -487,9 +484,9 @@ def __init__( def _set_node_data(self, ds: Dataset): data_vars, coord_vars = _collect_data_and_coord_variables(ds) self._data_variables = data_vars - self._coord_variables.maps[0] = coord_vars - self._dims.maps[0] = ds._dims - self._indexes.maps[0] = ds._indexes + self._node_coord_variables = coord_vars + self._node_dims = ds._dims + self._node_indexes = ds._indexes self._encoding = ds._encoding self._attrs = ds._attrs self._close = ds._close @@ -505,16 +502,19 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: parent_ds = parent._to_dataset_view(rebuild_dims=False) _check_alignment(path, node_ds, parent_ds, self.children) - def _add_parent_maps(self: DataTree, parent: DataTree) -> None: - self._coord_variables.maps.extend(parent._coord_variables.maps) - self._dims.maps.extend(parent._dims.maps) - self._indexes.maps.extend(parent._indexes.maps) - for child in self._children.values(): - child._add_parent_maps(self) + @property + def _coord_variables(self) -> ChainMap[Hashable, Variable]: + return ChainMap( + self._node_coord_variables, *(p._node_coord_variables for p in self.parents) + ) - def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: - super()._post_attach(parent, name) - self._add_parent_maps(parent) + @property + def _dims(self) -> ChainMap[Hashable, int]: + return ChainMap(self._node_dims, *(p._node_dims for p in self.parents)) + + @property + def _indexes(self) -> ChainMap[Hashable, Index]: + return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents)) @property def parent(self: DataTree) -> DataTree | None: @@ -587,19 +587,15 @@ def to_dataset(self, inherited: bool = True) -> Dataset: -------- DataTree.ds """ - coord_vars = ( - self._coord_variables if inherited else self._coord_variables.maps[0] - ) + coord_vars = self._coord_variables if inherited else self._node_coord_variables variables = self._data_variables | coord_vars - dims = ( - calculate_dimensions(variables) if inherited else dict(self._dims.maps[0]) - ) + dims = calculate_dimensions(variables) if inherited else dict(self._node_dims) return Dataset._construct_direct( variables, set(coord_vars), dims, None if self._attrs is None else dict(self._attrs), - dict(self._indexes if inherited else self._indexes.maps[0]), + dict(self._indexes if inherited else self._node_indexes), None if self._encoding is None else dict(self._encoding), self._close, ) @@ -607,7 +603,7 @@ def to_dataset(self, inherited: bool = True) -> Dataset: @property def has_data(self) -> bool: """Whether or not there are any variables in this node.""" - return bool(self._data_variables or self._coord_variables.maps[0]) + return bool(self._data_variables or self._node_coord_variables) @property def has_attrs(self) -> bool: @@ -787,13 +783,6 @@ def _replace_node( self._children = children - if self.parent is not None: - assert self.name is not None - self._post_attach(self.parent, self.name) - else: - for child in children.values(): - child._add_parent_maps(self) - def copy( self: DataTree, deep: bool = False,