diff --git a/src/xscen/catalog.py b/src/xscen/catalog.py index 20e88568..841f12d3 100644 --- a/src/xscen/catalog.py +++ b/src/xscen/catalog.py @@ -541,6 +541,7 @@ def copy_files( dest: Union[str, os.PathLike], flat: bool = True, unzip: bool = False, + zipzarr: bool = False, inplace: bool = False, ): """Copy each file of the catalog to another location, unzipping datasets along the way if requested. @@ -558,6 +559,8 @@ def copy_files( Nothing is done in case of duplicates in that case. unzip: bool If True, any datasets with a `.zip` suffix are unzipped during the copy (or rather instead of a copy). + zipzarr: bool + If True, any datasets with a `.zarr` suffix are zipped during the copy (or rather instead of a copy). inplace : bool If True, the catalog is updated in place. If False (default), a copy is returned. @@ -568,7 +571,7 @@ def copy_files( """ # Local imports to avoid circular imports from .catutils import build_path - from .io import unzip_directory + from .io import unzip_directory, zip_directory dest = Path(dest) data = self.esmcat._df.copy() @@ -577,6 +580,8 @@ def copy_files( for path in map(Path, data.path.values): if unzip and path.suffix == ".zip": new = dest / path.with_suffix("").name + elif zipzarr and path.suffix == ".zarr": + new = dest / path.with_suffix(".zarr.zip").name else: new = dest / path.name if new in new_paths: @@ -598,6 +603,9 @@ def copy_files( if unzip and old.suffix == ".zip": logger.info(f"Unzipping {old} to {new}.") unzip_directory(old, new) + elif zipzarr and old.suffix == ".zarr": + logger.info(f"Zipping {old} to {new}.") + zip_directory(old, new) elif old.is_dir(): logger.info(f"Copying directory tree {old} to {new}.") sh.copytree(old, new) diff --git a/src/xscen/io.py b/src/xscen/io.py index 40073947..e7350c31 100644 --- a/src/xscen/io.py +++ b/src/xscen/io.py @@ -1000,7 +1000,10 @@ def rechunk( def zip_directory( - root: Union[str, os.PathLike], zipfile: Union[str, os.PathLike], **zip_args + root: Union[str, os.PathLike], + zipfile: Union[str, os.PathLike], + delete: bool = False, + **zip_args, ): r"""Make a zip archive of the content of a directory. @@ -1010,6 +1013,8 @@ def zip_directory( The directory with the content to archive. zipfile : path The zip file to create. + delete : bool + If True, the original directory is deleted after zipping. \*\*zip_args Any other arguments to pass to :py:mod:`zipfile.ZipFile`, such as "compression". The default is to make no compression (``compression=ZIP_STORED``). @@ -1026,6 +1031,9 @@ def _add_to_zip(zf, path, root): for file in root.iterdir(): _add_to_zip(zf, file, root) + if delete: + sh.rmtree(root) + def unzip_directory(zipfile: Union[str, os.PathLike], root: Union[str, os.PathLike]): r"""Unzip an archive to a directory. diff --git a/tests/test_catalog.py b/tests/test_catalog.py index abb324ac..8895c25d 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -1,7 +1,10 @@ +from pathlib import Path + import pandas as pd +import xarray as xr from conftest import SAMPLES_DIR -from xscen import catalog +from xscen import catalog, extract def test_subset_file_coverage(): @@ -35,3 +38,58 @@ def test_subset_file_coverage(): def test_xrfreq_fix(): cat = catalog.DataCatalog(SAMPLES_DIR.parent / "pangeo-cmip6.json") assert set(cat.df.xrfreq) == {"3h", "D", "fx"} + + +class TestCopyFiles: + def test_flat(self, samplecat, tmp_path): + newcat = samplecat.copy_files(tmp_path, flat=True) + assert len(list(tmp_path.glob("*.nc"))) == len(newcat.df) + + def test_inplace(self, samplecat, tmp_path): + dsid, scat = extract.search_data_catalogs( + data_catalogs=[samplecat], + variables_and_freqs={"tas": "MS"}, + allow_resampling=True, + other_search_criteria={ + "experiment": "ssp585", + "source": "NorESM.*", + "member": "r1i1p1f1", + }, + ).popitem() + scat.copy_files(tmp_path, inplace=True) + assert len(list(tmp_path.glob("*.nc"))) == len(scat.df) + + _, ds = extract.extract_dataset(scat).popitem() + frq = xr.infer_freq(ds.time) + assert frq == "MS" + + def test_zipunzip(self, samplecat, tmp_path): + dsid, scat = extract.search_data_catalogs( + data_catalogs=[samplecat], + variables_and_freqs={"tas": "D"}, + allow_resampling=True, + other_search_criteria={ + "experiment": "ssp585", + "source": "NorESM.*", + "member": "r1i1p1f1", + }, + ).popitem() + _, ds = extract.extract_dataset(scat).popitem() + ds.to_zarr(tmp_path / "temp.zarr") + scat.esmcat.df.loc[0, "path"] = tmp_path / "temp.zarr" + + rz = tmp_path / "zipped" + rz.mkdir() + scat_z = scat.copy_files(rz, zipzarr=True) + f = Path(scat_z.df.path.iloc[0]) + assert f.suffix == ".zip" + assert f.parent.name == rz.name + assert f.is_file() + + ru = tmp_path / "unzipped" + ru.mkdir() + scat_uz = scat.copy_files(ru, unzip=True) + f = Path(scat_uz.df.path.iloc[0]) + assert f.suffix == ".zarr" + assert f.parent.name == ru.name + assert f.is_dir() diff --git a/tests/test_indicators.py b/tests/test_indicators.py index 3b7dd0b2..79878eb0 100644 --- a/tests/test_indicators.py +++ b/tests/test_indicators.py @@ -208,3 +208,18 @@ def test_select_inds_for_avail_vars(self, indicator_iter): ) assert len(list(inds_for_avail_vars.iter_indicators())) == 0 assert [(n, i) for n, i in inds_for_avail_vars.iter_indicators()] == [] + + +@pytest.mark.parametrize( + "ind,expvars,expfrq", + [ + ("wind_vector_from_speed", ["uas", "vas"], "D"), + ("fit", ["params"], "fx"), + ("tg_mean", ["tg_mean"], "YS-JAN"), + ], +) +def test_get_indicator_outputs(ind, expvars, expfrq): + ind = xclim.core.indicator.registry[ind.upper()].get_instance() + outvars, outfrq = xs.indicators.get_indicator_outputs(ind, "D") + assert outvars == expvars + assert outfrq == expfrq