Skip to content

Commit

Permalink
Rename index to row to avoid name collision with pandas args
Browse files Browse the repository at this point in the history
  • Loading branch information
aulemahal committed Sep 28, 2023
1 parent ed97cd8 commit 3d47138
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 38 deletions.
2 changes: 1 addition & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_normal(self):

# Variable in the index, thus no coords
tab = xs.io.to_table(
self.ds, index=["time", "variable"], column=["season", "site"], coords=False
self.ds, row=["time", "variable"], column=["season", "site"], coords=False
)
assert tab.shape == (15, 24)
assert tab.columns.names == ["season", "site"]
Expand Down
80 changes: 43 additions & 37 deletions xscen/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def _skip(var):

def _to_dataframe(
data: xr.DataArray,
index: list[str],
row: list[str],
column: list[str],
coords: list[str],
coords_dims: dict,
Expand All @@ -505,18 +505,15 @@ def _to_dataframe(
df_data = (
df[[data.name]]
.reset_index()
.pivot(index=index, columns=column)
.pivot(index=row, columns=column)
.droplevel(None, axis=1)
)
dfs = []
for v in coords:
drop_cols = [c for c in column if c not in coords_dims[v]]
cols = [c for c in column if c in coords_dims[v]]
dfc = (
df[[v]]
.reset_index()
.drop(columns=drop_cols)
.pivot(index=index, columns=cols)
df[[v]].reset_index().drop(columns=drop_cols).pivot(index=row, columns=cols)
)
cols = dfc.columns
# The "None" level has the aux coord name we want it either at the same level as variable, or at lowest missing level otherwise.
Expand All @@ -542,13 +539,13 @@ def _to_dataframe(
dfc[~dfc.index.duplicated()]
) # We dropped columns thus the index is not unique anymore
dfs.append(df_data)
return pd.concat(dfs, axis=1).sort_index(level=index, key=season_sort_key)
return pd.concat(dfs, axis=1).sort_index(level=row, key=season_sort_key)


def to_table(
ds: Union[xr.Dataset, xr.DataArray],
*,
index: Union[None, str, Sequence[str]] = None,
row: Union[None, str, Sequence[str]] = None,
column: Union[None, str, Sequence[str]] = None,
sheet: Union[None, str, Sequence[str]] = None,
coords: Union[bool, Sequence[str]] = True,
Expand All @@ -562,12 +559,12 @@ def to_table(
ds : xr.Dataset or xr.DataArray
Dataset or DataArray to be saved.
If a Dataset with more than one variable is given, the dimension "variable"
must appear in one of `index`, `column` or `sheet`.
index : str or sequence of str, optional
Name of the dimension(s) to use as index.
must appear in one of `row`, `column` or `sheet`.
row : str or sequence of str, optional
Name of the dimension(s) to use as indexes (rows).
Default is all data dimensions.
column : str or sequence of str, optional
Name of the dimension(s) to use as index.
Name of the dimension(s) to use as columns.
Default is "variable", i.e. the name of the variable(s).
sheet : str or sequence of str, optional
Name of the dimension(s) to use as sheet names.
Expand All @@ -578,9 +575,9 @@ def to_table(
Returns
-------
pd.DataFrame or dict
DataFrame with a MultiIndex with levels `index` and MultiColumn with levels `column`.
DataFrame with a MultiIndex with levels `row` and MultiColumn with levels `column`.
If `sheet` is given, the output is dictionary with keys for each unique "sheet" dimensions tuple, values are DataFrames.
The DataFrames are always sorted with level priority as given in `index` and in ascending order,.
The DataFrames are always sorted with level priority as given in `row` and in ascending order,.
"""
if isinstance(ds, xr.Dataset):
da = ds.to_array(name="data")
Expand All @@ -592,19 +589,19 @@ def _ensure_list(seq):
return [seq]
return list(seq)

index = _ensure_list(index or (set(da.dims) - {"variable"}))
row = _ensure_list(row or (set(da.dims) - {"variable"}))
column = _ensure_list(column or (["variable"] if len(ds) > 1 else []))
sheet = _ensure_list(sheet or [])

needed_dims = index + column + sheet
needed_dims = row + column + sheet
if len(set(needed_dims)) != len(needed_dims):
raise ValueError(
f"Repeated dimension names. Got index={index}, column={column} and sheet={sheet}."
f"Repeated dimension names. Got row={row}, column={column} and sheet={sheet}."
"Each dimension should appear only once."
)
if set(needed_dims) != set(da.dims):
raise ValueError(
f"Passed index, column and sheet do not match available dimensions. Got {needed_dims}, data has {da.dims}."
f"Passed row, column and sheet do not match available dimensions. Got {needed_dims}, data has {da.dims}."
)

coords = coords or []
Expand All @@ -613,13 +610,14 @@ def _ensure_list(seq):
da = da.drop_vars(drop)
else:
coords = list(set(ds.coords.keys()) - set(da.dims))
if len(coords) > 1 and "variable" in index:
if len(coords) > 1 and ("variable" in row or "variable" in sheet):
raise NotImplementedError(
"Keeping auxiliary coords is not implemented when 'variable' is in the index. Pass `coords=False` or put 'variable' in `column` instead."
"Keeping auxiliary coords is not implemented when 'variable' is in the row or in the sheets."
"Pass `coords=False` or put 'variable' in `column` instead."
)

table_kwargs = dict(
index=index,
row=row,
column=column,
coords=coords,
coords_dims={c: ds[c].dims for c in coords},
Expand All @@ -643,11 +641,12 @@ def save_to_table(
filename: str,
output_format: Optional[str] = None,
*,
index: Union[None, str, Sequence[str]] = None,
row: Union[None, str, Sequence[str]] = None,
column: Union[None, str, Sequence[str]] = "variable",
sheet: Union[None, str, Sequence[str]] = None,
coords: Union[bool, Sequence[str]] = True,
sep: str = "_",
col_sep: str = "_",
row_sep: str = None,
**kwargs,
):
"""Save the dataset to a tabular file (csv, excel, ...).
Expand All @@ -659,29 +658,30 @@ def save_to_table(
ds : xr.Dataset or xr.DataArray
Dataset or DataArray to be saved.
If a Dataset with more than one variable is given, the dimension "variable"
must appear in one of `index`, `column` or `sheet`.
must appear in one of `row`, `column` or `sheet`.
filename : str
Name of the file to be saved.
output_format: {'csv', 'excel', ...}, optional
The output format. If None (default), it is inferred
from the extension of `filename`. Not all possible output format are supported for inference.
Valid values are any that matches a :py:class:`pandas.DataFrame` method like "df.to_{format}".
index : str or sequence of str, optional
Name of the dimension(s) to use as index.
row : str or sequence of str, optional
Name of the dimension(s) to use as indexes (rows).
Default is all data dimensions.
column : str or sequence of str, optional
Name of the dimension(s) to use as index.
Name of the dimension(s) to use as columns.
Default is "variable", i.e. the name of the variable(s).
sheet : str or sequence of str, optional
Name of the dimension(s) to use as sheet names.
Only valid if the output format is excel.
coords: bool or sequence of str
A list of auxiliary coordinates to add to the columns (as would variables).
If True, all (if any) are added.
sep : str
For output formats other than excel and for sheet names,
index, column and sheet names from multiple dimensions are
constructed by concatenating values with this separator.
col_sep : str,
Multi-columns (except in excel) and sheet names are concatenated with this separator.
row_sep : str, optional
Multi-index names are concatenated with this separator, except in excel.
If None (default), each level is written in its own column.
kwargs:
Other arguments passed to the panda function.
"""
Expand All @@ -699,17 +699,23 @@ def save_to_table(
f"Argument `sheet` is only valid with excel as the output format. Got {output_format}."
)

out = to_table(ds, index=index, column=column, sheet=sheet, coords=coords)
out = to_table(ds, row=row, column=column, sheet=sheet, coords=coords)

if sheet:
with pd.ExcelWriter(filename, engine=kwargs.get("engine")) as writer:
for sheet_name, df in out.items():
df.to_excel(writer, sheet_name=sep.join(sheet_name), **kwargs)
df.to_excel(writer, sheet_name=col_sep.join(sheet_name), **kwargs)
else:
if isinstance(out.columns, pd.MultiIndex):
out.columns = out.columns.map(lambda lvls: sep.join(map(str, lvls)))
if isinstance(out.index, pd.MultiIndex):
out.index = out.index.map(lambda lvls: sep.join(map(str, lvls)))
if output_format != "excel" and isinstance(out.columns, pd.MultiIndex):
out.columns = out.columns.map(lambda lvls: col_sep.join(map(str, lvls)))
if (
output_format != "excel"
and row_sep is not None
and isinstance(out.index, pd.MultiIndex)
):
new_name = row_sep.join(out.index.names)
out.index = out.index.map(lambda lvls: row_sep.join(map(str, lvls)))
out.index.name = new_name
getattr(out, f"to_{output_format}")(filename, **kwargs)


Expand Down

0 comments on commit 3d47138

Please sign in to comment.