Skip to content

Commit

Permalink
more tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
RondeauG committed Nov 11, 2024
1 parent 82bf43c commit 0bfe1bd
Show file tree
Hide file tree
Showing 5 changed files with 378 additions and 68 deletions.
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- netCDF4
- numcodecs
- numpy >=1.24
- openpyxl
- pandas >=2.2
- parse
- pyyaml
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- netCDF4
- numcodecs
- numpy >=1.24
- openpyxl
- pandas >=2.2
- parse
- pyyaml
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies = [
"netCDF4",
"numcodecs",
"numpy >=1.24",
"openpyxl",
"pandas >=2.2",
"parse",
# Used when opening catalogs.
Expand Down
44 changes: 34 additions & 10 deletions src/xscen/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def save_to_zarr( # noqa: C901
if 'o', removes the existing variables.
if 'a', skip existing variables, writes the others.
encoding : dict, optional
If given, skipped variables are popped in place.
If given here instead of 'zarr_kwargs', encoding will only be applied to the variables that are being written,
skipping those that are already in the zarr.
bitround : bool or int or dict
If not False, float variables are bit-rounded by dropping a certain number of bits from their mantissa,
allowing for a much better compression.
Expand Down Expand Up @@ -512,17 +513,29 @@ def _skip(var):
return False

if mode == "a":
# In all cases, we need to skip the encoding of existing variables.
if exists:
if encoding:
encoding.pop(var, None)

# If we are not appending, we need to skip the writing of existing variables.
if "append_dim" not in zarr_kwargs:
return exists

# If we are appending, we need to raise an error if there are new variables.
elif exists is False:
raise ValueError(
f"When 'append_dim' is set in zarr_kwargs, all variables must already exist in the dataset."
)

return False

for var in list(ds.data_vars.keys()):
if _skip(var):
msg = f"Skipping {var} in {path}."
logger.info(msg)
ds = ds.drop_vars(var)
if encoding:
encoding.pop(var)
continue
if keepbits := _get_keepbits(bitround, var, ds[var].dtype):
ds = ds.assign({var: round_bits(ds[var], keepbits)})
# Remove original_shape from encoding, since it can cause issues with some engines.
Expand Down Expand Up @@ -773,7 +786,6 @@ def make_toc(ds: xr.Dataset | xr.DataArray, loc: str | None = None) -> pd.DataFr
for vv, da in ds.data_vars.items()
],
).set_index(_("Variable"))
toc.attrs["name"] = _("Content")

# Add global attributes by using a fake variable and description
if len(ds.attrs) > 0:
Expand All @@ -793,28 +805,29 @@ def make_toc(ds: xr.Dataset | xr.DataArray, loc: str | None = None) -> pd.DataFr
toc = pd.concat([toc, pd.DataFrame(index=[""])])
toc = pd.concat([toc, pd.DataFrame(index=[_("Global attributes")])])
toc = pd.concat([toc, globattr])
toc.attrs["name"] = _("Content")

return toc


TABLE_FORMATS = {".csv": "csv", ".xls": "excel", ".xlsx": "excel"}


def save_to_table(
def save_to_table( # noqa: C901
ds: xr.Dataset | xr.DataArray,
filename: str | os.PathLike,
output_format: str | None = None,
*,
row: str | Sequence[str] | None = None,
column: None | str | Sequence[str] = "variable",
column: None | str | Sequence[str] = None,
sheet: str | Sequence[str] | None = None,
coords: bool | Sequence[str] = True,
col_sep: str = "_",
row_sep: str | None = None,
add_toc: bool | pd.DataFrame = False,
**kwargs,
):
"""Save the dataset to a tabular file (csv, excel, ...).
r"""Save the dataset to a tabular file (csv, excel, ...).
This function will trigger a computation of the dataset.
Expand All @@ -835,7 +848,8 @@ def save_to_table(
Default is all data dimensions.
column : str or sequence of str, optional
Name of the dimension(s) to use as columns.
Default is "variable", i.e. the name of the variable(s).
When using a Dataset with more than 1 variable, default is "variable", i.e. the name of the variable(s).
When using a DataArray, default is None.
sheet : str or sequence of str, optional
Name of the dimension(s) to use as sheet names.
Only valid if the output format is excel.
Expand All @@ -851,7 +865,7 @@ def save_to_table(
A table of content to add as the first sheet. Only valid if the output format is excel.
If True, :py:func:`make_toc` is used to generate the toc.
The sheet name of the toc can be given through the "name" attribute of the DataFrame, otherwise "Content" is used.
kwargs:
\*\*kwargs:
Other arguments passed to the pandas function.
If the output format is excel, kwargs to :py:class:`pandas.ExcelWriter` can be given here as well.
"""
Expand All @@ -864,6 +878,9 @@ def save_to_table(
f"Output format could not be inferred from filename {filename.name}. Please pass `output_format`."
)

if column is None and isinstance(ds, xr.Dataset) and len(ds.data_vars) > 1:
column = "variable"

if sheet is not None and output_format != "excel":
raise ValueError(
f"Argument `sheet` is only valid with excel as the output format. Got {output_format}."
Expand All @@ -882,15 +899,22 @@ def save_to_table(
add_toc = make_toc(ds)
out = {(add_toc.attrs.get("name", "Content"),): add_toc, **out}

if sheet or (add_toc is not False):
# Get engine_kwargs
if output_format == "excel":
engine_kwargs = {} # Extract engine kwargs
for arg in signature(pd.ExcelWriter).parameters:
if arg in kwargs:
engine_kwargs[arg] = kwargs.pop(arg)
else:
engine_kwargs = {}

if sheet or (add_toc is not False):
with pd.ExcelWriter(filename, **engine_kwargs) as writer:
for sheet_name, df in out.items():
df.to_excel(writer, sheet_name=col_sep.join(sheet_name), **kwargs)
elif len(engine_kwargs) > 0:
with pd.ExcelWriter(filename, **engine_kwargs) as writer:
out.to_excel(writer, **kwargs)
else:
if output_format != "excel" and isinstance(out.columns, pd.MultiIndex):
out.columns = out.columns.map(lambda lvls: col_sep.join(map(str, lvls)))
Expand Down
Loading

0 comments on commit 0bfe1bd

Please sign in to comment.