-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reimplement Datatree typed ops #9619
Changes from 18 commits
812d207
1b9c089
2d9cef4
4bb8902
8e7c1da
909ae0e
03ce2c5
12112f1
b823290
7147bb3
368d456
2c740b4
1eae418
bb019da
20199cb
7634b6c
2814801
1fcbe02
9e2dfad
399c6e1
304eb19
7dbc817
c7d8060
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import functools | ||
import itertools | ||
import textwrap | ||
from collections import ChainMap | ||
|
@@ -15,6 +16,7 @@ | |
|
||
from xarray.core import utils | ||
from xarray.core._aggregations import DataTreeAggregations | ||
from xarray.core._typed_ops import DataTreeOpsMixin | ||
from xarray.core.alignment import align | ||
from xarray.core.common import TreeAttrAccessMixin | ||
from xarray.core.coordinates import Coordinates, DataTreeCoordinates | ||
|
@@ -60,6 +62,7 @@ | |
from xarray.core.merge import CoercibleMapping, CoercibleValue | ||
from xarray.core.types import ( | ||
Dims, | ||
DtCompatible, | ||
ErrorOptions, | ||
ErrorOptionsWithWarn, | ||
NetcdfWriteModes, | ||
|
@@ -403,6 +406,7 @@ def map( # type: ignore[override] | |
class DataTree( | ||
NamedNode["DataTree"], | ||
DataTreeAggregations, | ||
DataTreeOpsMixin, | ||
TreeAttrAccessMixin, | ||
Mapping[str, "DataArray | DataTree"], | ||
): | ||
|
@@ -1486,6 +1490,50 @@ def groups(self): | |
"""Return all groups in the tree, given as a tuple of path-like strings.""" | ||
return tuple(node.path for node in self.subtree) | ||
|
||
def _unary_op(self, f, *args, **kwargs) -> DataTree: | ||
# TODO do we need to any additional work to avoid duplication etc.? (Similar to aggregations) | ||
return self.map_over_subtree(f, *args, **kwargs) # type: ignore[return-value] | ||
|
||
def _binary_op(self, other, f, reflexive=False, join=None) -> DataTree: | ||
from xarray.core.dataset import Dataset | ||
from xarray.core.groupby import GroupBy | ||
|
||
if isinstance(other, GroupBy): | ||
# TODO should we be trying to make this work? | ||
raise NotImplementedError | ||
|
||
ds_binop = functools.partial( | ||
Dataset._binary_op, | ||
f=f, | ||
reflexive=reflexive, | ||
join=join, | ||
) | ||
return map_over_subtree(ds_binop)(self, other) | ||
|
||
def _inplace_binary_op(self, other, f) -> Self: | ||
from xarray.core.groupby import GroupBy | ||
|
||
if isinstance(other, GroupBy): | ||
raise TypeError( | ||
"in-place operations between a DataTree and " | ||
"a grouped object are not permitted" | ||
) | ||
|
||
# TODO requires an implementation of map_over_subtree_inplace | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I will just do in-place ops in a follow-up PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect this actually needs a different implementation, in order to handle error recovery properly. I would suggest handling this like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. I've raised #9629 to track that, so I'll remove this pseudocode from this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in 304eb19 |
||
# | ||
# ds_inplace_binop = functools.partial( | ||
# Dataset._inplace_binary_op, | ||
# f=f, | ||
# ) | ||
# | ||
# return map_over_subtree_inplace(ds_inplace_binop)(self, other) | ||
raise NotImplementedError() | ||
|
||
# TODO: dirty workaround for mypy 1.5 error with inherited DatasetOpsMixin vs. Mapping | ||
# related to https://github.com/python/mypy/issues/9319? | ||
def __eq__(self, other: DtCompatible) -> Self: # type: ignore[override] | ||
return super().__eq__(other) | ||
|
||
def to_netcdf( | ||
self, | ||
filepath, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NotImplemented
is a sentinel value that tells Python that an arithmetic operator is not implemented, and allows the other argument to try implementing it. If all special methods returnNotImplemented
, then Python raises an informativeTypeError
.Should we also explicitly exclude
Dataset
here, or are the "mapped over all nodes" semantics of DataTree + Dataset arithmetic obvious enough? #9365There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I did not realise that part. So
NotImplemented
is fine here for now.I think it's fine to allow
Dataset
- I did that deliberately. It's also tested now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in 9e2dfad