diff --git a/tests/test_io.py b/tests/test_io.py index 4d9206fb..5e98479e 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -130,7 +130,10 @@ def test_round_bits(datablock_3d): dar = xs.io.round_bits(da, 12) # Close but NOT equal, meaning something happened np.testing.assert_allclose(da, dar, rtol=0.013) - assert not (da == dar).any() + # There's always a chance of having a randomly chosen number with only zeros in the bit rounded part of the mantissa + # Assuming a uniform distribution of binary numbers (which it is not), the chance of this happening should be: + # 2^(23 - 12 + 1) / 2^24 = 2^(-12) ~ 0.02 % (but we'll allow 1% of values to be safe) + assert (da != dar).sum() > (0.99 * da.size) class TestSaveToZarr: diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 10a49b01..6d7e6827 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,5 +1,3 @@ -import shutil as sh - import numpy as np from conftest import notebooks from xclim.testing.helpers import test_timeseries as timeseries diff --git a/xscen/io.py b/xscen/io.py index beba8757..b62c482b 100644 --- a/xscen/io.py +++ b/xscen/io.py @@ -321,7 +321,7 @@ def round_bits(da: xr.DataArray, keepbits: int): return da -def _guess_keepbits(bitround, varname, vartype): +def _get_keepbits(bitround, varname, vartype): # Guess the number of bits to keep depending on how bitround was passed, the var dtype and the var name. if not np.issubdtype(vartype, np.floating) or bitround is False: if isinstance(bitround, dict) and varname in bitround: @@ -391,7 +391,7 @@ def save_to_netcdf( netcdf_kwargs.setdefault("format", "NETCDF4") for var in list(ds.data_vars.keys()): - if keepbits := _guess_keepbits(bitround, var, ds[var].dtype): + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): ds = ds.assign({var: round_bits(ds[var], keepbits)}) _coerce_attrs(ds.attrs) @@ -507,7 +507,7 @@ def _skip(var): ds = ds.drop_vars(var) if encoding: encoding.pop(var) - if keepbits := _guess_keepbits(bitround, var, ds[var].dtype): + if keepbits := _get_keepbits(bitround, var, ds[var].dtype): ds = ds.assign({var: round_bits(ds[var], keepbits)}) if len(ds.data_vars) == 0: