Skip to content

Commit

Permalink
chore: lazy init keys with lazy refresh (#1110)
Browse files Browse the repository at this point in the history
RSA key-pair generation should be done on first connection
attempt when lazy refresh strategy is configured.
  • Loading branch information
jackwotherspoon authored Jun 3, 2024
1 parent deb732e commit 5712e9e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
28 changes: 19 additions & 9 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,32 @@ def __init__(
RefreshStrategy.BACKGROUND ("BACKGROUND").
Default: RefreshStrategy.BACKGROUND
"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
refresh_strategy = RefreshStrategy._from_str(refresh_strategy)
self._refresh_strategy = refresh_strategy
# if event loop is given, use for background tasks
if loop:
self._loop: asyncio.AbstractEventLoop = loop
self._thread: Optional[Thread] = None
self._keys: asyncio.Future = loop.create_task(generate_keys())
# if lazy refresh is specified we should lazy init keys
if self._refresh_strategy == RefreshStrategy.LAZY:
self._keys: Optional[asyncio.Future] = None
else:
self._keys = loop.create_task(generate_keys())
# if no event loop is given, spin up new loop in background thread
else:
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=self._loop.run_forever, daemon=True)
self._thread.start()
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)
# if lazy refresh is specified we should lazy init keys
if self._refresh_strategy == RefreshStrategy.LAZY:
self._keys = None
else:
self._keys = asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)
self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
self._client: Optional[CloudSQLClient] = None

Expand All @@ -148,10 +160,6 @@ def __init__(
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
self._ip_type = ip_type
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
refresh_strategy = RefreshStrategy._from_str(refresh_strategy)
self._refresh_strategy = refresh_strategy
self._universe_domain = universe_domain
# construct service endpoint for Cloud SQL Admin API calls
if not sqladmin_api_endpoint:
Expand Down Expand Up @@ -258,6 +266,8 @@ async def connect_async(
DnsNameResolutionError: Could not resolve PSC IP address from DNS
host name.
"""
if self._keys is None:
self._keys = asyncio.create_task(generate_keys())
if self._client is None:
# lazy init client as it has to be initialized in async context
self._client = CloudSQLClient(
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def test_Connector_Init(fake_credentials: Credentials) -> None:
connector.close()


def test_Connector_Init_with_lazy_refresh(fake_credentials: Credentials) -> None:
"""Test that Connector with lazy refresh sets keys to None."""
with Connector(credentials=fake_credentials, refresh_strategy="lazy") as connector:
assert connector._keys is None


def test_Connector_Init_with_credentials(fake_credentials: Credentials) -> None:
"""Test that Connector uses custom credentials when given them."""
with patch(
Expand Down

0 comments on commit 5712e9e

Please sign in to comment.