diff --git a/asv_bench/benchmarks/datatree.py b/asv_bench/benchmarks/datatree.py new file mode 100644 index 00000000000..3db0c708225 --- /dev/null +++ b/asv_bench/benchmarks/datatree.py @@ -0,0 +1,14 @@ +import xarray as xr +from xarray.core.datatree import DataTree + +from . import parameterized + + +class Datatree: + def setup(self): + run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) + self.d = {"run1": run1} + + @parameterized(["fastpath"], [(False, True)]) + def time_from_dict(self, fastpath: bool): + DataTree.from_dict(self.d, fastpath=fastpath) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4cd34c4cf54..94014f9c489 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,11 @@ v2024.07.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add optional parameter `fastpath` to :py:func:`xarray.core.datatree.DataTree.from_dict`, + which by default shallow copies the values in the dict. (:pull:`9193`) + By `Jimmy Westling `_. +- Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`). + By `Martin Raspaud `_. Breaking changes diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 6289146308e..9f38bffddb3 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -890,7 +890,9 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray: else: raise ValueError(f"Invalid format for key: {key}") - def _set(self, key: str, val: DataTree | CoercibleValue) -> None: + def _set( + self, key: str, val: DataTree | CoercibleValue, *, fastpath: bool = False + ) -> None: """ Set the child node or variable with the specified key to value. @@ -898,7 +900,7 @@ def _set(self, key: str, val: DataTree | CoercibleValue) -> None: """ if isinstance(val, DataTree): # create and assign a shallow copy here so as not to alter original name of node in grafted tree - new_node = val.copy(deep=False) + new_node = val.copy(deep=False) if not fastpath else val new_node.name = key new_node.parent = self else: @@ -1069,6 +1071,8 @@ def from_dict( cls, d: MutableMapping[str, Dataset | DataArray | DataTree | None], name: str | None = None, + *, + fastpath: bool = False, ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1084,6 +1088,8 @@ def from_dict( To assign data to the root node of the tree use "/" as the path. name : Hashable | None, optional Name for the root node of the tree. Default is None. + fastpath : bool, optional + Whether to bypass checks and avoid shallow copies of the values in the dict. Default is False. Returns ------- @@ -1097,7 +1103,7 @@ def from_dict( # First create the root node root_data = d.pop("/", None) if isinstance(root_data, DataTree): - obj = root_data.copy() + obj = root_data.copy() if not fastpath else root_data obj.orphan() else: obj = cls(name=name, data=root_data, parent=None, children=None) @@ -1113,7 +1119,7 @@ def depth(item) -> int: # Create and set new node node_name = NodePath(path).name if isinstance(data, DataTree): - new_node = data.copy() + new_node = data.copy() if not fastpath else data new_node.orphan() else: new_node = cls(name=node_name, data=data) @@ -1122,6 +1128,7 @@ def depth(item) -> int: new_node, allow_overwrite=False, new_nodes_along_path=True, + fastpath=fastpath, ) return obj diff --git a/xarray/core/treenode.py b/xarray/core/treenode.py index 77e7ed23a51..1a8f13f8c88 100644 --- a/xarray/core/treenode.py +++ b/xarray/core/treenode.py @@ -468,7 +468,7 @@ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray: current_node = current_node.get(part) return current_node - def _set(self: Tree, key: str, val: Tree) -> None: + def _set(self: Tree, key: str, val: Tree, *, fastpath: bool = False) -> None: """ Set the child node with the specified key to value. @@ -483,6 +483,8 @@ def _set_item( item: Tree | T_DataArray, new_nodes_along_path: bool = False, allow_overwrite: bool = True, + *, + fastpath: bool = False, ) -> None: """ Set a new item in the tree, overwriting anything already present at that path. @@ -500,6 +502,8 @@ def _set_item( allow_overwrite : bool Whether or not to overwrite any existing node at the location given by path. + fastpath : bool, optional + Whether to bypass checks and avoid shallow copies of the values in the dict. Default is False. Raises ------ @@ -539,7 +543,7 @@ def _set_item( elif new_nodes_along_path: # Want child classes (i.e. DataTree) to populate tree with their own types new_node = type(self)() - current_node._set(part, new_node) + current_node._set(part, new_node, fastpath=fastpath) current_node = current_node.children[part] else: raise KeyError(f"Could not reach node at path {path}") @@ -547,11 +551,11 @@ def _set_item( if name in current_node.children: # Deal with anything already existing at this location if allow_overwrite: - current_node._set(name, item) + current_node._set(name, item, fastpath=fastpath) else: raise KeyError(f"Already a node object at path {path}") else: - current_node._set(name, item) + current_node._set(name, item, fastpath=fastpath) def __delitem__(self: Tree, key: str): """Remove a child node from this tree object.""" diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index c875322b9c5..3f9335b3c77 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -561,6 +561,13 @@ def test_roundtrip_unnamed_root(self, simple_datatree): roundtrip = DataTree.from_dict(dt.to_dict()) assert roundtrip.equals(dt) + @pytest.mark.parametrize("fastpath", [False, True]) + def test_fastpath(self, fastpath: bool) -> None: + run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})}) + dt = DataTree.from_dict({"run1": run1}, fastpath=fastpath) + is_exact = dt["run1"] is run1 + assert is_exact is fastpath + def test_insertion_order(self): # regression test for GH issue #9276 reversed = DataTree.from_dict(