Skip to content

Commit

Permalink
Construct ChainMap objects on demand.
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Jun 30, 2024
1 parent c282e62 commit 6595e2f
Showing 1 changed file with 25 additions and 36 deletions.
61 changes: 25 additions & 36 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,9 @@ class DataTree(
_children: dict[str, DataTree]
_cache: dict[str, Any] # used by _CachedAccessor
_data_variables: dict[Hashable, Variable]
_coord_variables: ChainMap[Hashable, Variable]
_dims: ChainMap[Hashable, int]
_indexes: ChainMap[Hashable, Index]
_node_coord_variables: dict[Hashable, Variable]
_node_dims: dict[Hashable, int]
_node_indexes: dict[Hashable, Index]
_attrs: dict[Hashable, Any] | None
_encoding: dict[Hashable, Any] | None
_close: Callable[[], None] | None
Expand All @@ -423,9 +423,9 @@ class DataTree(
"_children",
"_cache", # used by _CachedAccessor
"_data_variables",
"_coord_variables",
"_dims",
"_indexes",
"_node_coord_variables",
"_node_dims",
"_node_indexes",
"_attrs",
"_encoding",
"_close",
Expand Down Expand Up @@ -475,9 +475,6 @@ def __init__(
self._parent = None

# set data attributes
self._coord_variables: ChainMap[Hashable, Variable] = ChainMap()
self._dims = ChainMap()
self._indexes = ChainMap()
self._set_node_data(_coerce_to_dataset(data))

# finalize tree attributes
Expand All @@ -487,9 +484,9 @@ def __init__(
def _set_node_data(self, ds: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(ds)
self._data_variables = data_vars
self._coord_variables.maps[0] = coord_vars
self._dims.maps[0] = ds._dims
self._indexes.maps[0] = ds._indexes
self._node_coord_variables = coord_vars
self._node_dims = ds._dims
self._node_indexes = ds._indexes
self._encoding = ds._encoding
self._attrs = ds._attrs
self._close = ds._close
Expand All @@ -505,16 +502,19 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
parent_ds = parent._to_dataset_view(rebuild_dims=False)
_check_alignment(path, node_ds, parent_ds, self.children)

def _add_parent_maps(self: DataTree, parent: DataTree) -> None:
self._coord_variables.maps.extend(parent._coord_variables.maps)
self._dims.maps.extend(parent._dims.maps)
self._indexes.maps.extend(parent._indexes.maps)
for child in self._children.values():
child._add_parent_maps(self)
@property
def _coord_variables(self) -> ChainMap[Hashable, Variable]:
return ChainMap(
self._node_coord_variables, *(p._node_coord_variables for p in self.parents)
)

def _post_attach(self: DataTree, parent: DataTree, name: str) -> None:
super()._post_attach(parent, name)
self._add_parent_maps(parent)
@property
def _dims(self) -> ChainMap[Hashable, int]:
return ChainMap(self._node_dims, *(p._node_dims for p in self.parents))

@property
def _indexes(self) -> ChainMap[Hashable, Index]:
return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents))

@property
def parent(self: DataTree) -> DataTree | None:
Expand Down Expand Up @@ -587,27 +587,23 @@ def to_dataset(self, inherited: bool = True) -> Dataset:
--------
DataTree.ds
"""
coord_vars = (
self._coord_variables if inherited else self._coord_variables.maps[0]
)
coord_vars = self._coord_variables if inherited else self._node_coord_variables
variables = self._data_variables | coord_vars
dims = (
calculate_dimensions(variables) if inherited else dict(self._dims.maps[0])
)
dims = calculate_dimensions(variables) if inherited else dict(self._node_dims)
return Dataset._construct_direct(
variables,
set(coord_vars),
dims,
None if self._attrs is None else dict(self._attrs),
dict(self._indexes if inherited else self._indexes.maps[0]),
dict(self._indexes if inherited else self._node_indexes),
None if self._encoding is None else dict(self._encoding),
self._close,
)

@property
def has_data(self) -> bool:
"""Whether or not there are any variables in this node."""
return bool(self._data_variables or self._coord_variables.maps[0])
return bool(self._data_variables or self._node_coord_variables)

@property
def has_attrs(self) -> bool:
Expand Down Expand Up @@ -787,13 +783,6 @@ def _replace_node(

self._children = children

if self.parent is not None:
assert self.name is not None
self._post_attach(self.parent, self.name)
else:
for child in children.values():
child._add_parent_maps(self)

def copy(
self: DataTree,
deep: bool = False,
Expand Down

0 comments on commit 6595e2f

Please sign in to comment.