diff --git a/earthaccess/api.py b/earthaccess/api.py index 4ef8598d..843aba65 100644 --- a/earthaccess/api.py +++ b/earthaccess/api.py @@ -12,6 +12,21 @@ from .utils import _validation as validate +def _normalize_location(location: Union[str, None]) -> Union[str, None]: + """Handle user-provided `daac` and `provider` values + + These values must have a capital letter as the first character + followed by capital letters, numbers, or an underscore. Here we + uppercase all strings to handle the case when users provide + lowercase values (e.g. "pocloud" instead of "POCLOUD"). + + https://wiki.earthdata.nasa.gov/display/ED/CMR+Data+Partner+User+Guide?src=contextnavpagetreemode + """ + if location is not None: + location = location.upper() + return location + + def search_datasets( count: int = -1, **kwargs: Any ) -> List[earthaccess.results.DataCollection]: @@ -170,6 +185,7 @@ def download( Returns: List of downloaded files """ + provider = _normalize_location(provider) if isinstance(granules, DataGranule): granules = [granules] try: @@ -194,6 +210,7 @@ def open( Returns: a list of s3fs "file pointers" to s3 files. """ + provider = _normalize_location(provider) results = earthaccess.__store__.open(granules=granules, provider=provider) return results @@ -215,10 +232,8 @@ def get_s3_credentials( Returns: a dictionary with S3 credentials for the DAAC or provider """ - if daac is not None: - daac = daac.upper() - if provider is not None: - provider = provider.upper() + daac = _normalize_location(daac) + provider = _normalize_location(provider) if results is not None: endpoint = results[0].get_s3_credentials_endpoint() return earthaccess.__auth__.get_s3_credentials(endpoint=endpoint) @@ -315,6 +330,8 @@ def get_s3fs_session( Returns: class s3fs.S3FileSystem: an authenticated s3fs session valid for 1 hour """ + daac = _normalize_location(daac) + provider = _normalize_location(provider) if results is not None: endpoint = results[0].get_s3_credentials_endpoint() if endpoint is not None: diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index a61f3513..d1bfae1e 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -6,6 +6,7 @@ import earthaccess import pytest +import s3fs logger = logging.getLogger(__name__) assertions = unittest.TestCase("__init__") @@ -94,3 +95,24 @@ def test_auth_can_fetch_s3_credentials(): print( f"An error occured while trying to fetch S3 credentials for {daac['short-name']}: {e}" ) + + +@pytest.mark.parametrize("location", ({"daac": "podaac"}, {"provider": "pocloud"})) +def test_get_s3_credentials_lowercase_location(location): + activate_environment() + earthaccess.login(strategy="environment") + creds = earthaccess.get_s3_credentials(**location) + assert creds + assert all( + creds[key] + for key in ["accessKeyId", "secretAccessKey", "sessionToken", "expiration"] + ) + + +@pytest.mark.parametrize("location", ({"daac": "podaac"}, {"provider": "pocloud"})) +def test_get_s3fs_session_lowercase_location(location): + activate_environment() + earthaccess.login(strategy="environment") + fs = earthaccess.get_s3fs_session(**location) + assert isinstance(fs, s3fs.S3FileSystem) + assert all(fs.storage_options[key] for key in ["key", "secret", "token"])