From 503b313c76e16a78f500e809dad40b5b895c0927 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 11 Dec 2024 10:55:03 +0100 Subject: [PATCH] (fix): `dtype` type handling --- xarray/core/indexing.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 775b65b5144..722ac6db4a8 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1663,7 +1663,11 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): array: pd.Index _dtype: np.dtype | pd.api.extensions.ExtensionDtype - def __init__(self, array: pd.Index, dtype: DTypeLike = None): + def __init__( + self, + array: pd.Index, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + ): from xarray.core.indexes import safe_cast_to_index self.array = safe_cast_to_index(array) @@ -1675,22 +1679,26 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None): else: self._dtype = get_valid_numpy_dtype(array) elif pd.api.types.is_extension_array_dtype(dtype): - cast(pd.api.extensions.ExtensionDtype, dtype) - self._dtype = dtype + self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype) else: - self._dtype = np.dtype(dtype) + self._dtype = np.dtype(cast(DTypeLike, dtype)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override] return self._dtype def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: np.typing.DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: if dtype is None: dtype = self.dtype if pd.api.types.is_extension_array_dtype(dtype): dtype = get_valid_numpy_dtype(self.array) + dtype = cast(np.dtype, dtype) array = self.array if isinstance(array, pd.PeriodIndex): with suppress(AttributeError): @@ -1726,7 +1734,7 @@ def _convert_scalar(self, item) -> np.ndarray: dtype = self.dtype if pd.api.types.is_extension_array_dtype(dtype): dtype = get_valid_numpy_dtype(self.array) - item = np.asarray(item, dtype=dtype) + item = np.asarray(item, dtype=cast(np.dtype, dtype)) # as for numpy.ndarray indexing, we always want the result to be # a NumPy array. @@ -1846,19 +1854,24 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): def __init__( self, array: pd.MultiIndex, - dtype: DTypeLike = None, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, level: str | None = None, ): super().__init__(array, dtype) self.level = level def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: if dtype is None: dtype = self.dtype if pd.api.types.is_extension_array_dtype(dtype): dtype = get_valid_numpy_dtype(self.array) + dtype = cast(np.dtype, dtype) if self.level is not None: return np.asarray( self.array.get_level_values(self.level).values, dtype=dtype