-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove SQLAlchemy
dependency
#1252
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,94 @@ | ||||||
from __future__ import annotations | ||||||
|
||||||
import urllib.parse | ||||||
from dataclasses import dataclass | ||||||
from typing import Optional, Sequence, Tuple | ||||||
|
||||||
|
||||||
@dataclass(frozen=True) | ||||||
class UrlQueryField: | ||||||
"""Field name / values specified in the query part of a URL.""" | ||||||
|
||||||
field_name: str | ||||||
values: Tuple[str, ...] | ||||||
|
||||||
|
||||||
@dataclass(frozen=True) | ||||||
class SqlEngineConnectionParameterSet: | ||||||
"""Describes how to connect to a SQL engine.""" | ||||||
|
||||||
url_str: str | ||||||
dialect: str | ||||||
query_fields: Tuple[UrlQueryField, ...] | ||||||
driver: Optional[str] | ||||||
username: Optional[str] | ||||||
password: Optional[str] | ||||||
hostname: Optional[str] | ||||||
port: Optional[int] | ||||||
database: Optional[str] | ||||||
|
||||||
# Custom-handling for Databricks. | ||||||
http_path: Optional[str] | ||||||
|
||||||
@staticmethod | ||||||
def create_from_url(url_str: str) -> SqlEngineConnectionParameterSet: | ||||||
"""The URL roughly follows the format used by SQLAlchemy. | ||||||
|
||||||
* This implementation is used to avoid having to specify SQLAlchemy as a dependency. | ||||||
* Databricks has a URL with a semicolon that separates an additional parameter. e.g. | ||||||
`databricks://host:port/database;http_path=a/b/c`. Need additional context for why this is done. | ||||||
""" | ||||||
url_seperator = ";" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||||||
url_split = url_str.split(url_seperator) | ||||||
if len(url_split) > 2: | ||||||
raise ValueError(f"Expected at most 1 {repr(url_seperator)} in {url_str}") | ||||||
|
||||||
parsed_url = urllib.parse.urlparse(url_split[0]) | ||||||
url_extra = url_split[1] if len(url_split) > 1 else None | ||||||
|
||||||
dialect_driver = parsed_url.scheme.split("+") | ||||||
if len(dialect_driver) > 2: | ||||||
raise ValueError(f"Expected at most one + in {repr(parsed_url.scheme)}") | ||||||
dialect = dialect_driver[0] | ||||||
driver = dialect_driver[1] if len(dialect_driver) > 1 else None | ||||||
|
||||||
# redshift://../dbname -> /dbname -> dbname | ||||||
database = parsed_url.path.lstrip("/") | ||||||
|
||||||
query_fields = tuple( | ||||||
UrlQueryField(field_name, tuple(values)) | ||||||
for field_name, values in urllib.parse.parse_qs(parsed_url.query).items() | ||||||
) | ||||||
|
||||||
http_path = None | ||||||
if url_extra is not None: | ||||||
field_name_value_seperator = "=" | ||||||
url_extra_split = url_extra.split(field_name_value_seperator) | ||||||
if len(field_name_value_seperator) == 2: | ||||||
field_name = url_extra_split[0] | ||||||
value = url_split[1] | ||||||
if field_name.lower() == "http_path": | ||||||
http_path = value | ||||||
|
||||||
return SqlEngineConnectionParameterSet( | ||||||
url_str=url_str, | ||||||
dialect=dialect, | ||||||
driver=driver, | ||||||
username=parsed_url.username, | ||||||
password=parsed_url.password, | ||||||
hostname=parsed_url.hostname, | ||||||
port=parsed_url.port, | ||||||
database=database, | ||||||
query_fields=query_fields, | ||||||
http_path=http_path, | ||||||
) | ||||||
|
||||||
def get_query_field_values(self, field_name: str) -> Sequence[str]: | ||||||
"""In the URL query, return the values for the field with the given name. | ||||||
|
||||||
Returns an empty sequence if the field name is not specified. | ||||||
""" | ||||||
for field in self.query_fields: | ||||||
if field.field_name == field_name: | ||||||
return field.values | ||||||
return () |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,15 +7,14 @@ | |
from typing import Generator | ||
|
||
import pytest | ||
import sqlalchemy | ||
import sqlalchemy.util | ||
from _pytest.fixtures import FixtureRequest | ||
from dbt.adapters.factory import get_adapter_by_type | ||
from dbt.cli.main import dbtRunner | ||
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration | ||
|
||
from metricflow.protocols.sql_client import SqlClient | ||
from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir, dialect_from_url | ||
from tests_metricflow.fixtures.connection_url import SqlEngineConnectionParameterSet | ||
from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir | ||
from tests_metricflow.fixtures.sql_clients.adapter_backed_ddl_client import AdapterBackedDDLSqlClient | ||
from tests_metricflow.fixtures.sql_clients.common_client import SqlDialect | ||
from tests_metricflow.fixtures.sql_clients.ddl_sql_client import SqlClientWithDDLMethods | ||
|
@@ -50,33 +49,34 @@ | |
DBT_ENV_SECRET_CATALOG = "DBT_ENV_SECRET_CATALOG" | ||
|
||
|
||
def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlalchemy.engine.URL: | ||
def __configure_test_env_from_url(url: str, password: str, schema: str) -> SqlEngineConnectionParameterSet: | ||
"""Populates default env var mapping from a sqlalchemy URL string. | ||
|
||
This is used to configure the test environment from the original MF_SQL_ENGINE_URL environment variable in | ||
a manner compatible with the dbt profile configurations laid out for most supported engines. We return | ||
the parsed URL object so that individual engine configurations can override the environment variables | ||
as needed to match their dbt profile configuration. | ||
""" | ||
parsed_url = sqlalchemy.engine.make_url(url) | ||
# parsed_url = sqlalchemy.engine.make_url(url) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Delete this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
|
||
if parsed_url.drivername != "duckdb": | ||
assert parsed_url.host, "Engine host is not set in engine connection URL!" | ||
os.environ[DBT_ENV_SECRET_HOST] = parsed_url.host | ||
connection_parameters = SqlEngineConnectionParameterSet.create_from_url(url) | ||
if connection_parameters.dialect != "duckdb": | ||
assert connection_parameters.hostname, "Engine host is not set in engine connection URL!" | ||
os.environ[DBT_ENV_SECRET_HOST] = connection_parameters.hostname | ||
|
||
if parsed_url.username: | ||
os.environ[DBT_ENV_SECRET_USER] = parsed_url.username | ||
if connection_parameters.username: | ||
os.environ[DBT_ENV_SECRET_USER] = connection_parameters.username | ||
|
||
if parsed_url.database: | ||
os.environ[DBT_ENV_SECRET_DATABASE] = parsed_url.database | ||
if connection_parameters.database: | ||
os.environ[DBT_ENV_SECRET_DATABASE] = connection_parameters.database | ||
|
||
if parsed_url.port: | ||
os.environ[DBT_PROFILE_PORT] = str(parsed_url.port) | ||
if connection_parameters.port: | ||
os.environ[DBT_PROFILE_PORT] = str(connection_parameters.port) | ||
|
||
os.environ[DBT_ENV_SECRET_PASSWORD] = password | ||
os.environ[DBT_ENV_SECRET_SCHEMA] = schema | ||
|
||
return parsed_url | ||
return connection_parameters | ||
|
||
|
||
def __configure_bigquery_env_from_credential_string(password: str, schema: str) -> None: | ||
|
@@ -140,37 +140,38 @@ def __initialize_dbt() -> None: | |
def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithDDLMethods: | ||
"""Build SQL client based on env configs.""" | ||
# TODO: Switch on an enum of adapter type when all engines are cut over | ||
dialect = dialect_from_url(url=url) | ||
dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(url).dialect) | ||
|
||
if dialect == SqlDialect.REDSHIFT: | ||
if dialect is SqlDialect.REDSHIFT: | ||
__configure_test_env_from_url(url, password=password, schema=schema) | ||
__initialize_dbt() | ||
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("redshift")) | ||
elif dialect == SqlDialect.SNOWFLAKE: | ||
parsed_url = __configure_test_env_from_url(url, password=password, schema=schema) | ||
assert "warehouse" in parsed_url.normalized_query, "Sql engine URL params did not include Snowflake warehouse!" | ||
warehouses = parsed_url.normalized_query["warehouse"] | ||
assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`" | ||
os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouses[0] | ||
elif dialect is SqlDialect.SNOWFLAKE: | ||
connection_parameters = __configure_test_env_from_url(url, password=password, schema=schema) | ||
warehouse_names = connection_parameters.get_query_field_values("warehouse") | ||
assert ( | ||
len(warehouse_names) == 1 | ||
), f"SQL engine URL did not specify exactly 1 Snowflake warehouse! Got {warehouse_names}" | ||
os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouse_names[0] | ||
__initialize_dbt() | ||
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("snowflake")) | ||
elif dialect == SqlDialect.BIGQUERY: | ||
elif dialect is SqlDialect.BIGQUERY: | ||
__configure_bigquery_env_from_credential_string(password=password, schema=schema) | ||
__initialize_dbt() | ||
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("bigquery")) | ||
elif dialect == SqlDialect.POSTGRESQL: | ||
elif dialect is SqlDialect.POSTGRESQL: | ||
__configure_test_env_from_url(url, password=password, schema=schema) | ||
__initialize_dbt() | ||
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("postgres")) | ||
elif dialect == SqlDialect.DUCKDB: | ||
elif dialect is SqlDialect.DUCKDB: | ||
__configure_test_env_from_url(url, password=password, schema=schema) | ||
__initialize_dbt() | ||
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("duckdb")) | ||
elif dialect == SqlDialect.DATABRICKS: | ||
elif dialect is SqlDialect.DATABRICKS: | ||
__configure_databricks_env_from_url(url, password=password, schema=schema) | ||
__initialize_dbt() | ||
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("databricks")) | ||
elif dialect == SqlDialect.TRINO: | ||
elif dialect is SqlDialect.TRINO: | ||
__configure_test_env_from_url(url, password=password, schema=schema) | ||
__initialize_dbt() | ||
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("trino")) | ||
|
@@ -237,7 +238,7 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103 | |
) from e | ||
|
||
num_items = len(request.session.items) | ||
dialect = dialect_from_url(mf_test_configuration.sql_engine_url) | ||
dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(mf_test_configuration.sql_engine_url).dialect) | ||
|
||
# If already running in parallel or if there's not many test items, no need to print the warning. Picking 10/30 as | ||
# the thresholds, but not much thought has been put into it. | ||
|
@@ -254,10 +255,3 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103 | |
f'Consider using the pytest-xdist option "-n <number of workers>" to parallelize execution and speed ' | ||
f"up the session." | ||
) | ||
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
def disable_sql_alchemy_deprecation_warning() -> None: | ||
"""Since MF is tied to using SQLAlchemy 1.x.x due to the Snowflake connector, silence 2.0 deprecation warnings.""" | ||
# Seeing 'error: Module has no attribute "SILENCE_UBER_WARNING"' in the type checker, but this seems to work. | ||
sqlalchemy.util.deprecations.SILENCE_UBER_WARNING = True # type:ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our original Databricks client was built before they added an officially supported SQLAlchemy client, so we used the JDBC connection URI. https://docs.databricks.com/en/integrations/jdbc/authentication.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added in comment.