Skip to content

Commit

Permalink
use h5netcdf xarray engine
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahAlidoost committed Dec 9, 2024
1 parent d95d0a5 commit ddef620
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dependencies = [
"numpy",
"pandas",
"matplotlib",
"xarray",
"xarray[io]",
"scipy", # required for xarray.interpolate
"rioxarray", # required for TIFF files
"tqdm",
Expand Down
4 changes: 2 additions & 2 deletions src/zampy/datasets/cds_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def convert_to_zampy(
ds = parse_nc_file(file)
# Rename the vswl data:
ncfile = Path(str(ncfile).replace("volumetric_soil_water", "soil_moisture"))
ds.to_netcdf(path=ncfile)
ds.to_netcdf(path=ncfile, engine="h5netcdf")


var_reference_ecmwf_to_zampy = {
Expand Down Expand Up @@ -444,7 +444,7 @@ def parse_nc_file(file: Path) -> xr.Dataset:
CF/Zampy formatted xarray Dataset
"""
# Open chunked: will be dask array -> file writing can be parallelized.
ds = xr.open_dataset(file, chunks={"x": 50, "y": 50})
ds = xr.open_dataset(file, chunks={"x": 50, "y": 50}, engine="h5netcdf")

for variable in ds.variables:
if variable in var_reference_ecmwf_to_zampy:
Expand Down
6 changes: 4 additions & 2 deletions src/zampy/datasets/ecmwf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def load(
if var in variable_names:
files += (ingest_dir / self.name).glob(f"{self.name}_{var}*.nc")

ds = xr.open_mfdataset(files, chunks={"latitude": 200, "longitude": 200})
ds = xr.open_mfdataset(
files, chunks={"latitude": 200, "longitude": 200}, engine="h5netcdf"
)

# rename valid_time to time
if "valid_time" in ds.dims:
Expand Down Expand Up @@ -152,7 +154,7 @@ def convert(
for file in data_files:
# start conversion process
print(f"Start processing file `{file.name}`.")
ds = xr.open_dataset(file, chunks={"x": 50, "y": 50})
ds = xr.open_dataset(file, chunks={"x": 50, "y": 50}, engine="h5netcdf")
ds = converter.convert(ds, dataset=self, convention=convention)
# TODO: support derived variables
# TODO: other calculations
Expand Down
11 changes: 5 additions & 6 deletions src/zampy/datasets/eth_canopy_height.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def load(
if self.variable_names[1] in variable_names:
files += (ingest_dir / self.name).glob("*Map_SD.nc")

ds = xr.open_mfdataset(files, chunks={"latitude": 2000, "longitude": 2000})
ds = xr.open_mfdataset(
files, chunks={"latitude": 2000, "longitude": 2000}, engine="h5netcdf"
)
ds = ds.sel(time=slice(time_bounds.start, time_bounds.end))

grid = xarray_regrid.create_regridding_dataset(
Expand All @@ -159,7 +161,7 @@ def convert(
for file in data_files:
# start conversion process
print(f"Start processing file `{file.name}`.")
ds = xr.open_dataset(file, chunks={"x": 2000, "y": 2000})
ds = xr.open_dataset(file, chunks={"x": 2000, "y": 2000}, engine="h5netcdf")
ds = converter.convert(ds, dataset=self, convention=convention)
# TODO: support derived variables
# TODO: other calculations
Expand Down Expand Up @@ -249,10 +251,7 @@ def convert_tiff_to_netcdf(
ds = ds.interpolate_na(dim="longitude", limit=1)
ds = ds.interpolate_na(dim="latitude", limit=1)

ds.to_netcdf(
path=ncfile,
encoding=ds.encoding,
)
ds.to_netcdf(path=ncfile, encoding=ds.encoding, engine="h5netcdf")


def parse_tiff_file(file: Path, sd_file: bool = False) -> xr.Dataset:
Expand Down
11 changes: 8 additions & 3 deletions src/zampy/datasets/fapar_lai.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def load(
variable_names: list[str],
) -> xr.Dataset:
files = list((ingest_dir / self.name).glob("*.nc"))
ds = xr.open_mfdataset(files) # see issue 65
ds = xr.open_mfdataset(files, engine="h5netcdf") # see issue 65
ds = ds.sel(time=slice(time_bounds.start, time_bounds.end))

grid = xarray_regrid.create_regridding_dataset(
Expand All @@ -175,7 +175,9 @@ def convert( # Will be removed, see issue #43.
for file in data_files:
# start conversion process
print(f"Start processing file `{file.name}`.")
ds = xr.open_dataset(file, chunks={"latitude": 2000, "longitude": 2000})
ds = xr.open_dataset(
file, chunks={"latitude": 2000, "longitude": 2000}, engine="h5netcdf"
)
ds = converter.convert(ds, dataset=self, convention=convention)

return True
Expand Down Expand Up @@ -248,7 +250,9 @@ def download_fapar_lai(
def ingest_ncfile(ncfile: Path, ingest_folder: Path) -> None:
"""Ingest the 'raw' netCDF file to the Zampy standard format."""
print(f"Converting file {ncfile.name}...")
ds = xr.open_dataset(ncfile, decode_times=False, chunks={"lat": 5000, "lon": 5000})
ds = xr.open_dataset(
ncfile, decode_times=False, chunks={"lat": 5000, "lon": 5000}, engine="h5netcdf"
)
ds = ds.rename(
{
"LAI": "leaf_area_index",
Expand All @@ -260,6 +264,7 @@ def ingest_ncfile(ncfile: Path, ingest_folder: Path) -> None:
ds[["leaf_area_index"]].to_netcdf(
path=ingest_folder / ncfile.name,
encoding={"leaf_area_index": {"zlib": True, "complevel": 3}},
engine="h5netcdf",
)
ds.close() # explicitly close to release file to system (for Windows)

Expand Down
10 changes: 6 additions & 4 deletions src/zampy/datasets/land_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def load(
)
raise ValueError(msg)
files = list((ingest_dir / self.name).glob(f"{self.name}_*.nc"))
ds = xr.open_mfdataset(files, chunks={"latitude": 200, "longitude": 200})
ds = xr.open_mfdataset(
files, chunks={"latitude": 200, "longitude": 200}, engine="h5netcdf"
)
ds = ds.sel(time=slice(time_bounds.start, time_bounds.end))

grid = xarray_regrid.create_regridding_dataset(
Expand Down Expand Up @@ -173,7 +175,7 @@ def convert(

for file in data_files:
print(f"Start processing file `{file.name}`.")
ds = xr.open_dataset(file)
ds = xr.open_dataset(file, engine="h5netcdf")
ds = converter.convert(ds, dataset=self, convention=convention)

return True
Expand All @@ -197,7 +199,7 @@ def unzip_raw_to_netcdf(
print(f"File '{ncfile.name}' already exists, skipping...")
else:
ds = extract_netcdf_to_zampy(file)
ds.to_netcdf(path=ncfile)
ds.to_netcdf(path=ncfile, engine="h5netcdf")


def extract_netcdf_to_zampy(file: Path) -> xr.Dataset:
Expand All @@ -220,7 +222,7 @@ def extract_netcdf_to_zampy(file: Path) -> xr.Dataset:
zip_object.extract(zipped_file_name, path=unzip_folder)

# only keep land cover class variable
with xr.open_dataset(unzip_folder / zipped_file_name) as ds:
with xr.open_dataset(unzip_folder / zipped_file_name, engine="h5netcdf") as ds:
var_list = list(ds.data_vars)
raw_variable = "lccs_class"
var_list.remove(raw_variable)
Expand Down
5 changes: 3 additions & 2 deletions src/zampy/datasets/prism_dem.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def preproc(ds: xr.Dataset) -> xr.Dataset:
"""Remove overlapping coordinates on the edges."""
return ds.isel(latitude=slice(None, -1), longitude=slice(None, -1))

ds = xr.open_mfdataset(files, preprocess=preproc)
ds = xr.open_mfdataset(files, preprocess=preproc, engine="h5netcdf")

grid = xarray_regrid.create_regridding_dataset(
utils.make_grid(spatial_bounds, resolution)
Expand All @@ -168,7 +168,7 @@ def convert(
for file in data_files:
# start conversion process
print(f"Start processing file `{file.name}`.")
ds = xr.open_dataset(file)
ds = xr.open_dataset(file, engine="h5netcdf")
ds = converter.convert(ds, dataset=self, convention=convention)

return True
Expand Down Expand Up @@ -216,6 +216,7 @@ def convert_raw_dem_to_netcdf(
ds.to_netcdf(
path=ncfile,
encoding=ds.encoding,
engine="h5netcdf",
)


Expand Down
4 changes: 3 additions & 1 deletion src/zampy/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def run(self) -> None:
time_end = str(self.timebounds.end.astype("datetime64[Y]"))
# e.g. "era5_2010-2020.nc"
fname = f"{dataset_name.lower()}_{time_start}-{time_end}.nc"
ds.to_netcdf(path=self.data_dir / fname, encoding=encoding)
ds.to_netcdf(
path=self.data_dir / fname, encoding=encoding, engine="h5netcdf"
)
del ds

print(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_datasets/test_fapar_lai.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def test_ingest(self, dummy_dir):
lai_dataset = FaparLAI()
lai_dataset.ingest(download_dir=data_folder, ingest_dir=dummy_dir)

ds = xr.open_mfdataset((dummy_dir / "fapar-lai").glob("*.nc"))
ds = xr.open_mfdataset(
(dummy_dir / "fapar-lai").glob("*.nc"), engine="h5netcdf"
)
assert isinstance(ds, xr.Dataset)

def test_load(self, dummy_dir):
Expand Down
20 changes: 15 additions & 5 deletions tests/test_recipes/test_simple_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def test_recipe(tmp_path: Path, mocker):

rm.run()

ds = xr.open_mfdataset(str(tmp_path / "output" / "era5_recipe" / "*.nc"))
ds = xr.open_mfdataset(
str(tmp_path / "output" / "era5_recipe" / "*.nc"), engine="h5netcdf"
)
assert all(var in ds.data_vars for var in ["Psurf", "Wind_N"])
# Check if time frequency is correct
assert ds.time.diff("time").min() == np.timedelta64(1, "h")
Expand Down Expand Up @@ -85,7 +87,9 @@ def test_recipe_with_lower_frequency(tmp_path: Path, mocker):

rm.run()

ds = xr.open_mfdataset(str(tmp_path / "output" / "era5_recipe" / "*.nc"))
ds = xr.open_mfdataset(
str(tmp_path / "output" / "era5_recipe" / "*.nc"), engine="h5netcdf"
)
# check the lenght of the time dimension, mean values are used
assert len(ds.time) == 4

Expand Down Expand Up @@ -121,7 +125,9 @@ def test_recipe_with_higher_frequency(tmp_path: Path, mocker):

rm.run()

ds = xr.open_mfdataset(str(tmp_path / "output" / "era5_recipe" / "*.nc"))
ds = xr.open_mfdataset(
str(tmp_path / "output" / "era5_recipe" / "*.nc"), engine="h5netcdf"
)
# check the lenght of the time dimension, data is interpolated
assert len(ds.time) == 47

Expand Down Expand Up @@ -156,7 +162,9 @@ def test_recipe_with_two_time_values(tmp_path: Path, mocker):

rm.run()

ds = xr.open_mfdataset(str(tmp_path / "output" / "era5_recipe" / "*.nc"))
ds = xr.open_mfdataset(
str(tmp_path / "output" / "era5_recipe" / "*.nc"), engine="h5netcdf"
)
# check the lenght of the time dimension
assert len(ds.time) == 2

Expand Down Expand Up @@ -191,7 +199,9 @@ def test_recipe_with_one_time_values(tmp_path: Path, mocker):

rm.run()

ds = xr.open_mfdataset(str(tmp_path / "output" / "era5_recipe" / "*.nc"))
ds = xr.open_mfdataset(
str(tmp_path / "output" / "era5_recipe" / "*.nc"), engine="h5netcdf"
)
# check the lenght of the time dimension, should not do interpolation or
# extrapolation in time
assert len(ds.time) == 1
Expand Down

0 comments on commit ddef620

Please sign in to comment.