Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better save_to_zarr #266

Merged
merged 6 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ New features and enhancements
* ``xs.save_to_table`` and ``xs.io.to_table`` to transform datasets and arrays to DataFrames, but with support for multi-columns, multi-sheets and localized table of content generation.
* Better ``xs.extract.resample`` : support for weighted resampling operations when starting with frequencies coarser than daily and missing timesteps/values handling. (:issue:`80`, :issue:`93`, :pull:`265`).
* New argument ``attribute_weights`` to ``generate_weights`` to allow for custom weights. (:pull:`252`).
* ``xs.io.round_bits`` to round floating point variable up to a number of bits, allowing for a better compression. This can be combined with the saving step through argument "bitround" of ``save_to_netcdf`` and ``save_to_zarr``. (:pull:`266`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
38 changes: 38 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,41 @@ def test_normal(self):
np.testing.assert_array_equal(
tab.loc[("1993", "pr"), ("JFM",)], self.ds.pr.sel(time="1993", season="JFM")
)


def test_round_bits(datablock_3d):
da = datablock_3d(
np.random.random((30, 30, 50)),
variable="tas",
x="lon",
x_start=-70,
y="lat",
y_start=45,
)
dar = xs.io.round_bits(da, 12)
# Close but NOT equal, meaning something happened
np.testing.assert_allclose(da, dar, rtol=0.013)
# There's always a chance of having a randomly chosen number with only zeros in the bit rounded part of the mantissa
# Assuming a uniform distribution of binary numbers (which it is not), the chance of this happening should be:
# 2^(23 - 12 + 1) / 2^24 = 2^(-12) ~ 0.02 % (but we'll allow 1% of values to be safe)
assert (da != dar).sum() > (0.99 * da.size)


class TestSaveToZarr:
@pytest.mark.parametrize(
"vname,vtype,bitr,exp",
[
("tas", np.float32, 12, 12),
("tas", np.float32, False, None),
("tas", np.int32, 12, None),
("tas", np.int32, {"tas": 2}, "error"),
("tas", object, {"pr": 2}, None),
("tas", np.float64, True, 12),
],
)
def test_guess_bitround(self, vname, vtype, bitr, exp):
if exp == "error":
with pytest.raises(ValueError):
xs.io._get_keepbits(bitr, vname, vtype)
else:
assert xs.io._get_keepbits(bitr, vname, vtype) == exp
2 changes: 0 additions & 2 deletions tests/test_scripting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import shutil as sh

import numpy as np
from conftest import notebooks
from xclim.testing.helpers import test_timeseries as timeseries
Expand Down
68 changes: 66 additions & 2 deletions xscen/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# noqa: D100
import datetime
import logging
import os
import shutil as sh
from collections import defaultdict
from collections.abc import Sequence
from inspect import signature
from pathlib import Path
Expand All @@ -13,6 +15,7 @@
import pandas as pd
import xarray as xr
import zarr
from numcodecs.bitround import BitRound
from rechunker import rechunk as _rechunk
from xclim.core.calendar import get_calendar
from xclim.core.options import METADATA_LOCALES
Expand All @@ -23,6 +26,7 @@
from .utils import TRANSLATOR, season_sort_key, translate_time_chunk

logger = logging.getLogger(__name__)
KEEPBITS = defaultdict(lambda: 12)


__all__ = [
Expand All @@ -32,6 +36,7 @@
"make_toc",
"rechunk",
"rechunk_for_saving",
"round_bits",
"save_to_table",
"save_to_netcdf",
"save_to_zarr",
Expand Down Expand Up @@ -292,12 +297,54 @@ def _coerce_attrs(attrs):
attrs[k] = str(attrs[k])


def _np_bitround(array, keepbits):
"""Bitround for Arrays."""
codec = BitRound(keepbits=keepbits)
data = array.copy() # otherwise overwrites the input
encoded = codec.encode(data)
return codec.decode(encoded)


def round_bits(da: xr.DataArray, keepbits: int):
"""Round floating point variable by keeping a given number of bits in the mantissa, dropping the rest."""
da = xr.apply_ufunc(
_np_bitround, da, keepbits, dask="parallelized", keep_attrs=True
)
da.attrs["_QuantizeBitRoundNumberOfSignificantDigits"] = keepbits
new_history = f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Data compressed with BitRound by keeping {keepbits} bits."
history = (
new_history + " \n " + da.attrs["history"]
if "history" in da.attrs
else new_history
)
da.attrs["history"] = history
return da


def _get_keepbits(bitround, varname, vartype):
# Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name.
if not np.issubdtype(vartype, np.floating) or bitround is False:
if isinstance(bitround, dict) and varname in bitround:
raise ValueError(
f"A keepbits value was given for variable {varname} even though it is not of a floating dtype."
)
return None
if bitround is True:
return KEEPBITS[varname]
if isinstance(bitround, int):
return bitround
if isinstance(bitround, dict):
return bitround.get(varname, KEEPBITS[varname])
return None


@parse_config
def save_to_netcdf(
ds: xr.Dataset,
filename: str,
*,
rechunk: Optional[dict] = None,
bitround: Union[bool, int, dict] = False,
compute: bool = True,
netcdf_kwargs: Optional[dict] = None,
) -> None:
Expand All @@ -314,6 +361,11 @@ def save_to_netcdf(
Spatial dimensions can be generalized as 'X' and 'Y', which will be mapped to the actual grid type's
dimension names.
Rechunking is only done on *data* variables sharing dimensions with this argument.
bitround : bool or int or dict
If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression.
If an int, this is the number of bits to keep for all float variables.
If a dict, a mapping from variable name to the number of bits to keep.
If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error below 0.013%.
compute : bool
Whether to start the computation or return a delayed object.
netcdf_kwargs : dict, optional
Expand All @@ -338,6 +390,10 @@ def save_to_netcdf(
netcdf_kwargs.setdefault("engine", "h5netcdf")
netcdf_kwargs.setdefault("format", "NETCDF4")

for var in list(ds.data_vars.keys()):
if keepbits := _get_keepbits(bitround, var, ds[var].dtype):
ds = ds.assign({var: round_bits(ds[var], keepbits)})

_coerce_attrs(ds.attrs)
for var in ds.variables.values():
_coerce_attrs(var.attrs)
Expand All @@ -354,6 +410,7 @@ def save_to_zarr(
zarr_kwargs: Optional[dict] = None,
compute: bool = True,
encoding: dict = None,
bitround: Union[bool, int, dict] = False,
mode: str = "f",
itervar: bool = False,
timeout_cleanup: bool = True,
Expand Down Expand Up @@ -383,6 +440,11 @@ def save_to_zarr(
if 'a', skip existing variables, writes the others.
encoding : dict, optional
If given, skipped variables are popped in place.
bitround : bool or int or dict
If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa, allowing for a much better compression.
If an int, this is the number of bits to keep for all float variables.
If a dict, a mapping from variable name to the number of bits to keep.
If True, the number of bits to keep is guessed based on the variable's name, defaulting to 12, which yields a relative error of 0.012%.
itervar : bool
If True, (data) variables are written one at a time, appending to the zarr.
If False, this function computes, no matter what was passed to kwargs.
Expand Down Expand Up @@ -430,7 +492,7 @@ def _skip(var):
if mode == "o":
if exists:
var_path = path / var
print(f"Removing {var_path} to overwrite.")
logger.warning(f"Removing {var_path} to overwrite.")
sh.rmtree(var_path)
return False

Expand All @@ -439,12 +501,14 @@ def _skip(var):
return exists
return False

for var in ds.data_vars.keys():
for var in list(ds.data_vars.keys()):
if _skip(var):
logger.info(f"Skipping {var} in {path}.")
ds = ds.drop_vars(var)
if encoding:
encoding.pop(var)
if keepbits := _get_keepbits(bitround, var, ds[var].dtype):
ds = ds.assign({var: round_bits(ds[var], keepbits)})

if len(ds.data_vars) == 0:
return None
Expand Down