Skip to content

Commit

Permalink
chore: encapsulate driver and engine validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Jun 3, 2024
1 parent f59810c commit 0dd112b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
9 changes: 1 addition & 8 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from google.cloud.sql.connector.enums import DriverMapping
from google.cloud.sql.connector.exceptions import ConnectorLoopError
from google.cloud.sql.connector.exceptions import DnsNameResolutionError
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
from google.cloud.sql.connector.instance import IPTypes
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.instance import RefreshStrategy
Expand Down Expand Up @@ -335,13 +334,7 @@ async def connect_async(
try:
conn_info = await cache.connect_info()
# validate driver matches intended database engine
mapping = DriverMapping[driver.upper()]
if not conn_info.database_version.startswith(mapping.value):
raise IncompatibleDriverError(
f"Database driver '{driver}' is incompatible with database "
f"version '{conn_info.database_version}'. Given driver can "
f"only be used with Cloud SQL {mapping.value} databases."
)
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
# resolve DNS name into IP address for PSC
if ip_type.value == "PSC":
Expand Down
22 changes: 22 additions & 0 deletions google/cloud/sql/connector/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from enum import Enum

from google.cloud.sql.connector.exceptions import IncompatibleDriverError


class DriverMapping(Enum):
"""Maps a given database driver to it's corresponding database engine."""
Expand All @@ -22,3 +24,23 @@ class DriverMapping(Enum):
PG8000 = "POSTGRES"
PYMYSQL = "MYSQL"
PYTDS = "SQLSERVER"

@staticmethod
def validate_engine(driver: str, engine_version: str) -> None:
"""Validate that the given driver is compatible with the given engine.
Args:
driver (str): Database driver being used. (i.e. "pg8000")
engine_version (str): Database engine version. (i.e. "POSTGRES_16")
Raises:
IncompatibleDriverError: If the given driver is not compatible with
the given engine.
"""
mapping = DriverMapping[driver.upper()]
if not engine_version.startswith(mapping.value):
raise IncompatibleDriverError(
f"Database driver '{driver}' is incompatible with database "
f"version '{engine_version}'. Given driver can "
f"only be used with Cloud SQL {mapping.value} databases."
)

0 comments on commit 0dd112b

Please sign in to comment.