diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 34d4f9a23ad..a30affcbe93 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -5,7 +5,7 @@ from collections.abc import Hashable from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Callable, Union, overload import numpy as np import pandas as pd @@ -22,13 +22,19 @@ ) from xarray.core import indexing from xarray.core.common import contains_cftime_datetimes, is_np_datetime_like -from xarray.core.duck_array_ops import asarray +from xarray.core.duck_array_ops import asarray, ravel from xarray.core.formatting import first_n_items, format_timestamp, last_item from xarray.core.pdcompat import nanosecond_precision_timestamp from xarray.core.utils import emit_user_level_warning from xarray.core.variable import Variable -from xarray.namedarray.parallelcompat import T_ChunkedArray, get_chunked_array_type -from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray._typing import ( + _chunkedarrayfunction_or_api, + chunkedduckarray, + duckarray, +) +from xarray.namedarray.parallelcompat import get_chunked_array_type + +# from xarray.namedarray.pycompat import is_chunked_array from xarray.namedarray.utils import is_duck_dask_array try: @@ -37,7 +43,7 @@ cftime = None if TYPE_CHECKING: - from xarray.core.types import CFCalendar, T_DuckArray + from xarray.core.types import CFCalendar T_Name = Union[Hashable, None] @@ -315,7 +321,7 @@ def decode_cf_datetime( cftime.num2date """ num_dates = np.asarray(num_dates) - flat_num_dates = num_dates.ravel() + flat_num_dates = ravel(num_dates) if calendar is None: calendar = "standard" @@ -369,7 +375,7 @@ def decode_cf_timedelta(num_timedeltas, units: str) -> np.ndarray: """ num_timedeltas = np.asarray(num_timedeltas) units = _netcdf_to_numpy_timeunit(units) - result = to_timedelta_unboxed(num_timedeltas.ravel(), unit=units) + result = to_timedelta_unboxed(ravel(num_timedeltas), unit=units) return result.reshape(num_timedeltas.shape) @@ -428,7 +434,7 @@ def infer_datetime_units(dates) -> str: 'hours', 'minutes' or 'seconds' (the first one that can evenly divide all unique time deltas in `dates`) """ - dates = np.asarray(dates).ravel() + dates = ravel(np.asarray(dates)) if np.asarray(dates).dtype == "datetime64[ns]": dates = to_datetime_unboxed(dates) dates = dates[pd.notnull(dates)] @@ -456,7 +462,7 @@ def infer_timedelta_units(deltas) -> str: {'days', 'hours', 'minutes' 'seconds'} (the first one that can evenly divide all unique time deltas in `deltas`) """ - deltas = to_timedelta_unboxed(np.asarray(deltas).ravel()) + deltas = to_timedelta_unboxed(ravel(np.asarray(deltas))) unique_timedeltas = np.unique(deltas[pd.notnull(deltas)]) return _infer_time_units_from_diff(unique_timedeltas) @@ -643,7 +649,7 @@ def encode_datetime(d): except TypeError: return np.nan if d is None else cftime.date2num(d, units, calendar) - return np.array([encode_datetime(d) for d in dates.ravel()]).reshape(dates.shape) + return np.array([encode_datetime(d) for d in ravel(dates)]).reshape(dates.shape) def cast_to_int_if_safe(num) -> np.ndarray: @@ -700,12 +706,26 @@ def _cast_to_dtype_if_safe(num: np.ndarray, dtype: np.dtype) -> np.ndarray: return cast_num +@overload +def encode_cf_datetime( + dates: chunkedduckarray, + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[chunkedduckarray, str, str]: ... +@overload +def encode_cf_datetime( + dates: duckarray, + units: str | None = None, + calendar: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[duckarray, str, str]: ... def encode_cf_datetime( - dates: T_DuckArray, # type: ignore + dates: duckarray | chunkedduckarray, units: str | None = None, calendar: str | None = None, dtype: np.dtype | None = None, -) -> tuple[T_DuckArray, str, str]: +) -> tuple[duckarray | chunkedduckarray, str, str]: """Given an array of datetime objects, returns the tuple `(num, units, calendar)` suitable for a CF compliant time variable. @@ -716,19 +736,19 @@ def encode_cf_datetime( cftime.date2num """ dates = asarray(dates) - if is_chunked_array(dates): + if isinstance(dates, _chunkedarrayfunction_or_api): return _lazily_encode_cf_datetime(dates, units, calendar, dtype) else: return _eagerly_encode_cf_datetime(dates, units, calendar, dtype) def _eagerly_encode_cf_datetime( - dates: T_DuckArray, # type: ignore + dates: duckarray, units: str | None = None, calendar: str | None = None, dtype: np.dtype | None = None, allow_units_modification: bool = True, -) -> tuple[T_DuckArray, str, str]: +) -> tuple[duckarray, str, str]: dates = asarray(dates) data_units = infer_datetime_units(dates) @@ -753,7 +773,7 @@ def _eagerly_encode_cf_datetime( # Wrap the dates in a DatetimeIndex to do the subtraction to ensure # an OverflowError is raised if the ref_date is too far away from # dates to be encoded (GH 2272). - dates_as_index = pd.DatetimeIndex(dates.ravel()) + dates_as_index = pd.DatetimeIndex(ravel(dates)) time_deltas = dates_as_index - ref_date # retrieve needed units to faithfully encode to int64 @@ -806,11 +826,11 @@ def _eagerly_encode_cf_datetime( def _encode_cf_datetime_within_map_blocks( - dates: T_DuckArray, # type: ignore + dates: duckarray, units: str, calendar: str, dtype: np.dtype, -) -> T_DuckArray: +) -> duckarray: num, *_ = _eagerly_encode_cf_datetime( dates, units, calendar, dtype, allow_units_modification=False ) @@ -818,11 +838,11 @@ def _encode_cf_datetime_within_map_blocks( def _lazily_encode_cf_datetime( - dates: T_ChunkedArray, + dates: chunkedduckarray, units: str | None = None, calendar: str | None = None, dtype: np.dtype | None = None, -) -> tuple[T_ChunkedArray, str, str]: +) -> tuple[chunkedduckarray, str, str]: if calendar is None: # This will only trigger minor compute if dates is an object dtype array. calendar = infer_calendar_name(dates) @@ -855,31 +875,43 @@ def _lazily_encode_cf_datetime( return num, units, calendar +@overload def encode_cf_timedelta( - timedeltas: T_DuckArray, # type: ignore + timedeltas: chunkedduckarray, units: str | None = None, dtype: np.dtype | None = None, -) -> tuple[T_DuckArray, str]: +) -> tuple[chunkedduckarray, str]: ... +@overload +def encode_cf_timedelta( + timedeltas: duckarray, + units: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[duckarray, str]: ... +def encode_cf_timedelta( + timedeltas: chunkedduckarray | duckarray, + units: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[chunkedduckarray | duckarray, str]: timedeltas = asarray(timedeltas) - if is_chunked_array(timedeltas): + if isinstance(timedeltas, _chunkedarrayfunction_or_api): return _lazily_encode_cf_timedelta(timedeltas, units, dtype) else: return _eagerly_encode_cf_timedelta(timedeltas, units, dtype) def _eagerly_encode_cf_timedelta( - timedeltas: T_DuckArray, # type: ignore + timedeltas: duckarray, units: str | None = None, dtype: np.dtype | None = None, allow_units_modification: bool = True, -) -> tuple[T_DuckArray, str]: +) -> tuple[duckarray, str]: data_units = infer_timedelta_units(timedeltas) if units is None: units = data_units time_delta = _time_units_to_timedelta64(units) - time_deltas = pd.TimedeltaIndex(timedeltas.ravel()) + time_deltas = pd.TimedeltaIndex(ravel(timedeltas)) # retrieve needed units to faithfully encode to int64 needed_units = data_units @@ -920,10 +952,10 @@ def _eagerly_encode_cf_timedelta( def _encode_cf_timedelta_within_map_blocks( - timedeltas: T_DuckArray, # type:ignore + timedeltas: duckarray, units: str, dtype: np.dtype, -) -> T_DuckArray: +) -> duckarray: num, _ = _eagerly_encode_cf_timedelta( timedeltas, units, dtype, allow_units_modification=False ) @@ -931,8 +963,10 @@ def _encode_cf_timedelta_within_map_blocks( def _lazily_encode_cf_timedelta( - timedeltas: T_ChunkedArray, units: str | None = None, dtype: np.dtype | None = None -) -> tuple[T_ChunkedArray, str]: + timedeltas: chunkedduckarray, + units: str | None = None, + dtype: np.dtype | None = None, +) -> tuple[chunkedduckarray, str]: if units is None and dtype is None: units = "nanoseconds" dtype = np.dtype("int64")