diff --git a/intake_esm/core.py b/intake_esm/core.py index 2ed5c638..f1d0ac51 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -14,10 +14,10 @@ import requests from .merge_util import ( + _aggregate, _create_asset_info_lookup, _restore_non_dim_coords, - aggregate, - to_nested_dict, + _to_nested_dict, ) logger = logging.getLogger(__name__) @@ -74,6 +74,7 @@ def __init__(self, esmcol_path, **kwargs): self.zarr_kwargs = None self.cdf_kwargs = None self.preprocess = None + self.aggregate = None self.metadata = {} super().__init__(**kwargs) @@ -231,7 +232,9 @@ def _get_subset(self, **query): query_results = self.df.loc[condition] return query_results - def to_dataset_dict(self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess=None): + def to_dataset_dict( + self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess=None, aggregate=True + ): """Load catalog entries into a dictionary of xarray datasets. Parameters @@ -242,7 +245,8 @@ def to_dataset_dict(self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess= Keyword arguments to pass to `xarray.open_dataset()` function preprocess : (callable, optional) If provided, call this function on each dataset prior to aggregation. - + aggregate : (boolean, optional) + If "False", no aggregation will be done. Returns ------- dsets : dict @@ -329,6 +333,7 @@ def to_dataset_dict(self, zarr_kwargs={}, cdf_kwargs={'chunks': {}}, preprocess= self.zarr_kwargs = zarr_kwargs self.cdf_kwargs = cdf_kwargs + self.aggregate = aggregate if preprocess is not None and not callable(preprocess): raise ValueError('preprocess argument must be callable') @@ -358,25 +363,50 @@ def _open_dataset(self): path: _path_to_mapper(path) for path in self.df[path_column_name] } # replace path column with mapper (dependent on filesystem type) - groupby_attrs = self._col_data['aggregation_control'].get('groupby_attrs', []) - aggregations = self._col_data['aggregation_control'].get('aggregations', []) - variable_column_name = self._col_data['aggregation_control']['variable_column_name'] - + groupby_attrs = [] + variable_column_name = None + aggregations = [] aggregation_dict = {} - for agg in aggregations: - key = agg['attribute_name'] - rest = agg.copy() - del rest['attribute_name'] - aggregation_dict[key] = rest - - agg_columns = list(aggregation_dict.keys()) - - if groupby_attrs: - groups = self.df.groupby(groupby_attrs) + agg_columns = [] + if self.aggregate: + if 'aggregation_control' in self._col_data: + variable_column_name = self._col_data['aggregation_control']['variable_column_name'] + groupby_attrs = self._col_data['aggregation_control'].get('groupby_attrs', []) + aggregations = self._col_data['aggregation_control'].get('aggregations', []) + + for agg in aggregations: + key = agg['attribute_name'] + rest = agg.copy() + del rest['attribute_name'] + aggregation_dict[key] = rest + + agg_columns = list(aggregation_dict.keys()) + + if not groupby_attrs: + groupby_attrs = self.df.columns.tolist() + + # filter groupby_attrs to ensure no columns with all nans + def _allnan_or_nonan(column): + if self.df[column].isnull().all(): + return False + elif self.df[column].isnull().any(): + raise ValueError( + f'The data in the {column} column should either be all NaN or there should be no NaNs' + ) + else: + return True + + groupby_attrs = list(filter(_allnan_or_nonan, groupby_attrs)) + + groups = self.df.groupby(groupby_attrs) + + if agg_columns: + keys = '.'.join(groupby_attrs) else: - groups = self.df.groupby(self.df.columns.tolist()) + keys = path_column_name + print( - f"""--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{".".join(groupby_attrs)}'""" + f"""--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{keys}'""" ) print(f'\n--> There will be {len(groups)} group(s)') @@ -399,9 +429,8 @@ def _open_dataset(self): ] dsets = dask.compute(*dsets) - del mapper_dict - self._ds = {dset[0]: dset[1] for dset in dsets} + self._ds = {group_id: ds for (group_id, ds) in dsets} def _unique(df, columns): @@ -449,8 +478,13 @@ def _load_group_dataset( # the number of aggregation columns determines the level of recursion n_agg = len(agg_columns) - mi = df.set_index(agg_columns) - nd = to_nested_dict(mi[path_column_name]) + if agg_columns: + mi = df.set_index(agg_columns) + nd = _to_nested_dict(mi[path_column_name]) + group_id = '.'.join(key) + else: + nd = df.iloc[0][path_column_name] + group_id = nd if use_format_column: format_column_name = col_data['assets']['format_column_name'] @@ -458,12 +492,11 @@ def _load_group_dataset( df, path_column_name, variable_column_name, format_column_name=format_column_name ) else: - lookup = _create_asset_info_lookup( df, path_column_name, variable_column_name, data_format=col_data['assets']['format'] ) - ds = aggregate( + ds = _aggregate( aggregation_dict, agg_columns, n_agg, @@ -474,8 +507,11 @@ def _load_group_dataset( cdf_kwargs, preprocess, ) - group_id = '.'.join(key) - return group_id, _restore_non_dim_coords(ds) + + if variable_column_name is None: + return group_id, ds + else: + return group_id, _restore_non_dim_coords(ds) def _is_valid_url(url): diff --git a/intake_esm/merge_util.py b/intake_esm/merge_util.py index b7e77e03..31667933 100644 --- a/intake_esm/merge_util.py +++ b/intake_esm/merge_util.py @@ -14,33 +14,35 @@ def union(dsets, options={}): return xr.merge(dsets, **options) -def to_nested_dict(df): +def _to_nested_dict(df): """Converts a multiindex series to nested dict""" if hasattr(df.index, 'levels') and len(df.index.levels) > 1: ret = {} for k, v in df.groupby(level=0): - ret[k] = to_nested_dict(v.droplevel(0)) + ret[k] = _to_nested_dict(v.droplevel(0)) return ret else: return df.to_dict() def _create_asset_info_lookup( - df, path_column_name, variable_column_name, data_format=None, format_column_name=None + df, path_column_name, variable_column_name=None, data_format=None, format_column_name=None ): if data_format: - return dict( - zip(df[path_column_name], tuple(zip(df[variable_column_name], [data_format] * len(df)))) - ) - + data_format_list = [data_format] * len(df) elif format_column_name is not None: - return dict( - zip(df[path_column_name], tuple(zip(df[variable_column_name], df[format_column_name]))) - ) + data_format_list = df[format_column_name] + + if variable_column_name is None: + varname_list = [None] * len(df) + else: + varname_list = df[variable_column_name] + return dict(zip(df[path_column_name], tuple(zip(varname_list, data_format_list)))) -def aggregate( + +def _aggregate( aggregation_dict, agg_columns, n_agg, @@ -62,9 +64,9 @@ def apply_aggregation(v, agg_column=None, key=None, level=0): # return open_dataset(v) varname = lookup[v][0] data_format = lookup[v][1] - return open_dataset( + return _open_asset( mapper_dict[v], - varname=[varname], + varname=varname, data_format=data_format, zarr_kwargs=zarr_kwargs, cdf_kwargs=cdf_kwargs, @@ -124,7 +126,7 @@ def apply_aggregation(v, agg_column=None, key=None, level=0): return apply_aggregation(v) -def open_dataset(path, varname, data_format, zarr_kwargs, cdf_kwargs, preprocess): +def _open_asset(path, varname, data_format, zarr_kwargs, cdf_kwargs, preprocess): if data_format == 'zarr': ds = xr.open_zarr(path, **zarr_kwargs) @@ -147,9 +149,13 @@ def _restore_non_dim_coords(ds): def _set_coords(ds, varname): """Set all variables except varname to be coords.""" + if varname is None: + return ds + if isinstance(varname, str): varname = [varname] coord_vars = set(ds.data_vars) - set(varname) + return ds.set_coords(coord_vars) diff --git a/tests/cmip6/test_cmip6.py b/tests/cmip6/test_cmip6.py index bb0af218..d77f180d 100644 --- a/tests/cmip6/test_cmip6.py +++ b/tests/cmip6/test_cmip6.py @@ -3,10 +3,12 @@ import intake import pandas as pd import pytest +import xarray as xr here = os.path.abspath(os.path.dirname(__file__)) zarr_col = os.path.join(here, 'pangeo-cmip6-zarr.json') cdf_col = os.path.join(here, 'cmip6-netcdf.json') + zarr_query = dict( variable_id=['pr'], experiment_id='ssp370', @@ -35,12 +37,27 @@ def test_to_dataset_dict(esmcol_path, query, kwargs): cat = col.search(**query) if kwargs: _, ds = cat.to_dataset_dict(cdf_kwargs=kwargs).popitem() - _, ds = cat.to_dataset_dict().popitem() + else: + _, ds = cat.to_dataset_dict().popitem() assert 'member_id' in ds.dims assert len(ds.__dask_keys__()) > 0 assert ds.time.encoding +@pytest.mark.parametrize('esmcol_path, query', [(cdf_col, cdf_query)]) +def test_to_dataset_dict_aggfalse(esmcol_path, query): + col = intake.open_esm_datastore(esmcol_path) + cat = col.search(**query) + nds = len(cat.df) + + dsets = cat.to_dataset_dict(aggregate=False) + assert len(dsets.keys()) == nds + path, ds = dsets.popitem() + + xr_ds = xr.open_dataset(path) + xr.testing.assert_identical(xr_ds, ds) + + @pytest.mark.parametrize( 'esmcol_path, query, kwargs', [(zarr_col, zarr_query, {}), (cdf_col, cdf_query, {'chunks': {'time': 1}})],