Skip to content

Commit

Permalink
New alignment option: join='strict'
Browse files Browse the repository at this point in the history
  • Loading branch information
etienneschalk committed Feb 3, 2024
1 parent c9ba2be commit bbe7d05
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 17 deletions.
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ New Features
- Allow negative frequency strings (e.g. ``"-1YE"``). These strings are for example used
in :py:func:`date_range`, and :py:func:`cftime_range` (:pull:`8651`).
By `Mathias Hauser <https://github.com/mathause>`_.
- Added a ``join="exact"`` mode for ``Aligner.align`` and related classes.
(:issue:`7132`, :issue:`8230`).
By `Etienne Schalk <https://github.com/etienneschalk>`_.

- Add :py:meth:`NamedArray.expand_dims`, :py:meth:`NamedArray.permute_dims` and :py:meth:`NamedArray.broadcast_to`
(:pull:`8380`) By `Anderson Banihirwe <https://github.com/andersy005>`_.

Expand Down
19 changes: 11 additions & 8 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,13 +791,15 @@ def open_dataarray(
def open_mfdataset(
paths: str | NestedSequence[str | os.PathLike],
chunks: T_Chunks | None = None,
concat_dim: str
| DataArray
| Index
| Sequence[str]
| Sequence[DataArray]
| Sequence[Index]
| None = None,
concat_dim: (
str
| DataArray
| Index
| Sequence[str]
| Sequence[DataArray]
| Sequence[Index]
| None
) = None,
compat: CompatOptions = "no_conflicts",
preprocess: Callable[[Dataset], Dataset] | None = None,
engine: T_Engine | None = None,
Expand Down Expand Up @@ -912,7 +914,8 @@ def open_mfdataset(
aligned are not equal
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- "strict": similar to "exact", but less permissive.
The alignment fails if dimensions' names differ.
attrs_file : str or path-like, optional
Path of the file used to read global attributes from.
By default global attributes are read from the first file provided,
Expand Down
43 changes: 35 additions & 8 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
from collections import defaultdict
from collections.abc import Hashable, Iterable, Mapping
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Final,
Generic,
TypeVar,
cast,
get_args,
overload,
)

import numpy as np
import pandas as pd
Expand All @@ -19,7 +29,7 @@
indexes_all_equal,
safe_cast_to_index,
)
from xarray.core.types import T_Alignable
from xarray.core.types import JoinOptions, T_Alignable
from xarray.core.utils import is_dict_like, is_full_slice
from xarray.core.variable import Variable, as_compatible_data, calculate_dimensions

Expand All @@ -28,7 +38,6 @@
from xarray.core.dataset import Dataset
from xarray.core.types import (
Alignable,
JoinOptions,
T_DataArray,
T_Dataset,
T_DuckArray,
Expand Down Expand Up @@ -145,7 +154,7 @@ def __init__(
self.objects = tuple(objects)
self.objects_matching_indexes = ()

if join not in ["inner", "outer", "override", "exact", "left", "right"]:
if join not in get_args(JoinOptions):
raise ValueError(f"invalid value for join: {join}")
self.join = join

Expand Down Expand Up @@ -264,13 +273,13 @@ def find_matching_indexes(self) -> None:
self.all_indexes = all_indexes
self.all_index_vars = all_index_vars

if self.join == "override":
if self.join in ("override", "strict"):
for dim_sizes in all_indexes_dim_sizes.values():
for dim, sizes in dim_sizes.items():
if len(sizes) > 1:
raise ValueError(
"cannot align objects with join='override' with matching indexes "
f"along dimension {dim!r} that don't have the same size"
f"cannot align objects with join={self.join!r} with matching indexes "
f"along dimension {dim!r} that don't have the same size ({sizes!r})"
)

def find_matching_unindexed_dims(self) -> None:
Expand Down Expand Up @@ -472,12 +481,27 @@ def assert_unindexed_dim_sizes_equal(self) -> None:
)
else:
add_err_msg = ""
# Same for indexed dims?
if len(sizes) > 1:
raise ValueError(
f"cannot reindex or align along dimension {dim!r} "
f"because of conflicting dimension sizes: {sizes!r}" + add_err_msg
)

def assert_equal_dimension_names(self) -> None:
# Strict mode only allows objects having the exact same dimensions' names.
if not self.join == "strict":
return

unique_dims = set(tuple(o.sizes) for o in self.objects)
all_objects_have_same_dims = len(unique_dims) == 1
if not all_objects_have_same_dims:
raise ValueError(
f"cannot align objects with join='strict' "
f"because given objects do not share the same dimension names ({[tuple(o.sizes) for o in self.objects]!r}); "
f"try using join='exact' if you only care about equal indexes"
)

def override_indexes(self) -> None:
objects = list(self.objects)

Expand Down Expand Up @@ -568,6 +592,7 @@ def align(self) -> None:
self.results = (obj.copy(deep=self.copy),)
return

self.assert_equal_dimension_names()
self.find_matching_indexes()
self.find_matching_unindexed_dims()
self.assert_no_index_conflict()
Expand All @@ -576,7 +601,7 @@ def align(self) -> None:

if self.join == "override":
self.override_indexes()
elif self.join == "exact" and not self.copy:
elif not self.copy and (self.join in ("exact", "strict")):
self.results = self.objects
else:
self.reindex_all()
Expand Down Expand Up @@ -716,6 +741,8 @@ def align(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- "strict": similar to "exact", but less permissive.
The alignment fails if dimensions' names differ.
copy : bool, default: True
If ``copy=True``, data in the return values is always copied. If
Expand Down
5 changes: 5 additions & 0 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def combine_nested(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- "strict": similar to "exact", but less permissive.
The alignment fails if dimensions' names differ.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"} or callable, default: "drop"
A callable or a string indicating how to combine attrs of the objects being
Expand Down Expand Up @@ -737,6 +740,8 @@ def combine_by_coords(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- "strict": similar to "exact", but less permissive.
The alignment fails if dimensions' names differ.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"} or callable, default: "no_conflicts"
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def concat(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- "strict": similar to "exact", but less permissive.
The alignment fails if dimensions' names differ.
combine_attrs : {"drop", "identical", "no_conflicts", "drop_conflicts", \
"override"} or callable, default: "override"
A callable or a string indicating how to combine attrs of the objects being
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,8 @@ def merge(
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- "strict": similar to "exact", but less permissive.
The alignment fails if dimensions' names differ.
fill_value : scalar or dict-like, optional
Value to use for newly missing values. If a dict-like, maps
Expand Down
2 changes: 2 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class set_options:
- "override": if indexes are of same size, rewrite indexes to be
those of the first object with that dimension. Indexes for the same
dimension must have the same size in all objects.
- "strict": similar to "exact", but less permissive.
The alignment fails if dimensions' names differ.
cmap_divergent : str or matplotlib.colors.Colormap, default: "RdBu_r"
Colormap to use for divergent data plots. If string, must be
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def copy(
Literal["drop", "identical", "no_conflicts", "drop_conflicts", "override"],
Callable[..., Any],
]
JoinOptions = Literal["outer", "inner", "left", "right", "exact", "override"]
JoinOptions = Literal["outer", "inner", "left", "right", "exact", "override", "strict"]

Interp1dOptions = Literal[
"linear", "nearest", "zero", "slinear", "quadratic", "cubic", "polynomial"
Expand Down
101 changes: 101 additions & 0 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable

Expand Down Expand Up @@ -1261,3 +1262,103 @@ def test_concat_index_not_same_dim() -> None:
match=r"Cannot concatenate along dimension 'x' indexes with dimensions.*",
):
concat([ds1, ds2], dim="x")


def test_concat_join_coordinate_variables_non_asked_dims():
ds1 = Dataset(
coords={
"x_center": ("x_center", [1, 2, 3]),
"x_outer": ("x_outer", [0.5, 1.5, 2.5, 3.5]),
},
)

ds2 = Dataset(
coords={
"x_center": ("x_center", [4, 5, 6]),
"x_outer": ("x_outer", [4.5, 5.5, 6.5]),
},
)

# Using join='outer'
expected_wrongly_concatenated_xds = Dataset(
coords={
"x_center": ("x_center", [1, 2, 3, 4, 5, 6]),
"x_outer": ("x_outer", [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5]),
},
)
# Not using strict mode will allow the concatenation to surprisingly happen
# even if `x_outer` sizes do not match
actual_xds = concat(
[ds1, ds2],
dim="x_center",
data_vars="different",
coords="different",
join="outer",
)
assert all(actual_xds == expected_wrongly_concatenated_xds)

# Using join='strict'
# A check similar to the one made on non-indexed dimensions regarding their sizes.
with pytest.raises(
ValueError,
match=re.escape(
r"cannot align objects with join='strict' with matching indexes "
r"along dimension 'x_outer' that don't have the same size ({3, 4})"
),
):
concat(
[ds1, ds2],
dim="x_center",
data_vars="different",
coords="different",
join="strict",
)


def test_concat_join_non_coordinate_variables():
ds1 = Dataset(
data_vars={
"a": ("x_center", [1, 2, 3]),
"b": ("x_outer", [0.5, 1.5, 2.5, 3.5]),
},
)

ds2 = Dataset(
data_vars={
"a": ("x_center", [4, 5, 6]),
"b": ("x_outer", [4.5, 5.5, 6.5]),
},
)

# Whether join='outer' or join='strict' modes are used,
# the concatenation fails because of the behavior disallowing alignment
# of non-indexed dimensions (not attached to a coordinate variable).
with pytest.raises(
ValueError,
match=(
r"cannot reindex or align along dimension 'x_outer' "
r"because of conflicting dimension sizes: {3, 4}"
),
):
concat(
[ds1, ds2],
dim="x_center",
data_vars="different",
coords="different",
join="strict",
)

with pytest.raises(
ValueError,
match=(
r"cannot reindex or align along dimension 'x_outer' "
r"because of conflicting dimension sizes: {3, 4}"
),
):
concat(
[ds1, ds2],
dim="x_center",
data_vars="different",
coords="different",
join="outer",
)
20 changes: 20 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3188,6 +3188,26 @@ def test_align_str_dtype(self) -> None:
assert_identical(expected_b, actual_b)
assert expected_b.x.dtype == actual_b.x.dtype

def test_align_exact_vs_strict(self) -> None:
xda_1 = xr.DataArray([1], dims="x1")
xda_2 = xr.DataArray([1], dims="x2")

# join='exact' passes
aligned_1, aligned_2 = xr.align(xda_1, xda_2, join="exact")
assert aligned_1 == xda_1
assert aligned_2 == xda_2

# join='strict' fails because of non-matching dimensions' names
with pytest.raises(
ValueError,
match=(
r"cannot align objects with join='strict' "
r"because given objects do not share the same dimension names "
r"([('x1',), ('x2',)])"
),
):
xr.align(xda_1, xda_2, join="strict")

def test_broadcast_arrays(self) -> None:
x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x")
y = DataArray([1, 2], coords=[("b", [3, 4])], name="y")
Expand Down

0 comments on commit bbe7d05

Please sign in to comment.