Skip to content

Commit

Permalink
Databricks Support (#260)
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb authored Oct 2, 2022
1 parent 5bc51e1 commit 7ae092e
Show file tree
Hide file tree
Showing 107 changed files with 7,020 additions and 120 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/cd-sql-engine-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,39 @@ jobs:
MF_SQL_ENGINE_URL: ${{ secrets.MF_BIGQUERY_URL }}
MF_SQL_ENGINE_PASSWORD: ${{ secrets.MF_BIGQUERY_PWD }}

databricks-tests:
environment: DW_INTEGRATION_TESTS
name: Databricks Tests
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
runs-on: ubuntu-latest
steps:
- name: Check-out the repo
uses: actions/checkout@v2

- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: "3.8"

- uses: actions/cache@v2
with:
path: |
${{ env.pythonLocation }}
~/.cache/pypoetry
key: ${{ env.pythonLocation }}-${{ hashFiles('metricflow/poetry.lock') }}

- name: Install Poetry
run: pip install poetry==1.1.15 && poetry config virtualenvs.create false

- name: Install Deps
run: cd metricflow && poetry install

- name: Run MetricFlow unit tests with Databricks configs
run: pytest metricflow/test/
env:
MF_SQL_ENGINE_URL: ${{ secrets.MF_DATABRICKS_URL }}
MF_SQL_ENGINE_PASSWORD: ${{ secrets.MF_DATABRICKS_PWD }}

postgres-tests:
name: PostgreSQL Tests
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'run_mf_sql_engine_tests' }}
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ celerybeat.pid
.env
.venv
env/
venv/
*venv
ENV/
env.bak/
venv.bak/
Expand Down
2 changes: 2 additions & 0 deletions metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
start_end_time_options,
generate_duckdb_demo_keys,
MF_POSTGRESQL_KEYS,
MF_DATABRICKS_KEYS,
)
from metricflow.configuration.config_builder import YamlTemplateBuilder
from metricflow.dataflow.sql_table import SqlTable
Expand Down Expand Up @@ -144,6 +145,7 @@ def setup(cfg: CLIContext, restart: bool) -> None:
SqlDialect.REDSHIFT.value: MF_REDSHIFT_KEYS,
SqlDialect.POSTGRESQL.value: MF_POSTGRESQL_KEYS,
SqlDialect.DUCKDB.value: generate_duckdb_demo_keys(config_dir=cfg.config.dir_path),
SqlDialect.DATABRICKS.value: MF_DATABRICKS_KEYS,
}

click.echo("Please enter your data warehouse dialect.")
Expand Down
12 changes: 11 additions & 1 deletion metricflow/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
CONFIG_DWH_WAREHOUSE,
CONFIG_EMAIL,
CONFIG_MODEL_PATH,
CONFIG_DWH_HTTP_PATH,
CONFIG_DWH_ACCESS_TOKEN,
)
from metricflow.sql_clients.common_client import SqlDialect

Expand All @@ -51,7 +53,7 @@
ConfigKey(key=CONFIG_DWH_DIALECT, value=SqlDialect.BIGQUERY.value),
)

# Redshift config keys
# Postgres config keys
MF_POSTGRESQL_KEYS = (
ConfigKey(key=CONFIG_DWH_DB),
ConfigKey(key=CONFIG_DWH_PASSWORD, comment="Password associated with the provided user"),
Expand Down Expand Up @@ -80,6 +82,14 @@
ConfigKey(key=CONFIG_DWH_DIALECT, value=SqlDialect.SNOWFLAKE.value),
)

# Databricks config keys
MF_DATABRICKS_KEYS = (
ConfigKey(key=CONFIG_DWH_HTTP_PATH),
ConfigKey(key=CONFIG_DWH_HOST),
ConfigKey(key=CONFIG_DWH_ACCESS_TOKEN),
ConfigKey(key=CONFIG_DWH_DIALECT, value=SqlDialect.DATABRICKS.value),
)


def generate_duckdb_demo_keys(config_dir: str) -> Tuple[ConfigKey, ...]:
"""Generate configuration keys for DuckDB with a file in the config_dir."""
Expand Down
2 changes: 2 additions & 0 deletions metricflow/configuration/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
CONFIG_DWH_PROJECT_ID = "dwh_project_id"
CONFIG_EMAIL = "email"
CONFIG_MODEL_PATH = "model_path"
CONFIG_DWH_HTTP_PATH = "dwh_http_path"
CONFIG_DWH_ACCESS_TOKEN = "dwh_access_token"
6 changes: 6 additions & 0 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class SupportedSqlEngine(Enum):
REDSHIFT = "Redshift"
POSTGRES = "Postgres"
SNOWFLAKE = "Snowflake"
DATABRICKS = "Databricks"


class SqlClient(Protocol):
Expand Down Expand Up @@ -143,6 +144,11 @@ def cancel_submitted_queries(self) -> None: # noqa: D
"""Cancel queries submitted through this client (that may be still running) with best-effort."""
raise NotImplementedError

@abstractmethod
def render_execution_param_key(self, execution_param_key: str) -> str:
"""Wrap execution parameter key with syntax accepted by engine."""
raise NotImplementedError


class SqlEngineAttributes(Protocol):
"""Base interface for SQL engine-specific attributes and features
Expand Down
4 changes: 4 additions & 0 deletions metricflow/sql_clients/base_sql_client_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,7 @@ def drop_table(self, sql_table: SqlTable) -> None: # noqa: D

def close(self) -> None: # noqa: D
pass

def render_execution_param_key(self, execution_param_key: str) -> str:
"""Wrap execution parameter key with syntax accepted by engine."""
return f":{execution_param_key}"
1 change: 1 addition & 0 deletions metricflow/sql_clients/common_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class SqlDialect(ExtendedEnum):
MYSQL = "mysql"
SNOWFLAKE = "snowflake"
BIGQUERY = "bigquery"
DATABRICKS = "databricks"


T = TypeVar("T")
Expand Down
176 changes: 176 additions & 0 deletions metricflow/sql_clients/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
from __future__ import annotations
from typing import Optional, List, ClassVar, Dict
import pandas as pd
import logging
import time
import sqlalchemy
from databricks import sql
from metricflow.sql_clients.common_client import SqlDialect
from metricflow.sql_clients.base_sql_client_implementation import BaseSqlClientImplementation
from metricflow.protocols.sql_client import SqlEngineAttributes, SupportedSqlEngine
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.dataflow.sql_table import SqlTable
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer, SqlQueryPlanRenderer

logger = logging.getLogger(__name__)

HTTP_PATH_KEY = "httppath="
PANDAS_TO_SQL_DTYPES = {
"object": "string",
"float64": "double",
"bool": "boolean",
"int64": "int",
"datetime64[ns]": "timestamp",
}


class DatabricksEngineAttributes(SqlEngineAttributes):
"""SQL engine attributes for Databricks."""

sql_engine_type: ClassVar[SupportedSqlEngine] = SupportedSqlEngine.DATABRICKS

# SQL Engine capabilities
date_trunc_supported: ClassVar[bool] = True
full_outer_joins_supported: ClassVar[bool] = True
indexes_supported: ClassVar[bool] = True
multi_threading_supported: ClassVar[bool] = True
timestamp_type_supported: ClassVar[bool] = True
timestamp_to_string_comparison_supported: ClassVar[bool] = True
# So far the only clear way to cancel a query is through the Databricks UI.
cancel_submitted_queries_supported: ClassVar[bool] = False

# SQL Dialect replacement strings
double_data_type_name: ClassVar[str] = "DOUBLE"
timestamp_type_name: ClassVar[Optional[str]] = "TIMESTAMP"

# MetricFlow attributes
sql_query_plan_renderer: ClassVar[SqlQueryPlanRenderer] = DefaultSqlQueryPlanRenderer()


class DatabricksSqlClient(BaseSqlClientImplementation):
"""Client used to connect to Databricks engine."""

def __init__(self, host: str, http_path: str, access_token: str) -> None: # noqa: D
self.host = host
self.http_path = http_path
self.access_token = access_token

@staticmethod
def from_connection_details(url: str, password: Optional[str]) -> DatabricksSqlClient: # noqa: D
"""Parse MF_SQL_ENGINE_URL & MF_SQL_ENGINE_PASSWORD into useful connection params.
Using just these 2 env variables ensures uniformity across engines.
"""
try:
split_url = url.split(";")
parsed_url = sqlalchemy.engine.url.make_url(split_url[0])
http_path = ""
for piece in split_url[1:]:
if HTTP_PATH_KEY in piece.lower():
__, http_path = piece.split("=")
break
dialect = SqlDialect.DATABRICKS.value
if not http_path or parsed_url.drivername != dialect or not parsed_url.host:
raise ValueError
except ValueError:
# If any errors in parsing URL, show user what expected URL looks like.
raise ValueError(
"Unexpected format for MF_SQL_ENGINE_URL. Expected: `databricks://<HOST>:443;HttpPath=<HTTP PATH>"
)

if not password:
raise ValueError(f"Password not supplied for {url}")

return DatabricksSqlClient(host=parsed_url.host, http_path=http_path, access_token=password)

def get_connection(self) -> sql.client.Connection:
"""Get connection to Databricks cluster/warehouse."""
return sql.connect(server_hostname=self.host, http_path=self.http_path, access_token=self.access_token)

@property
def sql_engine_attributes(self) -> SqlEngineAttributes:
"""Databricks engine attributes."""
return DatabricksEngineAttributes()

@staticmethod
def params_or_none(bind_params: SqlBindParameters) -> Optional[Dict[str, str]]:
"""If there are no parameters, use None to prevent collision with `%` wildcard."""
return None if bind_params == SqlBindParameters() else bind_params.param_dict

def _engine_specific_query_implementation(self, stmt: str, bind_params: SqlBindParameters) -> pd.DataFrame:
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.execute(operation=stmt, parameters=self.params_or_none(bind_params))
logger.info("Fetching query results as PyArrow Table.")
pyarrow_df = cursor.fetchall_arrow()

logger.info("Beginning conversion of PyArrow Table to pandas DataFrame.")
pandas_df = pyarrow_df.to_pandas()
logger.info("Completed conversion of PyArrow Table to pandas DataFrame.")
return pandas_df

def _engine_specific_execute_implementation(self, stmt: str, bind_params: SqlBindParameters) -> None:
"""Execute statement, returning nothing."""
with self.get_connection() as connection:
with connection.cursor() as cursor:
logger.info(f"Executing SQL statment: {stmt}")
cursor.execute(operation=stmt, parameters=self.params_or_none(bind_params))

def _engine_specific_dry_run_implementation(self, stmt: str, bind_params: SqlBindParameters) -> None:
"""Check that query will run successfully without actually running the query, error if not."""
stmt = f"EXPLAIN {stmt}"

with self.get_connection() as connection:
with connection.cursor() as cursor:
logger.info(f"Executing SQL statment: {stmt}")
cursor.execute(operation=stmt, parameters=self.params_or_none(bind_params))

# If the plan contains errors, they won't be raised. Parse plan string to find & raise errors.
result = str(cursor.fetchall_arrow()["plan"][0])
if "org.apache.spark.sql.AnalysisException" in result:
error = result.split("== Physical Plan ==")[1].split(";")[0]
raise sql.exc.ServerOperationError(error)

def create_table_from_dataframe( # noqa: D
self, sql_table: SqlTable, df: pd.DataFrame, chunk_size: Optional[int] = None
) -> None:
logger.info(f"Creating table '{sql_table.sql}' from a DataFrame with {df.shape[0]} row(s)")
start_time = time.time()
with self.get_connection() as connection:
with connection.cursor() as cursor:
# Create table
columns = df.columns
columns_to_insert = []
for i in range(len(df.columns)):
# Format as "column_name column_type"
columns_to_insert.append(f"{columns[i]} {PANDAS_TO_SQL_DTYPES[str(df[columns[i]].dtype)]}")
cursor.execute(f"CREATE TABLE IF NOT EXISTS {sql_table.sql} ({', '.join(columns_to_insert)})")

# Insert rows
values = []
for row in df.itertuples(index=False, name=None):
cells = []
for cell in row:
if type(cell) in [str, pd.Timestamp]:
# Wrap cell in quotes & escape existing single quotes
escaped_cell = str(cell).replace("'", "\\'")
cells.append(f"'{escaped_cell}'")
else:
cells.append(str(cell))
values.append(f"({', '.join(cells)})")
cursor.execute(f"INSERT INTO {sql_table.sql} VALUES {', '.join(values)}")

logger.info(f"Created table '{sql_table.sql}' from a DataFrame in {time.time() - start_time:.2f}s")

def list_tables(self, schema_name: str) -> List[str]: # noqa: D
with self.get_connection() as connection:
with connection.cursor() as cursor:
cursor.tables(schema_name=schema_name)
return [table.TABLE_NAME for table in cursor.fetchall()]

def cancel_submitted_queries(self) -> None: # noqa: D
pass

def render_execution_param_key(self, execution_param_key: str) -> str:
"""Wrap execution parameter key with syntax accepted by engine."""
return f"%({execution_param_key})s"
2 changes: 1 addition & 1 deletion metricflow/sql_clients/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__( # noqa: D

@property
def sql_engine_attributes(self) -> SqlEngineAttributes:
"""Collection of attributes and features specific to the Snowflake SQL engine"""
"""Collection of attributes and features specific to the Postgres SQL engine"""
return PostgresEngineAttributes()

def cancel_submitted_queries(self) -> None: # noqa: D
Expand Down
15 changes: 13 additions & 2 deletions metricflow/sql_clients/sql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
CONFIG_DWH_PROJECT_ID,
CONFIG_DWH_USER,
CONFIG_DWH_WAREHOUSE,
CONFIG_DWH_ACCESS_TOKEN,
CONFIG_DWH_HTTP_PATH,
)
from metricflow.configuration.yaml_handler import YamlFileHandler
from metricflow.protocols.sql_client import SqlClient
Expand All @@ -24,6 +26,7 @@
from metricflow.sql_clients.postgres import PostgresSqlClient
from metricflow.sql_clients.redshift import RedshiftSqlClient
from metricflow.sql_clients.snowflake import SnowflakeSqlClient
from metricflow.sql_clients.databricks import DatabricksSqlClient


def make_df( # type: ignore [misc]
Expand Down Expand Up @@ -56,8 +59,9 @@ def make_df( # type: ignore [misc]
)


def make_sql_client(url: str, password: str) -> SqlClient: # noqa: D
dialect_protocol = make_url(url).drivername.split("+")
def make_sql_client(url: str, password: str) -> SqlClient:
"""Build SQL client based on env configs. Used only in tests."""
dialect_protocol = make_url(url.split(";")[0]).drivername.split("+")
dialect = SqlDialect(dialect_protocol[0])
if len(dialect_protocol) > 2:
raise ValueError(f"Invalid # of +'s in {url}")
Expand All @@ -72,6 +76,8 @@ def make_sql_client(url: str, password: str) -> SqlClient: # noqa: D
return PostgresSqlClient.from_connection_details(url, password)
elif dialect == SqlDialect.DUCKDB:
return DuckDbSqlClient.from_connection_details(url, password)
elif dialect == SqlDialect.DATABRICKS:
return DatabricksSqlClient.from_connection_details(url, password)
else:
raise ValueError(f"Unknown dialect: `{dialect}` in URL {url}")

Expand Down Expand Up @@ -139,6 +145,11 @@ def make_sql_client_from_config(handler: YamlFileHandler) -> SqlClient:
password=password,
database=database,
)
elif dialect == SqlDialect.DATABRICKS.value:
host = not_empty(handler.get_value(CONFIG_DWH_HOST), CONFIG_DWH_HOST, url)
access_token = not_empty(handler.get_value(CONFIG_DWH_ACCESS_TOKEN), CONFIG_DWH_ACCESS_TOKEN, url)
http_path = not_empty(handler.get_value(CONFIG_DWH_HTTP_PATH), CONFIG_DWH_HTTP_PATH, url)
return DatabricksSqlClient(host=host, access_token=access_token, http_path=http_path)
else:
supported_dialects = [x.value for x in SqlDialect]
raise ValueError(f"Invalid dialect '{dialect}', must be one of {supported_dialects} in {url}")
10 changes: 9 additions & 1 deletion metricflow/test/compare_df.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging

import math
import pandas as pd

Expand All @@ -26,6 +25,15 @@ def _dataframes_contain_same_data(
elif isinstance(expected.iloc[c, r], float) and isinstance(actual.iloc[c, r], float):
if not math.isclose(expected.iloc[c, r], actual.iloc[c, r]):
return False
elif (
isinstance(expected.iloc[c, r], pd.Timestamp)
and isinstance(actual.iloc[c, r], pd.Timestamp)
# If expected has no tz but actual is UTC, remove timezone. Some engines add UTC by default.
and actual.iloc[c, r].tzname() == "UTC"
and expected.iloc[c, r].tzname() is None
):
if actual.iloc[c, r].tz_localize(None) != expected.iloc[c, r].tz_localize(None):
return False
elif expected.iloc[c, r] != actual.iloc[c, r]:
return False
return True
Expand Down
Loading

0 comments on commit 7ae092e

Please sign in to comment.