diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 84281d1f..48e4dade 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -108,20 +108,32 @@ def __init__( RefreshStrategy.BACKGROUND ("BACKGROUND"). Default: RefreshStrategy.BACKGROUND """ + # if refresh_strategy is str, convert to RefreshStrategy enum + if isinstance(refresh_strategy, str): + refresh_strategy = RefreshStrategy._from_str(refresh_strategy) + self._refresh_strategy = refresh_strategy # if event loop is given, use for background tasks if loop: self._loop: asyncio.AbstractEventLoop = loop self._thread: Optional[Thread] = None - self._keys: asyncio.Future = loop.create_task(generate_keys()) + # if lazy refresh is specified we should lazy init keys + if self._refresh_strategy == RefreshStrategy.LAZY: + self._keys: Optional[asyncio.Future] = None + else: + self._keys = loop.create_task(generate_keys()) # if no event loop is given, spin up new loop in background thread else: self._loop = asyncio.new_event_loop() self._thread = Thread(target=self._loop.run_forever, daemon=True) self._thread.start() - self._keys = asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), - loop=self._loop, - ) + # if lazy refresh is specified we should lazy init keys + if self._refresh_strategy == RefreshStrategy.LAZY: + self._keys = None + else: + self._keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), + loop=self._loop, + ) self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {} self._client: Optional[CloudSQLClient] = None @@ -148,10 +160,6 @@ def __init__( if isinstance(ip_type, str): ip_type = IPTypes._from_str(ip_type) self._ip_type = ip_type - # if refresh_strategy is str, convert to RefreshStrategy enum - if isinstance(refresh_strategy, str): - refresh_strategy = RefreshStrategy._from_str(refresh_strategy) - self._refresh_strategy = refresh_strategy self._universe_domain = universe_domain # construct service endpoint for Cloud SQL Admin API calls if not sqladmin_api_endpoint: @@ -258,6 +266,8 @@ async def connect_async( DnsNameResolutionError: Could not resolve PSC IP address from DNS host name. """ + if self._keys is None: + self._keys = asyncio.create_task(generate_keys()) if self._client is None: # lazy init client as it has to be initialized in async context self._client = CloudSQLClient( diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 3621d89e..0ff3fd63 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -114,6 +114,12 @@ def test_Connector_Init(fake_credentials: Credentials) -> None: connector.close() +def test_Connector_Init_with_lazy_refresh(fake_credentials: Credentials) -> None: + """Test that Connector with lazy refresh sets keys to None.""" + with Connector(credentials=fake_credentials, refresh_strategy="lazy") as connector: + assert connector._keys is None + + def test_Connector_Init_with_credentials(fake_credentials: Credentials) -> None: """Test that Connector uses custom credentials when given them.""" with patch(