diff --git a/gcsfs/credentials.py b/gcsfs/credentials.py index bbd9ea57..7cd5ac3b 100644 --- a/gcsfs/credentials.py +++ b/gcsfs/credentials.py @@ -5,6 +5,7 @@ import textwrap import threading import warnings +from datetime import datetime import google.auth as gauth import google.auth.compute_engine @@ -99,7 +100,6 @@ def _connect_cloud(self): raise ValueError("Invalid gcloud credentials") from error def _connect_cache(self): - if len(self.tokens) == 0: raise ValueError("No cached tokens") @@ -167,17 +167,30 @@ def _connect_token(self, token): if self.credentials.valid: self.credentials.apply(self.heads) - def maybe_refresh(self): - # this uses requests and is blocking + def _credentials_valid(self, refresh_buffer): + return ( + self.credentials.valid + # In addition to checking current validity, we ensure that there is + # not a near-future expiry to avoid errors when expiration hits. + and self.credentials.expiry + and (self.credentials.expiry - datetime.utcnow()).total_seconds() + > refresh_buffer + ) + + def maybe_refresh(self, refresh_buffer=300): + """Check and refresh credentials if needed""" if self.credentials is None: return # anon - if self.credentials.valid: - return # still good + + if self._credentials_valid(refresh_buffer): + return # still good, with buffer + with requests.Session() as session: req = Request(session) with self.lock: - if self.credentials.valid: + if self._credentials_valid(refresh_buffer): return # repeat to avoid race (but don't want lock in common case) + logger.debug("GCS refresh") self.credentials.refresh(req) # https://github.com/fsspec/filesystem_spec/issues/565 diff --git a/gcsfs/retry.py b/gcsfs/retry.py index 0b02fbc2..c5062173 100644 --- a/gcsfs/retry.py +++ b/gcsfs/retry.py @@ -71,6 +71,9 @@ def is_retriable(exception): """Returns True if this exception is retriable.""" if isinstance(exception, HttpError): + # Add 401 to retriable errors when it's an auth expiration issue + if exception.code == 401 and "Invalid Credentials" in str(exception.message): + return True return exception.code in errs return isinstance(exception, RETRIABLE_EXCEPTIONS)