Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add copy option in DataTree.from_dict #9193

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
13 changes: 9 additions & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,15 +882,17 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
else:
raise ValueError(f"Invalid format for key: {key}")

def _set(self, key: str, val: DataTree | CoercibleValue) -> None:
def _set(
self, key: str, val: DataTree | CoercibleValue, *, copy: bool = True
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Set the child node or variable with the specified key to value.

Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree.
"""
if isinstance(val, DataTree):
# create and assign a shallow copy here so as not to alter original name of node in grafted tree
new_node = val.copy(deep=False)
new_node = val.copy(deep=False) if copy else val
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
new_node.name = key
new_node.parent = self
else:
Expand Down Expand Up @@ -1052,6 +1054,8 @@ def from_dict(
cls,
d: MutableMapping[str, Dataset | DataArray | DataTree | None],
name: str | None = None,
*,
copy: bool = True,
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
) -> DataTree:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.
Expand Down Expand Up @@ -1080,7 +1084,7 @@ def from_dict(
# First create the root node
root_data = d.pop("/", None)
if isinstance(root_data, DataTree):
obj = root_data.copy()
obj = root_data.copy() if copy else root_data
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
obj.orphan()
else:
obj = cls(name=name, data=root_data, parent=None, children=None)
Expand All @@ -1091,7 +1095,7 @@ def from_dict(
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
new_node = data.copy() if copy else data
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
new_node.orphan()
else:
new_node = cls(name=node_name, data=data)
Expand All @@ -1100,6 +1104,7 @@ def from_dict(
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
copy=copy,
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
)

return obj
Expand Down
10 changes: 6 additions & 4 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray:
current_node = current_node.get(part)
return current_node

def _set(self: Tree, key: str, val: Tree) -> None:
def _set(self: Tree, key: str, val: Tree, *, copy: bool = True) -> None:
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
"""
Set the child node with the specified key to value.

Expand All @@ -483,6 +483,8 @@ def _set_item(
item: Tree | T_DataArray,
new_nodes_along_path: bool = False,
allow_overwrite: bool = True,
*,
copy: bool = True,
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Set a new item in the tree, overwriting anything already present at that path.
Expand Down Expand Up @@ -539,19 +541,19 @@ def _set_item(
elif new_nodes_along_path:
# Want child classes (i.e. DataTree) to populate tree with their own types
new_node = type(self)()
current_node._set(part, new_node)
current_node._set(part, new_node, copy=copy)
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
current_node = current_node.children[part]
else:
raise KeyError(f"Could not reach node at path {path}")

if name in current_node.children:
# Deal with anything already existing at this location
if allow_overwrite:
current_node._set(name, item)
current_node._set(name, item, copy=copy)
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
else:
raise KeyError(f"Already a node object at path {path}")
else:
current_node._set(name, item)
current_node._set(name, item, copy=copy)
Illviljan marked this conversation as resolved.
Show resolved Hide resolved

def __delitem__(self: Tree, key: str):
"""Remove a child node from this tree object."""
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,13 @@ def test_roundtrip_unnamed_root(self, simple_datatree):
roundtrip = DataTree.from_dict(dt.to_dict())
assert roundtrip.equals(dt)

@pytest.mark.parametrize("copy", [True, False])
def test_copy(self, copy: bool) -> None:
run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})})
dt = DataTree.from_dict({"run1": run1}, copy=copy)
is_exact = dt["run1"] is run1
assert is_exact is not copy
Illviljan marked this conversation as resolved.
Show resolved Hide resolved


class TestDatasetView:
def test_view_contents(self):
Expand Down
Loading