diff --git a/.cruft.json b/.cruft.json index ab6834b4..a7dc6521 100644 --- a/.cruft.json +++ b/.cruft.json @@ -11,7 +11,7 @@ "project_slug": "xscen", "project_short_description": "A climate change scenario-building analysis framework, built with xclim/xarray.", "pypi_username": "RondeauG", - "version": "0.7.14-beta", + "version": "0.7.15-beta", "use_pytest": "y", "use_black": "y", "add_pyup_badge": "n", diff --git a/HISTORY.rst b/HISTORY.rst index a710f6fa..423e584e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -18,9 +18,12 @@ New features and enhancements * Added documentation for `require_all_on` in `search_data_catalogs`. (:pull:`263`). * ``xs.save_to_table`` and ``xs.io.to_table`` to transform datasets and arrays to DataFrames, but with support for multi-columns, multi-sheets and localized table of content generation. * Better ``xs.extract.resample`` : support for weighted resampling operations when starting with frequencies coarser than daily and missing timesteps/values handling. (:issue:`80`, :issue:`93`, :pull:`265`). +* New argument ``attribute_weights`` to ``generate_weights`` to allow for custom weights. (:pull:`252`). +* ``xs.io.round_bits`` to round floating point variable up to a number of bits, allowing for a better compression. This can be combined with the saving step through argument "bitround" of ``save_to_netcdf`` and ``save_to_zarr``. (:pull:`266`). * Added annual global tas timeseries for CMIP6's models CMCC-ESM2 (ssp245, ssp370, ssp585), EC-Earth3-CC (ssp245, ssp585), KACE-1-0-G (ssp245, ssp370, ssp585) and TaiESM1 (ssp245, ssp370). Moved global tas database to a netCDF file. (:issue:`268`, :pull:`270`). * Implemented support for multiple levels and models in ``xs.subset_warming_level``. Better support for `DataArray` and `DataFrame` in ``xs.get_warming_level``. (:pull:`270`). + Breaking changes ^^^^^^^^^^^^^^^^ * New argument ``attribute_weights`` to ``generate_weights`` to allow for custom weights. (:pull:`252`). diff --git a/setup.cfg b/setup.cfg index a7f2fbc4..d52fe214 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.7.14-beta +current_version = 0.7.15-beta commit = True tag = False parse = (?P\d+)\.(?P\d+).(?P\d+)(\-(?P[a-z]+))? diff --git a/setup.py b/setup.py index ae4e22b6..4952b2ab 100644 --- a/setup.py +++ b/setup.py @@ -102,6 +102,6 @@ def run(self): test_suite="tests", extras_require={"dev": dev_requirements}, url="https://github.com/Ouranosinc/xscen", - version="0.7.14-beta", + version="0.7.15-beta", zip_safe=False, ) diff --git a/tests/test_io.py b/tests/test_io.py index c8083246..5d971eca 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -116,3 +116,41 @@ def test_normal(self): np.testing.assert_array_equal( tab.loc[("1993", "pr"), ("JFM",)], self.ds.pr.sel(time="1993", season="JFM") ) + + +def test_round_bits(datablock_3d): + da = datablock_3d( + np.random.random((30, 30, 50)), + variable="tas", + x="lon", + x_start=-70, + y="lat", + y_start=45, + ) + dar = xs.io.round_bits(da, 12) + # Close but NOT equal, meaning something happened + np.testing.assert_allclose(da, dar, rtol=0.013) + # There's always a chance of having a randomly chosen number with only zeros in the bit rounded part of the mantissa + # Assuming a uniform distribution of binary numbers (which it is not), the chance of this happening should be: + # 2^(23 - 12 + 1) / 2^24 = 2^(-12) ~ 0.02 % (but we'll allow 1% of values to be safe) + assert (da != dar).sum() > (0.99 * da.size) + + +class TestSaveToZarr: + @pytest.mark.parametrize( + "vname,vtype,bitr,exp", + [ + ("tas", np.float32, 12, 12), + ("tas", np.float32, False, None), + ("tas", np.int32, 12, None), + ("tas", np.int32, {"tas": 2}, "error"), + ("tas", object, {"pr": 2}, None), + ("tas", np.float64, True, 12), + ], + ) + def test_guess_bitround(self, vname, vtype, bitr, exp): + if exp == "error": + with pytest.raises(ValueError): + xs.io._get_keepbits(bitr, vname, vtype) + else: + assert xs.io._get_keepbits(bitr, vname, vtype) == exp diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 10a49b01..6d7e6827 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,5 +1,3 @@ -import shutil as sh - import numpy as np from conftest import notebooks from xclim.testing.helpers import test_timeseries as timeseries diff --git a/tests/test_xscen.py b/tests/test_xscen.py index e8a4e2f2..30c400e2 100644 --- a/tests/test_xscen.py +++ b/tests/test_xscen.py @@ -28,4 +28,4 @@ def test_package_metadata(self): contents = f.read() assert """Gabriel Rondeau-Genesse""" in contents assert '__email__ = "rondeau-genesse.gabriel@ouranos.ca"' in contents - assert '__version__ = "0.7.14-beta"' in contents + assert '__version__ = "0.7.15-beta"' in contents diff --git a/xscen/__init__.py b/xscen/__init__.py index b8c5e817..55280871 100644 --- a/xscen/__init__.py +++ b/xscen/__init__.py @@ -52,7 +52,7 @@ __author__ = """Gabriel Rondeau-Genesse""" __email__ = "rondeau-genesse.gabriel@ouranos.ca" -__version__ = "0.7.14-beta" +__version__ = "0.7.15-beta" # monkeypatch so that warnings.warn() doesn't mention itself diff --git a/xscen/io.py b/xscen/io.py index 87235767..b62c482b 100644 --- a/xscen/io.py +++ b/xscen/io.py @@ -1,7 +1,9 @@ # noqa: D100 +import datetime import logging import os import shutil as sh +from collections import defaultdict from collections.abc import Sequence from inspect import signature from pathlib import Path @@ -13,6 +15,7 @@ import pandas as pd import xarray as xr import zarr +from numcodecs.bitround import BitRound from rechunker import rechunk as _rechunk from xclim.core.calendar import get_calendar from xclim.core.options import METADATA_LOCALES @@ -23,6 +26,7 @@ from .utils import TRANSLATOR, season_sort_key, translate_time_chunk logger = logging.getLogger(__name__) +KEEPBITS = defaultdict(lambda: 12) __all__ = [ @@ -32,6 +36,7 @@ "make_toc", "rechunk", "rechunk_for_saving", + "round_bits", "save_to_table", "save_to_netcdf", "save_to_zarr", @@ -292,12 +297,54 @@ def _coerce_attrs(attrs): attrs[k] = str(attrs[k]) +def _np_bitround(array, keepbits): + """Bitround for Arrays.""" + codec = BitRound(keepbits=keepbits) + data = array.copy() # otherwise overwrites the input + encoded = codec.encode(data) + return codec.decode(encoded) + + +def round_bits(da: xr.DataArray, keepbits: int): + """Round floating point variable by keeping a given number of bits in the mantissa, dropping the rest.""" + da = xr.apply_ufunc( + _np_bitround, da, keepbits, dask="parallelized", keep_attrs=True + ) + da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keepbits + new_history = f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data compressed with BitRound by keeping {keepbits} bits." + history = ( + new_history + " \n " + da.attrs["history"] + if "history" in da.attrs + else new_history + ) + da.attrs["history"] = history + return da + + +def _get_keepbits(bitround, varname, vartype): + # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. + if not np.issubdtype(vartype, np.floating) or bitround is False: + if isinstance(bitround, dict) and varname in bitround: + raise ValueError( + f"A keepbits value was given for variable {varname} even though it is not of a floating dtype." + ) + return None + if bitround is True: + return KEEPBITS[varname] + if isinstance(bitround, int): + return bitround + if isinstance(bitround, dict): + return bitround.get(varname, KEEPBITS[varname]) + return None + + @parse_config def save_to_netcdf( ds: xr.Dataset, filename: str, *, rechunk: Optional[dict] = None, + bitround: Union[bool, int, dict] = False, compute: bool = True, netcdf_kwargs: Optional[dict] = None, ) -> None: @@ -314,6 +361,11 @@ def save_to_netcdf( Spatial dimensions can be generalized as 'X' and 'Y', which will be mapped to the actual grid type's dimension names. Rechunking is only done on *data* variables sharing dimensions with this argument. + bitround : bool or int or dict + If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. + If an int, this is the number of bits to keep for all float variables. + If a dict, a mapping from variable name to the number of bits to keep. + If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error below 0.013%. compute : bool Whether to start the computation or return a delayed object. netcdf_kwargs : dict, optional @@ -338,6 +390,10 @@ def save_to_netcdf( netcdf_kwargs.setdefault("engine", "h5netcdf") netcdf_kwargs.setdefault("format", "NETCDF4") + for var in list(ds.data_vars.keys()): + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): + ds = ds.assign({var: round_bits(ds[var], keepbits)}) + _coerce_attrs(ds.attrs) for var in ds.variables.values(): _coerce_attrs(var.attrs) @@ -354,6 +410,7 @@ def save_to_zarr( zarr_kwargs: Optional[dict] = None, compute: bool = True, encoding: dict = None, + bitround: Union[bool, int, dict] = False, mode: str = "f", itervar: bool = False, timeout_cleanup: bool = True, @@ -383,6 +440,11 @@ def save_to_zarr( if 'a', skip existing variables, writes the others. encoding : dict, optional If given, skipped variables are popped in place. + bitround : bool or int or dict + If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression. + If an int, this is the number of bits to keep for all float variables. + If a dict, a mapping from variable name to the number of bits to keep. + If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error of 0.012%. itervar : bool If True, (data) variables are written one at a time, appending to the zarr. If False, this function computes, no matter what was passed to kwargs. @@ -430,7 +492,7 @@ def _skip(var): if mode == "o": if exists: var_path = path / var - print(f"Removing {var_path} to overwrite.") + logger.warning(f"Removing {var_path} to overwrite.") sh.rmtree(var_path) return False @@ -439,12 +501,14 @@ def _skip(var): return exists return False - for var in ds.data_vars.keys(): + for var in list(ds.data_vars.keys()): if _skip(var): logger.info(f"Skipping {var} in {path}.") ds = ds.drop_vars(var) if encoding: encoding.pop(var) + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): + ds = ds.assign({var: round_bits(ds[var], keepbits)}) if len(ds.data_vars) == 0: return None