Skip to content

Commit

Permalink
DAS-2064 - Minor changes to datatree_mapping.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
owenlittlejohns committed Apr 16, 2024
1 parent e680426 commit 4ab8e78
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
Expand Down
28 changes: 14 additions & 14 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal:
if node_a.name != node_b.name:
diff = dedent(
f"""\
if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
Expand Down Expand Up @@ -124,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable:
func : callable
Function to apply to datasets with signature:
`func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`.
`func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`.
(i.e. func must accept at least one Dataset and return at least one Dataset.)
Function will not be applied to any nodes without datasets.
Expand Down Expand Up @@ -258,19 +257,18 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]:
return _map_over_subtree


def _handle_errors_with_path_context(path):
def _handle_errors_with_path_context(path: str):
"""Wraps given function so that if it fails it also raises path to node on which it failed."""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
if sys.version_info >= (3, 11):
# Add the context information to the error message
e.add_note(
f"Raised whilst mapping function over node with path {path}"
)
# Add the context information to the error message
add_note(
e, f"Raised whilst mapping function over node with path {path}"
)
raise

return wrapper
Expand All @@ -286,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None:
err.add_note(msg)


def _check_single_set_return_values(path_to_node, obj):
def _check_single_set_return_values(
path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray]
):
"""Check types returned from single evaluation of func, and return number of return values received from func."""
if isinstance(obj, (Dataset, DataArray)):
return 1
Expand Down
3 changes: 1 addition & 2 deletions xarray/core/iterators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections import abc
from collections.abc import Iterator
from typing import Callable

Expand All @@ -9,7 +8,7 @@
"""These iterators are copied from anytree.iterators, with minor modifications."""


class LevelOrderIter(abc.Iterator):
class LevelOrderIter(Iterator):
"""Iterate over tree applying level-order strategy starting at `node`.
This is the iterator used by `DataTree` to traverse nodes.
Expand Down

0 comments on commit 4ab8e78

Please sign in to comment.