diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 41b39bf2b1a..f809590d351 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,9 @@ Bug fixes Internal Changes ~~~~~~~~~~~~~~~~ +- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`) + By `Matt Savoie `_ `Owen Littlejohns + ` and `Tom Nicholas `_. .. _whats-new.2024.03.0: diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 5df97fb0bc9..714921d2a90 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -83,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): path_a, path_b = node_a.path, node_b.path - if require_names_equal: - if node_a.name != node_b.name: - diff = dedent( - f"""\ + if require_names_equal and node_a.name != node_b.name: + diff = dedent( + f"""\ Node '{path_a}' in the left object has name '{node_a.name}' Node '{path_b}' in the right object has name '{node_b.name}'""" - ) - return diff + ) + return diff if len(node_a.children) != len(node_b.children): diff = dedent( @@ -124,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable: func : callable Function to apply to datasets with signature: - `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. + `func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`. (i.e. func must accept at least one Dataset and return at least one Dataset.) Function will not be applied to any nodes without datasets. @@ -258,7 +257,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: return _map_over_subtree -def _handle_errors_with_path_context(path): +def _handle_errors_with_path_context(path: str): """Wraps given function so that if it fails it also raises path to node on which it failed.""" def decorator(func): @@ -266,11 +265,10 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - if sys.version_info >= (3, 11): - # Add the context information to the error message - e.add_note( - f"Raised whilst mapping function over node with path {path}" - ) + # Add the context information to the error message + add_note( + e, f"Raised whilst mapping function over node with path {path}" + ) raise return wrapper @@ -286,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None: err.add_note(msg) -def _check_single_set_return_values(path_to_node, obj): +def _check_single_set_return_values( + path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray] +): """Check types returned from single evaluation of func, and return number of return values received from func.""" if isinstance(obj, (Dataset, DataArray)): return 1 diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py index 5c0c0f652e8..dd5fa7ee97a 100644 --- a/xarray/core/iterators.py +++ b/xarray/core/iterators.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import abc from collections.abc import Iterator from typing import Callable @@ -9,7 +8,7 @@ """These iterators are copied from anytree.iterators, with minor modifications.""" -class LevelOrderIter(abc.Iterator): +class LevelOrderIter(Iterator): """Iterate over tree applying level-order strategy starting at `node`. This is the iterator used by `DataTree` to traverse nodes.