Skip to content

Commit

Permalink
Merge branch 'main' into grouper-public-api
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian authored Jul 8, 2024
2 parents 64f78cd + 04b38a0 commit 58de38f
Show file tree
Hide file tree
Showing 26 changed files with 873 additions and 341 deletions.
113 changes: 112 additions & 1 deletion asv_bench/benchmarks/dataset_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd

import xarray as xr
from xarray.backends.api import open_datatree
from xarray.core.datatree import DataTree

from . import _skip_slow, parameterized, randint, randn, requires_dask

Expand All @@ -16,7 +18,6 @@
except ImportError:
pass


os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

_ENGINES = tuple(xr.backends.list_engines().keys() - {"store"})
Expand Down Expand Up @@ -469,6 +470,116 @@ def create_delayed_write():
return ds.to_netcdf("file.nc", engine="netcdf4", compute=False)


class IONestedDataTree:
"""
A few examples that benchmark reading/writing a heavily nested netCDF datatree with
xarray
"""

timeout = 300.0
repeat = 1
number = 5

def make_datatree(self, nchildren=10):
# multiple Dataset
self.ds = xr.Dataset()
self.nt = 1000
self.nx = 90
self.ny = 45
self.nchildren = nchildren

self.block_chunks = {
"time": self.nt / 4,
"lon": self.nx / 3,
"lat": self.ny / 3,
}

self.time_chunks = {"time": int(self.nt / 36)}

times = pd.date_range("1970-01-01", periods=self.nt, freq="D")
lons = xr.DataArray(
np.linspace(0, 360, self.nx),
dims=("lon",),
attrs={"units": "degrees east", "long_name": "longitude"},
)
lats = xr.DataArray(
np.linspace(-90, 90, self.ny),
dims=("lat",),
attrs={"units": "degrees north", "long_name": "latitude"},
)
self.ds["foo"] = xr.DataArray(
randn((self.nt, self.nx, self.ny), frac_nan=0.2),
coords={"lon": lons, "lat": lats, "time": times},
dims=("time", "lon", "lat"),
name="foo",
attrs={"units": "foo units", "description": "a description"},
)
self.ds["bar"] = xr.DataArray(
randn((self.nt, self.nx, self.ny), frac_nan=0.2),
coords={"lon": lons, "lat": lats, "time": times},
dims=("time", "lon", "lat"),
name="bar",
attrs={"units": "bar units", "description": "a description"},
)
self.ds["baz"] = xr.DataArray(
randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32),
coords={"lon": lons, "lat": lats},
dims=("lon", "lat"),
name="baz",
attrs={"units": "baz units", "description": "a description"},
)

self.ds.attrs = {"history": "created for xarray benchmarking"}

self.oinds = {
"time": randint(0, self.nt, 120),
"lon": randint(0, self.nx, 20),
"lat": randint(0, self.ny, 10),
}
self.vinds = {
"time": xr.DataArray(randint(0, self.nt, 120), dims="x"),
"lon": xr.DataArray(randint(0, self.nx, 120), dims="x"),
"lat": slice(3, 20),
}
root = {f"group_{group}": self.ds for group in range(self.nchildren)}
nested_tree1 = {
f"group_{group}/subgroup_1": xr.Dataset() for group in range(self.nchildren)
}
nested_tree2 = {
f"group_{group}/subgroup_2": xr.DataArray(np.arange(1, 10)).to_dataset(
name="a"
)
for group in range(self.nchildren)
}
nested_tree3 = {
f"group_{group}/subgroup_2/sub-subgroup_1": self.ds
for group in range(self.nchildren)
}
dtree = root | nested_tree1 | nested_tree2 | nested_tree3
self.dtree = DataTree.from_dict(dtree)


class IOReadDataTreeNetCDF4(IONestedDataTree):
def setup(self):
# TODO: Lazily skipped in CI as it is very demanding and slow.
# Improve times and remove errors.
_skip_slow()

requires_dask()

self.make_datatree()
self.format = "NETCDF4"
self.filepath = "datatree.nc4.nc"
dtree = self.dtree
dtree.to_netcdf(filepath=self.filepath)

def time_load_datatree_netcdf4(self):
open_datatree(self.filepath, engine="netcdf4").load()

def time_open_datatree_netcdf4(self):
open_datatree(self.filepath, engine="netcdf4")


class IOWriteNetCDFDask:
timeout = 60
repeat = 1
Expand Down
14 changes: 13 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ New Features
~~~~~~~~~~~~
- Allow chunking for arrays with duplicated dimension names (:issue:`8759`, :pull:`9099`).
By `Martin Raspaud <https://github.com/mraspaud>`_.
- Extract the source url from fsspec objects (:issue:`9142`, :pull:`8923`).
By `Justus Magin <https://github.com/keewis>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand All @@ -35,11 +37,21 @@ Deprecations

Bug fixes
~~~~~~~~~
- Don't convert custom indexes to ``pandas`` indexes when computing a diff (:pull:`9157`)
By `Justus Magin <https://github.com/keewis>`_.
- Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`).
By `Pontus Lurcock <https://github.com/pont-us>`_.
- Allow diffing objects with array attributes on variables (:issue:`9153`, :pull:`9169`).
By `Justus Magin <https://github.com/keewis>`_.
- Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`).
By `Justus Magin <https://github.com/keewis>`_.

- Fiy static typing of tolerance arguments by allowing `str` type (:issue:`8892`, :pull:`9194`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Dark themes are now properly detected for ``html[data-theme=dark]``-tags (:pull:`9200`).
By `Dieter Werthmüller <https://github.com/prisae>`_.
- Reductions no longer fail for ``np.complex_`` dtype arrays when numbagg is
installed.
By `Maximilian Roos <https://github.com/max-sixty>`_

Documentation
~~~~~~~~~~~~~
Expand Down
39 changes: 23 additions & 16 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

import xarray as xr
from xarray.tests import has_pandas_3

pytest.importorskip("hypothesis")
import hypothesis.extra.numpy as npst # isort:skip
Expand All @@ -25,22 +24,34 @@

numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))


@st.composite
def dataframe_strategy(draw):
tz = draw(st.timezones())
dtype = pd.DatetimeTZDtype(unit="ns", tz=tz)

datetimes = st.datetimes(
min_value=pd.Timestamp("1677-09-21T00:12:43.145224193"),
max_value=pd.Timestamp("2262-04-11T23:47:16.854775807"),
timezones=st.just(tz),
)

df = pdst.data_frames(
[
pdst.column("datetime_col", elements=datetimes),
pdst.column("other_col", elements=st.integers()),
],
index=pdst.range_indexes(min_size=1, max_size=10),
)
return draw(df).astype({"datetime_col": dtype})


an_array = npst.arrays(
dtype=numeric_dtypes,
shape=npst.array_shapes(max_dims=2), # can only convert 1D/2D to pandas
)


datetime_with_tz_strategy = st.datetimes(timezones=st.timezones())
dataframe_strategy = pdst.data_frames(
[
pdst.column("datetime_col", elements=datetime_with_tz_strategy),
pdst.column("other_col", elements=st.integers()),
],
index=pdst.range_indexes(min_size=1, max_size=10),
)


@st.composite
def datasets_1d_vars(draw) -> xr.Dataset:
"""Generate datasets with only 1D variables
Expand Down Expand Up @@ -111,11 +122,7 @@ def test_roundtrip_pandas_dataframe(df) -> None:
xr.testing.assert_identical(arr, roundtripped.to_xarray())


@pytest.mark.skipif(
has_pandas_3,
reason="fails to roundtrip on pandas 3 (see https://github.com/pydata/xarray/issues/9098)",
)
@given(df=dataframe_strategy)
@given(df=dataframe_strategy())
def test_roundtrip_pandas_dataframe_datetime(df) -> None:
# Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'.
df.index.name = "rows"
Expand Down
17 changes: 17 additions & 0 deletions properties/test_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

pytest.importorskip("hypothesis")

from hypothesis import given

import xarray as xr
import xarray.testing.strategies as xrst


@given(attrs=xrst.simple_attrs)
def test_assert_identical(attrs):
v = xr.Variable(dims=(), data=0, attrs=attrs)
xr.testing.assert_identical(v, v.copy(deep=True))

ds = xr.Dataset(attrs=attrs)
xr.testing.assert_identical(ds, ds.copy(deep=True))
7 changes: 5 additions & 2 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,11 @@ def _dataset_from_backend_dataset(
ds.set_close(backend_ds._close)

# Ensure source filename always stored in dataset object
if "source" not in ds.encoding and isinstance(filename_or_obj, (str, os.PathLike)):
ds.encoding["source"] = _normalize_path(filename_or_obj)
if "source" not in ds.encoding:
path = getattr(filename_or_obj, "path", filename_or_obj)

if isinstance(path, (str, os.PathLike)):
ds.encoding["source"] = _normalize_path(path)

return ds

Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def open_store(
stacklevel=stacklevel,
zarr_version=zarr_version,
)
group_paths = [str(group / node[1:]) for node in _iter_zarr_groups(zarr_group)]
group_paths = [node for node in _iter_zarr_groups(zarr_group, parent=group)]
return {
group: cls(
zarr_group.get(group),
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
exclude_dims: str | Iterable[Hashable] = frozenset(),
exclude_vars: Iterable[Hashable] = frozenset(),
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
sparse: bool = False,
Expand Down Expand Up @@ -965,7 +965,7 @@ def reindex(
obj: T_Alignable,
indexers: Mapping[Any, Any],
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
sparse: bool = False,
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def reindex_like(
obj: T_Alignable,
other: Dataset | DataArray,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
) -> T_Alignable:
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,7 +1910,7 @@ def reindex_like(
other: T_DataArrayOrSet,
*,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value=dtypes.NA,
) -> Self:
Expand All @@ -1937,7 +1937,7 @@ def reindex_like(
- backfill / bfill: propagate next valid index value backward
- nearest: use nearest valid index value
tolerance : optional
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down Expand Up @@ -2097,7 +2097,7 @@ def reindex(
indexers: Mapping[Any, Any] | None = None,
*,
method: ReindexMethodOptions = None,
tolerance: float | Iterable[float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value=dtypes.NA,
**indexers_kwargs: Any,
Expand Down Expand Up @@ -2127,7 +2127,7 @@ def reindex(
- backfill / bfill: propagate next valid index value backward
- nearest: use nearest valid index value
tolerance : float | Iterable[float] | None, default: None
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3500,7 +3500,7 @@ def reindex_like(
self,
other: T_Xarray,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = xrdtypes.NA,
) -> Self:
Expand All @@ -3527,7 +3527,7 @@ def reindex_like(
- "backfill" / "bfill": propagate next valid index value backward
- "nearest": use nearest valid index value
tolerance : optional
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down Expand Up @@ -3570,7 +3570,7 @@ def reindex(
self,
indexers: Mapping[Any, Any] | None = None,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = xrdtypes.NA,
**indexers_kwargs: Any,
Expand All @@ -3595,7 +3595,7 @@ def reindex(
- "backfill" / "bfill": propagate next valid index value backward
- "nearest": use nearest valid index value
tolerance : optional
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down
Loading

0 comments on commit 58de38f

Please sign in to comment.