diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index bfec8e8d..d49e6adf 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -418,6 +418,18 @@ def refresh_token( "Adding auto refresh key word arguments %s.", self.auto_refresh_kwargs ) kwargs.update(self.auto_refresh_kwargs) + + auth = auth or kwargs.pop("auth", None) + client_id = kwargs.get("client_id") + client_secret = kwargs.get("client_secret", "") + + if client_id and (auth is None): + log.debug( + 'Encoding client_id "%s" with client_secret as Basic auth credentials.', + client_id, + ) + auth = requests.auth.HTTPBasicAuth(client_id, client_secret) + body = self._client.prepare_refresh_body( body=body, refresh_token=refresh_token, scope=self.scope, **kwargs ) @@ -491,16 +503,11 @@ def request( self.auto_refresh_url, ) - # We mustn't pass auth twice. - auth = kwargs.pop("auth", None) - if client_id and client_secret and (auth is None): - log.debug( - 'Encoding client_id "%s" with client_secret as Basic auth credentials.', - client_id, - ) - auth = requests.auth.HTTPBasicAuth(client_id, client_secret) token = self.refresh_token( - self.auto_refresh_url, auth=auth, **kwargs + self.auto_refresh_url, + client_id=client_id, + client_secret=client_secret, + **kwargs ) if self.token_updater: log.debug( diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index cfc62368..1fca7cc5 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -194,6 +194,21 @@ def fake_refresh_with_auth(r, **kwargs): client_secret=self.client_secret, ) + # auto refresh with auth from auto_refresh_kwargs + for client in self.clients: + sess = OAuth2Session( + client=client, + token=self.expired_token, + auto_refresh_url="https://i.b/refresh", + token_updater=token_updater, + auto_refresh_kwargs={ + "client_id": self.client_id, + "client_secret": self.client_secret, + }, + ) + sess.send = fake_refresh_with_auth + sess.get("https://i.b") + @mock.patch("time.time", new=lambda: fake_time) def test_token_from_fragment(self): mobile = MobileApplicationClient(self.client_id)