Skip to content

Commit

Permalink
add DataArray.sum() test over DataTree by dim and variable, and modif…
Browse files Browse the repository at this point in the history
…y datatree object creation function to make this easier/meaningful, and exersize pre-commit hooks
  • Loading branch information
acrosby committed Jul 14, 2024
1 parent d173411 commit 43b7669
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 30 deletions.
42 changes: 21 additions & 21 deletions xarray/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ def create_test_multidataset_withoutroot_datatree():
│ Data variables:
│ x (yi, xi) int64 24kB 0 0 0 0 0 0 0 0 0 ... 49 49 49 49 49 49 49 49
│ y (yi, xi) int64 24kB 0 1 2 3 4 5 6 7 8 ... 52 53 54 55 56 57 58 59
│ u (t, yi, xi) float64 720kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
│ v (t, yi, xi) float64 720kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
│ p (t, yi, xi) float64 720kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
│ u (t, yi, xi) float64 720kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
│ v (t, yi, xi) float64 720kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
│ p (t, yi, xi) float64 720kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
├── Group: /ovr1
│ Dimensions: (yi: 51, xi: 51, t: 10)
│ Coordinates:
Expand All @@ -226,9 +226,9 @@ def create_test_multidataset_withoutroot_datatree():
│ Data variables:
│ x (yi, xi) float64 21kB 150.0 150.0 150.0 150.0 ... 200.0 200.0 200.0
│ y (yi, xi) int64 21kB 0 1 2 3 4 5 6 7 8 ... 43 44 45 46 47 48 49 50
│ u (t, yi, xi) float64 208kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
│ v (t, yi, xi) float64 208kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
│ p (t, yi, xi) float64 208kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
│ u (t, yi, xi) float64 208kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
│ v (t, yi, xi) float64 208kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
│ p (t, yi, xi) float64 208kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
│ cx (t) int64 80B 10 11 12 13 14 15 16 17 18 19
│ cy (t) int64 80B 10 11 12 13 14 15 16 17 18 19
└── Group: /ovr2
Expand All @@ -239,9 +239,9 @@ def create_test_multidataset_withoutroot_datatree():
Data variables:
x (yi, xi) int64 21kB 0 0 0 0 0 0 0 0 0 ... 50 50 50 50 50 50 50 50
y (yi, xi) int64 21kB 0 1 2 3 4 5 6 7 8 ... 43 44 45 46 47 48 49 50
u (t, yi, xi) float64 416kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
v (t, yi, xi) float64 416kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
p (t, yi, xi) float64 416kB 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
u (t, yi, xi) float64 416kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
v (t, yi, xi) float64 416kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
p (t, yi, xi) float64 416kB 1.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 1.0
cx (t) int64 160B 10 11 12 13 14 15 16 17 ... 22 23 24 25 26 27 28 29
cy (t) int64 160B 10 11 12 13 14 15 16 17 ... 22 23 24 25 26 27 28 29
Expand All @@ -265,9 +265,9 @@ def _create_test_multidataset_withoutroot_datatree(
"x": (["yi", "xi"], np.meshgrid(np.arange(60), np.arange(50))[1]),
"y": (["yi", "xi"], np.meshgrid(np.arange(60), np.arange(50))[0]),
"t": ("t", np.arange(30)),
"u": (["t", "yi", "xi"], np.zeros((30, 50, 60))),
"v": (["t", "yi", "xi"], np.zeros((30, 50, 60))),
"p": (["t", "yi", "xi"], np.zeros((30, 50, 60))),
"u": (["t", "yi", "xi"], np.ones((30, 50, 60))),
"v": (["t", "yi", "xi"], np.ones((30, 50, 60))),
"p": (["t", "yi", "xi"], np.ones((30, 50, 60))),
}
)
)
Expand All @@ -280,9 +280,9 @@ def _create_test_multidataset_withoutroot_datatree(
),
"y": (["yi", "xi"], np.meshgrid(np.arange(51), np.arange(51))[0]),
"t": ("t", np.arange(10, 20)),
"u": (["t", "yi", "xi"], np.zeros((10, 51, 51))),
"v": (["t", "yi", "xi"], np.zeros((10, 51, 51))),
"p": (["t", "yi", "xi"], np.zeros((10, 51, 51))),
"u": (["t", "yi", "xi"], np.ones((10, 51, 51))),
"v": (["t", "yi", "xi"], np.ones((10, 51, 51))),
"p": (["t", "yi", "xi"], np.ones((10, 51, 51))),
"cx": ("t", np.arange(10, 20)),
"cy": ("t", np.arange(10, 20)),
}
Expand All @@ -294,9 +294,9 @@ def _create_test_multidataset_withoutroot_datatree(
"x": (["yi", "xi"], np.meshgrid(np.arange(51), np.arange(51))[1]),
"y": (["yi", "xi"], np.meshgrid(np.arange(51), np.arange(51))[0]),
"t": ("t", np.arange(10, 30)),
"u": (["t", "yi", "xi"], np.zeros((20, 51, 51))),
"v": (["t", "yi", "xi"], np.zeros((20, 51, 51))),
"p": (["t", "yi", "xi"], np.zeros((20, 51, 51))),
"u": (["t", "yi", "xi"], np.ones((20, 51, 51))),
"v": (["t", "yi", "xi"], np.ones((20, 51, 51))),
"p": (["t", "yi", "xi"], np.ones((20, 51, 51))),
"cx": ("t", np.arange(10, 30)),
"cy": ("t", np.arange(10, 30)),
}
Expand All @@ -306,9 +306,9 @@ def _create_test_multidataset_withoutroot_datatree(

# Avoid using __init__ so we can independently test it
root: DataTree = DataTree(data=root_data)
main: DataTree = DataTree(name="Main", parent=root, data=main_data)
ovr1: DataTree = DataTree(name="ovr1", parent=root, data=ovr1_data)
ovr2: DataTree = DataTree(name="ovr2", parent=root, data=ovr2_data)
DataTree(name="Main", parent=root, data=main_data)
DataTree(name="ovr1", parent=root, data=ovr1_data)
DataTree(name="ovr2", parent=root, data=ovr2_data)

return root

Expand Down
36 changes: 27 additions & 9 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,9 +936,9 @@ def test_allow_inconsistent_noninherited_coords_dims_vars(
np.meshgrid(np.arange(60), np.arange(50))[0],
),
"t": ("t", np.arange(30)),
"u": (["t", "yi", "xi"], np.zeros((30, 50, 60))),
"v": (["t", "yi", "xi"], np.zeros((30, 50, 60))),
"p": (["t", "yi", "xi"], np.zeros((30, 50, 60))),
"u": (["t", "yi", "xi"], np.ones((30, 50, 60))),
"v": (["t", "yi", "xi"], np.ones((30, 50, 60))),
"p": (["t", "yi", "xi"], np.ones((30, 50, 60))),
}
),
"/ovr1": xr.Dataset(
Expand All @@ -952,9 +952,9 @@ def test_allow_inconsistent_noninherited_coords_dims_vars(
np.meshgrid(np.arange(51), np.arange(51))[0],
),
"t": ("t", np.arange(10, 20)),
"u": (["t", "yi", "xi"], np.zeros((10, 51, 51))),
"v": (["t", "yi", "xi"], np.zeros((10, 51, 51))),
"p": (["t", "yi", "xi"], np.zeros((10, 51, 51))),
"u": (["t", "yi", "xi"], np.ones((10, 51, 51))),
"v": (["t", "yi", "xi"], np.ones((10, 51, 51))),
"p": (["t", "yi", "xi"], np.ones((10, 51, 51))),
"cx": ("t", np.arange(10, 20)),
"cy": ("t", np.arange(10, 20)),
}
Expand All @@ -970,9 +970,9 @@ def test_allow_inconsistent_noninherited_coords_dims_vars(
np.meshgrid(np.arange(51), np.arange(51))[0],
),
"t": ("t", np.arange(10, 30)),
"u": (["t", "yi", "xi"], np.zeros((20, 51, 51))),
"v": (["t", "yi", "xi"], np.zeros((20, 51, 51))),
"p": (["t", "yi", "xi"], np.zeros((20, 51, 51))),
"u": (["t", "yi", "xi"], np.ones((20, 51, 51))),
"v": (["t", "yi", "xi"], np.ones((20, 51, 51))),
"p": (["t", "yi", "xi"], np.ones((20, 51, 51))),
"cx": ("t", np.arange(10, 30)),
"cy": ("t", np.arange(10, 30)),
}
Expand Down Expand Up @@ -1138,6 +1138,24 @@ def test_cum_method(self):
result = dt.cumsum()
assert_equal(result, expected)

def test_dim_sum_over_duplicated_names_in_tree(
self, create_test_multidataset_withoutroot_datatree
):
dt = create_test_multidataset_withoutroot_datatree()
dims = ("t", "xi", "yi")
for dname in dims:
dtsum = dt.sum(dim=dname)
subgroups = dt.groups[1:]
for gname in subgroups:
subds = dt[gname]
sumds = dtsum[gname]
assert dname in subds.dims
assert dname not in sumds.dims
for vname in subds.variables:
if vname != dname:
if dname in subds[vname].dims:
assert (subds[vname].sum(dim=dname) == sumds[vname]).all()


class TestOps:
def test_binary_op_on_int(self):
Expand Down

0 comments on commit 43b7669

Please sign in to comment.