diff --git a/tests_metricflow/fixtures/sql_clients/base_sql_client_implementation.py b/tests_metricflow/fixtures/sql_clients/base_sql_client_implementation.py deleted file mode 100644 index 8f7df405fb..0000000000 --- a/tests_metricflow/fixtures/sql_clients/base_sql_client_implementation.py +++ /dev/null @@ -1,145 +0,0 @@ -from __future__ import annotations - -import logging -import time -from abc import ABC, abstractmethod -from typing import Optional - -import pandas as pd -from metricflow_semantics.mf_logging.formatting import indent -from metricflow_semantics.mf_logging.pretty_print import mf_pformat -from metricflow_semantics.random_id import random_id -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters - -from metricflow.protocols.sql_client import ( - SqlClient, -) -from metricflow.sql.sql_table import SqlTable -from metricflow.sql_request.sql_request_attributes import SqlRequestId - -logger = logging.getLogger(__name__) - - -class SqlClientException(Exception): - """Raised when an interaction with the SQL engine has an error.""" - - pass - - -class BaseSqlClientImplementation(ABC, SqlClient): - """Abstract implementation that other SQL clients are based on.""" - - @staticmethod - def _format_run_query_log_message(statement: str, sql_bind_parameters: SqlBindParameters) -> str: - message = f"Running query:\n\n{indent(statement)}" - if len(sql_bind_parameters.param_dict) > 0: - message += f"\n\nwith parameters:\n\n{indent(mf_pformat(sql_bind_parameters.param_dict))}" - return message - - def query( - self, - stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), - ) -> pd.DataFrame: - """Query statement; result expected to be data which will be returned as a DataFrame. - - Args: - stmt: The SQL query statement to run. This should produce output via a SELECT - sql_bind_parameters: The parameter replacement mapping for filling in - concrete values for SQL query parameters. - """ - start = time.time() - SqlRequestId(f"mf_rid__{random_id()}") - logger.info(BaseSqlClientImplementation._format_run_query_log_message(stmt, sql_bind_parameters)) - df = self._engine_specific_query_implementation( - stmt=stmt, - bind_params=sql_bind_parameters, - ) - if not isinstance(df, pd.DataFrame): - raise RuntimeError(f"Expected query to return a DataFrame, got {type(df)}") - stop = time.time() - logger.info(f"Finished running the query in {stop - start:.2f}s with {df.shape[0]} row(s) returned") - return df - - def execute( # noqa: D102 - self, - stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), - ) -> None: - start = time.time() - logger.info(BaseSqlClientImplementation._format_run_query_log_message(stmt, sql_bind_parameters)) - self._engine_specific_execute_implementation( - stmt=stmt, - bind_params=sql_bind_parameters, - ) - stop = time.time() - logger.info(f"Finished running the query in {stop - start:.2f}s") - return None - - def dry_run( - self, - stmt: str, - sql_bind_parameters: SqlBindParameters = SqlBindParameters(), - ) -> None: - """Dry run statement; checks that the 'stmt' is queryable. Returns None. Raises an exception if the 'stmt' isn't queryable. - - Args: - stmt: The SQL query statement to dry run. - sql_bind_parameters: The parameter replacement mapping for filling in - concrete values for SQL query parameters. - """ - start = time.time() - logger.info( - f"Running dry_run of:" - f"\n\n{indent(stmt)}\n" - + (f"\nwith parameters: {dict(sql_bind_parameters.param_dict)}" if sql_bind_parameters.param_dict else "") - ) - results = self._engine_specific_dry_run_implementation(stmt, sql_bind_parameters) - stop = time.time() - logger.info(f"Finished running the dry_run in {stop - start:.2f}s") - return results - - @abstractmethod - def _engine_specific_query_implementation( - self, - stmt: str, - bind_params: SqlBindParameters, - ) -> pd.DataFrame: - """Sub-classes should implement this to query the engine.""" - pass - - @abstractmethod - def _engine_specific_execute_implementation( - self, - stmt: str, - bind_params: SqlBindParameters, - ) -> None: - """Sub-classes should implement this to execute a statement that doesn't return results.""" - pass - - @abstractmethod - def _engine_specific_dry_run_implementation(self, stmt: str, bind_params: SqlBindParameters) -> None: - """Sub-classes should implement this to check a query will run successfully without actually running the query.""" - pass - - @abstractmethod - def create_table_from_dataframe( # noqa: D102 - self, - sql_table: SqlTable, - df: pd.DataFrame, - chunk_size: Optional[int] = None, - ) -> None: - pass - - def create_schema(self, schema_name: str) -> None: # noqa: D102 - self.execute(f"CREATE SCHEMA IF NOT EXISTS {schema_name}") - - def drop_schema(self, schema_name: str, cascade: bool = True) -> None: # noqa: D102 - self.execute(f"DROP SCHEMA IF EXISTS {schema_name}{' CASCADE' if cascade else ''}") - - def close(self) -> None: # noqa: D102 - pass - - def render_bind_parameter_key(self, bind_parameter_key: str) -> str: - """Wrap execution parameter key with syntax accepted by engine.""" - return f":{bind_parameter_key}" diff --git a/tests_metricflow/fixtures/sql_clients/sqlalchemy_dialect.py b/tests_metricflow/fixtures/sql_clients/sqlalchemy_dialect.py deleted file mode 100644 index b5ea8f2440..0000000000 --- a/tests_metricflow/fixtures/sql_clients/sqlalchemy_dialect.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import logging -import time -from abc import ABC -from contextlib import contextmanager -from typing import Iterator, Mapping, Optional, Sequence, Set, Union - -import pandas as pd -import sqlalchemy -from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters - -from metricflow.sql.sql_table import SqlTable -from tests_metricflow.fixtures.sql_clients.base_sql_client_implementation import BaseSqlClientImplementation - -logger = logging.getLogger(__name__) - - -class SqlAlchemySqlClient(BaseSqlClientImplementation, ABC): - """Base class for to create DBClients for engines supported by SQLAlchemy.""" - - def __init__(self, engine: sqlalchemy.engine.Engine) -> None: # noqa: D107 - self._engine = engine - super().__init__() - - @staticmethod - def build_engine_url( # noqa: D102 - dialect: str, - database: str, - username: str, - password: Optional[str], - host: str, - port: Optional[int] = None, - query: Optional[Mapping[str, Union[str, Sequence[str]]]] = None, - driver: Optional[str] = None, - ) -> sqlalchemy.engine.url.URL: - return sqlalchemy.engine.url.URL.create( - f"{dialect}+{driver}" if driver else f"{dialect}", - username=username, - password=password, - host=host, - port=port, - database=database, - **({"query": query} if query is not None else {}), - ) - - @staticmethod - def create_engine( # noqa: D102 - dialect: str, - port: int, - database: str, - username: str, - password: str, - host: str, - driver: Optional[str] = None, - query: Optional[Mapping[str, Union[str, Sequence[str]]]] = None, - ) -> sqlalchemy.engine.Engine: - connect_url = SqlAlchemySqlClient.build_engine_url( - dialect=dialect, - driver=driver, - username=username, - password=password, - host=host, - port=port, - database=database, - query=query, - ) - # Without pool_pre_ping, it's possible for timed-out connections to be returned to the client and cause errors. - # However, this can cause increase latency for slow engines. - return sqlalchemy.create_engine( - connect_url, - pool_size=10, - max_overflow=10, - pool_pre_ping=True, - ) - - @contextmanager - def _engine_connection( - self, - engine: sqlalchemy.engine.Engine, - ) -> Iterator[sqlalchemy.engine.Connection]: - """Context Manager for providing a configured connection.""" - conn = engine.connect() - try: - yield conn - finally: - conn.close() - - def _engine_specific_query_implementation( - self, - stmt: str, - bind_params: SqlBindParameters, - ) -> pd.DataFrame: - with self._engine_connection(self._engine) as conn: - return pd.read_sql_query(sqlalchemy.text(stmt), conn, params=bind_params.param_dict) - - def _engine_specific_execute_implementation( - self, - stmt: str, - bind_params: SqlBindParameters, - ) -> None: - with self._engine_connection(self._engine) as conn: - conn.execute(sqlalchemy.text(stmt), bind_params.param_dict) - - def _engine_specific_dry_run_implementation(self, stmt: str, bind_params: SqlBindParameters) -> None: - with self._engine_connection(self._engine) as conn: - s = "EXPLAIN " + stmt - conn.execute(sqlalchemy.text(s), bind_params.param_dict) - - def create_table_from_dataframe( # noqa: D102 - self, sql_table: SqlTable, df: pd.DataFrame, chunk_size: Optional[int] = None - ) -> None: - logger.info(f"Creating table '{sql_table.sql}' from a DataFrame with {df.shape[0]} row(s)") - start_time = time.time() - with self._engine_connection(self._engine) as conn: - pd.io.sql.to_sql( - frame=df, - name=sql_table.table_name, - con=conn, - schema=sql_table.schema_name, - index=False, - if_exists="fail", - method="multi", - chunksize=chunk_size, - ) - logger.info(f"Created table '{sql_table.sql}' from a DataFrame in {time.time() - start_time:.2f}s") - - @staticmethod - def validate_query_params( - url: sqlalchemy.engine.url.URL, - required_parameters: Set[str], - optional_parameters: Set[str], - ) -> None: - """Checks that the query parameters in the URL only include the valid parameters specified.""" - query_keys = set(url.query.keys()) - errors = [] - if not query_keys.issuperset(required_parameters): - errors.append(f"Missing required parameters {required_parameters - query_keys}") - if not query_keys.issubset(required_parameters.union(optional_parameters)): - errors.append(f"Found extra parameters {query_keys - required_parameters.union(optional_parameters)}") - - if errors: - raise ValueError(f"Found errors in the URL: {url}\n" + "\n".join(errors))