From b004af5174a4b0e32519df792a4f625d5548a9f0 Mon Sep 17 00:00:00 2001 From: Tom Nicholas Date: Sun, 14 Apr 2024 07:14:42 -0400 Subject: [PATCH 1/7] Correct save_mfdataset docstring (#8934) * correct docstring * get type hint right --- xarray/backends/api.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2f73c38341b..2589ff196f9 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1416,8 +1416,6 @@ def save_mfdataset( these locations will be overwritten. format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ "NETCDF3_CLASSIC"}, optional - **kwargs : additional arguments are passed along to ``to_netcdf`` - File format for the resulting netCDF file: * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API @@ -1449,10 +1447,11 @@ def save_mfdataset( compute : bool If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. + **kwargs : dict, optional + Additional arguments are passed along to ``to_netcdf``. Examples -------- - Save a dataset into one netCDF per year of data: >>> ds = xr.Dataset( From 2b2de81ba55a6b302fc4a5d174d7c2b2f05ca9d4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Apr 2024 12:16:37 -0700 Subject: [PATCH 2/7] Bump codecov/codecov-action from 4.2.0 to 4.3.0 in the actions group (#8943) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/ci-additional.yaml | 8 ++++---- .github/workflows/ci.yaml | 2 +- .github/workflows/upstream-dev-ci.yaml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index bb39f1875e9..0fcba9d5499 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -127,7 +127,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.2.0 + uses: codecov/codecov-action@v4.3.0 with: file: mypy_report/cobertura.xml flags: mypy @@ -181,7 +181,7 @@ jobs: python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report xarray/ - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.2.0 + uses: codecov/codecov-action@v4.3.0 with: file: mypy_report/cobertura.xml flags: mypy39 @@ -242,7 +242,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v4.2.0 + uses: codecov/codecov-action@v4.3.0 with: file: pyright_report/cobertura.xml flags: pyright @@ -301,7 +301,7 @@ jobs: python -m pyright xarray/ - name: Upload pyright coverage to Codecov - uses: codecov/codecov-action@v4.2.0 + uses: codecov/codecov-action@v4.3.0 with: file: pyright_report/cobertura.xml flags: pyright39 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9724c18436e..7e097fa5c45 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -151,7 +151,7 @@ jobs: path: pytest.xml - name: Upload code coverage to Codecov - uses: codecov/codecov-action@v4.2.0 + uses: codecov/codecov-action@v4.3.0 with: file: ./coverage.xml flags: unittests diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml index 9d43c575b1a..4f3d199dd2d 100644 --- a/.github/workflows/upstream-dev-ci.yaml +++ b/.github/workflows/upstream-dev-ci.yaml @@ -143,7 +143,7 @@ jobs: run: | python -m mypy --install-types --non-interactive --cobertura-xml-report mypy_report - name: Upload mypy coverage to Codecov - uses: codecov/codecov-action@v4.2.0 + uses: codecov/codecov-action@v4.3.0 with: file: mypy_report/cobertura.xml flags: mypy From 239309f881ba0d7e02280147bc443e6e286e6a63 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 16 Apr 2024 10:53:41 -0400 Subject: [PATCH 3/7] Convert 360_day calendars by choosing random dates to drop or add (#8603) * Convert 360 calendar randomly * add note to whats new * add pull number to whats new entry * run pre-commit * Change test to use recommended freq * Apply suggestions from code review Co-authored-by: Spencer Clark * Fix merge - remove rng arg --------- Co-authored-by: Spencer Clark --- doc/whats-new.rst | 2 ++ xarray/coding/calendar_ops.py | 53 ++++++++++++++++++++++++++----- xarray/tests/test_calendar_ops.py | 39 +++++++++++++++++++++++ 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 41b39bf2b1a..24dcb8c9687 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,8 @@ v2024.04.0 (unreleased) New Features ~~~~~~~~~~~~ +- New "random" method for converting to and from 360_day calendars (:pull:`8603`). + By `Pascal Bourgault `_. Breaking changes diff --git a/xarray/coding/calendar_ops.py b/xarray/coding/calendar_ops.py index dc2f95b832e..c4fe9e1f4ae 100644 --- a/xarray/coding/calendar_ops.py +++ b/xarray/coding/calendar_ops.py @@ -64,7 +64,7 @@ def convert_calendar( The target calendar name. dim : str Name of the time coordinate in the input DataArray or Dataset. - align_on : {None, 'date', 'year'} + align_on : {None, 'date', 'year', 'random'} Must be specified when either the source or target is a `"360_day"` calendar; ignored otherwise. See Notes. missing : any, optional @@ -143,6 +143,16 @@ def convert_calendar( will be dropped as there are no equivalent dates in a standard calendar. This option is best used with data on a frequency coarser than daily. + + "random" + Similar to "year", each day of year of the source is mapped to another day of year + of the target. However, instead of having always the same missing days according + the source and target years, here 5 days are chosen randomly, one for each fifth + of the year. However, February 29th is always missing when converting to a leap year, + or its value is dropped when converting from a leap year. This is similar to the method + used in the LOCA dataset (see Pierce, Cayan, and Thrasher (2014). doi:10.1175/JHM-D-14-0082.1). + + This option is best used on daily data. """ from xarray.core.dataarray import DataArray @@ -174,14 +184,20 @@ def convert_calendar( out = obj.copy() - if align_on == "year": + if align_on in ["year", "random"]: # Special case for conversion involving 360_day calendar - # Instead of translating dates directly, this tries to keep the position within a year similar. - - new_doy = time.groupby(f"{dim}.year").map( - _interpolate_day_of_year, target_calendar=calendar, use_cftime=use_cftime - ) - + if align_on == "year": + # Instead of translating dates directly, this tries to keep the position within a year similar. + new_doy = time.groupby(f"{dim}.year").map( + _interpolate_day_of_year, + target_calendar=calendar, + use_cftime=use_cftime, + ) + elif align_on == "random": + # The 5 days to remove are randomly chosen, one for each of the five 72-days periods of the year. + new_doy = time.groupby(f"{dim}.year").map( + _random_day_of_year, target_calendar=calendar, use_cftime=use_cftime + ) # Convert the source datetimes, but override the day of year with our new day of years. out[dim] = DataArray( [ @@ -229,6 +245,27 @@ def _interpolate_day_of_year(time, target_calendar, use_cftime): ).astype(int) +def _random_day_of_year(time, target_calendar, use_cftime): + """Return a day of year in the new calendar. + + Removes Feb 29th and five other days chosen randomly within five sections of 72 days. + """ + year = int(time.dt.year[0]) + source_calendar = time.dt.calendar + new_doy = np.arange(360) + 1 + rm_idx = np.random.default_rng().integers(0, 72, 5) + 72 * np.arange(5) + if source_calendar == "360_day": + for idx in rm_idx: + new_doy[idx + 1 :] = new_doy[idx + 1 :] + 1 + if _days_in_year(year, target_calendar, use_cftime) == 366: + new_doy[new_doy >= 60] = new_doy[new_doy >= 60] + 1 + elif target_calendar == "360_day": + new_doy = np.insert(new_doy, rm_idx - np.arange(5), -1) + if _days_in_year(year, source_calendar, use_cftime) == 366: + new_doy = np.insert(new_doy, 60, -1) + return new_doy[time.dt.dayofyear - 1] + + def _convert_to_new_calendar_with_new_day_of_year( date, day_of_year, calendar, use_cftime ): diff --git a/xarray/tests/test_calendar_ops.py b/xarray/tests/test_calendar_ops.py index d2792034876..5a57083fd22 100644 --- a/xarray/tests/test_calendar_ops.py +++ b/xarray/tests/test_calendar_ops.py @@ -106,6 +106,45 @@ def test_convert_calendar_360_days(source, target, freq, align_on): assert conv.size == 359 if freq == "D" else 359 * 4 +def test_convert_calendar_360_days_random(): + da_std = DataArray( + np.linspace(0, 1, 366), + dims=("time",), + coords={ + "time": date_range( + "2004-01-01", + "2004-12-31", + freq="D", + calendar="standard", + use_cftime=False, + ) + }, + ) + da_360 = DataArray( + np.linspace(0, 1, 360), + dims=("time",), + coords={ + "time": date_range("2004-01-01", "2004-12-30", freq="D", calendar="360_day") + }, + ) + + conv = convert_calendar(da_std, "360_day", align_on="random") + conv2 = convert_calendar(da_std, "360_day", align_on="random") + assert (conv != conv2).any() + + conv = convert_calendar(da_360, "standard", use_cftime=False, align_on="random") + assert np.datetime64("2004-02-29") not in conv.time + conv2 = convert_calendar(da_360, "standard", use_cftime=False, align_on="random") + assert (conv2 != conv).any() + + # Ensure that added days are evenly distributed in the 5 fifths of each year + conv = convert_calendar(da_360, "noleap", align_on="random", missing=np.NaN) + conv = conv.where(conv.isnull(), drop=True) + nandoys = conv.time.dt.dayofyear[:366] + assert all(nandoys < np.array([74, 147, 220, 293, 366])) + assert all(nandoys > np.array([0, 73, 146, 219, 292])) + + @requires_cftime @pytest.mark.parametrize( "source,target,freq", From 9a37053114d9950d6ca09a1ac0ded40eca521778 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 17 Apr 2024 09:39:22 -0700 Subject: [PATCH 4/7] Add mypy to dev dependencies (#8947) --- pyproject.toml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bdbfd9b52ab..15a3c99eec2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ accel = ["scipy", "bottleneck", "numbagg", "flox", "opt_einsum"] complete = ["xarray[accel,io,parallel,viz,dev]"] dev = [ "hypothesis", + "mypy", "pre-commit", "pytest", "pytest-cov", @@ -86,8 +87,8 @@ exclude_lines = ["pragma: no cover", "if TYPE_CHECKING"] [tool.mypy] enable_error_code = "redundant-self" exclude = [ - 'xarray/util/generate_.*\.py', - 'xarray/datatree_/.*\.py', + 'xarray/util/generate_.*\.py', + 'xarray/datatree_/.*\.py', ] files = "xarray" show_error_codes = true @@ -98,8 +99,8 @@ warn_unused_ignores = true # Ignore mypy errors for modules imported from datatree_. [[tool.mypy.overrides]] -module = "xarray.datatree_.*" ignore_errors = true +module = "xarray.datatree_.*" # Much of the numerical computing stack doesn't have type annotations yet. [[tool.mypy.overrides]] @@ -255,6 +256,9 @@ target-version = "py39" # E402: module level import not at top of file # E501: line too long - let black worry about that # E731: do not assign a lambda expression, use a def +extend-safe-fixes = [ + "TID252", # absolute imports +] ignore = [ "E402", "E501", @@ -268,9 +272,6 @@ select = [ "I", # isort "UP", # Pyupgrade ] -extend-safe-fixes = [ - "TID252", # absolute imports -] [tool.ruff.lint.per-file-ignores] # don't enforce absolute imports From 60f3e741d463840de8409fb4c6cec41de7f7ce05 Mon Sep 17 00:00:00 2001 From: owenlittlejohns Date: Wed, 17 Apr 2024 13:59:34 -0600 Subject: [PATCH 5/7] Migrate datatree mapping.py (#8948) * DAS-2064: rename/relocate mapping.py -> xarray.core.datatree_mapping.py DAS-2064: fix circular import issue. * DAS-2064 - Minor changes to datatree_mapping.py. --------- Co-authored-by: Matt Savoie --- doc/whats-new.rst | 3 ++ xarray/core/datatree.py | 10 +++--- .../mapping.py => core/datatree_mapping.py} | 33 +++++++++---------- xarray/core/iterators.py | 3 +- xarray/datatree_/datatree/__init__.py | 3 -- xarray/datatree_/datatree/formatting.py | 2 +- xarray/datatree_/datatree/ops.py | 2 +- .../test_datatree_mapping.py} | 12 ++++--- 8 files changed, 35 insertions(+), 33 deletions(-) rename xarray/{datatree_/datatree/mapping.py => core/datatree_mapping.py} (94%) rename xarray/{datatree_/datatree/tests/test_mapping.py => tests/test_datatree_mapping.py} (98%) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 24dcb8c9687..1c67ba1ee7f 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,9 @@ Bug fixes Internal Changes ~~~~~~~~~~~~~~~~ +- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`) + By `Matt Savoie `_ `Owen Littlejohns + ` and `Tom Nicholas `_. .. _whats-new.2024.03.0: diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 1b06d87c9fb..e24dfb504b1 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -18,6 +18,11 @@ from xarray.core.coordinates import DatasetCoordinates from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset, DataVariables +from xarray.core.datatree_mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS @@ -36,11 +41,6 @@ from xarray.datatree_.datatree.formatting_html import ( datatree_repr as datatree_repr_html, ) -from xarray.datatree_.datatree.mapping import ( - TreeIsomorphismError, - check_isomorphic, - map_over_subtree, -) from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, diff --git a/xarray/datatree_/datatree/mapping.py b/xarray/core/datatree_mapping.py similarity index 94% rename from xarray/datatree_/datatree/mapping.py rename to xarray/core/datatree_mapping.py index 9546905e1ac..714921d2a90 100644 --- a/xarray/datatree_/datatree/mapping.py +++ b/xarray/core/datatree_mapping.py @@ -4,10 +4,9 @@ import sys from itertools import repeat from textwrap import dedent -from typing import TYPE_CHECKING, Callable, Tuple +from typing import TYPE_CHECKING, Callable from xarray import DataArray, Dataset - from xarray.core.iterators import LevelOrderIter from xarray.core.treenode import NodePath, TreeNode @@ -84,14 +83,13 @@ def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> s for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)): path_a, path_b = node_a.path, node_b.path - if require_names_equal: - if node_a.name != node_b.name: - diff = dedent( - f"""\ + if require_names_equal and node_a.name != node_b.name: + diff = dedent( + f"""\ Node '{path_a}' in the left object has name '{node_a.name}' Node '{path_b}' in the right object has name '{node_b.name}'""" - ) - return diff + ) + return diff if len(node_a.children) != len(node_b.children): diff = dedent( @@ -125,7 +123,7 @@ def map_over_subtree(func: Callable) -> Callable: func : callable Function to apply to datasets with signature: - `func(*args, **kwargs) -> Union[Dataset, Iterable[Dataset]]`. + `func(*args, **kwargs) -> Union[DataTree, Iterable[DataTree]]`. (i.e. func must accept at least one Dataset and return at least one Dataset.) Function will not be applied to any nodes without datasets. @@ -154,7 +152,7 @@ def map_over_subtree(func: Callable) -> Callable: # TODO inspect function to work out immediately if the wrong number of arguments were passed for it? @functools.wraps(func) - def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: + def _map_over_subtree(*args, **kwargs) -> DataTree | tuple[DataTree, ...]: """Internal function which maps func over every node in tree, returning a tree of the results.""" from xarray.core.datatree import DataTree @@ -259,7 +257,7 @@ def _map_over_subtree(*args, **kwargs) -> DataTree | Tuple[DataTree, ...]: return _map_over_subtree -def _handle_errors_with_path_context(path): +def _handle_errors_with_path_context(path: str): """Wraps given function so that if it fails it also raises path to node on which it failed.""" def decorator(func): @@ -267,11 +265,10 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except Exception as e: - if sys.version_info >= (3, 11): - # Add the context information to the error message - e.add_note( - f"Raised whilst mapping function over node with path {path}" - ) + # Add the context information to the error message + add_note( + e, f"Raised whilst mapping function over node with path {path}" + ) raise return wrapper @@ -287,7 +284,9 @@ def add_note(err: BaseException, msg: str) -> None: err.add_note(msg) -def _check_single_set_return_values(path_to_node, obj): +def _check_single_set_return_values( + path_to_node: str, obj: Dataset | DataArray | tuple[Dataset | DataArray] +): """Check types returned from single evaluation of func, and return number of return values received from func.""" if isinstance(obj, (Dataset, DataArray)): return 1 diff --git a/xarray/core/iterators.py b/xarray/core/iterators.py index 5c0c0f652e8..dd5fa7ee97a 100644 --- a/xarray/core/iterators.py +++ b/xarray/core/iterators.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections import abc from collections.abc import Iterator from typing import Callable @@ -9,7 +8,7 @@ """These iterators are copied from anytree.iterators, with minor modifications.""" -class LevelOrderIter(abc.Iterator): +class LevelOrderIter(Iterator): """Iterate over tree applying level-order strategy starting at `node`. This is the iterator used by `DataTree` to traverse nodes. diff --git a/xarray/datatree_/datatree/__init__.py b/xarray/datatree_/datatree/__init__.py index f2603b64641..3159d612913 100644 --- a/xarray/datatree_/datatree/__init__.py +++ b/xarray/datatree_/datatree/__init__.py @@ -1,11 +1,8 @@ # import public API -from .mapping import TreeIsomorphismError, map_over_subtree from xarray.core.treenode import InvalidTreeError, NotFoundInTreeError __all__ = ( - "TreeIsomorphismError", "InvalidTreeError", "NotFoundInTreeError", - "map_over_subtree", ) diff --git a/xarray/datatree_/datatree/formatting.py b/xarray/datatree_/datatree/formatting.py index 9ebee72d4ef..fdd23933ae6 100644 --- a/xarray/datatree_/datatree/formatting.py +++ b/xarray/datatree_/datatree/formatting.py @@ -2,7 +2,7 @@ from xarray.core.formatting import _compat_to_str, diff_dataset_repr -from xarray.datatree_.datatree.mapping import diff_treestructure +from xarray.core.datatree_mapping import diff_treestructure from xarray.datatree_.datatree.render import RenderTree if TYPE_CHECKING: diff --git a/xarray/datatree_/datatree/ops.py b/xarray/datatree_/datatree/ops.py index d6ac4f83e7c..83b9d1b275a 100644 --- a/xarray/datatree_/datatree/ops.py +++ b/xarray/datatree_/datatree/ops.py @@ -2,7 +2,7 @@ from xarray import Dataset -from .mapping import map_over_subtree +from xarray.core.datatree_mapping import map_over_subtree """ Module which specifies the subset of xarray.Dataset's API which we wish to copy onto DataTree. diff --git a/xarray/datatree_/datatree/tests/test_mapping.py b/xarray/tests/test_datatree_mapping.py similarity index 98% rename from xarray/datatree_/datatree/tests/test_mapping.py rename to xarray/tests/test_datatree_mapping.py index c6cd04887c0..16ca726759d 100644 --- a/xarray/datatree_/datatree/tests/test_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -1,9 +1,13 @@ import numpy as np import pytest -import xarray as xr +import xarray as xr from xarray.core.datatree import DataTree -from xarray.datatree_.datatree.mapping import TreeIsomorphismError, check_isomorphic, map_over_subtree +from xarray.core.datatree_mapping import ( + TreeIsomorphismError, + check_isomorphic, + map_over_subtree, +) from xarray.datatree_.datatree.testing import assert_equal empty = xr.Dataset() @@ -12,7 +16,7 @@ class TestCheckTreesIsomorphic: def test_not_a_tree(self): with pytest.raises(TypeError, match="not a tree"): - check_isomorphic("s", 1) + check_isomorphic("s", 1) # type: ignore[arg-type] def test_different_widths(self): dt1 = DataTree.from_dict(d={"a": empty}) @@ -69,7 +73,7 @@ def test_not_isomorphic_complex_tree(self, create_test_datatree): def test_checking_from_root(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() - real_root = DataTree(name="real root") + real_root: DataTree = DataTree(name="real root") dt2.name = "not_real_root" dt2.parent = real_root with pytest.raises(TreeIsomorphismError): From 9eb180b7b3df869c21518a17c6fa9eb456551d21 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Thu, 18 Apr 2024 14:52:02 +0200 Subject: [PATCH 6/7] (feat): Support for `pandas` `ExtensionArray` (#8723) * (feat): first pass supporting extension arrays * (feat): categorical tests + functionality * (feat): use multiple dispatch for unimplemented ops * (feat): implement (not really) broadcasting * (chore): add more `groupby` tests * (fix): fix more groupby incompatibility * (bug): fix unused categories * (chore): refactor dispatched methods + tests * (fix): shared type should check for extension arrays first and then fall back to numpy * (refactor): tests moved * (chore): more higher level tests * (feat): to/from dataframe * (chore): check for plum import * (fix): `__setitem__`/`__getitem__` * (chore): disallow stacking * (fix): `pyproject.toml` * (fix): `as_shared_type` fix * (chore): add variable tests * (fix): dask + categoricals * (chore): notes/docs * (chore): remove old testing file * (chore): remove ocmmented out code * (fix): import plum dispatch * (refactor): use `is_extension_array_dtype` as much as possible * (refactor): `extension_array`->`array` + move to `indexing` * (refactor): change order of classes * (chore): add small pyarrow test * (fix): fix some mypy issues * (fix): don't register unregisterable method * (fix): appease mypy * (fix): more sensible default implemetations allow most use without `plum` * (fix): handling `pyarrow` tests * (fix): actually do import correctly * (fix): `reduce` condition * (fix): column ordering for dataframes * (refactor): remove encoding business * (refactor): raise error for dask + extension array * (fix): only wrap `ExtensionDuckArray` that has a `.array` which is a pandas extension array * (fix): use duck array equality method, not pandas * (refactor): bye plum! * (fix): `and` to `or` for casting to `ExtensionDuckArray` * (fix): check for class, not type * (fix): only support native endianness * (refactor): no need for superfluous checks in `_maybe_wrap_data` * (chore): clean up docs to no longer reference `plum` * (fix): no longer allow `ExtensionDuckArray` to wrap `ExtensionDuckArray` * (refactor): move `implements` logic to `indexing` * (refactor): `indexing.py` -> `extension_array.py` * (refactor): `ExtensionDuckArray` -> `PandasExtensionArray` * (fix): add writeable property * (fix): don't check writeable for `PandasExtensionArray` * (fix): move check eariler * (refactor): correct guard clause * (chore): remove unnecessary `AttributeError` * (feat): singleton wrapped as array * (feat): remove shared dtype casting * (feat): loop once over `dataframe.items` * (feat): add `__len__` attribute * (fix): ensure constructor recieves `pd.Categorical` * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian * Update xarray/core/extension_array.py Co-authored-by: Deepak Cherian * (fix): drop condition for categorical corrected * Apply suggestions from code review * (chore): test `chunk` behavior * Update xarray/core/variable.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (fix): bring back error * (chore): add test for dropping cat for mean * Update whats-new.rst * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Deepak Cherian Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 6 +- properties/test_pandas_roundtrip.py | 4 +- pyproject.toml | 1 + xarray/core/dataset.py | 60 +++++++++--- xarray/core/duck_array_ops.py | 19 +++- xarray/core/extension_array.py | 136 ++++++++++++++++++++++++++++ xarray/core/types.py | 3 + xarray/core/variable.py | 10 ++ xarray/testing/strategies.py | 10 +- xarray/tests/__init__.py | 18 +++- xarray/tests/test_concat.py | 27 +++++- xarray/tests/test_dataset.py | 42 ++++++--- xarray/tests/test_duck_array_ops.py | 108 ++++++++++++++++++++++ xarray/tests/test_groupby.py | 12 ++- xarray/tests/test_merge.py | 2 +- xarray/tests/test_variable.py | 19 ++++ 16 files changed, 434 insertions(+), 43 deletions(-) create mode 100644 xarray/core/extension_array.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1c67ba1ee7f..4fe51708646 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,7 +24,11 @@ New Features ~~~~~~~~~~~~ - New "random" method for converting to and from 360_day calendars (:pull:`8603`). By `Pascal Bourgault `_. - +- Xarray now makes a best attempt not to coerce :py:class:`pandas.api.extensions.ExtensionArray` to a numpy array + by supporting 1D `ExtensionArray` objects internally where possible. Thus, `Dataset`s initialized with a `pd.Catgeorical`, + for example, will retain the object. However, one cannot do operations that are not possible on the `ExtensionArray` + then, such as broadcasting. + By `Ilan Gold `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/properties/test_pandas_roundtrip.py b/properties/test_pandas_roundtrip.py index 5c0f14976e6..3d87fcce1d9 100644 --- a/properties/test_pandas_roundtrip.py +++ b/properties/test_pandas_roundtrip.py @@ -17,7 +17,9 @@ from hypothesis import given # isort:skip numeric_dtypes = st.one_of( - npst.unsigned_integer_dtypes(), npst.integer_dtypes(), npst.floating_dtypes() + npst.unsigned_integer_dtypes(endianness="="), + npst.integer_dtypes(endianness="="), + npst.floating_dtypes(endianness="="), ) numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt)) diff --git a/pyproject.toml b/pyproject.toml index 15a3c99eec2..8cbd395b2a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ module = [ "opt_einsum.*", "pandas.*", "pooch.*", + "pyarrow.*", "pydap.*", "pytest.*", "scipy.*", diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4c80a47209e..96f3be00995 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -24,6 +24,7 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Generic, Literal, cast, overload import numpy as np +from pandas.api.types import is_extension_array_dtype # remove once numpy 2.0 is the oldest supported version try: @@ -6852,10 +6853,13 @@ def reduce( if ( # Some reduction functions (e.g. std, var) need to run on variables # that don't have the reduce dims: PR5393 - not reduce_dims - or not numeric_only - or np.issubdtype(var.dtype, np.number) - or (var.dtype == np.bool_) + not is_extension_array_dtype(var.dtype) + and ( + not reduce_dims + or not numeric_only + or np.issubdtype(var.dtype, np.number) + or (var.dtype == np.bool_) + ) ): # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because @@ -7168,13 +7172,37 @@ def to_pandas(self) -> pd.Series | pd.DataFrame: ) def _to_dataframe(self, ordered_dims: Mapping[Any, int]): - columns = [k for k in self.variables if k not in self.dims] + columns_in_order = [k for k in self.variables if k not in self.dims] + non_extension_array_columns = [ + k + for k in columns_in_order + if not is_extension_array_dtype(self.variables[k].data) + ] + extension_array_columns = [ + k + for k in columns_in_order + if is_extension_array_dtype(self.variables[k].data) + ] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) - for k in columns + for k in non_extension_array_columns ] index = self.coords.to_index([*ordered_dims]) - return pd.DataFrame(dict(zip(columns, data)), index=index) + broadcasted_df = pd.DataFrame( + dict(zip(non_extension_array_columns, data)), index=index + ) + for extension_array_column in extension_array_columns: + extension_array = self.variables[extension_array_column].data.array + index = self[self.variables[extension_array_column].dims[0]].data + extension_array_df = pd.DataFrame( + {extension_array_column: extension_array}, + index=self[self.variables[extension_array_column].dims[0]].data, + ) + extension_array_df.index.name = self.variables[extension_array_column].dims[ + 0 + ] + broadcasted_df = broadcasted_df.join(extension_array_df) + return broadcasted_df[columns_in_order] def to_dataframe(self, dim_order: Sequence[Hashable] | None = None) -> pd.DataFrame: """Convert this dataset into a pandas.DataFrame. @@ -7321,11 +7349,13 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: "cannot convert a DataFrame with a non-unique MultiIndex into xarray" ) - # Cast to a NumPy array first, in case the Series is a pandas Extension - # array (which doesn't have a valid NumPy dtype) - # TODO: allow users to control how this casting happens, e.g., by - # forwarding arguments to pandas.Series.to_numpy? - arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] + arrays = [] + extension_arrays = [] + for k, v in dataframe.items(): + if not is_extension_array_dtype(v): + arrays.append((k, np.asarray(v))) + else: + extension_arrays.append((k, v)) indexes: dict[Hashable, Index] = {} index_vars: dict[Hashable, Variable] = {} @@ -7339,6 +7369,8 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: xr_idx = PandasIndex(lev, dim) indexes[dim] = xr_idx index_vars.update(xr_idx.create_variables()) + arrays += [(k, np.asarray(v)) for k, v in extension_arrays] + extension_arrays = [] else: index_name = idx.name if idx.name is not None else "index" dims = (index_name,) @@ -7352,7 +7384,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> Self: obj._set_sparse_data_from_dataframe(idx, arrays, dims) else: obj._set_numpy_data_from_dataframe(idx, arrays, dims) - return obj + for name, extension_array in extension_arrays: + obj[name] = (dims, extension_array) + return obj[dataframe.columns] if len(dataframe.columns) else obj def to_dask_dataframe( self, dim_order: Sequence[Hashable] | None = None, set_index: bool = False diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index ef497e78ebf..d95dfa566cc 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -32,6 +32,7 @@ from numpy import concatenate as _concatenate from numpy.lib.stride_tricks import sliding_window_view # noqa from packaging.version import Version +from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_ops, dtypes, nputils from xarray.core.options import OPTIONS @@ -156,7 +157,7 @@ def isnull(data): return full_like(data, dtype=bool, fill_value=False) else: # at this point, array should have dtype=object - if isinstance(data, np.ndarray): + if isinstance(data, np.ndarray) or is_extension_array_dtype(data): return pandas_isnull(data) else: # Not reachable yet, but intended for use with other duck array @@ -221,9 +222,19 @@ def asarray(data, xp=np): def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - array_type_cupy = array_type("cupy") - if array_type_cupy and any( - isinstance(x, array_type_cupy) for x in scalars_or_arrays + if any(is_extension_array_dtype(x) for x in scalars_or_arrays): + extension_array_types = [ + x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x) + ] + if len(extension_array_types) == len(scalars_or_arrays) and all( + isinstance(x, type(extension_array_types[0])) for x in extension_array_types + ): + return scalars_or_arrays + raise ValueError( + f"Cannot cast arrays to shared type, found array types {[x.dtype for x in scalars_or_arrays]}" + ) + elif array_type_cupy := array_type("cupy") and any( # noqa: F841 + isinstance(x, array_type_cupy) for x in scalars_or_arrays # noqa: F821 ): import cupy as cp diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py new file mode 100644 index 00000000000..6521e425615 --- /dev/null +++ b/xarray/core/extension_array.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Callable, Generic + +import numpy as np +import pandas as pd +from pandas.api.types import is_extension_array_dtype + +from xarray.core.types import DTypeLikeSave, T_ExtensionArray + +HANDLED_EXTENSION_ARRAY_FUNCTIONS: dict[Callable, Callable] = {} + + +def implements(numpy_function): + """Register an __array_function__ implementation for MyArray objects.""" + + def decorator(func): + HANDLED_EXTENSION_ARRAY_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.issubdtype) +def __extension_duck_array__issubdtype( + extension_array_dtype: T_ExtensionArray, other_dtype: DTypeLikeSave +) -> bool: + return False # never want a function to think a pandas extension dtype is a subtype of numpy + + +@implements(np.broadcast_to) +def __extension_duck_array__broadcast(arr: T_ExtensionArray, shape: tuple): + if shape[0] == len(arr) and len(shape) == 1: + return arr + raise NotImplementedError("Cannot broadcast 1d-only pandas categorical array.") + + +@implements(np.stack) +def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): + raise NotImplementedError("Cannot stack 1d-only pandas categorical array.") + + +@implements(np.concatenate) +def __extension_duck_array__concatenate( + arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None +) -> T_ExtensionArray: + return type(arrays[0])._concat_same_type(arrays) + + +@implements(np.where) +def __extension_duck_array__where( + condition: np.ndarray, x: T_ExtensionArray, y: T_ExtensionArray +) -> T_ExtensionArray: + if ( + isinstance(x, pd.Categorical) + and isinstance(y, pd.Categorical) + and x.dtype != y.dtype + ): + x = x.add_categories(set(y.categories).difference(set(x.categories))) + y = y.add_categories(set(x.categories).difference(set(y.categories))) + return pd.Series(x).where(condition, pd.Series(y)).array + + +class PandasExtensionArray(Generic[T_ExtensionArray]): + array: T_ExtensionArray + + def __init__(self, array: T_ExtensionArray): + """NEP-18 compliant wrapper for pandas extension arrays. + + Parameters + ---------- + array : T_ExtensionArray + The array to be wrapped upon e.g,. :py:class:`xarray.Variable` creation. + ``` + """ + if not isinstance(array, pd.api.extensions.ExtensionArray): + raise TypeError(f"{array} is not an pandas ExtensionArray.") + self.array = array + + def __array_function__(self, func, types, args, kwargs): + def replace_duck_with_extension_array(args) -> list: + args_as_list = list(args) + for index, value in enumerate(args_as_list): + if isinstance(value, PandasExtensionArray): + args_as_list[index] = value.array + elif isinstance( + value, tuple + ): # should handle more than just tuple? iterable? + args_as_list[index] = tuple( + replace_duck_with_extension_array(value) + ) + elif isinstance(value, list): + args_as_list[index] = replace_duck_with_extension_array(value) + return args_as_list + + args = tuple(replace_duck_with_extension_array(args)) + if func not in HANDLED_EXTENSION_ARRAY_FUNCTIONS: + return func(*args, **kwargs) + res = HANDLED_EXTENSION_ARRAY_FUNCTIONS[func](*args, **kwargs) + if is_extension_array_dtype(res): + return type(self)[type(res)](res) + return res + + def __array_ufunc__(ufunc, method, *inputs, **kwargs): + return ufunc(*inputs, **kwargs) + + def __repr__(self): + return f"{type(self)}(array={repr(self.array)})" + + def __getattr__(self, attr: str) -> object: + return getattr(self.array, attr) + + def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: + item = self.array[key] + if is_extension_array_dtype(item): + return type(self)(item) + if np.isscalar(item): + return type(self)(type(self.array)([item])) + return item + + def __setitem__(self, key, val): + self.array[key] = val + + def __eq__(self, other): + if np.isscalar(other): + other = type(self)(type(self.array)([other])) + if isinstance(other, PandasExtensionArray): + return self.array == other.array + return self.array == other + + def __ne__(self, other): + return ~(self == other) + + def __len__(self): + return len(self.array) diff --git a/xarray/core/types.py b/xarray/core/types.py index 410cf3de00b..8f58e54d8cf 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -167,6 +167,9 @@ def copy( # hopefully in the future we can narrow this down more: T_DuckArray = TypeVar("T_DuckArray", bound=Any, covariant=True) +# For typing pandas extension arrays. +T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) + ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] VarCompatible = Union["Variable", "ScalarOrArray"] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e89cf95411c..2229eaa2d24 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,11 +13,13 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike +from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic from xarray.core.common import AbstractArray +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( BasicIndexer, OuterIndexer, @@ -47,6 +49,7 @@ NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( indexing.ExplicitlyIndexed, pd.Index, + pd.api.extensions.ExtensionArray, ) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) @@ -184,6 +187,8 @@ def _maybe_wrap_data(data): """ if isinstance(data, pd.Index): return PandasIndexingAdapter(data) + if isinstance(data, pd.api.extensions.ExtensionArray): + return PandasExtensionArray[type(data)](data) return data @@ -2570,6 +2575,11 @@ def chunk( # type: ignore[override] dask.array.from_array """ + if is_extension_array_dtype(self): + raise ValueError( + f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." + ) + if from_array_kwargs is None: from_array_kwargs = {} diff --git a/xarray/testing/strategies.py b/xarray/testing/strategies.py index d2503dfd535..449d0c793cc 100644 --- a/xarray/testing/strategies.py +++ b/xarray/testing/strategies.py @@ -45,7 +45,7 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: Generates only those numpy dtypes which xarray can handle. Use instead of hypothesis.extra.numpy.scalar_dtypes in order to exclude weirder dtypes such as unicode, byte_string, array, or nested dtypes. - Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. + Also excludes datetimes, which dodges bugs with pandas non-nanosecond datetime overflows. Checks only native endianness. Requires the hypothesis package to be installed. @@ -56,10 +56,10 @@ def supported_dtypes() -> st.SearchStrategy[np.dtype]: # TODO should this be exposed publicly? # We should at least decide what the set of numpy dtypes that xarray officially supports is. return ( - npst.integer_dtypes() - | npst.unsigned_integer_dtypes() - | npst.floating_dtypes() - | npst.complex_number_dtypes() + npst.integer_dtypes(endianness="=") + | npst.unsigned_integer_dtypes(endianness="=") + | npst.floating_dtypes(endianness="=") + | npst.complex_number_dtypes(endianness="=") # | npst.datetime64_dtypes() # | npst.timedelta64_dtypes() # | npst.unicode_string_dtypes() diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 5007db9eeb2..3ce788dfb7f 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -18,6 +18,7 @@ from xarray import Dataset from xarray.core import utils from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 +from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ExplicitlyIndexed from xarray.core.options import set_options from xarray.core.variable import IndexVariable @@ -52,7 +53,9 @@ def assert_writeable(ds): readonly = [ name for name, var in ds.variables.items() - if not isinstance(var, IndexVariable) and not var.data.flags.writeable + if not isinstance(var, IndexVariable) + and not isinstance(var.data, PandasExtensionArray) + and not var.data.flags.writeable ] assert not readonly, readonly @@ -112,6 +115,7 @@ def _importorskip( has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") has_numbagg, requires_numbagg = _importorskip("numbagg", "0.4.0") +has_pyarrow, requires_pyarrow = _importorskip("pyarrow") with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -307,6 +311,7 @@ def create_test_data( seed: int | None = None, add_attrs: bool = True, dim_sizes: tuple[int, int, int] = _DEFAULT_TEST_DIM_SIZES, + use_extension_array: bool = False, ) -> Dataset: rs = np.random.RandomState(seed) _vars = { @@ -329,7 +334,16 @@ def create_test_data( obj[v] = (dims, data) if add_attrs: obj[v].attrs = {"foo": "variable"} - + if use_extension_array: + obj["var4"] = ( + "dim1", + pd.Categorical( + np.random.choice( + list(string.ascii_lowercase[: np.random.randint(5)]), + size=dim_sizes[0], + ) + ), + ) if dim_sizes == _DEFAULT_TEST_DIM_SIZES: numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64") else: diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0cf4cc03a09..1ddb5a569bd 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -152,6 +152,21 @@ def test_concat_missing_var() -> None: assert_identical(actual, expected) +def test_concat_categorical() -> None: + data1 = create_test_data(use_extension_array=True) + data2 = create_test_data(use_extension_array=True) + concatenated = concat([data1, data2], dim="dim1") + assert ( + concatenated["var4"] + == type(data2["var4"].variable.data.array)._concat_same_type( + [ + data1["var4"].variable.data.array, + data2["var4"].variable.data.array, + ] + ) + ).all() + + def test_concat_missing_multiple_consecutive_var() -> None: datasets = create_concat_datasets(3, seed=123) expected = concat(datasets, dim="day") @@ -451,8 +466,11 @@ def test_concat_fill_missing_variables( class TestConcatDataset: @pytest.fixture - def data(self) -> Dataset: - return create_test_data().drop_dims("dim3") + def data(self, request) -> Dataset: + use_extension_array = request.param if hasattr(request, "param") else False + return create_test_data(use_extension_array=use_extension_array).drop_dims( + "dim3" + ) def rectify_dim_order(self, data, dataset) -> Dataset: # return a new dataset with all variable dimensions transposed into @@ -464,7 +482,9 @@ def rectify_dim_order(self, data, dataset) -> Dataset: ) @pytest.mark.parametrize("coords", ["different", "minimal"]) - @pytest.mark.parametrize("dim", ["dim1", "dim2"]) + @pytest.mark.parametrize( + "dim,data", [["dim1", True], ["dim2", False]], indirect=["data"] + ) def test_concat_simple(self, data, dim, coords) -> None: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim, coords=coords)) @@ -492,6 +512,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: expected = data.copy().assign(foo=(["dim1", "bar"], foo)) assert_identical(expected, actual) + @pytest.mark.parametrize("data", [False], indirect=["data"]) def test_concat_2(self, data) -> None: dim = "dim2" datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index e2a64964775..a948fafc815 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4614,10 +4614,12 @@ def test_to_and_from_dataframe(self) -> None: x = np.random.randn(10) y = np.random.randn(10) t = list("abcdefghij") - ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) + cat = pd.Categorical(["a", "b"] * 5) + ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t), "cat": ("t", cat)}) expected = pd.DataFrame( np.array([x, y]).T, columns=["a", "b"], index=pd.Index(t, name="t") ) + expected["cat"] = cat actual = ds.to_dataframe() # use the .equals method to check all DataFrame metadata assert expected.equals(actual), (expected, actual) @@ -4628,23 +4630,31 @@ def test_to_and_from_dataframe(self) -> None: # check roundtrip assert_identical(ds, Dataset.from_dataframe(actual)) - + assert isinstance(ds["cat"].variable.data.dtype, pd.CategoricalDtype) # test a case with a MultiIndex w = np.random.randn(2, 3) - ds = Dataset({"w": (("x", "y"), w)}) + cat = pd.Categorical(["a", "a", "c"]) + ds = Dataset({"w": (("x", "y"), w), "cat": ("y", cat)}) ds["y"] = ("y", list("abc")) exp_index = pd.MultiIndex.from_arrays( [[0, 0, 0, 1, 1, 1], ["a", "b", "c", "a", "b", "c"]], names=["x", "y"] ) - expected = pd.DataFrame(w.reshape(-1), columns=["w"], index=exp_index) + expected = pd.DataFrame( + {"w": w.reshape(-1), "cat": pd.Categorical(["a", "a", "c", "a", "a", "c"])}, + index=exp_index, + ) actual = ds.to_dataframe() assert expected.equals(actual) # check roundtrip + # from_dataframe attempts to broadcast across because it doesn't know better, so cat must be converted + ds["cat"] = (("x", "y"), np.stack((ds["cat"].to_numpy(), ds["cat"].to_numpy()))) assert_identical(ds.assign_coords(x=[0, 1]), Dataset.from_dataframe(actual)) # Check multiindex reordering new_order = ["x", "y"] + # revert broadcasting fix above for 1d arrays + ds["cat"] = ("y", cat) actual = ds.to_dataframe(dim_order=new_order) assert expected.equals(actual) @@ -4653,7 +4663,11 @@ def test_to_and_from_dataframe(self) -> None: [["a", "a", "b", "b", "c", "c"], [0, 1, 0, 1, 0, 1]], names=["y", "x"] ) expected = pd.DataFrame( - w.transpose().reshape(-1), columns=["w"], index=exp_index + { + "w": w.transpose().reshape(-1), + "cat": pd.Categorical(["a", "a", "a", "a", "c", "c"]), + }, + index=exp_index, ) actual = ds.to_dataframe(dim_order=new_order) assert expected.equals(actual) @@ -4706,7 +4720,7 @@ def test_to_and_from_dataframe(self) -> None: expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) - def test_from_dataframe_categorical(self) -> None: + def test_from_dataframe_categorical_index(self) -> None: cat = pd.CategoricalDtype( categories=["foo", "bar", "baz", "qux", "quux", "corge"] ) @@ -4721,7 +4735,7 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 - def test_from_dataframe_categorical_string_categories(self) -> None: + def test_from_dataframe_categorical_index_string_categories(self) -> None: cat = pd.CategoricalIndex( pd.Categorical.from_codes( np.array([1, 1, 0, 2]), @@ -5449,18 +5463,22 @@ def test_reduce_cumsum_test_dims(self, reduct, expected, func) -> None: assert list(actual) == expected def test_reduce_non_numeric(self) -> None: - data1 = create_test_data(seed=44) + data1 = create_test_data(seed=44, use_extension_array=True) data2 = create_test_data(seed=44) - add_vars = {"var4": ["dim1", "dim2"], "var5": ["dim1"]} + add_vars = {"var5": ["dim1", "dim2"], "var6": ["dim1"]} for v, dims in sorted(add_vars.items()): size = tuple(data1.sizes[d] for d in dims) data = np.random.randint(0, 100, size=size).astype(np.str_) data1[v] = (dims, data, {"foo": "variable"}) - - assert "var4" not in data1.mean() and "var5" not in data1.mean() + # var4 is extension array categorical and should be dropped + assert ( + "var4" not in data1.mean() + and "var5" not in data1.mean() + and "var6" not in data1.mean() + ) assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) - assert "var4" not in data1.mean(dim="dim2") and "var5" in data1.mean(dim="dim2") + assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2") @pytest.mark.filterwarnings( "ignore:Once the behaviour of DataArray:DeprecationWarning" diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index df1ab1f40f9..26821c69495 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -27,6 +27,7 @@ timedelta_to_numeric, where, ) +from xarray.core.extension_array import PandasExtensionArray from xarray.namedarray.pycompat import array_type from xarray.testing import assert_allclose, assert_equal, assert_identical from xarray.tests import ( @@ -38,11 +39,55 @@ requires_bottleneck, requires_cftime, requires_dask, + requires_pyarrow, ) dask_array_type = array_type("dask") +@pytest.fixture +def categorical1(): + return pd.Categorical(["cat1", "cat2", "cat2", "cat1", "cat2"]) + + +@pytest.fixture +def categorical2(): + return pd.Categorical(["cat2", "cat1", "cat2", "cat3", "cat1"]) + + +try: + import pyarrow as pa + + @pytest.fixture + def arrow1(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 1, "y": True}, {"x": 2, "y": False}]) + ) + + @pytest.fixture + def arrow2(): + return pd.arrays.ArrowExtensionArray( + pa.array([{"x": 3, "y": False}, {"x": 4, "y": True}]) + ) + +except ImportError: + pass + + +@pytest.fixture +def int1(): + return pd.arrays.IntegerArray( + np.array([1, 2, 3, 4, 5]), np.array([True, False, False, True, True]) + ) + + +@pytest.fixture +def int2(): + return pd.arrays.IntegerArray( + np.array([6, 7, 8, 9, 10]), np.array([True, True, False, True, False]) + ) + + class TestOps: @pytest.fixture(autouse=True) def setUp(self): @@ -119,6 +164,51 @@ def test_where_type_promotion(self): assert result.dtype == np.float32 assert_array_equal(result, np.array([1, np.nan], dtype=np.float32)) + def test_where_extension_duck_array(self, categorical1, categorical2): + where_res = where( + np.array([True, False, True, False, False]), + PandasExtensionArray(categorical1), + PandasExtensionArray(categorical2), + ) + assert isinstance(where_res, PandasExtensionArray) + assert ( + where_res == pd.Categorical(["cat1", "cat1", "cat2", "cat3", "cat1"]) + ).all() + + def test_concatenate_extension_duck_array(self, categorical1, categorical2): + concate_res = concatenate( + [PandasExtensionArray(categorical1), PandasExtensionArray(categorical2)] + ) + assert isinstance(concate_res, PandasExtensionArray) + assert ( + concate_res + == type(categorical1)._concat_same_type((categorical1, categorical2)) + ).all() + + @requires_pyarrow + def test_duck_extension_array_pyarrow_concatenate(self, arrow1, arrow2): + concatenated = concatenate( + (PandasExtensionArray(arrow1), PandasExtensionArray(arrow2)) + ) + assert concatenated[2]["x"] == 3 + assert concatenated[3]["y"] + + def test___getitem__extension_duck_array(self, categorical1): + extension_duck_array = PandasExtensionArray(categorical1) + assert (extension_duck_array[0:2] == categorical1[0:2]).all() + assert isinstance(extension_duck_array[0:2], PandasExtensionArray) + assert extension_duck_array[0] == categorical1[0] + assert isinstance(extension_duck_array[0], PandasExtensionArray) + mask = [True, False, True, False, True] + assert (extension_duck_array[mask] == categorical1[mask]).all() + + def test__setitem__extension_duck_array(self, categorical1): + extension_duck_array = PandasExtensionArray(categorical1) + extension_duck_array[2] = "cat1" # already existing category + assert extension_duck_array[2] == "cat1" + with pytest.raises(TypeError, match="Cannot setitem on a Categorical"): + extension_duck_array[2] = "cat4" # new category + def test_stack_type_promotion(self): result = stack([1, "b"]) assert_array_equal(result, np.array([1, "b"], dtype=object)) @@ -932,3 +1022,21 @@ def test_push_dask(): dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)), axis=0, n=n ) np.testing.assert_equal(actual, expected) + + +def test_duck_extension_array_equality(categorical1, int1): + int_duck_array = PandasExtensionArray(int1) + categorical_duck_array = PandasExtensionArray(categorical1) + assert (int_duck_array != categorical_duck_array).all() + assert (categorical_duck_array == categorical1).all() + assert (int_duck_array[0:2] == int1[0:2]).all() + + +def test_duck_extension_array_repr(int1): + int_duck_array = PandasExtensionArray(int1) + assert repr(int1) in repr(int_duck_array) + + +def test_duck_extension_array_attr(int1): + int_duck_array = PandasExtensionArray(int1) + assert (~int_duck_array.fillna(10)).all() diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index df9c40ca6f4..afe4d669628 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -35,6 +35,7 @@ def dataset() -> xr.Dataset: { "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), "baz": ("x", ["e", "f", "g"]), + "cat": ("y", pd.Categorical(["cat1", "cat2", "cat2", "cat1"])), }, {"x": ("x", ["a", "b", "c"], {"name": "x"}), "y": [1, 2, 3, 4], "z": [1, 2]}, ) @@ -79,6 +80,7 @@ def test_groupby_dims_property(dataset, recwarn) -> None: ) assert len(recwarn) == 0 + dataset = dataset.drop_vars(["cat"]) stacked = dataset.stack({"xy": ("x", "y")}) assert tuple(stacked.groupby("xy", squeeze=False).dims) == tuple( stacked.isel(xy=[0]).dims @@ -91,7 +93,7 @@ def test_groupby_sizes_property(dataset) -> None: assert dataset.groupby("x").sizes == dataset.isel(x=1).sizes with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert dataset.groupby("y").sizes == dataset.isel(y=1).sizes - + dataset = dataset.drop("cat") stacked = dataset.stack({"xy": ("x", "y")}) with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert stacked.groupby("xy").sizes == stacked.isel(xy=0).sizes @@ -760,6 +762,8 @@ def test_groupby_getitem(dataset) -> None: assert_identical(dataset.foo.sel(x="a"), dataset.foo.groupby("x")["a"]) with pytest.warns(UserWarning, match="The `squeeze` kwarg"): assert_identical(dataset.foo.sel(z=1), dataset.foo.groupby("z")[1]) + with pytest.warns(UserWarning, match="The `squeeze` kwarg"): + assert_identical(dataset.cat.sel(y=1), dataset.cat.groupby("y")[1]) assert_identical(dataset.sel(x=["a"]), dataset.groupby("x", squeeze=False)["a"]) assert_identical(dataset.sel(z=[1]), dataset.groupby("z", squeeze=False)[1]) @@ -769,6 +773,12 @@ def test_groupby_getitem(dataset) -> None: ) assert_identical(dataset.foo.sel(z=[1]), dataset.foo.groupby("z", squeeze=False)[1]) + assert_identical(dataset.cat.sel(y=[1]), dataset.cat.groupby("y", squeeze=False)[1]) + with pytest.raises( + NotImplementedError, match="Cannot broadcast 1d-only pandas categorical array." + ): + dataset.groupby("boo", squeeze=False) + dataset = dataset.drop_vars(["cat"]) actual = ( dataset.groupby("boo", squeeze=False)["f"].unstack().transpose("x", "y", "z") ) diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py index c6597d5abb0..52935e9714e 100644 --- a/xarray/tests/test_merge.py +++ b/xarray/tests/test_merge.py @@ -37,7 +37,7 @@ def test_merge_arrays(self): assert_identical(actual, expected) def test_merge_datasets(self): - data = create_test_data(add_attrs=False) + data = create_test_data(add_attrs=False, use_extension_array=True) actual = xr.merge([data[["var1"]], data[["var2"]]]) expected = data[["var1", "var2"]] diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d9289aa6674..8a9345e74d4 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1576,6 +1576,20 @@ def test_transpose_0d(self): actual = variable.transpose() assert_identical(actual, variable) + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + print(v) # should not error + assert pd.api.types.is_extension_array_dtype(v.dtype) + + def test_pandas_cateogrical_no_chunk(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + v = self.cls("x", data) + with pytest.raises( + ValueError, match=r".*was found to be a Pandas ExtensionArray.*" + ): + v.chunk((5,)) + def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) @@ -2373,6 +2387,11 @@ def test_multiindex(self): def test_pad(self, mode, xr_arg, np_arg): super().test_pad(mode, xr_arg, np_arg) + def test_pandas_cateogrical_dtype(self): + data = pd.Categorical(np.arange(10, dtype="int64")) + with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): + self.cls("x", data) + @requires_sparse class TestVariableWithSparse: From 5f24079ddaeb0bb16a039d091cbd48710aca4915 Mon Sep 17 00:00:00 2001 From: Eni <51421921+eni-awowale@users.noreply.github.com> Date: Thu, 18 Apr 2024 17:59:44 -0400 Subject: [PATCH 7/7] Migrate formatting_html.py into xarray core (#8930) * DAS-2066: formatting_html.py merge * Revert "DAS-2066: formatting_html.py merge" oopsies This reverts commit d68ae3c823bbbe26c918605a388464b5ba2309c2. * Adding datatree functions to core * Deleted formatting_html.py and test_formatting_html.py in datatree_ * renamed all of the relevant functions so they represent datatree * added comment for mypy * added dict[any, any] as type hint for children * added dict[Any, Any] as type hint for children * updated typhint to dict[str, DataTree] and increased pixel min width to 700 * replaced Any with DataTree * added type checking * DAS-2066 - Incorporate CSS changes for DataTree display. * DAS-2066 - Fix tests. --------- Co-authored-by: Owen Littlejohns Co-authored-by: Matt Savoie --- doc/whats-new.rst | 3 + xarray/core/datatree.py | 6 +- xarray/core/formatting_html.py | 131 ++++++++++++ xarray/datatree_/datatree/formatting_html.py | 135 ------------ .../datatree/tests/test_formatting_html.py | 198 ------------------ xarray/tests/test_formatting_html.py | 195 +++++++++++++++++ 6 files changed, 332 insertions(+), 336 deletions(-) delete mode 100644 xarray/datatree_/datatree/formatting_html.py delete mode 100644 xarray/datatree_/datatree/tests/test_formatting_html.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4fe51708646..2332f7f236b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,9 @@ Bug fixes Internal Changes ~~~~~~~~~~~~~~~~ +- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`) + By `Eni Awowale `_, `Julia Signell `_ + and `Tom Nicholas `_. - Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`) By `Matt Savoie `_ `Owen Littlejohns ` and `Tom Nicholas `_. diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index e24dfb504b1..57fd7222898 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -23,6 +23,9 @@ check_isomorphic, map_over_subtree, ) +from xarray.core.formatting_html import ( + datatree_repr as datatree_repr_html, +) from xarray.core.indexes import Index, Indexes from xarray.core.merge import dataset_update_method from xarray.core.options import OPTIONS as XR_OPTS @@ -38,9 +41,6 @@ from xarray.core.variable import Variable from xarray.datatree_.datatree.common import TreeAttrAccessMixin from xarray.datatree_.datatree.formatting import datatree_repr -from xarray.datatree_.datatree.formatting_html import ( - datatree_repr as datatree_repr_html, -) from xarray.datatree_.datatree.ops import ( DataTreeArithmeticMixin, MappedDatasetMethodsMixin, diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 2c76b182207..9bf5befbe3f 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -2,9 +2,11 @@ import uuid from collections import OrderedDict +from collections.abc import Mapping from functools import lru_cache, partial from html import escape from importlib.resources import files +from typing import TYPE_CHECKING from xarray.core.formatting import ( inline_index_repr, @@ -18,6 +20,9 @@ ("xarray.static.css", "style.css"), ) +if TYPE_CHECKING: + from xarray.core.datatree import DataTree + @lru_cache(None) def _load_static_files(): @@ -341,3 +346,129 @@ def dataset_repr(ds) -> str: ] return _obj_repr(ds, header_components, sections) + + +def summarize_datatree_children(children: Mapping[str, DataTree]) -> str: + N_CHILDREN = len(children) - 1 + + # Get result from datatree_node_repr and wrap it + lines_callback = lambda n, c, end: _wrap_datatree_repr( + datatree_node_repr(n, c), end=end + ) + + children_html = "".join( + ( + lines_callback(n, c, end=False) # Long lines + if i < N_CHILDREN + else lines_callback(n, c, end=True) + ) # Short lines + for i, (n, c) in enumerate(children.items()) + ) + + return "".join( + [ + "
", + children_html, + "
", + ] + ) + + +children_section = partial( + _mapping_section, + name="Groups", + details_func=summarize_datatree_children, + max_items_collapse=1, + expand_option_name="display_expand_groups", +) + + +def datatree_node_repr(group_title: str, dt: DataTree) -> str: + header_components = [f"
{escape(group_title)}
"] + + ds = dt.ds + + sections = [ + children_section(dt.children), + dim_section(ds), + coord_section(ds.coords), + datavar_section(ds.data_vars), + attr_section(ds.attrs), + ] + + return _obj_repr(ds, header_components, sections) + + +def _wrap_datatree_repr(r: str, end: bool = False) -> str: + """ + Wrap HTML representation with a tee to the left of it. + + Enclosing HTML tag is a
with :code:`display: inline-grid` style. + + Turns: + [ title ] + | details | + |_____________| + + into (A): + |─ [ title ] + | | details | + | |_____________| + + or (B): + └─ [ title ] + | details | + |_____________| + + Parameters + ---------- + r: str + HTML representation to wrap. + end: bool + Specify if the line on the left should continue or end. + + Default is True. + + Returns + ------- + str + Wrapped HTML representation. + + Tee color is set to the variable :code:`--xr-border-color`. + """ + # height of line + end = bool(end) + height = "100%" if end is False else "1.2em" + return "".join( + [ + "
", + "
", + "
", + "
", + "
", + "
", + r, + "
", + "
", + ] + ) + + +def datatree_repr(dt: DataTree) -> str: + obj_type = f"datatree.{type(dt).__name__}" + return datatree_node_repr(obj_type, dt) diff --git a/xarray/datatree_/datatree/formatting_html.py b/xarray/datatree_/datatree/formatting_html.py deleted file mode 100644 index 547b567a396..00000000000 --- a/xarray/datatree_/datatree/formatting_html.py +++ /dev/null @@ -1,135 +0,0 @@ -from functools import partial -from html import escape -from typing import Any, Mapping - -from xarray.core.formatting_html import ( - _mapping_section, - _obj_repr, - attr_section, - coord_section, - datavar_section, - dim_section, -) - - -def summarize_children(children: Mapping[str, Any]) -> str: - N_CHILDREN = len(children) - 1 - - # Get result from node_repr and wrap it - lines_callback = lambda n, c, end: _wrap_repr(node_repr(n, c), end=end) - - children_html = "".join( - lines_callback(n, c, end=False) # Long lines - if i < N_CHILDREN - else lines_callback(n, c, end=True) # Short lines - for i, (n, c) in enumerate(children.items()) - ) - - return "".join( - [ - "
", - children_html, - "
", - ] - ) - - -children_section = partial( - _mapping_section, - name="Groups", - details_func=summarize_children, - max_items_collapse=1, - expand_option_name="display_expand_groups", -) - - -def node_repr(group_title: str, dt: Any) -> str: - header_components = [f"
{escape(group_title)}
"] - - ds = dt.ds - - sections = [ - children_section(dt.children), - dim_section(ds), - coord_section(ds.coords), - datavar_section(ds.data_vars), - attr_section(ds.attrs), - ] - - return _obj_repr(ds, header_components, sections) - - -def _wrap_repr(r: str, end: bool = False) -> str: - """ - Wrap HTML representation with a tee to the left of it. - - Enclosing HTML tag is a
with :code:`display: inline-grid` style. - - Turns: - [ title ] - | details | - |_____________| - - into (A): - |─ [ title ] - | | details | - | |_____________| - - or (B): - └─ [ title ] - | details | - |_____________| - - Parameters - ---------- - r: str - HTML representation to wrap. - end: bool - Specify if the line on the left should continue or end. - - Default is True. - - Returns - ------- - str - Wrapped HTML representation. - - Tee color is set to the variable :code:`--xr-border-color`. - """ - # height of line - end = bool(end) - height = "100%" if end is False else "1.2em" - return "".join( - [ - "
", - "
", - "
", - "
", - "
", - "
", - "
    ", - r, - "
" "
", - "
", - ] - ) - - -def datatree_repr(dt: Any) -> str: - obj_type = f"datatree.{type(dt).__name__}" - return node_repr(obj_type, dt) diff --git a/xarray/datatree_/datatree/tests/test_formatting_html.py b/xarray/datatree_/datatree/tests/test_formatting_html.py deleted file mode 100644 index 98cdf02bff4..00000000000 --- a/xarray/datatree_/datatree/tests/test_formatting_html.py +++ /dev/null @@ -1,198 +0,0 @@ -import pytest -import xarray as xr - -from xarray.core.datatree import DataTree -from xarray.datatree_.datatree import formatting_html - - -@pytest.fixture(scope="module", params=["some html", "some other html"]) -def repr(request): - return request.param - - -class Test_summarize_children: - """ - Unit tests for summarize_children. - """ - - func = staticmethod(formatting_html.summarize_children) - - @pytest.fixture(scope="class") - def childfree_tree_factory(self): - """ - Fixture for a child-free DataTree factory. - """ - from random import randint - - def _childfree_tree_factory(): - return DataTree( - data=xr.Dataset({"z": ("y", [randint(1, 100) for _ in range(3)])}) - ) - - return _childfree_tree_factory - - @pytest.fixture(scope="class") - def childfree_tree(self, childfree_tree_factory): - """ - Fixture for a child-free DataTree. - """ - return childfree_tree_factory() - - @pytest.fixture(scope="function") - def mock_node_repr(self, monkeypatch): - """ - Apply mocking for node_repr. - """ - - def mock(group_title, dt): - """ - Mock with a simple result - """ - return group_title + " " + str(id(dt)) - - monkeypatch.setattr(formatting_html, "node_repr", mock) - - @pytest.fixture(scope="function") - def mock_wrap_repr(self, monkeypatch): - """ - Apply mocking for _wrap_repr. - """ - - def mock(r, *, end, **kwargs): - """ - Mock by appending "end" or "not end". - """ - return r + " " + ("end" if end else "not end") + "//" - - monkeypatch.setattr(formatting_html, "_wrap_repr", mock) - - def test_empty_mapping(self): - """ - Test with an empty mapping of children. - """ - children = {} - assert self.func(children) == ( - "
" "
" - ) - - def test_one_child(self, childfree_tree, mock_wrap_repr, mock_node_repr): - """ - Test with one child. - - Uses a mock of _wrap_repr and node_repr to essentially mock - the inline lambda function "lines_callback". - """ - # Create mapping of children - children = {"a": childfree_tree} - - # Expect first line to be produced from the first child, and - # wrapped as the last child - first_line = f"a {id(children['a'])} end//" - - assert self.func(children) == ( - "
" - f"{first_line}" - "
" - ) - - def test_two_children(self, childfree_tree_factory, mock_wrap_repr, mock_node_repr): - """ - Test with two level deep children. - - Uses a mock of _wrap_repr and node_repr to essentially mock - the inline lambda function "lines_callback". - """ - - # Create mapping of children - children = {"a": childfree_tree_factory(), "b": childfree_tree_factory()} - - # Expect first line to be produced from the first child, and - # wrapped as _not_ the last child - first_line = f"a {id(children['a'])} not end//" - - # Expect second line to be produced from the second child, and - # wrapped as the last child - second_line = f"b {id(children['b'])} end//" - - assert self.func(children) == ( - "
" - f"{first_line}" - f"{second_line}" - "
" - ) - - -class Test__wrap_repr: - """ - Unit tests for _wrap_repr. - """ - - func = staticmethod(formatting_html._wrap_repr) - - def test_end(self, repr): - """ - Test with end=True. - """ - r = self.func(repr, end=True) - assert r == ( - "
" - "
" - "
" - "
" - "
" - "
" - "
    " - f"{repr}" - "
" - "
" - "
" - ) - - def test_not_end(self, repr): - """ - Test with end=False. - """ - r = self.func(repr, end=False) - assert r == ( - "
" - "
" - "
" - "
" - "
" - "
" - "
    " - f"{repr}" - "
" - "
" - "
" - ) diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 6540406e914..ada7f75b21b 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -7,6 +7,7 @@ import xarray as xr from xarray.core import formatting_html as fh from xarray.core.coordinates import Coordinates +from xarray.core.datatree import DataTree @pytest.fixture @@ -196,3 +197,197 @@ def test_nonstr_variable_repr_html() -> None: html = v._repr_html_().strip() assert "
22 :
bar
" in html assert "
  • 10: 3
  • " in html + + +@pytest.fixture(scope="module", params=["some html", "some other html"]) +def repr(request): + return request.param + + +class Test_summarize_datatree_children: + """ + Unit tests for summarize_datatree_children. + """ + + func = staticmethod(fh.summarize_datatree_children) + + @pytest.fixture(scope="class") + def childfree_tree_factory(self): + """ + Fixture for a child-free DataTree factory. + """ + from random import randint + + def _childfree_tree_factory(): + return DataTree( + data=xr.Dataset({"z": ("y", [randint(1, 100) for _ in range(3)])}) + ) + + return _childfree_tree_factory + + @pytest.fixture(scope="class") + def childfree_tree(self, childfree_tree_factory): + """ + Fixture for a child-free DataTree. + """ + return childfree_tree_factory() + + @pytest.fixture(scope="function") + def mock_datatree_node_repr(self, monkeypatch): + """ + Apply mocking for datatree_node_repr. + """ + + def mock(group_title, dt): + """ + Mock with a simple result + """ + return group_title + " " + str(id(dt)) + + monkeypatch.setattr(fh, "datatree_node_repr", mock) + + @pytest.fixture(scope="function") + def mock_wrap_datatree_repr(self, monkeypatch): + """ + Apply mocking for _wrap_datatree_repr. + """ + + def mock(r, *, end, **kwargs): + """ + Mock by appending "end" or "not end". + """ + return r + " " + ("end" if end else "not end") + "//" + + monkeypatch.setattr(fh, "_wrap_datatree_repr", mock) + + def test_empty_mapping(self): + """ + Test with an empty mapping of children. + """ + children: dict[str, DataTree] = {} + assert self.func(children) == ( + "
    " + "
    " + ) + + def test_one_child( + self, childfree_tree, mock_wrap_datatree_repr, mock_datatree_node_repr + ): + """ + Test with one child. + + Uses a mock of _wrap_datatree_repr and _datatree_node_repr to essentially mock + the inline lambda function "lines_callback". + """ + # Create mapping of children + children = {"a": childfree_tree} + + # Expect first line to be produced from the first child, and + # wrapped as the last child + first_line = f"a {id(children['a'])} end//" + + assert self.func(children) == ( + "
    " + f"{first_line}" + "
    " + ) + + def test_two_children( + self, childfree_tree_factory, mock_wrap_datatree_repr, mock_datatree_node_repr + ): + """ + Test with two level deep children. + + Uses a mock of _wrap_datatree_repr and datatree_node_repr to essentially mock + the inline lambda function "lines_callback". + """ + + # Create mapping of children + children = {"a": childfree_tree_factory(), "b": childfree_tree_factory()} + + # Expect first line to be produced from the first child, and + # wrapped as _not_ the last child + first_line = f"a {id(children['a'])} not end//" + + # Expect second line to be produced from the second child, and + # wrapped as the last child + second_line = f"b {id(children['b'])} end//" + + assert self.func(children) == ( + "
    " + f"{first_line}" + f"{second_line}" + "
    " + ) + + +class Test__wrap_datatree_repr: + """ + Unit tests for _wrap_datatree_repr. + """ + + func = staticmethod(fh._wrap_datatree_repr) + + def test_end(self, repr): + """ + Test with end=True. + """ + r = self.func(repr, end=True) + assert r == ( + "
    " + "
    " + "
    " + "
    " + "
    " + "
    " + f"{repr}" + "
    " + "
    " + ) + + def test_not_end(self, repr): + """ + Test with end=False. + """ + r = self.func(repr, end=False) + assert r == ( + "
    " + "
    " + "
    " + "
    " + "
    " + "
    " + f"{repr}" + "
    " + "
    " + )