From 3dcfa31955d334d0d05f171f21e0af6bebfb36e1 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 14 Mar 2024 22:47:30 -0600 Subject: [PATCH] Return a dataclass from Grouper.factorize (#8777) * Return dataclass from factorize * cleanup --- xarray/core/groupby.py | 112 +++++++++++++++++++++++++++-------------- 1 file changed, 73 insertions(+), 39 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index d34b94e9f33..3fbfb74d985 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -60,9 +60,6 @@ GroupKey = Any GroupIndex = Union[int, slice, list[int]] T_GroupIndices = list[GroupIndex] - T_FactorizeOut = tuple[ - DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index - ] def check_reduce_dims(reduce_dims, dimensions): @@ -98,7 +95,7 @@ def _maybe_squeeze_indices( def unique_value_groups( ar, sort: bool = True -) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]: +) -> tuple[np.ndarray | pd.Index, np.ndarray]: """Group an array by its unique values. Parameters @@ -119,11 +116,11 @@ def unique_value_groups( inverse, values = pd.factorize(ar, sort=sort) if isinstance(values, pd.MultiIndex): values.names = ar.names - groups = _codes_to_groups(inverse, len(values)) - return values, groups, inverse + return values, inverse -def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices: +def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices: + assert inverse.ndim == 1 groups: T_GroupIndices = [[] for _ in range(N)] for n, g in enumerate(inverse): if g >= 0: @@ -356,7 +353,7 @@ def can_squeeze(self) -> bool: return False @abstractmethod - def factorize(self, group) -> T_FactorizeOut: + def factorize(self, group) -> EncodedGroups: """ Takes the group, and creates intermediates necessary for GroupBy. These intermediates are @@ -378,6 +375,27 @@ class Resampler(Grouper): pass +@dataclass +class EncodedGroups: + """ + Dataclass for storing intermediate values for GroupBy operation. + Returned by factorize method on Grouper objects. + + Parameters + ---------- + codes: integer codes for each group + full_index: pandas Index for the group coordinate + group_indices: optional, List of indices of array elements belonging + to each group. Inferred if not provided. + unique_coord: Unique group values present in dataset. Inferred if not provided + """ + + codes: DataArray + full_index: pd.Index + group_indices: T_GroupIndices | None = field(default=None) + unique_coord: IndexVariable | _DummyGroup | None = field(default=None) + + @dataclass class ResolvedGrouper(Generic[T_DataWithCoords]): """ @@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]): group: T_Group obj: T_DataWithCoords - # Defined by factorize: + # returned by factorize: codes: DataArray = field(init=False) + full_index: pd.Index = field(init=False) group_indices: T_GroupIndices = field(init=False) unique_coord: IndexVariable | _DummyGroup = field(init=False) - full_index: pd.Index = field(init=False) # _ensure_1d: group1d: T_Group = field(init=False) @@ -445,12 +463,26 @@ def dims(self): return self.group1d.dims def factorize(self) -> None: - ( - self.codes, - self.group_indices, - self.unique_coord, - self.full_index, - ) = self.grouper.factorize(self.group1d) + encoded = self.grouper.factorize(self.group1d) + + self.codes = encoded.codes + self.full_index = encoded.full_index + + if encoded.group_indices is not None: + self.group_indices = encoded.group_indices + else: + self.group_indices = [ + g + for g in _codes_to_group_indices(self.codes.data, len(self.full_index)) + if g + ] + if encoded.unique_coord is None: + unique_values = self.full_index[np.unique(encoded.codes)] + self.unique_coord = IndexVariable( + self.group.name, unique_values, attrs=self.group.attrs + ) + else: + self.unique_coord = encoded.unique_coord @dataclass @@ -477,7 +509,7 @@ def can_squeeze(self) -> bool: is_dimension = self.group.dims == (self.group.name,) return is_dimension and self.is_unique_and_monotonic - def factorize(self, group1d) -> T_FactorizeOut: + def factorize(self, group1d) -> EncodedGroups: self.group = group1d if self.can_squeeze: @@ -485,26 +517,25 @@ def factorize(self, group1d) -> T_FactorizeOut: else: return self._factorize_unique() - def _factorize_unique(self) -> T_FactorizeOut: + def _factorize_unique(self) -> EncodedGroups: # look through group to find the unique values sort = not isinstance(self.group_as_index, pd.MultiIndex) - unique_values, group_indices, codes_ = unique_value_groups( - self.group_as_index, sort=sort - ) - if len(group_indices) == 0: + unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort) + if (codes_ == -1).all(): raise ValueError( "Failed to group data. Are you grouping by a variable that is all NaN?" ) codes = self.group.copy(data=codes_) - group_indices = group_indices unique_coord = IndexVariable( self.group.name, unique_values, attrs=self.group.attrs ) full_index = unique_coord - return codes, group_indices, unique_coord, full_index + return EncodedGroups( + codes=codes, full_index=full_index, unique_coord=unique_coord + ) - def _factorize_dummy(self) -> T_FactorizeOut: + def _factorize_dummy(self) -> EncodedGroups: size = self.group.size # no need to factorize # use slices to do views instead of fancy indexing @@ -519,8 +550,12 @@ def _factorize_dummy(self) -> T_FactorizeOut: full_index = IndexVariable( self.group.name, unique_coord.values, self.group.attrs ) - - return codes, group_indices, unique_coord, full_index + return EncodedGroups( + codes=codes, + group_indices=group_indices, + full_index=full_index, + unique_coord=unique_coord, + ) @dataclass @@ -536,7 +571,7 @@ def __post_init__(self) -> None: if duck_array_ops.isnull(self.bins).all(): raise ValueError("All bin edges are NaN.") - def factorize(self, group) -> T_FactorizeOut: + def factorize(self, group) -> EncodedGroups: from xarray.core.dataarray import DataArray data = group.data @@ -554,20 +589,14 @@ def factorize(self, group) -> T_FactorizeOut: full_index = binned.categories uniques = np.sort(pd.unique(binned_codes)) unique_values = full_index[uniques[uniques != -1]] - group_indices = [ - g for g in _codes_to_groups(binned_codes, len(full_index)) if g - ] - - if len(group_indices) == 0: - raise ValueError( - f"None of the data falls within bins with edges {self.bins!r}" - ) codes = DataArray( binned_codes, getattr(group, "coords", None), name=new_dim_name ) unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs) - return codes, group_indices, unique_coord, full_index + return EncodedGroups( + codes=codes, full_index=full_index, unique_coord=unique_coord + ) @dataclass @@ -672,7 +701,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]: _apply_loffset(self.loffset, first_items) return first_items, codes - def factorize(self, group) -> T_FactorizeOut: + def factorize(self, group) -> EncodedGroups: self._init_properties(group) full_index, first_items, codes_ = self._get_index_and_items() sbins = first_items.values.astype(np.int64) @@ -684,7 +713,12 @@ def factorize(self, group) -> T_FactorizeOut: unique_coord = IndexVariable(group.name, first_items.index, group.attrs) codes = group.copy(data=codes_) - return codes, group_indices, unique_coord, full_index + return EncodedGroups( + codes=codes, + group_indices=group_indices, + full_index=full_index, + unique_coord=unique_coord, + ) def _validate_groupby_squeeze(squeeze: bool | None) -> None: