Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 16, 2024
1 parent 3652637 commit 9f34d5a
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 28 deletions.
7 changes: 6 additions & 1 deletion earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,12 @@ def download(

try:
return earthaccess.__store__.get(
granules, local_path, provider, threads, access=access, pqdm_kwargs=pqdm_kwargs
granules,
local_path,
provider,
threads,
access=access,
pqdm_kwargs=pqdm_kwargs,
)
except AttributeError as err:
logger.error(
Expand Down
4 changes: 3 additions & 1 deletion earthaccess/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ def get_s3_credentials(
)
if not auth_url:
# Display possible typos in a helpfull error
raise Exception(f'auth_url not found using daac: "{daac}" and provider: "{provider}"')
raise Exception(
f'auth_url not found using daac: "{daac}" and provider: "{provider}"'
)
else:
auth_url = endpoint
if auth_url.startswith("https://"):
Expand Down
5 changes: 2 additions & 3 deletions earthaccess/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,14 @@ def _derive_s3_link(self, links: List[str]) -> List[str]:
s3_links.append(f's3://{links[0].split("nasa.gov/")[1]}')
return s3_links

def data_links(
self, access: Optional[str] = None) -> List[str]:
def data_links(self, access: Optional[str] = None) -> List[str]:
"""Placeholder.
Returns the data links from a granule.
Parameters:
access: direct or external.
Direct means in-region access for cloud-hosted collections.
Direct means in-region access for cloud-hosted collections.
Returns:
The data links for the requested access type.
Expand Down
36 changes: 21 additions & 15 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import logging
import shutil
import traceback
import warnings
from functools import lru_cache
from itertools import chain
from pathlib import Path
from pickle import dumps, loads
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from uuid import uuid4
import warnings

import fsspec
import requests
Expand Down Expand Up @@ -81,24 +81,26 @@ def multi_thread_open(data: tuple[str, Optional[DataGranule]]) -> EarthAccessFil


def make_instance(
cls: Any, granule: DataGranule, auth: Auth, data: Any, out_of_region_handling: Optional[str] = "raise"
cls: Any,
granule: DataGranule,
auth: Auth,
data: Any,
out_of_region_handling: Optional[str] = "raise",
) -> EarthAccessFile:
"""
Creates an EarthAccessFile instance
"""Creates an EarthAccessFile instance
Parameters:
cls: the datatype of a file system, such as s3fs.S3File
granule: a granule search result
auth: earthaccess.auth.Auth object
data: dumped buffered file data
out_of_region_handling: "raise" to raise an Exception if attempting out of region access or
"handle" (or anything else) to attempt using a granule's first
"handle" (or anything else) to attempt using a granule's first
data link upon faliure
Return: An EarthAccessFile object
"""

# Attempt to re-authenticate
if not earthaccess.__auth__.authenticated:
earthaccess.__auth__ = auth
Expand All @@ -121,7 +123,7 @@ def make_instance(
"This may be caused by trying to access the data outside the us-west-2 region.\n"
"Attempting on_prem access..."
)

# NOTE: This uses the first data_link listed in the granule. That's not
# guaranteed to be the right one.
return EarthAccessFile(earthaccess.open([granule])[0], granule)
Expand Down Expand Up @@ -187,7 +189,6 @@ def _own_s3_credentials(self, links: List[Dict[str, Any]]) -> Union[str, None]:
return link["URL"]
return None


def set_requests_session(
self, url: str, method: str = "get", bearer_token: bool = False
) -> None:
Expand Down Expand Up @@ -558,7 +559,9 @@ def get(
**(pqdm_kwargs or {}),
}

return self._get(granules, Path(local_path), provider, access=access, pqdm_kwargs=pqdm_kwargs)
return self._get(
granules, Path(local_path), provider, access=access, pqdm_kwargs=pqdm_kwargs
)

@singledispatchmethod
def _get(
Expand Down Expand Up @@ -642,7 +645,9 @@ def _get_urls(
)

# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, pqdm_kwargs=pqdm_kwargs)
return self._download_onprem_granules(
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

@_get.register
def _get_granules(
Expand All @@ -664,8 +669,7 @@ def _get_granules(
access = "direct" if cloud_hosted else "external"
data_links = list(
chain.from_iterable(
granule.data_links(access=access)
for granule in granules
granule.data_links(access=access) for granule in granules
)
)
total_size = round(sum(granule.size() for granule in granules) / 1024, 2)
Expand Down Expand Up @@ -709,7 +713,9 @@ def _get_granules(

# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, pqdm_kwargs=pqdm_kwargs)
return self._download_onprem_granules(
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.
Expand Down Expand Up @@ -765,7 +771,7 @@ def _download_onprem_granules(
Returns:
A list of local filepaths to which the files were downloaded.
"""
"""
if urls is None:
raise ValueError("The granules didn't provide a valid GET DATA link")
if self.auth is None:
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def mock_netrc(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
username = os.environ["EARTHDATA_USERNAME"]
password = os.environ["EARTHDATA_PASSWORD"]
else:
raise Exception("Unable to mock a .netrc without EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables")
raise Exception(
"Unable to mock a .netrc without EARTHDATA_USERNAME and EARTHDATA_PASSWORD environment variables"
)

netrc.write_text(
f"machine urs.earthdata.nasa.gov login {username} password {password}\n"
Expand Down
19 changes: 12 additions & 7 deletions tests/integration/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import earthaccess
import pytest
import s3fs

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -98,10 +97,14 @@ def test_download_immediate_failure(tmp_path: Path, access):
# any further downloads.
if not access or access == "direct":
with patch("s3fs.S3FileSystem", fail_to_download_file):
earthaccess.download(results, tmp_path, access=access, pqdm_kwargs=dict(disable=True))
earthaccess.download(
results, tmp_path, access=access, pqdm_kwargs=dict(disable=True)
)
elif access == "external":
with patch("earthaccess.__store__._download_file", fail_to_download_file):
earthaccess.download(results, tmp_path, access=access, pqdm_kwargs=dict(disable=True))
earthaccess.download(
results, tmp_path, access=access, pqdm_kwargs=dict(disable=True)
)


@pytest.mark.parametrize("access", [None, "direct", "external"])
Expand All @@ -117,17 +120,17 @@ def test_download_deferred_failure(tmp_path: Path, access):
with pytest.raises(Exception) as exc_info:
# With "deferred" exceptions, pqdm catches all exceptions, then at the end
# raises a single generic Exception, passing the sequence of caught exceptions
# as arguments to the Exception constructor.
# as arguments to the Exception constructor.
if not access or access == "direct":
with patch("s3fs.S3FileSystem", fail_to_download_file):
with patch("s3fs.S3FileSystem", fail_to_download_file):
earthaccess.download(
results,
tmp_path,
access=access,
pqdm_kwargs=dict(exception_behaviour="deferred", disable=True),
)
elif access == "external":
with patch("earthaccess.__store__._download_file", fail_to_download_file):
with patch("earthaccess.__store__._download_file", fail_to_download_file):
earthaccess.download(
results,
tmp_path,
Expand All @@ -137,7 +140,9 @@ def test_download_deferred_failure(tmp_path: Path, access):

errors = exc_info.value.args
assert len(errors) == count
assert all(isinstance(e, IOError) and str(e) == "Download failed" for e in errors)
assert all(
isinstance(e, IOError) and str(e) == "Download failed" for e in errors
)


def test_auth_environ():
Expand Down

0 comments on commit 9f34d5a

Please sign in to comment.