Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support turning aggregations off #269

Merged
merged 3 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 78 additions & 59 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import concurrent.futures
import json
import logging
from collections import namedtuple
from collections import OrderedDict, namedtuple
from copy import deepcopy
from typing import Any, Dict, List, Tuple, Union
from warnings import warn

Expand Down Expand Up @@ -107,14 +109,33 @@ def __init__(
self._log_level = log_level
self._datasets = {}
self.sep = sep
self._data_format, self._format_column_name = None, None
self._path_column_name = self.esmcol_data['assets']['column_name']
if 'format' in self.esmcol_data['assets']:
self._data_format = self.esmcol_data['assets']['format']
else:
self._format_column_name = self.esmcol_data['assets']['format_column_name']
self.aggregation_info = self._get_aggregation_info()
self._entries = {}
self._set_groups_and_keys()
super(esm_datastore, self).__init__(**kwargs)

def _set_groups_and_keys(self):
self._grouped = self.df.groupby(self.aggregation_info.groupby_attrs)
self._keys = list(self._grouped.groups.keys())
if self.aggregation_info.groupby_attrs and set(self.df.columns) != set(
self.aggregation_info.groupby_attrs
):
self._grouped = self.df.groupby(self.aggregation_info.groupby_attrs)
internal_keys = self._grouped.groups.keys()
public_keys = [self.sep.join(str(v) for v in x) for x in internal_keys]

else:
self._grouped = self.df
internal_keys = list(self._grouped.index)
public_keys = [
self.sep.join(str(v) for v in row.values) for _, row in self._grouped.iterrows()
]

self._keys = dict(zip(public_keys, internal_keys))

def _allnan_or_nonan(self, column: str) -> bool:
"""
Expand Down Expand Up @@ -153,46 +174,27 @@ def _get_aggregation_info(self):
'aggregations',
'agg_columns',
'aggregation_dict',
'path_column_name',
'data_format',
'format_column_name',
],
)

groupby_attrs = []
data_format = None
format_column_name = None
variable_column_name = None
aggregations = []
aggregation_dict = {}
agg_columns = []
path_column_name = self.esmcol_data['assets']['column_name']

if 'format' in self.esmcol_data['assets']:
data_format = self.esmcol_data['assets']['format']
else:
format_column_name = self.esmcol_data['assets']['format_column_name']

if 'aggregation_control' in self.esmcol_data:
variable_column_name = self.esmcol_data['aggregation_control']['variable_column_name']
groupby_attrs = self.esmcol_data['aggregation_control'].get('groupby_attrs', [])
aggregations = self.esmcol_data['aggregation_control'].get('aggregations', [])
aggregations, aggregation_dict, agg_columns = _construct_agg_info(aggregations)
groupby_attrs = list(filter(self._allnan_or_nonan, groupby_attrs))

if not groupby_attrs:
groupby_attrs = self.df.columns.tolist()

groupby_attrs = list(filter(self._allnan_or_nonan, groupby_attrs))
elif not groupby_attrs or 'aggregation_control' not in self.esmcol_data:
groupby_attrs = []

aggregation_info = AggregationInfo(
groupby_attrs,
variable_column_name,
aggregations,
agg_columns,
aggregation_dict,
path_column_name,
data_format,
format_column_name,
groupby_attrs, variable_column_name, aggregations, agg_columns, aggregation_dict,
)
return aggregation_info

Expand All @@ -205,8 +207,7 @@ def keys(self) -> List:
list
keys for the catalog entries
"""
keys = [self.sep.join(x) for x in self._keys]
return keys
return self._keys.keys()

@property
def key_template(self) -> str:
Expand All @@ -218,7 +219,11 @@ def key_template(self) -> str:
str
string template used to create catalog entry keys
"""
return self.sep.join(self.aggregation_info.groupby_attrs)
if self.aggregation_info.groupby_attrs:
template = self.sep.join(self.aggregation_info.groupby_attrs)
else:
template = self.sep.join(self.df.columns)
return template

@property
def df(self) -> pd.DataFrame:
Expand Down Expand Up @@ -249,6 +254,7 @@ def groupby_attrs(self, value: list) -> None:
groupby_attrs = list(filter(self._allnan_or_nonan, value))
self.aggregation_info = self.aggregation_info._replace(groupby_attrs=groupby_attrs)
self._set_groups_and_keys()
self._entries = {}

@property
def variable_column_name(self) -> str:
Expand Down Expand Up @@ -282,34 +288,34 @@ def path_column_name(self) -> str:
"""
The name of the column containing the path to the asset.
"""
return self.aggregation_info.path_column_name
return self._path_column_name

@path_column_name.setter
def path_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(path_column_name=value)
self._path_column_name = value

@property
def data_format(self) -> str:
"""
The data format. Valid values are netcdf and zarr.
If specified, it means that all data assets in the catalog use the same data format.
"""
return self.aggregation_info.data_format
return self._data_format

@data_format.setter
def data_format(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(data_format=value)
self._data_format = value

@property
def format_column_name(self) -> str:
"""
Name of the column which contains the data format.
"""
return self.aggregation_info.format_column_name
return self._format_column_name

@format_column_name.setter
def format_column_name(self, value: str) -> None:
self.aggregation_info = self.aggregation_info._replace(format_column_name=value)
self._format_column_name = value

def __len__(self):
return len(self.keys())
Expand Down Expand Up @@ -353,9 +359,31 @@ def __getitem__(self, key: str):
return self._entries[key]
except KeyError:
if key in self.keys():
_key = tuple(key.split(self.sep))
df = self._grouped.get_group(_key)
self._entries[key] = _make_entry(key, df, self.aggregation_info)
internal_key = self._keys[key]
if isinstance(self._grouped, pd.DataFrame):
df = self._grouped.loc[internal_key]
args = dict(
key=key,
row=df,
path_column=self.path_column_name,
data_format=self.data_format,
format_column=self.format_column_name,
)
entry = _make_entry(key, 'esm_single_source', args)
else:
df = self._grouped.get_group(internal_key)
args = dict(
df=df,
aggregation_dict=self.aggregation_info.aggregation_dict,
path_column=self.path_column_name,
variable_column=self.aggregation_info.variable_column_name,
data_format=self.data_format,
format_column=self.format_column_name,
key=key,
)
entry = _make_entry(key, 'esm_group', args)

self._entries[key] = entry
return self._entries[key]
raise KeyError(key)

Expand Down Expand Up @@ -673,6 +701,7 @@ def to_dataset_dict(
preprocess: Dict[str, Any] = None,
storage_options: Dict[str, Any] = None,
progressbar: bool = None,
aggregate: bool = None,
) -> Dict[str, xr.Dataset]:
"""
Load catalog entries into a dictionary of xarray datasets.
Expand All @@ -685,14 +714,14 @@ def to_dataset_dict(
Keyword arguments to pass to :py:func:`~xarray.open_dataset` function
preprocess : callable, optional
If provided, call this function on each dataset prior to aggregation.
aggregate : bool, optional
If "False", no aggregation will be done.
storage_options : dict, optional
Parameters passed to the backend file-system such as Google Cloud Storage,
Amazon Web Service S3.
progressbar : bool
If True, will print a progress bar to standard error (stderr)
when loading assets into :py:class:`~xarray.Dataset`.
aggregate : bool, optional
If False, no aggregation will be done.

Returns
-------
Expand Down Expand Up @@ -725,9 +754,6 @@ def to_dataset_dict(
pr (member_id, time, lat, lon) float32 dask.array<chunksize=(1, 600, 160, 320), meta=np.ndarray>
"""

import concurrent.futures
from collections import OrderedDict

# Return fast
if not self.keys():
warn('There are no datasets to load! Returning an empty dictionary.')
Expand All @@ -739,7 +765,7 @@ def to_dataset_dict(
preprocess=preprocess,
storage_options=storage_options,
)
token = dask.base.tokenize(source_kwargs)
token = dask.base.tokenize([source_kwargs, aggregate])
if progressbar is not None:
self.progressbar = progressbar

Expand All @@ -749,8 +775,12 @@ def to_dataset_dict(
# Avoid re-loading data if nothing has changed since the last call
if self._datasets and (token == self._to_dataset_args_token):
return self._datasets

self._to_dataset_args_token = token

if aggregate is not None and not aggregate:
self = deepcopy(self)
self.groupby_attrs = []

if self.progressbar:
print(
f"""\n--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{self.key_template}'"""
Expand All @@ -761,37 +791,26 @@ def _load_source(source):

sources = [source(**source_kwargs) for _, source in self.items()]

progress, total = None, None
if self.progressbar:
total = len(sources)
progress = progress_bar(range(total))

with concurrent.futures.ThreadPoolExecutor(max_workers=len(sources)) as executor:
future_tasks = [executor.submit(_load_source, source) for source in sources]

for i, task in enumerate(concurrent.futures.as_completed(future_tasks)):
ds = task.result()
self._datasets[ds.attrs['intake_esm_dataset_key']] = ds
if self.progressbar:
progress.update(i)

if self.progressbar:
progress.update(total)

return self._datasets


def _make_entry(key, df, aggregation_info):
args = dict(
df=df,
aggregation_dict=aggregation_info.aggregation_dict,
path_column=aggregation_info.path_column_name,
variable_column=aggregation_info.variable_column_name,
data_format=aggregation_info.data_format,
format_column=aggregation_info.format_column_name,
key=key,
)
def _make_entry(key: str, driver: str, args: dict):
entry = intake.catalog.local.LocalCatalogEntry(
name=key, description='', driver='esm_group', args=args, metadata={}
name=key, description='', driver=driver, args=args, metadata={}
)
return entry.get()

Expand Down
6 changes: 3 additions & 3 deletions intake_esm/merge_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def apply_aggregation(v, agg_column=None, key=None, level=0):
return apply_aggregation(v)


def _open_asset(path, data_format, zarr_kwargs, cdf_kwargs, preprocess, varname):
def _open_asset(path, data_format, zarr_kwargs, cdf_kwargs, preprocess, varname=None):
protocol = None
root = path
if isinstance(path, fsspec.mapping.FSMap):
Expand Down Expand Up @@ -259,8 +259,8 @@ def _open_asset(path, data_format, zarr_kwargs, cdf_kwargs, preprocess, varname)
except Exception as e:
logger.error(f'Failed to open netCDF/HDF dataset with cdf_kwargs={cdf_kwargs}')
raise e

ds.attrs['intake_esm_varname'] = varname
if varname:
ds.attrs['intake_esm_varname'] = varname

if preprocess is None:
return ds
Expand Down
Loading