Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite MultiZarrToZarr to always use xr.concat #33

Merged
merged 1 commit into from
Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 38 additions & 25 deletions fsspec_reference_maker/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@
import numcodecs
import xarray as xr
import zarr
logging = logging.getLogger('reference-combine')
logger = logging.getLogger('reference-combine')


class MultiZarrToZarr:

def __init__(self, path, remote_protocol,
remote_options=None, xarray_kwargs=None, storage_options=None,
with_mf=True):
remote_options=None, xarray_open_kwargs=None, xarray_concat_args=None,
preprocess=None, storage_options=None):
"""

:param path: a URL containing multiple JSONs
:param xarray_kwargs:
:param storage_options:
"""
xarray_kwargs = xarray_kwargs or {}
self.path = path
self.with_mf = with_mf
self.xr_kwargs = xarray_kwargs
self.xr_kwargs = xarray_open_kwargs or {}
self.concat_kwargs = xarray_concat_args or {}
self.storage_options = storage_options or {}
self.preprocess = preprocess
self.remote_protocol = remote_protocol
self.remote_options = remote_options or {}

Expand All @@ -36,6 +36,7 @@ def translate(self, outpath):
self.output = self._consolidate(out)

self._write(self.output, outpath)
# TODO: return new zarr dataset?

@staticmethod
def _write(refs, outpath, filetype=None):
Expand Down Expand Up @@ -86,7 +87,7 @@ def _write(refs, outpath, filetype=None):
compression="ZSTD"
)

def _consolidate(self, mapping, inline_threashold=100, template_count=5):
def _consolidate(self, mapping, inline_threshold=100, template_count=5):
counts = Counter(v[0] for v in mapping.values() if isinstance(v, list))

def letter_sets():
Expand All @@ -104,7 +105,7 @@ def letter_sets():

out = {}
for k, v in mapping.items():
if isinstance(v, list) and v[2] < inline_threashold:
if isinstance(v, list) and v[2] < inline_threshold:
v = self.fs.cat_file(v[0], start=v[1], end=v[1] + v[2])
if isinstance(v, bytes):
try:
Expand Down Expand Up @@ -158,15 +159,17 @@ def _determine_dims(self):
self.fs = fss[0].fs
mappers = [fs.get_mapper("") for fs in fss]

if self.with_mf is True:
ds = xr.open_mfdataset(mappers, engine="zarr", chunks={}, **self.xr_kwargs)
ds0 = xr.open_mfdataset(mappers[:1], engine="zarr", chunks={}, **self.xr_kwargs)
else:
dss = [xr.open_dataset(m, engine="zarr", chunks={}, **self.xr_kwargs) for m in mappers]
ds = xr.concat(dss, dim=self.with_mf)
ds0 = dss[0]
dss = [xr.open_dataset(m, engine="zarr", chunks={}, **self.xr_kwargs)
for m in mappers]
if self.preprocess:
dss = [self.preprocess(d) for d in dss]
ds = xr.concat(dss, **self.concat_kwargs)
ds0 = dss[0]
self.extra_dims = set(ds.dims) - set(ds0.dims)
self.concat_dims = set(k for k, v in ds.dims.items() if k in ds0.dims and v / ds0.dims[k] == len(mappers))
self.concat_dims = set(
k for k, v in ds.dims.items()
if k in ds0.dims and v / ds0.dims[k] == len(mappers)
)
self.same_dims = set(ds.dims) - self.extra_dims - self.concat_dims
return ds, ds0, fss

Expand All @@ -180,19 +183,29 @@ def drop_coords(ds):
ds = ds.drop(['reference_time', 'crs'])
return ds.reset_coords(drop=True)

xarray_open_kwargs = {
"decode_cf": False,
"mask_and_scale": False,
"decode_times": False,
"decode_timedelta": False,
"use_cftime": False,
"decode_coords": False
}
concat_kwargs = {
"data_vars": "minimal",
"coords": "minimal",
"compat": "override",
"join": "override",
"combine_attrs": "override",
"dim": "time"
}
mzz = MultiZarrToZarr(
"zip://*.json::out.zip",
remote_protocol="s3",
remote_options={'anon': True},
xarray_kwargs={
"preprocess": drop_coords,
"decode_cf": False,
"mask_and_scale": False,
"decode_times": False,
"decode_timedelta": False,
"use_cftime": False,
"decode_coords": False
},
preprocess=drop_coords,
xarray_open_kwargs=xarray_open_kwargs,
xarray_concat_args=concat_kwargs
)
mzz.translate("output.zarr")

3 changes: 2 additions & 1 deletion fsspec_reference_maker/grib2.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,6 @@ def example_multi(filter={'typeOfLevel': 'heightAboveGround', 'level': 2}):
# 'hrrr.t04z.wrfsfcf01.json',
# 'hrrr.t05z.wrfsfcf01.json',
# 'hrrr.t06z.wrfsfcf01.json']
# mzz = MultiZarrToZarr(files, remote_protocol="s3", remote_options={"anon": True}, with_mf='time')
# mzz = MultiZarrToZarr(files, remote_protocol="s3", remote_options={"anon": True}
# concat_kwargs={"dim": 'time'})
# mzz.translate("hrrr.total.json")
2 changes: 1 addition & 1 deletion fsspec_reference_maker/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def example_single():
)
fsspec.utils.setup_logging(logger=lggr)
with fsspec.open(url, **so) as f:
h5chunks = SingleHdf5ToZarr(f, url, xarray=True)
h5chunks = SingleHdf5ToZarr(f, url)
return h5chunks.translate()


Expand Down