From 5e73334f8560d852d77a05d144b4f737f8e54e29 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Mon, 2 Oct 2023 18:13:48 -0400 Subject: [PATCH 1/5] Add bitrounding --- tests/test_io.py | 35 +++++++++++++++++++++++++++++ xscen/io.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/tests/test_io.py b/tests/test_io.py index c8083246..4d9206fb 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -116,3 +116,38 @@ def test_normal(self): np.testing.assert_array_equal( tab.loc[("1993", "pr"), ("JFM",)], self.ds.pr.sel(time="1993", season="JFM") ) + + +def test_round_bits(datablock_3d): + da = datablock_3d( + np.random.random((30, 30, 50)), + variable="tas", + x="lon", + x_start=-70, + y="lat", + y_start=45, + ) + dar = xs.io.round_bits(da, 12) + # Close but NOT equal, meaning something happened + np.testing.assert_allclose(da, dar, rtol=0.013) + assert not (da == dar).any() + + +class TestSaveToZarr: + @pytest.mark.parametrize( + "vname,vtype,bitr,exp", + [ + ("tas", np.float32, 12, 12), + ("tas", np.float32, False, None), + ("tas", np.int32, 12, None), + ("tas", np.int32, {"tas": 2}, "error"), + ("tas", object, {"pr": 2}, None), + ("tas", np.float64, True, 12), + ], + ) + def test_guess_bitround(self, vname, vtype, bitr, exp): + if exp == "error": + with pytest.raises(ValueError): + xs.io._guess_keepbits(bitr, vname, vtype) + else: + assert xs.io._guess_keepbits(bitr, vname, vtype) == exp diff --git a/xscen/io.py b/xscen/io.py index 87235767..e2d5958f 100644 --- a/xscen/io.py +++ b/xscen/io.py @@ -1,7 +1,9 @@ # noqa: D100 +import datetime import logging import os import shutil as sh +from collections import defaultdict from collections.abc import Sequence from inspect import signature from pathlib import Path @@ -13,6 +15,7 @@ import pandas as pd import xarray as xr import zarr +from numcodecs.bitround import BitRound from rechunker import rechunk as _rechunk from xclim.core.calendar import get_calendar from xclim.core.options import METADATA_LOCALES @@ -23,6 +26,7 @@ from .utils import TRANSLATOR, season_sort_key, translate_time_chunk logger = logging.getLogger(__name__) +KEEPBITS = defaultdict(lambda: 12) __all__ = [ @@ -32,6 +36,7 @@ "make_toc", "rechunk", "rechunk_for_saving", + "round_bits", "save_to_table", "save_to_netcdf", "save_to_zarr", @@ -345,6 +350,47 @@ def save_to_netcdf( return ds.to_netcdf(filename, compute=compute, **netcdf_kwargs) +def _np_bitround(array, keepbits): + """Bitround for Arrays.""" + codec = BitRound(keepbits=keepbits) + data = array.copy() # otherwise overwrites the input + encoded = codec.encode(data) + return codec.decode(encoded) + + +def round_bits(da: xr.DataArray, keepbits: int): + """Round floating point variable by keeping a given number of bits in the mantissa, dropping the rest.""" + da = xr.apply_ufunc( + _np_bitround, da, keepbits, dask="parallelized", keep_attrs=True + ) + da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keepbits + new_history = f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data compressed with BitRound by keeping {keepbits} bits." + history = ( + new_history + " \n " + da.attrs["history"] + if "history" in da.attrs + else new_history + ) + da.attrs["history"] = history + return da + + +def _guess_keepbits(bitround, varname, vartype): + # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. + if not np.issubdtype(vartype, np.floating) or bitround is False: + if isinstance(bitround, dict) and varname in bitround: + raise ValueError( + f"A keepbits value was given for variable {varname} even though it is not of a floatig dtype." + ) + return None + if bitround is True: + return KEEPBITS[varname] + if isinstance(bitround, int): + return bitround + if isinstance(bitround, dict): + return bitround.get(varname, KEEPBITS[varname]) + return None + + @parse_config def save_to_zarr( ds: xr.Dataset, @@ -354,6 +400,7 @@ def save_to_zarr( zarr_kwargs: Optional[dict] = None, compute: bool = True, encoding: dict = None, + bitround: Union[bool, int, dict] = False, mode: str = "f", itervar: bool = False, timeout_cleanup: bool = True, @@ -383,6 +430,11 @@ def save_to_zarr( if 'a', skip existing variables, writes the others. encoding : dict, optional If given, skipped variables are popped in place. + bitround : bool or int or dict + If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. + If an int, this is the number of bits to keep for all float variables. + If a dict, a mapping from variable name to the number of bits to keep. + If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error of 0.012%. itervar : bool If True, (data) variables are written one at a time, appending to the zarr. If False, this function computes, no matter what was passed to kwargs. @@ -430,7 +482,7 @@ def _skip(var): if mode == "o": if exists: var_path = path / var - print(f"Removing {var_path} to overwrite.") + logger.warning(f"Removing {var_path} to overwrite.") sh.rmtree(var_path) return False @@ -439,12 +491,14 @@ def _skip(var): return exists return False - for var in ds.data_vars.keys(): + for var in list(ds.data_vars.keys()): if _skip(var): logger.info(f"Skipping {var} in {path}.") ds = ds.drop_vars(var) if encoding: encoding.pop(var) + if keepbits := _guess_keepbits(bitround, var, ds[var].dtype): + ds = ds.assign({var: round_bits(ds[var], keepbits)}) if len(ds.data_vars) == 0: return None From 3987f47f7f2e40f5dcaa89d01c4c7bc7261333f5 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Mon, 2 Oct 2023 18:20:38 -0400 Subject: [PATCH 2/5] add to netcdf - upd hist --- HISTORY.rst | 1 + xscen/io.py | 92 +++++++++++++++++++++++++++++------------------------ 2 files changed, 52 insertions(+), 41 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 67515d77..618dd3a5 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -17,6 +17,7 @@ New features and enhancements * ``xs.spatial_mean`` with ``method='xESMF'`` will also automatically segmentize polygons (down to a 1° resolution) to ensure a correct average (:pull:`260`). * Added documentation for `require_all_on` in `search_data_catalogs`. (:pull:`263`). * ``xs.save_to_table`` and ``xs.io.to_table`` to transform datasets and arrays to DataFrames, but with support for multi-columns, multi-sheets and localized table of content generation. +* ``xs.io.round_bits`` to round floating point variable up to a number of bits, allowing for a better compression. This can be combined with the saving step through argument "bitround" of ``save_to_netcdf`` and ``save_to_zarr``. (:pull:`266`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/xscen/io.py b/xscen/io.py index e2d5958f..12b7c0ee 100644 --- a/xscen/io.py +++ b/xscen/io.py @@ -297,12 +297,54 @@ def _coerce_attrs(attrs): attrs[k] = str(attrs[k]) +def _np_bitround(array, keepbits): + """Bitround for Arrays.""" + codec = BitRound(keepbits=keepbits) + data = array.copy() # otherwise overwrites the input + encoded = codec.encode(data) + return codec.decode(encoded) + + +def round_bits(da: xr.DataArray, keepbits: int): + """Round floating point variable by keeping a given number of bits in the mantissa, dropping the rest.""" + da = xr.apply_ufunc( + _np_bitround, da, keepbits, dask="parallelized", keep_attrs=True + ) + da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keepbits + new_history = f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data compressed with BitRound by keeping {keepbits} bits." + history = ( + new_history + " \n " + da.attrs["history"] + if "history" in da.attrs + else new_history + ) + da.attrs["history"] = history + return da + + +def _guess_keepbits(bitround, varname, vartype): + # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. + if not np.issubdtype(vartype, np.floating) or bitround is False: + if isinstance(bitround, dict) and varname in bitround: + raise ValueError( + f"A keepbits value was given for variable {varname} even though it is not of a floatig dtype." + ) + return None + if bitround is True: + return KEEPBITS[varname] + if isinstance(bitround, int): + return bitround + if isinstance(bitround, dict): + return bitround.get(varname, KEEPBITS[varname]) + return None + + @parse_config def save_to_netcdf( ds: xr.Dataset, filename: str, *, rechunk: Optional[dict] = None, + bitround: Union[bool, int, dict] = False, compute: bool = True, netcdf_kwargs: Optional[dict] = None, ) -> None: @@ -319,6 +361,11 @@ def save_to_netcdf( Spatial dimensions can be generalized as 'X' and 'Y', which will be mapped to the actual grid type's dimension names. Rechunking is only done on *data* variables sharing dimensions with this argument. + bitround : bool or int or dict + If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. + If an int, this is the number of bits to keep for all float variables. + If a dict, a mapping from variable name to the number of bits to keep. + If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error of 0.012%. compute : bool Whether to start the computation or return a delayed object. netcdf_kwargs : dict, optional @@ -343,6 +390,10 @@ def save_to_netcdf( netcdf_kwargs.setdefault("engine", "h5netcdf") netcdf_kwargs.setdefault("format", "NETCDF4") + for var in list(ds.data_vars.keys()): + if keepbits := _guess_keepbits(bitround, var, ds[var].dtype): + ds = ds.assign({var: round_bits(ds[var], keepbits)}) + _coerce_attrs(ds.attrs) for var in ds.variables.values(): _coerce_attrs(var.attrs) @@ -350,47 +401,6 @@ def save_to_netcdf( return ds.to_netcdf(filename, compute=compute, **netcdf_kwargs) -def _np_bitround(array, keepbits): - """Bitround for Arrays.""" - codec = BitRound(keepbits=keepbits) - data = array.copy() # otherwise overwrites the input - encoded = codec.encode(data) - return codec.decode(encoded) - - -def round_bits(da: xr.DataArray, keepbits: int): - """Round floating point variable by keeping a given number of bits in the mantissa, dropping the rest.""" - da = xr.apply_ufunc( - _np_bitround, da, keepbits, dask="parallelized", keep_attrs=True - ) - da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keepbits - new_history = f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data compressed with BitRound by keeping {keepbits} bits." - history = ( - new_history + " \n " + da.attrs["history"] - if "history" in da.attrs - else new_history - ) - da.attrs["history"] = history - return da - - -def _guess_keepbits(bitround, varname, vartype): - # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. - if not np.issubdtype(vartype, np.floating) or bitround is False: - if isinstance(bitround, dict) and varname in bitround: - raise ValueError( - f"A keepbits value was given for variable {varname} even though it is not of a floatig dtype." - ) - return None - if bitround is True: - return KEEPBITS[varname] - if isinstance(bitround, int): - return bitround - if isinstance(bitround, dict): - return bitround.get(varname, KEEPBITS[varname]) - return None - - @parse_config def save_to_zarr( ds: xr.Dataset, From 2b31b70be715a4b892069d54678c9348479dba43 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Mon, 16 Oct 2023 14:39:13 -0400 Subject: [PATCH 3/5] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Éric Dupuis <71575674+coxipi@users.noreply.github.com> --- xscen/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xscen/io.py b/xscen/io.py index 12b7c0ee..beba8757 100644 --- a/xscen/io.py +++ b/xscen/io.py @@ -326,7 +326,7 @@ def _guess_keepbits(bitround, varname, vartype): if not np.issubdtype(vartype, np.floating) or bitround is False: if isinstance(bitround, dict) and varname in bitround: raise ValueError( - f"A keepbits value was given for variable {varname} even though it is not of a floatig dtype." + f"A keepbits value was given for variable {varname} even though it is not of a floating dtype." ) return None if bitround is True: @@ -365,7 +365,7 @@ def save_to_netcdf( If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. If an int, this is the number of bits to keep for all float variables. If a dict, a mapping from variable name to the number of bits to keep. - If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error of 0.012%. + If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error below 0.013%. compute : bool Whether to start the computation or return a delayed object. netcdf_kwargs : dict, optional From e028ffa032eee82dbe6b753b3dfb805e6545706b Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Mon, 16 Oct 2023 14:58:52 -0400 Subject: [PATCH 4/5] Fix test - rename _get_keepbits --- tests/test_io.py | 5 ++++- tests/test_scripting.py | 2 -- xscen/io.py | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_io.py b/tests/test_io.py index 4d9206fb..5e98479e 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -130,7 +130,10 @@ def test_round_bits(datablock_3d): dar = xs.io.round_bits(da, 12) # Close but NOT equal, meaning something happened np.testing.assert_allclose(da, dar, rtol=0.013) - assert not (da == dar).any() + # There's always a chance of having a randomly chosen number with only zeros in the bit rounded part of the mantissa + # Assuming a uniform distribution of binary numbers (which it is not), the chance of this happening should be: + # 2^(23 - 12 + 1) / 2^24 = 2^(-12) ~ 0.02 % (but we'll allow 1% of values to be safe) + assert (da != dar).sum() > (0.99 * da.size) class TestSaveToZarr: diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 10a49b01..6d7e6827 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,5 +1,3 @@ -import shutil as sh - import numpy as np from conftest import notebooks from xclim.testing.helpers import test_timeseries as timeseries diff --git a/xscen/io.py b/xscen/io.py index beba8757..b62c482b 100644 --- a/xscen/io.py +++ b/xscen/io.py @@ -321,7 +321,7 @@ def round_bits(da: xr.DataArray, keepbits: int): return da -def _guess_keepbits(bitround, varname, vartype): +def _get_keepbits(bitround, varname, vartype): # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. if not np.issubdtype(vartype, np.floating) or bitround is False: if isinstance(bitround, dict) and varname in bitround: @@ -391,7 +391,7 @@ def save_to_netcdf( netcdf_kwargs.setdefault("format", "NETCDF4") for var in list(ds.data_vars.keys()): - if keepbits := _guess_keepbits(bitround, var, ds[var].dtype): + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): ds = ds.assign({var: round_bits(ds[var], keepbits)}) _coerce_attrs(ds.attrs) @@ -507,7 +507,7 @@ def _skip(var): ds = ds.drop_vars(var) if encoding: encoding.pop(var) - if keepbits := _guess_keepbits(bitround, var, ds[var].dtype): + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): ds = ds.assign({var: round_bits(ds[var], keepbits)}) if len(ds.data_vars) == 0: From 09b319058cd1923f3aaa995232dd6b6c59c4068c Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Mon, 16 Oct 2023 15:19:27 -0400 Subject: [PATCH 5/5] rename func in tests too --- tests/test_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_io.py b/tests/test_io.py index 5e98479e..5d971eca 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -151,6 +151,6 @@ class TestSaveToZarr: def test_guess_bitround(self, vname, vtype, bitr, exp): if exp == "error": with pytest.raises(ValueError): - xs.io._guess_keepbits(bitr, vname, vtype) + xs.io._get_keepbits(bitr, vname, vtype) else: - assert xs.io._guess_keepbits(bitr, vname, vtype) == exp + assert xs.io._get_keepbits(bitr, vname, vtype) == exp