Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle S3 credential expiration more gracefully #354

Merged
merged 6 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 49 additions & 35 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import os
import shutil
import traceback
from copy import deepcopy
from functools import lru_cache
from itertools import chain
from pathlib import Path
from pickle import dumps, loads
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from uuid import uuid4

import earthaccess
Expand Down Expand Up @@ -97,8 +96,9 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
"""
if auth.authenticated is True:
self.auth = auth
self.s3_fs = None
self.initial_ts = datetime.datetime.now()
self._s3_credentials: Dict[
Tuple, Tuple[datetime.datetime, Dict[str, str]]
] = {}
MattF-NSIDC marked this conversation as resolved.
Show resolved Hide resolved
oauth_profile = "https://urs.earthdata.nasa.gov/profile"
# sets the initial URS cookie
self._requests_cookies: Dict[str, Any] = {}
Expand Down Expand Up @@ -182,7 +182,6 @@ def set_requests_session(
elif resp.status_code >= 500:
resp.raise_for_status()

@lru_cache
def get_s3fs_session(
self,
daac: Optional[str] = None,
Expand All @@ -200,39 +199,54 @@ def get_s3fs_session(
Returns:
a s3fs file instance
"""
if self.auth is not None:
if not any([concept_id, daac, provider, endpoint]):
raise ValueError(
"At least one of the concept_id, daac, provider or endpoint"
"parameters must be specified. "
)
if endpoint is not None:
s3_credentials = self.auth.get_s3_credentials(endpoint=endpoint)
elif concept_id is not None:
provider = self._derive_concept_provider(concept_id)
s3_credentials = self.auth.get_s3_credentials(provider=provider)
elif daac is not None:
s3_credentials = self.auth.get_s3_credentials(daac=daac)
elif provider is not None:
s3_credentials = self.auth.get_s3_credentials(provider=provider)
now = datetime.datetime.now()
delta_minutes = now - self.initial_ts
# TODO: test this mocking the time or use https://github.com/dbader/schedule
# if we exceed 1 hour
if (
self.s3_fs is None or round(delta_minutes.seconds / 60, 2) > 59
) and s3_credentials is not None:
self.s3_fs = s3fs.S3FileSystem(
key=s3_credentials["accessKeyId"],
secret=s3_credentials["secretAccessKey"],
token=s3_credentials["sessionToken"],
)
MattF-NSIDC marked this conversation as resolved.
Show resolved Hide resolved
self.initial_ts = datetime.datetime.now()
return deepcopy(self.s3_fs)
else:
if self.auth is None:
raise ValueError(
"A valid Earthdata login instance is required to retrieve S3 credentials"
)
if not any([concept_id, daac, provider, endpoint]):
raise ValueError(
"At least one of the concept_id, daac, provider or endpoint"
"parameters must be specified. "
)

if concept_id is not None:
provider = self._derive_concept_provider(concept_id)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯


# Get existing S3 credentials if we already have them
location = (
daac,
provider,
endpoint,
) # Identifier for where to get S3 credentials from
need_new_creds = False
try:
dt_init, creds = self._s3_credentials[location]
MattF-NSIDC marked this conversation as resolved.
Show resolved Hide resolved
except KeyError:
need_new_creds = True
else:
# If cached credentials are expired, invalidate the cache
delta = datetime.datetime.now() - dt_init
jrbourbeau marked this conversation as resolved.
Show resolved Hide resolved
if round(delta.seconds / 60, 2) > 55:
need_new_creds = True
self._s3_credentials.pop(location)

if need_new_creds:
# Don't have existing valid S3 credentials, so get new ones
now = datetime.datetime.now()
if endpoint is not None:
creds = self.auth.get_s3_credentials(endpoint=endpoint)
MattF-NSIDC marked this conversation as resolved.
Show resolved Hide resolved
elif daac is not None:
creds = self.auth.get_s3_credentials(daac=daac)
elif provider is not None:
creds = self.auth.get_s3_credentials(provider=provider)
# Include new credentials in the cache
self._s3_credentials[location] = now, creds

return s3fs.S3FileSystem(
key=creds["accessKeyId"],
secret=creds["secretAccessKey"],
token=creds["sessionToken"],
)

@lru_cache
def get_fsspec_session(self) -> fsspec.AbstractFileSystem:
Expand Down
32 changes: 21 additions & 11 deletions tests/unit/test_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import fsspec
import pytest
import responses
import s3fs
from earthaccess import Auth, Store


Expand Down Expand Up @@ -60,12 +61,22 @@ def test_store_can_create_s3_fsspec_session(self):
"https://api.giovanni.earthdata.nasa.gov/s3credentials",
"https://data.laadsdaac.earthdatacloud.nasa.gov/s3credentials",
]
mock_creds = {
"accessKeyId": "sure",
"secretAccessKey": "correct",
"sessionToken": "whynot",
}
expected_storage_options = {
"key": mock_creds["accessKeyId"],
"secret": mock_creds["secretAccessKey"],
"token": mock_creds["sessionToken"],
}

for endpoint in custom_endpoints:
responses.add(
responses.GET,
endpoint,
json={},
json=mock_creds,
status=200,
)

Expand All @@ -74,40 +85,39 @@ def test_store_can_create_s3_fsspec_session(self):
responses.add(
responses.GET,
daac["s3-credentials"],
json={
"accessKeyId": "sure",
"secretAccessKey": "correct",
"sessionToken": "whynot",
},
json=mock_creds,
status=200,
)
responses.add(
responses.GET,
"https://urs.earthdata.nasa.gov/profile",
json={},
json=mock_creds,
status=200,
)

store = Store(self.auth)
self.assertTrue(isinstance(store.auth, Auth))
for daac in ["NSIDC", "PODAAC", "LPDAAC", "ORNLDAAC", "GES_DISC", "ASF"]:
s3_fs = store.get_s3fs_session(daac=daac)
self.assertEqual(type(s3_fs), type(fsspec.filesystem("s3")))
assert isinstance(s3_fs, s3fs.S3FileSystem)
assert s3_fs.storage_options == expected_storage_options
MattF-NSIDC marked this conversation as resolved.
Show resolved Hide resolved

for endpoint in custom_endpoints:
s3_fs = store.get_s3fs_session(endpoint=endpoint)
self.assertEqual(type(s3_fs), type(fsspec.filesystem("s3")))
assert isinstance(s3_fs, s3fs.S3FileSystem)
assert s3_fs.storage_options == expected_storage_options

for provider in [
"NSIDC_CPRD",
"POCLOUD",
"LPCLOUD",
"ORNLCLOUD",
"ORNL_CLOUD",
MattF-NSIDC marked this conversation as resolved.
Show resolved Hide resolved
"GES_DISC",
"ASF",
]:
s3_fs = store.get_s3fs_session(provider=provider)
assert isinstance(s3_fs, fsspec.AbstractFileSystem)
assert isinstance(s3_fs, s3fs.S3FileSystem)
assert s3_fs.storage_options == expected_storage_options

# Ensure informative error is raised
with pytest.raises(ValueError, match="parameters must be specified"):
Expand Down
Loading