diff --git a/google/cloud/sql/connector/connection_info.py b/google/cloud/sql/connector/connection_info.py index 40ca0848..06a0b976 100644 --- a/google/cloud/sql/connector/connection_info.py +++ b/google/cloud/sql/connector/connection_info.py @@ -101,5 +101,5 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str: return self.ip_addrs[ip_type.value] raise CloudSQLIPTypeError( "Cloud SQL instance does not have any IP addresses matching " - f"preference: {ip_type.value})" + f"preference: {ip_type.value}" ) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 48e4dade..3f82e69a 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -340,41 +340,46 @@ async def connect_async( kwargs.pop("ssl", None) kwargs.pop("port", None) - # attempt to make connection to Cloud SQL instance + # attempt to get connection info for Cloud SQL instance try: conn_info = await cache.connect_info() # validate driver matches intended database engine DriverMapping.validate_engine(driver, conn_info.database_version) ip_address = conn_info.get_preferred_ip(ip_type) - # resolve DNS name into IP address for PSC - if ip_type.value == "PSC": - addr_info = await self._loop.getaddrinfo( - ip_address, None, family=socket.AF_INET, type=socket.SOCK_STREAM - ) - # getaddrinfo returns a list of 5-tuples that contain socket - # connection info in the form - # (family, type, proto, canonname, sockaddr), where sockaddr is a - # 2-tuple in the form (ip_address, port) - try: - ip_address = addr_info[0][4][0] - except IndexError as e: - raise DnsNameResolutionError( - f"['{instance_connection_string}']: DNS name could not be resolved into IP address" - ) from e - logger.debug( - f"['{instance_connection_string}']: Connecting to {ip_address}:3307" + except Exception: + # with an error from Cloud SQL Admin API call or IP type, invalidate + # the cache and re-raise the error + await self._remove_cached(instance_connection_string) + raise + # resolve DNS name into IP address for PSC + if ip_type.value == "PSC": + addr_info = await self._loop.getaddrinfo( + ip_address, None, family=socket.AF_INET, type=socket.SOCK_STREAM + ) + # getaddrinfo returns a list of 5-tuples that contain socket + # connection info in the form + # (family, type, proto, canonname, sockaddr), where sockaddr is a + # 2-tuple in the form (ip_address, port) + try: + ip_address = addr_info[0][4][0] + except IndexError as e: + raise DnsNameResolutionError( + f"['{instance_connection_string}']: DNS name could not be resolved into IP address" + ) from e + logger.debug( + f"['{instance_connection_string}']: Connecting to {ip_address}:3307" + ) + # format `user` param for automatic IAM database authn + if enable_iam_auth: + formatted_user = format_database_user( + conn_info.database_version, kwargs["user"] ) - # format `user` param for automatic IAM database authn - if enable_iam_auth: - formatted_user = format_database_user( - conn_info.database_version, kwargs["user"] + if formatted_user != kwargs["user"]: + logger.debug( + f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" ) - if formatted_user != kwargs["user"]: - logger.debug( - f"['{instance_connection_string}']: Truncated IAM database username from {kwargs['user']} to {formatted_user}" - ) - kwargs["user"] = formatted_user - + kwargs["user"] = formatted_user + try: # async drivers are unblocking and can be awaited directly if driver in ASYNC_DRIVERS: return await connector( @@ -396,6 +401,17 @@ async def connect_async( await cache.force_refresh() raise + async def _remove_cached(self, instance_connection_string: str) -> None: + """Stops all background refreshes and deletes the connection + info cache from the map of caches. + """ + logger.debug( + f"['{instance_connection_string}']: Removing connection info from cache" + ) + # remove cache from stored caches and close it + cache = self._cache.pop(instance_connection_string) + await cache.close() + def __enter__(self) -> Any: """Enter context manager by returning Connector object""" return self diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 0ff3fd63..7a59b448 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -17,6 +17,7 @@ import asyncio from typing import Union +from aiohttp import ClientResponseError from google.auth.credentials import Credentials from mock import patch import pytest # noqa F401 Needed to run the tests @@ -25,6 +26,7 @@ from google.cloud.sql.connector import create_async_connector from google.cloud.sql.connector import IPTypes from google.cloud.sql.connector.client import CloudSQLClient +from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError from google.cloud.sql.connector.exceptions import ConnectorLoopError from google.cloud.sql.connector.exceptions import IncompatibleDriverError from google.cloud.sql.connector.instance import RefreshAheadCache @@ -305,6 +307,60 @@ def test_Connector_close_called_multiple_times(fake_credentials: Credentials) -> connector.close() +async def test_Connector_remove_cached_bad_instance( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """When a Connector attempts to retrieve connection info for a + non-existent instance, it should delete the instance from + the cache and ensure no background refresh happens (which would be + wasted cycles). + """ + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + conn_name = "bad-project:bad-region:bad-inst" + # populate cache + cache = RefreshAheadCache(conn_name, fake_client, connector._keys) + connector._cache[conn_name] = cache + # aiohttp client should throw a 404 ClientResponseError + with pytest.raises(ClientResponseError): + await connector.connect_async( + conn_name, + "pg8000", + ) + # check that cache has been removed from dict + assert conn_name not in connector._cache + + +async def test_Connector_remove_cached_no_ip_type( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """When a Connector attempts to connect and preferred IP type is not present, + it should delete the instance from the cache and ensure no background refresh + happens (which would be wasted cycles). + """ + # set instance to only have public IP + fake_client.instance.ip_addrs = {"PRIMARY": "127.0.0.1"} + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + conn_name = "test-project:test-region:test-instance" + # populate cache + cache = RefreshAheadCache(conn_name, fake_client, connector._keys) + connector._cache[conn_name] = cache + # test instance does not have Private IP, thus should invalidate cache + with pytest.raises(CloudSQLIPTypeError): + await connector.connect_async( + conn_name, + "pg8000", + user="my-user", + password="my-pass", + ip_type="private", + ) + # check that cache has been removed from dict + assert conn_name not in connector._cache + + def test_default_universe_domain(fake_credentials: Credentials) -> None: """Test that default universe domain and constructed service endpoint are formatted correctly.