diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index e5af54b2..1c805814 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -141,10 +141,13 @@ async def _get_metadata( if "ipAddresses" in ret_dict else {} ) - # Remove trailing period from PSC DNS name. - psc_dns = ret_dict.get("dnsName") - if psc_dns: - ip_addresses["PSC"] = psc_dns.rstrip(".") + # resolve dnsName into IP address for PSC + # Note that we have to check for PSC enablement also because CAS + # instances also set the dnsName field. + # Remove trailing period from DNS name. Required for SSL in Python + dns_name = ret_dict.get("dnsName", "").rstrip(".") + if dns_name and ret_dict.get("pscEnabled"): + ip_addresses["PSC"] = dns_name return { "ip_addresses": ip_addresses, diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 03ad5a6e..0f25f1c1 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -234,6 +234,7 @@ def __init__( self.name = name self.db_version = db_version self.ip_addrs = ip_addrs + self.psc_enabled = False self.cert_before = cert_before self.cert_expiration = cert_expiration # create self signed CA cert @@ -255,6 +256,7 @@ async def connect_settings(self, request: Any) -> web.Response: "expirationTime": str(self.cert_expiration), }, "dnsName": "abcde.12345.us-central1.sql.goog", + "pscEnabled": self.psc_enabled, "ipAddresses": ip_addrs, "region": self.region, "databaseVersion": self.db_version, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index da1c97ea..046e8e51 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -24,9 +24,9 @@ @pytest.mark.asyncio -async def test_get_metadata(fake_client: CloudSQLClient) -> None: +async def test_get_metadata_no_psc(fake_client: CloudSQLClient) -> None: """ - Test _get_metadata returns successfully. + Test _get_metadata returns successfully and does not include PSC IP type. """ resp = await fake_client._get_metadata( "test-project", @@ -34,6 +34,26 @@ async def test_get_metadata(fake_client: CloudSQLClient) -> None: "test-instance", ) assert resp["database_version"] == "POSTGRES_15" + assert resp["ip_addresses"] == { + "PRIMARY": "127.0.0.1", + "PRIVATE": "10.0.0.1", + } + assert isinstance(resp["server_ca_cert"], str) + + +@pytest.mark.asyncio +async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None: + """ + Test _get_metadata returns successfully with PSC IP type. + """ + # set PSC to enabled on test instance + fake_client.instance.psc_enabled = True + resp = await fake_client._get_metadata( + "test-project", + "test-region", + "test-instance", + ) + assert resp["database_version"] == "POSTGRES_15" assert resp["ip_addresses"] == { "PRIMARY": "127.0.0.1", "PRIVATE": "10.0.0.1",