Skip to content

Commit

Permalink
Merge pull request #944 from dbt-labs/add-trino-support-with-cd-config
Browse files Browse the repository at this point in the history
Add Trino Support
  • Loading branch information
tlento authored Dec 19, 2023
2 parents 92a517e + 7e931bf commit aa7724b
Show file tree
Hide file tree
Showing 227 changed files with 30,629 additions and 88 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231008-195608.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add Trino support to the MetricFlow.
time: 2023-10-08T19:56:08.427006-06:00
custom:
Author: sarbmeetka
Issue: "207"
23 changes: 23 additions & 0 deletions .github/workflows/cd-sql-engine-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ jobs:
additional-pytest-options: ${{ env.ADDITIONAL_PYTEST_OPTIONS }}
make-target: "test-databricks"

trino-tests:
# Trino tests run on a local service container, which obviates the need for separate Environment hosting.
# We run them here instead of in the CI unit test suite because they are a bit slower, and because in future
# we may choose to execute them against a hosted instance, at which point this config will look like the other
# engine configs in this file.
name: Trino Tests
if: ${{ github.event.action != 'labeled' || github.event.label.name == 'Run Tests With Other SQL Engines' }}
runs-on: ubuntu-latest
services:
trino:
image: trinodb/trino
ports:
- 8080:8080
steps:
- name: Check-out the repo
uses: actions/checkout@v3

- name: Test w/ Python 3.11
uses: ./.github/actions/run-mf-tests
with:
python-version: "3.11"
make-target: "test-trino"

remove-label:
name: Remove Label After Running Tests
runs-on: ubuntu-latest
Expand Down
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ test-snowflake:
populate-persistent-source-schema-snowflake:
hatch -v run snowflake-env:pytest -vv $(ADDITIONAL_PYTEST_OPTIONS) $(USE_PERSISTENT_SOURCE_SCHEMA) $(POPULATE_PERSISTENT_SOURCE_SCHEMA)

.PHONY: test-trino
test-trino:
hatch -v run trino-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/

.PHONY: lint
lint:
Expand Down
3 changes: 3 additions & 0 deletions dbt-metricflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ redshift = [
snowflake = [
"dbt-snowflake~=1.7.0"
]
trino = [
"dbt-trino~=1.7.0"
]

[tool.hatch.build.targets.sdist]
exclude = [
Expand Down
19 changes: 18 additions & 1 deletion metricflow/cli/dbt_connectors/adapter_backed_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow.sql.render.redshift import RedshiftSqlQueryPlanRenderer
from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.render.trino import TrinoSqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_request.sql_request_attributes import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.sql_request.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata
Expand All @@ -42,6 +43,7 @@ class SupportedAdapterTypes(enum.Enum):
REDSHIFT = "redshift"
BIGQUERY = "bigquery"
DUCKDB = "duckdb"
TRINO = "trino"

@property
def sql_engine_type(self) -> SqlEngine:
Expand All @@ -58,6 +60,8 @@ def sql_engine_type(self) -> SqlEngine:
return SqlEngine.SNOWFLAKE
elif self is SupportedAdapterTypes.DUCKDB:
return SqlEngine.DUCKDB
elif self is SupportedAdapterTypes.TRINO:
return SqlEngine.TRINO
else:
assert_values_exhausted(self)

Expand All @@ -76,6 +80,8 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
return SnowflakeSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.DUCKDB:
return DuckDbSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.TRINO:
return TrinoSqlQueryPlanRenderer()
else:
assert_values_exhausted(self)

Expand Down Expand Up @@ -213,7 +219,18 @@ def dry_run(
request_id = SqlRequestId(f"mf_rid__{random_id()}")
connection_name = f"MetricFlow_dry_run_request_{request_id}"
# TODO - consolidate to self._adapter.validate_sql() when all implementations will work from within MetricFlow
if self.sql_engine_type is SqlEngine.BIGQUERY:

# Trino has a bug where explain command actually creates table. Wrapping with validate to avoid this.
# See https://github.com/trinodb/trino/issues/130
if self.sql_engine_type is SqlEngine.TRINO:
with self._adapter.connection_named(connection_name):
# Either the response will be bool value or a string with error message from Trino.
result = self._adapter.execute(f"EXPLAIN (type validate) {stmt}", auto_begin=True, fetch=True)
has_error = False if str(result[0]) == "SUCCESS" else True
if has_error:
raise DbtDatabaseError("Encountered error in Trino dry run.")

elif self.sql_engine_type is SqlEngine.BIGQUERY:
with self._adapter.connection_named(connection_name):
self._adapter.validate_sql(stmt)
else:
Expand Down
1 change: 1 addition & 0 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SqlEngine(Enum):
POSTGRES = "Postgres"
SNOWFLAKE = "Snowflake"
DATABRICKS = "Databricks"
TRINO = "Trino"


class SqlClient(Protocol):
Expand Down
126 changes: 126 additions & 0 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

from typing import Collection

from dateutil.parser import parse
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
SqlExpressionRenderResult,
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_exprs import (
SqlBetweenExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
SqlSubtractTimeIntervalExpression,
)


class TrinoSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the Trino engine."""

@property
@override
def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]:
return {
SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS,
}

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
sql="uuid()",
bind_parameters=SqlBindParameters(),
)

@override
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta for Trino, require granularity in quotes and function name change."""
arg_rendered = node.arg.accept(self)

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
sql=f"DATE_ADD('{granularity.value}', -{count}, {arg_rendered.sql})",
bind_parameters=arg_rendered.bind_parameters,
)

@override
def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult:
"""Render a percentile expression for Trino."""
arg_rendered = self.render_sql_expr(node.order_by_arg)
params = arg_rendered.bind_parameters
percentile = node.percentile_args.percentile

if node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS:
return SqlExpressionRenderResult(
sql=f"approx_percentile({arg_rendered.sql}, {percentile})",
bind_parameters=params,
)
elif (
node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE
or node.percentile_args.function_type is SqlPercentileFunctionType.DISCRETE
or node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS
):
raise RuntimeError(
"Discrete, Continuous and Approximate discrete percentile aggregates are not supported for Trino. Set "
+ "use_approximate_percentile and disable use_discrete_percentile in all percentile measures."
)
else:
assert_values_exhausted(node.percentile_args.function_type)

@override
def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderResult:
"""Render a between expression for Trino. If the expression is a timestamp literal then wrap literals with timestamp."""
rendered_column_arg = self.render_sql_expr(node.column_arg)
rendered_start_expr = self.render_sql_expr(node.start_expr)
rendered_end_expr = self.render_sql_expr(node.end_expr)

bind_parameters = SqlBindParameters()
bind_parameters = bind_parameters.combine(rendered_column_arg.bind_parameters)
bind_parameters = bind_parameters.combine(rendered_start_expr.bind_parameters)
bind_parameters = bind_parameters.combine(rendered_end_expr.bind_parameters)

# Handle timestamp literals differently.
if parse(rendered_start_expr.sql):
sql = f"{rendered_column_arg.sql} BETWEEN timestamp {rendered_start_expr.sql} AND timestamp {rendered_end_expr.sql}"
else:
sql = f"{rendered_column_arg.sql} BETWEEN {rendered_start_expr.sql} AND {rendered_end_expr.sql}"

return SqlExpressionRenderResult(
sql=sql,
bind_parameters=bind_parameters,
)

@override
def render_date_part(self, date_part: DatePart) -> str:
"""Render DATE PART for an EXTRACT expression.
Override DAY_OF_WEEK in Trino to ISO date part to ensure all engines return consistent results.
"""
if date_part is DatePart.DOW:
return "DAY_OF_WEEK"

return date_part.value


class TrinoSqlQueryPlanRenderer(DefaultSqlQueryPlanRenderer):
"""Plan renderer for the Trino engine."""

EXPR_RENDERER = TrinoSqlExpressionRenderer()

@property
@override
def expr_renderer(self) -> SqlExpressionRenderer:
return self.EXPR_RENDERER
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,14 @@ duckdb:
dev:
type: duckdb
schema: "{{ env_var('DBT_ENV_SECRET_SCHEMA') }}"
trino:
target: dev
outputs:
dev:
type: trino
host: "{{ env_var('DBT_ENV_SECRET_HOST') }}"
port: "{{ env_var('DBT_PROFILE_PORT') | int }}"
user: "{{ env_var('DBT_ENV_SECRET_USER') }}"
password: "{{ env_var('DBT_ENV_SECRET_PASSWORD') }}"
catalog: "{{ env_var('DBT_ENV_SECRET_CATALOG') }}"
schema: "{{ env_var('DBT_ENV_SECRET_SCHEMA') }}"
7 changes: 7 additions & 0 deletions metricflow/test/fixtures/sql_client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
DBT_ENV_SECRET_PROJECT_ID = "DBT_ENV_SECRET_PROJECT_ID"
DBT_ENV_SECRET_TOKEN_URI = "DBT_ENV_SECRET_TOKEN_URI"

# Trino is special, so it gets its own set of env vars. Keeping them split out here for consistency.
DBT_ENV_SECRET_CATALOG = "DBT_ENV_SECRET_CATALOG"


def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlalchemy.engine.URL:
"""Populates default env var mapping from a sqlalchemy URL string.
Expand Down Expand Up @@ -163,6 +166,10 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithD
__configure_databricks_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("databricks"))
elif dialect == SqlDialect.TRINO:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("trino"))
else:
raise ValueError(f"Unknown dialect: `{dialect}` in URL {url}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_table_from_dataframe(
# This mirrors the SQLAlchemy schema detection logic in pandas.io.sql
df = df.convert_dtypes()
columns = df.columns

columns_to_insert = []
for i in range(len(df.columns)):
# Format as "column_name column_type"
Expand All @@ -63,7 +64,12 @@ def create_table_from_dataframe(
elif type(cell) in [str, pd.Timestamp]:
# Wrap cell in quotes & escape existing single quotes
escaped_cell = self._quote_escape_value(str(cell))
cells.append(f"'{escaped_cell}'")
# Trino requires timestamp literals to be wrapped in a timestamp() function.
# There is probably a better way to handle this.
if self.sql_engine_type is SqlEngine.TRINO and type(cell) is pd.Timestamp:
cells.append(f"timestamp '{escaped_cell}'")
else:
cells.append(f"'{escaped_cell}'")
else:
cells.append(str(cell))

Expand Down Expand Up @@ -93,6 +99,8 @@ def _get_type_from_pandas_dtype(self, dtype: str) -> str:
if dtype == "string" or dtype == "object":
if self.sql_engine_type is SqlEngine.DATABRICKS or self.sql_engine_type is SqlEngine.BIGQUERY:
return "string"
if self.sql_engine_type is SqlEngine.TRINO:
return "varchar"
return "text"
elif dtype == "boolean" or dtype == "bool":
return "boolean"
Expand Down
1 change: 1 addition & 0 deletions metricflow/test/fixtures/sql_clients/common_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SqlDialect(ExtendedEnum):
SNOWFLAKE = "snowflake"
BIGQUERY = "bigquery"
DATABRICKS = "databricks"
TRINO = "trino"


T = TypeVar("T")
Expand Down
10 changes: 10 additions & 0 deletions metricflow/test/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
"engine_url": postgres://...",
"engine_password": "..."
},
"trino": {
"engine_url": trino://...",
"engine_password": "..."
},
}
EOF
)
Expand Down Expand Up @@ -69,6 +73,7 @@ class MetricFlowTestCredentialSetForAllEngines(FrozenBaseModel): # noqa: D
big_query: MetricFlowTestCredentialSet
databricks: MetricFlowTestCredentialSet
postgres: MetricFlowTestCredentialSet
trino: MetricFlowTestCredentialSet

@property
def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
Expand Down Expand Up @@ -97,6 +102,10 @@ def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
engine=SqlEngine.POSTGRES,
credential_set=self.postgres,
),
MetricFlowTestConfiguration(
engine=SqlEngine.TRINO,
credential_set=self.trino,
),
)


Expand Down Expand Up @@ -137,6 +146,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa:
or test_configuration.engine is SqlEngine.BIGQUERY
or test_configuration.engine is SqlEngine.DATABRICKS
or test_configuration.engine is SqlEngine.POSTGRES
or test_configuration.engine is SqlEngine.TRINO
):
engine_name = test_configuration.engine.value.lower()
os.environ["MF_TEST_ADAPTER_TYPE"] = engine_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,18 @@ def test_cumulative_metric_with_non_adjustable_filter(
it_helpers: IntegrationTestHelpers,
) -> None:
"""Tests a cumulative metric with a filter that cannot be adjusted to ensure all data is included."""
# Handle ds expression based on engine to support Trino.
first_ds_expr = f"CAST('2020-03-15' AS {sql_client.sql_query_plan_renderer.expr_renderer.timestamp_data_type})"
second_ds_expr = f"CAST('2020-04-30' AS {sql_client.sql_query_plan_renderer.expr_renderer.timestamp_data_type})"
where_constraint = f"{{{{ TimeDimension('metric_time', 'day') }}}} = {first_ds_expr} or"
where_constraint += f" {{{{ TimeDimension('metric_time', 'day') }}}} = {second_ds_expr}"

query_result = it_helpers.mf_engine.query(
MetricFlowQueryRequest.create_with_random_request_id(
metric_names=["trailing_2_months_revenue"],
group_by_names=["metric_time"],
order_by_names=["metric_time"],
where_constraint=(
"{{ TimeDimension('metric_time', 'day') }} = '2020-03-15' or "
"{{ TimeDimension('metric_time', 'day') }} = '2020-04-30'"
),
where_constraint=where_constraint,
time_constraint_end=as_datetime("2020-12-31"),
)
)
Expand Down
Loading

0 comments on commit aa7724b

Please sign in to comment.