Skip to content

Commit

Permalink
(fix): dtype type handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Dec 11, 2024
1 parent a405f03 commit 503b313
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,7 +1663,11 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
array: pd.Index
_dtype: np.dtype | pd.api.extensions.ExtensionDtype

def __init__(self, array: pd.Index, dtype: DTypeLike = None):
def __init__(
self,
array: pd.Index,
dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None,
):
from xarray.core.indexes import safe_cast_to_index

self.array = safe_cast_to_index(array)
Expand All @@ -1675,22 +1679,26 @@ def __init__(self, array: pd.Index, dtype: DTypeLike = None):
else:
self._dtype = get_valid_numpy_dtype(array)
elif pd.api.types.is_extension_array_dtype(dtype):
cast(pd.api.extensions.ExtensionDtype, dtype)
self._dtype = dtype
self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype)
else:
self._dtype = np.dtype(dtype)
self._dtype = np.dtype(cast(DTypeLike, dtype))

@property
def dtype(self) -> np.dtype:
def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override]
return self._dtype

def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
self,
dtype: np.typing.DTypeLike | pd.api.extensions.ExtensionDtype | None = None,
/,
*,
copy: bool | None = None,
) -> np.ndarray:
if dtype is None:
dtype = self.dtype
if pd.api.types.is_extension_array_dtype(dtype):
dtype = get_valid_numpy_dtype(self.array)
dtype = cast(np.dtype, dtype)
array = self.array
if isinstance(array, pd.PeriodIndex):
with suppress(AttributeError):
Expand Down Expand Up @@ -1726,7 +1734,7 @@ def _convert_scalar(self, item) -> np.ndarray:
dtype = self.dtype
if pd.api.types.is_extension_array_dtype(dtype):
dtype = get_valid_numpy_dtype(self.array)
item = np.asarray(item, dtype=dtype)
item = np.asarray(item, dtype=cast(np.dtype, dtype))

# as for numpy.ndarray indexing, we always want the result to be
# a NumPy array.
Expand Down Expand Up @@ -1846,19 +1854,24 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter):
def __init__(
self,
array: pd.MultiIndex,
dtype: DTypeLike = None,
dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None,
level: str | None = None,
):
super().__init__(array, dtype)
self.level = level

def __array__(
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
self,
dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None,
/,
*,
copy: bool | None = None,
) -> np.ndarray:
if dtype is None:
dtype = self.dtype
if pd.api.types.is_extension_array_dtype(dtype):
dtype = get_valid_numpy_dtype(self.array)
dtype = cast(np.dtype, dtype)
if self.level is not None:
return np.asarray(
self.array.get_level_values(self.level).values, dtype=dtype
Expand Down

0 comments on commit 503b313

Please sign in to comment.