From 678536c19f7ceb89421296064f45936a20938567 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 1 Jul 2024 22:38:56 +0200 Subject: [PATCH 01/23] remove pandas from ignore missing imports --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2ada0c1c171..0fcbec312ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,6 @@ module = [ "netCDF4.*", "netcdftime.*", "opt_einsum.*", - "pandas.*", "pint.*", "pooch.*", "pyarrow.*", From 20066c49df9ecd4e33f810533953e33e69c4d9b1 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 1 Jul 2024 22:39:46 +0200 Subject: [PATCH 02/23] add Any to dim arg of concat as placeholder --- xarray/core/concat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index b1cca586992..5e6af84c3b6 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -32,10 +32,11 @@ T_DataVars = Union[ConcatOptions, Iterable[Hashable]] +# TODO: replace dim: Any by 1D array_likes @overload def concat( objs: Iterable[T_Dataset], - dim: Hashable | T_Variable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", @@ -50,7 +51,7 @@ def concat( @overload def concat( objs: Iterable[T_DataArray], - dim: Hashable | T_Variable | T_DataArray | pd.Index, + dim: Hashable | T_Variable | T_DataArray | pd.Index | Any, data_vars: T_DataVars = "all", coords: ConcatOptions | list[Hashable] = "different", compat: CompatOptions = "equals", From 9053b6076023bb738c3e3df7f336295b166345d9 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 1 Jul 2024 22:40:55 +0200 Subject: [PATCH 03/23] allow sequence of np.ndarrays as coords in dataArray constructor --- xarray/core/dataarray.py | 4 +++- xarray/core/utils.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b67f8089eb2..6264970704b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -413,7 +413,9 @@ class DataArray( def __init__( self, data: Any = dtypes.NA, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | np.ndarray] | Mapping | None + ) = None, dims: str | Iterable[Hashable] | None = None, name: Hashable | None = None, attrs: Mapping | None = None, diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 5cb52cbd25c..26de63f6e32 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -117,26 +117,27 @@ def wrapper(*args, **kwargs): return wrapper -def get_valid_numpy_dtype(array: np.ndarray | pd.Index): +def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype: """Return a numpy compatible dtype from either a numpy array or a pandas.Index. - Used for wrapping a pandas.Index as an xarray,Variable. + Used for wrapping a pandas.Index as an xarray.Variable. """ if isinstance(array, pd.PeriodIndex): - dtype = np.dtype("O") - elif hasattr(array, "categories"): + return np.dtype("O") + + if hasattr(array, "categories"): # category isn't a real numpy dtype dtype = array.categories.dtype if not is_valid_numpy_dtype(dtype): dtype = np.dtype("O") - elif not is_valid_numpy_dtype(array.dtype): - dtype = np.dtype("O") - else: - dtype = array.dtype + return dtype + + if not is_valid_numpy_dtype(array.dtype): + return np.dtype("O") - return dtype + return array.dtype def maybe_coerce_to_str(index, original_coords): From 27bd0ab834160b8f283a8b35774ccf847d37c228 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 1 Jul 2024 22:41:34 +0200 Subject: [PATCH 04/23] fix several typing issues in tests --- xarray/tests/test_backends.py | 8 ++- xarray/tests/test_cftimeindex.py | 36 +++++------ xarray/tests/test_coding_times.py | 100 +++++++++++++++--------------- xarray/tests/test_conventions.py | 2 +- xarray/tests/test_dataarray.py | 9 ++- xarray/tests/test_formatting.py | 4 +- xarray/tests/test_indexes.py | 25 ++++---- 7 files changed, 94 insertions(+), 90 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 15485dc178a..ec3885bb078 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -568,7 +568,7 @@ def test_roundtrip_cftime_datetime_data(self) -> None: assert actual.t.encoding["calendar"] == expected_calendar def test_roundtrip_timedelta_data(self) -> None: - time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) + time_deltas = pd.to_timedelta(["1h", "2h", "NaT"]) # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 expected = Dataset({"td": ("td", time_deltas), "td0": time_deltas[0]}) with self.roundtrip(expected) as actual: assert_identical(expected, actual) @@ -5598,7 +5598,9 @@ def test_h5netcdf_entrypoint(tmp_path: Path) -> None: @requires_netCDF4 @pytest.mark.parametrize("str_type", (str, np.str_)) -def test_write_file_from_np_str(str_type, tmpdir) -> None: +def test_write_file_from_np_str( + str_type: type[str] | type[np.str_], tmpdir: str +) -> None: # https://github.com/pydata/xarray/pull/5264 scenarios = [str_type(v) for v in ["scenario_a", "scenario_b", "scenario_c"]] years = range(2015, 2100 + 1) @@ -5609,7 +5611,7 @@ def test_write_file_from_np_str(str_type, tmpdir) -> None: ) tdf.index.name = "scenario" tdf.columns.name = "year" - tdf = tdf.stack() + tdf = cast(pd.DataFrame, tdf.stack()) tdf.name = "tas" txr = tdf.to_xarray() diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index f6eb15fa373..2c74581fcac 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -679,11 +679,11 @@ def test_indexing_in_series_loc(series, index, scalar_args, range_args): @requires_cftime def test_indexing_in_series_iloc(series, index): - expected = 1 - assert series.iloc[0] == expected + expected1 = 1 + assert series.iloc[0] == expected1 - expected = pd.Series([1, 2], index=index[:2]) - assert series.iloc[:2].equals(expected) + expected2 = pd.Series([1, 2], index=index[:2]) + assert series.iloc[:2].equals(expected2) @requires_cftime @@ -696,27 +696,27 @@ def test_series_dropna(index): @requires_cftime def test_indexing_in_dataframe_loc(df, index, scalar_args, range_args): - expected = pd.Series([1], name=index[0]) + expected_s = pd.Series([1], name=index[0]) for arg in scalar_args: - result = df.loc[arg] - assert result.equals(expected) + result_s = df.loc[arg] + assert result_s.equals(expected_s) - expected = pd.DataFrame([1, 2], index=index[:2]) + expected_df = pd.DataFrame([1, 2], index=index[:2]) for arg in range_args: - result = df.loc[arg] - assert result.equals(expected) + result_df = df.loc[arg] + assert result_df.equals(expected_df) @requires_cftime def test_indexing_in_dataframe_iloc(df, index): - expected = pd.Series([1], name=index[0]) - result = df.iloc[0] - assert result.equals(expected) - assert result.equals(expected) - - expected = pd.DataFrame([1, 2], index=index[:2]) - result = df.iloc[:2] - assert result.equals(expected) + expected_s = pd.Series([1], name=index[0]) + result_s = df.iloc[0] + assert result_s.equals(expected_s) + assert result_s.equals(expected_s) + + expected_df = pd.DataFrame([1, 2], index=index[:2]) + result_df = df.iloc[:2] + assert result_df.equals(expected_df) @requires_cftime diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 393f8400c46..87a6dc5c0a4 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -20,6 +20,7 @@ decode_cf, ) from xarray.coding.times import ( + CFDatetimeCoder, _encode_datetime_with_cftime, _numpy_to_netcdf_timeunit, _should_cftime_be_used, @@ -28,6 +29,9 @@ decode_cf_timedelta, encode_cf_datetime, encode_cf_timedelta, + format_cftime_datetime, + infer_datetime_units, + infer_timedelta_units, to_timedelta_unboxed, ) from xarray.coding.variables import SerializationWarning @@ -130,7 +134,7 @@ def test_cf_datetime(num_dates, units, calendar) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(num_dates, units, calendar) + actual = decode_cf_datetime(num_dates, units, calendar) abs_diff = np.asarray(abs(actual - expected)).ravel() abs_diff = pd.to_timedelta(abs_diff.tolist()).to_numpy() @@ -139,17 +143,15 @@ def test_cf_datetime(num_dates, units, calendar) -> None: # we could do this check with near microsecond accuracy: # https://github.com/Unidata/netcdf4-python/issues/355 assert (abs_diff <= np.timedelta64(1, "s")).all() - encoded, _, _ = coding.times.encode_cf_datetime(actual, units, calendar) + encoded1, _, _ = encode_cf_datetime(actual, units, calendar) + assert_array_equal(num_dates, np.around(encoded1, 1)) - assert_array_equal(num_dates, np.around(encoded, 1)) if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: # verify that wrapping with a pandas.Index works # note that it *does not* currently work to put # non-datetime64 compatible dates into a pandas.Index - encoded, _, _ = coding.times.encode_cf_datetime( - pd.Index(actual), units, calendar - ) - assert_array_equal(num_dates, np.around(encoded, 1)) + encoded2, _, _ = encode_cf_datetime(pd.Index(actual), units, calendar) + assert_array_equal(num_dates, np.around(encoded2, 1)) @requires_cftime @@ -169,7 +171,7 @@ def test_decode_cf_datetime_overflow() -> None: for i, day in enumerate(days): with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - result = coding.times.decode_cf_datetime(day, units) + result = decode_cf_datetime(day, units) assert result == expected[i] @@ -178,7 +180,7 @@ def test_decode_cf_datetime_non_standard_units() -> None: # netCDFs from madis.noaa.gov use this format for their time units # they cannot be parsed by cftime, but pd.Timestamp works units = "hours since 1-1-1970" - actual = coding.times.decode_cf_datetime(np.arange(100), units) + actual = decode_cf_datetime(np.arange(100), units) assert_array_equal(actual, expected) @@ -193,7 +195,7 @@ def test_decode_cf_datetime_non_iso_strings() -> None: (np.arange(100), "hours since 2000-01-01 0:00"), ] for num_dates, units in cases: - actual = coding.times.decode_cf_datetime(num_dates, units) + actual = decode_cf_datetime(num_dates, units) abs_diff = abs(actual - expected.values) # once we no longer support versions of netCDF4 older than 1.1.5, # we could do this check with near microsecond accuracy: @@ -212,7 +214,7 @@ def test_decode_standard_calendar_inside_timestamp_range(calendar) -> None: expected = times.values expected_dtype = np.dtype("M8[ns]") - actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) + actual = decode_cf_datetime(time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -235,9 +237,7 @@ def test_decode_non_standard_calendar_inside_timestamp_range(calendar) -> None: ) expected_dtype = np.dtype("O") - actual = coding.times.decode_cf_datetime( - non_standard_time, units, calendar=calendar - ) + actual = decode_cf_datetime(non_standard_time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -264,7 +264,7 @@ def test_decode_dates_outside_timestamp_range(calendar) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(time, units, calendar=calendar) + actual = decode_cf_datetime(time, units, calendar=calendar) assert all(isinstance(value, expected_date_type) for value in actual) abs_diff = abs(actual - expected) # once we no longer support versions of netCDF4 older than 1.1.5, @@ -282,7 +282,7 @@ def test_decode_standard_calendar_single_element_inside_timestamp_range( for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + actual = decode_cf_datetime(num_time, units, calendar=calendar) assert actual.dtype == np.dtype("M8[ns]") @@ -295,7 +295,7 @@ def test_decode_non_standard_calendar_single_element_inside_timestamp_range( for num_time in [735368, [735368], [[735368]]]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + actual = decode_cf_datetime(num_time, units, calendar=calendar) assert actual.dtype == np.dtype("O") @@ -309,9 +309,7 @@ def test_decode_single_element_outside_timestamp_range(calendar) -> None: for num_time in [days, [days], [[days]]]: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime( - num_time, units, calendar=calendar - ) + actual = decode_cf_datetime(num_time, units, calendar=calendar) expected = cftime.num2date( days, units, calendar, only_use_cftime_datetimes=True @@ -338,7 +336,7 @@ def test_decode_standard_calendar_multidim_time_inside_timestamp_range( expected1 = times1.values expected2 = times2.values - actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + actual = decode_cf_datetime(mdim_time, units, calendar=calendar) assert actual.dtype == np.dtype("M8[ns]") abs_diff1 = abs(actual[:, 0] - expected1) @@ -379,7 +377,7 @@ def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range( expected_dtype = np.dtype("O") - actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + actual = decode_cf_datetime(mdim_time, units, calendar=calendar) assert actual.dtype == expected_dtype abs_diff1 = abs(actual[:, 0] - expected1) @@ -412,7 +410,7 @@ def test_decode_multidim_time_outside_timestamp_range(calendar) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "Unable to decode time axis") - actual = coding.times.decode_cf_datetime(mdim_time, units, calendar=calendar) + actual = decode_cf_datetime(mdim_time, units, calendar=calendar) assert actual.dtype == np.dtype("O") @@ -435,7 +433,7 @@ def test_decode_non_standard_calendar_single_element(calendar, num_time) -> None units = "days since 0001-01-01" - actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) + actual = decode_cf_datetime(num_time, units, calendar=calendar) expected = np.asarray( cftime.num2date(num_time, units, calendar, only_use_cftime_datetimes=True) @@ -460,9 +458,7 @@ def test_decode_360_day_calendar() -> None: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - actual = coding.times.decode_cf_datetime( - num_times, units, calendar=calendar - ) + actual = decode_cf_datetime(num_times, units, calendar=calendar) assert len(w) == 0 assert actual.dtype == np.dtype("O") @@ -476,8 +472,8 @@ def test_decode_abbreviation() -> None: val = np.array([1586628000000.0]) units = "msecs since 1970-01-01T00:00:00Z" - actual = coding.times.decode_cf_datetime(val, units) - expected = coding.times.cftime_to_nptime(cftime.num2date(val, units)) + actual = decode_cf_datetime(val, units) + expected = cftime_to_nptime(cftime.num2date(val, units)) assert_array_equal(actual, expected) @@ -498,7 +494,7 @@ def test_decode_abbreviation() -> None: def test_cf_datetime_nan(num_dates, units, expected_list) -> None: with warnings.catch_warnings(): warnings.filterwarnings("ignore", "All-NaN") - actual = coding.times.decode_cf_datetime(num_dates, units) + actual = decode_cf_datetime(num_dates, units) # use pandas because numpy will deprecate timezone-aware conversions expected = pd.to_datetime(expected_list).to_numpy(dtype="datetime64[ns]") assert_array_equal(expected, actual) @@ -510,7 +506,7 @@ def test_decoded_cf_datetime_array_2d() -> None: variable = Variable( ("x", "y"), np.array([[0, 1], [2, 3]]), {"units": "days since 2000-01-01"} ) - result = coding.times.CFDatetimeCoder().decode(variable) + result = CFDatetimeCoder().decode(variable) assert result.dtype == "datetime64[ns]" expected = pd.date_range("2000-01-01", periods=4).values.reshape(2, 2) assert_array_equal(np.asarray(result), expected) @@ -531,7 +527,7 @@ def test_decoded_cf_datetime_array_2d() -> None: def test_infer_datetime_units(freq, units) -> None: dates = pd.date_range("2000", periods=2, freq=freq) expected = f"{units} since 2000-01-01 00:00:00" - assert expected == coding.times.infer_datetime_units(dates) + assert expected == infer_datetime_units(dates) @pytest.mark.parametrize( @@ -549,7 +545,7 @@ def test_infer_datetime_units(freq, units) -> None: ], ) def test_infer_datetime_units_with_NaT(dates, expected) -> None: - assert expected == coding.times.infer_datetime_units(dates) + assert expected == infer_datetime_units(dates) _CFTIME_DATETIME_UNITS_TESTS = [ @@ -573,7 +569,7 @@ def test_infer_datetime_units_with_NaT(dates, expected) -> None: def test_infer_cftime_datetime_units(calendar, date_args, expected) -> None: date_type = _all_cftime_date_types()[calendar] dates = [date_type(*args) for args in date_args] - assert expected == coding.times.infer_datetime_units(dates) + assert expected == infer_datetime_units(dates) @pytest.mark.filterwarnings("ignore:Timedeltas can't be serialized faithfully") @@ -600,18 +596,18 @@ def test_cf_timedelta(timedeltas, units, numbers) -> None: numbers = np.array(numbers) expected = numbers - actual, _ = coding.times.encode_cf_timedelta(timedeltas, units) + actual, _ = encode_cf_timedelta(timedeltas, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype if units is not None: expected = timedeltas - actual = coding.times.decode_cf_timedelta(numbers, units) + actual = decode_cf_timedelta(numbers, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype expected = np.timedelta64("NaT", "ns") - actual = coding.times.decode_cf_timedelta(np.array(np.nan), "days") + actual = decode_cf_timedelta(np.array(np.nan), "days") assert_array_equal(expected, actual) @@ -622,7 +618,7 @@ def test_cf_timedelta_2d() -> None: timedeltas = np.atleast_2d(to_timedelta_unboxed(["1D", "2D", "3D"])) expected = timedeltas - actual = coding.times.decode_cf_timedelta(numbers, units) + actual = decode_cf_timedelta(numbers, units) assert_array_equal(expected, actual) assert expected.dtype == actual.dtype @@ -630,14 +626,14 @@ def test_cf_timedelta_2d() -> None: @pytest.mark.parametrize( ["deltas", "expected"], [ - (pd.to_timedelta(["1 day", "2 days"]), "days"), - (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), - (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), - (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), + (pd.to_timedelta(["1 day", "2 days"]), "days"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1h", "1 day 1 hour"]), "hours"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1m", "2m", np.nan]), "minutes"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 + (pd.to_timedelta(["1m3s", "1m4s"]), "seconds"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 ], ) def test_infer_timedelta_units(deltas, expected) -> None: - assert expected == coding.times.infer_timedelta_units(deltas) + assert expected == infer_timedelta_units(deltas) @requires_cftime @@ -653,7 +649,7 @@ def test_infer_timedelta_units(deltas, expected) -> None: def test_format_cftime_datetime(date_args, expected) -> None: date_types = _all_cftime_date_types() for date_type in date_types.values(): - result = coding.times.format_cftime_datetime(date_type(*date_args)) + result = format_cftime_datetime(date_type(*date_args)) assert result == expected @@ -1043,7 +1039,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype( pytest.skip("Nanosecond frequency is not valid for cftime dates.") times = date_range("2000", periods=3, freq=freq) units = f"{encoding_units} since 2000-01-01" - encoded, _units, _ = coding.times.encode_cf_datetime(times, units) + encoded, _units, _ = encode_cf_datetime(times, units) numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) @@ -1202,7 +1198,7 @@ def test_decode_float_datetime(): def test_scalar_unit() -> None: # test that a scalar units (often NaN when using to_netcdf) does not raise an error variable = Variable(("x", "y"), np.array([[0, 1], [2, 3]]), {"units": np.nan}) - result = coding.times.CFDatetimeCoder().decode(variable) + result = CFDatetimeCoder().decode(variable) assert np.isnan(result.attrs["units"]) @@ -1436,8 +1432,8 @@ def test_roundtrip_float_times() -> None: def test_encode_cf_datetime_datetime64_via_dask(freq, units, dtype) -> None: import dask.array - times = pd.date_range(start="1700", freq=freq, periods=3) - times = dask.array.from_array(times, chunks=1) + times_pd = pd.date_range(start="1700", freq=freq, periods=3) + times = dask.array.from_array(times_pd, chunks=1) encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( times, units, None, dtype ) @@ -1560,11 +1556,13 @@ def test_encode_cf_datetime_casting_overflow_error(use_cftime, use_dask, dtype) @pytest.mark.parametrize( ("units", "dtype"), [("days", np.dtype("int32")), (None, None)] ) -def test_encode_cf_timedelta_via_dask(units, dtype) -> None: +def test_encode_cf_timedelta_via_dask( + units: str | None, dtype: np.dtype | None +) -> None: import dask.array - times = pd.timedelta_range(start="0D", freq="D", periods=3) - times = dask.array.from_array(times, chunks=1) + times_pd = pd.timedelta_range(start="0D", freq="D", periods=3) + times = dask.array.from_array(times_pd, chunks=1) encoded_times, encoding_units = encode_cf_timedelta(times, units, dtype) assert is_duck_dask_array(encoded_times) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index fdfea3c3fe8..dc0b270dc51 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -119,7 +119,7 @@ def test_incompatible_attributes(self) -> None: Variable( ["t"], pd.date_range("2000-01-01", periods=3), {"units": "foobar"} ), - Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), + Variable(["t"], pd.to_timedelta(["1 day"]), {"units": "foobar"}), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 Variable(["t"], [0, 1, 2], {"add_offset": 0}, {"add_offset": 2}), Variable(["t"], [0, 1, 2], {"_FillValue": 0}, {"_FillValue": 2}), ] diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 659c7c168a5..af3134cbdaa 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -341,12 +341,15 @@ def test_constructor(self) -> None: assert_identical(expected, actual) # list coords, w dims - coords1 = [["a", "b"], [-1, -2, -3]] + coords1: list[Any] = [["a", "b"], [-1, -2, -3]] actual = DataArray(data, coords1, ["x", "y"]) assert_identical(expected, actual) # pd.Index coords, w dims - coords2 = [pd.Index(["a", "b"], name="A"), pd.Index([-1, -2, -3], name="B")] + coords2: list[pd.Index] = [ + pd.Index(["a", "b"], name="A"), + pd.Index([-1, -2, -3], name="B"), + ] actual = DataArray(data, coords2, ["x", "y"]) assert_identical(expected, actual) @@ -3539,7 +3542,7 @@ def test_nbytes_does_not_load_data(self) -> None: def test_to_and_from_empty_series(self) -> None: # GH697 - expected = pd.Series([], dtype=np.float64) + expected: pd.Series[Any] = pd.Series([], dtype=np.float64) da = DataArray.from_series(expected) assert len(da) == 0 actual = da.to_series() diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 6c49ab456f6..3786b3876b9 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -107,9 +107,9 @@ def test_format_items(self) -> None: np.arange(4) * np.timedelta64(500, "ms"), "00:00:00 00:00:00.500000 00:00:01 00:00:01.500000", ), - (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), + (pd.to_timedelta(["NaT", "0s", "1s", "NaT"]), "NaT 00:00:00 00:00:01 NaT"), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 ( - pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), + pd.to_timedelta(["1 day 1 hour", "1 day", "0 hours"]), # type: ignore[arg-type] #https://github.com/pandas-dev/pandas-stubs/issues/956 "1 days 01:00:00 1 days 00:00:00 0 days 00:00:00", ), ([1, 2, 3], "1 2 3"), diff --git a/xarray/tests/test_indexes.py b/xarray/tests/test_indexes.py index 5ebdfd5da6e..48e254b037b 100644 --- a/xarray/tests/test_indexes.py +++ b/xarray/tests/test_indexes.py @@ -410,13 +410,15 @@ def test_stack(self) -> None: "y": xr.Variable("y", pd.Index([1, 3, 2])), } - index = PandasMultiIndex.stack(prod_vars, "z") + index_xr = PandasMultiIndex.stack(prod_vars, "z") - assert index.dim == "z" + assert index_xr.dim == "z" + index_pd = index_xr.index + assert isinstance(index_pd, pd.MultiIndex) # TODO: change to tuple when pandas 3 is minimum - assert list(index.index.names) == ["x", "y"] + assert list(index_pd.names) == ["x", "y"] np.testing.assert_array_equal( - index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] + index_pd.codes, [[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]] ) with pytest.raises( @@ -433,13 +435,15 @@ def test_stack_non_unique(self) -> None: "y": xr.Variable("y", pd.Index([1, 1, 2])), } - index = PandasMultiIndex.stack(prod_vars, "z") + index_xr = PandasMultiIndex.stack(prod_vars, "z") + index_pd = index_xr.index + assert isinstance(index_pd, pd.MultiIndex) np.testing.assert_array_equal( - index.index.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] + index_pd.codes, [[0, 0, 0, 1, 1, 1], [0, 0, 1, 0, 0, 1]] ) - np.testing.assert_array_equal(index.index.levels[0], ["b", "a"]) - np.testing.assert_array_equal(index.index.levels[1], [1, 2]) + np.testing.assert_array_equal(index_pd.levels[0], ["b", "a"]) + np.testing.assert_array_equal(index_pd.levels[1], [1, 2]) def test_unstack(self) -> None: pd_midx = pd.MultiIndex.from_product( @@ -600,10 +604,7 @@ def indexes( _, variables = indexes_and_vars - if isinstance(x_idx, Index): - index_type = Index - else: - index_type = pd.Index + index_type = Index if isinstance(x_idx, Index) else pd.Index return Indexes(indexes, variables, index_type=index_type) From 0dcb2a14c09876263fda879e4da51ed1b066bd6d Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 1 Jul 2024 23:15:24 +0200 Subject: [PATCH 05/23] fix more types --- xarray/core/dataarray.py | 22 ++++++--- xarray/core/missing.py | 4 +- xarray/tests/test_dataarray.py | 82 ++++++++++++++++++---------------- 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 6264970704b..c29d8b7b5a4 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -135,7 +135,11 @@ def _check_coords_dims(shape, coords, dim): def _infer_coords_and_dims( shape: tuple[int, ...], - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None + ), dims: str | Iterable[Hashable] | None, ) -> tuple[Mapping[Hashable, Any], tuple[Hashable, ...]]: """All the logic for creating a new DataArray""" @@ -199,7 +203,11 @@ def _infer_coords_and_dims( def _check_data_shape( data: Any, - coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None, + coords: ( + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None + ), dims: str | Iterable[Hashable] | None, ) -> Any: if data is dtypes.NA: @@ -414,7 +422,9 @@ def __init__( self, data: Any = dtypes.NA, coords: ( - Sequence[Sequence | pd.Index | DataArray | np.ndarray] | Mapping | None + Sequence[Sequence | pd.Index | DataArray | Variable | np.ndarray] + | Mapping + | None ) = None, dims: str | Iterable[Hashable] | None = None, name: Hashable | None = None, @@ -3006,7 +3016,7 @@ def to_unstacked_dataset(self, dim: Hashable, level: int | Hashable = 0) -> Data if not isinstance(idx, pd.MultiIndex): raise ValueError(f"'{dim}' is not a stacked coordinate") - level_number = idx._get_level_number(level) + level_number = idx._get_level_number(level) # type: ignore[attr-defined] variables = idx.levels[level_number] variable_dim = idx.names[level_number] @@ -3840,7 +3850,7 @@ def to_pandas(self) -> Self | pd.Series | pd.DataFrame: "pandas objects. Requires 2 or fewer dimensions." ) indexes = [self.get_index(dim) for dim in self.dims] - return constructor(self.values, *indexes) + return constructor(self.values, *indexes) # type: ignore[operator] def to_dataframe( self, name: Hashable | None = None, dim_order: Sequence[Hashable] | None = None @@ -6843,7 +6853,7 @@ def groupby_bins( _validate_groupby_squeeze(squeeze) grouper = BinGrouper( - bins=bins, + bins=bins, # type: ignore[arg-type] # TODO: fix this arg or BinGrouper cut_kwargs={ "right": right, "labels": labels, diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 45abc70c0d3..bfbad72649a 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -315,7 +315,9 @@ def interp_na( use_coordinate: bool | str = True, method: InterpOptions = "linear", limit: int | None = None, - max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None, + max_gap: ( + int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta | None + ) = None, keep_attrs: bool | None = None, **kwargs, ): diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index af3134cbdaa..eae87e03e41 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2491,7 +2491,7 @@ def test_stack_unstack(self) -> None: # test GH3000 a = orig[:0, :1].stack(new_dim=("x", "y")).indexes["new_dim"] b = pd.MultiIndex( - levels=[pd.Index([], np.int64), pd.Index([0], np.int64)], + levels=[pd.Index([], dtype=np.int64), pd.Index([0], dtype=np.int64)], codes=[[], []], names=["x", "y"], ) @@ -3329,28 +3329,28 @@ def test_broadcast_coordinates(self) -> None: def test_to_pandas(self) -> None: # 0d - actual = DataArray(42).to_pandas() + actual_xr = DataArray(42).to_pandas() expected = np.array(42) - assert_array_equal(actual, expected) + assert_array_equal(actual_xr, expected) # 1d values = np.random.randn(3) index = pd.Index(["a", "b", "c"], name="x") da = DataArray(values, coords=[index]) - actual = da.to_pandas() - assert_array_equal(actual.values, values) - assert_array_equal(actual.index, index) - assert_array_equal(actual.index.name, "x") + actual_s = da.to_pandas() + assert_array_equal(np.asarray(actual_s.values), values) + assert_array_equal(actual_s.index, index) + assert_array_equal(actual_s.index.name, "x") # 2d values = np.random.randn(3, 2) da = DataArray( values, coords=[("x", ["a", "b", "c"]), ("y", [0, 1])], name="foo" ) - actual = da.to_pandas() - assert_array_equal(actual.values, values) - assert_array_equal(actual.index, ["a", "b", "c"]) - assert_array_equal(actual.columns, [0, 1]) + actual_df = da.to_pandas() + assert_array_equal(np.asarray(actual_df.values), values) + assert_array_equal(actual_df.index, ["a", "b", "c"]) + assert_array_equal(actual_df.columns, [0, 1]) # roundtrips for shape in [(3,), (3, 4)]: @@ -3367,24 +3367,24 @@ def test_to_dataframe(self) -> None: arr_np = np.random.randn(3, 4) arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") - expected = arr.to_series() - actual = arr.to_dataframe()["foo"] - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.name, actual.name) - assert_array_equal(expected.index.values, actual.index.values) + expected_s = arr.to_series() + actual_s = arr.to_dataframe()["foo"] + assert_array_equal(np.asarray(expected_s.values), np.asarray(actual_s.values)) + assert_array_equal(np.asarray(expected_s.name), np.asarray(actual_s.name)) + assert_array_equal(expected_s.index.values, actual_s.index.values) - actual = arr.to_dataframe(dim_order=["A", "B"])["foo"] - assert_array_equal(arr_np.transpose().reshape(-1), actual.values) + actual_s = arr.to_dataframe(dim_order=["A", "B"])["foo"] + assert_array_equal(arr_np.transpose().reshape(-1), np.asarray(actual_s.values)) # regression test for coords with different dimensions arr.coords["C"] = ("B", [-1, -2, -3]) - expected = arr.to_series().to_frame() - expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 - expected = expected[["C", "foo"]] - actual = arr.to_dataframe() - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.columns.values, actual.columns.values) - assert_array_equal(expected.index.values, actual.index.values) + expected_df = arr.to_series().to_frame() + expected_df["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected_df = expected_df[["C", "foo"]] + actual_df = arr.to_dataframe() + assert_array_equal(np.asarray(expected_df.values), np.asarray(actual_df.values)) + assert_array_equal(expected_df.columns.values, actual_df.columns.values) + assert_array_equal(expected_df.index.values, actual_df.index.values) with pytest.raises(ValueError, match="does not match the set of dimensions"): arr.to_dataframe(dim_order=["B", "A", "C"]) @@ -3405,11 +3405,13 @@ def test_to_dataframe_multiindex(self) -> None: arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo") actual = arr.to_dataframe() - assert_array_equal(actual["foo"].values, arr_np.flatten()) - assert_array_equal(actual.index.names, list("ABC")) - assert_array_equal(actual.index.levels[0], [1, 2]) - assert_array_equal(actual.index.levels[1], ["a", "b"]) - assert_array_equal(actual.index.levels[2], [5, 6, 7]) + index_pd = actual.index + assert isinstance(index_pd, pd.MultiIndex) + assert_array_equal(np.asarray(actual["foo"].values), arr_np.flatten()) + assert_array_equal(index_pd.names, list("ABC")) + assert_array_equal(index_pd.levels[0], [1, 2]) + assert_array_equal(index_pd.levels[1], ["a", "b"]) + assert_array_equal(index_pd.levels[2], [5, 6, 7]) def test_to_dataframe_0length(self) -> None: # regression test for #3008 @@ -3429,10 +3431,10 @@ def test_to_dataframe_0length(self) -> None: def test_to_dask_dataframe(self) -> None: arr_np = np.arange(3 * 4).reshape(3, 4) arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") - expected = arr.to_series() + expected_s = arr.to_series() actual = arr.to_dask_dataframe()["foo"] - assert_array_equal(actual.values, expected.values) + assert_array_equal(actual.values, np.asarray(expected_s.values)) actual = arr.to_dask_dataframe(dim_order=["A", "B"])["foo"] assert_array_equal(arr_np.transpose().reshape(-1), actual.values) @@ -3440,13 +3442,15 @@ def test_to_dask_dataframe(self) -> None: # regression test for coords with different dimensions arr.coords["C"] = ("B", [-1, -2, -3]) - expected = arr.to_series().to_frame() - expected["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 - expected = expected[["C", "foo"]] + expected_df = arr.to_series().to_frame() + expected_df["C"] = [-1] * 4 + [-2] * 4 + [-3] * 4 + expected_df = expected_df[["C", "foo"]] actual = arr.to_dask_dataframe()[["C", "foo"]] - assert_array_equal(expected.values, actual.values) - assert_array_equal(expected.columns.values, actual.columns.values) + assert_array_equal(expected_df.values, np.asarray(actual.values)) + assert_array_equal( + expected_df.columns.values, np.asarray(actual.columns.values) + ) with pytest.raises(ValueError, match="does not match the set of dimensions"): arr.to_dask_dataframe(dim_order=["B", "A", "C"]) @@ -3462,8 +3466,8 @@ def test_to_pandas_name_matches_coordinate(self) -> None: # coordinate with same name as array arr = DataArray([1, 2, 3], dims="x", name="x") series = arr.to_series() - assert_array_equal([1, 2, 3], series.values) - assert_array_equal([0, 1, 2], series.index.values) + assert_array_equal([1, 2, 3], list(series.values)) + assert_array_equal([0, 1, 2], list(series.index.values)) assert "x" == series.name assert "x" == series.index.name From a2b5b4cc35f585d69e011b7670d728d6fe37eef0 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Tue, 2 Jul 2024 22:34:43 +0200 Subject: [PATCH 06/23] more fixes --- xarray/core/extension_array.py | 12 +++--- xarray/core/indexes.py | 71 ++++++++++++++++++-------------- xarray/core/indexing.py | 13 ++++-- xarray/core/types.py | 2 +- xarray/core/utils.py | 28 +++++-------- xarray/core/variable.py | 2 +- xarray/namedarray/daskmanager.py | 8 ++-- 7 files changed, 73 insertions(+), 63 deletions(-) diff --git a/xarray/core/extension_array.py b/xarray/core/extension_array.py index c8b4fa88409..b0361ef0f0f 100644 --- a/xarray/core/extension_array.py +++ b/xarray/core/extension_array.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Callable, Generic +from typing import Callable, Generic, cast import numpy as np import pandas as pd @@ -45,7 +45,7 @@ def __extension_duck_array__stack(arr: T_ExtensionArray, axis: int): def __extension_duck_array__concatenate( arrays: Sequence[T_ExtensionArray], axis: int = 0, out=None ) -> T_ExtensionArray: - return type(arrays[0])._concat_same_type(arrays) + return type(arrays[0])._concat_same_type(arrays) # type: ignore[attr-defined] @implements(np.where) @@ -57,9 +57,9 @@ def __extension_duck_array__where( and isinstance(y, pd.Categorical) and x.dtype != y.dtype ): - x = x.add_categories(set(y.categories).difference(set(x.categories))) - y = y.add_categories(set(x.categories).difference(set(y.categories))) - return pd.Series(x).where(condition, pd.Series(y)).array + x = x.add_categories(set(y.categories).difference(set(x.categories))) # type: ignore[assignment] + y = y.add_categories(set(x.categories).difference(set(y.categories))) # type: ignore[assignment] + return cast(T_ExtensionArray, pd.Series(x).where(condition, pd.Series(y)).array) class PandasExtensionArray(Generic[T_ExtensionArray]): @@ -116,7 +116,7 @@ def __getitem__(self, key) -> PandasExtensionArray[T_ExtensionArray]: if is_extension_array_dtype(item): return type(self)(item) if np.isscalar(item): - return type(self)(type(self.array)([item])) + return type(self)(type(self.array)([item])) # type: ignore[call-arg] # only subclasses with proper __init__ allowed return item def __setitem__(self, key, val): diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index f25c0ecf936..9d8a68edbf3 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -451,7 +451,7 @@ def safe_cast_to_index(array: Any) -> pd.Index: elif isinstance(array, PandasIndexingAdapter): index = array.array else: - kwargs: dict[str, str] = {} + kwargs: dict[str, Any] = {} if hasattr(array, "dtype"): if array.dtype.kind == "O": kwargs["dtype"] = "object" @@ -551,7 +551,7 @@ def as_scalar(value: np.ndarray): return value[()] if value.dtype.kind in "mM" else value.item() -def get_indexer_nd(index, labels, method=None, tolerance=None): +def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.ndarray: """Wrapper around :meth:`pandas.Index.get_indexer` supporting n-dimensional labels """ @@ -740,7 +740,7 @@ def isel( # scalar indexer: drop index return None - return self._replace(self.index[indxr]) + return self._replace(self.index[indxr]) # type: ignore[index] def sel( self, labels: dict[Any, Any], method=None, tolerance=None @@ -898,36 +898,45 @@ def _check_dim_compat(variables: Mapping[Any, Variable], all_dims: str = "equal" ) -def remove_unused_levels_categories(index: pd.Index) -> pd.Index: +T_PDIndex = TypeVar("T_PDIndex", bound=pd.Index) + + +def remove_unused_levels_categories(index: T_PDIndex) -> T_PDIndex: """ Remove unused levels from MultiIndex and unused categories from CategoricalIndex """ if isinstance(index, pd.MultiIndex): - index = index.remove_unused_levels() + new_index = cast(pd.MultiIndex, index.remove_unused_levels()) # if it contains CategoricalIndex, we need to remove unused categories # manually. See https://github.com/pandas-dev/pandas/issues/30846 - if any(isinstance(lev, pd.CategoricalIndex) for lev in index.levels): + if any(isinstance(lev, pd.CategoricalIndex) for lev in new_index.levels): levels = [] - for i, level in enumerate(index.levels): + for i, level in enumerate(new_index.levels): if isinstance(level, pd.CategoricalIndex): - level = level[index.codes[i]].remove_unused_categories() + level = level[new_index.codes[i]].remove_unused_categories() else: - level = level[index.codes[i]] + level = level[new_index.codes[i]] levels.append(level) # TODO: calling from_array() reorders MultiIndex levels. It would # be best to avoid this, if possible, e.g., by using # MultiIndex.remove_unused_levels() (which does not reorder) on the # part of the MultiIndex that is not categorical, or by fixing this # upstream in pandas. - index = pd.MultiIndex.from_arrays(levels, names=index.names) - elif isinstance(index, pd.CategoricalIndex): - index = index.remove_unused_categories() + new_index = pd.MultiIndex.from_arrays(levels, names=new_index.names) + return cast(T_PDIndex, new_index) + + if isinstance(index, pd.CategoricalIndex): + return index.remove_unused_categories() # type: ignore[attr-defined] + return index class PandasMultiIndex(PandasIndex): """Wrap a pandas.MultiIndex as an xarray compatible index.""" + index: pd.MultiIndex + dim: Hashable + coord_dtype: Any level_coords_dtype: dict[str, Any] __slots__ = ("index", "dim", "coord_dtype", "level_coords_dtype") @@ -1063,8 +1072,8 @@ def from_variables_maybe_expand( The index and its corresponding coordinates may be created along a new dimension. """ names: list[Hashable] = [] - codes: list[list[int]] = [] - levels: list[list[int]] = [] + codes: list[Iterable[int]] = [] + levels: list[Iterable[Any]] = [] level_variables: dict[Any, Variable] = {} _check_dim_compat({**current_variables, **variables}) @@ -1134,7 +1143,7 @@ def reorder_levels( its corresponding coordinates. """ - index = self.index.reorder_levels(level_variables.keys()) + index = cast(pd.MultiIndex, self.index.reorder_levels(level_variables.keys())) level_coords_dtype = {k: self.level_coords_dtype[k] for k in index.names} return self._replace(index, level_coords_dtype=level_coords_dtype) @@ -1147,13 +1156,13 @@ def create_variables( variables = {} index_vars: IndexVars = {} - for name in (self.dim,) + self.index.names: + for name in (self.dim,) + tuple(self.index.names): if name == self.dim: level = None dtype = None else: level = name - dtype = self.level_coords_dtype[name] + dtype = self.level_coords_dtype[name] # type: ignore[index] # TODO: are Hashables ok? var = variables.get(name, None) if var is not None: @@ -1163,7 +1172,7 @@ def create_variables( attrs = {} encoding = {} - data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) + data = PandasMultiIndexingAdapter(self.index, dtype=dtype, level=level) # type: ignore[arg-type] # TODO: are Hashables ok? index_vars[name] = IndexVariable( self.dim, data, @@ -1186,6 +1195,8 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: new_index = None scalar_coord_values = {} + indexer: int | slice | np.ndarray | Variable | DataArray + # label(s) given for multi-index level(s) if all([lbl in self.index.names for lbl in labels]): label_values = {} @@ -1212,7 +1223,7 @@ def sel(self, labels, method=None, tolerance=None) -> IndexSelResult: ) scalar_coord_values.update(label_values) # GH2619. Raise a KeyError if nothing is chosen - if indexer.dtype.kind == "b" and indexer.sum() == 0: + if indexer.dtype.kind == "b" and indexer.sum() == 0: # type: ignore[union-attr] raise KeyError(f"{labels} not found") # assume one label value given for the multi-index "array" (dimension) @@ -1600,9 +1611,7 @@ def group_by_index( """Returns a list of unique indexes and their corresponding coordinates.""" index_coords = [] - - for i in self._id_index: - index = self._id_index[i] + for i, index in self._id_index.items(): coords = {k: self._variables[k] for k in self._id_coord_names[i]} index_coords.append((index, coords)) @@ -1640,26 +1649,28 @@ def copy_indexes( in this dict. """ - new_indexes = {} - new_index_vars = {} + new_indexes: dict[Hashable, T_PandasOrXarrayIndex] = {} + new_index_vars: dict[Hashable, Variable] = {} - idx: T_PandasOrXarrayIndex + xr_idx: Index + new_idx: T_PandasOrXarrayIndex for idx, coords in self.group_by_index(): if isinstance(idx, pd.Index): convert_new_idx = True dim = next(iter(coords.values())).dims[0] if isinstance(idx, pd.MultiIndex): - idx = PandasMultiIndex(idx, dim) + xr_idx = PandasMultiIndex(idx, dim) else: - idx = PandasIndex(idx, dim) + xr_idx = PandasIndex(idx, dim) else: convert_new_idx = False + xr_idx = idx - new_idx = idx._copy(deep=deep, memo=memo) - idx_vars = idx.create_variables(coords) + new_idx = xr_idx._copy(deep=deep, memo=memo) # type: ignore[assignment] + idx_vars = xr_idx.create_variables(coords) if convert_new_idx: - new_idx = cast(PandasIndex, new_idx).index + new_idx = new_idx.index # type: ignore[attr-defined] new_indexes.update({k: new_idx for k in coords}) new_index_vars.update(idx_vars) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 06e7efdbb48..19937270268 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -34,6 +34,7 @@ from numpy.typing import DTypeLike from xarray.core.indexes import Index + from xarray.core.types import Self from xarray.core.variable import Variable from xarray.namedarray._typing import _Shape, duckarray from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint @@ -1656,6 +1657,9 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array", "_dtype") + array: pd.Index + _dtype: np.dtype + def __init__(self, array: pd.Index, dtype: DTypeLike = None): from xarray.core.indexes import safe_cast_to_index @@ -1792,7 +1796,7 @@ def transpose(self, order) -> pd.Index: def __repr__(self) -> str: return f"{type(self).__name__}(array={self.array!r}, dtype={self.dtype!r})" - def copy(self, deep: bool = True) -> PandasIndexingAdapter: + def copy(self, deep: bool = True) -> Self: # Not the same as just writing `self.array.copy(deep=deep)`, as # shallow copies of the underlying numpy.ndarrays become deep ones # upon pickling @@ -1810,11 +1814,14 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): This allows creating one instance for each multi-index level while preserving indexing efficiency (memoized + might reuse another instance with the same multi-index). - """ __slots__ = ("array", "_dtype", "level", "adapter") + array: pd.MultiIndex + _dtype: np.dtype + level: str | None + def __init__( self, array: pd.MultiIndex, @@ -1910,7 +1917,7 @@ def _repr_html_(self) -> str: array_repr = short_array_repr(self._get_array_subset()) return f"
{escape(array_repr)}
" - def copy(self, deep: bool = True) -> PandasMultiIndexingAdapter: + def copy(self, deep: bool = True) -> Self: # see PandasIndexingAdapter.copy array = self.array.copy(deep=True) if deep else self.array return type(self)(array, self._dtype, self.level) diff --git a/xarray/core/types.py b/xarray/core/types.py index 588d15dc35d..63e3522fb63 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -45,7 +45,7 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray # type: ignore + DaskArray = np.ndarray try: from cubed import Array as CubedArray diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 26de63f6e32..c2859632360 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -57,19 +57,12 @@ Mapping, MutableMapping, MutableSet, + Sequence, ValuesView, ) from enum import Enum from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Literal, - TypeVar, - overload, -) +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, overload import numpy as np import pandas as pd @@ -137,7 +130,7 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index) -> np.dtype: if not is_valid_numpy_dtype(array.dtype): return np.dtype("O") - return array.dtype + return array.dtype # type: ignore[return-value] def maybe_coerce_to_str(index, original_coords): @@ -184,18 +177,17 @@ def equivalent(first: T, second: T) -> bool: if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): return duck_array_ops.array_equiv(first, second) if isinstance(first, list) or isinstance(second, list): - return list_equiv(first, second) - return (first == second) or (pd.isnull(first) and pd.isnull(second)) + return list_equiv(first, second) # type: ignore[arg-type] + return (first == second) or (pd.isnull(first) and pd.isnull(second)) # type: ignore[call-overload] -def list_equiv(first, second): - equiv = True +def list_equiv(first: Sequence[T], second: Sequence[T]) -> bool: if len(first) != len(second): return False - else: - for f, s in zip(first, second): - equiv = equiv and equivalent(f, s) - return equiv + for f, s in zip(first, second): + if not equivalent(f, s): + return False + return True def peek_at(iterable: Iterable[T]) -> tuple[T, Iterator[T]]: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f0685882595..efc67dab2b8 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -298,7 +298,7 @@ def as_compatible_data( # we don't want nested self-described arrays if isinstance(data, (pd.Series, pd.DataFrame)): - data = data.values + data = data.values # type: ignore[assignment] if isinstance(data, np.ma.MaskedArray): mask = np.ma.getmaskarray(data) diff --git a/xarray/namedarray/daskmanager.py b/xarray/namedarray/daskmanager.py index 14744d2de6b..963d12fd865 100644 --- a/xarray/namedarray/daskmanager.py +++ b/xarray/namedarray/daskmanager.py @@ -21,13 +21,13 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray[Any, Any] # type: ignore[assignment, misc] + DaskArray = np.ndarray[Any, Any] dask_available = module_available("dask") -class DaskManager(ChunkManagerEntrypoint["DaskArray"]): # type: ignore[type-var] +class DaskManager(ChunkManagerEntrypoint["DaskArray"]): array_cls: type[DaskArray] available: bool = dask_available @@ -91,7 +91,7 @@ def array_api(self) -> Any: return da - def reduction( # type: ignore[override] + def reduction( self, arr: T_ChunkedArray, func: Callable[..., Any], @@ -113,7 +113,7 @@ def reduction( # type: ignore[override] keepdims=keepdims, ) # type: ignore[no-untyped-call] - def scan( # type: ignore[override] + def scan( self, func: Callable[..., Any], binop: Callable[..., Any], From 6a75746ec0684ef669f4fdcfc8aef2b5db81e14e Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Wed, 3 Jul 2024 22:27:19 +0200 Subject: [PATCH 07/23] more typing... --- xarray/coding/cftime_offsets.py | 22 ++++++------ xarray/coding/cftimeindex.py | 61 ++++++++++++++++++++------------- xarray/coding/times.py | 34 ++++++++++-------- xarray/core/concat.py | 42 +++++++++++++---------- xarray/core/types.py | 1 + xarray/core/variable.py | 2 +- 6 files changed, 94 insertions(+), 68 deletions(-) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index c2712569782..75493ecdce8 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -759,16 +759,16 @@ def _emit_freq_deprecation_warning(deprecated_freq): emit_user_level_warning(message, FutureWarning) -def to_offset(freq, warn=True): +def to_offset(freq: BaseCFTimeOffset | str, warn: bool = True) -> BaseCFTimeOffset: """Convert a frequency string to the appropriate subclass of BaseCFTimeOffset.""" if isinstance(freq, BaseCFTimeOffset): return freq - else: - try: - freq_data = re.match(_PATTERN, freq).groupdict() - except AttributeError: - raise ValueError("Invalid frequency string provided") + + match = re.match(_PATTERN, freq) + if match is None: + raise ValueError("Invalid frequency string provided") + freq_data = match.groupdict() freq = freq_data["freq"] if warn and freq in _DEPRECATED_FREQUENICES: @@ -909,7 +909,9 @@ def _translate_closed_to_inclusive(closed): return inclusive -def _infer_inclusive(closed, inclusive): +def _infer_inclusive( + closed: NoDefault | SideOptions, inclusive: InclusiveOptions | None +) -> InclusiveOptions: """Follows code added in pandas #43504.""" if closed is not no_default and inclusive is not None: raise ValueError( @@ -917,9 +919,9 @@ def _infer_inclusive(closed, inclusive): "passed if argument `inclusive` is not None." ) if closed is not no_default: - inclusive = _translate_closed_to_inclusive(closed) - elif inclusive is None: - inclusive = "both" + return _translate_closed_to_inclusive(closed) + if inclusive is None: + return "both" return inclusive diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 6898809e3b0..0747074ac46 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -45,6 +45,7 @@ import re import warnings from datetime import timedelta +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -64,6 +65,10 @@ except ImportError: cftime = None +if TYPE_CHECKING: + from xarray.coding.cftime_offsets import BaseCFTimeOffset + from xarray.core.types import Self + # constants for cftimeindex.repr CFTIME_REPR_LENGTH = 19 @@ -495,7 +500,7 @@ def get_value(self, series, key): else: return series.iloc[self.get_loc(key)] - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: """Adapted from pandas.tseries.base.DatetimeIndexOpsMixin.__contains__""" try: @@ -503,16 +508,20 @@ def __contains__(self, key): return ( is_scalar(result) or type(result) == slice - or (isinstance(result, np.ndarray) and result.size) + or (isinstance(result, np.ndarray) and result.size > 0) ) except (KeyError, TypeError, ValueError): return False - def contains(self, key): + def contains(self, key: Any) -> bool: """Needed for .loc based partial-string indexing""" return self.__contains__(key) - def shift(self, n: int | float, freq: str | timedelta): + def shift( # type: ignore[override] # freq is typed Any, we are more precise + self, + periods: int | float, + freq: str | timedelta | BaseCFTimeOffset | None = None, + ) -> Self: """Shift the CFTimeIndex a multiple of the given frequency. See the documentation for :py:func:`~xarray.cftime_range` for a @@ -520,9 +529,9 @@ def shift(self, n: int | float, freq: str | timedelta): Parameters ---------- - n : int, float if freq of days or below + periods : int, float if freq of days or below Periods to shift by - freq : str or datetime.timedelta + freq : str, datetime.timedelta or BaseCFTimeOffset A frequency string or datetime.timedelta object to shift by Returns @@ -546,33 +555,40 @@ def shift(self, n: int | float, freq: str | timedelta): CFTimeIndex([2000-02-01 12:00:00], dtype='object', length=1, calendar='standard', freq=None) """ + if freq is None: + # None type is required to be compatible with base pd.Index class + raise TypeError( + f"`freq` argument cannot be None for {type(self).__name__}.shift" + ) + if isinstance(freq, timedelta): - return self + n * freq - elif isinstance(freq, str): + return self + periods * freq + + if isinstance(freq, (str, BaseCFTimeOffset)): from xarray.coding.cftime_offsets import to_offset - return self + n * to_offset(freq) - else: - raise TypeError( - f"'freq' must be of type str or datetime.timedelta, got {freq}." - ) + return self + periods * to_offset(freq) + + raise TypeError( + f"'freq' must be of type str or datetime.timedelta, got {freq}." + ) - def __add__(self, other): + def __add__(self, other) -> Self: if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() - return CFTimeIndex(np.array(self) + other) + return type(self)(np.array(self) + other) - def __radd__(self, other): + def __radd__(self, other) -> Self: if isinstance(other, pd.TimedeltaIndex): other = other.to_pytimedelta() - return CFTimeIndex(other + np.array(self)) + return type(self)(other + np.array(self)) def __sub__(self, other): if _contains_datetime_timedeltas(other): - return CFTimeIndex(np.array(self) - other) - elif isinstance(other, pd.TimedeltaIndex): - return CFTimeIndex(np.array(self) - other.to_pytimedelta()) - elif _contains_cftime_datetimes(np.array(other)): + return type(self)(np.array(self) - other) + if isinstance(other, pd.TimedeltaIndex): + return type(self)(np.array(self) - other.to_pytimedelta()) + if _contains_cftime_datetimes(np.array(other)): try: return pd.TimedeltaIndex(np.array(self) - np.array(other)) except OUT_OF_BOUNDS_TIMEDELTA_ERRORS: @@ -580,8 +596,7 @@ def __sub__(self, other): "The time difference exceeds the range of values " "that can be expressed at the nanosecond resolution." ) - else: - return NotImplemented + return NotImplemented def __rsub__(self, other): try: diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 34d4f9a23ad..0804956fb57 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -5,7 +5,7 @@ from collections.abc import Hashable from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, Callable, Union +from typing import TYPE_CHECKING, Callable, Literal, Union, cast import numpy as np import pandas as pd @@ -37,7 +37,7 @@ cftime = None if TYPE_CHECKING: - from xarray.core.types import CFCalendar, T_DuckArray + from xarray.core.types import CFCalendar, NPDatetimeUnitOptions, T_DuckArray T_Name = Union[Hashable, None] @@ -111,22 +111,25 @@ def _is_numpy_compatible_time_range(times): return True -def _netcdf_to_numpy_timeunit(units: str) -> str: +def _netcdf_to_numpy_timeunit(units: str) -> NPDatetimeUnitOptions: units = units.lower() if not units.endswith("s"): units = f"{units}s" - return { - "nanoseconds": "ns", - "microseconds": "us", - "milliseconds": "ms", - "seconds": "s", - "minutes": "m", - "hours": "h", - "days": "D", - }[units] + return cast( + NPDatetimeUnitOptions, + { + "nanoseconds": "ns", + "microseconds": "us", + "milliseconds": "ms", + "seconds": "s", + "minutes": "m", + "hours": "h", + "days": "D", + }[units], + ) -def _numpy_to_netcdf_timeunit(units: str) -> str: +def _numpy_to_netcdf_timeunit(units: NPDatetimeUnitOptions) -> str: return { "ns": "nanoseconds", "us": "microseconds", @@ -252,12 +255,12 @@ def _decode_datetime_with_pandas( "pandas." ) - time_units, ref_date = _unpack_netcdf_time_units(units) + time_units, ref_date_str = _unpack_netcdf_time_units(units) time_units = _netcdf_to_numpy_timeunit(time_units) try: # TODO: the strict enforcement of nanosecond precision Timestamps can be # relaxed when addressing GitHub issue #7493. - ref_date = nanosecond_precision_timestamp(ref_date) + ref_date = nanosecond_precision_timestamp(ref_date_str) except ValueError: # ValueError is raised by pd.Timestamp for non-ISO timestamp # strings, in which case we fall back to using cftime @@ -471,6 +474,7 @@ def cftime_to_nptime(times, raise_on_invalid: bool = True) -> np.ndarray: # TODO: the strict enforcement of nanosecond precision datetime values can # be relaxed when addressing GitHub issue #7493. new = np.empty(times.shape, dtype="M8[ns]") + dt: pd.Timestamp | Literal["NaT"] for i, t in np.ndenumerate(times): try: # Use pandas.Timestamp in place of datetime.datetime, because diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 5e6af84c3b6..687278083a8 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -304,7 +304,7 @@ def _calc_concat_dim_index( dim: Hashable | None - if isinstance(dim_or_data, str): + if utils.hashable(dim_or_data): dim = dim_or_data index = None else: @@ -475,7 +475,7 @@ def _parse_datasets( def _dataset_concat( - datasets: list[T_Dataset], + datasets: Iterable[T_Dataset], dim: str | T_Variable | T_DataArray | pd.Index, data_vars: T_DataVars, coords: str | list[str], @@ -506,12 +506,14 @@ def _dataset_concat( else: dim_var = None - dim, index = _calc_concat_dim_index(dim) + dim_name, index = _calc_concat_dim_index(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] datasets = list( - align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value) + align( + *datasets, join=join, copy=False, exclude=[dim_name], fill_value=fill_value + ) ) dim_coords, dims_sizes, coord_names, data_names, vars_order = _parse_datasets( @@ -525,19 +527,21 @@ def _dataset_concat( f"{both_data_and_coords!r} is a coordinate in some datasets but not others." ) # we don't want the concat dimension in the result dataset yet - dim_coords.pop(dim, None) - dims_sizes.pop(dim, None) + dim_coords.pop(dim_name, None) + dims_sizes.pop(dim_name, None) # case where concat dimension is a coordinate or data_var but not a dimension - if (dim in coord_names or dim in data_names) and dim not in dim_names: + if ( + dim_name in coord_names or dim_name in data_names + ) and dim_name not in dim_names: datasets = [ - ds.expand_dims(dim, create_index_for_new_dim=create_index_for_new_dim) + ds.expand_dims(dim_name, create_index_for_new_dim=create_index_for_new_dim) for ds in datasets ] # determine which variables to concatenate concat_over, equals, concat_dim_lengths = _calc_concat_over( - datasets, dim, dim_names, data_vars, coords, compat + datasets, dim_name, dim_names, data_vars, coords, compat ) # determine which variables to merge, and then merge them according to compat @@ -577,8 +581,8 @@ def ensure_common_dims(vars, concat_dim_lengths): # dimensions and the same shape for all of them except along the # concat dimension common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims)) - if dim not in common_dims: - common_dims = (dim,) + common_dims + if dim_name not in common_dims: + common_dims = (dim_name,) + common_dims for var, dim_len in zip(vars, concat_dim_lengths): if var.dims != common_dims: common_shape = tuple(dims_sizes.get(d, dim_len) for d in common_dims) @@ -597,9 +601,9 @@ def get_indexes(name): elif name == dim: var = ds._variables[name] if not var.dims: - data = var.set_dims(dim).values + data = var.set_dims(dim_name).values if create_index_for_new_dim: - yield PandasIndex(data, dim, coord_dtype=var.dtype) + yield PandasIndex(data, dim_name, coord_dtype=var.dtype) # create concatenation index, needed for later reindexing file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths)) @@ -645,7 +649,7 @@ def get_indexes(name): f"{name!r} must have either an index or no index in all datasets, " f"found {len(indexes)}/{len(datasets)} datasets with an index." ) - combined_idx = indexes[0].concat(indexes, dim, positions) + combined_idx = indexes[0].concat(indexes, dim_name, positions) if name in datasets[0]._indexes: idx_vars = datasets[0].xindexes.get_all_coords(name) else: @@ -661,14 +665,14 @@ def get_indexes(name): result_vars[k] = v else: combined_var = concat_vars( - vars, dim, positions, combine_attrs=combine_attrs + vars, dim_name, positions, combine_attrs=combine_attrs ) # reindex if variable is not present in all datasets if len(variable_index) < concat_index_size: combined_var = reindex_variables( variables={name: combined_var}, dim_pos_indexers={ - dim: pd.Index(variable_index).get_indexer(concat_index) + dim_name: pd.Index(variable_index).get_indexer(concat_index) }, fill_value=fill_value, )[name] @@ -694,12 +698,12 @@ def get_indexes(name): if index is not None: if dim_var is not None: - index_vars = index.create_variables({dim: dim_var}) + index_vars = index.create_variables({dim_name: dim_var}) else: index_vars = index.create_variables() - coord_vars[dim] = index_vars[dim] - result_indexes[dim] = index + coord_vars[dim_name] = index_vars[dim_name] + result_indexes[dim_name] = index coords_obj = Coordinates(coord_vars, indexes=result_indexes) diff --git a/xarray/core/types.py b/xarray/core/types.py index 63e3522fb63..b1d0424ac11 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -219,6 +219,7 @@ def copy( DatetimeUnitOptions = Literal[ "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "μs", "ns", "ps", "fs", "as", None ] +NPDatetimeUnitOptions = Literal["D", "h", "m", "s", "ms", "us", "ns"] QueryEngineOptions = Literal["python", "numexpr", None] QueryParserOptions = Literal["pandas", "python"] diff --git a/xarray/core/variable.py b/xarray/core/variable.py index efc67dab2b8..377dafa6f79 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1504,7 +1504,7 @@ def _unstack_once( # Potentially we could replace `len(other_dims)` with just `-1` other_dims = [d for d in self.dims if d != dim] new_shape = tuple(list(reordered.shape[: len(other_dims)]) + new_dim_sizes) - new_dims = reordered.dims[: len(other_dims)] + new_dim_names + new_dims = reordered.dims[: len(other_dims)] + tuple(new_dim_names) create_template: Callable if fill_value is dtypes.NA: From ce657e4056963ea0d476ebc8abc926f0eec2f150 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Wed, 3 Jul 2024 22:54:45 +0200 Subject: [PATCH 08/23] we are getting there? --- xarray/coding/cftime_offsets.py | 2 +- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 21 ++++++++++++++------- xarray/core/resample_cftime.py | 31 ++++++++++++++++++------------- 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 75493ecdce8..064a77831fb 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -935,7 +935,7 @@ def cftime_range( closed: NoDefault | SideOptions = no_default, inclusive: None | InclusiveOptions = None, calendar="standard", -): +) -> CFTimeIndex: """Return a fixed frequency CFTimeIndex. Parameters diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c29d8b7b5a4..2df11c9a3de 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -977,7 +977,7 @@ def indexes(self) -> Indexes: return self.xindexes.to_pandas_indexes() @property - def xindexes(self) -> Indexes: + def xindexes(self) -> Indexes[Index]: """Mapping of :py:class:`~xarray.indexes.Index` objects used for label based indexing. """ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 50cfc7b0c29..62864807519 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -4156,15 +4156,15 @@ def interp_like( kwargs = {} # pick only dimension coordinates with a single index - coords = {} + coords: dict[Hashable, Variable] = {} other_indexes = other.xindexes for dim in self.dims: other_dim_coords = other_indexes.get_all_coords(dim, errors="ignore") if len(other_dim_coords) == 1: coords[dim] = other_dim_coords[dim] - numeric_coords: dict[Hashable, pd.Index] = {} - object_coords: dict[Hashable, pd.Index] = {} + numeric_coords: dict[Hashable, Variable] = {} + object_coords: dict[Hashable, Variable] = {} for k, v in coords.items(): if v.dtype.kind in "uifcMm": numeric_coords[k] = v @@ -6539,7 +6539,13 @@ def interpolate_na( limit: int | None = None, use_coordinate: bool | Hashable = True, max_gap: ( - int | float | str | pd.Timedelta | np.timedelta64 | datetime.timedelta + int + | float + | str + | pd.Timedelta + | np.timedelta64 + | datetime.timedelta + | None ) = None, **kwargs: Any, ) -> Self: @@ -6573,7 +6579,8 @@ def interpolate_na( or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None + max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta \ + or None, default: None Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -9715,7 +9722,7 @@ def eval( c (x) float64 40B 0.0 1.25 2.5 3.75 5.0 """ - return pd.eval( + return pd.eval( # type: ignore[return-value] statement, resolvers=[self], target=self, @@ -10394,7 +10401,7 @@ def groupby_bins( _validate_groupby_squeeze(squeeze) grouper = BinGrouper( - bins=bins, + bins=bins, # type: ignore[arg-type] #TODO: fix this here or in BinGrouper? cut_kwargs={ "right": right, "labels": labels, diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index 216bd8fca6b..fcdeb6d8021 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -66,6 +66,13 @@ class CFTimeGrouper: single method, the only one required for resampling in xarray. It cannot be used in a call to groupby like a pandas.Grouper object can.""" + freq: BaseCFTimeOffset + closed: SideOptions + label: SideOptions + loffset: str | datetime.timedelta | BaseCFTimeOffset | None + origin: str | CFTimeDatetime + offset: datetime.timedelta | None + def __init__( self, freq: str | BaseCFTimeOffset, @@ -73,11 +80,8 @@ def __init__( label: SideOptions | None = None, loffset: str | datetime.timedelta | BaseCFTimeOffset | None = None, origin: str | CFTimeDatetime = "start_day", - offset: str | datetime.timedelta | None = None, + offset: str | datetime.timedelta | BaseCFTimeOffset | None = None, ): - self.offset: datetime.timedelta | None - self.closed: SideOptions - self.label: SideOptions self.freq = to_offset(freq) self.loffset = loffset self.origin = origin @@ -120,10 +124,10 @@ def __init__( if offset is not None: try: self.offset = _convert_offset_to_timedelta(offset) - except (ValueError, AttributeError) as error: + except (AttributeError, TypeError) as error: raise ValueError( f"offset must be a datetime.timedelta object or an offset string " - f"that can be converted to a timedelta. Got {offset} instead." + f"that can be converted to a timedelta. Got {type(offset)} instead." ) from error else: self.offset = None @@ -250,12 +254,12 @@ def _get_time_bins( def _adjust_bin_edges( - datetime_bins: np.ndarray, + datetime_bins: CFTimeIndex, freq: BaseCFTimeOffset, closed: SideOptions, index: CFTimeIndex, - labels: np.ndarray, -): + labels: CFTimeIndex, +) -> tuple[CFTimeIndex, CFTimeIndex]: """This is required for determining the bin edges resampling with month end, quarter end, and year end frequencies. @@ -499,10 +503,11 @@ def _convert_offset_to_timedelta( ) -> datetime.timedelta: if isinstance(offset, datetime.timedelta): return offset - elif isinstance(offset, (str, Tick)): - return to_offset(offset).as_timedelta() - else: - raise ValueError + if isinstance(offset, (str, Tick)): + timedelta_cftime_offset = to_offset(offset) + assert isinstance(timedelta_cftime_offset, Tick) + return timedelta_cftime_offset.as_timedelta() + raise TypeError(f"Expected timedelta, str or Tick, got {type(offset)}") def _ceil_via_cftimeindex(date: CFTimeDatetime, freq: str | BaseCFTimeOffset): From 55891562af257c1417f24e2e45dae4d1ae532495 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Thu, 4 Jul 2024 23:27:12 +0200 Subject: [PATCH 09/23] who might have guessed it... more typing --- xarray/core/groupers.py | 46 +++++++++++++---------- xarray/core/types.py | 7 ++-- xarray/tests/test_cftimeindex.py | 6 +-- xarray/tests/test_cftimeindex_resample.py | 5 ++- xarray/tests/test_coding_times.py | 7 ++-- xarray/tests/test_dataarray.py | 2 +- xarray/tests/test_dataset.py | 41 +++++++++++--------- 7 files changed, 65 insertions(+), 49 deletions(-) diff --git a/xarray/core/groupers.py b/xarray/core/groupers.py index 075afd9f62f..f76bd22a2f6 100644 --- a/xarray/core/groupers.py +++ b/xarray/core/groupers.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from typing import Any +from typing import TYPE_CHECKING, Any, cast import numpy as np import pandas as pd @@ -25,6 +25,9 @@ from xarray.core.utils import emit_user_level_warning from xarray.core.variable import Variable +if TYPE_CHECKING: + pass + __all__ = [ "EncodedGroups", "Grouper", @@ -160,6 +163,7 @@ def _factorize_dummy(self) -> EncodedGroups: # equivalent to: group_indices = group_indices.reshape(-1, 1) group_indices: T_GroupIndices = [slice(i, i + 1) for i in range(size)] size_range = np.arange(size) + full_index: pd.Index if isinstance(self.group, _DummyGroup): codes = self.group.to_dataarray().copy(data=size_range) unique_coord = self.group @@ -275,7 +279,7 @@ def _init_properties(self, group: T_Group) -> None: if isinstance(group_as_index, CFTimeIndex): from xarray.core.resample_cftime import CFTimeGrouper - index_grouper = CFTimeGrouper( + self.index_grouper = CFTimeGrouper( freq=self.freq, closed=self.closed, label=self.label, @@ -284,7 +288,7 @@ def _init_properties(self, group: T_Group) -> None: loffset=self.loffset, ) else: - index_grouper = pd.Grouper( + self.index_grouper = pd.Grouper( # TODO remove once requiring pandas >= 2.2 freq=_new_to_legacy_freq(self.freq), closed=self.closed, @@ -292,7 +296,6 @@ def _init_properties(self, group: T_Group) -> None: origin=self.origin, offset=offset, ) - self.index_grouper = index_grouper self.group_as_index = group_as_index def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: @@ -305,22 +308,25 @@ def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]: return full_index, first_items, codes def first_items(self) -> tuple[pd.Series, np.ndarray]: - from xarray import CFTimeIndex + from xarray.coding.cftimeindex import CFTimeIndex + from xarray.core.resample_cftime import CFTimeGrouper - if isinstance(self.group_as_index, CFTimeIndex): - return self.index_grouper.first_items(self.group_as_index) - else: - s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) - grouped = s.groupby(self.index_grouper) - first_items = grouped.first() - counts = grouped.count() - # This way we generate codes for the final output index: full_index. - # So for _flox_reduce we avoid one reindex and copy by avoiding - # _maybe_restore_empty_groups - codes = np.repeat(np.arange(len(first_items)), counts) - if self.loffset is not None: - _apply_loffset(self.loffset, first_items) - return first_items, codes + if isinstance(self.index_grouper, CFTimeGrouper): + return self.index_grouper.first_items( + cast(CFTimeIndex, self.group_as_index) + ) + + s = pd.Series(np.arange(self.group_as_index.size), self.group_as_index) + grouped = s.groupby(self.index_grouper) + first_items = grouped.first() + counts = grouped.count() + # This way we generate codes for the final output index: full_index. + # So for _flox_reduce we avoid one reindex and copy by avoiding + # _maybe_restore_empty_groups + codes = np.repeat(np.arange(len(first_items)), counts) + if self.loffset is not None: + _apply_loffset(self.loffset, first_items) + return first_items, codes def factorize(self, group) -> EncodedGroups: self._init_properties(group) @@ -369,7 +375,7 @@ def _apply_loffset( ) if isinstance(loffset, str): - loffset = pd.tseries.frequencies.to_offset(loffset) + loffset = pd.tseries.frequencies.to_offset(loffset) # type: ignore[assignment] needs_offset = ( isinstance(loffset, (pd.DateOffset, datetime.timedelta)) diff --git a/xarray/core/types.py b/xarray/core/types.py index b1d0424ac11..20d94de5116 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -28,10 +28,11 @@ else: Self: Any = None -if TYPE_CHECKING: - from numpy._typing import _SupportsDType - from numpy.typing import ArrayLike +from numpy._typing import _SupportsDType +from numpy.typing import ArrayLike + +if TYPE_CHECKING: from xarray.backends.common import BackendEntrypoint from xarray.core.alignment import Aligner from xarray.core.common import AbstractArray, DataWithCoords diff --git a/xarray/tests/test_cftimeindex.py b/xarray/tests/test_cftimeindex.py index 2c74581fcac..116487e2bcf 100644 --- a/xarray/tests/test_cftimeindex.py +++ b/xarray/tests/test_cftimeindex.py @@ -957,17 +957,17 @@ def test_cftimeindex_shift(index, freq) -> None: @requires_cftime -def test_cftimeindex_shift_invalid_n() -> None: +def test_cftimeindex_shift_invalid_periods() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): - index.shift("a", "D") + index.shift("a", "D") # type: ignore[arg-type] @requires_cftime def test_cftimeindex_shift_invalid_freq() -> None: index = xr.cftime_range("2000", periods=3) with pytest.raises(TypeError): - index.shift(1, 1) + index.shift(1, 1) # type: ignore[arg-type] @requires_cftime diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 98d4377706c..361807378d1 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -10,6 +10,7 @@ import xarray as xr from xarray.coding.cftime_offsets import _new_to_legacy_freq +from xarray.coding.cftimeindex import CFTimeIndex from xarray.core.pdcompat import _convert_base_to_offset from xarray.core.resample_cftime import CFTimeGrouper @@ -204,7 +205,9 @@ def test_calendars(calendar: str) -> None: .mean() ) # TODO (benbovy - flexible indexes): update when CFTimeIndex is a xarray Index subclass - da_cftime["time"] = da_cftime.xindexes["time"].to_pandas_index().to_datetimeindex() + new_pd_index = da_cftime.xindexes["time"].to_pandas_index() + assert isinstance(new_pd_index, CFTimeIndex) # shouldn't that be a pd.Index? + da_cftime["time"] = new_pd_index.to_datetimeindex() xr.testing.assert_identical(da_cftime, da_datetime) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 87a6dc5c0a4..37e63f8d4c2 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -3,6 +3,7 @@ import warnings from datetime import timedelta from itertools import product +from typing import Literal import numpy as np import pandas as pd @@ -1236,7 +1237,7 @@ def test_contains_cftime_lazy() -> None: ) def test_roundtrip_datetime64_nanosecond_precision( timestr: str, - timeunit: str, + timeunit: Literal["ns", "us"], dtype: np.typing.DTypeLike, fill_value: int | float | None, use_encoding: bool, @@ -1483,8 +1484,8 @@ def test_encode_cf_datetime_cftime_datetime_via_dask(units, dtype) -> None: import dask.array calendar = "standard" - times = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) - times = dask.array.from_array(times, chunks=1) + times_idx = cftime_range(start="1700", freq="D", periods=3, calendar=calendar) + times = dask.array.from_array(times_idx, chunks=1) encoded_times, encoding_units, encoding_calendar = encode_cf_datetime( times, units, None, dtype ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index eae87e03e41..1401be7b9b5 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -427,7 +427,7 @@ def test_constructor_invalid(self) -> None: DataArray(np.random.rand(4, 4), [("x", self.mindex), ("level_1", range(4))]) def test_constructor_from_self_described(self) -> None: - data = [[-0.1, 21], [0, 2]] + data: list[list[float]] = [[-0.1, 21], [0, 2]] expected = DataArray( data, coords={"x": ["a", "b"], "y": [-1, -2]}, diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f6829861776..2e21fbb1dae 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -39,6 +39,7 @@ from xarray.core.common import duck_array_ops, full_like from xarray.core.coordinates import Coordinates, DatasetCoordinates from xarray.core.indexes import Index, PandasIndex +from xarray.core.types import ArrayLike from xarray.core.utils import is_scalar from xarray.namedarray.pycompat import array_type, integer_types from xarray.testing import _assert_internal_invariants @@ -2044,9 +2045,9 @@ def test_to_pandas(self) -> None: y = np.random.randn(10) t = list("abcdefghij") ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) - actual = ds.to_pandas() - expected = ds.to_dataframe() - assert expected.equals(actual), (expected, actual) + actual_df = ds.to_pandas() + expected_df = ds.to_dataframe() + assert expected_df.equals(actual_df), (expected_df, actual_df) # 2D -> error x2d = np.random.randn(10, 10) @@ -3618,6 +3619,7 @@ def test_reset_index_drop_convert( def test_reorder_levels(self) -> None: ds = create_test_multiindex() mindex = ds["x"].to_index() + assert isinstance(mindex, pd.MultiIndex) midx = mindex.reorder_levels(["level_2", "level_1"]) midx_coords = Coordinates.from_pandas_multiindex(midx, "x") expected = Dataset({}, coords=midx_coords) @@ -3943,7 +3945,9 @@ def test_to_stacked_array_dtype_dims(self) -> None: D = xr.Dataset({"a": a, "b": b}) sample_dims = ["x"] y = D.to_stacked_array("features", sample_dims) - assert y.xindexes["features"].to_pandas_index().levels[1].dtype == D.y.dtype + mindex = y.xindexes["features"].to_pandas_index() + assert isinstance(mindex, pd.MultiIndex) + assert mindex.levels[1].dtype == D.y.dtype assert y.dims == ("x", "features") def test_to_stacked_array_to_unstacked_dataset(self) -> None: @@ -4114,9 +4118,9 @@ def test_virtual_variables_default_coords(self) -> None: def test_virtual_variables_time(self) -> None: # access virtual variables data = create_test_data() - assert_array_equal( - data["time.month"].values, data.variables["time"].to_index().month - ) + index = data.variables["time"].to_index() + assert isinstance(index, pd.DatetimeIndex) + assert_array_equal(data["time.month"].values, index.month) assert_array_equal(data["time.season"].values, "DJF") # test virtual variable math assert_array_equal(data["time.dayofyear"] + 1, 2 + np.arange(20)) @@ -4757,20 +4761,20 @@ def test_to_and_from_dataframe(self) -> None: # check pathological cases df = pd.DataFrame([1]) - actual = Dataset.from_dataframe(df) - expected = Dataset({0: ("index", [1])}, {"index": [0]}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset({0: ("index", [1])}, {"index": [0]}) + assert_identical(expected_ds, actual_ds) df = pd.DataFrame() - actual = Dataset.from_dataframe(df) - expected = Dataset(coords={"index": []}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset(coords={"index": []}) + assert_identical(expected_ds, actual_ds) # GH697 df = pd.DataFrame({"A": []}) - actual = Dataset.from_dataframe(df) - expected = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) - assert_identical(expected, actual) + actual_ds = Dataset.from_dataframe(df) + expected_ds = Dataset({"A": DataArray([], dims=("index",))}, {"index": []}) + assert_identical(expected_ds, actual_ds) # regression test for GH278 # use int64 to ensure consistent results for the pandas .equals method @@ -4809,7 +4813,7 @@ def test_from_dataframe_categorical_index(self) -> None: def test_from_dataframe_categorical_index_string_categories(self) -> None: cat = pd.CategoricalIndex( pd.Categorical.from_codes( - np.array([1, 1, 0, 2]), + np.array([1, 1, 0, 2], dtype=np.int64), # type: ignore[arg-type] categories=pd.Index(["foo", "bar", "baz"], dtype="string"), ) ) @@ -4894,7 +4898,7 @@ def test_from_dataframe_unsorted_levels(self) -> None: def test_from_dataframe_non_unique_columns(self) -> None: # regression test for GH449 df = pd.DataFrame(np.zeros((2, 2))) - df.columns = ["foo", "foo"] + df.columns = ["foo", "foo"] # type: ignore[assignment] with pytest.raises(ValueError, match=r"non-unique columns"): Dataset.from_dataframe(df) @@ -7183,6 +7187,7 @@ def test_cumulative_integrate(dask) -> None: @pytest.mark.parametrize("which_datetime", ["np", "cftime"]) def test_trapezoid_datetime(dask, which_datetime) -> None: rs = np.random.RandomState(42) + coord: ArrayLike if which_datetime == "np": coord = np.array( [ From e45d28b6708ddeb5dcf70bfd65a16b99a5752958 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 7 Jul 2024 22:15:21 +0200 Subject: [PATCH 10/23] continue fixing typing issues --- xarray/tests/test_cftimeindex_resample.py | 4 +- xarray/tests/test_dataset.py | 4 +- xarray/tests/test_groupby.py | 12 ++--- xarray/tests/test_variable.py | 53 +++++++++++++---------- 4 files changed, 41 insertions(+), 32 deletions(-) diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 361807378d1..08f71ada634 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -282,7 +282,9 @@ def test_resample_loffset_cftimeindex(loffset) -> None: result = da_cftimeindex.resample(time="24h", loffset=loffset).mean() expected = da_datetimeindex.resample(time="24h", loffset=loffset).mean() - result["time"] = result.xindexes["time"].to_pandas_index().to_datetimeindex() + index = result.xindexes["time"].to_pandas_index() + assert isinstance(index, CFTimeIndex) + result["time"] = index.to_datetimeindex() xr.testing.assert_identical(result, expected) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 2e21fbb1dae..52de0a265c3 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -581,7 +581,7 @@ def test_constructor_pandas_single(self) -> None: pandas_obj = a.to_pandas() ds_based_on_pandas = Dataset(pandas_obj) # type: ignore # TODO: improve typing of __init__ for dim in ds_based_on_pandas.data_vars: - assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) + assert_array_equal(ds_based_on_pandas[dim], pandas_obj[str(dim)]) def test_constructor_compat(self) -> None: data = {"x": DataArray(0, coords={"y": 1}), "y": ("z", [1, 1, 1])} @@ -1695,7 +1695,7 @@ def test_sel_categorical_error(self) -> None: with pytest.raises(ValueError): ds.sel(ind="bar", method="nearest") with pytest.raises(ValueError): - ds.sel(ind="bar", tolerance="nearest") + ds.sel(ind="bar", tolerance="nearest") # type: ignore[arg-type] def test_categorical_index(self) -> None: cat = pd.CategoricalIndex( diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index f0a0fd14d9d..5446ea0a540 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -734,7 +734,7 @@ def test_groupby_bins_timeseries() -> None: expected = xr.DataArray( 96 * np.ones((14,)), dims=["time_bins"], - coords={"time_bins": pd.cut(time_bins, time_bins).categories}, + coords={"time_bins": pd.cut(time_bins, time_bins).categories}, # type: ignore[arg-type] ).to_dataset(name="val") assert_identical(actual, expected) @@ -868,7 +868,7 @@ def test_groupby_dataset_errors() -> None: with pytest.raises(ValueError, match=r"length does not match"): data.groupby(data["dim1"][:3]) with pytest.raises(TypeError, match=r"`group` must be"): - data.groupby(data.coords["dim1"].to_index()) + data.groupby(data.coords["dim1"].to_index()) # type: ignore[arg-type] def test_groupby_dataset_reduce() -> None: @@ -1481,7 +1481,7 @@ def test_groupby_math(self) -> None: with pytest.raises(TypeError, match=r"only support binary ops"): grouped + grouped # type: ignore[type-var] with pytest.raises(TypeError, match=r"in-place operations"): - array += grouped # type: ignore[arg-type] + array += grouped def test_groupby_math_not_aligned(self) -> None: array = DataArray( @@ -1624,7 +1624,7 @@ def test_groupby_bins( bins = [0, 1.5, 5] df = array.to_dataframe() - df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) + df["dim_0_bins"] = pd.cut(array["dim_0"], bins, **cut_kwargs) # type: ignore[call-overload] expected_df = df.groupby("dim_0_bins", observed=True).sum() # TODO: can't convert df with IntervalIndex to Xarray @@ -1690,7 +1690,7 @@ def test_groupby_bins_empty(self) -> None: array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty bins = [0, 4, 5] - bin_coords = pd.cut(array["x"], bins).categories + bin_coords = pd.cut(array["x"], bins).categories # type: ignore[call-overload] actual = array.groupby_bins("x", bins).sum() expected = DataArray([6, np.nan], dims="x_bins", coords={"x_bins": bin_coords}) assert_identical(expected, actual) @@ -1701,7 +1701,7 @@ def test_groupby_bins_empty(self) -> None: def test_groupby_bins_multidim(self) -> None: array = self.make_groupby_multidim_example_array() bins = [0, 15, 20] - bin_coords = pd.cut(array["lat"].values.flat, bins).categories + bin_coords = pd.cut(array["lat"].values.flat, bins).categories # type: ignore[call-overload] expected = DataArray([16, 40], dims="lat_bins", coords={"lat_bins": bin_coords}) actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) assert_identical(expected, actual) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 081bf09484a..60c173a9e52 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2649,7 +2649,7 @@ def test_tz_datetime(self) -> None: tz = pytz.timezone("America/New_York") times_ns = pd.date_range("2000", periods=1, tz=tz) - times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) + times_s = times_ns.astype(pd.DatetimeTZDtype("s", tz)) # type: ignore[arg-type] with warnings.catch_warnings(): warnings.simplefilter("ignore") actual: T_DuckArray = as_compatible_data(times_s) @@ -2661,7 +2661,7 @@ def test_tz_datetime(self) -> None: warnings.simplefilter("ignore") actual2: T_DuckArray = as_compatible_data(series) - np.testing.assert_array_equal(actual2, series.values) + np.testing.assert_array_equal(actual2, np.asarray(series.values)) assert actual2.dtype == np.dtype("datetime64[ns]") def test_full_like(self) -> None: @@ -2978,26 +2978,35 @@ def test_datetime_conversion_warning(values, warns) -> None: ) -def test_pandas_two_only_datetime_conversion_warnings() -> None: - # Note these tests rely on pandas features that are only present in pandas - # 2.0.0 and above, and so for now cannot be parametrized. - cases = [ - (pd.date_range("2000", periods=1), "datetime64[s]"), - (pd.Series(pd.date_range("2000", periods=1)), "datetime64[s]"), - ( - pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")), - pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), +tz_ny = pytz.timezone("America/New_York") + + +@pytest.mark.parametrize( + ["data", "dtype"], + [ + pytest.param(pd.date_range("2000", periods=1), "datetime64[s]", id="index-sec"), + pytest.param( + pd.Series(pd.date_range("2000", periods=1)), + "datetime64[s]", + id="series-sec", ), - ( - pd.Series( - pd.date_range("2000", periods=1, tz=pytz.timezone("America/New_York")) - ), - pd.DatetimeTZDtype("s", pytz.timezone("America/New_York")), + pytest.param( + pd.date_range("2000", periods=1, tz=tz_ny), + pd.DatetimeTZDtype("s", tz_ny), # type: ignore[arg-type] + id="index-timezone", ), - ] - for data, dtype in cases: - with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): - var = Variable(["time"], data.astype(dtype)) + pytest.param( + pd.Series(pd.date_range("2000", periods=1, tz=tz_ny)), + pd.DatetimeTZDtype("s", tz_ny), # type: ignore[arg-type] + id="series-timezone", + ), + ], +) +def test_pandas_two_only_datetime_conversion_warnings( + data: pd.DatetimeIndex | pd.Series, dtype: str | pd.DatetimeTZDtype +) -> None: + with pytest.warns(UserWarning, match="non-nanosecond precision datetime"): + var = Variable(["time"], data.astype(dtype)) # type: ignore[arg-type] if var.dtype.kind == "M": assert var.dtype == np.dtype("datetime64[ns]") @@ -3006,9 +3015,7 @@ def test_pandas_two_only_datetime_conversion_warnings() -> None: # the case that the variable is backed by a timezone-aware # DatetimeIndex, and thus is hidden within the PandasIndexingAdapter class. assert isinstance(var._data, PandasIndexingAdapter) - assert var._data.array.dtype == pd.DatetimeTZDtype( - "ns", pytz.timezone("America/New_York") - ) + assert var._data.array.dtype == pd.DatetimeTZDtype("ns", tz_ny) @pytest.mark.parametrize( From faa8bccaadf200c8886a30d247a07f209d49d547 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 7 Jul 2024 22:56:12 +0200 Subject: [PATCH 11/23] fix some typed_ops --- xarray/core/_typed_ops.py | 141 ++++++++++++++++++------------------ xarray/util/generate_ops.py | 22 +++--- 2 files changed, 78 insertions(+), 85 deletions(-) diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index c1748e322c2..d30a68b3dc4 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -5,21 +5,18 @@ from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import Any, Callable, overload from xarray.core import nputils, ops from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByCompatible, Self, - T_DataArray, T_Xarray, VarCompatible, ) - -if TYPE_CHECKING: - from xarray.core.dataset import Dataset +from xarray.core.types import T_DataArray as T_DA +from xarray.core.types import T_Dataset as T_DS class DatasetOpsMixin: @@ -455,165 +452,165 @@ def _binary_op( raise NotImplementedError @overload - def __add__(self, other: T_DataArray) -> T_DataArray: ... + def __add__(self, other: T_DA) -> T_DA: ... @overload def __add__(self, other: VarCompatible) -> Self: ... - def __add__(self, other: VarCompatible) -> Self | T_DataArray: + def __add__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.add) @overload - def __sub__(self, other: T_DataArray) -> T_DataArray: ... + def __sub__(self, other: T_DA) -> T_DA: ... @overload def __sub__(self, other: VarCompatible) -> Self: ... - def __sub__(self, other: VarCompatible) -> Self | T_DataArray: + def __sub__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.sub) @overload - def __mul__(self, other: T_DataArray) -> T_DataArray: ... + def __mul__(self, other: T_DA) -> T_DA: ... @overload def __mul__(self, other: VarCompatible) -> Self: ... - def __mul__(self, other: VarCompatible) -> Self | T_DataArray: + def __mul__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.mul) @overload - def __pow__(self, other: T_DataArray) -> T_DataArray: ... + def __pow__(self, other: T_DA) -> T_DA: ... @overload def __pow__(self, other: VarCompatible) -> Self: ... - def __pow__(self, other: VarCompatible) -> Self | T_DataArray: + def __pow__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.pow) @overload - def __truediv__(self, other: T_DataArray) -> T_DataArray: ... + def __truediv__(self, other: T_DA) -> T_DA: ... @overload def __truediv__(self, other: VarCompatible) -> Self: ... - def __truediv__(self, other: VarCompatible) -> Self | T_DataArray: + def __truediv__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.truediv) @overload - def __floordiv__(self, other: T_DataArray) -> T_DataArray: ... + def __floordiv__(self, other: T_DA) -> T_DA: ... @overload def __floordiv__(self, other: VarCompatible) -> Self: ... - def __floordiv__(self, other: VarCompatible) -> Self | T_DataArray: + def __floordiv__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.floordiv) @overload - def __mod__(self, other: T_DataArray) -> T_DataArray: ... + def __mod__(self, other: T_DA) -> T_DA: ... @overload def __mod__(self, other: VarCompatible) -> Self: ... - def __mod__(self, other: VarCompatible) -> Self | T_DataArray: + def __mod__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.mod) @overload - def __and__(self, other: T_DataArray) -> T_DataArray: ... + def __and__(self, other: T_DA) -> T_DA: ... @overload def __and__(self, other: VarCompatible) -> Self: ... - def __and__(self, other: VarCompatible) -> Self | T_DataArray: + def __and__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.and_) @overload - def __xor__(self, other: T_DataArray) -> T_DataArray: ... + def __xor__(self, other: T_DA) -> T_DA: ... @overload def __xor__(self, other: VarCompatible) -> Self: ... - def __xor__(self, other: VarCompatible) -> Self | T_DataArray: + def __xor__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.xor) @overload - def __or__(self, other: T_DataArray) -> T_DataArray: ... + def __or__(self, other: T_DA) -> T_DA: ... @overload def __or__(self, other: VarCompatible) -> Self: ... - def __or__(self, other: VarCompatible) -> Self | T_DataArray: + def __or__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.or_) @overload - def __lshift__(self, other: T_DataArray) -> T_DataArray: ... + def __lshift__(self, other: T_DA) -> T_DA: ... @overload def __lshift__(self, other: VarCompatible) -> Self: ... - def __lshift__(self, other: VarCompatible) -> Self | T_DataArray: + def __lshift__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.lshift) @overload - def __rshift__(self, other: T_DataArray) -> T_DataArray: ... + def __rshift__(self, other: T_DA) -> T_DA: ... @overload def __rshift__(self, other: VarCompatible) -> Self: ... - def __rshift__(self, other: VarCompatible) -> Self | T_DataArray: + def __rshift__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.rshift) @overload - def __lt__(self, other: T_DataArray) -> T_DataArray: ... + def __lt__(self, other: T_DA) -> T_DA: ... @overload def __lt__(self, other: VarCompatible) -> Self: ... - def __lt__(self, other: VarCompatible) -> Self | T_DataArray: + def __lt__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.lt) @overload - def __le__(self, other: T_DataArray) -> T_DataArray: ... + def __le__(self, other: T_DA) -> T_DA: ... @overload def __le__(self, other: VarCompatible) -> Self: ... - def __le__(self, other: VarCompatible) -> Self | T_DataArray: + def __le__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.le) @overload - def __gt__(self, other: T_DataArray) -> T_DataArray: ... + def __gt__(self, other: T_DA) -> T_DA: ... @overload def __gt__(self, other: VarCompatible) -> Self: ... - def __gt__(self, other: VarCompatible) -> Self | T_DataArray: + def __gt__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.gt) @overload - def __ge__(self, other: T_DataArray) -> T_DataArray: ... + def __ge__(self, other: T_DA) -> T_DA: ... @overload def __ge__(self, other: VarCompatible) -> Self: ... - def __ge__(self, other: VarCompatible) -> Self | T_DataArray: + def __ge__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, operator.ge) @overload # type:ignore[override] - def __eq__(self, other: T_DataArray) -> T_DataArray: ... + def __eq__(self, other: T_DA) -> T_DA: ... @overload def __eq__(self, other: VarCompatible) -> Self: ... - def __eq__(self, other: VarCompatible) -> Self | T_DataArray: + def __eq__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, nputils.array_eq) @overload # type:ignore[override] - def __ne__(self, other: T_DataArray) -> T_DataArray: ... + def __ne__(self, other: T_DA) -> T_DA: ... @overload def __ne__(self, other: VarCompatible) -> Self: ... - def __ne__(self, other: VarCompatible) -> Self | T_DataArray: + def __ne__(self, other: VarCompatible) -> Self | T_DA: return self._binary_op(other, nputils.array_ne) # When __eq__ is defined but __hash__ is not, then an object is unhashable, @@ -770,96 +767,96 @@ class DatasetGroupByOpsMixin: __slots__ = () def _binary_op( - self, other: GroupByCompatible, f: Callable, reflexive: bool = False - ) -> Dataset: + self, other: T_DS | T_DA, f: Callable, reflexive: bool = False + ) -> T_DS: raise NotImplementedError - def __add__(self, other: GroupByCompatible) -> Dataset: + def __add__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.add) - def __sub__(self, other: GroupByCompatible) -> Dataset: + def __sub__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.sub) - def __mul__(self, other: GroupByCompatible) -> Dataset: + def __mul__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.mul) - def __pow__(self, other: GroupByCompatible) -> Dataset: + def __pow__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.pow) - def __truediv__(self, other: GroupByCompatible) -> Dataset: + def __truediv__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other: GroupByCompatible) -> Dataset: + def __floordiv__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.floordiv) - def __mod__(self, other: GroupByCompatible) -> Dataset: + def __mod__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.mod) - def __and__(self, other: GroupByCompatible) -> Dataset: + def __and__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.and_) - def __xor__(self, other: GroupByCompatible) -> Dataset: + def __xor__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.xor) - def __or__(self, other: GroupByCompatible) -> Dataset: + def __or__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.or_) - def __lshift__(self, other: GroupByCompatible) -> Dataset: + def __lshift__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.lshift) - def __rshift__(self, other: GroupByCompatible) -> Dataset: + def __rshift__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.rshift) - def __lt__(self, other: GroupByCompatible) -> Dataset: + def __lt__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.lt) - def __le__(self, other: GroupByCompatible) -> Dataset: + def __le__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.le) - def __gt__(self, other: GroupByCompatible) -> Dataset: + def __gt__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.gt) - def __ge__(self, other: GroupByCompatible) -> Dataset: + def __ge__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.ge) - def __eq__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] + def __eq__(self, other: T_DS | T_DA) -> T_DS: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other: GroupByCompatible) -> Dataset: # type:ignore[override] + def __ne__(self, other: T_DS | T_DA) -> T_DS: # type:ignore[override] return self._binary_op(other, nputils.array_ne) # When __eq__ is defined but __hash__ is not, then an object is unhashable, # and it should be declared as follows: __hash__: None # type:ignore[assignment] - def __radd__(self, other: GroupByCompatible) -> Dataset: + def __radd__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other: GroupByCompatible) -> Dataset: + def __rsub__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other: GroupByCompatible) -> Dataset: + def __rmul__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other: GroupByCompatible) -> Dataset: + def __rpow__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other: GroupByCompatible) -> Dataset: + def __rtruediv__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other: GroupByCompatible) -> Dataset: + def __rfloordiv__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other: GroupByCompatible) -> Dataset: + def __rmod__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other: GroupByCompatible) -> Dataset: + def __rand__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other: GroupByCompatible) -> Dataset: + def __rxor__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other: GroupByCompatible) -> Dataset: + def __ror__(self, other: T_DS | T_DA) -> T_DS: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index ee4dd68b3ba..9718a994831 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -88,12 +88,10 @@ def {method}(self, other: {other_type}) -> {return_type}:{type_ignore} return self._binary_op(other, {func})""" template_binop_overload = """ @overload{overload_type_ignore} - def {method}(self, other: {overload_type}) -> {overload_type}: - ... + def {method}(self, other: {overload_type}) -> {overload_type}: ... @overload - def {method}(self, other: {other_type}) -> {return_type}: - ... + def {method}(self, other: {other_type}) -> {return_type}: ... def {method}(self, other: {other_type}) -> {return_type} | {overload_type}:{type_ignore} return self._binary_op(other, {func})""" @@ -129,7 +127,7 @@ def {method}(self, *args: Any, **kwargs: Any) -> Self: # The type ignores might not be necessary anymore at some point. # # We require a "hack" to tell type checkers that e.g. Variable + DataArray = DataArray -# In reality this returns NotImplementes, but this is not a valid type in python 3.9. +# In reality this returns NotImplemented, but this is not a valid type in python 3.9. # Therefore, we return DataArray. In reality this would call DataArray.__add__(Variable) # TODO: change once python 3.10 is the minimum. # @@ -224,12 +222,12 @@ def unops() -> list[OpsType]: binops(other_type="DaCompatible") + inplace(other_type="DaCompatible") + unops() ) ops_info["VariableOpsMixin"] = ( - binops_overload(other_type="VarCompatible", overload_type="T_DataArray") + binops_overload(other_type="VarCompatible", overload_type="T_DA") + inplace(other_type="VarCompatible", type_ignore="misc") + unops() ) ops_info["DatasetGroupByOpsMixin"] = binops( - other_type="GroupByCompatible", return_type="Dataset" + other_type="T_DS | T_DA", return_type="T_DS" ) ops_info["DataArrayGroupByOpsMixin"] = binops( other_type="T_Xarray", return_type="T_Xarray" @@ -237,26 +235,24 @@ def unops() -> list[OpsType]: MODULE_PREAMBLE = '''\ """Mixin classes with arithmetic operators.""" + # This file was generated using xarray.util.generate_ops. Do not edit manually. from __future__ import annotations import operator -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import Any, Callable, overload from xarray.core import nputils, ops from xarray.core.types import ( DaCompatible, DsCompatible, - GroupByCompatible, Self, - T_DataArray, T_Xarray, VarCompatible, ) - -if TYPE_CHECKING: - from xarray.core.dataset import Dataset''' +from xarray.core.types import T_DataArray as T_DA +from xarray.core.types import T_Dataset as T_DS''' CLASS_PREAMBLE = """{newline} From 2b678b1183993fb046e3b1fd509140639c4ab355 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 7 Jul 2024 23:09:47 +0200 Subject: [PATCH 12/23] fix last non-typed-ops errors --- xarray/tests/test_plot.py | 8 +++++--- xarray/tests/test_rolling.py | 22 ++++++++++++++-------- xarray/util/generate_ops.py | 4 ++++ 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b302ad3af93..a0bd6b939d5 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -6,7 +6,7 @@ from collections.abc import Generator, Hashable from copy import copy from datetime import date, timedelta -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, cast import numpy as np import pandas as pd @@ -527,7 +527,7 @@ def test__infer_interval_breaks(self) -> None: [-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10]) ) assert_array_equal( - pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), + pd.date_range("20000101", periods=4) - np.timedelta64(12, "h"), # type: ignore[operator] _infer_interval_breaks(pd.date_range("20000101", periods=3)), ) @@ -1047,7 +1047,9 @@ def test_list_levels(self) -> None: assert cmap_params["cmap"].N == 5 assert cmap_params["norm"].N == 6 - for wrap_levels in [list, np.array, pd.Index, DataArray]: + for wrap_levels in cast( + list[Callable[[Any], dict[Any, Any]]], [list, np.array, pd.Index, DataArray] + ): cmap_params = _determine_cmap_params(data, levels=wrap_levels(orig_levels)) assert_array_equal(cmap_params["levels"], orig_levels) diff --git a/xarray/tests/test_rolling.py b/xarray/tests/test_rolling.py index 89f6ebba2c3..79869e63ae7 100644 --- a/xarray/tests/test_rolling.py +++ b/xarray/tests/test_rolling.py @@ -206,9 +206,9 @@ def test_rolling_pandas_compat( index=window, center=center, min_periods=min_periods ).reduce(np.nanmean) - np.testing.assert_allclose(s_rolling.values, da_rolling.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling.values) np.testing.assert_allclose(s_rolling.index, da_rolling["index"]) - np.testing.assert_allclose(s_rolling.values, da_rolling_np.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling_np.values) np.testing.assert_allclose(s_rolling.index, da_rolling_np["index"]) @pytest.mark.parametrize("center", (True, False)) @@ -221,12 +221,14 @@ def test_rolling_construct(self, center: bool, window: int) -> None: da_rolling = da.rolling(index=window, center=center, min_periods=1) da_rolling_mean = da_rolling.construct("window").mean("window") - np.testing.assert_allclose(s_rolling.values, da_rolling_mean.values) + np.testing.assert_allclose(np.asarray(s_rolling.values), da_rolling_mean.values) np.testing.assert_allclose(s_rolling.index, da_rolling_mean["index"]) # with stride da_rolling_mean = da_rolling.construct("window", stride=2).mean("window") - np.testing.assert_allclose(s_rolling.values[::2], da_rolling_mean.values) + np.testing.assert_allclose( + np.asarray(s_rolling.values[::2]), da_rolling_mean.values + ) np.testing.assert_allclose(s_rolling.index[::2], da_rolling_mean["index"]) # with fill_value @@ -649,7 +651,9 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None: index=window, center=center, min_periods=min_periods ).mean() - np.testing.assert_allclose(df_rolling["x"].values, ds_rolling["x"].values) + np.testing.assert_allclose( + np.asarray(df_rolling["x"].values), ds_rolling["x"].values + ) np.testing.assert_allclose(df_rolling.index, ds_rolling["index"]) @pytest.mark.parametrize("center", (True, False)) @@ -668,7 +672,9 @@ def test_rolling_construct(self, center: bool, window: int) -> None: ds_rolling = ds.rolling(index=window, center=center) ds_rolling_mean = ds_rolling.construct("window").mean("window") - np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) + np.testing.assert_allclose( + np.asarray(df_rolling["x"].values), ds_rolling_mean["x"].values + ) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) # with fill_value @@ -695,7 +701,7 @@ def test_rolling_construct_stride(self, center: bool, window: int) -> None: ds_rolling = ds.rolling(index=window, center=center) ds_rolling_mean = ds_rolling.construct("w", stride=2).mean("w") np.testing.assert_allclose( - df_rolling_mean["x"][::2].values, ds_rolling_mean["x"].values + np.asarray(df_rolling_mean["x"][::2].values), ds_rolling_mean["x"].values ) np.testing.assert_allclose(df_rolling_mean.index[::2], ds_rolling_mean["index"]) @@ -704,7 +710,7 @@ def test_rolling_construct_stride(self, center: bool, window: int) -> None: ds2_rolling = ds2.rolling(index=window, center=center) ds2_rolling_mean = ds2_rolling.construct("w", stride=2).mean("w") np.testing.assert_allclose( - df_rolling_mean["x"][::2].values, ds2_rolling_mean["x"].values + np.asarray(df_rolling_mean["x"][::2].values), ds2_rolling_mean["x"].values ) # Mixed coordinates, indexes and 2D coordinates diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index 9718a994831..15b0c9dcd66 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -214,6 +214,10 @@ def unops() -> list[OpsType]: ] +# We use short names T_DA and T_DS to keep below 88 lines so +# ruff does not reformat everything. When reformatting, the +# type-ignores end up in the wrong line :/ + ops_info = {} ops_info["DatasetOpsMixin"] = ( binops(other_type="DsCompatible") + inplace(other_type="DsCompatible") + unops() From ab4dbf55888adc39386cd66eca84377153b8fb90 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 12 Jul 2024 21:19:09 +0200 Subject: [PATCH 13/23] update typed ops --- xarray/core/_typed_ops.py | 69 +++++++++++++++++++------------------ xarray/util/generate_ops.py | 11 +++--- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/xarray/core/_typed_ops.py b/xarray/core/_typed_ops.py index d30a68b3dc4..61aa1846bd0 100644 --- a/xarray/core/_typed_ops.py +++ b/xarray/core/_typed_ops.py @@ -5,7 +5,7 @@ from __future__ import annotations import operator -from typing import Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, overload from xarray.core import nputils, ops from xarray.core.types import ( @@ -15,8 +15,11 @@ T_Xarray, VarCompatible, ) -from xarray.core.types import T_DataArray as T_DA -from xarray.core.types import T_Dataset as T_DS + +if TYPE_CHECKING: + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import T_DataArray as T_DA class DatasetOpsMixin: @@ -767,96 +770,96 @@ class DatasetGroupByOpsMixin: __slots__ = () def _binary_op( - self, other: T_DS | T_DA, f: Callable, reflexive: bool = False - ) -> T_DS: + self, other: Dataset | DataArray, f: Callable, reflexive: bool = False + ) -> Dataset: raise NotImplementedError - def __add__(self, other: T_DS | T_DA) -> T_DS: + def __add__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.add) - def __sub__(self, other: T_DS | T_DA) -> T_DS: + def __sub__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.sub) - def __mul__(self, other: T_DS | T_DA) -> T_DS: + def __mul__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mul) - def __pow__(self, other: T_DS | T_DA) -> T_DS: + def __pow__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.pow) - def __truediv__(self, other: T_DS | T_DA) -> T_DS: + def __truediv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.truediv) - def __floordiv__(self, other: T_DS | T_DA) -> T_DS: + def __floordiv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.floordiv) - def __mod__(self, other: T_DS | T_DA) -> T_DS: + def __mod__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mod) - def __and__(self, other: T_DS | T_DA) -> T_DS: + def __and__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.and_) - def __xor__(self, other: T_DS | T_DA) -> T_DS: + def __xor__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.xor) - def __or__(self, other: T_DS | T_DA) -> T_DS: + def __or__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.or_) - def __lshift__(self, other: T_DS | T_DA) -> T_DS: + def __lshift__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.lshift) - def __rshift__(self, other: T_DS | T_DA) -> T_DS: + def __rshift__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.rshift) - def __lt__(self, other: T_DS | T_DA) -> T_DS: + def __lt__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.lt) - def __le__(self, other: T_DS | T_DA) -> T_DS: + def __le__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.le) - def __gt__(self, other: T_DS | T_DA) -> T_DS: + def __gt__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.gt) - def __ge__(self, other: T_DS | T_DA) -> T_DS: + def __ge__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.ge) - def __eq__(self, other: T_DS | T_DA) -> T_DS: # type:ignore[override] + def __eq__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_eq) - def __ne__(self, other: T_DS | T_DA) -> T_DS: # type:ignore[override] + def __ne__(self, other: Dataset | DataArray) -> Dataset: # type:ignore[override] return self._binary_op(other, nputils.array_ne) # When __eq__ is defined but __hash__ is not, then an object is unhashable, # and it should be declared as follows: __hash__: None # type:ignore[assignment] - def __radd__(self, other: T_DS | T_DA) -> T_DS: + def __radd__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.add, reflexive=True) - def __rsub__(self, other: T_DS | T_DA) -> T_DS: + def __rsub__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.sub, reflexive=True) - def __rmul__(self, other: T_DS | T_DA) -> T_DS: + def __rmul__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mul, reflexive=True) - def __rpow__(self, other: T_DS | T_DA) -> T_DS: + def __rpow__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.pow, reflexive=True) - def __rtruediv__(self, other: T_DS | T_DA) -> T_DS: + def __rtruediv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.truediv, reflexive=True) - def __rfloordiv__(self, other: T_DS | T_DA) -> T_DS: + def __rfloordiv__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.floordiv, reflexive=True) - def __rmod__(self, other: T_DS | T_DA) -> T_DS: + def __rmod__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.mod, reflexive=True) - def __rand__(self, other: T_DS | T_DA) -> T_DS: + def __rand__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.and_, reflexive=True) - def __rxor__(self, other: T_DS | T_DA) -> T_DS: + def __rxor__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.xor, reflexive=True) - def __ror__(self, other: T_DS | T_DA) -> T_DS: + def __ror__(self, other: Dataset | DataArray) -> Dataset: return self._binary_op(other, operator.or_, reflexive=True) __add__.__doc__ = operator.add.__doc__ diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index 15b0c9dcd66..a9f66cdc614 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -231,7 +231,7 @@ def unops() -> list[OpsType]: + unops() ) ops_info["DatasetGroupByOpsMixin"] = binops( - other_type="T_DS | T_DA", return_type="T_DS" + other_type="Dataset | DataArray", return_type="Dataset" ) ops_info["DataArrayGroupByOpsMixin"] = binops( other_type="T_Xarray", return_type="T_Xarray" @@ -245,7 +245,7 @@ def unops() -> list[OpsType]: from __future__ import annotations import operator -from typing import Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, overload from xarray.core import nputils, ops from xarray.core.types import ( @@ -255,8 +255,11 @@ def unops() -> list[OpsType]: T_Xarray, VarCompatible, ) -from xarray.core.types import T_DataArray as T_DA -from xarray.core.types import T_Dataset as T_DS''' + +if TYPE_CHECKING: + from xarray.core.dataarray import DataArray + from xarray.core.dataset import Dataset + from xarray.core.types import T_DataArray as T_DA''' CLASS_PREAMBLE = """{newline} From 27db395ef7f19b881f33a4823b8742b26dfcd9fa Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 12 Jul 2024 21:20:13 +0200 Subject: [PATCH 14/23] remove useless DaskArray type in scalar or array type --- xarray/core/types.py | 2 +- xarray/tests/test_groupby.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/types.py b/xarray/core/types.py index 20d94de5116..786ab5973b1 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -178,7 +178,7 @@ def copy( T_ExtensionArray = TypeVar("T_ExtensionArray", bound=pd.api.extensions.ExtensionArray) -ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"] +ScalarOrArray = Union["ArrayLike", np.generic] VarCompatible = Union["Variable", "ScalarOrArray"] DaCompatible = Union["DataArray", "VarCompatible"] DsCompatible = Union["Dataset", "DaCompatible"] diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 5446ea0a540..469e5a3b1f2 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1481,7 +1481,7 @@ def test_groupby_math(self) -> None: with pytest.raises(TypeError, match=r"only support binary ops"): grouped + grouped # type: ignore[type-var] with pytest.raises(TypeError, match=r"in-place operations"): - array += grouped + array += grouped # type: ignore[arg-type] def test_groupby_math_not_aligned(self) -> None: array = DataArray( From 59771078279d72c6351ee4e3e52370e37b04f30d Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 12 Jul 2024 21:27:25 +0200 Subject: [PATCH 15/23] fix missing import in type_checking --- xarray/coding/times.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 7be26f95697..badb9259b06 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -5,7 +5,7 @@ from collections.abc import Hashable from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, Callable, Literal, Union, cast +from typing import Callable, Literal, Union, cast import numpy as np import pandas as pd @@ -36,10 +36,9 @@ except ImportError: cftime = None -if TYPE_CHECKING: - from xarray.core.types import CFCalendar, NPDatetimeUnitOptions, T_DuckArray +from xarray.core.types import CFCalendar, NPDatetimeUnitOptions, T_DuckArray - T_Name = Union[Hashable, None] +T_Name = Union[Hashable, None] # standard calendars recognized by cftime _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} From 64765c870d2d84e05f12a5bd45f9da8c1cf9d4d1 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 12 Jul 2024 21:43:56 +0200 Subject: [PATCH 16/23] fix import --- xarray/coding/cftimeindex.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 0747074ac46..cd902257902 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -555,6 +555,8 @@ def shift( # type: ignore[override] # freq is typed Any, we are more precise CFTimeIndex([2000-02-01 12:00:00], dtype='object', length=1, calendar='standard', freq=None) """ + from xarray.coding.cftime_offsets import BaseCFTimeOffset + if freq is None: # None type is required to be compatible with base pd.Index class raise TypeError( @@ -570,7 +572,7 @@ def shift( # type: ignore[override] # freq is typed Any, we are more precise return self + periods * to_offset(freq) raise TypeError( - f"'freq' must be of type str or datetime.timedelta, got {freq}." + f"'freq' must be of type str or datetime.timedelta, got {type(freq)}." ) def __add__(self, other) -> Self: From 5b1d98cde85eccf4cb966f394584945127354aed Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 12 Jul 2024 23:00:02 +0200 Subject: [PATCH 17/23] improve cftime offsets typing --- xarray/coding/cftime_offsets.py | 133 ++++++++++++---------- xarray/tests/test_cftimeindex_resample.py | 4 +- 2 files changed, 73 insertions(+), 64 deletions(-) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 064a77831fb..90478dffe2d 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -43,9 +43,10 @@ from __future__ import annotations import re +from collections.abc import Mapping from datetime import datetime, timedelta from functools import partial -from typing import TYPE_CHECKING, ClassVar +from typing import TYPE_CHECKING, ClassVar, Literal import numpy as np import pandas as pd @@ -74,7 +75,10 @@ if TYPE_CHECKING: - from xarray.core.types import InclusiveOptions, SideOptions + from xarray.core.types import InclusiveOptions, Self, SideOptions, TypeAlias + + +DayOption: TypeAlias = Literal["start", "end"] def _nanosecond_precision_timestamp(*args, **kwargs): @@ -108,10 +112,11 @@ def get_date_type(calendar, use_cftime=True): class BaseCFTimeOffset: - _freq: ClassVar[str | None] = None - _day_option: ClassVar[str | None] = None + _freq: ClassVar[str] + _day_option: ClassVar[DayOption] + n: int - def __init__(self, n: int = 1): + def __init__(self, n: int = 1) -> None: if not isinstance(n, int): raise TypeError( "The provided multiple 'n' must be an integer. " @@ -119,13 +124,15 @@ def __init__(self, n: int = 1): ) self.n = n - def rule_code(self): + def rule_code(self) -> str: return self._freq - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, BaseCFTimeOffset): + return NotImplemented return self.n == other.n and self.rule_code() == other.rule_code() - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self == other def __add__(self, other): @@ -142,12 +149,12 @@ def __sub__(self, other): else: return NotImplemented - def __mul__(self, other): + def __mul__(self, other: int) -> Self: if not isinstance(other, int): return NotImplemented return type(self)(n=other * self.n) - def __neg__(self): + def __neg__(self) -> Self: return self * -1 def __rmul__(self, other): @@ -161,10 +168,10 @@ def __rsub__(self, other): raise TypeError("Cannot subtract cftime offsets of differing types") return -self + other - def __apply__(self): + def __apply__(self, other): return NotImplemented - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" test_date = (self + date) - self @@ -197,22 +204,21 @@ def _get_offset_day(self, other): class Tick(BaseCFTimeOffset): # analogous https://github.com/pandas-dev/pandas/blob/ccb25ab1d24c4fb9691270706a59c8d319750870/pandas/_libs/tslibs/offsets.pyx#L806 - def _next_higher_resolution(self): + def _next_higher_resolution(self) -> Tick: self_type = type(self) - if self_type not in [Day, Hour, Minute, Second, Millisecond]: - raise ValueError("Could not convert to integer offset at any resolution") - if type(self) is Day: + if self_type is Day: return Hour(self.n * 24) - if type(self) is Hour: + if self_type is Hour: return Minute(self.n * 60) - if type(self) is Minute: + if self_type is Minute: return Second(self.n * 60) - if type(self) is Second: + if self_type is Second: return Millisecond(self.n * 1000) - if type(self) is Millisecond: + if self_type is Millisecond: return Microsecond(self.n * 1000) + raise ValueError("Could not convert to integer offset at any resolution") - def __mul__(self, other): + def __mul__(self, other: int | float) -> Tick: if not isinstance(other, (int, float)): return NotImplemented if isinstance(other, float): @@ -227,12 +233,12 @@ def __mul__(self, other): return new_self * other return type(self)(n=other * self.n) - def as_timedelta(self): + def as_timedelta(self) -> timedelta: """All Tick subclasses must implement an as_timedelta method.""" raise NotImplementedError -def _get_day_of_month(other, day_option): +def _get_day_of_month(other, day_option: DayOption) -> int: """Find the day in `other`'s month that satisfies a BaseCFTimeOffset's onOffset policy, as described by the `day_option` argument. @@ -251,14 +257,13 @@ def _get_day_of_month(other, day_option): if day_option == "start": return 1 - elif day_option == "end": + if day_option == "end": return _days_in_month(other) - elif day_option is None: + if day_option is None: # Note: unlike `_shift_month`, _get_day_of_month does not # allow day_option = None raise NotImplementedError() - else: - raise ValueError(day_option) + raise ValueError(day_option) def _days_in_month(date): @@ -293,7 +298,7 @@ def _adjust_n_years(other, n, month, reference_day): return n -def _shift_month(date, months, day_option="start"): +def _shift_month(date, months, day_option: DayOption = "start"): """Shift the date to a month start or end a given number of months away.""" if cftime is None: raise ModuleNotFoundError("No module named 'cftime'") @@ -316,7 +321,9 @@ def _shift_month(date, months, day_option="start"): return date.replace(year=year, month=month, day=day) -def roll_qtrday(other, n, month, day_option, modby=3): +def roll_qtrday( + other, n: int, month: int, day_option: DayOption, modby: int = 3 +) -> int: """Possibly increment or decrement the number of periods to shift based on rollforward/rollbackward conventions. @@ -357,7 +364,7 @@ def roll_qtrday(other, n, month, day_option, modby=3): return n -def _validate_month(month, default_month): +def _validate_month(month: int | None, default_month: int) -> int: result_month = default_month if month is None else month if not isinstance(result_month, int): raise TypeError( @@ -381,7 +388,7 @@ def __apply__(self, other): n = _adjust_n_months(other.day, self.n, 1) return _shift_month(other, n, "start") - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == 1 @@ -394,7 +401,7 @@ def __apply__(self, other): n = _adjust_n_months(other.day, self.n, _days_in_month(other)) return _shift_month(other, n, "end") - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == _days_in_month(date) @@ -419,10 +426,10 @@ def onOffset(self, date): class QuarterOffset(BaseCFTimeOffset): """Quarter representation copied off of pandas/tseries/offsets.py""" - _freq: ClassVar[str] _default_month: ClassVar[int] + month: int - def __init__(self, n=1, month=None): + def __init__(self, n: int = 1, month: int | None = None) -> None: BaseCFTimeOffset.__init__(self, n) self.month = _validate_month(month, self._default_month) @@ -439,29 +446,28 @@ def __apply__(self, other): months = qtrs * 3 - months_since return _shift_month(other, months, self._day_option) - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" mod_month = (date.month - self.month) % 3 return mod_month == 0 and date.day == self._get_offset_day(date) - def __sub__(self, other): + def __sub__(self, other: Self) -> Self: if cftime is None: raise ModuleNotFoundError("No module named 'cftime'") if isinstance(other, cftime.datetime): raise TypeError("Cannot subtract cftime.datetime from offset.") - elif type(other) == type(self) and other.month == self.month: + if type(other) == type(self) and other.month == self.month: return type(self)(self.n - other.n, month=self.month) - else: - return NotImplemented + return NotImplemented def __mul__(self, other): if isinstance(other, float): return NotImplemented return type(self)(n=other * self.n, month=self.month) - def rule_code(self): + def rule_code(self) -> str: return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}" def __str__(self): @@ -519,11 +525,10 @@ def rollback(self, date): class YearOffset(BaseCFTimeOffset): - _freq: ClassVar[str] - _day_option: ClassVar[str] _default_month: ClassVar[int] + month: int - def __init__(self, n=1, month=None): + def __init__(self, n: int = 1, month: int | None = None) -> None: BaseCFTimeOffset.__init__(self, n) self.month = _validate_month(month, self._default_month) @@ -549,10 +554,10 @@ def __mul__(self, other): return NotImplemented return type(self)(n=other * self.n, month=self.month) - def rule_code(self): + def rule_code(self) -> str: return f"{self._freq}-{_MONTH_ABBREVIATIONS[self.month]}" - def __str__(self): + def __str__(self) -> str: return f"<{type(self).__name__}: n={self.n}, month={self.month}>" @@ -561,7 +566,7 @@ class YearBegin(YearOffset): _day_option = "start" _default_month = 1 - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == 1 and date.month == self.month @@ -586,7 +591,7 @@ class YearEnd(YearOffset): _day_option = "end" _default_month = 12 - def onOffset(self, date): + def onOffset(self, date) -> bool: """Check if the given date is in the set of possible dates created using a length-one version of this offset class.""" return date.day == _days_in_month(date) and date.month == self.month @@ -609,7 +614,7 @@ def rollback(self, date): class Day(Tick): _freq = "D" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(days=self.n) def __apply__(self, other): @@ -619,7 +624,7 @@ def __apply__(self, other): class Hour(Tick): _freq = "h" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(hours=self.n) def __apply__(self, other): @@ -629,7 +634,7 @@ def __apply__(self, other): class Minute(Tick): _freq = "min" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(minutes=self.n) def __apply__(self, other): @@ -639,7 +644,7 @@ def __apply__(self, other): class Second(Tick): _freq = "s" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(seconds=self.n) def __apply__(self, other): @@ -649,7 +654,7 @@ def __apply__(self, other): class Millisecond(Tick): _freq = "ms" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(milliseconds=self.n) def __apply__(self, other): @@ -659,30 +664,32 @@ def __apply__(self, other): class Microsecond(Tick): _freq = "us" - def as_timedelta(self): + def as_timedelta(self) -> timedelta: return timedelta(microseconds=self.n) def __apply__(self, other): return other + self.as_timedelta() -def _generate_anchored_offsets(base_freq, offset): - offsets = {} +def _generate_anchored_offsets( + base_freq: str, offset: type[YearOffset | QuarterOffset] +) -> dict[str, type[BaseCFTimeOffset]]: + offsets: dict[str, type[BaseCFTimeOffset]] = {} for month, abbreviation in _MONTH_ABBREVIATIONS.items(): anchored_freq = f"{base_freq}-{abbreviation}" - offsets[anchored_freq] = partial(offset, month=month) + offsets[anchored_freq] = partial(offset, month=month) # type: ignore[assignment] return offsets -_FREQUENCIES = { +_FREQUENCIES: Mapping[str, type[BaseCFTimeOffset]] = { "A": YearEnd, "AS": YearBegin, "Y": YearEnd, "YE": YearEnd, "YS": YearBegin, - "Q": partial(QuarterEnd, month=12), - "QE": partial(QuarterEnd, month=12), - "QS": partial(QuarterBegin, month=1), + "Q": partial(QuarterEnd, month=12), # type: ignore[dict-item] + "QE": partial(QuarterEnd, month=12), # type: ignore[dict-item] + "QS": partial(QuarterBegin, month=1), # type: ignore[dict-item] "M": MonthEnd, "ME": MonthEnd, "MS": MonthBegin, @@ -717,7 +724,9 @@ def _generate_anchored_offsets(base_freq, offset): CFTIME_TICKS = (Day, Hour, Minute, Second) -def _generate_anchored_deprecated_frequencies(deprecated, recommended): +def _generate_anchored_deprecated_frequencies( + deprecated: str, recommended: str +) -> dict[str, str]: pairs = {} for abbreviation in _MONTH_ABBREVIATIONS.values(): anchored_deprecated = f"{deprecated}-{abbreviation}" @@ -726,7 +735,7 @@ def _generate_anchored_deprecated_frequencies(deprecated, recommended): return pairs -_DEPRECATED_FREQUENICES = { +_DEPRECATED_FREQUENICES: dict[str, str] = { "A": "YE", "Y": "YE", "AS": "YS", diff --git a/xarray/tests/test_cftimeindex_resample.py b/xarray/tests/test_cftimeindex_resample.py index 08f71ada634..3dda7a5f1eb 100644 --- a/xarray/tests/test_cftimeindex_resample.py +++ b/xarray/tests/test_cftimeindex_resample.py @@ -251,11 +251,11 @@ def test_base_and_offset_error(): @pytest.mark.parametrize("offset", ["foo", "5MS", 10]) -def test_invalid_offset_error(offset) -> None: +def test_invalid_offset_error(offset: str | int) -> None: cftime_index = xr.cftime_range("2000", periods=5) da_cftime = da(cftime_index) with pytest.raises(ValueError, match="offset must be"): - da_cftime.resample(time="2D", offset=offset) + da_cftime.resample(time="2D", offset=offset) # type: ignore[arg-type] def test_timedelta_offset() -> None: From 16bee67280aab8b37f263eaa7e6cac2d46525134 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 12 Jul 2024 23:22:01 +0200 Subject: [PATCH 18/23] fix classvars --- xarray/coding/cftime_offsets.py | 6 +++--- xarray/tests/test_cftime_offsets.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index 90478dffe2d..9dbc60ef0f3 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -112,8 +112,8 @@ def get_date_type(calendar, use_cftime=True): class BaseCFTimeOffset: - _freq: ClassVar[str] - _day_option: ClassVar[DayOption] + _freq: ClassVar[str | None] = None + _day_option: ClassVar[DayOption | None] = None n: int def __init__(self, n: int = 1) -> None: @@ -124,7 +124,7 @@ def __init__(self, n: int = 1) -> None: ) self.n = n - def rule_code(self) -> str: + def rule_code(self) -> str | None: return self._freq def __eq__(self, other: object) -> bool: diff --git a/xarray/tests/test_cftime_offsets.py b/xarray/tests/test_cftime_offsets.py index eabb7d2f4d6..78aa49c7f83 100644 --- a/xarray/tests/test_cftime_offsets.py +++ b/xarray/tests/test_cftime_offsets.py @@ -511,7 +511,7 @@ def test_Microsecond_multiplied_float_error(): ], ids=_id_func, ) -def test_neg(offset, expected): +def test_neg(offset: BaseCFTimeOffset, expected: BaseCFTimeOffset) -> None: assert -offset == expected From 602c3cf3caa94883e1e4fa6957770947ee5595fa Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 12 Jul 2024 23:34:03 +0200 Subject: [PATCH 19/23] fix some checks --- xarray/core/resample_cftime.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index fcdeb6d8021..a048e85b4d4 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -124,10 +124,10 @@ def __init__( if offset is not None: try: self.offset = _convert_offset_to_timedelta(offset) - except (AttributeError, TypeError) as error: + except (ValueError, TypeError) as error: raise ValueError( f"offset must be a datetime.timedelta object or an offset string " - f"that can be converted to a timedelta. Got {type(offset)} instead." + f"that can be converted to a timedelta. Got {type(offset)} instead." ) from error else: self.offset = None @@ -505,8 +505,8 @@ def _convert_offset_to_timedelta( return offset if isinstance(offset, (str, Tick)): timedelta_cftime_offset = to_offset(offset) - assert isinstance(timedelta_cftime_offset, Tick) - return timedelta_cftime_offset.as_timedelta() + if isinstance(timedelta_cftime_offset, Tick): + return timedelta_cftime_offset.as_timedelta() raise TypeError(f"Expected timedelta, str or Tick, got {type(offset)}") From 07a52aef59eb204916384bdeb090acc4f7f41d1e Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 13 Jul 2024 23:10:14 +0200 Subject: [PATCH 20/23] fix a broken test --- xarray/tests/test_concat.py | 2 +- xarray/tests/test_dataset.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 0c570de3b52..f6d17d6a08a 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -515,7 +515,7 @@ def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: assert_identical(expected, actual) @pytest.mark.parametrize("data", [False], indirect=["data"]) - def test_concat_2(self, data) -> None: + def test_concat_2(self, data: Dataset) -> None: dim = "dim2" datasets = [g.squeeze(dim) for _, g in data.groupby(dim, squeeze=False)] concat_over = [k for k, v in data.coords.items() if dim in v.dims and k != dim] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 1ba45616ab0..4db005ca3fb 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -581,7 +581,8 @@ def test_constructor_pandas_single(self) -> None: pandas_obj = a.to_pandas() ds_based_on_pandas = Dataset(pandas_obj) # type: ignore # TODO: improve typing of __init__ for dim in ds_based_on_pandas.data_vars: - assert_array_equal(ds_based_on_pandas[dim], pandas_obj[str(dim)]) + assert isinstance(dim, int) + assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) def test_constructor_compat(self) -> None: data = {"x": DataArray(0, coords={"y": 1}), "y": ("z", [1, 1, 1])} From a5eaa82d9ffe69b065cc7b45bddec873c65cc99c Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 15 Jul 2024 21:10:28 +0200 Subject: [PATCH 21/23] improve typing of test_concat --- xarray/tests/test_concat.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index f6d17d6a08a..8b2a7ec5d28 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Literal import numpy as np import pandas as pd @@ -474,7 +474,7 @@ def data(self, request) -> Dataset: "dim3" ) - def rectify_dim_order(self, data, dataset) -> Dataset: + def rectify_dim_order(self, data: Dataset, dataset) -> Dataset: # return a new dataset with all variable dimensions transposed into # the order in which they are found in `data` return Dataset( @@ -487,11 +487,13 @@ def rectify_dim_order(self, data, dataset) -> Dataset: @pytest.mark.parametrize( "dim,data", [["dim1", True], ["dim2", False]], indirect=["data"] ) - def test_concat_simple(self, data, dim, coords) -> None: + def test_concat_simple(self, data: Dataset, dim, coords) -> None: datasets = [g for _, g in data.groupby(dim, squeeze=False)] assert_identical(data, concat(datasets, dim, coords=coords)) - def test_concat_merge_variables_present_in_some_datasets(self, data) -> None: + def test_concat_merge_variables_present_in_some_datasets( + self, data: Dataset + ) -> None: # coordinates present in some datasets but not others ds1 = Dataset(data_vars={"a": ("y", [0.1])}, coords={"x": 0.1}) ds2 = Dataset(data_vars={"a": ("y", [0.2])}, coords={"z": 0.2}) @@ -524,7 +526,9 @@ def test_concat_2(self, data: Dataset) -> None: @pytest.mark.parametrize("coords", ["different", "minimal", "all"]) @pytest.mark.parametrize("dim", ["dim1", "dim2"]) - def test_concat_coords_kwarg(self, data, dim, coords) -> None: + def test_concat_coords_kwarg( + self, data: Dataset, dim: str, coords: Literal["all", "minimal", "different"] + ) -> None: data = data.copy(deep=True) # make sure the coords argument behaves as expected data.coords["extra"] = ("dim4", np.arange(3)) @@ -538,7 +542,7 @@ def test_concat_coords_kwarg(self, data, dim, coords) -> None: else: assert_equal(data["extra"], actual["extra"]) - def test_concat(self, data) -> None: + def test_concat(self, data: Dataset) -> None: split_data = [ data.isel(dim1=slice(3)), data.isel(dim1=3), @@ -546,7 +550,7 @@ def test_concat(self, data) -> None: ] assert_identical(data, concat(split_data, "dim1")) - def test_concat_dim_precedence(self, data) -> None: + def test_concat_dim_precedence(self, data: Dataset) -> None: # verify that the dim argument takes precedence over # concatenating dataset variables of the same name dim = (2 * data["dim1"]).rename("dim1") From 2f12f888bcc3245cc7f4e2656f6ddf5a09f2012a Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 15 Jul 2024 21:20:12 +0200 Subject: [PATCH 22/23] fix broken concat --- xarray/core/concat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 687278083a8..15292bdb34b 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -598,7 +598,7 @@ def get_indexes(name): for ds in datasets: if name in ds._indexes: yield ds._indexes[name] - elif name == dim: + elif name == dim_name: var = ds._variables[name] if not var.dims: data = var.set_dims(dim_name).values From 4bfec5e7b49e3c2763635c84e8b9478a6643da7c Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Mon, 15 Jul 2024 22:25:16 +0200 Subject: [PATCH 23/23] add whats-new --- doc/whats-new.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 6a8e898c93c..e7ff97edaba 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -79,6 +79,8 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Enable typing checks of pandas (:pull:`9213`). + By `Michael Niklas `_. .. _whats-new.2024.06.0: