diff --git a/earthaccess/api.py b/earthaccess/api.py index dca9a5ac..37edba0b 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -248,7 +248,12 @@ def download( try: return earthaccess.__store__.get( - granules, local_path, provider, threads, access=access, pqdm_kwargs=pqdm_kwargs + granules, + local_path, + provider, + threads, + access=access, + pqdm_kwargs=pqdm_kwargs, ) except AttributeError as err: logger.error( diff --git a/earthaccess/auth.py b/earthaccess/auth.py index 5b2393fc..c0abf48b 100644 --- a/earthaccess/auth.py +++ b/earthaccess/auth.py @@ -180,7 +180,9 @@ def get_s3_credentials( ) if not auth_url: # Display possible typos in a helpfull error - raise Exception(f'auth_url not found using daac: "{daac}" and provider: "{provider}"') + raise Exception( + f'auth_url not found using daac: "{daac}" and provider: "{provider}"' + ) else: auth_url = endpoint if auth_url.startswith("https://"): diff --git a/earthaccess/results.py b/earthaccess/results.py index f3400268..fbdacab0 100644 --- a/earthaccess/results.py +++ b/earthaccess/results.py @@ -303,15 +303,14 @@ def _derive_s3_link(self, links: List[str]) -> List[str]: s3_links.append(f's3://{links[0].split("nasa.gov/")[1]}') return s3_links - def data_links( - self, access: Optional[str] = None) -> List[str]: + def data_links(self, access: Optional[str] = None) -> List[str]: """Placeholder. Returns the data links from a granule. Parameters: access: direct or external. - Direct means in-region access for cloud-hosted collections. + Direct means in-region access for cloud-hosted collections. Returns: The data links for the requested access type. diff --git a/earthaccess/store.py b/earthaccess/store.py index d6532508..7bc4f458 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -2,13 +2,13 @@ import logging import shutil import traceback +import warnings from functools import lru_cache from itertools import chain from pathlib import Path from pickle import dumps, loads from typing import Any, Dict, List, Mapping, Optional, Tuple, Union from uuid import uuid4 -import warnings import fsspec import requests @@ -81,24 +81,26 @@ def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFil def make_instance( - cls: Any, granule: DataGranule, auth: Auth, data: Any, out_of_region_handling: Optional[str] = "raise" + cls: Any, + granule: DataGranule, + auth: Auth, + data: Any, + out_of_region_handling: Optional[str] = "raise", ) -> EarthAccessFile: - """ - Creates an EarthAccessFile instance - + """Creates an EarthAccessFile instance + Parameters: cls: the datatype of a file system, such as s3fs.S3File granule: a granule search result auth: earthaccess.auth.Auth object data: dumped buffered file data out_of_region_handling: "raise" to raise an Exception if attempting out of region access or - "handle" (or anything else) to attempt using a granule's first + "handle" (or anything else) to attempt using a granule's first data link upon faliure Return: An EarthAccessFile object """ - # Attempt to re-authenticate if not earthaccess.__auth__.authenticated: earthaccess.__auth__ = auth @@ -121,7 +123,7 @@ def make_instance( "This may be caused by trying to access the data outside the us-west-2 region.\n" "Attempting on_prem access..." ) - + # NOTE: This uses the first data_link listed in the granule. That's not # guaranteed to be the right one. return EarthAccessFile(earthaccess.open([granule])[0], granule) @@ -187,7 +189,6 @@ def _own_s3_credentials(self, links: List[Dict[str, Any]]) -> Union[str, None]: return link["URL"] return None - def set_requests_session( self, url: str, method: str = "get", bearer_token: bool = False ) -> None: @@ -558,7 +559,9 @@ def get( **(pqdm_kwargs or {}), } - return self._get(granules, Path(local_path), provider, access=access, pqdm_kwargs=pqdm_kwargs) + return self._get( + granules, Path(local_path), provider, access=access, pqdm_kwargs=pqdm_kwargs + ) @singledispatchmethod def _get( @@ -642,7 +645,9 @@ def _get_urls( ) # if we are not in AWS - return self._download_onprem_granules(data_links, local_path, pqdm_kwargs=pqdm_kwargs) + return self._download_onprem_granules( + data_links, local_path, pqdm_kwargs=pqdm_kwargs + ) @_get.register def _get_granules( @@ -664,8 +669,7 @@ def _get_granules( access = "direct" if cloud_hosted else "external" data_links = list( chain.from_iterable( - granule.data_links(access=access) - for granule in granules + granule.data_links(access=access) for granule in granules ) ) total_size = round(sum(granule.size() for granule in granules) / 1024, 2) @@ -709,7 +713,9 @@ def _get_granules( # if the data are cloud-based, but we are not in AWS, # it will be downloaded as if it was on prem - return self._download_onprem_granules(data_links, local_path, pqdm_kwargs=pqdm_kwargs) + return self._download_onprem_granules( + data_links, local_path, pqdm_kwargs=pqdm_kwargs + ) def _download_file(self, url: str, directory: Path) -> str: """Download a single file from an on-prem location, a DAAC data center. @@ -765,7 +771,7 @@ def _download_onprem_granules( Returns: A list of local filepaths to which the files were downloaded. - """ + """ if urls is None: raise ValueError("The granules didn't provide a valid GET DATA link") if self.auth is None: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index ee679c2e..18c6d14d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -68,7 +68,9 @@ def mock_netrc(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch): username = os.environ["EARTHDATA_USERNAME"] password = os.environ["EARTHDATA_PASSWORD"] else: - raise Exception("Unable to mock a .netrc without EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables") + raise Exception( + "Unable to mock a .netrc without EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables" + ) netrc.write_text( f"machine urs.earthdata.nasa.gov login {username} password {password}\n" diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 8bf4b5fc..3fa7d7b8 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -5,7 +5,6 @@ import earthaccess import pytest -import s3fs logger = logging.getLogger(__name__) @@ -98,10 +97,14 @@ def test_download_immediate_failure(tmp_path: Path, access): # any further downloads. if not access or access == "direct": with patch("s3fs.S3FileSystem", fail_to_download_file): - earthaccess.download(results, tmp_path, access=access, pqdm_kwargs=dict(disable=True)) + earthaccess.download( + results, tmp_path, access=access, pqdm_kwargs=dict(disable=True) + ) elif access == "external": with patch("earthaccess.__store__._download_file", fail_to_download_file): - earthaccess.download(results, tmp_path, access=access, pqdm_kwargs=dict(disable=True)) + earthaccess.download( + results, tmp_path, access=access, pqdm_kwargs=dict(disable=True) + ) @pytest.mark.parametrize("access", [None, "direct", "external"]) @@ -117,9 +120,9 @@ def test_download_deferred_failure(tmp_path: Path, access): with pytest.raises(Exception) as exc_info: # With "deferred" exceptions, pqdm catches all exceptions, then at the end # raises a single generic Exception, passing the sequence of caught exceptions - # as arguments to the Exception constructor. + # as arguments to the Exception constructor. if not access or access == "direct": - with patch("s3fs.S3FileSystem", fail_to_download_file): + with patch("s3fs.S3FileSystem", fail_to_download_file): earthaccess.download( results, tmp_path, @@ -127,7 +130,7 @@ def test_download_deferred_failure(tmp_path: Path, access): pqdm_kwargs=dict(exception_behaviour="deferred", disable=True), ) elif access == "external": - with patch("earthaccess.__store__._download_file", fail_to_download_file): + with patch("earthaccess.__store__._download_file", fail_to_download_file): earthaccess.download( results, tmp_path, @@ -137,7 +140,9 @@ def test_download_deferred_failure(tmp_path: Path, access): errors = exc_info.value.args assert len(errors) == count - assert all(isinstance(e, IOError) and str(e) == "Download failed" for e in errors) + assert all( + isinstance(e, IOError) and str(e) == "Download failed" for e in errors + ) def test_auth_environ():