From 7b7b13240b6cd1e869761dcb30f8121f0a0a98f4 Mon Sep 17 00:00:00 2001 From: Chuck Daniels Date: Mon, 4 Nov 2024 16:07:51 -0500 Subject: [PATCH] Fix pqdm_kwargs integration tests --- earthaccess/api.py | 40 +++++++++--------- earthaccess/store.py | 94 +++++++++++++++++++++++------------------- tests/unit/test_api.py | 52 +++++++++++------------ 3 files changed, 97 insertions(+), 89 deletions(-) diff --git a/earthaccess/api.py b/earthaccess/api.py index 5ad75ed5..6b758aa2 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -1,4 +1,5 @@ import logging +from pathlib import Path import requests import s3fs @@ -202,9 +203,10 @@ def login(strategy: str = "all", persist: bool = False, system: System = PROD) - def download( granules: Union[DataGranule, List[DataGranule], str, List[str]], - local_path: Optional[str], + local_path: Optional[Union[Path, str]] = None, provider: Optional[str] = None, threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -215,7 +217,11 @@ def download( Parameters: granules: a granule, list of granules, a granule link (HTTP), or a list of granule links (HTTP) - local_path: local directory to store the remote data granules + local_path: Local directory to store the remote data granules. If not + supplied, defaults to a subdirectory of the current working directory + of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year, + month, and day of the current date, and `UUID` is the last 6 digits + of a UUID4 value. provider: if we download a list of URLs, we need to specify the provider. threads: parallel number of threads to use to download the files, adjust as necessary, default = 8 pqdm_kwargs: Additional keyword arguments to pass to pqdm, a parallel processing library. @@ -228,31 +234,29 @@ def download( Raises: Exception: A file download failed. """ - provider = _normalize_location(provider) - pqdm_kwargs = { - "exception_behavior": "immediate", - "n_jobs": threads, - **(pqdm_kwargs or {}), - } + provider = _normalize_location(str(provider)) + if isinstance(granules, DataGranule): granules = [granules] elif isinstance(granules, str): granules = [granules] + try: - results = earthaccess.__store__.get( - granules, local_path, provider, threads, pqdm_kwargs + return earthaccess.__store__.get( + granules, local_path, provider, threads, pqdm_kwargs=pqdm_kwargs ) except AttributeError as err: logger.error( f"{err}: You must call earthaccess.login() before you can download data" ) - return [] - return results + + return [] def open( granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[AbstractFileSystem]: """Returns a list of file-like objects that can be used to access files @@ -269,15 +273,11 @@ def open( Returns: A list of "file pointers" to remote (i.e. s3 or https) files. """ - provider = _normalize_location(provider) - pqdm_kwargs = { - "exception_behavior": "immediate", - **(pqdm_kwargs or {}), - } - results = earthaccess.__store__.open( - granules=granules, provider=provider, pqdm_kwargs=pqdm_kwargs + return earthaccess.__store__.open( + granules=granules, + provider=_normalize_location(provider), + pqdm_kwargs=pqdm_kwargs, ) - return results def get_s3_credentials( diff --git a/earthaccess/store.py b/earthaccess/store.py index f7b5c85e..58ac9f59 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -63,7 +63,7 @@ def __repr__(self) -> str: def _open_files( url_mapping: Mapping[str, Union[DataGranule, None]], fs: fsspec.AbstractFileSystem, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.spec.AbstractBufferedFile]: def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFile: @@ -71,14 +71,12 @@ def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFil return EarthAccessFile(fs.open(url), granule) # type: ignore pqdm_kwargs = { - "exception_behavior": "immediate", + "exception_behaviour": "immediate", + "n_jobs": 8, **(pqdm_kwargs or {}), } - fileset = pqdm( - url_mapping.items(), multi_thread_open, n_jobs=threads, **pqdm_kwargs - ) - return fileset + return pqdm(url_mapping.items(), multi_thread_open, **pqdm_kwargs) def make_instance( @@ -344,6 +342,7 @@ def open( self, granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.spec.AbstractBufferedFile]: """Returns a list of file-like objects that can be used to access files @@ -361,7 +360,7 @@ def open( A list of "file pointers" to remote (i.e. s3 or https) files. """ if len(granules): - return self._open(granules, provider, pqdm_kwargs) + return self._open(granules, provider, pqdm_kwargs=pqdm_kwargs) return [] @singledispatchmethod @@ -369,6 +368,7 @@ def _open( self, granules: Union[List[str], List[DataGranule]], provider: Optional[str] = None, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: raise NotImplementedError("granules should be a list of DataGranule or URLs") @@ -378,7 +378,8 @@ def _open_granules( self, granules: List[DataGranule], provider: Optional[str] = None, - threads: int = 8, + *, + pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: fileset: List = [] total_size = round(sum([granule.size() for granule in granules]) / 1024, 2) @@ -411,7 +412,7 @@ def _open_granules( fileset = _open_files( url_mapping, fs=s3_fs, - threads=threads, + pqdm_kwargs=pqdm_kwargs, ) except Exception as e: raise RuntimeError( @@ -420,19 +421,19 @@ def _open_granules( f"Exception: {traceback.format_exc()}" ) from e else: - fileset = self._open_urls_https(url_mapping, threads=threads) - return fileset + fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs) else: url_mapping = _get_url_granule_mapping(granules, access="on_prem") - fileset = self._open_urls_https(url_mapping, threads=threads) - return fileset + fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs) + + return fileset @_open.register def _open_urls( self, granules: List[str], provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: fileset: List = [] @@ -460,7 +461,6 @@ def _open_urls( fileset = _open_files( url_mapping, fs=s3_fs, - threads=threads, pqdm_kwargs=pqdm_kwargs, ) except Exception as e: @@ -481,15 +481,16 @@ def _open_urls( raise ValueError( "We cannot open S3 links when we are not in-region, try using HTTPS links" ) - fileset = self._open_urls_https(url_mapping, threads, pqdm_kwargs) + fileset = self._open_urls_https(url_mapping, pqdm_kwargs=pqdm_kwargs) return fileset def get( self, granules: Union[List[DataGranule], List[str]], - local_path: Union[Path, str, None] = None, + local_path: Optional[Union[Path, str]] = None, provider: Optional[str] = None, threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -503,7 +504,11 @@ def get( Parameters: granules: A list of granules(DataGranule) instances or a list of granule links (HTTP). - local_path: Local directory to store the remote data granules. + local_path: Local directory to store the remote data granules. If not + supplied, defaults to a subdirectory of the current working directory + of the form `data/YYYY-MM-DD-UUID`, where `YYYY-MM-DD` is the year, + month, and day of the current date, and `UUID` is the last 6 digits + of a UUID4 value. provider: a valid cloud provider, each DAAC has a provider code for their cloud distributions threads: Parallel number of threads to use to download the files; adjust as necessary, default = 8. @@ -514,18 +519,20 @@ def get( Returns: List of downloaded files """ + if not granules: + raise ValueError("List of URLs or DataGranule instances expected") + if local_path is None: - today = datetime.datetime.today().strftime("%Y-%m-%d") + today = datetime.datetime.now().strftime("%Y-%m-%d") uuid = uuid4().hex[:6] local_path = Path.cwd() / "data" / f"{today}-{uuid}" - elif isinstance(local_path, str): - local_path = Path(local_path) - if len(granules): - files = self._get(granules, local_path, provider, threads, pqdm_kwargs) - return files - else: - raise ValueError("List of URLs or DataGranule instances expected") + pqdm_kwargs = { + "n_jobs": threads, + **(pqdm_kwargs or {}), + } + + return self._get(granules, Path(local_path), provider, pqdm_kwargs=pqdm_kwargs) @singledispatchmethod def _get( @@ -533,7 +540,7 @@ def _get( granules: Union[List[DataGranule], List[str]], local_path: Path, provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: """Retrieves data granules from a remote storage system. @@ -566,7 +573,7 @@ def _get_urls( granules: List[str], local_path: Path, provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: data_links = granules @@ -590,7 +597,7 @@ def _get_urls( else: # if we are not in AWS return self._download_onprem_granules( - data_links, local_path, threads, pqdm_kwargs + data_links, local_path, pqdm_kwargs=pqdm_kwargs ) @_get.register @@ -599,7 +606,7 @@ def _get_granules( granules: List[DataGranule], local_path: Path, provider: Optional[str] = None, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[str]: data_links: List = [] @@ -615,7 +622,7 @@ def _get_granules( for granule in granules ) ) - total_size = round(sum([granule.size() for granule in granules]) / 1024, 2) + total_size = round(sum(granule.size() for granule in granules) / 1024, 2) logger.info( f" Getting {len(granules)} granules, approx download size: {total_size} GB" ) @@ -642,7 +649,7 @@ def _get_granules( # if the data are cloud-based, but we are not in AWS, # it will be downloaded as if it was on prem return self._download_onprem_granules( - data_links, local_path, threads, pqdm_kwargs + data_links, local_path, pqdm_kwargs=pqdm_kwargs ) def _download_file(self, url: str, directory: Path) -> str: @@ -684,7 +691,7 @@ def _download_onprem_granules( self, urls: List[str], directory: Path, - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[Any]: """Downloads a list of URLS into the data directory. @@ -711,25 +718,26 @@ def _download_onprem_granules( arguments = [(url, directory) for url in urls] - results = pqdm( - arguments, - self._download_file, - n_jobs=threads, - argument_type="args", - **pqdm_kwargs, - ) - return results + pqdm_kwargs = { + "exception_behaviour": "immediate", + **(pqdm_kwargs or {}), + # We don't want a user to be able to override the following kwargs, + # which is why they appear *after* spreading pqdm_kwargs above. + "argument_type": "args", + } + + return pqdm(arguments, self._download_file, **pqdm_kwargs) def _open_urls_https( self, url_mapping: Mapping[str, Union[DataGranule, None]], - threads: int = 8, + *, pqdm_kwargs: Optional[Mapping[str, Any]] = None, ) -> List[fsspec.AbstractFileSystem]: https_fs = self.get_fsspec_session() try: - return _open_files(url_mapping, https_fs, threads, pqdm_kwargs) + return _open_files(url_mapping, https_fs, pqdm_kwargs=pqdm_kwargs) except Exception: logger.exception( "An exception occurred while trying to access remote files via HTTPS" diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 20980e35..0f50fe93 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,10 +1,15 @@ -from unittest.mock import Mock +from pathlib import Path +from unittest.mock import patch import earthaccess import pytest -def test_download_immediate_failure(monkeypatch): +def fail_to_download_file(*args, **kwargs): + raise IOError("Download failed") + + +def test_download_immediate_failure(tmp_path: Path): earthaccess.login() results = earthaccess.search_data( @@ -14,37 +19,32 @@ def test_download_immediate_failure(monkeypatch): count=10, ) - def mock_get(*args, **kwargs): - raise Exception("Download failed") - - mock_store = Mock() - monkeypatch.setattr(earthaccess, "__store__", mock_store) - monkeypatch.setattr(mock_store, "get", mock_get) + with patch.object(earthaccess.__store__, "_download_file", fail_to_download_file): + with pytest.raises(IOError, match="Download failed"): + earthaccess.download(results, str(tmp_path)) - with pytest.raises(Exception, match="Download failed"): - earthaccess.download(results, "/home/download-folder") - -def test_download_deferred_failure(monkeypatch): +def test_download_deferred_failure(tmp_path: Path): earthaccess.login() + count = 3 results = earthaccess.search_data( short_name="ATL06", bounding_box=(-10, 20, 10, 50), temporal=("1999-02", "2019-03"), - count=10, - ) - - def mock_get(*args, **kwargs): - return [Exception("Download failed")] * len(results) - - mock_store = Mock() - monkeypatch.setattr(earthaccess, "__store__", mock_store) - monkeypatch.setattr(mock_store, "get", mock_get) - - results = earthaccess.download( - results, "/home/download-folder", None, 8, {"exception_behavior": "deferred"} + count=count, ) - assert all(isinstance(e, Exception) for e in results) - assert len(results) == 10 + with patch.object(earthaccess.__store__, "_download_file", fail_to_download_file): + with pytest.raises(Exception) as exc_info: + earthaccess.download( + results, + tmp_path, + None, + 8, + pqdm_kwargs={"exception_behaviour": "deferred"}, + ) + + errors = exc_info.value.args + assert len(errors) == count + assert all(isinstance(e, IOError) and str(e) == "Download failed" for e in errors)