Skip to content

Commit

Permalink
Merge branch 'main' into numpy2-netcdf4
Browse files Browse the repository at this point in the history
  • Loading branch information
keewis authored Jul 2, 2024
2 parents 1f77517 + 6c2d8c3 commit 0176688
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 46 deletions.
113 changes: 112 additions & 1 deletion asv_bench/benchmarks/dataset_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd

import xarray as xr
from xarray.backends.api import open_datatree
from xarray.core.datatree import DataTree

from . import _skip_slow, parameterized, randint, randn, requires_dask

Expand All @@ -16,7 +18,6 @@
except ImportError:
pass


os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

_ENGINES = tuple(xr.backends.list_engines().keys() - {"store"})
Expand Down Expand Up @@ -469,6 +470,116 @@ def create_delayed_write():
return ds.to_netcdf("file.nc", engine="netcdf4", compute=False)


class IONestedDataTree:
"""
A few examples that benchmark reading/writing a heavily nested netCDF datatree with
xarray
"""

timeout = 300.0
repeat = 1
number = 5

def make_datatree(self, nchildren=10):
# multiple Dataset
self.ds = xr.Dataset()
self.nt = 1000
self.nx = 90
self.ny = 45
self.nchildren = nchildren

self.block_chunks = {
"time": self.nt / 4,
"lon": self.nx / 3,
"lat": self.ny / 3,
}

self.time_chunks = {"time": int(self.nt / 36)}

times = pd.date_range("1970-01-01", periods=self.nt, freq="D")
lons = xr.DataArray(
np.linspace(0, 360, self.nx),
dims=("lon",),
attrs={"units": "degrees east", "long_name": "longitude"},
)
lats = xr.DataArray(
np.linspace(-90, 90, self.ny),
dims=("lat",),
attrs={"units": "degrees north", "long_name": "latitude"},
)
self.ds["foo"] = xr.DataArray(
randn((self.nt, self.nx, self.ny), frac_nan=0.2),
coords={"lon": lons, "lat": lats, "time": times},
dims=("time", "lon", "lat"),
name="foo",
attrs={"units": "foo units", "description": "a description"},
)
self.ds["bar"] = xr.DataArray(
randn((self.nt, self.nx, self.ny), frac_nan=0.2),
coords={"lon": lons, "lat": lats, "time": times},
dims=("time", "lon", "lat"),
name="bar",
attrs={"units": "bar units", "description": "a description"},
)
self.ds["baz"] = xr.DataArray(
randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32),
coords={"lon": lons, "lat": lats},
dims=("lon", "lat"),
name="baz",
attrs={"units": "baz units", "description": "a description"},
)

self.ds.attrs = {"history": "created for xarray benchmarking"}

self.oinds = {
"time": randint(0, self.nt, 120),
"lon": randint(0, self.nx, 20),
"lat": randint(0, self.ny, 10),
}
self.vinds = {
"time": xr.DataArray(randint(0, self.nt, 120), dims="x"),
"lon": xr.DataArray(randint(0, self.nx, 120), dims="x"),
"lat": slice(3, 20),
}
root = {f"group_{group}": self.ds for group in range(self.nchildren)}
nested_tree1 = {
f"group_{group}/subgroup_1": xr.Dataset() for group in range(self.nchildren)
}
nested_tree2 = {
f"group_{group}/subgroup_2": xr.DataArray(np.arange(1, 10)).to_dataset(
name="a"
)
for group in range(self.nchildren)
}
nested_tree3 = {
f"group_{group}/subgroup_2/sub-subgroup_1": self.ds
for group in range(self.nchildren)
}
dtree = root | nested_tree1 | nested_tree2 | nested_tree3
self.dtree = DataTree.from_dict(dtree)


class IOReadDataTreeNetCDF4(IONestedDataTree):
def setup(self):
# TODO: Lazily skipped in CI as it is very demanding and slow.
# Improve times and remove errors.
_skip_slow()

requires_dask()

self.make_datatree()
self.format = "NETCDF4"
self.filepath = "datatree.nc4.nc"
dtree = self.dtree
dtree.to_netcdf(filepath=self.filepath)

def time_load_datatree_netcdf4(self):
open_datatree(self.filepath, engine="netcdf4").load()

def time_open_datatree_netcdf4(self):
open_datatree(self.filepath, engine="netcdf4")


class IOWriteNetCDFDask:
timeout = 60
repeat = 1
Expand Down
7 changes: 6 additions & 1 deletion doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,16 @@ Bug fixes
~~~~~~~~~
- Make :py:func:`testing.assert_allclose` work with numpy 2.0 (:issue:`9165`, :pull:`9166`).
By `Pontus Lurcock <https://github.com/pont-us>`_.
- Allow diffing objects with array attributes on variables (:issue:`9153`, :pull:`9169`).
By `Justus Magin <https://github.com/keewis>`_.
- ``numpy>=2`` compatibility in the ``netcdf4`` backend (:pull:`9136`).
By `Justus Magin <https://github.com/keewis>`_ and `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Promote floating-point numeric datetimes before decoding (:issue:`9179`, :pull:`9182`).
By `Justus Magin <https://github.com/keewis>`_.

- Fiy static typing of tolerance arguments by allowing `str` type (:issue:`8892`, :pull:`9194`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Dark themes are now properly detected for ``html[data-theme=dark]``-tags (:pull:`9200`).
By `Dieter Werthmüller <https://github.com/prisae>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
39 changes: 23 additions & 16 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest

import xarray as xr
from xarray.tests import has_pandas_3

pytest.importorskip("hypothesis")
import hypothesis.extra.numpy as npst # isort:skip
Expand All @@ -25,22 +24,34 @@

numeric_series = numeric_dtypes.flatmap(lambda dt: pdst.series(dtype=dt))


@st.composite
def dataframe_strategy(draw):
tz = draw(st.timezones())
dtype = pd.DatetimeTZDtype(unit="ns", tz=tz)

datetimes = st.datetimes(
min_value=pd.Timestamp("1677-09-21T00:12:43.145224193"),
max_value=pd.Timestamp("2262-04-11T23:47:16.854775807"),
timezones=st.just(tz),
)

df = pdst.data_frames(
[
pdst.column("datetime_col", elements=datetimes),
pdst.column("other_col", elements=st.integers()),
],
index=pdst.range_indexes(min_size=1, max_size=10),
)
return draw(df).astype({"datetime_col": dtype})


an_array = npst.arrays(
dtype=numeric_dtypes,
shape=npst.array_shapes(max_dims=2), # can only convert 1D/2D to pandas
)


datetime_with_tz_strategy = st.datetimes(timezones=st.timezones())
dataframe_strategy = pdst.data_frames(
[
pdst.column("datetime_col", elements=datetime_with_tz_strategy),
pdst.column("other_col", elements=st.integers()),
],
index=pdst.range_indexes(min_size=1, max_size=10),
)


@st.composite
def datasets_1d_vars(draw) -> xr.Dataset:
"""Generate datasets with only 1D variables
Expand Down Expand Up @@ -111,11 +122,7 @@ def test_roundtrip_pandas_dataframe(df) -> None:
xr.testing.assert_identical(arr, roundtripped.to_xarray())


@pytest.mark.skipif(
has_pandas_3,
reason="fails to roundtrip on pandas 3 (see https://github.com/pydata/xarray/issues/9098)",
)
@given(df=dataframe_strategy)
@given(df=dataframe_strategy())
def test_roundtrip_pandas_dataframe_datetime(df) -> None:
# Need to name the indexes, otherwise Xarray names them 'dim_0', 'dim_1'.
df.index.name = "rows"
Expand Down
17 changes: 17 additions & 0 deletions properties/test_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

pytest.importorskip("hypothesis")

from hypothesis import given

import xarray as xr
import xarray.testing.strategies as xrst


@given(attrs=xrst.simple_attrs)
def test_assert_identical(attrs):
v = xr.Variable(dims=(), data=0, attrs=attrs)
xr.testing.assert_identical(v, v.copy(deep=True))

ds = xr.Dataset(attrs=attrs)
xr.testing.assert_identical(ds, ds.copy(deep=True))
2 changes: 1 addition & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def open_store(
stacklevel=stacklevel,
zarr_version=zarr_version,
)
group_paths = [str(group / node[1:]) for node in _iter_zarr_groups(zarr_group)]
group_paths = [node for node in _iter_zarr_groups(zarr_group, parent=group)]
return {
group: cls(
zarr_group.get(group),
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def __init__(
exclude_dims: str | Iterable[Hashable] = frozenset(),
exclude_vars: Iterable[Hashable] = frozenset(),
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
sparse: bool = False,
Expand Down Expand Up @@ -965,7 +965,7 @@ def reindex(
obj: T_Alignable,
indexers: Mapping[Any, Any],
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
sparse: bool = False,
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def reindex_like(
obj: T_Alignable,
other: Dataset | DataArray,
method: str | None = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = dtypes.NA,
) -> T_Alignable:
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,7 @@ def reindex_like(
other: T_DataArrayOrSet,
*,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value=dtypes.NA,
) -> Self:
Expand All @@ -1936,7 +1936,7 @@ def reindex_like(
- backfill / bfill: propagate next valid index value backward
- nearest: use nearest valid index value
tolerance : optional
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down Expand Up @@ -2096,7 +2096,7 @@ def reindex(
indexers: Mapping[Any, Any] | None = None,
*,
method: ReindexMethodOptions = None,
tolerance: float | Iterable[float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value=dtypes.NA,
**indexers_kwargs: Any,
Expand Down Expand Up @@ -2126,7 +2126,7 @@ def reindex(
- backfill / bfill: propagate next valid index value backward
- nearest: use nearest valid index value
tolerance : float | Iterable[float] | None, default: None
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3499,7 +3499,7 @@ def reindex_like(
self,
other: T_Xarray,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = xrdtypes.NA,
) -> Self:
Expand All @@ -3526,7 +3526,7 @@ def reindex_like(
- "backfill" / "bfill": propagate next valid index value backward
- "nearest": use nearest valid index value
tolerance : optional
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down Expand Up @@ -3569,7 +3569,7 @@ def reindex(
self,
indexers: Mapping[Any, Any] | None = None,
method: ReindexMethodOptions = None,
tolerance: int | float | Iterable[int | float] | None = None,
tolerance: float | Iterable[float] | str | None = None,
copy: bool = True,
fill_value: Any = xrdtypes.NA,
**indexers_kwargs: Any,
Expand All @@ -3594,7 +3594,7 @@ def reindex(
- "backfill" / "bfill": propagate next valid index value backward
- "nearest": use nearest valid index value
tolerance : optional
tolerance : float | Iterable[float] | str | None, default: None
Maximum distance between original and new labels for inexact
matches. The values of the index at the matching locations must
satisfy the equation ``abs(index[indexer] - target) <= tolerance``.
Expand Down
18 changes: 12 additions & 6 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,12 @@ def _diff_mapping_repr(
a_indexes=None,
b_indexes=None,
):
def compare_attr(a, b):
if is_duck_array(a) or is_duck_array(b):
return array_equiv(a, b)
else:
return a == b

def extra_items_repr(extra_keys, mapping, ab_side, kwargs):
extra_repr = [
summarizer(k, mapping[k], col_width, **kwargs[k]) for k in extra_keys
Expand Down Expand Up @@ -801,11 +807,7 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs):
is_variable = True
except AttributeError:
# compare attribute value
if is_duck_array(a_mapping[k]) or is_duck_array(b_mapping[k]):
compatible = array_equiv(a_mapping[k], b_mapping[k])
else:
compatible = a_mapping[k] == b_mapping[k]

compatible = compare_attr(a_mapping[k], b_mapping[k])
is_variable = False

if not compatible:
Expand All @@ -821,7 +823,11 @@ def extra_items_repr(extra_keys, mapping, ab_side, kwargs):

attrs_to_print = set(a_attrs) ^ set(b_attrs)
attrs_to_print.update(
{k for k in set(a_attrs) & set(b_attrs) if a_attrs[k] != b_attrs[k]}
{
k
for k in set(a_attrs) & set(b_attrs)
if not compare_attr(a_attrs[k], b_attrs[k])
}
)
for m in (a_mapping, b_mapping):
attr_s = "\n".join(
Expand Down
Loading

0 comments on commit 0176688

Please sign in to comment.