From 64c0976f93557109c6c346babc7ad9676f6e6cf4 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Tue, 14 Nov 2023 13:38:16 -0600 Subject: [PATCH] Handle S3 credential expiration more gracefully --- earthaccess/store.py | 74 ++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 33 deletions(-) diff --git a/earthaccess/store.py b/earthaccess/store.py index cfe7bc79..1e089aaa 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -2,7 +2,6 @@ import os import shutil import traceback -from copy import deepcopy from functools import lru_cache from itertools import chain from pathlib import Path @@ -97,8 +96,9 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None: """ if auth.authenticated is True: self.auth = auth - self.s3_fs = None - self.initial_ts = datetime.datetime.now() + self._s3_credentials: Dict[ + tuple, tuple[datetime.datetime, Dict[str, str]] + ] = {} oauth_profile = "https://urs.earthdata.nasa.gov/profile" # sets the initial URS cookie self._requests_cookies: Dict[str, Any] = {} @@ -182,7 +182,6 @@ def set_requests_session( elif resp.status_code >= 500: resp.raise_for_status() - @lru_cache def get_s3fs_session( self, daac: Optional[str] = None, @@ -200,40 +199,49 @@ def get_s3fs_session( Returns: a s3fs file instance """ - if self.auth is not None: - if not any([concept_id, daac, provider, endpoint]): - raise ValueError( - "At least one of the concept_id, daac, provider or endpoint" - "parameters must be specified. " - ) + if self.auth is None: + raise ValueError( + "A valid Earthdata login instance is required to retrieve S3 credentials" + ) + if not any([concept_id, daac, provider, endpoint]): + raise ValueError( + "At least one of the concept_id, daac, provider or endpoint" + "parameters must be specified. " + ) + + # Get existing S3 credentials if we already have them + location = (concept_id, daac, provider, endpoint) + need_new_creds = False + try: + dt_init, creds = self._s3_credentials[location] + except KeyError: + need_new_creds = True + else: + delta = datetime.datetime.now() - dt_init + if round(delta.seconds / 60, 2) > 55: + need_new_creds = True + self._s3_credentials.pop(location) + + if need_new_creds: + # Don't have existing valid S3 credentials, so get new ones + now = datetime.datetime.now() if endpoint is not None: - s3_credentials = self.auth.get_s3_credentials(endpoint=endpoint) + creds = self.auth.get_s3_credentials(endpoint=endpoint) elif concept_id is not None: provider = self._derive_concept_provider(concept_id) - s3_credentials = self.auth.get_s3_credentials(provider=provider) + creds = self.auth.get_s3_credentials(provider=provider) elif daac is not None: - s3_credentials = self.auth.get_s3_credentials(daac=daac) + creds = self.auth.get_s3_credentials(daac=daac) elif provider is not None: - s3_credentials = self.auth.get_s3_credentials(provider=provider) - now = datetime.datetime.now() - delta_minutes = now - self.initial_ts - # TODO: test this mocking the time or use https://github.com/dbader/schedule - # if we exceed 1 hour - if ( - self.s3_fs is None or round(delta_minutes.seconds / 60, 2) > 59 - ) and s3_credentials is not None: - self.s3_fs = s3fs.S3FileSystem( - key=s3_credentials["accessKeyId"], - secret=s3_credentials["secretAccessKey"], - token=s3_credentials["sessionToken"], - ) - self.initial_ts = datetime.datetime.now() - return deepcopy(self.s3_fs) - else: - print( - "A valid Earthdata login instance is required to retrieve S3 credentials" - ) - return None + creds = self.auth.get_s3_credentials(provider=provider) + # Include new credentials in the cache + self._s3_credentials[location] = now, creds + + return s3fs.S3FileSystem( + key=creds["accessKeyId"], + secret=creds["secretAccessKey"], + token=creds["sessionToken"], + ) @lru_cache def get_fsspec_session(self) -> fsspec.AbstractFileSystem: