Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attrs argument to DataTree constructor #9242

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion xarray/core/datatree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import itertools
import textwrap
from collections import ChainMap
Expand Down Expand Up @@ -409,6 +410,7 @@ class DataTree(
_attrs: dict[Hashable, Any] | None
_encoding: dict[Hashable, Any] | None
_close: Callable[[], None] | None
_attrs: dict[Hashable, Any] | None

__slots__ = (
"_name",
Expand All @@ -422,6 +424,7 @@ class DataTree(
"_attrs",
"_encoding",
"_close",
"_attrs",
)

def __init__(
Expand All @@ -430,6 +433,7 @@ def __init__(
parent: DataTree | None = None,
children: Mapping[str, DataTree] | None = None,
name: str | None = None,
attrs: Mapping[Any, Any] | None = None,
):
"""
Create a single node of a DataTree.
Expand All @@ -449,6 +453,8 @@ def __init__(
Any child nodes of this node. Default is None.
name : str, optional
Name for this node of the tree. Default is None.
attrs : dict-like, optional
Global attributes to save on this datatree.

Returns
-------
Expand All @@ -465,6 +471,7 @@ def __init__(
self._set_node_data(_coerce_to_dataset(data))
self.parent = parent
self.children = children
self._attrs = dict(attrs) if attrs else None

def _set_node_data(self, ds: Dataset):
data_vars, coord_vars = _collect_data_and_coord_variables(ds)
Expand Down Expand Up @@ -823,7 +830,8 @@ def _copy_node(
) -> DataTree:
"""Copy just one node of a tree"""
data = self.ds.copy(deep=deep)
new_node: DataTree = DataTree(data, name=self.name)
attrs = copy.deepcopy(self.attrs) if deep else self.attrs.copy()
new_node: DataTree = DataTree(data, name=self.name, attrs=attrs)
return new_node

def __copy__(self: DataTree) -> DataTree:
Expand Down
20 changes: 20 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,3 +1193,23 @@ def test_one_liner(self):
def test_none(self):
actual_doc = insert_doc_addendum(None, _MAPPED_DOCSTRING_ADDENDUM)
assert actual_doc is None


class TestDataTreeAttrs:
"""
Test passing ``attrs`` to the DataTree constructor.
"""

@pytest.fixture
def dataset(self):
"""Sample dataset fixture."""
ds = xr.Dataset({"a": ("x", [0, 3])})
return ds

def test_attrs_argument(self, dataset):
"""
Test passing attrs as argument to the constructor.
"""
attrs = {"foo": "bar"}
dt = DataTree(dataset, attrs=attrs)
assert dt.attrs == attrs
Loading