Skip to content

Commit

Permalink
Handle S3 credential expiration more gracefully
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Nov 14, 2023
1 parent 7db2e59 commit 64c0976
Showing 1 changed file with 41 additions and 33 deletions.
74 changes: 41 additions & 33 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import shutil
import traceback
from copy import deepcopy
from functools import lru_cache
from itertools import chain
from pathlib import Path
Expand Down Expand Up @@ -97,8 +96,9 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
"""
if auth.authenticated is True:
self.auth = auth
self.s3_fs = None
self.initial_ts = datetime.datetime.now()
self._s3_credentials: Dict[
tuple, tuple[datetime.datetime, Dict[str, str]]
] = {}
oauth_profile = "https://urs.earthdata.nasa.gov/profile"
# sets the initial URS cookie
self._requests_cookies: Dict[str, Any] = {}
Expand Down Expand Up @@ -182,7 +182,6 @@ def set_requests_session(
elif resp.status_code >= 500:
resp.raise_for_status()

@lru_cache
def get_s3fs_session(
self,
daac: Optional[str] = None,
Expand All @@ -200,40 +199,49 @@ def get_s3fs_session(
Returns:
a s3fs file instance
"""
if self.auth is not None:
if not any([concept_id, daac, provider, endpoint]):
raise ValueError(
"At least one of the concept_id, daac, provider or endpoint"
"parameters must be specified. "
)
if self.auth is None:
raise ValueError(
"A valid Earthdata login instance is required to retrieve S3 credentials"
)
if not any([concept_id, daac, provider, endpoint]):
raise ValueError(
"At least one of the concept_id, daac, provider or endpoint"
"parameters must be specified. "
)

# Get existing S3 credentials if we already have them
location = (concept_id, daac, provider, endpoint)
need_new_creds = False
try:
dt_init, creds = self._s3_credentials[location]
except KeyError:
need_new_creds = True
else:
delta = datetime.datetime.now() - dt_init
if round(delta.seconds / 60, 2) > 55:
need_new_creds = True
self._s3_credentials.pop(location)

if need_new_creds:
# Don't have existing valid S3 credentials, so get new ones
now = datetime.datetime.now()
if endpoint is not None:
s3_credentials = self.auth.get_s3_credentials(endpoint=endpoint)
creds = self.auth.get_s3_credentials(endpoint=endpoint)
elif concept_id is not None:
provider = self._derive_concept_provider(concept_id)
s3_credentials = self.auth.get_s3_credentials(provider=provider)
creds = self.auth.get_s3_credentials(provider=provider)
elif daac is not None:
s3_credentials = self.auth.get_s3_credentials(daac=daac)
creds = self.auth.get_s3_credentials(daac=daac)
elif provider is not None:
s3_credentials = self.auth.get_s3_credentials(provider=provider)
now = datetime.datetime.now()
delta_minutes = now - self.initial_ts
# TODO: test this mocking the time or use https://github.com/dbader/schedule
# if we exceed 1 hour
if (
self.s3_fs is None or round(delta_minutes.seconds / 60, 2) > 59
) and s3_credentials is not None:
self.s3_fs = s3fs.S3FileSystem(
key=s3_credentials["accessKeyId"],
secret=s3_credentials["secretAccessKey"],
token=s3_credentials["sessionToken"],
)
self.initial_ts = datetime.datetime.now()
return deepcopy(self.s3_fs)
else:
print(
"A valid Earthdata login instance is required to retrieve S3 credentials"
)
return None
creds = self.auth.get_s3_credentials(provider=provider)
# Include new credentials in the cache
self._s3_credentials[location] = now, creds

return s3fs.S3FileSystem(
key=creds["accessKeyId"],
secret=creds["secretAccessKey"],
token=creds["sessionToken"],
)

@lru_cache
def get_fsspec_session(self) -> fsspec.AbstractFileSystem:
Expand Down

0 comments on commit 64c0976

Please sign in to comment.