Skip to content

Commit

Permalink
fix: retry 50x errors with exponential backoff (#1125)
Browse files Browse the repository at this point in the history
This commit adds retry behavior to the two SQL Admin API calls.

Any response that results in a 50x error will now be retried up to
5 times with exponential backoff and jitter between each attempt.

The formula used to calculate the duration to wait is:

200ms * 1.618^(attempt + jitter)

This calculation matches what the Cloud SQL Proxy v1 did and
will not trigger any significant change in load on the backend.
  • Loading branch information
jackwotherspoon authored Jul 10, 2024
1 parent 076da83 commit 2da9128
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 6 deletions.
14 changes: 9 additions & 5 deletions google/cloud/sql/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
from google.cloud.sql.connector.refresh_utils import retry_50x
from google.cloud.sql.connector.version import __version__ as version

if TYPE_CHECKING:
Expand Down Expand Up @@ -124,7 +125,10 @@ async def _get_metadata(

url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}/connectSettings"

resp = await self._client.get(url, headers=headers, raise_for_status=True)
resp = await self._client.get(url, headers=headers)
if resp.status >= 500:
resp = await retry_50x(self._client.get, url, headers=headers)
resp.raise_for_status()
ret_dict = await resp.json()

if ret_dict["region"] != region:
Expand Down Expand Up @@ -188,10 +192,10 @@ async def _get_ephemeral(
login_creds = _downscope_credentials(self._credentials)
data["access_token"] = login_creds.token

resp = await self._client.post(
url, headers=headers, json=data, raise_for_status=True
)

resp = await self._client.post(url, headers=headers, json=data)
if resp.status >= 500:
resp = await retry_50x(self._client.post, url, headers=headers, json=data)
resp.raise_for_status()
ret_dict = await resp.json()

ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"]
Expand Down
49 changes: 48 additions & 1 deletion google/cloud/sql/connector/refresh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import copy
import datetime
import logging
from typing import List
import random
from typing import Any, Callable, List

import aiohttp
from google.auth.credentials import Credentials
from google.auth.credentials import Scoped
import google.auth.transport.requests
Expand Down Expand Up @@ -105,3 +107,48 @@ def _downscope_credentials(
request = google.auth.transport.requests.Request()
scoped_creds.refresh(request)
return scoped_creds


def _exponential_backoff(attempt: int) -> float:
"""Calculates a duration to backoff in milliseconds based on the attempt i.
The formula is:
base * multi^(attempt + 1 + random)
With base = 200ms and multi = 1.618, and random = [0.0, 1.0),
the backoff values would fall between the following low and high ends:
Attempt Low (ms) High (ms)
0 324 524
1 524 847
2 847 1371
3 1371 2218
4 2218 3588
The theoretical worst case scenario would have a client wait 8.5s in total
for an API request to complete (with the first four attempts failing, and
the fifth succeeding).
"""
base = 200
multi = 1.618
exp = attempt + 1 + random.random()
return base * pow(multi, exp)


async def retry_50x(
request_coro: Callable, *args: Any, **kwargs: Any
) -> aiohttp.ClientResponse:
"""Retry any 50x HTTP response up to X number of times."""
max_retries = 5
for i in range(max_retries):
resp = await request_coro(*args, **kwargs)
# backoff for any 50X errors
if resp.status >= 500 and i < max_retries:
# calculate backoff time
backoff = _exponential_backoff(i)
await asyncio.sleep(backoff / 1000)
else:
break
return resp
51 changes: 51 additions & 0 deletions tests/unit/test_refresh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
limitations under the License.
"""

from __future__ import annotations

import asyncio
import datetime

Expand All @@ -27,8 +29,10 @@
import pytest # noqa F401 Needed to run the tests

from google.cloud.sql.connector.refresh_utils import _downscope_credentials
from google.cloud.sql.connector.refresh_utils import _exponential_backoff
from google.cloud.sql.connector.refresh_utils import _is_valid
from google.cloud.sql.connector.refresh_utils import _seconds_until_refresh
from google.cloud.sql.connector.refresh_utils import retry_50x


@pytest.fixture
Expand Down Expand Up @@ -148,3 +152,50 @@ def test_seconds_until_refresh_under_4_mins() -> None:
)
== 0
)


@pytest.mark.parametrize(
"attempt, low, high",
[
(0, 324, 524),
(1, 524, 847),
(2, 847, 1371),
(3, 1371, 2218),
(4, 2218, 3588),
],
)
def test_exponential_backoff(attempt: int, low: int, high: int) -> None:
"""
Test _exponential_backoff produces times (in ms) in the proper range.
"""
backoff = _exponential_backoff(attempt)
assert backoff >= low
assert backoff <= high


class RetryClass:
def __init__(self) -> None:
self.attempts = 0

async def fake_request(self, status: int) -> RetryClass:
self.status = status
self.attempts += 1
return self


async def test_retry_50x_with_503() -> None:
fake_client = RetryClass()
resp = await retry_50x(fake_client.fake_request, 503)
assert resp.attempts == 5


async def test_retry_50x_with_200() -> None:
fake_client = RetryClass()
resp = await retry_50x(fake_client.fake_request, 200)
assert resp.attempts == 1


async def test_retry_50x_with_400() -> None:
fake_client = RetryClass()
resp = await retry_50x(fake_client.fake_request, 400)
assert resp.attempts == 1

0 comments on commit 2da9128

Please sign in to comment.