From f1bc33b0e66e825c6eaafb4dfcfc4851ffbf9297 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 12 Sep 2023 15:21:24 +0200 Subject: [PATCH 01/37] ENH: make a light refactoring Reuse instead of duplicating function. --- xarray/conventions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/conventions.py b/xarray/conventions.py index 5a6675d60c1..7522f3b99fc 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -48,10 +48,6 @@ T_DatasetOrAbstractstore = Union[Dataset, AbstractDataStore] -def _var_as_tuple(var: Variable) -> T_VarTuple: - return var.dims, var.data, var.attrs.copy(), var.encoding.copy() - - def _infer_dtype(array, name: T_Name = None) -> np.dtype: """Given an object array with no missing values, infer its dtype from its first element @@ -106,7 +102,7 @@ def _copy_with_dtype(data, dtype: np.typing.DTypeLike): def ensure_dtype_not_object(var: Variable, name: T_Name = None) -> Variable: # TODO: move this from conventions to backends? (it's not CF related) if var.dtype.kind == "O": - dims, data, attrs, encoding = _var_as_tuple(var) + dims, data, attrs, encoding = variables.unpack_for_encoding(var) # leave vlen dtypes unchanged if strings.check_vlen_dtype(data.dtype) is not None: From 4da893815361146982d7f6485603cef2e32684cc Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 13 Sep 2023 12:18:48 +0200 Subject: [PATCH 02/37] dirty commit --- xarray/backends/api.py | 2 +- xarray/backends/netCDF4_.py | 92 +++++++++++++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 5 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7be7541a79b..7538e47088b 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -172,7 +172,7 @@ def _validate_attrs(dataset, invalid_netcdf=False): `invalid_netcdf=True`. """ - valid_types = (str, Number, np.ndarray, np.number, list, tuple) + valid_types = (str, Number, np.ndarray, np.number, list, tuple, dict) if invalid_netcdf: valid_types += (np.bool_,) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b5c3413e7f8..2f624b0b52f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -49,7 +49,9 @@ # string used by netCDF4. _endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} - +# according to https://github.com/DennisHeimbigner/netcdf-c/blob/ef94285ac13b011613bb5e905d49b63d2a3bb076/libsrc4/nc4type.c#L486 +DEFAULT_HDF_ENUM_FILL_VALUE = 0 +DEFAULT_UNDEFINED_ENUM_MEANING = "_UNDEFINED" NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) @@ -124,6 +126,31 @@ def _getitem(self, key): return array + def replace_mask(self, mask, replacing_value): + ds = self.datastore._acquire(needs_lock=True) + variable = ds.variables[self.variable_name] + variable[mask] = replacing_value + + + +class NetCDF4EnumedArrayWrapper(NetCDF4ArrayWrapper): + __slots__ = () + + def get_array(self, needs_lock=True): + ds = self.datastore._acquire(needs_lock) + variable = ds.variables[self.variable_name] + with suppress(AttributeError): + variable.set_auto_chartostring(False) + return variable + + def unmask(self): + ds = self.datastore._acquire(needs_lock=True) + variable = ds.variables[self.variable_name] + variable.set_auto_maskandscale(False) + + + + def _encode_nc4_variable(var): for coder in [ coding.strings.EncodedStringCoder(allows_unicode=True), @@ -408,10 +435,44 @@ def _acquire(self, needs_lock=True): def ds(self): return self._acquire() - def open_store_variable(self, name, var): + def open_store_variable(self, name: str, var): + import netCDF4 + dimensions = var.dimensions - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) attributes = {k: var.getncattr(k) for k in var.ncattrs()} + + enum_meaning = None + enum_name = None + if isinstance(var.datatype, netCDF4.EnumType): + # Workaround for poorly generated variables: + # When creating a variable typed by an enum + # va_va = self._acquire()[name][:] # get masked array + # old_fill_value = va_va.fill_value + # mask = va_va.mask + enum_meaning = var.datatype.enum_dict + enum_name = var.datatype.name + # Add a meaning to fill_value value if missing + # fill_value = list(var.datatype.enum_dict.values())[0] + # fill_value = attributes.get("_FillValue", DEFAULT_HDF_ENUM_FILL_VALUE) + # attributes["_FillValue"] = fill_value + # masked_data = indexing.LazilyIndexedArray(NetCDF4EnumedArrayWrapper(name, self)) + # masked_data.array.replace_mask(fill_value) + # masked_data.array.unmask() + # va_va[mask] = fill_value + # filtered_reversed_enum_meaning = { + # v: k + # for k, v in enum_meaning.items() + # if v == fill_value + # } + # # TODO: manage fill_value, see todo comment taged with [enum][missing_value] below + # if filtered_reversed_enum_meaning.get(fill_value) is None: + # enum_meaning[DEFAULT_UNDEFINED_ENUM_MEANING] = fill_value + attributes["enum_name"] = enum_name + attributes["enum_meaning"] = enum_meaning + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) + # if enum_meaning is not None: + # data.array.replace_mask(mask=mask, replacing_value=fill_value) + _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later encoding = {} @@ -478,7 +539,7 @@ def encode_variable(self, variable): return variable def prepare_variable( - self, name, variable, check_encoding=False, unlimited_dims=None + self, name, variable: Variable, check_encoding=False, unlimited_dims=None ): _ensure_no_forward_slash_in_name(name) @@ -503,6 +564,29 @@ def prepare_variable( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) + enum = None + if attrs.get("enum_meaning") is not None: + enum = self.ds.createEnumType( + variable.dtype, + attrs["enum_name"], + attrs["enum_meaning"], + ) + datatype = enum + del attrs["enum_name"] + del attrs["enum_meaning"] + fill_value = None + # TODO [enum][missing_value]: + # What should we do with fill+value on enum ? + # On one hand it makes sens to ensure fill_value is a valid enum + # value. + # On the other hand, HDF and netCDF4 does not enforce this. + # In fact with the current netcdf4 we can end up with variables that + # can be created but not read by ncdump if we set fill_value to a + # value outside the enum range but we can also have clunky files + # with fill_value set to None, that are readable even though + # the default fill_value is outside the enum range too. + # Also, HDF considers the value associated with 0 to be the missing + # value *verify this claim). if name in self.ds.variables: nc4_var = self.ds.variables[name] else: From ab539702835608d44ec64449faad7a5c05a8c5e0 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 13 Sep 2023 12:22:10 +0200 Subject: [PATCH 03/37] Clean Remove attempts to workaround the fill_value issues. --- xarray/backends/netCDF4_.py | 61 ------------------------------------- 1 file changed, 61 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 2f624b0b52f..66b4a307ec7 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -49,9 +49,6 @@ # string used by netCDF4. _endian_lookup = {"=": "native", ">": "big", "<": "little", "|": "native"} -# according to https://github.com/DennisHeimbigner/netcdf-c/blob/ef94285ac13b011613bb5e905d49b63d2a3bb076/libsrc4/nc4type.c#L486 -DEFAULT_HDF_ENUM_FILL_VALUE = 0 -DEFAULT_UNDEFINED_ENUM_MEANING = "_UNDEFINED" NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) @@ -126,31 +123,6 @@ def _getitem(self, key): return array - def replace_mask(self, mask, replacing_value): - ds = self.datastore._acquire(needs_lock=True) - variable = ds.variables[self.variable_name] - variable[mask] = replacing_value - - - -class NetCDF4EnumedArrayWrapper(NetCDF4ArrayWrapper): - __slots__ = () - - def get_array(self, needs_lock=True): - ds = self.datastore._acquire(needs_lock) - variable = ds.variables[self.variable_name] - with suppress(AttributeError): - variable.set_auto_chartostring(False) - return variable - - def unmask(self): - ds = self.datastore._acquire(needs_lock=True) - variable = ds.variables[self.variable_name] - variable.set_auto_maskandscale(False) - - - - def _encode_nc4_variable(var): for coder in [ coding.strings.EncodedStringCoder(allows_unicode=True), @@ -444,29 +416,8 @@ def open_store_variable(self, name: str, var): enum_meaning = None enum_name = None if isinstance(var.datatype, netCDF4.EnumType): - # Workaround for poorly generated variables: - # When creating a variable typed by an enum - # va_va = self._acquire()[name][:] # get masked array - # old_fill_value = va_va.fill_value - # mask = va_va.mask enum_meaning = var.datatype.enum_dict enum_name = var.datatype.name - # Add a meaning to fill_value value if missing - # fill_value = list(var.datatype.enum_dict.values())[0] - # fill_value = attributes.get("_FillValue", DEFAULT_HDF_ENUM_FILL_VALUE) - # attributes["_FillValue"] = fill_value - # masked_data = indexing.LazilyIndexedArray(NetCDF4EnumedArrayWrapper(name, self)) - # masked_data.array.replace_mask(fill_value) - # masked_data.array.unmask() - # va_va[mask] = fill_value - # filtered_reversed_enum_meaning = { - # v: k - # for k, v in enum_meaning.items() - # if v == fill_value - # } - # # TODO: manage fill_value, see todo comment taged with [enum][missing_value] below - # if filtered_reversed_enum_meaning.get(fill_value) is None: - # enum_meaning[DEFAULT_UNDEFINED_ENUM_MEANING] = fill_value attributes["enum_name"] = enum_name attributes["enum_meaning"] = enum_meaning data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) @@ -575,18 +526,6 @@ def prepare_variable( del attrs["enum_name"] del attrs["enum_meaning"] fill_value = None - # TODO [enum][missing_value]: - # What should we do with fill+value on enum ? - # On one hand it makes sens to ensure fill_value is a valid enum - # value. - # On the other hand, HDF and netCDF4 does not enforce this. - # In fact with the current netcdf4 we can end up with variables that - # can be created but not read by ncdump if we set fill_value to a - # value outside the enum range but we can also have clunky files - # with fill_value set to None, that are readable even though - # the default fill_value is outside the enum range too. - # Also, HDF considers the value associated with 0 to be the missing - # value *verify this claim). if name in self.ds.variables: nc4_var = self.ds.variables[name] else: From 75e00c7c60933a5c8f5856cdfe6a4d0538557651 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:23:10 +0000 Subject: [PATCH 04/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/netCDF4_.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 66b4a307ec7..82cdf396a5e 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -409,10 +409,10 @@ def ds(self): def open_store_variable(self, name: str, var): import netCDF4 - + dimensions = var.dimensions attributes = {k: var.getncattr(k) for k in var.ncattrs()} - + enum_meaning = None enum_name = None if isinstance(var.datatype, netCDF4.EnumType): @@ -518,7 +518,7 @@ def prepare_variable( enum = None if attrs.get("enum_meaning") is not None: enum = self.ds.createEnumType( - variable.dtype, + variable.dtype, attrs["enum_name"], attrs["enum_meaning"], ) From e1d51e3505de52d4acf0f7000d19bdbcf6800bc0 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 13 Sep 2023 16:16:03 +0200 Subject: [PATCH 05/37] wip: fix tests --- xarray/tests/test_backends.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index d54e1004f08..197c8cd4d09 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -4620,8 +4620,8 @@ def new_dataset_and_coord_attrs(): ds, attrs = new_dataset_and_attrs() attrs["test"] = {"a": 5} - with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"): - ds.to_netcdf("test.nc") + with create_tmp_file() as tmp_file: + ds.to_netcdf(tmp_file) ds, attrs = new_dataset_and_attrs() attrs["test"] = MiscObject() From 95e30b2cbd009755bd99c83ac9193bb0e5606c37 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 15 Sep 2023 13:18:18 +0200 Subject: [PATCH 06/37] dirty --- xarray/backends/netCDF4_.py | 71 ++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 82cdf396a5e..706799038c2 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -321,6 +321,7 @@ class NetCDF4DataStore(WritableCFDataStore): "_group", "_manager", "_mode", + "enum_map", ) def __init__( @@ -348,6 +349,7 @@ def __init__( self.is_remote = is_remote_uri(self._filename) self.lock = ensure_lock(lock) self.autoclose = autoclose + self.enum_map = {} @classmethod def open( @@ -412,21 +414,23 @@ def open_store_variable(self, name: str, var): dimensions = var.dimensions attributes = {k: var.getncattr(k) for k in var.ncattrs()} - - enum_meaning = None + encoding = {} + data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) + enum_dict = None enum_name = None if isinstance(var.datatype, netCDF4.EnumType): - enum_meaning = var.datatype.enum_dict + enum_dict = var.datatype.enum_dict enum_name = var.datatype.name - attributes["enum_name"] = enum_name - attributes["enum_meaning"] = enum_meaning - data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) - # if enum_meaning is not None: - # data.array.replace_mask(mask=mask, replacing_value=fill_value) - + encoding["enum"] = enum_name + attributes["flag_values"] = enum_dict.key() + attributes["flag_meanings"] = enum_dict.values() + if self.enum_map.get("enum_name") is None: + self.enum_map["enum_name"] = [name] + else: + self.enum_map["enum_name"].append(name) _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later - encoding = {} + filters = var.filters() if filters is not None: encoding.update(filters) @@ -493,39 +497,34 @@ def prepare_variable( self, name, variable: Variable, check_encoding=False, unlimited_dims=None ): _ensure_no_forward_slash_in_name(name) - - datatype = _get_datatype( - variable, self.format, raise_on_invalid_encoding=check_encoding - ) attrs = variable.attrs.copy() - fill_value = attrs.pop("_FillValue", None) - - if datatype is str and fill_value is not None: - raise NotImplementedError( - "netCDF4 does not yet support setting a fill value for " - "variable-length strings " - "(https://github.com/Unidata/netcdf4-python/issues/730). " - f"Either remove '_FillValue' from encoding on variable {name!r} " - "or set {'dtype': 'S1'} in encoding to use the fixed width " - "NC_CHAR type." - ) - encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) - - enum = None - if attrs.get("enum_meaning") is not None: - enum = self.ds.createEnumType( + if encoding.get("enum") is not None: + enum_dict = {k:v for k,v in zip(attrs["flag_values"], attrs["flag_meanings"])} + datatype = self.ds.createEnumType( variable.dtype, - attrs["enum_name"], - attrs["enum_meaning"], + encoding["enum"], + enum_dict, ) - datatype = enum - del attrs["enum_name"] - del attrs["enum_meaning"] - fill_value = None + del attrs["flag_values"] + del attrs["flag_meanings"] + else: + datatype = _get_datatype( + variable, self.format, raise_on_invalid_encoding=check_encoding + ) + if datatype is str and fill_value is not None: + raise NotImplementedError( + "netCDF4 does not yet support setting a fill value for " + "variable-length strings " + "(https://github.com/Unidata/netcdf4-python/issues/730). " + f"Either remove '_FillValue' from encoding on variable {name!r} " + "or set {'dtype': 'S1'} in encoding to use the fixed width " + "NC_CHAR type." + ) + if name in self.ds.variables: nc4_var = self.ds.variables[name] else: From a3160c5927a0fddf361ffb087be9ac73ebc73b38 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 15 Sep 2023 13:18:32 +0200 Subject: [PATCH 07/37] clean --- xarray/backends/netCDF4_.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 706799038c2..514e167af23 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -321,7 +321,6 @@ class NetCDF4DataStore(WritableCFDataStore): "_group", "_manager", "_mode", - "enum_map", ) def __init__( @@ -349,7 +348,6 @@ def __init__( self.is_remote = is_remote_uri(self._filename) self.lock = ensure_lock(lock) self.autoclose = autoclose - self.enum_map = {} @classmethod def open( @@ -424,10 +422,6 @@ def open_store_variable(self, name: str, var): encoding["enum"] = enum_name attributes["flag_values"] = enum_dict.key() attributes["flag_meanings"] = enum_dict.values() - if self.enum_map.get("enum_name") is None: - self.enum_map["enum_name"] = [name] - else: - self.enum_map["enum_name"].append(name) _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later From 59ef68618cb31ca7267d496c542e9dad8f7a28e6 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 15 Sep 2023 13:19:04 +0200 Subject: [PATCH 08/37] Remove dict from valid attrs type --- xarray/backends/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 7538e47088b..7be7541a79b 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -172,7 +172,7 @@ def _validate_attrs(dataset, invalid_netcdf=False): `invalid_netcdf=True`. """ - valid_types = (str, Number, np.ndarray, np.number, list, tuple, dict) + valid_types = (str, Number, np.ndarray, np.number, list, tuple) if invalid_netcdf: valid_types += (np.bool_,) From d135be2342dd0d593e1210a4a3488c1a59b87beb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:25:14 +0000 Subject: [PATCH 09/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/backends/netCDF4_.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 514e167af23..28923fce035 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -424,7 +424,7 @@ def open_store_variable(self, name: str, var): attributes["flag_meanings"] = enum_dict.values() _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later - + filters = var.filters() if filters is not None: encoding.update(filters) @@ -497,7 +497,9 @@ def prepare_variable( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) if encoding.get("enum") is not None: - enum_dict = {k:v for k,v in zip(attrs["flag_values"], attrs["flag_meanings"])} + enum_dict = { + k: v for k, v in zip(attrs["flag_values"], attrs["flag_meanings"]) + } datatype = self.ds.createEnumType( variable.dtype, encoding["enum"], @@ -507,7 +509,7 @@ def prepare_variable( del attrs["flag_meanings"] else: datatype = _get_datatype( - variable, self.format, raise_on_invalid_encoding=check_encoding + variable, self.format, raise_on_invalid_encoding=check_encoding ) if datatype is str and fill_value is not None: raise NotImplementedError( From 8c12e50cff15486e784f4a96637408d4d8e50e3f Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 15 Sep 2023 13:39:43 +0200 Subject: [PATCH 10/37] Fix encoding --- xarray/backends/netCDF4_.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 28923fce035..9d077e2c339 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -256,6 +256,7 @@ def _extract_nc4_variable_encoding( "_FillValue", "dtype", "compression", + "enum", } if lsd_okay: valid_encodings.add("least_significant_digit") @@ -420,8 +421,8 @@ def open_store_variable(self, name: str, var): enum_dict = var.datatype.enum_dict enum_name = var.datatype.name encoding["enum"] = enum_name - attributes["flag_values"] = enum_dict.key() - attributes["flag_meanings"] = enum_dict.values() + attributes["flag_values"] = tuple(enum_dict.keys()) + attributes["flag_meanings"] = tuple(enum_dict.values()) _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later From e8f48724b261716c0e548fcba6a8122c5e0f4d26 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 15 Sep 2023 15:42:54 +0200 Subject: [PATCH 11/37] FIX: ordering of flags Added unit test --- xarray/backends/netCDF4_.py | 6 +++--- xarray/tests/test_backends.py | 26 ++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 9d077e2c339..c3db72c9c67 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -421,8 +421,8 @@ def open_store_variable(self, name: str, var): enum_dict = var.datatype.enum_dict enum_name = var.datatype.name encoding["enum"] = enum_name - attributes["flag_values"] = tuple(enum_dict.keys()) - attributes["flag_meanings"] = tuple(enum_dict.values()) + attributes["flag_values"] = tuple(enum_dict.values()) + attributes["flag_meanings"] = tuple(enum_dict.keys()) _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later @@ -499,7 +499,7 @@ def prepare_variable( ) if encoding.get("enum") is not None: enum_dict = { - k: v for k, v in zip(attrs["flag_values"], attrs["flag_meanings"]) + k: v for k, v in zip(attrs["flag_meanings"], attrs["flag_values"]) } datatype = self.ds.createEnumType( variable.dtype, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 197c8cd4d09..5fa97955718 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1626,6 +1626,28 @@ def test_raise_on_forward_slashes_in_names(self) -> None: with self.roundtrip(ds): pass + @requires_netCDF4 + def test_encoding_enum__no_fill_value(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + v = nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=None, + ) + v[:] = 1 + with open_dataset(tmp_file) as ds: + assert list(ds.clouds.attrs.get("flag_meanings")) == list( + cloud_type_dict.keys() + ) + assert list(ds.clouds.attrs.get("flag_values")) == list( + cloud_type_dict.values() + ) + @requires_netCDF4 class TestNetCDF4Data(NetCDF4Base): @@ -4620,8 +4642,8 @@ def new_dataset_and_coord_attrs(): ds, attrs = new_dataset_and_attrs() attrs["test"] = {"a": 5} - with create_tmp_file() as tmp_file: - ds.to_netcdf(tmp_file) + with pytest.raises(TypeError, match=r"Invalid value for attr 'test'"): + ds.to_netcdf("test.nc") ds, attrs = new_dataset_and_attrs() attrs["test"] = MiscObject() From f481e1f806717c2d420c29e63bcfff4650e67070 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Mon, 18 Sep 2023 15:31:44 +0200 Subject: [PATCH 12/37] FIX: encoding of the same enum twice (or more). --- xarray/backends/netCDF4_.py | 14 +++++++++----- xarray/tests/test_backends.py | 23 +++++++++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index c3db72c9c67..de943ff969d 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -501,11 +501,15 @@ def prepare_variable( enum_dict = { k: v for k, v in zip(attrs["flag_meanings"], attrs["flag_values"]) } - datatype = self.ds.createEnumType( - variable.dtype, - encoding["enum"], - enum_dict, - ) + enum_name = encoding["enum"] + if enum_name in self.ds.enumtypes: + datatype = self.ds.enumtypes[enum_name] + else: + datatype = self.ds.createEnumType( + variable.dtype, + enum_name, + enum_dict, + ) del attrs["flag_values"] del attrs["flag_meanings"] else: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 5fa97955718..812ebe236d9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1648,6 +1648,29 @@ def test_encoding_enum__no_fill_value(self): cloud_type_dict.values() ) + @requires_netCDF4 + def test_encoding_enum__multiple_enum_usage(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "cloud", + cloud_type, + "time", + fill_value=None, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=None, + ) + with open_dataset(tmp_file) as ds: + with create_tmp_file() as tmp_file2: + ds.to_netcdf(tmp_file2) + @requires_netCDF4 class TestNetCDF4Data(NetCDF4Base): From 951ea321a251ef6e0802428fddc5ef2a2b26c2a8 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 22 Sep 2023 15:33:42 +0200 Subject: [PATCH 13/37] DOC: Add note for Enum on to_netcdf --- xarray/backends/netCDF4_.py | 6 +++--- xarray/core/dataarray.py | 5 +++++ xarray/tests/test_backends.py | 10 ++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index de943ff969d..a44d7180387 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -233,13 +233,13 @@ def _force_native_endianness(var): def _extract_nc4_variable_encoding( - variable, + variable: Variable, raise_on_invalid=False, lsd_okay=True, h5py_okay=False, backend="netCDF4", unlimited_dims=None, -): +) -> dict[str, Any]: if unlimited_dims is None: unlimited_dims = () @@ -302,7 +302,7 @@ def _extract_nc4_variable_encoding( return encoding -def _is_list_of_strings(value): +def _is_list_of_strings(value) -> bool: arr = np.asarray(value) return arr.dtype.kind in ["U", "S"] and arr.size > 1 diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 5a68fc7ffac..99fcd3de71e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3913,6 +3913,11 @@ def to_netcdf( ) -> bytes | Delayed | None: """Write DataArray contents to a netCDF file. + [netCDF4 backend only] When the CF flag_values/flag_meanings attributes are + set in for this DataArray, you can choose to replace these attributes with + EnumType by updating the encoding dictionary with a key value pair like: + `encoding["enum"] = "enum_name"`. + Parameters ---------- path : str, path-like or None, optional diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 812ebe236d9..0ebdb7c623e 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1647,11 +1647,13 @@ def test_encoding_enum__no_fill_value(self): assert list(ds.clouds.attrs.get("flag_values")) == list( cloud_type_dict.values() ) + with create_tmp_file() as tmp_file2: + ds.to_netcdf(tmp_file2) @requires_netCDF4 - def test_encoding_enum__multiple_enum_usage(self): + def test_encoding_enum__multiple_enum_used(self): with create_tmp_file() as tmp_file: - cloud_type_dict = {"clear": 0, "cloudy": 1} + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} with nc4.Dataset(tmp_file, mode="w") as nc: nc.createDimension("time", size=2) cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) @@ -1659,13 +1661,13 @@ def test_encoding_enum__multiple_enum_usage(self): "cloud", cloud_type, "time", - fill_value=None, + fill_value=255, ) nc.createVariable( "tifa", cloud_type, "time", - fill_value=None, + fill_value=255, ) with open_dataset(tmp_file) as ds: with create_tmp_file() as tmp_file2: From ec3c90a7035978f35c54752c5a326f7587493070 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 8 Nov 2023 12:41:19 +0100 Subject: [PATCH 14/37] ENH: Raise explicit error on invalid variable Raise explicit error whenever we try to write a variable to netCDF4 but whose type is an enum and whose data is outside the enum range. We assume the enum values are contiguous but there is no guarantee they are. --- xarray/backends/netCDF4_.py | 23 +++++++++++++++++++++++ xarray/tests/test_backends.py | 24 +++++++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a44d7180387..463dad0eddf 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -73,8 +73,31 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): + import netCDF4 + with self.datastore.lock: data = self.get_array(needs_lock=False) + if isinstance(data.datatype, netCDF4.EnumType): + # Make sure the values we are trying to assign are in enums valid range. + error_message = ( + "Cannot save the variable `{0}` to netCDF4:" + " `{0}` has values, such as `{1}`, which are not in the" + " enum/flag_values valid values: `{2}`. Fix the variable data or edit" + " its flag_values and flag_meanings attributes and try again. Note that" + " if the enum values are not contiguous, there might be other invalid" + " values too" + ) + valid_values = tuple(data.datatype.enum_dict.values()) + max_val = np.max(value) + min_val = np.min(value) + if max_val not in valid_values: + raise ValueError( + error_message.format(data.name, max_val, valid_values) + ) + if min_val not in valid_values: + raise ValueError( + error_message.format(data.name, min_val, valid_values) + ) data[key] = value if self.datastore.autoclose: self.datastore.close(needs_lock=False) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0ebdb7c623e..88779e76f90 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1651,7 +1651,29 @@ def test_encoding_enum__no_fill_value(self): ds.to_netcdf(tmp_file2) @requires_netCDF4 - def test_encoding_enum__multiple_enum_used(self): + def test_encoding_enum__error_handling(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "clouds", + cloud_type, + "time", + fill_value=None, + ) + # v is filled with default fill_value of u1 + with open_dataset(tmp_file) as ds: + assert np.all(ds.clouds.values == 255) + with create_tmp_file() as tmp_file2: + with pytest.raises(ValueError) as err: + ds.to_netcdf(tmp_file2) + assert True is False + assert "Cannot save the variable clouds" in err.args[0] + + @requires_netCDF4 + def test_encoding_enum__multiple_variable_with_enum(self): with create_tmp_file() as tmp_file: cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} with nc4.Dataset(tmp_file, mode="w") as nc: From 55927f160baa0fbbb9839be5152be61405b0ea5f Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 8 Nov 2023 13:02:23 +0100 Subject: [PATCH 15/37] DOC: Update whats-new --- doc/whats-new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b6bad62dd7c..413b984130c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -24,6 +24,11 @@ New Features - Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). +- Open netCDF4 enums and turn them into CF flag_meanings/flag_values. + This also adds a new encoding key `enum` to DataArray that tells the netCDF4 backend + to turn flag_meanings and flag_values into Enums when calling + :py:meth:`Dataset.to_netcdf`. + By `Abel Aoun _`(:issue:`8144`, :pull:`8147`) Breaking changes ~~~~~~~~~~~~~~~~ From 9e9c62c219c87c4c2cfb6fdd34dbf9c9a5a42235 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 8 Nov 2023 15:50:47 +0100 Subject: [PATCH 16/37] fix: move enum check __setitem__ was not a good fit for this check as `data` may not always be a netCDF4 variable there. --- xarray/backends/netCDF4_.py | 45 ++++++++++++++++------------------- xarray/tests/test_backends.py | 6 ++--- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 8ff170f8f3c..b029a3988f7 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -73,31 +73,8 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): - import netCDF4 - with self.datastore.lock: data = self.get_array(needs_lock=False) - if isinstance(data.datatype, netCDF4.EnumType): - # Make sure the values we are trying to assign are in enums valid range. - error_message = ( - "Cannot save the variable `{0}` to netCDF4:" - " `{0}` has values, such as `{1}`, which are not in the" - " enum/flag_values valid values: `{2}`. Fix the variable data or edit" - " its flag_values and flag_meanings attributes and try again. Note that" - " if the enum values are not contiguous, there might be other invalid" - " values too" - ) - valid_values = tuple(data.datatype.enum_dict.values()) - max_val = np.max(value) - min_val = np.min(value) - if max_val not in valid_values: - raise ValueError( - error_message.format(data.name, max_val, valid_values) - ) - if min_val not in valid_values: - raise ValueError( - error_message.format(data.name, min_val, valid_values) - ) data[key] = value if self.datastore.autoclose: self.datastore.close(needs_lock=False) @@ -521,7 +498,11 @@ def prepare_variable( encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) - if encoding.get("enum") is not None: + if ( + encoding.get("enum") + and attrs.get("flag_meanings") + and attrs.get("flag_values") + ): enum_dict = { k: v for k, v in zip(attrs["flag_meanings"], attrs["flag_values"]) } @@ -534,6 +515,22 @@ def prepare_variable( enum_name, enum_dict, ) + # Make sure the values we are trying to assign are in enums valid range. + error_message = ( + "Cannot save the variable `{0}` to netCDF4: `{0}` has values, such" + " as `{1}`, which are not in the enum/flag_values valid values:" + " `{2}`. Fix the variable data or edit its flag_values and" + " flag_meanings attributes and try again. Note that if the enum" + " values are not contiguous, there might be other invalid values" + " too." + ) + valid_values = tuple(attrs["flag_values"]) + max_val = np.max(variable.data) + min_val = np.min(variable.data) + if max_val not in valid_values: + raise ValueError(error_message.format(name, max_val, valid_values)) + if min_val not in valid_values: + raise ValueError(error_message.format(name, min_val, valid_values)) del attrs["flag_values"] del attrs["flag_meanings"] else: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 948b5e4a9ba..9734b161d4e 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1740,10 +1740,10 @@ def test_encoding_enum__error_handling(self): with open_dataset(tmp_file) as ds: assert np.all(ds.clouds.values == 255) with create_tmp_file() as tmp_file2: - with pytest.raises(ValueError) as err: + with pytest.raises( + ValueError, match="Cannot save the variable `clouds` .*" + ): ds.to_netcdf(tmp_file2) - assert True is False - assert "Cannot save the variable clouds" in err.args[0] @requires_netCDF4 def test_encoding_enum__multiple_variable_with_enum(self): From 5f1bffc7878b08134a7b9d07c7b4028a015035b5 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Thu, 9 Nov 2023 16:05:27 +0100 Subject: [PATCH 17/37] FIX: unit test for min-all-deps requirements The behavior of the netCDF4 lib has changed in later versions about how fill_values are handled. This commit make sure the unit test has the expected behavior in every version. --- xarray/tests/test_backends.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 9734b161d4e..1ef40cd5325 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1734,14 +1734,14 @@ def test_encoding_enum__error_handling(self): "clouds", cloud_type, "time", - fill_value=None, + fill_value=255, ) # v is filled with default fill_value of u1 with open_dataset(tmp_file) as ds: - assert np.all(ds.clouds.values == 255) with create_tmp_file() as tmp_file2: with pytest.raises( - ValueError, match="Cannot save the variable `clouds` .*" + ValueError, + match=("Cannot save the variable `clouds` to" " netCDF4.*"), ): ds.to_netcdf(tmp_file2) From 2410c2ee266d12570bcaf91e481aa0b1a2529990 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 10 Nov 2023 13:50:33 +0100 Subject: [PATCH 18/37] ENH: Add enum discovery When a existing enum has the same name but a different value, a new enum is created. --- xarray/backends/netCDF4_.py | 15 +++++++++++++-- xarray/tests/test_backends.py | 36 +++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index b029a3988f7..54e6888d82f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -503,17 +503,28 @@ def prepare_variable( and attrs.get("flag_meanings") and attrs.get("flag_values") ): - enum_dict = { + var_enum_dict = { k: v for k, v in zip(attrs["flag_meanings"], attrs["flag_values"]) } enum_name = encoding["enum"] if enum_name in self.ds.enumtypes: datatype = self.ds.enumtypes[enum_name] + if datatype.enum_dict != var_enum_dict: + datatype = None + for e_name, e_val in self.ds.enumtypes.items(): + if e_val.enum_dict == var_enum_dict: + datatype = self.ds.enumtypes[e_name] + if datatype is None: + datatype = self.ds.createEnumType( + variable.dtype, + f"{enum_name}_{name}", + var_enum_dict, + ) else: datatype = self.ds.createEnumType( variable.dtype, enum_name, - enum_dict, + var_enum_dict, ) # Make sure the values we are trying to assign are in enums valid range. error_message = ( diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1ef40cd5325..1055699326e 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1768,6 +1768,42 @@ def test_encoding_enum__multiple_variable_with_enum(self): with create_tmp_file() as tmp_file2: ds.to_netcdf(tmp_file2) + @requires_netCDF4 + def test_encoding_enum__multiple_variable_with_changing_enum(self): + with create_tmp_file() as tmp_file: + cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} + with nc4.Dataset(tmp_file, mode="w") as nc: + nc.createDimension("time", size=2) + cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) + nc.createVariable( + "cloud", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "tifa", + cloud_type, + "time", + fill_value=255, + ) + nc.createVariable( + "barret", + cloud_type, + "time", + fill_value=255, + ) + with open_dataset(tmp_file) as ds: + ds.cloud.attrs["flag_values"] += (2,) + ds.cloud.attrs["flag_meanings"] += ("neblig",) + ds.tifa.attrs["flag_values"] += (2,) + ds.tifa.attrs["flag_meanings"] += ("neblig",) + with create_tmp_file() as tmp_file2: + ds.to_netcdf(tmp_file2) + with nc4.Dataset(tmp_file2, mode="r") as nc: + # We want to assert that two different enums are written. + assert len(nc.enumtypes.keys()) == 2 + @requires_netCDF4 class TestNetCDF4Data(NetCDF4Base): From 4b966ba6e3b92b8db7d7456360d167eeb1ec36b3 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 10 Nov 2023 14:19:09 +0100 Subject: [PATCH 19/37] ENH: Raise error instead of modifying dataset When there are multiple variable that are encoded to use the same enum but actually have different flag_*, we now raise a meaningful error instead of creating an adhoc enum. --- xarray/backends/netCDF4_.py | 20 ++++++++++---------- xarray/tests/test_backends.py | 20 ++++++++------------ 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 54e6888d82f..bb09079b669 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -510,16 +510,16 @@ def prepare_variable( if enum_name in self.ds.enumtypes: datatype = self.ds.enumtypes[enum_name] if datatype.enum_dict != var_enum_dict: - datatype = None - for e_name, e_val in self.ds.enumtypes.items(): - if e_val.enum_dict == var_enum_dict: - datatype = self.ds.enumtypes[e_name] - if datatype is None: - datatype = self.ds.createEnumType( - variable.dtype, - f"{enum_name}_{name}", - var_enum_dict, - ) + raise ValueError( + f"Cannot save variable `{name}` because an enum" + f" `{enum_name}` already exists in the Dataset but have" + " a different definition. Enums are created when" + " `encoding['enum']` is set by combining flag_values" + " and flag_meanings attributes. To fix this error, make sure" + " each variable have a unique name for `encoding['enum']` or " + " if they should have the same enum, that their flag_values and" + " flag_meanings are identical." + ) else: datatype = self.ds.createEnumType( variable.dtype, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 1055699326e..f19d365029e 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1787,22 +1787,18 @@ def test_encoding_enum__multiple_variable_with_changing_enum(self): "time", fill_value=255, ) - nc.createVariable( - "barret", - cloud_type, - "time", - fill_value=255, - ) with open_dataset(tmp_file) as ds: ds.cloud.attrs["flag_values"] += (2,) ds.cloud.attrs["flag_meanings"] += ("neblig",) - ds.tifa.attrs["flag_values"] += (2,) - ds.tifa.attrs["flag_meanings"] += ("neblig",) with create_tmp_file() as tmp_file2: - ds.to_netcdf(tmp_file2) - with nc4.Dataset(tmp_file2, mode="r") as nc: - # We want to assert that two different enums are written. - assert len(nc.enumtypes.keys()) == 2 + with pytest.raises( + ValueError, + match=( + "Cannot save variable .*" + " because an enum `cloud_type` already exists in the Dataset .*" + ), + ): + ds.to_netcdf(tmp_file2) @requires_netCDF4 From ca043a750d3474a24ab14334e4e0405154fdaf0b Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Fri, 5 Jan 2024 12:13:18 +0100 Subject: [PATCH 20/37] FIX: pop unnecessary encoding Before calling netcdf4-python createVariable, we need to clear the unwanted arguments from encoding. --- xarray/backends/netCDF4_.py | 6 ++++-- xarray/core/dataarray.py | 10 +++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index c328121bf60..afa4a54a0cb 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -512,11 +512,11 @@ def prepare_variable( var_enum_dict = { k: v for k, v in zip(attrs["flag_meanings"], attrs["flag_values"]) } - enum_name = encoding["enum"] + enum_name = encoding.pop("enum") if enum_name in self.ds.enumtypes: datatype = self.ds.enumtypes[enum_name] if datatype.enum_dict != var_enum_dict: - raise ValueError( + error_msg = ( f"Cannot save variable `{name}` because an enum" f" `{enum_name}` already exists in the Dataset but have" " a different definition. Enums are created when" @@ -526,6 +526,7 @@ def prepare_variable( " if they should have the same enum, that their flag_values and" " flag_meanings are identical." ) + raise ValueError(error_msg) else: datatype = self.ds.createEnumType( variable.dtype, @@ -554,6 +555,7 @@ def prepare_variable( datatype = _get_datatype( variable, self.format, raise_on_invalid_encoding=check_encoding ) + encoding.pop("dtype", None) if name in self.ds.variables: nc4_var = self.ds.variables[name] else: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c6fc0baed7f..c0b4757be10 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3989,11 +3989,6 @@ def to_netcdf( ) -> bytes | Delayed | None: """Write DataArray contents to a netCDF file. - [netCDF4 backend only] When the CF flag_values/flag_meanings attributes are - set in for this DataArray, you can choose to replace these attributes with - EnumType by updating the encoding dictionary with a key value pair like: - `encoding["enum"] = "enum_name"`. - Parameters ---------- path : str, path-like or None, optional @@ -4074,6 +4069,11 @@ def to_netcdf( name is the same as a coordinate name, then it is given the name ``"__xarray_dataarray_variable__"``. + [netCDF4 backend only] When the CF flag_values/flag_meanings attributes are + set in for this DataArray, you can choose to replace these attributes by + a netcdf4 EnumType by updating the encoding dictionary with a key value pair + like: `encoding["enum"] = "enum_name"`. + See Also -------- Dataset.to_netcdf From ee3dc0033f19c5b53ab5b0f10313c247a75be7f1 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Mon, 8 Jan 2024 21:52:55 +0100 Subject: [PATCH 21/37] Add Enum Coder --- xarray/backends/api.py | 7 ++- xarray/backends/netCDF4_.py | 83 ++++++++++++++--------------------- xarray/backends/store.py | 2 + xarray/coding/variables.py | 32 ++++++++++++++ xarray/conventions.py | 22 +++++++++- xarray/core/dataarray.py | 2 +- xarray/tests/test_backends.py | 25 +++++------ 7 files changed, 107 insertions(+), 66 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 1d538bf94ed..9f3adc3a377 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -2,6 +2,7 @@ import os from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence +from enum import EnumType from functools import partial from io import BytesIO from numbers import Number @@ -172,7 +173,7 @@ def _validate_attrs(dataset, invalid_netcdf=False): `invalid_netcdf=True`. """ - valid_types = (str, Number, np.ndarray, np.number, list, tuple) + valid_types = (str, Number, np.ndarray, np.number, list, tuple, EnumType) if invalid_netcdf: valid_types += (np.bool_,) @@ -407,6 +408,7 @@ def open_dataset( chunked_array_type: str | None = None, from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, + decode_enum: bool | None = None, **kwargs, ) -> Dataset: """Open and decode a dataset from a file or file-like object. @@ -512,6 +514,8 @@ def open_dataset( backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. + decode_enum: bool, optional + If True, decode CF flag_values and flag_meanings into a pyton Enum. **kwargs: dict Additional keyword arguments passed on to the engine open function. For example: @@ -566,6 +570,7 @@ def open_dataset( concat_characters=concat_characters, use_cftime=use_cftime, decode_coords=decode_coords, + decode_enum=decode_enum, ) overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index afa4a54a0cb..e2266a13248 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -426,7 +426,7 @@ def open_store_variable(self, name: str, var): if isinstance(var.datatype, netCDF4.EnumType): enum_dict = var.datatype.enum_dict enum_name = var.datatype.name - encoding["enum"] = enum_name + attributes["enum"] = enum_name attributes["flag_values"] = tuple(enum_dict.values()) attributes["flag_meanings"] = tuple(enum_dict.keys()) _ensure_fill_value_valid(data, attributes) @@ -504,58 +504,12 @@ def prepare_variable( encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) - if ( - encoding.get("enum") - and attrs.get("flag_meanings") - and attrs.get("flag_values") - ): - var_enum_dict = { - k: v for k, v in zip(attrs["flag_meanings"], attrs["flag_values"]) - } - enum_name = encoding.pop("enum") - if enum_name in self.ds.enumtypes: - datatype = self.ds.enumtypes[enum_name] - if datatype.enum_dict != var_enum_dict: - error_msg = ( - f"Cannot save variable `{name}` because an enum" - f" `{enum_name}` already exists in the Dataset but have" - " a different definition. Enums are created when" - " `encoding['enum']` is set by combining flag_values" - " and flag_meanings attributes. To fix this error, make sure" - " each variable have a unique name for `encoding['enum']` or " - " if they should have the same enum, that their flag_values and" - " flag_meanings are identical." - ) - raise ValueError(error_msg) - else: - datatype = self.ds.createEnumType( - variable.dtype, - enum_name, - var_enum_dict, - ) - # Make sure the values we are trying to assign are in enums valid range. - error_message = ( - "Cannot save the variable `{0}` to netCDF4: `{0}` has values, such" - " as `{1}`, which are not in the enum/flag_values valid values:" - " `{2}`. Fix the variable data or edit its flag_values and" - " flag_meanings attributes and try again. Note that if the enum" - " values are not contiguous, there might be other invalid values" - " too." - ) - valid_values = tuple(attrs["flag_values"]) - max_val = np.max(variable.data) - min_val = np.min(variable.data) - if max_val not in valid_values: - raise ValueError(error_message.format(name, max_val, valid_values)) - if min_val not in valid_values: - raise ValueError(error_message.format(name, min_val, valid_values)) - del attrs["flag_values"] - del attrs["flag_meanings"] + if attrs.get("enum"): + datatype = self._build_and_get_enum(name, attrs, variable.dtype) else: datatype = _get_datatype( variable, self.format, raise_on_invalid_encoding=check_encoding ) - encoding.pop("dtype", None) if name in self.ds.variables: nc4_var = self.ds.variables[name] else: @@ -583,6 +537,35 @@ def prepare_variable( return target, variable.data + def _build_and_get_enum( + self, var_name: str, attributes: dict, dtype: np.dtype + ) -> object: + flag_meanings = attributes.pop("flag_meanings") + flag_values = attributes.pop("flag_values") + enum_name = attributes.pop("enum") + enum_dict = {k: v for k, v in zip(flag_meanings, flag_values)} + if enum_name in self.ds.enumtypes: + datatype = self.ds.enumtypes[enum_name] + if datatype.enum_dict != enum_dict: + error_msg = ( + f"Cannot save variable `{var_name}` because an enum" + f" `{enum_name}` already exists in the Dataset but have" + " a different definition. Enums are created when" + " `attrs['enum']` is filled with an enum name, then flag_values" + " and flag_meanings attributes are combined. To fix this error, make sure" + " each variable have a unique name for `attrs['enum']` or " + " if they should be typed with the same enum, that their flag_values and" + " flag_meanings are identical." + ) + raise ValueError(error_msg) + else: + datatype = self.ds.createEnumType( + dtype, + enum_name, + enum_dict, + ) + return datatype + def sync(self): self.ds.sync() @@ -653,6 +636,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti persist=False, lock=None, autoclose=False, + decode_enum: bool | None = None, ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( @@ -678,6 +662,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti drop_variables=drop_variables, use_cftime=use_cftime, decode_timedelta=decode_timedelta, + decode_enum=decode_enum, ) return ds diff --git a/xarray/backends/store.py b/xarray/backends/store.py index a507ee37470..7af08aaade7 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -37,6 +37,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, + decode_enum: bool | None = None, ) -> Dataset: assert isinstance(filename_or_obj, AbstractDataStore) @@ -53,6 +54,7 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti drop_variables=drop_variables, use_cftime=use_cftime, decode_timedelta=decode_timedelta, + decode_enum=decode_enum, ) ds = Dataset(vars, attrs=attrs) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 487197605e8..455a7919b8a 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -3,6 +3,7 @@ import warnings from collections.abc import Hashable, MutableMapping +from enum import Enum from functools import partial from typing import TYPE_CHECKING, Any, Callable, Union @@ -574,3 +575,34 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: return variable else: return variable + + +class EnumCoder(VariableCoder): + """Encode and decode Enum to CF""" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + """From python Enum to CF""" + dims, data, attrs, encoding = unpack_for_encoding(variable) + if attrs.get("enum"): + enum = attrs.pop("enum") + enum_name = enum.__name__ + attrs["flag_meanings"] = tuple(i.name for i in enum) + attrs["flag_values"] = tuple(i.value for i in enum) + attrs["enum"] = enum_name + return Variable(dims, data, attrs, encoding, fastpath=True) + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + """From CF to python Enum""" + dims, data, attrs, encoding = unpack_for_decoding(variable) + if ( + attrs.get("enum") + and attrs.get("flag_meanings") + and attrs.get("flag_values") + ): + flag_meanings = attrs.pop("flag_meanings") + flag_values = attrs.pop("flag_values") + enum_name = attrs.pop("enum") + enum_dict = {k: v for k, v in zip(flag_meanings, flag_values)} + attrs["enum"] = Enum(enum_name, enum_dict) + return Variable(dims, data, attrs, encoding, fastpath=True) + return variable diff --git a/xarray/conventions.py b/xarray/conventions.py index 2f8d0a893f9..af9f4ac7a99 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -167,13 +167,14 @@ def encode_cf_variable( var: Variable, needs_copy: bool = True, name: T_Name = None ) -> Variable: """ - Converts an Variable into an Variable which follows some + Converts a Variable into a Variable which follows some of the CF conventions: - Nans are masked using _FillValue (or the deprecated missing_value) - Rescaling via: scale_factor and add_offset - datetimes are converted to the CF 'units since time' format - dtype encodings are enforced. + - enum is turned into flag_values and flag_meanings Parameters ---------- @@ -196,6 +197,7 @@ def encode_cf_variable( variables.NonStringCoder(), variables.DefaultFillvalueCoder(), variables.BooleanCoder(), + variables.EnumCoder(), ]: var = coder.encode(var, name=name) @@ -217,6 +219,7 @@ def decode_cf_variable( stack_char_dim: bool = True, use_cftime: bool | None = None, decode_timedelta: bool | None = None, + decode_enum: bool | None = None, ) -> Variable: """ Decodes a variable which may hold CF encoded information. @@ -257,6 +260,8 @@ def decode_cf_variable( represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible raise an error. + decode_enum: bool, optional + Turn the CF flag_values and flag_meanings into a python Enum in `attrs['enum']`. Returns ------- @@ -300,6 +305,9 @@ def decode_cf_variable( var = variables.BooleanCoder().decode(var) + if decode_enum: + var = variables.EnumCoder().decode(var) + dimensions, data, attributes, encoding = variables.unpack_for_decoding(var) encoding.setdefault("dtype", original_dtype) @@ -398,6 +406,7 @@ def decode_cf_variables( drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, + decode_enum: bool | None = None, ) -> tuple[T_Variables, T_Attrs, set[Hashable]]: """ Decode several CF encoded variables. @@ -450,6 +459,7 @@ def stackable(dim: Hashable) -> bool: stack_char_dim=stack_char_dim, use_cftime=use_cftime, decode_timedelta=decode_timedelta, + decode_enum=decode_enum, ) except Exception as e: raise type(e)(f"Failed to decode variable {k!r}: {e}") @@ -514,6 +524,7 @@ def decode_cf( drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, + decode_enum: bool = True, ) -> Dataset: """Decode the given Dataset or Datastore according to CF conventions into a new Dataset. @@ -592,6 +603,7 @@ def decode_cf( drop_variables=drop_variables, use_cftime=use_cftime, decode_timedelta=decode_timedelta, + decode_enum=decode_enum, ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) @@ -607,6 +619,7 @@ def cf_decoder( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, + decode_enum: bool = True, ) -> tuple[T_Variables, T_Attrs]: """ Decode a set of CF encoded variables and attributes. @@ -638,7 +651,12 @@ def cf_decoder( decode_cf_variable """ variables, attributes, _ = decode_cf_variables( - variables, attributes, concat_characters, mask_and_scale, decode_times + variables, + attributes, + concat_characters, + mask_and_scale, + decode_times, + decode_enum=decode_enum, ) return variables, attributes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c0b4757be10..d5ac5a79b7f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4072,7 +4072,7 @@ def to_netcdf( [netCDF4 backend only] When the CF flag_values/flag_meanings attributes are set in for this DataArray, you can choose to replace these attributes by a netcdf4 EnumType by updating the encoding dictionary with a key value pair - like: `encoding["enum"] = "enum_name"`. + like: `da.attrs["enum"] = "enum_name"`. See Also -------- diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index aa8748b5829..3233564938c 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -15,6 +15,7 @@ import warnings from collections.abc import Generator, Iterator from contextlib import ExitStack +from enum import Enum from io import BytesIO from os import listdir from pathlib import Path @@ -1720,13 +1721,10 @@ def test_encoding_enum__no_fill_value(self): fill_value=None, ) v[:] = 1 - with open_dataset(tmp_file) as ds: - assert list(ds.clouds.attrs.get("flag_meanings")) == list( - cloud_type_dict.keys() - ) - assert list(ds.clouds.attrs.get("flag_values")) == list( - cloud_type_dict.values() - ) + with open_dataset(tmp_file, decode_enum=True) as ds: + assert { + i.name: i.value for i in ds.clouds.attrs["enum"] + } == cloud_type_dict with create_tmp_file() as tmp_file2: ds.to_netcdf(tmp_file2) @@ -1744,11 +1742,11 @@ def test_encoding_enum__error_handling(self): fill_value=255, ) # v is filled with default fill_value of u1 - with open_dataset(tmp_file) as ds: + with open_dataset(tmp_file, decode_enum=True) as ds: with create_tmp_file() as tmp_file2: with pytest.raises( ValueError, - match=("Cannot save the variable `clouds` to" " netCDF4.*"), + match=("trying to assign illegal value to Enum variable"), ): ds.to_netcdf(tmp_file2) @@ -1771,7 +1769,7 @@ def test_encoding_enum__multiple_variable_with_enum(self): "time", fill_value=255, ) - with open_dataset(tmp_file) as ds: + with open_dataset(tmp_file, decode_enum=True) as ds: with create_tmp_file() as tmp_file2: ds.to_netcdf(tmp_file2) @@ -1794,9 +1792,10 @@ def test_encoding_enum__multiple_variable_with_changing_enum(self): "time", fill_value=255, ) - with open_dataset(tmp_file) as ds: - ds.cloud.attrs["flag_values"] += (2,) - ds.cloud.attrs["flag_meanings"] += ("neblig",) + with open_dataset(tmp_file, decode_enum=True) as ds: + cloud_enum_dict = {e.name: e.value for e in ds.cloud.attrs["enum"]} + cloud_enum_dict.update({"neblig": 2}) + ds.cloud.attrs["enum"] = Enum("cloud_type", cloud_enum_dict) with create_tmp_file() as tmp_file2: with pytest.raises( ValueError, From 9ab1ad16d0002dcd5fa6780523f5c5b6e18304fa Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 9 Jan 2024 09:49:48 +0100 Subject: [PATCH 22/37] DOC: Update what's new --- doc/whats-new.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0cae8fe4808..ce2925009d9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -222,9 +222,8 @@ New Features - Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). - Open netCDF4 enums and turn them into CF flag_meanings/flag_values. - This also adds a new encoding key `enum` to DataArray that tells the netCDF4 backend - to turn flag_meanings and flag_values into Enums when calling - :py:meth:`Dataset.to_netcdf`. + This also gives a special meaning to the 'enum' attribute in DataArrays, when it is set, this tells the netCDF4 backend + to turn flag_meanings and flag_values into a netCDF4 Enum named using ``attrs["enum"]`` content. By `Abel Aoun _`(:issue:`8144`, :pull:`8147`) - Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`). By `Ben Mares `_. From 7219b9976206949930946c17de15dfa84399a051 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 9 Jan 2024 09:50:35 +0100 Subject: [PATCH 23/37] FIX: Use EnumMeta instead of EnumType fo py<3.11 --- xarray/backends/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9f3adc3a377..9cee5529b0c 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -2,7 +2,7 @@ import os from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence -from enum import EnumType +from enum import EnumMeta from functools import partial from io import BytesIO from numbers import Number @@ -173,7 +173,7 @@ def _validate_attrs(dataset, invalid_netcdf=False): `invalid_netcdf=True`. """ - valid_types = (str, Number, np.ndarray, np.number, list, tuple, EnumType) + valid_types = (str, Number, np.ndarray, np.number, list, tuple, EnumMeta) if invalid_netcdf: valid_types += (np.bool_,) From 2aa119f455a860d61df9f78c28656d55c4883b69 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 9 Jan 2024 09:54:42 +0100 Subject: [PATCH 24/37] ENH: Improve error message --- xarray/backends/netCDF4_.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index e2266a13248..bde6aab9561 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -550,12 +550,10 @@ def _build_and_get_enum( error_msg = ( f"Cannot save variable `{var_name}` because an enum" f" `{enum_name}` already exists in the Dataset but have" - " a different definition. Enums are created when" - " `attrs['enum']` is filled with an enum name, then flag_values" - " and flag_meanings attributes are combined. To fix this error, make sure" - " each variable have a unique name for `attrs['enum']` or " - " if they should be typed with the same enum, that their flag_values and" - " flag_meanings are identical." + " a different definition. To fix this error, make sure" + " each variable have a unique name for their `attrs['enum']`" + " or, if they should share same enum type, make sure" + " their flag_values and flag_meanings are identical." ) raise ValueError(error_msg) else: From da43a10f3044790c135b053cb82c395ca2bd0c19 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 9 Jan 2024 09:54:59 +0100 Subject: [PATCH 25/37] Remove unnecessary test --- xarray/tests/test_backends.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a3e6f884283..b3476454d66 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1726,28 +1726,6 @@ def test_encoding_enum__no_fill_value(self): with create_tmp_file() as tmp_file2: ds.to_netcdf(tmp_file2) - @requires_netCDF4 - def test_encoding_enum__error_handling(self): - with create_tmp_file() as tmp_file: - cloud_type_dict = {"clear": 0, "cloudy": 1} - with nc4.Dataset(tmp_file, mode="w") as nc: - nc.createDimension("time", size=2) - cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) - nc.createVariable( - "clouds", - cloud_type, - "time", - fill_value=255, - ) - # v is filled with default fill_value of u1 - with open_dataset(tmp_file, decode_enum=True) as ds: - with create_tmp_file() as tmp_file2: - with pytest.raises( - ValueError, - match=("trying to assign illegal value to Enum variable"), - ): - ds.to_netcdf(tmp_file2) - @requires_netCDF4 def test_encoding_enum__multiple_variable_with_enum(self): with create_tmp_file() as tmp_file: From 096f021e6b9b362dbd9a37fef1f7325d1342fe5f Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 9 Jan 2024 11:26:41 +0100 Subject: [PATCH 26/37] Update enum Coder --- xarray/backends/netCDF4_.py | 11 ++++------- xarray/coding/variables.py | 14 ++++++++------ xarray/tests/test_coding.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 13 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index bde6aab9561..ca7f052455c 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -419,10 +419,7 @@ def open_store_variable(self, name: str, var): dimensions = var.dimensions attributes = {k: var.getncattr(k) for k in var.ncattrs()} - encoding = {} data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) - enum_dict = None - enum_name = None if isinstance(var.datatype, netCDF4.EnumType): enum_dict = var.datatype.enum_dict enum_name = var.datatype.name @@ -431,7 +428,7 @@ def open_store_variable(self, name: str, var): attributes["flag_meanings"] = tuple(enum_dict.keys()) _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later - + encoding = {} filters = var.filters() if filters is not None: encoding.update(filters) @@ -501,15 +498,15 @@ def prepare_variable( _ensure_no_forward_slash_in_name(name) attrs = variable.attrs.copy() fill_value = attrs.pop("_FillValue", None) - encoding = _extract_nc4_variable_encoding( - variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims - ) if attrs.get("enum"): datatype = self._build_and_get_enum(name, attrs, variable.dtype) else: datatype = _get_datatype( variable, self.format, raise_on_invalid_encoding=check_encoding ) + encoding = _extract_nc4_variable_encoding( + variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims + ) if name in self.ds.variables: nc4_var = self.ds.variables[name] else: diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 17b62d0d0dd..1f1caf4bad5 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -3,7 +3,7 @@ import warnings from collections.abc import Hashable, MutableMapping -from enum import Enum +from enum import Enum, EnumMeta from functools import partial from typing import TYPE_CHECKING, Any, Callable, Union @@ -581,18 +581,18 @@ class EnumCoder(VariableCoder): """Encode and decode Enum to CF""" def encode(self, variable: Variable, name: T_Name = None) -> Variable: - """From python Enum to CF""" + """From python Enum to CF flag_*""" dims, data, attrs, encoding = unpack_for_encoding(variable) - if attrs.get("enum"): + if isinstance(attrs.get("enum"), EnumMeta): enum = attrs.pop("enum") enum_name = enum.__name__ - attrs["flag_meanings"] = tuple(i.name for i in enum) - attrs["flag_values"] = tuple(i.value for i in enum) + attrs["flag_meanings"] = " ".join(i.name for i in enum) + attrs["flag_values"] = ", ".join(str(i.value) for i in enum) attrs["enum"] = enum_name return Variable(dims, data, attrs, encoding, fastpath=True) def decode(self, variable: Variable, name: T_Name = None) -> Variable: - """From CF to python Enum""" + """From CF flag_* to python Enum""" dims, data, attrs, encoding = unpack_for_decoding(variable) if ( attrs.get("enum") @@ -600,7 +600,9 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: and attrs.get("flag_values") ): flag_meanings = attrs.pop("flag_meanings") + flag_meanings = flag_meanings.split(" ") flag_values = attrs.pop("flag_values") + flag_values = [int(v) for v in flag_values.split(", ")] enum_name = attrs.pop("enum") enum_dict = {k: v for k, v in zip(flag_meanings, flag_values)} attrs["enum"] = Enum(enum_name, enum_dict) diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index f7579c4b488..7152ffd8638 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import suppress +from enum import Enum, EnumMeta import numpy as np import pandas as pd @@ -146,3 +147,31 @@ def test_decode_signed_from_unsigned(bits) -> None: decoded = coder.decode(encoded) assert decoded.dtype == signed_dtype assert decoded.values == original_values + + +def test_decode_enum() -> None: + encoded = xr.Variable( + ("x",), + [42], + attrs={ + "flag_values": "0, 1", + "flag_meanings": "flag galf", + "enum": "a_flag_name", + }, + ) + coder = variables.EnumCoder() + decoded = coder.decode(encoded) + assert isinstance(decoded.attrs["enum"], EnumMeta) + assert decoded.attrs["enum"].flag.value == 0 + assert decoded.attrs["enum"].galf.value == 1 + + +def test_encode_enum() -> None: + decoded = xr.Variable( + ("x",), [42], attrs={"enum": Enum("an_enum", {"flag": 0, "galf": 1})} + ) + coder = variables.EnumCoder() + encoded = coder.encode(decoded) + assert encoded.attrs["enum"] == "an_enum" + assert encoded.attrs["flag_values"] == "0, 1" + assert encoded.attrs["flag_meanings"] == "flag galf" From 26bb8ce49110bb4480f4d1019faee862cd0ad120 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 9 Jan 2024 11:46:36 +0100 Subject: [PATCH 27/37] ENH: Update error handling of decoding Preserve stacktrace using `from` when raise is used to wrap an existing exception. --- xarray/conventions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/conventions.py b/xarray/conventions.py index d5903752d94..aa3a6fc4fc4 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -453,7 +453,7 @@ def stackable(dim: Hashable) -> bool: decode_enum=decode_enum, ) except Exception as e: - raise type(e)(f"Failed to decode variable {k!r}: {e}") + raise type(e)(f"Failed to decode variable {k!r}: {e}") from e if decode_coords in [True, "coordinates", "all"]: var_attrs = new_vars[k].attrs if "coordinates" in var_attrs: From d21d73a3194150ca97bc6b68d20f1abcaf9d00c4 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Tue, 9 Jan 2024 14:01:43 +0100 Subject: [PATCH 28/37] ENH: Avoid encoding enum to CF --- toto.nc | Bin 0 -> 393 bytes xarray/backends/netCDF4_.py | 18 +++++++----------- xarray/coding/variables.py | 16 ++++------------ xarray/conventions.py | 2 -- xarray/tests/test_coding.py | 13 +------------ 5 files changed, 12 insertions(+), 37 deletions(-) create mode 100644 toto.nc diff --git a/toto.nc b/toto.nc new file mode 100644 index 0000000000000000000000000000000000000000..cb541d389f12dbb119d78a0fc4558a1564557b59 GIT binary patch literal 393 zcmeD5aB<`1lHy|G;9!7(|4`7$2oW)WN*M8(Rr!0k1Tpb!VNwE%F+)`_z_g&#TucmL zC2Ue4^^7b~lNi`@D^v4IbB*3Z1w>^SKypCy&`p5x1q47!&7c+rL-jGJ!5D<;bAR@} z06Smdw))&uWURxi!N6+4#J~Uw0|6ib2@3~?M1~}|xiGK3>P_G7@9g2t;|h{uX5a&= z2D&}o&pDtdzaX`!Br~;`K^Q0n3L>DH3}vZB#hLkewnjR6sU^uNX|^VMmU>1y87XO| PwuX9!CVEEAem#N!>U?GE literal 0 HcmV?d00001 diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index ca7f052455c..029a47151f0 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -5,6 +5,7 @@ import os from collections.abc import Iterable from contextlib import suppress +from enum import Enum from typing import TYPE_CHECKING, Any import numpy as np @@ -421,11 +422,7 @@ def open_store_variable(self, name: str, var): attributes = {k: var.getncattr(k) for k in var.ncattrs()} data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) if isinstance(var.datatype, netCDF4.EnumType): - enum_dict = var.datatype.enum_dict - enum_name = var.datatype.name - attributes["enum"] = enum_name - attributes["flag_values"] = tuple(enum_dict.values()) - attributes["flag_meanings"] = tuple(enum_dict.keys()) + attributes["enum"] = Enum(var.datatype.name, var.datatype.enum_dict) _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later encoding = {} @@ -537,10 +534,9 @@ def prepare_variable( def _build_and_get_enum( self, var_name: str, attributes: dict, dtype: np.dtype ) -> object: - flag_meanings = attributes.pop("flag_meanings") - flag_values = attributes.pop("flag_values") - enum_name = attributes.pop("enum") - enum_dict = {k: v for k, v in zip(flag_meanings, flag_values)} + enum = attributes.pop("enum") + enum_dict = {e.name: e.value for e in enum} + enum_name = enum.__name__ if enum_name in self.ds.enumtypes: datatype = self.ds.enumtypes[enum_name] if datatype.enum_dict != enum_dict: @@ -548,9 +544,9 @@ def _build_and_get_enum( f"Cannot save variable `{var_name}` because an enum" f" `{enum_name}` already exists in the Dataset but have" " a different definition. To fix this error, make sure" - " each variable have a unique name for their `attrs['enum']`" + " each variable have a unique name in `attrs['enum']`" " or, if they should share same enum type, make sure" - " their flag_values and flag_meanings are identical." + " the enums are identical." ) raise ValueError(error_msg) else: diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 1f1caf4bad5..1d0830dfd4c 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -3,7 +3,7 @@ import warnings from collections.abc import Hashable, MutableMapping -from enum import Enum, EnumMeta +from enum import Enum from functools import partial from typing import TYPE_CHECKING, Any, Callable, Union @@ -567,7 +567,7 @@ def decode(self): class ObjectVLenStringCoder(VariableCoder): def encode(self): - return NotImplementedError + raise NotImplementedError def decode(self, variable: Variable, name: T_Name = None) -> Variable: if variable.dtype == object and variable.encoding.get("dtype", False) == str: @@ -578,18 +578,10 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: class EnumCoder(VariableCoder): - """Encode and decode Enum to CF""" + """Decode CF flag_* to python Enum""" def encode(self, variable: Variable, name: T_Name = None) -> Variable: - """From python Enum to CF flag_*""" - dims, data, attrs, encoding = unpack_for_encoding(variable) - if isinstance(attrs.get("enum"), EnumMeta): - enum = attrs.pop("enum") - enum_name = enum.__name__ - attrs["flag_meanings"] = " ".join(i.name for i in enum) - attrs["flag_values"] = ", ".join(str(i.value) for i in enum) - attrs["enum"] = enum_name - return Variable(dims, data, attrs, encoding, fastpath=True) + raise NotImplementedError def decode(self, variable: Variable, name: T_Name = None) -> Variable: """From CF flag_* to python Enum""" diff --git a/xarray/conventions.py b/xarray/conventions.py index aa3a6fc4fc4..c007c29cae6 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -165,7 +165,6 @@ def encode_cf_variable( - Rescaling via: scale_factor and add_offset - datetimes are converted to the CF 'units since time' format - dtype encodings are enforced. - - enum is turned into flag_values and flag_meanings Parameters ---------- @@ -188,7 +187,6 @@ def encode_cf_variable( variables.NonStringCoder(), variables.DefaultFillvalueCoder(), variables.BooleanCoder(), - variables.EnumCoder(), ]: var = coder.encode(var, name=name) diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 7152ffd8638..13344630056 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -1,7 +1,7 @@ from __future__ import annotations from contextlib import suppress -from enum import Enum, EnumMeta +from enum import EnumMeta import numpy as np import pandas as pd @@ -164,14 +164,3 @@ def test_decode_enum() -> None: assert isinstance(decoded.attrs["enum"], EnumMeta) assert decoded.attrs["enum"].flag.value == 0 assert decoded.attrs["enum"].galf.value == 1 - - -def test_encode_enum() -> None: - decoded = xr.Variable( - ("x",), [42], attrs={"enum": Enum("an_enum", {"flag": 0, "galf": 1})} - ) - coder = variables.EnumCoder() - encoded = coder.encode(decoded) - assert encoded.attrs["enum"] == "an_enum" - assert encoded.attrs["flag_values"] == "0, 1" - assert encoded.attrs["flag_meanings"] == "flag galf" From 81a4beceddc8449126a8d8b11fc978109cfbe3a8 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 10 Jan 2024 16:07:12 +0100 Subject: [PATCH 29/37] ENH: encode netcdf4 enum within dtype --- xarray/backends/netCDF4_.py | 63 ++++++++++++++++++++--------------- xarray/tests/test_backends.py | 52 +++++++++++++++++------------ 2 files changed, 67 insertions(+), 48 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 029a47151f0..4cf90e4c9b8 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -5,7 +5,6 @@ import os from collections.abc import Iterable from contextlib import suppress -from enum import Enum from typing import TYPE_CHECKING, Any import numpy as np @@ -141,7 +140,9 @@ def _check_encoding_dtype_is_vlen_string(dtype): ) -def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False): +def _get_datatype( + var, nc_format="NETCDF4", raise_on_invalid_encoding=False +) -> np.dtype | None: if nc_format == "NETCDF4": return _nc4_dtype(var) if "dtype" in var.encoding: @@ -421,11 +422,19 @@ def open_store_variable(self, name: str, var): dimensions = var.dimensions attributes = {k: var.getncattr(k) for k in var.ncattrs()} data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) + encoding = {} if isinstance(var.datatype, netCDF4.EnumType): - attributes["enum"] = Enum(var.datatype.name, var.datatype.enum_dict) + encoding["dtype"] = np.dtype( + data.dtype, + metadata={ + "enum_dict": var.datatype.enum_dict, + "enum_name": var.datatype.name, + }, + ) + else: + encoding["dtype"] = var.dtype _ensure_fill_value_valid(data, attributes) # netCDF4 specific encoding; save _FillValue for later - encoding = {} filters = var.filters() if filters is not None: encoding.update(filters) @@ -445,7 +454,6 @@ def open_store_variable(self, name: str, var): # save source so __repr__ can detect if it's local or not encoding["source"] = self._filename encoding["original_shape"] = var.shape - encoding["dtype"] = var.dtype return Variable(dimensions, data, attributes, encoding) @@ -495,8 +503,12 @@ def prepare_variable( _ensure_no_forward_slash_in_name(name) attrs = variable.attrs.copy() fill_value = attrs.pop("_FillValue", None) - if attrs.get("enum"): - datatype = self._build_and_get_enum(name, attrs, variable.dtype) + if ( + variable.dtype.metadata + and variable.dtype.metadata.get("enum_name") + and variable.dtype.metadata.get("enum_dict") + ): + datatype = self._build_and_get_enum(name, variable.dtype) else: datatype = _get_datatype( variable, self.format, raise_on_invalid_encoding=check_encoding @@ -531,30 +543,27 @@ def prepare_variable( return target, variable.data - def _build_and_get_enum( - self, var_name: str, attributes: dict, dtype: np.dtype - ) -> object: - enum = attributes.pop("enum") - enum_dict = {e.name: e.value for e in enum} - enum_name = enum.__name__ - if enum_name in self.ds.enumtypes: - datatype = self.ds.enumtypes[enum_name] - if datatype.enum_dict != enum_dict: - error_msg = ( - f"Cannot save variable `{var_name}` because an enum" - f" `{enum_name}` already exists in the Dataset but have" - " a different definition. To fix this error, make sure" - " each variable have a unique name in `attrs['enum']`" - " or, if they should share same enum type, make sure" - " the enums are identical." - ) - raise ValueError(error_msg) - else: - datatype = self.ds.createEnumType( + def _build_and_get_enum(self, var_name: str, dtype: np.dtype) -> object: + """Add or get the netCDF4 Enum based on the dtype in encoding.""" + enum_dict = dtype.metadata["enum_dict"] + enum_name = dtype.metadata["enum_name"] + if enum_name not in self.ds.enumtypes: + return self.ds.createEnumType( dtype, enum_name, enum_dict, ) + datatype = self.ds.enumtypes[enum_name] + if datatype.enum_dict != enum_dict: + error_msg = ( + f"Cannot save variable `{var_name}` because an enum" + f" `{enum_name}` already exists in the Dataset but have" + " a different definition. To fix this error, make sure" + " each variable have a uniquely named enum in their" + " `encoding['dtype']` or, if they should share same enum type," + " make sure the enums are identical." + ) + raise ValueError(error_msg) return datatype def sync(self): diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index b3476454d66..22ba82a5cda 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -15,7 +15,6 @@ import warnings from collections.abc import Generator, Iterator from contextlib import ExitStack -from enum import Enum from io import BytesIO from os import listdir from pathlib import Path @@ -1720,9 +1719,10 @@ def test_encoding_enum__no_fill_value(self): ) v[:] = 1 with open_dataset(tmp_file, decode_enum=True) as ds: - assert { - i.name: i.value for i in ds.clouds.attrs["enum"] - } == cloud_type_dict + assert ( + ds.clouds.encoding["dtype"].metadata["enum_dict"] == cloud_type_dict + ) + assert ds.clouds.encoding["dtype"].metadata["enum_name"] == "cloud_type" with create_tmp_file() as tmp_file2: ds.to_netcdf(tmp_file2) @@ -1745,12 +1745,18 @@ def test_encoding_enum__multiple_variable_with_enum(self): "time", fill_value=255, ) - with open_dataset(tmp_file, decode_enum=True) as ds: - with create_tmp_file() as tmp_file2: - ds.to_netcdf(tmp_file2) + with open_dataset( + tmp_file, decode_enum=True + ) as ds, create_tmp_file() as tmp_file2: + # nothing to assert ; just make sure round trip ca be done + ds.to_netcdf(tmp_file2) @requires_netCDF4 - def test_encoding_enum__multiple_variable_with_changing_enum(self): + def test_encoding_enum__error_multiple_variable_with_changing_enum(self): + """ + Given 2 variables, if they share the same enum type, + the 2 enum definition should be identical. + """ with create_tmp_file() as tmp_file: cloud_type_dict = {"clear": 0, "cloudy": 1, "missing": 255} with nc4.Dataset(tmp_file, mode="w") as nc: @@ -1768,19 +1774,23 @@ def test_encoding_enum__multiple_variable_with_changing_enum(self): "time", fill_value=255, ) - with open_dataset(tmp_file, decode_enum=True) as ds: - cloud_enum_dict = {e.name: e.value for e in ds.cloud.attrs["enum"]} - cloud_enum_dict.update({"neblig": 2}) - ds.cloud.attrs["enum"] = Enum("cloud_type", cloud_enum_dict) - with create_tmp_file() as tmp_file2: - with pytest.raises( - ValueError, - match=( - "Cannot save variable .*" - " because an enum `cloud_type` already exists in the Dataset .*" - ), - ): - ds.to_netcdf(tmp_file2) + with open_dataset( + tmp_file, decode_enum=True + ) as ds, create_tmp_file() as tmp_file2: + modified_enum = ds.cloud.encoding["dtype"].metadata["enum_dict"] + modified_enum.update({"neblig": 2}) + ds.cloud.encoding["dtype"] = np.dtype( + "u1", + metadata={"enum_dict": modified_enum, "enum_name": "cloud_type"}, + ) + with pytest.raises( + ValueError, + match=( + "Cannot save variable .*" + " because an enum `cloud_type` already exists in the Dataset .*" + ), + ): + ds.to_netcdf(tmp_file2) @requires_netCDF4 From b114ccc0e1c20a58c731abe267676a0ab20735d9 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 10 Jan 2024 16:17:04 +0100 Subject: [PATCH 30/37] MAINT: Remove CF flag_* encoding --- xarray/backends/api.py | 4 ---- xarray/backends/netCDF4_.py | 2 -- xarray/backends/store.py | 2 -- xarray/coding/variables.py | 26 -------------------------- xarray/conventions.py | 12 ------------ xarray/tests/test_backends.py | 10 +++------- xarray/tests/test_coding.py | 18 ------------------ 7 files changed, 3 insertions(+), 71 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9cee5529b0c..415a60a5e0d 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -408,7 +408,6 @@ def open_dataset( chunked_array_type: str | None = None, from_array_kwargs: dict[str, Any] | None = None, backend_kwargs: dict[str, Any] | None = None, - decode_enum: bool | None = None, **kwargs, ) -> Dataset: """Open and decode a dataset from a file or file-like object. @@ -514,8 +513,6 @@ def open_dataset( backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. - decode_enum: bool, optional - If True, decode CF flag_values and flag_meanings into a pyton Enum. **kwargs: dict Additional keyword arguments passed on to the engine open function. For example: @@ -570,7 +567,6 @@ def open_dataset( concat_characters=concat_characters, use_cftime=use_cftime, decode_coords=decode_coords, - decode_enum=decode_enum, ) overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 4cf90e4c9b8..12f8e469c30 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -636,7 +636,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti persist=False, lock=None, autoclose=False, - decode_enum: bool | None = None, ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) store = NetCDF4DataStore.open( @@ -662,7 +661,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti drop_variables=drop_variables, use_cftime=use_cftime, decode_timedelta=decode_timedelta, - decode_enum=decode_enum, ) return ds diff --git a/xarray/backends/store.py b/xarray/backends/store.py index 7af08aaade7..a507ee37470 100644 --- a/xarray/backends/store.py +++ b/xarray/backends/store.py @@ -37,7 +37,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti drop_variables: str | Iterable[str] | None = None, use_cftime=None, decode_timedelta=None, - decode_enum: bool | None = None, ) -> Dataset: assert isinstance(filename_or_obj, AbstractDataStore) @@ -54,7 +53,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti drop_variables=drop_variables, use_cftime=use_cftime, decode_timedelta=decode_timedelta, - decode_enum=decode_enum, ) ds = Dataset(vars, attrs=attrs) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 1d0830dfd4c..fb4da537f42 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -3,7 +3,6 @@ import warnings from collections.abc import Hashable, MutableMapping -from enum import Enum from functools import partial from typing import TYPE_CHECKING, Any, Callable, Union @@ -575,28 +574,3 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: return variable else: return variable - - -class EnumCoder(VariableCoder): - """Decode CF flag_* to python Enum""" - - def encode(self, variable: Variable, name: T_Name = None) -> Variable: - raise NotImplementedError - - def decode(self, variable: Variable, name: T_Name = None) -> Variable: - """From CF flag_* to python Enum""" - dims, data, attrs, encoding = unpack_for_decoding(variable) - if ( - attrs.get("enum") - and attrs.get("flag_meanings") - and attrs.get("flag_values") - ): - flag_meanings = attrs.pop("flag_meanings") - flag_meanings = flag_meanings.split(" ") - flag_values = attrs.pop("flag_values") - flag_values = [int(v) for v in flag_values.split(", ")] - enum_name = attrs.pop("enum") - enum_dict = {k: v for k, v in zip(flag_meanings, flag_values)} - attrs["enum"] = Enum(enum_name, enum_dict) - return Variable(dims, data, attrs, encoding, fastpath=True) - return variable diff --git a/xarray/conventions.py b/xarray/conventions.py index c007c29cae6..6f9d96c8663 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -208,7 +208,6 @@ def decode_cf_variable( stack_char_dim: bool = True, use_cftime: bool | None = None, decode_timedelta: bool | None = None, - decode_enum: bool | None = None, ) -> Variable: """ Decodes a variable which may hold CF encoded information. @@ -249,8 +248,6 @@ def decode_cf_variable( represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible raise an error. - decode_enum: bool, optional - Turn the CF flag_values and flag_meanings into a python Enum in `attrs['enum']`. Returns ------- @@ -294,9 +291,6 @@ def decode_cf_variable( var = variables.BooleanCoder().decode(var) - if decode_enum: - var = variables.EnumCoder().decode(var) - dimensions, data, attributes, encoding = variables.unpack_for_decoding(var) encoding.setdefault("dtype", original_dtype) @@ -395,7 +389,6 @@ def decode_cf_variables( drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, - decode_enum: bool | None = None, ) -> tuple[T_Variables, T_Attrs, set[Hashable]]: """ Decode several CF encoded variables. @@ -448,7 +441,6 @@ def stackable(dim: Hashable) -> bool: stack_char_dim=stack_char_dim, use_cftime=use_cftime, decode_timedelta=decode_timedelta, - decode_enum=decode_enum, ) except Exception as e: raise type(e)(f"Failed to decode variable {k!r}: {e}") from e @@ -513,7 +505,6 @@ def decode_cf( drop_variables: T_DropVariables = None, use_cftime: bool | None = None, decode_timedelta: bool | None = None, - decode_enum: bool = True, ) -> Dataset: """Decode the given Dataset or Datastore according to CF conventions into a new Dataset. @@ -592,7 +583,6 @@ def decode_cf( drop_variables=drop_variables, use_cftime=use_cftime, decode_timedelta=decode_timedelta, - decode_enum=decode_enum, ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) @@ -608,7 +598,6 @@ def cf_decoder( concat_characters: bool = True, mask_and_scale: bool = True, decode_times: bool = True, - decode_enum: bool = True, ) -> tuple[T_Variables, T_Attrs]: """ Decode a set of CF encoded variables and attributes. @@ -645,7 +634,6 @@ def cf_decoder( concat_characters, mask_and_scale, decode_times, - decode_enum=decode_enum, ) return variables, attributes diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 22ba82a5cda..0677d8b60bb 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1718,7 +1718,7 @@ def test_encoding_enum__no_fill_value(self): fill_value=None, ) v[:] = 1 - with open_dataset(tmp_file, decode_enum=True) as ds: + with open_dataset(tmp_file) as ds: assert ( ds.clouds.encoding["dtype"].metadata["enum_dict"] == cloud_type_dict ) @@ -1745,9 +1745,7 @@ def test_encoding_enum__multiple_variable_with_enum(self): "time", fill_value=255, ) - with open_dataset( - tmp_file, decode_enum=True - ) as ds, create_tmp_file() as tmp_file2: + with open_dataset(tmp_file) as ds, create_tmp_file() as tmp_file2: # nothing to assert ; just make sure round trip ca be done ds.to_netcdf(tmp_file2) @@ -1774,9 +1772,7 @@ def test_encoding_enum__error_multiple_variable_with_changing_enum(self): "time", fill_value=255, ) - with open_dataset( - tmp_file, decode_enum=True - ) as ds, create_tmp_file() as tmp_file2: + with open_dataset(tmp_file) as ds, create_tmp_file() as tmp_file2: modified_enum = ds.cloud.encoding["dtype"].metadata["enum_dict"] modified_enum.update({"neblig": 2}) ds.cloud.encoding["dtype"] = np.dtype( diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 13344630056..f7579c4b488 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -1,7 +1,6 @@ from __future__ import annotations from contextlib import suppress -from enum import EnumMeta import numpy as np import pandas as pd @@ -147,20 +146,3 @@ def test_decode_signed_from_unsigned(bits) -> None: decoded = coder.decode(encoded) assert decoded.dtype == signed_dtype assert decoded.values == original_values - - -def test_decode_enum() -> None: - encoded = xr.Variable( - ("x",), - [42], - attrs={ - "flag_values": "0, 1", - "flag_meanings": "flag galf", - "enum": "a_flag_name", - }, - ) - coder = variables.EnumCoder() - decoded = coder.decode(encoded) - assert isinstance(decoded.attrs["enum"], EnumMeta) - assert decoded.attrs["enum"].flag.value == 0 - assert decoded.attrs["enum"].galf.value == 1 From 6376a136713a87d1dcee8298b84ab0c11b8e51c2 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Wed, 10 Jan 2024 22:07:17 +0100 Subject: [PATCH 31/37] Add assertion after roundtrip in enum tests --- doc/whats-new.rst | 6 +++--- xarray/backends/api.py | 3 +-- xarray/backends/netCDF4_.py | 11 +++++------ xarray/core/dataarray.py | 6 ++---- xarray/tests/test_backends.py | 13 +++++++------ 5 files changed, 18 insertions(+), 21 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ce2925009d9..db7556aec15 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -221,9 +221,9 @@ New Features - Use `opt_einsum `_ for :py:func:`xarray.dot` by default if installed. By `Deepak Cherian `_. (:issue:`7764`, :pull:`8373`). -- Open netCDF4 enums and turn them into CF flag_meanings/flag_values. - This also gives a special meaning to the 'enum' attribute in DataArrays, when it is set, this tells the netCDF4 backend - to turn flag_meanings and flag_values into a netCDF4 Enum named using ``attrs["enum"]`` content. +- Decode/Encode netCDF4 enums and store the enum definition in dataarrays' dtype metadata. + If multiple variables share the same enum in netCDF4, each dataarray will have its own + enum definition in their respective dtype metadata. By `Abel Aoun _`(:issue:`8144`, :pull:`8147`) - Add ``DataArray.dt.total_seconds()`` method to match the Pandas API. (:pull:`8435`). By `Ben Mares `_. diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 415a60a5e0d..1d538bf94ed 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -2,7 +2,6 @@ import os from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence -from enum import EnumMeta from functools import partial from io import BytesIO from numbers import Number @@ -173,7 +172,7 @@ def _validate_attrs(dataset, invalid_netcdf=False): `invalid_netcdf=True`. """ - valid_types = (str, Number, np.ndarray, np.number, list, tuple, EnumMeta) + valid_types = (str, Number, np.ndarray, np.number, list, tuple) if invalid_netcdf: valid_types += (np.bool_,) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 12f8e469c30..78ec795dbb0 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -258,7 +258,6 @@ def _extract_nc4_variable_encoding( "_FillValue", "dtype", "compression", - "enum", "significant_digits", "quantize_mode", "blosc_shuffle", @@ -427,7 +426,7 @@ def open_store_variable(self, name: str, var): encoding["dtype"] = np.dtype( data.dtype, metadata={ - "enum_dict": var.datatype.enum_dict, + "enum": var.datatype.enum_dict, "enum_name": var.datatype.name, }, ) @@ -506,7 +505,7 @@ def prepare_variable( if ( variable.dtype.metadata and variable.dtype.metadata.get("enum_name") - and variable.dtype.metadata.get("enum_dict") + and variable.dtype.metadata.get("enum") ): datatype = self._build_and_get_enum(name, variable.dtype) else: @@ -545,7 +544,7 @@ def prepare_variable( def _build_and_get_enum(self, var_name: str, dtype: np.dtype) -> object: """Add or get the netCDF4 Enum based on the dtype in encoding.""" - enum_dict = dtype.metadata["enum_dict"] + enum_dict = dtype.metadata["enum"] enum_name = dtype.metadata["enum_name"] if enum_name not in self.ds.enumtypes: return self.ds.createEnumType( @@ -560,8 +559,8 @@ def _build_and_get_enum(self, var_name: str, dtype: np.dtype) -> object: f" `{enum_name}` already exists in the Dataset but have" " a different definition. To fix this error, make sure" " each variable have a uniquely named enum in their" - " `encoding['dtype']` or, if they should share same enum type," - " make sure the enums are identical." + " `encoding['dtype'].metadata` or, if they should share same" + " the same enum type, make sure the enums are identical." ) raise ValueError(error_msg) return datatype diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 98c989a55cd..2c2ec0c2fc1 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4069,10 +4069,8 @@ def to_netcdf( name is the same as a coordinate name, then it is given the name ``"__xarray_dataarray_variable__"``. - [netCDF4 backend only] When the CF flag_values/flag_meanings attributes are - set in for this DataArray, you can choose to replace these attributes by - a netcdf4 EnumType by updating the encoding dictionary with a key value pair - like: `da.attrs["enum"] = "enum_name"`. + [netCDF4 backend only] netCDF4 enums are decoded into the + dataarray dtype metadata. See Also -------- diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 0677d8b60bb..3751e3131c1 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1719,12 +1719,12 @@ def test_encoding_enum__no_fill_value(self): ) v[:] = 1 with open_dataset(tmp_file) as ds: - assert ( - ds.clouds.encoding["dtype"].metadata["enum_dict"] == cloud_type_dict - ) + assert ds.clouds.encoding["dtype"].metadata["enum"] == cloud_type_dict assert ds.clouds.encoding["dtype"].metadata["enum_name"] == "cloud_type" with create_tmp_file() as tmp_file2: ds.to_netcdf(tmp_file2) + with nc4.Dataset(tmp_file2, "r") as nc: + assert nc.enumtypes["cloud_type"] == cloud_type_dict @requires_netCDF4 def test_encoding_enum__multiple_variable_with_enum(self): @@ -1746,8 +1746,9 @@ def test_encoding_enum__multiple_variable_with_enum(self): fill_value=255, ) with open_dataset(tmp_file) as ds, create_tmp_file() as tmp_file2: - # nothing to assert ; just make sure round trip ca be done ds.to_netcdf(tmp_file2) + with nc4.Dataset(tmp_file2, "r") as nc: + assert nc.enumtypes["cloud_type"] == cloud_type_dict @requires_netCDF4 def test_encoding_enum__error_multiple_variable_with_changing_enum(self): @@ -1773,11 +1774,11 @@ def test_encoding_enum__error_multiple_variable_with_changing_enum(self): fill_value=255, ) with open_dataset(tmp_file) as ds, create_tmp_file() as tmp_file2: - modified_enum = ds.cloud.encoding["dtype"].metadata["enum_dict"] + modified_enum = ds.cloud.encoding["dtype"].metadata["enum"] modified_enum.update({"neblig": 2}) ds.cloud.encoding["dtype"] = np.dtype( "u1", - metadata={"enum_dict": modified_enum, "enum_name": "cloud_type"}, + metadata={"enum": modified_enum, "enum_name": "cloud_type"}, ) with pytest.raises( ValueError, From 89a8751997b388237a7a15914ce245f081b11bcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Thu, 11 Jan 2024 09:44:08 +0100 Subject: [PATCH 32/37] add NativeEnumCoder, adapt tests --- xarray/coding/variables.py | 19 +++++++++ xarray/conventions.py | 1 + xarray/tests/test_backends.py | 77 ++++++++++++++++++++++++----------- 3 files changed, 73 insertions(+), 24 deletions(-) diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index fb4da537f42..c3d57ad1903 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -574,3 +574,22 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable: return variable else: return variable + + +class NativeEnumCoder(VariableCoder): + """Encode Enum into variable dtype metadata.""" + + def encode(self, variable: Variable, name: T_Name = None) -> Variable: + if ( + "dtype" in variable.encoding + and np.dtype(variable.encoding["dtype"]).metadata + and "enum" in variable.encoding["dtype"].metadata + ): + dims, data, attrs, encoding = unpack_for_encoding(variable) + data = data.astype(dtype=variable.encoding.pop("dtype")) + return Variable(dims, data, attrs, encoding, fastpath=True) + else: + return variable + + def decode(self, variable: Variable, name: T_Name = None) -> Variable: + raise NotImplementedError() diff --git a/xarray/conventions.py b/xarray/conventions.py index 6f9d96c8663..1d8e81e1bf2 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -184,6 +184,7 @@ def encode_cf_variable( variables.CFScaleOffsetCoder(), variables.CFMaskCoder(), variables.UnsignedIntegerCoder(), + variables.NativeEnumCoder(), variables.NonStringCoder(), variables.DefaultFillvalueCoder(), variables.BooleanCoder(), diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3751e3131c1..f279e9ea7d4 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1718,13 +1718,19 @@ def test_encoding_enum__no_fill_value(self): fill_value=None, ) v[:] = 1 - with open_dataset(tmp_file) as ds: - assert ds.clouds.encoding["dtype"].metadata["enum"] == cloud_type_dict - assert ds.clouds.encoding["dtype"].metadata["enum_name"] == "cloud_type" - with create_tmp_file() as tmp_file2: - ds.to_netcdf(tmp_file2) - with nc4.Dataset(tmp_file2, "r") as nc: - assert nc.enumtypes["cloud_type"] == cloud_type_dict + with open_dataset(tmp_file) as original: + with self.roundtrip(original) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) @requires_netCDF4 def test_encoding_enum__multiple_variable_with_enum(self): @@ -1734,7 +1740,7 @@ def test_encoding_enum__multiple_variable_with_enum(self): nc.createDimension("time", size=2) cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) nc.createVariable( - "cloud", + "clouds", cloud_type, "time", fill_value=255, @@ -1745,10 +1751,26 @@ def test_encoding_enum__multiple_variable_with_enum(self): "time", fill_value=255, ) - with open_dataset(tmp_file) as ds, create_tmp_file() as tmp_file2: - ds.to_netcdf(tmp_file2) - with nc4.Dataset(tmp_file2, "r") as nc: - assert nc.enumtypes["cloud_type"] == cloud_type_dict + with open_dataset(tmp_file) as original: + with self.roundtrip(original) as actual: + assert_equal(original, actual) + assert ( + actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"] + ) + assert ( + actual.clouds.encoding["dtype"].metadata + == actual.tifa.encoding["dtype"].metadata + ) + assert ( + actual.clouds.encoding["dtype"].metadata["enum"] + == cloud_type_dict + ) + if self.engine != "h5netcdf": + # not implemented in h5netcdf yet + assert ( + actual.clouds.encoding["dtype"].metadata["enum_name"] + == "cloud_type" + ) @requires_netCDF4 def test_encoding_enum__error_multiple_variable_with_changing_enum(self): @@ -1762,7 +1784,7 @@ def test_encoding_enum__error_multiple_variable_with_changing_enum(self): nc.createDimension("time", size=2) cloud_type = nc.createEnumType("u1", "cloud_type", cloud_type_dict) nc.createVariable( - "cloud", + "clouds", cloud_type, "time", fill_value=255, @@ -1773,21 +1795,28 @@ def test_encoding_enum__error_multiple_variable_with_changing_enum(self): "time", fill_value=255, ) - with open_dataset(tmp_file) as ds, create_tmp_file() as tmp_file2: - modified_enum = ds.cloud.encoding["dtype"].metadata["enum"] + with open_dataset(tmp_file) as original: + assert ( + original.clouds.encoding["dtype"].metadata + == original.tifa.encoding["dtype"].metadata + ) + modified_enum = original.clouds.encoding["dtype"].metadata["enum"] modified_enum.update({"neblig": 2}) - ds.cloud.encoding["dtype"] = np.dtype( + original.clouds.encoding["dtype"] = np.dtype( "u1", metadata={"enum": modified_enum, "enum_name": "cloud_type"}, ) - with pytest.raises( - ValueError, - match=( - "Cannot save variable .*" - " because an enum `cloud_type` already exists in the Dataset .*" - ), - ): - ds.to_netcdf(tmp_file2) + if self.engine != "h5netcdf": + # not implemented yet in h5netcdf + with pytest.raises( + ValueError, + match=( + "Cannot save variable .*" + " because an enum `cloud_type` already exists in the Dataset .*" + ), + ): + with self.roundtrip(original): + pass @requires_netCDF4 From ac20a40f07b4cd6d90793dd2514734b268a5b098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Thu, 11 Jan 2024 09:46:05 +0100 Subject: [PATCH 33/37] remove test-file --- toto.nc | Bin 393 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 toto.nc diff --git a/toto.nc b/toto.nc deleted file mode 100644 index cb541d389f12dbb119d78a0fc4558a1564557b59..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 393 zcmeD5aB<`1lHy|G;9!7(|4`7$2oW)WN*M8(Rr!0k1Tpb!VNwE%F+)`_z_g&#TucmL zC2Ue4^^7b~lNi`@D^v4IbB*3Z1w>^SKypCy&`p5x1q47!&7c+rL-jGJ!5D<;bAR@} z06Smdw))&uWURxi!N6+4#J~Uw0|6ib2@3~?M1~}|xiGK3>P_G7@9g2t;|h{uX5a&= z2D&}o&pDtdzaX`!Br~;`K^Q0n3L>DH3}vZB#hLkewnjR6sU^uNX|^VMmU>1y87XO| PwuX9!CVEEAem#N!>U?GE From d515e0d017c4ca2ce91898da41d4458b85684d30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Thu, 11 Jan 2024 09:55:57 +0100 Subject: [PATCH 34/37] restructure datatype extraction --- xarray/backends/netCDF4_.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 78ec795dbb0..90402d732fc 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -502,16 +502,16 @@ def prepare_variable( _ensure_no_forward_slash_in_name(name) attrs = variable.attrs.copy() fill_value = attrs.pop("_FillValue", None) + datatype = _get_datatype( + variable, self.format, raise_on_invalid_encoding=check_encoding + ) + # check enum metadata and use netCDF4.EnumType if ( - variable.dtype.metadata - and variable.dtype.metadata.get("enum_name") - and variable.dtype.metadata.get("enum") + np.dtype(datatype).metadata + and datatype.metadata.get("enum_name") + and datatype.metadata.get("enum") ): - datatype = self._build_and_get_enum(name, variable.dtype) - else: - datatype = _get_datatype( - variable, self.format, raise_on_invalid_encoding=check_encoding - ) + datatype = self._build_and_get_enum(name, datatype) encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) From 5c665639adf30916b157dbd0c239e380abadd39c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Thu, 11 Jan 2024 10:05:55 +0100 Subject: [PATCH 35/37] use invalid_netcdf for h5netcdf tests --- xarray/tests/test_backends.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index f279e9ea7d4..d01cfd7ff55 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1719,7 +1719,10 @@ def test_encoding_enum__no_fill_value(self): ) v[:] = 1 with open_dataset(tmp_file) as original: - with self.roundtrip(original) as actual: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: assert_equal(original, actual) assert ( actual.clouds.encoding["dtype"].metadata["enum"] @@ -1752,7 +1755,10 @@ def test_encoding_enum__multiple_variable_with_enum(self): fill_value=255, ) with open_dataset(tmp_file) as original: - with self.roundtrip(original) as actual: + save_kwargs = {} + if self.engine == "h5netcdf": + save_kwargs["invalid_netcdf"] = True + with self.roundtrip(original, save_kwargs=save_kwargs) as actual: assert_equal(original, actual) assert ( actual.clouds.encoding["dtype"] == actual.tifa.encoding["dtype"] From d62ac2908fb7b251d92e38bec09f240e7fb10f62 Mon Sep 17 00:00:00 2001 From: Abel Aoun Date: Thu, 11 Jan 2024 11:55:31 +0100 Subject: [PATCH 36/37] FIX: encoding typing --- xarray/backends/netCDF4_.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 90402d732fc..45f7d38651b 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -142,7 +142,7 @@ def _check_encoding_dtype_is_vlen_string(dtype): def _get_datatype( var, nc_format="NETCDF4", raise_on_invalid_encoding=False -) -> np.dtype | None: +) -> np.dtype: if nc_format == "NETCDF4": return _nc4_dtype(var) if "dtype" in var.encoding: @@ -421,7 +421,7 @@ def open_store_variable(self, name: str, var): dimensions = var.dimensions attributes = {k: var.getncattr(k) for k in var.ncattrs()} data = indexing.LazilyIndexedArray(NetCDF4ArrayWrapper(name, self)) - encoding = {} + encoding: dict[str, Any] = {} if isinstance(var.datatype, netCDF4.EnumType): encoding["dtype"] = np.dtype( data.dtype, @@ -507,11 +507,11 @@ def prepare_variable( ) # check enum metadata and use netCDF4.EnumType if ( - np.dtype(datatype).metadata - and datatype.metadata.get("enum_name") - and datatype.metadata.get("enum") + (meta := np.dtype(datatype).metadata) + and (e_name := meta.get("enum_name")) + and (e_dict := meta.get("enum")) ): - datatype = self._build_and_get_enum(name, datatype) + datatype = self._build_and_get_enum(name, datatype, e_name, e_dict) encoding = _extract_nc4_variable_encoding( variable, raise_on_invalid=check_encoding, unlimited_dims=unlimited_dims ) @@ -542,10 +542,14 @@ def prepare_variable( return target, variable.data - def _build_and_get_enum(self, var_name: str, dtype: np.dtype) -> object: - """Add or get the netCDF4 Enum based on the dtype in encoding.""" - enum_dict = dtype.metadata["enum"] - enum_name = dtype.metadata["enum_name"] + def _build_and_get_enum( + self, var_name: str, dtype: np.dtype, enum_name: str, enum_dict: dict[str, int] + ) -> Any: + """ + Add or get the netCDF4 Enum based on the dtype in encoding. + The return type should be ``netCDF4.EnumType``, + but we avoid importing netCDF4 globally for performances. + """ if enum_name not in self.ds.enumtypes: return self.ds.createEnumType( dtype, From f834edef001cb493b31084cc56960e13a08ecffa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Sun, 14 Jan 2024 17:17:17 +0100 Subject: [PATCH 37/37] Update xarray/backends/netCDF4_.py --- xarray/backends/netCDF4_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 45f7d38651b..d3845568709 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -563,7 +563,7 @@ def _build_and_get_enum( f" `{enum_name}` already exists in the Dataset but have" " a different definition. To fix this error, make sure" " each variable have a uniquely named enum in their" - " `encoding['dtype'].metadata` or, if they should share same" + " `encoding['dtype'].metadata` or, if they should share" " the same enum type, make sure the enums are identical." ) raise ValueError(error_msg)