Skip to content

Commit

Permalink
Merge pull request #164 from matt-long/no-aggregation
Browse files Browse the repository at this point in the history
Add option to disable aggregation
  • Loading branch information
andersy005 authored Oct 18, 2019
2 parents b08f446 + 4b78e2a commit cf8eab2
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 43 deletions.
92 changes: 64 additions & 28 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import requests

from .merge_util import (
_aggregate,
_create_asset_info_lookup,
_restore_non_dim_coords,
aggregate,
to_nested_dict,
_to_nested_dict,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(self, esmcol_path, **kwargs):
self.zarr_kwargs = None
self.cdf_kwargs = None
self.preprocess = None
self.aggregate = None
self.metadata = {}
super().__init__(**kwargs)

Expand Down Expand Up @@ -231,7 +232,9 @@ def _get_subset(self, **query):
query_results = self.df.loc[condition]
return query_results

def to_dataset_dict(self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess=None):
def to_dataset_dict(
self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess=None, aggregate=True
):
"""Load catalog entries into a dictionary of xarray datasets.
Parameters
Expand All @@ -242,7 +245,8 @@ def to_dataset_dict(self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess=
Keyword arguments to pass to `xarray.open_dataset()` function
preprocess : (callable, optional)
If provided, call this function on each dataset prior to aggregation.
aggregate : (boolean, optional)
If "False", no aggregation will be done.
Returns
-------
dsets : dict
Expand Down Expand Up @@ -329,6 +333,7 @@ def to_dataset_dict(self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess=

self.zarr_kwargs = zarr_kwargs
self.cdf_kwargs = cdf_kwargs
self.aggregate = aggregate

if preprocess is not None and not callable(preprocess):
raise ValueError('preprocess argument must be callable')
Expand Down Expand Up @@ -358,25 +363,50 @@ def _open_dataset(self):
path: _path_to_mapper(path) for path in self.df[path_column_name]
} # replace path column with mapper (dependent on filesystem type)

groupby_attrs = self._col_data['aggregation_control'].get('groupby_attrs', [])
aggregations = self._col_data['aggregation_control'].get('aggregations', [])
variable_column_name = self._col_data['aggregation_control']['variable_column_name']

groupby_attrs = []
variable_column_name = None
aggregations = []
aggregation_dict = {}
for agg in aggregations:
key = agg['attribute_name']
rest = agg.copy()
del rest['attribute_name']
aggregation_dict[key] = rest

agg_columns = list(aggregation_dict.keys())

if groupby_attrs:
groups = self.df.groupby(groupby_attrs)
agg_columns = []
if self.aggregate:
if 'aggregation_control' in self._col_data:
variable_column_name = self._col_data['aggregation_control']['variable_column_name']
groupby_attrs = self._col_data['aggregation_control'].get('groupby_attrs', [])
aggregations = self._col_data['aggregation_control'].get('aggregations', [])

for agg in aggregations:
key = agg['attribute_name']
rest = agg.copy()
del rest['attribute_name']
aggregation_dict[key] = rest

agg_columns = list(aggregation_dict.keys())

if not groupby_attrs:
groupby_attrs = self.df.columns.tolist()

# filter groupby_attrs to ensure no columns with all nans
def _allnan_or_nonan(column):
if self.df[column].isnull().all():
return False
elif self.df[column].isnull().any():
raise ValueError(
f'The data in the {column} column should either be all NaN or there should be no NaNs'
)
else:
return True

groupby_attrs = list(filter(_allnan_or_nonan, groupby_attrs))

groups = self.df.groupby(groupby_attrs)

if agg_columns:
keys = '.'.join(groupby_attrs)
else:
groups = self.df.groupby(self.df.columns.tolist())
keys = path_column_name

print(
f"""--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{".".join(groupby_attrs)}'"""
f"""--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{keys}'"""
)
print(f'\n--> There will be {len(groups)} group(s)')

Expand All @@ -399,9 +429,8 @@ def _open_dataset(self):
]

dsets = dask.compute(*dsets)
del mapper_dict

self._ds = {dset[0]: dset[1] for dset in dsets}
self._ds = {group_id: ds for (group_id, ds) in dsets}


def _unique(df, columns):
Expand Down Expand Up @@ -449,21 +478,25 @@ def _load_group_dataset(
# the number of aggregation columns determines the level of recursion
n_agg = len(agg_columns)

mi = df.set_index(agg_columns)
nd = to_nested_dict(mi[path_column_name])
if agg_columns:
mi = df.set_index(agg_columns)
nd = _to_nested_dict(mi[path_column_name])
group_id = '.'.join(key)
else:
nd = df.iloc[0][path_column_name]
group_id = nd

if use_format_column:
format_column_name = col_data['assets']['format_column_name']
lookup = _create_asset_info_lookup(
df, path_column_name, variable_column_name, format_column_name=format_column_name
)
else:

lookup = _create_asset_info_lookup(
df, path_column_name, variable_column_name, data_format=col_data['assets']['format']
)

ds = aggregate(
ds = _aggregate(
aggregation_dict,
agg_columns,
n_agg,
Expand All @@ -474,8 +507,11 @@ def _load_group_dataset(
cdf_kwargs,
preprocess,
)
group_id = '.'.join(key)
return group_id, _restore_non_dim_coords(ds)

if variable_column_name is None:
return group_id, ds
else:
return group_id, _restore_non_dim_coords(ds)


def _is_valid_url(url):
Expand Down
34 changes: 20 additions & 14 deletions intake_esm/merge_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,35 @@ def union(dsets, options={}):
return xr.merge(dsets, **options)


def to_nested_dict(df):
def _to_nested_dict(df):
"""Converts a multiindex series to nested dict"""
if hasattr(df.index, 'levels') and len(df.index.levels) > 1:
ret = {}
for k, v in df.groupby(level=0):
ret[k] = to_nested_dict(v.droplevel(0))
ret[k] = _to_nested_dict(v.droplevel(0))
return ret
else:
return df.to_dict()


def _create_asset_info_lookup(
df, path_column_name, variable_column_name, data_format=None, format_column_name=None
df, path_column_name, variable_column_name=None, data_format=None, format_column_name=None
):

if data_format:
return dict(
zip(df[path_column_name], tuple(zip(df[variable_column_name], [data_format] * len(df))))
)

data_format_list = [data_format] * len(df)
elif format_column_name is not None:
return dict(
zip(df[path_column_name], tuple(zip(df[variable_column_name], df[format_column_name])))
)
data_format_list = df[format_column_name]

if variable_column_name is None:
varname_list = [None] * len(df)
else:
varname_list = df[variable_column_name]

return dict(zip(df[path_column_name], tuple(zip(varname_list, data_format_list))))

def aggregate(

def _aggregate(
aggregation_dict,
agg_columns,
n_agg,
Expand All @@ -62,9 +64,9 @@ def apply_aggregation(v, agg_column=None, key=None, level=0):
# return open_dataset(v)
varname = lookup[v][0]
data_format = lookup[v][1]
return open_dataset(
return _open_asset(
mapper_dict[v],
varname=[varname],
varname=varname,
data_format=data_format,
zarr_kwargs=zarr_kwargs,
cdf_kwargs=cdf_kwargs,
Expand Down Expand Up @@ -124,7 +126,7 @@ def apply_aggregation(v, agg_column=None, key=None, level=0):
return apply_aggregation(v)


def open_dataset(path, varname, data_format, zarr_kwargs, cdf_kwargs, preprocess):
def _open_asset(path, varname, data_format, zarr_kwargs, cdf_kwargs, preprocess):

if data_format == 'zarr':
ds = xr.open_zarr(path, **zarr_kwargs)
Expand All @@ -147,9 +149,13 @@ def _restore_non_dim_coords(ds):

def _set_coords(ds, varname):
"""Set all variables except varname to be coords."""
if varname is None:
return ds

if isinstance(varname, str):
varname = [varname]
coord_vars = set(ds.data_vars) - set(varname)

return ds.set_coords(coord_vars)


Expand Down
19 changes: 18 additions & 1 deletion tests/cmip6/test_cmip6.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import intake
import pandas as pd
import pytest
import xarray as xr

here = os.path.abspath(os.path.dirname(__file__))
zarr_col = os.path.join(here, 'pangeo-cmip6-zarr.json')
cdf_col = os.path.join(here, 'cmip6-netcdf.json')

zarr_query = dict(
variable_id=['pr'],
experiment_id='ssp370',
Expand Down Expand Up @@ -35,12 +37,27 @@ def test_to_dataset_dict(esmcol_path, query, kwargs):
cat = col.search(**query)
if kwargs:
_, ds = cat.to_dataset_dict(cdf_kwargs=kwargs).popitem()
_, ds = cat.to_dataset_dict().popitem()
else:
_, ds = cat.to_dataset_dict().popitem()
assert 'member_id' in ds.dims
assert len(ds.__dask_keys__()) > 0
assert ds.time.encoding


@pytest.mark.parametrize('esmcol_path, query', [(cdf_col, cdf_query)])
def test_to_dataset_dict_aggfalse(esmcol_path, query):
col = intake.open_esm_datastore(esmcol_path)
cat = col.search(**query)
nds = len(cat.df)

dsets = cat.to_dataset_dict(aggregate=False)
assert len(dsets.keys()) == nds
path, ds = dsets.popitem()

xr_ds = xr.open_dataset(path)
xr.testing.assert_identical(xr_ds, ds)


@pytest.mark.parametrize(
'esmcol_path, query, kwargs',
[(zarr_col, zarr_query, {}), (cdf_col, cdf_query, {'chunks': {'time': 1}})],
Expand Down

0 comments on commit cf8eab2

Please sign in to comment.