Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add copy option in DataTree.from_dict #9193

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
14 changes: 14 additions & 0 deletions asv_bench/benchmarks/datatree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import xarray as xr
from xarray.core.datatree import DataTree

from . import parameterized


class Datatree:
def setup(self):
run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})})
self.d = {"run1": run1}

@parameterized(["fastpath"], [(False, True)])
def time_from_dict(self, fastpath: bool):
DataTree.from_dict(self.d, fastpath=fastpath)
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ v2024.07.1 (unreleased)

New Features
~~~~~~~~~~~~
- Add optional parameter `fastpath` to :py:func:`xarray.core.datatree.DataTree.from_dict`,
which by default shallow copies the values in the dict. (:pull:`9193`)
By `Jimmy Westling <https://github.com/illviljan>`_.
- Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`).
By `Martin Raspaud <https://github.com/mraspaud>`_.


Breaking changes
Expand Down
15 changes: 11 additions & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,15 +890,17 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
else:
raise ValueError(f"Invalid format for key: {key}")

def _set(self, key: str, val: DataTree | CoercibleValue) -> None:
def _set(
self, key: str, val: DataTree | CoercibleValue, *, fastpath: bool = False
) -> None:
"""
Set the child node or variable with the specified key to value.

Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree.
"""
if isinstance(val, DataTree):
# create and assign a shallow copy here so as not to alter original name of node in grafted tree
new_node = val.copy(deep=False)
new_node = val.copy(deep=False) if not fastpath else val
new_node.name = key
new_node.parent = self
else:
Expand Down Expand Up @@ -1069,6 +1071,8 @@ def from_dict(
cls,
d: MutableMapping[str, Dataset | DataArray | DataTree | None],
name: str | None = None,
*,
fastpath: bool = False,
) -> DataTree:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.
Expand All @@ -1084,6 +1088,8 @@ def from_dict(
To assign data to the root node of the tree use "/" as the path.
name : Hashable | None, optional
Name for the root node of the tree. Default is None.
fastpath : bool, optional
Whether to bypass checks and avoid shallow copies of the values in the dict. Default is False.

Returns
-------
Expand All @@ -1097,7 +1103,7 @@ def from_dict(
# First create the root node
root_data = d.pop("/", None)
if isinstance(root_data, DataTree):
obj = root_data.copy()
obj = root_data.copy() if not fastpath else root_data
obj.orphan()
else:
obj = cls(name=name, data=root_data, parent=None, children=None)
Expand All @@ -1113,7 +1119,7 @@ def depth(item) -> int:
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
new_node = data.copy() if not fastpath else data
new_node.orphan()
else:
new_node = cls(name=node_name, data=data)
Expand All @@ -1122,6 +1128,7 @@ def depth(item) -> int:
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
fastpath=fastpath,
)

return obj
Expand Down
12 changes: 8 additions & 4 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray:
current_node = current_node.get(part)
return current_node

def _set(self: Tree, key: str, val: Tree) -> None:
def _set(self: Tree, key: str, val: Tree, *, fastpath: bool = False) -> None:
"""
Set the child node with the specified key to value.

Expand All @@ -483,6 +483,8 @@ def _set_item(
item: Tree | T_DataArray,
new_nodes_along_path: bool = False,
allow_overwrite: bool = True,
*,
fastpath: bool = False,
) -> None:
"""
Set a new item in the tree, overwriting anything already present at that path.
Expand All @@ -500,6 +502,8 @@ def _set_item(
allow_overwrite : bool
Whether or not to overwrite any existing node at the location given
by path.
fastpath : bool, optional
Whether to bypass checks and avoid shallow copies of the values in the dict. Default is False.

Raises
------
Expand Down Expand Up @@ -539,19 +543,19 @@ def _set_item(
elif new_nodes_along_path:
# Want child classes (i.e. DataTree) to populate tree with their own types
new_node = type(self)()
current_node._set(part, new_node)
current_node._set(part, new_node, fastpath=fastpath)
current_node = current_node.children[part]
else:
raise KeyError(f"Could not reach node at path {path}")

if name in current_node.children:
# Deal with anything already existing at this location
if allow_overwrite:
current_node._set(name, item)
current_node._set(name, item, fastpath=fastpath)
else:
raise KeyError(f"Already a node object at path {path}")
else:
current_node._set(name, item)
current_node._set(name, item, fastpath=fastpath)

def __delitem__(self: Tree, key: str):
"""Remove a child node from this tree object."""
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,13 @@ def test_roundtrip_unnamed_root(self, simple_datatree):
roundtrip = DataTree.from_dict(dt.to_dict())
assert roundtrip.equals(dt)

@pytest.mark.parametrize("fastpath", [False, True])
def test_fastpath(self, fastpath: bool) -> None:
run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})})
dt = DataTree.from_dict({"run1": run1}, fastpath=fastpath)
is_exact = dt["run1"] is run1
assert is_exact is fastpath

def test_insertion_order(self):
# regression test for GH issue #9276
reversed = DataTree.from_dict(
Expand Down
Loading