Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove SQLAlchemy dependency #1252

Merged
merged 4 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

1 change: 0 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ disallow_any_explicit = True
disallow_untyped_defs = True
warn_redundant_casts = True
namespace_packages = True
plugins = sqlalchemy.ext.mypy.plugin


# Overrides for missing imports
Expand Down
10 changes: 0 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ dev-packages = [
"metricflow-semantics/extra-hatch-configuration/requirements-dev-packages.txt",
"dbt-metricflow/extra-hatch-configuration/requirements-cli.txt"
]
sql-client-packages = [
"extra-hatch-configuration/requirements-sql-client-packages.txt"
]
trino-sql-client-packages = [
"extra-hatch-configuration/requirements-trino-sql-client-packages.txt"
]
Expand Down Expand Up @@ -96,7 +93,6 @@ run = "run-coverage --no-cov"

features = [
"dev-packages",
"sql-client-packages",
"dbt-duckdb",
]

Expand All @@ -119,7 +115,6 @@ description = "Dev environment for working with Postgres adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-postgres",
]

Expand All @@ -136,7 +131,6 @@ description = "Dev environment for working with the BigQuery adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-bigquery",
]

Expand All @@ -151,7 +145,6 @@ description = "Dev environment for working with the Databricks adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-databricks",
]

Expand All @@ -165,7 +158,6 @@ description = "Dev environment for working with the Redshift adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-redshift"
]

Expand All @@ -179,7 +171,6 @@ description = "Dev environment for working with Snowflake adapter"

features = [
"dev-packages",
"sql-client-packages",
"dbt-snowflake",
]

Expand All @@ -195,7 +186,6 @@ description = "Dev environment for working with the Trino adapter"

features = [
"dev-packages",
"sql-client-packages",
"trino-sql-client-packages",
"dbt-trino"
]
Expand Down
94 changes: 94 additions & 0 deletions tests_metricflow/fixtures/connection_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from __future__ import annotations

import urllib.parse
from dataclasses import dataclass
from typing import Optional, Sequence, Tuple


@dataclass(frozen=True)
class UrlQueryField:
"""Field name / values specified in the query part of a URL."""

field_name: str
values: Tuple[str, ...]


@dataclass(frozen=True)
class SqlEngineConnectionParameterSet:
"""Describes how to connect to a SQL engine."""

url_str: str
dialect: str
query_fields: Tuple[UrlQueryField, ...]
driver: Optional[str]
username: Optional[str]
password: Optional[str]
hostname: Optional[str]
port: Optional[int]
database: Optional[str]

# Custom-handling for Databricks.
http_path: Optional[str]

@staticmethod
def create_from_url(url_str: str) -> SqlEngineConnectionParameterSet:
"""The URL roughly follows the format used by SQLAlchemy.

* This implementation is used to avoid having to specify SQLAlchemy as a dependency.
* Databricks has a URL with a semicolon that separates an additional parameter. e.g.
`databricks://host:port/database;http_path=a/b/c`. Need additional context for why this is done.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our original Databricks client was built before they added an officially supported SQLAlchemy client, so we used the JDBC connection URI. https://docs.databricks.com/en/integrations/jdbc/authentication.html

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in comment.

"""
url_seperator = ";"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
url_seperator = ";"
url_separator = ";"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

url_split = url_str.split(url_seperator)
if len(url_split) > 2:
raise ValueError(f"Expected at most 1 {repr(url_seperator)} in {url_str}")

parsed_url = urllib.parse.urlparse(url_split[0])
url_extra = url_split[1] if len(url_split) > 1 else None

dialect_driver = parsed_url.scheme.split("+")
if len(dialect_driver) > 2:
raise ValueError(f"Expected at most one + in {repr(parsed_url.scheme)}")
dialect = dialect_driver[0]
driver = dialect_driver[1] if len(dialect_driver) > 1 else None

# redshift://../dbname -> /dbname -> dbname
database = parsed_url.path.lstrip("/")

query_fields = tuple(
UrlQueryField(field_name, tuple(values))
for field_name, values in urllib.parse.parse_qs(parsed_url.query).items()
)

http_path = None
if url_extra is not None:
field_name_value_seperator = "="
url_extra_split = url_extra.split(field_name_value_seperator)
if len(field_name_value_seperator) == 2:
field_name = url_extra_split[0]
value = url_split[1]
if field_name.lower() == "http_path":
http_path = value

return SqlEngineConnectionParameterSet(
url_str=url_str,
dialect=dialect,
driver=driver,
username=parsed_url.username,
password=parsed_url.password,
hostname=parsed_url.hostname,
port=parsed_url.port,
database=database,
query_fields=query_fields,
http_path=http_path,
)

def get_query_field_values(self, field_name: str) -> Sequence[str]:
"""In the URL query, return the values for the field with the given name.

Returns an empty sequence if the field name is not specified.
"""
for field in self.query_fields:
if field.field_name == field_name:
return field.values
return ()
11 changes: 0 additions & 11 deletions tests_metricflow/fixtures/setup_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
add_display_snapshots_cli_flag,
add_overwrite_snapshots_cli_flag,
)
from sqlalchemy.engine import make_url

from tests_metricflow import TESTS_METRICFLOW_DIRECTORY_ANCHOR
from tests_metricflow.fixtures.sql_clients.common_client import SqlDialect
from tests_metricflow.snapshots import METRICFLOW_SNAPSHOT_DIRECTORY_ANCHOR
from tests_metricflow.table_snapshot.table_snapshots import SqlTableSnapshotHash, SqlTableSnapshotRepository

Expand Down Expand Up @@ -84,7 +82,6 @@ def check_sql_engine_snapshot_marker(request: FixtureRequest) -> None:
@pytest.fixture(scope="session")
def mf_test_configuration( # noqa: D103
request: FixtureRequest,
disable_sql_alchemy_deprecation_warning: None,
source_table_snapshot_repository: SqlTableSnapshotRepository,
) -> MetricFlowTestConfiguration:
engine_url = os.environ.get("MF_SQL_ENGINE_URL")
Expand Down Expand Up @@ -154,14 +151,6 @@ def mf_test_configuration( # noqa: D103
)


def dialect_from_url(url: str) -> SqlDialect:
"""Return the SQL dialect specified in the URL in the configuration."""
dialect_protocol = make_url(url.split(";")[0]).drivername.split("+")
if len(dialect_protocol) > 2:
raise ValueError(f"Invalid # of +'s in {url}")
return SqlDialect(dialect_protocol[0])


def dbt_project_dir() -> str:
"""Return the canonical path string for the dbt project dir in the test package.

Expand Down
66 changes: 30 additions & 36 deletions tests_metricflow/fixtures/sql_client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from typing import Generator

import pytest
import sqlalchemy
import sqlalchemy.util
from _pytest.fixtures import FixtureRequest
from dbt.adapters.factory import get_adapter_by_type
from dbt.cli.main import dbtRunner
from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration

from metricflow.protocols.sql_client import SqlClient
from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir, dialect_from_url
from tests_metricflow.fixtures.connection_url import SqlEngineConnectionParameterSet
from tests_metricflow.fixtures.setup_fixtures import dbt_project_dir
from tests_metricflow.fixtures.sql_clients.adapter_backed_ddl_client import AdapterBackedDDLSqlClient
from tests_metricflow.fixtures.sql_clients.common_client import SqlDialect
from tests_metricflow.fixtures.sql_clients.ddl_sql_client import SqlClientWithDDLMethods
Expand Down Expand Up @@ -50,33 +49,34 @@
DBT_ENV_SECRET_CATALOG = "DBT_ENV_SECRET_CATALOG"


def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlalchemy.engine.URL:
def __configure_test_env_from_url(url: str, password: str, schema: str) -> SqlEngineConnectionParameterSet:
"""Populates default env var mapping from a sqlalchemy URL string.

This is used to configure the test environment from the original MF_SQL_ENGINE_URL environment variable in
a manner compatible with the dbt profile configurations laid out for most supported engines. We return
the parsed URL object so that individual engine configurations can override the environment variables
as needed to match their dbt profile configuration.
"""
parsed_url = sqlalchemy.engine.make_url(url)
# parsed_url = sqlalchemy.engine.make_url(url)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.


if parsed_url.drivername != "duckdb":
assert parsed_url.host, "Engine host is not set in engine connection URL!"
os.environ[DBT_ENV_SECRET_HOST] = parsed_url.host
connection_parameters = SqlEngineConnectionParameterSet.create_from_url(url)
if connection_parameters.dialect != "duckdb":
assert connection_parameters.hostname, "Engine host is not set in engine connection URL!"
os.environ[DBT_ENV_SECRET_HOST] = connection_parameters.hostname

if parsed_url.username:
os.environ[DBT_ENV_SECRET_USER] = parsed_url.username
if connection_parameters.username:
os.environ[DBT_ENV_SECRET_USER] = connection_parameters.username

if parsed_url.database:
os.environ[DBT_ENV_SECRET_DATABASE] = parsed_url.database
if connection_parameters.database:
os.environ[DBT_ENV_SECRET_DATABASE] = connection_parameters.database

if parsed_url.port:
os.environ[DBT_PROFILE_PORT] = str(parsed_url.port)
if connection_parameters.port:
os.environ[DBT_PROFILE_PORT] = str(connection_parameters.port)

os.environ[DBT_ENV_SECRET_PASSWORD] = password
os.environ[DBT_ENV_SECRET_SCHEMA] = schema

return parsed_url
return connection_parameters


def __configure_bigquery_env_from_credential_string(password: str, schema: str) -> None:
Expand Down Expand Up @@ -140,37 +140,38 @@ def __initialize_dbt() -> None:
def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithDDLMethods:
"""Build SQL client based on env configs."""
# TODO: Switch on an enum of adapter type when all engines are cut over
dialect = dialect_from_url(url=url)
dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(url).dialect)

if dialect == SqlDialect.REDSHIFT:
if dialect is SqlDialect.REDSHIFT:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("redshift"))
elif dialect == SqlDialect.SNOWFLAKE:
parsed_url = __configure_test_env_from_url(url, password=password, schema=schema)
assert "warehouse" in parsed_url.normalized_query, "Sql engine URL params did not include Snowflake warehouse!"
warehouses = parsed_url.normalized_query["warehouse"]
assert len(warehouses) == 1, f"Found more than 1 warehouse in Snowflake URL: `{warehouses}`"
os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouses[0]
elif dialect is SqlDialect.SNOWFLAKE:
connection_parameters = __configure_test_env_from_url(url, password=password, schema=schema)
warehouse_names = connection_parameters.get_query_field_values("warehouse")
assert (
len(warehouse_names) == 1
), f"SQL engine URL did not specify exactly 1 Snowflake warehouse! Got {warehouse_names}"
os.environ[DBT_ENV_SECRET_WAREHOUSE] = warehouse_names[0]
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("snowflake"))
elif dialect == SqlDialect.BIGQUERY:
elif dialect is SqlDialect.BIGQUERY:
__configure_bigquery_env_from_credential_string(password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("bigquery"))
elif dialect == SqlDialect.POSTGRESQL:
elif dialect is SqlDialect.POSTGRESQL:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("postgres"))
elif dialect == SqlDialect.DUCKDB:
elif dialect is SqlDialect.DUCKDB:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("duckdb"))
elif dialect == SqlDialect.DATABRICKS:
elif dialect is SqlDialect.DATABRICKS:
__configure_databricks_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("databricks"))
elif dialect == SqlDialect.TRINO:
elif dialect is SqlDialect.TRINO:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("trino"))
Expand Down Expand Up @@ -237,7 +238,7 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103
) from e

num_items = len(request.session.items)
dialect = dialect_from_url(mf_test_configuration.sql_engine_url)
dialect = SqlDialect(SqlEngineConnectionParameterSet.create_from_url(mf_test_configuration.sql_engine_url).dialect)

# If already running in parallel or if there's not many test items, no need to print the warning. Picking 10/30 as
# the thresholds, but not much thought has been put into it.
Expand All @@ -254,10 +255,3 @@ def warn_user_about_slow_tests_without_parallelism( # noqa: D103
f'Consider using the pytest-xdist option "-n <number of workers>" to parallelize execution and speed '
f"up the session."
)


@pytest.fixture(scope="session", autouse=True)
def disable_sql_alchemy_deprecation_warning() -> None:
"""Since MF is tied to using SQLAlchemy 1.x.x due to the Snowflake connector, silence 2.0 deprecation warnings."""
# Seeing 'error: Module has no attribute "SILENCE_UBER_WARNING"' in the type checker, but this seems to work.
sqlalchemy.util.deprecations.SILENCE_UBER_WARNING = True # type:ignore
Loading