Skip to content

Commit

Permalink
feat: invalidate cache on bad connection info and IP lookup (#1118)
Browse files Browse the repository at this point in the history
The Connector caches connection info for future connections and
schedules refresh operations, however for unrecoverable errors/state
we should invalidate the cache to stop future bad refreshes.

We should invalidate the cache on all failed calls to the Cloud SQL Admin
APIs, as well as failed IP lookup (preferred IP does not exist).

Added a ._remove_cached method to the Connector to facilitate
invalidating the cache.
  • Loading branch information
jackwotherspoon authored Jun 25, 2024
1 parent fe2ccb8 commit 672dc4e
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 29 deletions.
2 changes: 1 addition & 1 deletion google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,5 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str:
return self.ip_addrs[ip_type.value]
raise CloudSQLIPTypeError(
"Cloud SQL instance does not have any IP addresses matching "
f"preference: {ip_type.value})"
f"preference: {ip_type.value}"
)
72 changes: 44 additions & 28 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,41 +340,46 @@ async def connect_async(
kwargs.pop("ssl", None)
kwargs.pop("port", None)

# attempt to make connection to Cloud SQL instance
# attempt to get connection info for Cloud SQL instance
try:
conn_info = await cache.connect_info()
# validate driver matches intended database engine
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
# resolve DNS name into IP address for PSC
if ip_type.value == "PSC":
addr_info = await self._loop.getaddrinfo(
ip_address, None, family=socket.AF_INET, type=socket.SOCK_STREAM
)
# getaddrinfo returns a list of 5-tuples that contain socket
# connection info in the form
# (family, type, proto, canonname, sockaddr), where sockaddr is a
# 2-tuple in the form (ip_address, port)
try:
ip_address = addr_info[0][4][0]
except IndexError as e:
raise DnsNameResolutionError(
f"['{instance_connection_string}']: DNS name could not be resolved into IP address"
) from e
logger.debug(
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
except Exception:
# with an error from Cloud SQL Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(instance_connection_string)
raise
# resolve DNS name into IP address for PSC
if ip_type.value == "PSC":
addr_info = await self._loop.getaddrinfo(
ip_address, None, family=socket.AF_INET, type=socket.SOCK_STREAM
)
# getaddrinfo returns a list of 5-tuples that contain socket
# connection info in the form
# (family, type, proto, canonname, sockaddr), where sockaddr is a
# 2-tuple in the form (ip_address, port)
try:
ip_address = addr_info[0][4][0]
except IndexError as e:
raise DnsNameResolutionError(
f"['{instance_connection_string}']: DNS name could not be resolved into IP address"
) from e
logger.debug(
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
)
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
conn_info.database_version, kwargs["user"]
)
# format `user` param for automatic IAM database authn
if enable_iam_auth:
formatted_user = format_database_user(
conn_info.database_version, kwargs["user"]
if formatted_user != kwargs["user"]:
logger.debug(
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
)
if formatted_user != kwargs["user"]:
logger.debug(
f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}"
)
kwargs["user"] = formatted_user

kwargs["user"] = formatted_user
try:
# async drivers are unblocking and can be awaited directly
if driver in ASYNC_DRIVERS:
return await connector(
Expand All @@ -396,6 +401,17 @@ async def connect_async(
await cache.force_refresh()
raise

async def _remove_cached(self, instance_connection_string: str) -> None:
"""Stops all background refreshes and deletes the connection
info cache from the map of caches.
"""
logger.debug(
f"['{instance_connection_string}']: Removing connection info from cache"
)
# remove cache from stored caches and close it
cache = self._cache.pop(instance_connection_string)
await cache.close()

def __enter__(self) -> Any:
"""Enter context manager by returning Connector object"""
return self
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import asyncio
from typing import Union

from aiohttp import ClientResponseError
from google.auth.credentials import Credentials
from mock import patch
import pytest # noqa F401 Needed to run the tests
Expand All @@ -25,6 +26,7 @@
from google.cloud.sql.connector import create_async_connector
from google.cloud.sql.connector import IPTypes
from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import ConnectorLoopError
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
from google.cloud.sql.connector.instance import RefreshAheadCache
Expand Down Expand Up @@ -305,6 +307,60 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) ->
connector.close()


async def test_Connector_remove_cached_bad_instance(
fake_credentials: Credentials, fake_client: CloudSQLClient
) -> None:
"""When a Connector attempts to retrieve connection info for a
non-existent instance, it should delete the instance from
the cache and ensure no background refresh happens (which would be
wasted cycles).
"""
async with Connector(
credentials=fake_credentials, loop=asyncio.get_running_loop()
) as connector:
conn_name = "bad-project:bad-region:bad-inst"
# populate cache
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
connector._cache[conn_name] = cache
# aiohttp client should throw a 404 ClientResponseError
with pytest.raises(ClientResponseError):
await connector.connect_async(
conn_name,
"pg8000",
)
# check that cache has been removed from dict
assert conn_name not in connector._cache


async def test_Connector_remove_cached_no_ip_type(
fake_credentials: Credentials, fake_client: CloudSQLClient
) -> None:
"""When a Connector attempts to connect and preferred IP type is not present,
it should delete the instance from the cache and ensure no background refresh
happens (which would be wasted cycles).
"""
# set instance to only have public IP
fake_client.instance.ip_addrs = {"PRIMARY": "127.0.0.1"}
async with Connector(
credentials=fake_credentials, loop=asyncio.get_running_loop()
) as connector:
conn_name = "test-project:test-region:test-instance"
# populate cache
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
connector._cache[conn_name] = cache
# test instance does not have Private IP, thus should invalidate cache
with pytest.raises(CloudSQLIPTypeError):
await connector.connect_async(
conn_name,
"pg8000",
user="my-user",
password="my-pass",
ip_type="private",
)
# check that cache has been removed from dict
assert conn_name not in connector._cache


def test_default_universe_domain(fake_credentials: Credentials) -> None:
"""Test that default universe domain and constructed service endpoint are
formatted correctly.
Expand Down

0 comments on commit 672dc4e

Please sign in to comment.