Skip to content

Commit

Permalink
Remove SQLAlchemy dependency (#1252)
Browse files Browse the repository at this point in the history
### Description

The `SQLAlchemy` dependency is only needed because we need to parse the
URL for the SQL engine configuration. Since it's relatively
straightforward to parse, this PR uses built-in `urllib` to do the
parsing and removes the `SQLAlchemy` dependency.

<!--- 
  Before requesting review, please make sure you have:
1. read [the contributing
guide](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md),
2. signed the
[CLA](https://docs.getdbt.com/docs/contributor-license-agreements)
3. run `changie new` to [create a changelog
entry](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md#adding-a-changelog-entry)
-->
  • Loading branch information
plypaul authored Jun 6, 2024
1 parent ef8cbac commit 72aaec0
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 63 deletions.

This file was deleted.

1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ disallow_any_explicit = True
disallow_untyped_defs = True
warn_redundant_casts = True
namespace_packages = True
plugins = sqlalchemy.ext.mypy.plugin


# Overrides for missing imports
Expand Down
10 changes: 0 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ dev-packages = [
"metricflow-semantics/extra-hatch-configuration/requirements-dev-packages.txt",
"dbt-metricflow/extra-hatch-configuration/requirements-cli.txt"
]
sql-client-packages = [
"extra-hatch-configuration/requirements-sql-client-packages.txt"
]
trino-sql-client-packages = [
"extra-hatch-configuration/requirements-trino-sql-client-packages.txt"
]
Expand Down Expand Up @@ -96,7 +93,6 @@ run = "run-coverage --no-cov"

features = [
"dev-packages",
"sql-client-packages",
"dbt-duckdb",
]

Expand All @@ -119,7 +115,6 @@ description = "Dev environment for working with Postgres adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-postgres",
]

Expand All @@ -136,7 +131,6 @@ description = "Dev environment for working with the BigQuery adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-bigquery",
]

Expand All @@ -151,7 +145,6 @@ description = "Dev environment for working with the Databricks adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-databricks",
]

Expand All @@ -165,7 +158,6 @@ description = "Dev environment for working with the Redshift adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-redshift"
]

Expand All @@ -179,7 +171,6 @@ description = "Dev environment for working with Snowflake adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-snowflake",
]

Expand All @@ -195,7 +186,6 @@ description = "Dev environment for working with the Trino adapter"

features = [
"dev-packages",
"sql-client-packages",
"trino-sql-client-packages",
"dbt-trino"
]
Expand Down
96 changes: 96 additions & 0 deletions tests_metricflow/fixtures/connection_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import urllib.parse
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple


@dataclass(frozen=True)
class UrlQueryField:
"""Field name / values specified in the query part of a URL."""

field_name: str
values: Tuple[str, ...]


@dataclass(frozen=True)
class SqlEngineConnectionParameterSet:
"""Describes how to connect to a SQL engine."""

url_str: str
dialect: str
query_fields: Tuple[UrlQueryField, ...]
driver: Optional[str]
username: Optional[str]
password: Optional[str]
hostname: Optional[str]
port: Optional[int]
database: Optional[str]

# Custom-handling for Databricks.
http_path: Optional[str]

@staticmethod
def create_from_url(url_str: str) -> SqlEngineConnectionParameterSet:
"""The URL roughly follows the format used by SQLAlchemy.
* This implementation is used to avoid having to specify SQLAlchemy as a dependency.
* Databricks has a URL with a semicolon that separates an additional parameter. e.g.
`databricks://host:port/database;http_path=a/b/c`. From @tlento: "Our original Databricks client was built
before they added an officially supported SQLAlchemy client, so we used the JDBC connection URI.
https://docs.databricks.com/en/integrations/jdbc/authentication.html"
"""
url_separator = ";"
url_split = url_str.split(url_separator)
if len(url_split) > 2:
raise ValueError(f"Expected at most 1 {repr(url_separator)} in {url_str}")

parsed_url = urllib.parse.urlparse(url_split[0])
url_extra = url_split[1] if len(url_split) > 1 else None

dialect_driver = parsed_url.scheme.split("+")
if len(dialect_driver) > 2:
raise ValueError(f"Expected at most one + in {repr(parsed_url.scheme)}")
dialect = dialect_driver[0]
driver = dialect_driver[1] if len(dialect_driver) > 1 else None

# redshift://../dbname -> /dbname -> dbname
database = parsed_url.path.lstrip("/")

query_fields = tuple(
UrlQueryField(field_name, tuple(values))
for field_name, values in urllib.parse.parse_qs(parsed_url.query).items()
)

http_path = None
if url_extra is not None:
field_name_value_separator = "="
url_extra_split = url_extra.split(field_name_value_separator)
if len(field_name_value_separator) == 2:
field_name = url_extra_split[0]
value = url_split[1]
if field_name.lower() == "http_path":
http_path = value

return SqlEngineConnectionParameterSet(
url_str=url_str,
dialect=dialect,
driver=driver,
username=parsed_url.username,
password=parsed_url.password,
hostname=parsed_url.hostname,
port=parsed_url.port,
database=database,
query_fields=query_fields,
http_path=http_path,
)

def get_query_field_values(self, field_name: str) -> Sequence[str]:
"""In the URL query, return the values for the field with the given name.
Returns an empty sequence if the field name is not specified.
"""
for field in self.query_fields:
if field.field_name == field_name:
return field.values
return ()
11 changes: 0 additions & 11 deletions tests_metricflow/fixtures/setup_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
add_display_snapshots_cli_flag,
add_overwrite_snapshots_cli_flag,
)
from sqlalchemy.engine import make_url

from tests_metricflow import TESTS_METRICFLOW_DIRECTORY_ANCHOR
from tests_metricflow.fixtures.sql_clients.common_client import SqlDialect
from tests_metricflow.snapshots import METRICFLOW_SNAPSHOT_DIRECTORY_ANCHOR
from tests_metricflow.table_snapshot.table_snapshots import SqlTableSnapshotHash, SqlTableSnapshotRepository

Expand Down Expand Up @@ -84,7 +82,6 @@ def check_sql_engine_snapshot_marker(request: FixtureRequest) -> None:
@pytest.fixture(scope="session")
def mf_test_configuration( # noqa: D103
request: FixtureRequest,
disable_sql_alchemy_deprecation_warning: None,
source_table_snapshot_repository: SqlTableSnapshotRepository,
) -> MetricFlowTestConfiguration:
engine_url = os.environ.get("MF_SQL_ENGINE_URL")
Expand Down Expand Up @@ -154,14 +151,6 @@ def mf_test_configuration( # noqa: D103
)


def dialect_from_url(url: str) -> SqlDialect:
"""Return the SQL dialect specified in the URL in the configuration."""
dialect_protocol = make_url(url.split(";")[0]).drivername.split("+")
if len(dialect_protocol) > 2:
raise ValueError(f"Invalid # of +'s in {url}")
return SqlDialect(dialect_protocol[0])


def dbt_project_dir() -> str:
"""Return the canonical path string for the dbt project dir in the test package.
Expand Down
66 changes: 29 additions & 37 deletions tests_metricflow/fixtures/sql_client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from typing import Generator

import pytest
import sqlalchemy
import sqlalchemy.util
from _pytest.fixtures import FixtureRequest
from dbt.adapters.factory import get_adapter_by_type
from dbt.cli.main import dbtRunner
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.protocols.sql_client import SqlClient
from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir, dialect_from_url
from tests_metricflow.fixtures.connection_url import SqlEngineConnectionParameterSet
from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir
from tests_metricflow.fixtures.sql_clients.adapter_backed_ddl_client import AdapterBackedDDLSqlClient
from tests_metricflow.fixtures.sql_clients.common_client import SqlDialect
from tests_metricflow.fixtures.sql_clients.ddl_sql_client import SqlClientWithDDLMethods
Expand Down Expand Up @@ -50,33 +49,32 @@
DBT_ENV_SECRET_CATALOG = "DBT_ENV_SECRET_CATALOG"


def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlalchemy.engine.URL:
def __configure_test_env_from_url(url: str, password: str, schema: str) -> SqlEngineConnectionParameterSet:
"""Populates default env var mapping from a sqlalchemy URL string.
This is used to configure the test environment from the original MF_SQL_ENGINE_URL environment variable in
a manner compatible with the dbt profile configurations laid out for most supported engines. We return
the parsed URL object so that individual engine configurations can override the environment variables
as needed to match their dbt profile configuration.
"""
parsed_url = sqlalchemy.engine.make_url(url)
connection_parameters = SqlEngineConnectionParameterSet.create_from_url(url)
if connection_parameters.dialect != "duckdb":
assert connection_parameters.hostname, "Engine host is not set in engine connection URL!"
os.environ[DBT_ENV_SECRET_HOST] = connection_parameters.hostname

if parsed_url.drivername != "duckdb":
assert parsed_url.host, "Engine host is not set in engine connection URL!"
os.environ[DBT_ENV_SECRET_HOST] = parsed_url.host
if connection_parameters.username:
os.environ[DBT_ENV_SECRET_USER] = connection_parameters.username

if parsed_url.username:
os.environ[DBT_ENV_SECRET_USER] = parsed_url.username
if connection_parameters.database:
os.environ[DBT_ENV_SECRET_DATABASE] = connection_parameters.database

if parsed_url.database:
os.environ[DBT_ENV_SECRET_DATABASE] = parsed_url.database

if parsed_url.port:
os.environ[DBT_PROFILE_PORT] = str(parsed_url.port)
if connection_parameters.port:
os.environ[DBT_PROFILE_PORT] = str(connection_parameters.port)

os.environ[DBT_ENV_SECRET_PASSWORD] = password
os.environ[DBT_ENV_SECRET_SCHEMA] = schema

return parsed_url
return connection_parameters


def __configure_bigquery_env_from_credential_string(password: str, schema: str) -> None:
Expand Down Expand Up @@ -140,37 +138,38 @@ def __initialize_dbt() -> None:
def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithDDLMethods:
"""Build SQL client based on env configs."""
# TODO: Switch on an enum of adapter type when all engines are cut over
dialect = dialect_from_url(url=url)
dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(url).dialect)

if dialect == SqlDialect.REDSHIFT:
if dialect is SqlDialect.REDSHIFT:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("redshift"))
elif dialect == SqlDialect.SNOWFLAKE:
parsed_url = __configure_test_env_from_url(url, password=password, schema=schema)
assert "warehouse" in parsed_url.normalized_query, "Sql engine URL params did not include Snowflake warehouse!"
warehouses = parsed_url.normalized_query["warehouse"]
assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`"
os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouses[0]
elif dialect is SqlDialect.SNOWFLAKE:
connection_parameters = __configure_test_env_from_url(url, password=password, schema=schema)
warehouse_names = connection_parameters.get_query_field_values("warehouse")
assert (
len(warehouse_names) == 1
), f"SQL engine URL did not specify exactly 1 Snowflake warehouse! Got {warehouse_names}"
os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouse_names[0]
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("snowflake"))
elif dialect == SqlDialect.BIGQUERY:
elif dialect is SqlDialect.BIGQUERY:
__configure_bigquery_env_from_credential_string(password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("bigquery"))
elif dialect == SqlDialect.POSTGRESQL:
elif dialect is SqlDialect.POSTGRESQL:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("postgres"))
elif dialect == SqlDialect.DUCKDB:
elif dialect is SqlDialect.DUCKDB:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("duckdb"))
elif dialect == SqlDialect.DATABRICKS:
elif dialect is SqlDialect.DATABRICKS:
__configure_databricks_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("databricks"))
elif dialect == SqlDialect.TRINO:
elif dialect is SqlDialect.TRINO:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("trino"))
Expand Down Expand Up @@ -237,7 +236,7 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103
) from e

num_items = len(request.session.items)
dialect = dialect_from_url(mf_test_configuration.sql_engine_url)
dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(mf_test_configuration.sql_engine_url).dialect)

# If already running in parallel or if there's not many test items, no need to print the warning. Picking 10/30 as
# the thresholds, but not much thought has been put into it.
Expand All @@ -254,10 +253,3 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103
f'Consider using the pytest-xdist option "-n <number of workers>" to parallelize execution and speed '
f"up the session."
)


@pytest.fixture(scope="session", autouse=True)
def disable_sql_alchemy_deprecation_warning() -> None:
"""Since MF is tied to using SQLAlchemy 1.x.x due to the Snowflake connector, silence 2.0 deprecation warnings."""
# Seeing 'error: Module has no attribute "SILENCE_UBER_WARNING"' in the type checker, but this seems to work.
sqlalchemy.util.deprecations.SILENCE_UBER_WARNING = True # type:ignore

0 comments on commit 72aaec0

Please sign in to comment.