diff --git a/HISTORY.rst b/HISTORY.rst index d4be1543..2849ba51 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -19,6 +19,7 @@ New features and enhancements * ``xs.save_to_table`` and ``xs.io.to_table`` to transform datasets and arrays to DataFrames, but with support for multi-columns, multi-sheets and localized table of content generation. * Better ``xs.extract.resample`` : support for weighted resampling operations when starting with frequencies coarser than daily and missing timesteps/values handling. (:issue:`80`, :issue:`93`, :pull:`265`). * New argument ``attribute_weights`` to ``generate_weights`` to allow for custom weights. (:pull:`252`). +* ``xs.io.round_bits`` to round floating point variable up to a number of bits, allowing for a better compression. This can be combined with the saving step through argument "bitround" of ``save_to_netcdf`` and ``save_to_zarr``. (:pull:`266`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/tests/test_io.py b/tests/test_io.py index c8083246..5d971eca 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -116,3 +116,41 @@ def test_normal(self): np.testing.assert_array_equal( tab.loc[("1993", "pr"), ("JFM",)], self.ds.pr.sel(time="1993", season="JFM") ) + + +def test_round_bits(datablock_3d): + da = datablock_3d( + np.random.random((30, 30, 50)), + variable="tas", + x="lon", + x_start=-70, + y="lat", + y_start=45, + ) + dar = xs.io.round_bits(da, 12) + # Close but NOT equal, meaning something happened + np.testing.assert_allclose(da, dar, rtol=0.013) + # There's always a chance of having a randomly chosen number with only zeros in the bit rounded part of the mantissa + # Assuming a uniform distribution of binary numbers (which it is not), the chance of this happening should be: + # 2^(23 - 12 + 1) / 2^24 = 2^(-12) ~ 0.02 % (but we'll allow 1% of values to be safe) + assert (da != dar).sum() > (0.99 * da.size) + + +class TestSaveToZarr: + @pytest.mark.parametrize( + "vname,vtype,bitr,exp", + [ + ("tas", np.float32, 12, 12), + ("tas", np.float32, False, None), + ("tas", np.int32, 12, None), + ("tas", np.int32, {"tas": 2}, "error"), + ("tas", object, {"pr": 2}, None), + ("tas", np.float64, True, 12), + ], + ) + def test_guess_bitround(self, vname, vtype, bitr, exp): + if exp == "error": + with pytest.raises(ValueError): + xs.io._get_keepbits(bitr, vname, vtype) + else: + assert xs.io._get_keepbits(bitr, vname, vtype) == exp diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 10a49b01..6d7e6827 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,5 +1,3 @@ -import shutil as sh - import numpy as np from conftest import notebooks from xclim.testing.helpers import test_timeseries as timeseries diff --git a/xscen/io.py b/xscen/io.py index 87235767..b62c482b 100644 --- a/xscen/io.py +++ b/xscen/io.py @@ -1,7 +1,9 @@ # noqa: D100 +import datetime import logging import os import shutil as sh +from collections import defaultdict from collections.abc import Sequence from inspect import signature from pathlib import Path @@ -13,6 +15,7 @@ import pandas as pd import xarray as xr import zarr +from numcodecs.bitround import BitRound from rechunker import rechunk as _rechunk from xclim.core.calendar import get_calendar from xclim.core.options import METADATA_LOCALES @@ -23,6 +26,7 @@ from .utils import TRANSLATOR, season_sort_key, translate_time_chunk logger = logging.getLogger(__name__) +KEEPBITS = defaultdict(lambda: 12) __all__ = [ @@ -32,6 +36,7 @@ "make_toc", "rechunk", "rechunk_for_saving", + "round_bits", "save_to_table", "save_to_netcdf", "save_to_zarr", @@ -292,12 +297,54 @@ def _coerce_attrs(attrs): attrs[k] = str(attrs[k]) +def _np_bitround(array, keepbits): + """Bitround for Arrays.""" + codec = BitRound(keepbits=keepbits) + data = array.copy() # otherwise overwrites the input + encoded = codec.encode(data) + return codec.decode(encoded) + + +def round_bits(da: xr.DataArray, keepbits: int): + """Round floating point variable by keeping a given number of bits in the mantissa, dropping the rest.""" + da = xr.apply_ufunc( + _np_bitround, da, keepbits, dask="parallelized", keep_attrs=True + ) + da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keepbits + new_history = f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data compressed with BitRound by keeping {keepbits} bits." + history = ( + new_history + " \n " + da.attrs["history"] + if "history" in da.attrs + else new_history + ) + da.attrs["history"] = history + return da + + +def _get_keepbits(bitround, varname, vartype): + # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. + if not np.issubdtype(vartype, np.floating) or bitround is False: + if isinstance(bitround, dict) and varname in bitround: + raise ValueError( + f"A keepbits value was given for variable {varname} even though it is not of a floating dtype." + ) + return None + if bitround is True: + return KEEPBITS[varname] + if isinstance(bitround, int): + return bitround + if isinstance(bitround, dict): + return bitround.get(varname, KEEPBITS[varname]) + return None + + @parse_config def save_to_netcdf( ds: xr.Dataset, filename: str, *, rechunk: Optional[dict] = None, + bitround: Union[bool, int, dict] = False, compute: bool = True, netcdf_kwargs: Optional[dict] = None, ) -> None: @@ -314,6 +361,11 @@ def save_to_netcdf( Spatial dimensions can be generalized as 'X' and 'Y', which will be mapped to the actual grid type's dimension names. Rechunking is only done on *data* variables sharing dimensions with this argument. + bitround : bool or int or dict + If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. + If an int, this is the number of bits to keep for all float variables. + If a dict, a mapping from variable name to the number of bits to keep. + If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error below 0.013%. compute : bool Whether to start the computation or return a delayed object. netcdf_kwargs : dict, optional @@ -338,6 +390,10 @@ def save_to_netcdf( netcdf_kwargs.setdefault("engine", "h5netcdf") netcdf_kwargs.setdefault("format", "NETCDF4") + for var in list(ds.data_vars.keys()): + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): + ds = ds.assign({var: round_bits(ds[var], keepbits)}) + _coerce_attrs(ds.attrs) for var in ds.variables.values(): _coerce_attrs(var.attrs) @@ -354,6 +410,7 @@ def save_to_zarr( zarr_kwargs: Optional[dict] = None, compute: bool = True, encoding: dict = None, + bitround: Union[bool, int, dict] = False, mode: str = "f", itervar: bool = False, timeout_cleanup: bool = True, @@ -383,6 +440,11 @@ def save_to_zarr( if 'a', skip existing variables, writes the others. encoding : dict, optional If given, skipped variables are popped in place. + bitround : bool or int or dict + If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. + If an int, this is the number of bits to keep for all float variables. + If a dict, a mapping from variable name to the number of bits to keep. + If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error of 0.012%. itervar : bool If True, (data) variables are written one at a time, appending to the zarr. If False, this function computes, no matter what was passed to kwargs. @@ -430,7 +492,7 @@ def _skip(var): if mode == "o": if exists: var_path = path / var - print(f"Removing {var_path} to overwrite.") + logger.warning(f"Removing {var_path} to overwrite.") sh.rmtree(var_path) return False @@ -439,12 +501,14 @@ def _skip(var): return exists return False - for var in ds.data_vars.keys(): + for var in list(ds.data_vars.keys()): if _skip(var): logger.info(f"Skipping {var} in {path}.") ds = ds.drop_vars(var) if encoding: encoding.pop(var) + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): + ds = ds.assign({var: round_bits(ds[var], keepbits)}) if len(ds.data_vars) == 0: return None