From 50fc672b5b37ab2bd9594d9725d2f737264d724a Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Tue, 4 Jun 2024 17:00:42 -0700 Subject: [PATCH 1/4] /* PR_START p--py312 15 */ Use `urlparse` instead of SQLAlchemy. --- tests_metricflow/fixtures/connection_url.py | 94 +++++++++++++++++++ tests_metricflow/fixtures/setup_fixtures.py | 10 -- .../fixtures/sql_client_fixtures.py | 57 +++++------ 3 files changed, 124 insertions(+), 37 deletions(-) create mode 100644 tests_metricflow/fixtures/connection_url.py diff --git a/tests_metricflow/fixtures/connection_url.py b/tests_metricflow/fixtures/connection_url.py new file mode 100644 index 0000000000..c74113fed2 --- /dev/null +++ b/tests_metricflow/fixtures/connection_url.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import urllib.parse +from dataclasses import dataclass +from typing import Optional, Sequence, Tuple + + +@dataclass(frozen=True) +class UrlQueryField: + """Field name / values specified in the query part of a URL.""" + + field_name: str + values: Tuple[str, ...] + + +@dataclass(frozen=True) +class SqlEngineConnectionParameterSet: + """Describes how to connect to a SQL engine.""" + + url_str: str + dialect: str + query_fields: Tuple[UrlQueryField, ...] + driver: Optional[str] + username: Optional[str] + password: Optional[str] + hostname: Optional[str] + port: Optional[int] + database: Optional[str] + + # Custom-handling for Databricks. + http_path: Optional[str] + + @staticmethod + def create_from_url(url_str: str) -> SqlEngineConnectionParameterSet: + """The URL roughly follows the format used by SQLAlchemy. + + * This implementation is used to avoid having to specify SQLAlchemy as a dependency. + * Databricks has a URL with a semicolon that separates an additional parameter. e.g. + `databricks://host:port/database;http_path=a/b/c`. Need additional context for why this is done. + """ + url_seperator = ";" + url_split = url_str.split(url_seperator) + if len(url_split) > 2: + raise ValueError(f"Expected at most 1 {repr(url_seperator)} in {url_str}") + + parsed_url = urllib.parse.urlparse(url_split[0]) + url_extra = url_split[1] if len(url_split) > 1 else None + + dialect_driver = parsed_url.scheme.split("+") + if len(dialect_driver) > 2: + raise ValueError(f"Expected at most one + in {repr(parsed_url.scheme)}") + dialect = dialect_driver[0] + driver = dialect_driver[1] if len(dialect_driver) > 1 else None + + # redshift://../dbname -> /dbname -> dbname + database = parsed_url.path.lstrip("/") + + query_fields = tuple( + UrlQueryField(field_name, tuple(values)) + for field_name, values in urllib.parse.parse_qs(parsed_url.query).items() + ) + + http_path = None + if url_extra is not None: + field_name_value_seperator = "=" + url_extra_split = url_extra.split(field_name_value_seperator) + if len(field_name_value_seperator) == 2: + field_name = url_extra_split[0] + value = url_split[1] + if field_name.lower() == "http_path": + http_path = value + + return SqlEngineConnectionParameterSet( + url_str=url_str, + dialect=dialect, + driver=driver, + username=parsed_url.username, + password=parsed_url.password, + hostname=parsed_url.hostname, + port=parsed_url.port, + database=database, + query_fields=query_fields, + http_path=http_path, + ) + + def get_query_field_values(self, field_name: str) -> Sequence[str]: + """In the URL query, return the values for the field with the given name. + + Returns an empty sequence if the field name is not specified. + """ + for field in self.query_fields: + if field.field_name == field_name: + return field.values + return () diff --git a/tests_metricflow/fixtures/setup_fixtures.py b/tests_metricflow/fixtures/setup_fixtures.py index 649ae5dacf..d809550793 100644 --- a/tests_metricflow/fixtures/setup_fixtures.py +++ b/tests_metricflow/fixtures/setup_fixtures.py @@ -16,10 +16,8 @@ add_display_snapshots_cli_flag, add_overwrite_snapshots_cli_flag, ) -from sqlalchemy.engine import make_url from tests_metricflow import TESTS_METRICFLOW_DIRECTORY_ANCHOR -from tests_metricflow.fixtures.sql_clients.common_client import SqlDialect from tests_metricflow.snapshots import METRICFLOW_SNAPSHOT_DIRECTORY_ANCHOR from tests_metricflow.table_snapshot.table_snapshots import SqlTableSnapshotHash, SqlTableSnapshotRepository @@ -154,14 +152,6 @@ def mf_test_configuration( # noqa: D103 ) -def dialect_from_url(url: str) -> SqlDialect: - """Return the SQL dialect specified in the URL in the configuration.""" - dialect_protocol = make_url(url.split(";")[0]).drivername.split("+") - if len(dialect_protocol) > 2: - raise ValueError(f"Invalid # of +'s in {url}") - return SqlDialect(dialect_protocol[0]) - - def dbt_project_dir() -> str: """Return the canonical path string for the dbt project dir in the test package. diff --git a/tests_metricflow/fixtures/sql_client_fixtures.py b/tests_metricflow/fixtures/sql_client_fixtures.py index 428578e527..b4d8d61af2 100644 --- a/tests_metricflow/fixtures/sql_client_fixtures.py +++ b/tests_metricflow/fixtures/sql_client_fixtures.py @@ -15,7 +15,8 @@ from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.protocols.sql_client import SqlClient -from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir, dialect_from_url +from tests_metricflow.fixtures.connection_url import SqlEngineConnectionParameterSet +from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir from tests_metricflow.fixtures.sql_clients.adapter_backed_ddl_client import AdapterBackedDDLSqlClient from tests_metricflow.fixtures.sql_clients.common_client import SqlDialect from tests_metricflow.fixtures.sql_clients.ddl_sql_client import SqlClientWithDDLMethods @@ -50,7 +51,7 @@ DBT_ENV_SECRET_CATALOG = "DBT_ENV_SECRET_CATALOG" -def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlalchemy.engine.URL: +def __configure_test_env_from_url(url: str, password: str, schema: str) -> SqlEngineConnectionParameterSet: """Populates default env var mapping from a sqlalchemy URL string. This is used to configure the test environment from the original MF_SQL_ENGINE_URL environment variable in @@ -58,25 +59,26 @@ def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlal the parsed URL object so that individual engine configurations can override the environment variables as needed to match their dbt profile configuration. """ - parsed_url = sqlalchemy.engine.make_url(url) + # parsed_url = sqlalchemy.engine.make_url(url) - if parsed_url.drivername != "duckdb": - assert parsed_url.host, "Engine host is not set in engine connection URL!" - os.environ[DBT_ENV_SECRET_HOST] = parsed_url.host + connection_parameters = SqlEngineConnectionParameterSet.create_from_url(url) + if connection_parameters.dialect != "duckdb": + assert connection_parameters.hostname, "Engine host is not set in engine connection URL!" + os.environ[DBT_ENV_SECRET_HOST] = connection_parameters.hostname - if parsed_url.username: - os.environ[DBT_ENV_SECRET_USER] = parsed_url.username + if connection_parameters.username: + os.environ[DBT_ENV_SECRET_USER] = connection_parameters.username - if parsed_url.database: - os.environ[DBT_ENV_SECRET_DATABASE] = parsed_url.database + if connection_parameters.database: + os.environ[DBT_ENV_SECRET_DATABASE] = connection_parameters.database - if parsed_url.port: - os.environ[DBT_PROFILE_PORT] = str(parsed_url.port) + if connection_parameters.port: + os.environ[DBT_PROFILE_PORT] = str(connection_parameters.port) os.environ[DBT_ENV_SECRET_PASSWORD] = password os.environ[DBT_ENV_SECRET_SCHEMA] = schema - return parsed_url + return connection_parameters def __configure_bigquery_env_from_credential_string(password: str, schema: str) -> None: @@ -140,37 +142,38 @@ def __initialize_dbt() -> None: def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithDDLMethods: """Build SQL client based on env configs.""" # TODO: Switch on an enum of adapter type when all engines are cut over - dialect = dialect_from_url(url=url) + dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(url).dialect) - if dialect == SqlDialect.REDSHIFT: + if dialect is SqlDialect.REDSHIFT: __configure_test_env_from_url(url, password=password, schema=schema) __initialize_dbt() return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("redshift")) - elif dialect == SqlDialect.SNOWFLAKE: - parsed_url = __configure_test_env_from_url(url, password=password, schema=schema) - assert "warehouse" in parsed_url.normalized_query, "Sql engine URL params did not include Snowflake warehouse!" - warehouses = parsed_url.normalized_query["warehouse"] - assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`" - os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouses[0] + elif dialect is SqlDialect.SNOWFLAKE: + connection_parameters = __configure_test_env_from_url(url, password=password, schema=schema) + warehouse_names = connection_parameters.get_query_field_values("warehouse") + assert ( + len(warehouse_names) == 1 + ), f"SQL engine URL did not specify exactly 1 Snowflake warehouse! Got {warehouse_names}" + os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouse_names[0] __initialize_dbt() return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("snowflake")) - elif dialect == SqlDialect.BIGQUERY: + elif dialect is SqlDialect.BIGQUERY: __configure_bigquery_env_from_credential_string(password=password, schema=schema) __initialize_dbt() return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("bigquery")) - elif dialect == SqlDialect.POSTGRESQL: + elif dialect is SqlDialect.POSTGRESQL: __configure_test_env_from_url(url, password=password, schema=schema) __initialize_dbt() return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("postgres")) - elif dialect == SqlDialect.DUCKDB: + elif dialect is SqlDialect.DUCKDB: __configure_test_env_from_url(url, password=password, schema=schema) __initialize_dbt() return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("duckdb")) - elif dialect == SqlDialect.DATABRICKS: + elif dialect is SqlDialect.DATABRICKS: __configure_databricks_env_from_url(url, password=password, schema=schema) __initialize_dbt() return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("databricks")) - elif dialect == SqlDialect.TRINO: + elif dialect is SqlDialect.TRINO: __configure_test_env_from_url(url, password=password, schema=schema) __initialize_dbt() return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("trino")) @@ -237,7 +240,7 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103 ) from e num_items = len(request.session.items) - dialect = dialect_from_url(mf_test_configuration.sql_engine_url) + dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(mf_test_configuration.sql_engine_url).dialect) # If already running in parallel or if there's not many test items, no need to print the warning. Picking 10/30 as # the thresholds, but not much thought has been put into it. From 51b9b0c3e203c4b0083cd549be5c9466eeb843e2 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Tue, 4 Jun 2024 17:01:16 -0700 Subject: [PATCH 2/4] Remove SQLAlchemy deprecation warning. --- tests_metricflow/fixtures/setup_fixtures.py | 1 - tests_metricflow/fixtures/sql_client_fixtures.py | 9 --------- 2 files changed, 10 deletions(-) diff --git a/tests_metricflow/fixtures/setup_fixtures.py b/tests_metricflow/fixtures/setup_fixtures.py index d809550793..57962f0961 100644 --- a/tests_metricflow/fixtures/setup_fixtures.py +++ b/tests_metricflow/fixtures/setup_fixtures.py @@ -82,7 +82,6 @@ def check_sql_engine_snapshot_marker(request: FixtureRequest) -> None: @pytest.fixture(scope="session") def mf_test_configuration( # noqa: D103 request: FixtureRequest, - disable_sql_alchemy_deprecation_warning: None, source_table_snapshot_repository: SqlTableSnapshotRepository, ) -> MetricFlowTestConfiguration: engine_url = os.environ.get("MF_SQL_ENGINE_URL") diff --git a/tests_metricflow/fixtures/sql_client_fixtures.py b/tests_metricflow/fixtures/sql_client_fixtures.py index b4d8d61af2..8fd0ce8c51 100644 --- a/tests_metricflow/fixtures/sql_client_fixtures.py +++ b/tests_metricflow/fixtures/sql_client_fixtures.py @@ -7,8 +7,6 @@ from typing import Generator import pytest -import sqlalchemy -import sqlalchemy.util from _pytest.fixtures import FixtureRequest from dbt.adapters.factory import get_adapter_by_type from dbt.cli.main import dbtRunner @@ -257,10 +255,3 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103 f'Consider using the pytest-xdist option "-n " to parallelize execution and speed ' f"up the session." ) - - -@pytest.fixture(scope="session", autouse=True) -def disable_sql_alchemy_deprecation_warning() -> None: - """Since MF is tied to using SQLAlchemy 1.x.x due to the Snowflake connector, silence 2.0 deprecation warnings.""" - # Seeing 'error: Module has no attribute "SILENCE_UBER_WARNING"' in the type checker, but this seems to work. - sqlalchemy.util.deprecations.SILENCE_UBER_WARNING = True # type:ignore From c76f7cc902835cfcee1365b8d234608b4b8f3af6 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Tue, 4 Jun 2024 17:05:56 -0700 Subject: [PATCH 3/4] Remove SQLAlchemy dependency. --- .../requirements-sql-client-packages.txt | 4 ---- mypy.ini | 1 - pyproject.toml | 10 ---------- 3 files changed, 15 deletions(-) delete mode 100644 extra-hatch-configuration/requirements-sql-client-packages.txt diff --git a/extra-hatch-configuration/requirements-sql-client-packages.txt b/extra-hatch-configuration/requirements-sql-client-packages.txt deleted file mode 100644 index 8fa34decb8..0000000000 --- a/extra-hatch-configuration/requirements-sql-client-packages.txt +++ /dev/null @@ -1,4 +0,0 @@ -# These are currently separate for ease of removal, but due to the way Python -# handles import statements they are required in all test environments -SQLAlchemy>=1.4.42, <1.5.0 -sqlalchemy2-stubs>=0.0.2a21, <0.0.3 diff --git a/mypy.ini b/mypy.ini index df423ccab4..0becb6e708 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,7 +5,6 @@ disallow_any_explicit = True disallow_untyped_defs = True warn_redundant_casts = True namespace_packages = True -plugins = sqlalchemy.ext.mypy.plugin # Overrides for missing imports diff --git a/pyproject.toml b/pyproject.toml index b7ce6ecef9..fc365bd5ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,9 +66,6 @@ dev-packages = [ "metricflow-semantics/extra-hatch-configuration/requirements-dev-packages.txt", "dbt-metricflow/extra-hatch-configuration/requirements-cli.txt" ] -sql-client-packages = [ - "extra-hatch-configuration/requirements-sql-client-packages.txt" -] trino-sql-client-packages = [ "extra-hatch-configuration/requirements-trino-sql-client-packages.txt" ] @@ -96,7 +93,6 @@ run = "run-coverage --no-cov" features = [ "dev-packages", - "sql-client-packages", "dbt-duckdb", ] @@ -119,7 +115,6 @@ description = "Dev environment for working with Postgres adapter" features = [ "dev-packages", - "sql-client-packages", "dbt-postgres", ] @@ -136,7 +131,6 @@ description = "Dev environment for working with the BigQuery adapter" features = [ "dev-packages", - "sql-client-packages", "dbt-bigquery", ] @@ -151,7 +145,6 @@ description = "Dev environment for working with the Databricks adapter" features = [ "dev-packages", - "sql-client-packages", "dbt-databricks", ] @@ -165,7 +158,6 @@ description = "Dev environment for working with the Redshift adapter" features = [ "dev-packages", - "sql-client-packages", "dbt-redshift" ] @@ -179,7 +171,6 @@ description = "Dev environment for working with Snowflake adapter" features = [ "dev-packages", - "sql-client-packages", "dbt-snowflake", ] @@ -195,7 +186,6 @@ description = "Dev environment for working with the Trino adapter" features = [ "dev-packages", - "sql-client-packages", "trino-sql-client-packages", "dbt-trino" ] From 70bc6ec48461916efe8e3614e8483aa62670c35b Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Wed, 5 Jun 2024 23:57:05 -0700 Subject: [PATCH 4/4] Address comments. --- tests_metricflow/fixtures/connection_url.py | 16 +++++++++------- tests_metricflow/fixtures/sql_client_fixtures.py | 2 -- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests_metricflow/fixtures/connection_url.py b/tests_metricflow/fixtures/connection_url.py index c74113fed2..12941b32ed 100644 --- a/tests_metricflow/fixtures/connection_url.py +++ b/tests_metricflow/fixtures/connection_url.py @@ -36,12 +36,14 @@ def create_from_url(url_str: str) -> SqlEngineConnectionParameterSet: * This implementation is used to avoid having to specify SQLAlchemy as a dependency. * Databricks has a URL with a semicolon that separates an additional parameter. e.g. - `databricks://host:port/database;http_path=a/b/c`. Need additional context for why this is done. + `databricks://host:port/database;http_path=a/b/c`. From @tlento: "Our original Databricks client was built + before they added an officially supported SQLAlchemy client, so we used the JDBC connection URI. + https://docs.databricks.com/en/integrations/jdbc/authentication.html" """ - url_seperator = ";" - url_split = url_str.split(url_seperator) + url_separator = ";" + url_split = url_str.split(url_separator) if len(url_split) > 2: - raise ValueError(f"Expected at most 1 {repr(url_seperator)} in {url_str}") + raise ValueError(f"Expected at most 1 {repr(url_separator)} in {url_str}") parsed_url = urllib.parse.urlparse(url_split[0]) url_extra = url_split[1] if len(url_split) > 1 else None @@ -62,9 +64,9 @@ def create_from_url(url_str: str) -> SqlEngineConnectionParameterSet: http_path = None if url_extra is not None: - field_name_value_seperator = "=" - url_extra_split = url_extra.split(field_name_value_seperator) - if len(field_name_value_seperator) == 2: + field_name_value_separator = "=" + url_extra_split = url_extra.split(field_name_value_separator) + if len(field_name_value_separator) == 2: field_name = url_extra_split[0] value = url_split[1] if field_name.lower() == "http_path": diff --git a/tests_metricflow/fixtures/sql_client_fixtures.py b/tests_metricflow/fixtures/sql_client_fixtures.py index 8fd0ce8c51..d19e760c44 100644 --- a/tests_metricflow/fixtures/sql_client_fixtures.py +++ b/tests_metricflow/fixtures/sql_client_fixtures.py @@ -57,8 +57,6 @@ def __configure_test_env_from_url(url: str, password: str, schema: str) -> SqlEn the parsed URL object so that individual engine configurations can override the environment variables as needed to match their dbt profile configuration. """ - # parsed_url = sqlalchemy.engine.make_url(url) - connection_parameters = SqlEngineConnectionParameterSet.create_from_url(url) if connection_parameters.dialect != "duckdb": assert connection_parameters.hostname, "Engine host is not set in engine connection URL!"