diff --git a/intake_esm/core.py b/intake_esm/core.py index acae1492..13fa7196 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -1,6 +1,8 @@ +import concurrent.futures import json import logging -from collections import namedtuple +from collections import OrderedDict, namedtuple +from copy import deepcopy from typing import Any, Dict, List, Tuple, Union from warnings import warn @@ -107,14 +109,33 @@ def __init__( self._log_level = log_level self._datasets = {} self.sep = sep + self._data_format, self._format_column_name = None, None + self._path_column_name = self.esmcol_data['assets']['column_name'] + if 'format' in self.esmcol_data['assets']: + self._data_format = self.esmcol_data['assets']['format'] + else: + self._format_column_name = self.esmcol_data['assets']['format_column_name'] self.aggregation_info = self._get_aggregation_info() self._entries = {} self._set_groups_and_keys() super(esm_datastore, self).__init__(**kwargs) def _set_groups_and_keys(self): - self._grouped = self.df.groupby(self.aggregation_info.groupby_attrs) - self._keys = list(self._grouped.groups.keys()) + if self.aggregation_info.groupby_attrs and set(self.df.columns) != set( + self.aggregation_info.groupby_attrs + ): + self._grouped = self.df.groupby(self.aggregation_info.groupby_attrs) + internal_keys = self._grouped.groups.keys() + public_keys = [self.sep.join(str(v) for v in x) for x in internal_keys] + + else: + self._grouped = self.df + internal_keys = list(self._grouped.index) + public_keys = [ + self.sep.join(str(v) for v in row.values) for _, row in self._grouped.iterrows() + ] + + self._keys = dict(zip(public_keys, internal_keys)) def _allnan_or_nonan(self, column: str) -> bool: """ @@ -153,46 +174,27 @@ def _get_aggregation_info(self): 'aggregations', 'agg_columns', 'aggregation_dict', - 'path_column_name', - 'data_format', - 'format_column_name', ], ) groupby_attrs = [] - data_format = None - format_column_name = None variable_column_name = None aggregations = [] aggregation_dict = {} agg_columns = [] - path_column_name = self.esmcol_data['assets']['column_name'] - - if 'format' in self.esmcol_data['assets']: - data_format = self.esmcol_data['assets']['format'] - else: - format_column_name = self.esmcol_data['assets']['format_column_name'] if 'aggregation_control' in self.esmcol_data: variable_column_name = self.esmcol_data['aggregation_control']['variable_column_name'] groupby_attrs = self.esmcol_data['aggregation_control'].get('groupby_attrs', []) aggregations = self.esmcol_data['aggregation_control'].get('aggregations', []) aggregations, aggregation_dict, agg_columns = _construct_agg_info(aggregations) + groupby_attrs = list(filter(self._allnan_or_nonan, groupby_attrs)) - if not groupby_attrs: - groupby_attrs = self.df.columns.tolist() - - groupby_attrs = list(filter(self._allnan_or_nonan, groupby_attrs)) + elif not groupby_attrs or 'aggregation_control' not in self.esmcol_data: + groupby_attrs = [] aggregation_info = AggregationInfo( - groupby_attrs, - variable_column_name, - aggregations, - agg_columns, - aggregation_dict, - path_column_name, - data_format, - format_column_name, + groupby_attrs, variable_column_name, aggregations, agg_columns, aggregation_dict, ) return aggregation_info @@ -205,8 +207,7 @@ def keys(self) -> List: list keys for the catalog entries """ - keys = [self.sep.join(x) for x in self._keys] - return keys + return self._keys.keys() @property def key_template(self) -> str: @@ -218,7 +219,11 @@ def key_template(self) -> str: str string template used to create catalog entry keys """ - return self.sep.join(self.aggregation_info.groupby_attrs) + if self.aggregation_info.groupby_attrs: + template = self.sep.join(self.aggregation_info.groupby_attrs) + else: + template = self.sep.join(self.df.columns) + return template @property def df(self) -> pd.DataFrame: @@ -249,6 +254,7 @@ def groupby_attrs(self, value: list) -> None: groupby_attrs = list(filter(self._allnan_or_nonan, value)) self.aggregation_info = self.aggregation_info._replace(groupby_attrs=groupby_attrs) self._set_groups_and_keys() + self._entries = {} @property def variable_column_name(self) -> str: @@ -282,11 +288,11 @@ def path_column_name(self) -> str: """ The name of the column containing the path to the asset. """ - return self.aggregation_info.path_column_name + return self._path_column_name @path_column_name.setter def path_column_name(self, value: str) -> None: - self.aggregation_info = self.aggregation_info._replace(path_column_name=value) + self._path_column_name = value @property def data_format(self) -> str: @@ -294,22 +300,22 @@ def data_format(self) -> str: The data format. Valid values are netcdf and zarr. If specified, it means that all data assets in the catalog use the same data format. """ - return self.aggregation_info.data_format + return self._data_format @data_format.setter def data_format(self, value: str) -> None: - self.aggregation_info = self.aggregation_info._replace(data_format=value) + self._data_format = value @property def format_column_name(self) -> str: """ Name of the column which contains the data format. """ - return self.aggregation_info.format_column_name + return self._format_column_name @format_column_name.setter def format_column_name(self, value: str) -> None: - self.aggregation_info = self.aggregation_info._replace(format_column_name=value) + self._format_column_name = value def __len__(self): return len(self.keys()) @@ -353,9 +359,31 @@ def __getitem__(self, key: str): return self._entries[key] except KeyError: if key in self.keys(): - _key = tuple(key.split(self.sep)) - df = self._grouped.get_group(_key) - self._entries[key] = _make_entry(key, df, self.aggregation_info) + internal_key = self._keys[key] + if isinstance(self._grouped, pd.DataFrame): + df = self._grouped.loc[internal_key] + args = dict( + key=key, + row=df, + path_column=self.path_column_name, + data_format=self.data_format, + format_column=self.format_column_name, + ) + entry = _make_entry(key, 'esm_single_source', args) + else: + df = self._grouped.get_group(internal_key) + args = dict( + df=df, + aggregation_dict=self.aggregation_info.aggregation_dict, + path_column=self.path_column_name, + variable_column=self.aggregation_info.variable_column_name, + data_format=self.data_format, + format_column=self.format_column_name, + key=key, + ) + entry = _make_entry(key, 'esm_group', args) + + self._entries[key] = entry return self._entries[key] raise KeyError(key) @@ -673,6 +701,7 @@ def to_dataset_dict( preprocess: Dict[str, Any] = None, storage_options: Dict[str, Any] = None, progressbar: bool = None, + aggregate: bool = None, ) -> Dict[str, xr.Dataset]: """ Load catalog entries into a dictionary of xarray datasets. @@ -685,14 +714,14 @@ def to_dataset_dict( Keyword arguments to pass to :py:func:`~xarray.open_dataset` function preprocess : callable, optional If provided, call this function on each dataset prior to aggregation. - aggregate : bool, optional - If "False", no aggregation will be done. storage_options : dict, optional Parameters passed to the backend file-system such as Google Cloud Storage, Amazon Web Service S3. progressbar : bool If True, will print a progress bar to standard error (stderr) when loading assets into :py:class:`~xarray.Dataset`. + aggregate : bool, optional + If False, no aggregation will be done. Returns ------- @@ -725,9 +754,6 @@ def to_dataset_dict( pr (member_id, time, lat, lon) float32 dask.array """ - import concurrent.futures - from collections import OrderedDict - # Return fast if not self.keys(): warn('There are no datasets to load! Returning an empty dictionary.') @@ -739,7 +765,7 @@ def to_dataset_dict( preprocess=preprocess, storage_options=storage_options, ) - token = dask.base.tokenize(source_kwargs) + token = dask.base.tokenize([source_kwargs, aggregate]) if progressbar is not None: self.progressbar = progressbar @@ -749,8 +775,12 @@ def to_dataset_dict( # Avoid re-loading data if nothing has changed since the last call if self._datasets and (token == self._to_dataset_args_token): return self._datasets - self._to_dataset_args_token = token + + if aggregate is not None and not aggregate: + self = deepcopy(self) + self.groupby_attrs = [] + if self.progressbar: print( f"""\n--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{self.key_template}'""" @@ -761,37 +791,26 @@ def _load_source(source): sources = [source(**source_kwargs) for _, source in self.items()] + progress, total = None, None if self.progressbar: total = len(sources) progress = progress_bar(range(total)) with concurrent.futures.ThreadPoolExecutor(max_workers=len(sources)) as executor: future_tasks = [executor.submit(_load_source, source) for source in sources] - for i, task in enumerate(concurrent.futures.as_completed(future_tasks)): ds = task.result() self._datasets[ds.attrs['intake_esm_dataset_key']] = ds if self.progressbar: progress.update(i) - if self.progressbar: progress.update(total) - return self._datasets -def _make_entry(key, df, aggregation_info): - args = dict( - df=df, - aggregation_dict=aggregation_info.aggregation_dict, - path_column=aggregation_info.path_column_name, - variable_column=aggregation_info.variable_column_name, - data_format=aggregation_info.data_format, - format_column=aggregation_info.format_column_name, - key=key, - ) +def _make_entry(key: str, driver: str, args: dict): entry = intake.catalog.local.LocalCatalogEntry( - name=key, description='', driver='esm_group', args=args, metadata={} + name=key, description='', driver=driver, args=args, metadata={} ) return entry.get() diff --git a/intake_esm/merge_util.py b/intake_esm/merge_util.py index 9540faf8..7477fcd5 100644 --- a/intake_esm/merge_util.py +++ b/intake_esm/merge_util.py @@ -229,7 +229,7 @@ def apply_aggregation(v, agg_column=None, key=None, level=0): return apply_aggregation(v) -def _open_asset(path, data_format, zarr_kwargs, cdf_kwargs, preprocess, varname): +def _open_asset(path, data_format, zarr_kwargs, cdf_kwargs, preprocess, varname=None): protocol = None root = path if isinstance(path, fsspec.mapping.FSMap): @@ -259,8 +259,8 @@ def _open_asset(path, data_format, zarr_kwargs, cdf_kwargs, preprocess, varname) except Exception as e: logger.error(f'Failed to open netCDF/HDF dataset with cdf_kwargs={cdf_kwargs}') raise e - - ds.attrs['intake_esm_varname'] = varname + if varname: + ds.attrs['intake_esm_varname'] = varname if preprocess is None: return ds diff --git a/intake_esm/source.py b/intake_esm/source.py index 7b3e3e18..d20b787b 100644 --- a/intake_esm/source.py +++ b/intake_esm/source.py @@ -1,12 +1,88 @@ import copy +import pandas as pd from intake.source.base import DataSource, Schema -from .merge_util import _aggregate, _path_to_mapper, _to_nested_dict +from .merge_util import _aggregate, _open_asset, _path_to_mapper, _to_nested_dict _DATA_FORMAT_KEY = '_data_format_' +class ESMDataSource(DataSource): + version = '1.0' + container = 'xarray' + name = 'esm_single_source' + partition_access = True + + def __init__( + self, + key, + row, + path_column, + data_format=None, + format_column=None, + cdf_kwargs=None, + zarr_kwargs=None, + storage_options=None, + preprocess=None, + **kwargs, + ): + super().__init__(**kwargs) + self.key = key + self.cdf_kwargs = cdf_kwargs or {'chunks': {}} + self.zarr_kwargs = zarr_kwargs or {} + self.storage_options = storage_options or {} + self.preprocess = preprocess + if not isinstance(row, pd.Series) or row.empty: + raise ValueError('`row` must be a non-empty pandas.Series') + self.row = row.copy() + self.path_column = path_column + self._ds = None + if format_column is not None: + self.data_format = self.row[format_column] + elif data_format: + self.data_format = data_format + else: + raise ValueError('Please specify either `data_format` or `format_column`') + + def __repr__(self): + return f'