From 33ead654e84511fb409218407852d7d146d16763 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Tue, 15 Oct 2024 02:10:26 -0600 Subject: [PATCH] map_over_subtree -> map_over_datasets (#9622) --- doc/api.rst | 4 +-- doc/user-guide/hierarchical-data.rst | 8 ++--- doc/whats-new.rst | 2 +- xarray/__init__.py | 4 +-- xarray/core/datatree.py | 14 ++++---- xarray/core/datatree_mapping.py | 10 +++--- xarray/core/datatree_ops.py | 14 ++++---- xarray/tests/test_datatree.py | 6 ++-- xarray/tests/test_datatree_mapping.py | 46 +++++++++++++-------------- 9 files changed, 54 insertions(+), 54 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 2e671f1de69..40e05035d11 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -687,7 +687,7 @@ For manipulating, traversing, navigating, or mapping over the tree structure. DataTree.relative_to DataTree.iter_lineage DataTree.find_common_ancestor - DataTree.map_over_subtree + DataTree.map_over_datasets DataTree.pipe DataTree.match DataTree.filter @@ -954,7 +954,7 @@ DataTree methods open_datatree open_groups - map_over_subtree + map_over_datasets DataTree.to_dict DataTree.to_netcdf DataTree.to_zarr diff --git a/doc/user-guide/hierarchical-data.rst b/doc/user-guide/hierarchical-data.rst index 4b3a7260567..cdd30f9a302 100644 --- a/doc/user-guide/hierarchical-data.rst +++ b/doc/user-guide/hierarchical-data.rst @@ -549,13 +549,13 @@ See that the same change (fast-forwarding by adding 10 years to the age of each Mapping Custom Functions Over Trees ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -You can map custom computation over each node in a tree using :py:meth:`xarray.DataTree.map_over_subtree`. +You can map custom computation over each node in a tree using :py:meth:`xarray.DataTree.map_over_datasets`. You can map any function, so long as it takes :py:class:`xarray.Dataset` objects as one (or more) of the input arguments, and returns one (or more) xarray datasets. .. note:: - Functions passed to :py:func:`~xarray.DataTree.map_over_subtree` cannot alter nodes in-place. + Functions passed to :py:func:`~xarray.DataTree.map_over_datasets` cannot alter nodes in-place. Instead they must return new :py:class:`xarray.Dataset` objects. For example, we can define a function to calculate the Root Mean Square of a timeseries @@ -569,11 +569,11 @@ Then calculate the RMS value of these signals: .. ipython:: python - voltages.map_over_subtree(rms) + voltages.map_over_datasets(rms) .. _multiple trees: -We can also use the :py:meth:`~xarray.map_over_subtree` decorator to promote a function which accepts datasets into one which +We can also use the :py:meth:`~xarray.map_over_datasets` decorator to promote a function which accepts datasets into one which accepts datatrees. Operating on Multiple Trees diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b374721c8ee..deb8cc9bdc3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,7 +23,7 @@ New Features ~~~~~~~~~~~~ - ``DataTree`` related functionality is now exposed in the main ``xarray`` public API. This includes: ``xarray.DataTree``, ``xarray.open_datatree``, ``xarray.open_groups``, - ``xarray.map_over_subtree``, ``xarray.register_datatree_accessor`` and + ``xarray.map_over_datasets``, ``xarray.register_datatree_accessor`` and ``xarray.testing.assert_isomorphic``. By `Owen Littlejohns `_, `Eni Awowale `_, diff --git a/xarray/__init__.py b/xarray/__init__.py index e3b7ec469e9..1e1bfe9a770 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -34,7 +34,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.datatree import DataTree -from xarray.core.datatree_mapping import TreeIsomorphismError, map_over_subtree +from xarray.core.datatree_mapping import TreeIsomorphismError, map_over_datasets from xarray.core.extensions import ( register_dataarray_accessor, register_dataset_accessor, @@ -86,7 +86,7 @@ "load_dataarray", "load_dataset", "map_blocks", - "map_over_subtree", + "map_over_datasets", "merge", "ones_like", "open_dataarray", diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index ec550cf4136..e503b5c0741 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -23,7 +23,7 @@ from xarray.core.datatree_mapping import ( TreeIsomorphismError, check_isomorphic, - map_over_subtree, + map_over_datasets, ) from xarray.core.formatting import datatree_repr, dims_and_coords_repr from xarray.core.formatting_html import ( @@ -254,14 +254,14 @@ def _constructor( def __setitem__(self, key, val) -> None: raise AttributeError( "Mutation of the DatasetView is not allowed, please use `.__setitem__` on the wrapping DataTree node, " - "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_datasets`," "use `.copy()` first to get a mutable version of the input dataset." ) def update(self, other) -> NoReturn: raise AttributeError( "Mutation of the DatasetView is not allowed, please use `.update` on the wrapping DataTree node, " - "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_subtree`," + "or use `dt.to_dataset()` if you want a mutable dataset. If calling this from within `map_over_datasets`," "use `.copy()` first to get a mutable version of the input dataset." ) @@ -1330,7 +1330,7 @@ def filter(self: DataTree, filterfunc: Callable[[DataTree], bool]) -> DataTree: -------- match pipe - map_over_subtree + map_over_datasets """ filtered_nodes = { node.path: node.dataset for node in self.subtree if filterfunc(node) @@ -1356,7 +1356,7 @@ def match(self, pattern: str) -> DataTree: -------- filter pipe - map_over_subtree + map_over_datasets Examples -------- @@ -1383,7 +1383,7 @@ def match(self, pattern: str) -> DataTree: } return DataTree.from_dict(matching_nodes, name=self.root.name) - def map_over_subtree( + def map_over_datasets( self, func: Callable, *args: Iterable[Any], @@ -1417,7 +1417,7 @@ def map_over_subtree( # TODO this signature means that func has no way to know which node it is being called upon - change? # TODO fix this typing error - return map_over_subtree(func)(self, *args, **kwargs) + return map_over_datasets(func)(self, *args, **kwargs) def pipe( self, func: Callable | tuple[Callable, str], *args: Any, **kwargs: Any diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 78abf42601d..fc26556b0e4 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -72,7 +72,7 @@ def check_isomorphic( raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff) -def map_over_subtree(func: Callable) -> Callable: +def map_over_datasets(func: Callable) -> Callable: """ Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees. @@ -112,8 +112,8 @@ def map_over_subtree(func: Callable) -> Callable: See also -------- - DataTree.map_over_subtree - DataTree.map_over_subtree_inplace + DataTree.map_over_datasets + DataTree.map_over_datasets_inplace DataTree.subtree """ @@ -122,7 +122,7 @@ def map_over_subtree(func: Callable) -> Callable: # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? @functools.wraps(func) - def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: + def _map_over_datasets(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" from xarray.core.datatree import DataTree @@ -227,7 +227,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: else: return tuple(result_trees) - return _map_over_subtree + return _map_over_datasets def _handle_errors_with_path_context(path: str): diff --git a/xarray/core/datatree_ops.py b/xarray/core/datatree_ops.py index 693b0a402b9..69c1b9e9082 100644 --- a/xarray/core/datatree_ops.py +++ b/xarray/core/datatree_ops.py @@ -4,7 +4,7 @@ import textwrap from xarray.core.dataset import Dataset -from xarray.core.datatree_mapping import map_over_subtree +from xarray.core.datatree_mapping import map_over_datasets """ Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree. @@ -17,7 +17,7 @@ _MAPPED_DOCSTRING_ADDENDUM = ( "This method was copied from :py:class:`xarray.Dataset`, but has been altered to " "call the method on the Datasets stored in every node of the subtree. " - "See the `map_over_subtree` function for more details." + "See the `map_over_datasets` function for more details." ) # TODO equals, broadcast_equals etc. @@ -174,7 +174,7 @@ def _wrap_then_attach_to_cls( target_cls_dict, source_cls, methods_to_set, wrap_func=None ): """ - Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree). + Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_datasets). Result is like having written this in the classes' definition: ``` @@ -206,7 +206,7 @@ def method_name(self, *args, **kwargs): ) target_cls_dict[method_name] = wrapped_method - if wrap_func is map_over_subtree: + if wrap_func is map_over_datasets: # Add a paragraph to the method's docstring explaining how it's been mapped orig_method_docstring = orig_method.__doc__ @@ -277,7 +277,7 @@ class MappedDatasetMethodsMixin: target_cls_dict=vars(), source_cls=Dataset, methods_to_set=_ALL_DATASET_METHODS_TO_MAP, - wrap_func=map_over_subtree, + wrap_func=map_over_datasets, ) @@ -291,7 +291,7 @@ class MappedDataWithCoords: target_cls_dict=vars(), source_cls=Dataset, methods_to_set=_DATA_WITH_COORDS_METHODS_TO_MAP, - wrap_func=map_over_subtree, + wrap_func=map_over_datasets, ) @@ -305,5 +305,5 @@ class DataTreeArithmeticMixin: target_cls_dict=vars(), source_cls=Dataset, methods_to_set=_ARITHMETIC_METHODS_TO_MAP, - wrap_func=map_over_subtree, + wrap_func=map_over_datasets, ) diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index ec5ce4e40c3..69c6566f88c 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -1805,7 +1805,7 @@ def test_tree(self, create_test_datatree): class TestDocInsertion: - """Tests map_over_subtree docstring injection.""" + """Tests map_over_datasets docstring injection.""" def test_standard_doc(self): @@ -1839,7 +1839,7 @@ def test_standard_doc(self): .. note:: This method was copied from :py:class:`xarray.Dataset`, but has been altered to call the method on the Datasets stored in every - node of the subtree. See the `map_over_subtree` function for more + node of the subtree. See the `map_over_datasets` function for more details. Normally, it should not be necessary to call this method in user code, @@ -1870,7 +1870,7 @@ def test_one_liner(self): This method was copied from :py:class:`xarray.Dataset`, but has been altered to call the method on the Datasets stored in every node of the subtree. See - the `map_over_subtree` function for more details.""" + the `map_over_datasets` function for more details.""" ) actual_doc = insert_doc_addendum(mixin_doc, _MAPPED_DOCSTRING_ADDENDUM) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 9ce882bf67a..766df76a259 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -5,7 +5,7 @@ from xarray.core.datatree_mapping import ( TreeIsomorphismError, check_isomorphic, - map_over_subtree, + map_over_datasets, ) from xarray.testing import assert_equal, assert_identical @@ -92,7 +92,7 @@ def test_checking_from_root(self, create_test_datatree): class TestMapOverSubTree: def test_no_trees_passed(self): - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -104,7 +104,7 @@ def test_not_isomorphic(self, create_test_datatree): dt2 = create_test_datatree() dt2["set1/set2/extra"] = xr.DataTree(name="extra") - @map_over_subtree + @map_over_datasets def times_ten(ds1, ds2): return ds1 * ds2 @@ -115,7 +115,7 @@ def test_no_trees_returned(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - @map_over_subtree + @map_over_datasets def bad_func(ds1, ds2): return None @@ -125,7 +125,7 @@ def bad_func(ds1, ds2): def test_single_dt_arg(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -136,7 +136,7 @@ def times_ten(ds): def test_single_dt_arg_plus_args_and_kwargs(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def multiply_then_add(ds, times, add=0.0): return (times * ds) + add @@ -148,7 +148,7 @@ def test_multiple_dt_args(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - @map_over_subtree + @map_over_datasets def add(ds1, ds2): return ds1 + ds2 @@ -160,7 +160,7 @@ def test_dt_as_kwarg(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - @map_over_subtree + @map_over_datasets def add(ds1, value=0.0): return ds1 + value @@ -171,7 +171,7 @@ def add(ds1, value=0.0): def test_return_multiple_dts(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def minmax(ds): return ds.min(), ds.max() @@ -184,7 +184,7 @@ def minmax(ds): def test_return_wrong_type(self, simple_datatree): dt1 = simple_datatree - @map_over_subtree + @map_over_datasets def bad_func(ds1): return "string" @@ -194,7 +194,7 @@ def bad_func(ds1): def test_return_tuple_of_wrong_types(self, simple_datatree): dt1 = simple_datatree - @map_over_subtree + @map_over_datasets def bad_func(ds1): return xr.Dataset(), "string" @@ -205,7 +205,7 @@ def bad_func(ds1): def test_return_inconsistent_number_of_results(self, simple_datatree): dt1 = simple_datatree - @map_over_subtree + @map_over_datasets def bad_func(ds): # Datasets in simple_datatree have different numbers of dims # TODO need to instead return different numbers of Dataset objects for this test to catch the intended error @@ -217,7 +217,7 @@ def bad_func(ds): def test_wrong_number_of_arguments_for_func(self, simple_datatree): dt = simple_datatree - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -229,7 +229,7 @@ def times_ten(ds): def test_map_single_dataset_against_whole_tree(self, create_test_datatree): dt = create_test_datatree() - @map_over_subtree + @map_over_datasets def nodewise_merge(node_ds, fixed_ds): return xr.merge([node_ds, fixed_ds]) @@ -250,7 +250,7 @@ def multiply_then_add(ds, times, add=0.0): return times * ds + add expected = create_test_datatree(modify=lambda ds: (10.0 * ds) + 2.0) - result_tree = dt.map_over_subtree(multiply_then_add, 10.0, add=2.0) + result_tree = dt.map_over_datasets(multiply_then_add, 10.0, add=2.0) assert_equal(result_tree, expected) def test_discard_ancestry(self, create_test_datatree): @@ -258,7 +258,7 @@ def test_discard_ancestry(self, create_test_datatree): dt = create_test_datatree() subtree = dt["set1"] - @map_over_subtree + @map_over_datasets def times_ten(ds): return 10.0 * ds @@ -276,7 +276,7 @@ def check_for_data(ds): assert len(ds.variables) != 0 return ds - dt.map_over_subtree(check_for_data) + dt.map_over_datasets(check_for_data) def test_keep_attrs_on_empty_nodes(self, create_test_datatree): # GH278 @@ -286,7 +286,7 @@ def test_keep_attrs_on_empty_nodes(self, create_test_datatree): def empty_func(ds): return ds - result = dt.map_over_subtree(empty_func) + result = dt.map_over_datasets(empty_func) assert result["set1/set2"].attrs == dt["set1/set2"].attrs @pytest.mark.xfail( @@ -304,13 +304,13 @@ def fail_on_specific_node(ds): with pytest.raises( ValueError, match="Raised whilst mapping function over node /set1" ): - dt.map_over_subtree(fail_on_specific_node) + dt.map_over_datasets(fail_on_specific_node) def test_inherited_coordinates_with_index(self): root = xr.Dataset(coords={"x": [1, 2]}) child = xr.Dataset({"foo": ("x", [0, 1])}) # no coordinates tree = xr.DataTree.from_dict({"/": root, "/child": child}) - actual = tree.map_over_subtree(lambda ds: ds) # identity + actual = tree.map_over_datasets(lambda ds: ds) # identity assert isinstance(actual, xr.DataTree) assert_identical(tree, actual) @@ -338,7 +338,7 @@ def test_construct_using_type(self): def weighted_mean(ds): return ds.weighted(ds.area).mean(["x", "y"]) - dt.map_over_subtree(weighted_mean) + dt.map_over_datasets(weighted_mean) def test_alter_inplace_forbidden(self): simpsons = xr.DataTree.from_dict( @@ -359,10 +359,10 @@ def fast_forward(ds: xr.Dataset, years: float) -> xr.Dataset: return ds with pytest.raises(AttributeError): - simpsons.map_over_subtree(fast_forward, years=10) + simpsons.map_over_datasets(fast_forward, years=10) @pytest.mark.xfail class TestMapOverSubTreeInplace: - def test_map_over_subtree_inplace(self): + def test_map_over_datasets_inplace(self): raise NotImplementedError