From 1580c2c47cca425d47e3a4c2777a625dadba0a8f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 16 Jan 2024 10:26:08 +0000 Subject: [PATCH] Clean up Dims type annotation (#8606) --- .github/workflows/ci-additional.yaml | 4 ++-- xarray/core/computation.py | 12 +++++------- xarray/core/types.py | 7 ++++--- xarray/core/utils.py | 26 +++++++++++--------------- xarray/tests/test_interp.py | 4 ++-- xarray/tests/test_utils.py | 18 ++++++++---------- 6 files changed, 32 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index c11816bc658..9d693a8c03e 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -120,7 +120,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy<1.8" --force-reinstall + python -m pip install "mypy<1.9" --force-reinstall - name: Run mypy run: | @@ -174,7 +174,7 @@ jobs: python xarray/util/print_versions.py - name: Install mypy run: | - python -m pip install "mypy<1.8" --force-reinstall + python -m pip install "mypy<1.9" --force-reinstall - name: Run mypy run: | diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 553836961b0..dda72c0163b 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -24,7 +24,7 @@ from xarray.core.parallelcompat import get_chunked_array_type from xarray.core.pycompat import is_chunked_array, is_duck_dask_array from xarray.core.types import Dims, T_DataArray -from xarray.core.utils import is_dict_like, is_scalar +from xarray.core.utils import is_dict_like, is_scalar, parse_dims from xarray.core.variable import Variable from xarray.util.deprecation_helpers import deprecate_dims @@ -1875,16 +1875,14 @@ def dot( einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} - if dim is ...: - dim = all_dims - elif isinstance(dim, str): - dim = (dim,) - elif dim is None: - # find dimensions that occur more than one times + if dim is None: + # find dimensions that occur more than once dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) dim = tuple(d for d, c in dim_counts.items() if c > 1) + else: + dim = parse_dims(dim, all_dims=tuple(all_dims)) dot_dims: set[Hashable] = set(dim) diff --git a/xarray/core/types.py b/xarray/core/types.py index 06ad65679d8..8c3164c52fa 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -2,7 +2,7 @@ import datetime import sys -from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence +from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, @@ -182,8 +182,9 @@ def copy( DsCompatible = Union["Dataset", "DaCompatible"] GroupByCompatible = Union["Dataset", "DataArray"] -Dims = Union[str, Iterable[Hashable], "ellipsis", None] -OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None] +# Don't change to Hashable | Collection[Hashable] +# Read: https://github.com/pydata/xarray/issues/6142 +Dims = Union[str, Collection[Hashable], "ellipsis", None] # FYI in some cases we don't allow `None`, which this doesn't take account of. T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 00c84d4c10c..85f901167e2 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -57,7 +57,6 @@ Mapping, MutableMapping, MutableSet, - Sequence, ValuesView, ) from enum import Enum @@ -76,7 +75,7 @@ import pandas as pd if TYPE_CHECKING: - from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray + from xarray.core.types import Dims, ErrorOptionsWithWarn, T_DuckArray K = TypeVar("K") V = TypeVar("V") @@ -983,12 +982,9 @@ def drop_missing_dims( ) -T_None = TypeVar("T_None", None, "ellipsis") - - @overload def parse_dims( - dim: str | Iterable[Hashable] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, @@ -999,12 +995,12 @@ def parse_dims( @overload def parse_dims( - dim: str | Iterable[Hashable] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, replace_none: Literal[False], -) -> tuple[Hashable, ...] | T_None: +) -> tuple[Hashable, ...] | None | ellipsis: ... @@ -1051,7 +1047,7 @@ def parse_dims( @overload def parse_ordered_dims( - dim: str | Sequence[Hashable | ellipsis] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, @@ -1062,17 +1058,17 @@ def parse_ordered_dims( @overload def parse_ordered_dims( - dim: str | Sequence[Hashable | ellipsis] | T_None, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, replace_none: Literal[False], -) -> tuple[Hashable, ...] | T_None: +) -> tuple[Hashable, ...] | None | ellipsis: ... def parse_ordered_dims( - dim: OrderedDims, + dim: Dims, all_dims: tuple[Hashable, ...], *, check_exists: bool = True, @@ -1126,9 +1122,9 @@ def parse_ordered_dims( ) -def _check_dims(dim: set[Hashable | ellipsis], all_dims: set[Hashable]) -> None: - wrong_dims = dim - all_dims - if wrong_dims and wrong_dims != {...}: +def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None: + wrong_dims = (dim - all_dims) - {...} + if wrong_dims: wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims) raise ValueError( f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}" diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py index 275b8fdb780..de0020b4d00 100644 --- a/xarray/tests/test_interp.py +++ b/xarray/tests/test_interp.py @@ -838,8 +838,8 @@ def test_interpolate_chunk_1d( if chunked: dest[dim] = xr.DataArray(data=dest[dim], dims=[dim]) dest[dim] = dest[dim].chunk(2) - actual = da.interp(method=method, **dest, kwargs=kwargs) # type: ignore - expected = da.compute().interp(method=method, **dest, kwargs=kwargs) # type: ignore + actual = da.interp(method=method, **dest, kwargs=kwargs) + expected = da.compute().interp(method=method, **dest, kwargs=kwargs) assert_identical(actual, expected) diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py index 36f62fad71f..ec898c80344 100644 --- a/xarray/tests/test_utils.py +++ b/xarray/tests/test_utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Hashable, Iterable, Sequence +from collections.abc import Hashable import numpy as np import pandas as pd @@ -257,17 +257,18 @@ def test_infix_dims_errors(supplied, all_): pytest.param("a", ("a",), id="str"), pytest.param(["a", "b"], ("a", "b"), id="list_of_str"), pytest.param(["a", 1], ("a", 1), id="list_mixed"), + pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"), pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"), pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"), pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"), + pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"), + pytest.param((), (), id="empty_tuple"), + pytest.param(set(), (), id="empty_collection"), pytest.param(None, None, id="None"), pytest.param(..., ..., id="ellipsis"), ], ) -def test_parse_dims( - dim: str | Iterable[Hashable] | None, - expected: tuple[Hashable, ...], -) -> None: +def test_parse_dims(dim, expected): all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables actual = utils.parse_dims(dim, all_dims, replace_none=False) assert actual == expected @@ -297,7 +298,7 @@ def test_parse_dims_replace_none(dim: None | ellipsis) -> None: pytest.param(["x", 2], id="list_missing_all"), ], ) -def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None: +def test_parse_dims_raises(dim): all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables with pytest.raises(ValueError, match="'x'"): utils.parse_dims(dim, all_dims, check_exists=True) @@ -313,10 +314,7 @@ def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None: pytest.param(["a", ..., "b"], ("a", "c", "b"), id="list_with_middle_ellipsis"), ], ) -def test_parse_ordered_dims( - dim: str | Sequence[Hashable | ellipsis], - expected: tuple[Hashable, ...], -) -> None: +def test_parse_ordered_dims(dim, expected): all_dims = ("a", "b", "c") actual = utils.parse_ordered_dims(dim, all_dims) assert actual == expected