Skip to content

Commit

Permalink
Add trino adapter support for MetricFlow
Browse files Browse the repository at this point in the history
lint

Handle timestamp literals for range operator

Update test cases to handle group by for Trino in metrics calculation

Handle dry run and approx percentile for Trino along with more explicit castings

Add changelog

Fix percentile function and rename sqltimedelta function in Trino client

Handle extract for Trino

fix lint issues

Fix additional test case snapshots related to extract function

Use positional arguments in group by to fix tests

Fix time delta expression for Trino and regenerate snapshot

Add trino to cd-sql-engine-tests

Address function inline comments

bump version for dependencies
  • Loading branch information
sarbmeetka authored and tlento committed Dec 16, 2023
1 parent fe807f2 commit ed71dc2
Show file tree
Hide file tree
Showing 247 changed files with 39,644 additions and 112 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231008-195608.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add Trino support to the MetricFlow.
time: 2023-10-08T19:56:08.427006-06:00
custom:
Author: sarbmeetka
Issue: "207"
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ test-snowflake:
populate-persistent-source-schema-snowflake:
hatch -v run snowflake-env:pytest -vv $(ADDITIONAL_PYTEST_OPTIONS) $(USE_PERSISTENT_SOURCE_SCHEMA) $(POPULATE_PERSISTENT_SOURCE_SCHEMA)

.PHONY: test-trino
test-trino:
hatch -v run trino-env:pytest -vv -n $(PARALLELISM) $(ADDITIONAL_PYTEST_OPTIONS) metricflow/test/

.PHONY: lint
lint:
Expand Down
5 changes: 4 additions & 1 deletion dbt-metricflow/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ redshift = [
snowflake = [
"dbt-snowflake~=1.7.0"
]
trino = [
"dbt-trino~=1.7.0"
]

[tool.hatch.build.targets.sdist]
exclude = [
Expand All @@ -60,4 +63,4 @@ exclude = [
".pre-commit-config.yaml",
"CONTRIBUTING.md",
"MAKEFILE",
]
]
19 changes: 18 additions & 1 deletion metricflow/cli/dbt_connectors/adapter_backed_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow.sql.render.redshift import RedshiftSqlQueryPlanRenderer
from metricflow.sql.render.snowflake import SnowflakeSqlQueryPlanRenderer
from metricflow.sql.render.sql_plan_renderer import SqlQueryPlanRenderer
from metricflow.sql.render.trino import TrinoSqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql_request.sql_request_attributes import SqlJsonTag, SqlRequestId, SqlRequestTagSet
from metricflow.sql_request.sql_statement_metadata import CombinedSqlTags, SqlStatementCommentMetadata
Expand All @@ -42,6 +43,7 @@ class SupportedAdapterTypes(enum.Enum):
REDSHIFT = "redshift"
BIGQUERY = "bigquery"
DUCKDB = "duckdb"
TRINO = "trino"

@property
def sql_engine_type(self) -> SqlEngine:
Expand All @@ -58,6 +60,8 @@ def sql_engine_type(self) -> SqlEngine:
return SqlEngine.SNOWFLAKE
elif self is SupportedAdapterTypes.DUCKDB:
return SqlEngine.DUCKDB
elif self is SupportedAdapterTypes.TRINO:
return SqlEngine.TRINO
else:
assert_values_exhausted(self)

Expand All @@ -76,6 +80,8 @@ def sql_query_plan_renderer(self) -> SqlQueryPlanRenderer:
return SnowflakeSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.DUCKDB:
return DuckDbSqlQueryPlanRenderer()
elif self is SupportedAdapterTypes.TRINO:
return TrinoSqlQueryPlanRenderer()
else:
assert_values_exhausted(self)

Expand Down Expand Up @@ -213,7 +219,18 @@ def dry_run(
request_id = SqlRequestId(f"mf_rid__{random_id()}")
connection_name = f"MetricFlow_dry_run_request_{request_id}"
# TODO - consolidate to self._adapter.validate_sql() when all implementations will work from within MetricFlow
if self.sql_engine_type is SqlEngine.BIGQUERY:

# Trino has a bug where explain command actually creates table. Wrapping with validate to avoid this.
# See https://github.com/trinodb/trino/issues/130
if self.sql_engine_type is SqlEngine.TRINO:
with self._adapter.connection_named(connection_name):
# Either the response will be bool value or a string with error message from Trino.
result = self._adapter.execute(f"EXPLAIN (type validate) {stmt}", auto_begin=True, fetch=True)
has_error = False if str(result[0]) == "SUCCESS" else True
if has_error:
raise DbtDatabaseError("Encountered error in Trino dry run.")

elif self.sql_engine_type is SqlEngine.BIGQUERY:
with self._adapter.connection_named(connection_name):
self._adapter.validate_sql(stmt)
else:
Expand Down
1 change: 1 addition & 0 deletions metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SqlEngine(Enum):
POSTGRES = "Postgres"
SNOWFLAKE = "Snowflake"
DATABRICKS = "Databricks"
TRINO = "Trino"


class SqlClient(Protocol):
Expand Down
126 changes: 126 additions & 0 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from __future__ import annotations

from typing import Collection

from dateutil.parser import parse
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from typing_extensions import override

from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
SqlExpressionRenderResult,
)
from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer
from metricflow.sql.sql_bind_parameters import SqlBindParameters
from metricflow.sql.sql_exprs import (
SqlBetweenExpression,
SqlGenerateUuidExpression,
SqlPercentileExpression,
SqlPercentileFunctionType,
SqlSubtractTimeIntervalExpression,
)


class TrinoSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the Trino engine."""

@property
@override
def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]:
return {
SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS,
}

@override
def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult:
return SqlExpressionRenderResult(
sql="uuid()",
bind_parameters=SqlBindParameters(),
)

@override
def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult:
"""Render time delta for Trino, require granularity in quotes and function name change."""
arg_rendered = node.arg.accept(self)

count = node.count
granularity = node.granularity
if granularity == TimeGranularity.QUARTER:
granularity = TimeGranularity.MONTH
count *= 3
return SqlExpressionRenderResult(
sql=f"DATE_ADD('{granularity.value}', -{count}, {arg_rendered.sql})",
bind_parameters=arg_rendered.bind_parameters,
)

@override
def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult:
"""Render a percentile expression for Trino."""
arg_rendered = self.render_sql_expr(node.order_by_arg)
params = arg_rendered.bind_parameters
percentile = node.percentile_args.percentile

if node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_CONTINUOUS:
return SqlExpressionRenderResult(
sql=f"approx_percentile({arg_rendered.sql}, {percentile})",
bind_parameters=params,
)
elif (
node.percentile_args.function_type is SqlPercentileFunctionType.APPROXIMATE_DISCRETE
or node.percentile_args.function_type is SqlPercentileFunctionType.DISCRETE
or node.percentile_args.function_type is SqlPercentileFunctionType.CONTINUOUS
):
raise RuntimeError(
"Discrete, Continuous and Approximate discrete percentile aggregates are not supported for Trino. Set "
+ "use_approximate_percentile and disable use_discrete_percentile in all percentile measures."
)
else:
assert_values_exhausted(node.percentile_args.function_type)

@override
def visit_between_expr(self, node: SqlBetweenExpression) -> SqlExpressionRenderResult:
"""Render a between expression for Trino. If the expression is a timestamp literal then wrap literals with timestamp."""
rendered_column_arg = self.render_sql_expr(node.column_arg)
rendered_start_expr = self.render_sql_expr(node.start_expr)
rendered_end_expr = self.render_sql_expr(node.end_expr)

bind_parameters = SqlBindParameters()
bind_parameters = bind_parameters.combine(rendered_column_arg.bind_parameters)
bind_parameters = bind_parameters.combine(rendered_start_expr.bind_parameters)
bind_parameters = bind_parameters.combine(rendered_end_expr.bind_parameters)

# Handle timestamp literals differently.
if parse(rendered_start_expr.sql):
sql = f"{rendered_column_arg.sql} BETWEEN timestamp {rendered_start_expr.sql} AND timestamp {rendered_end_expr.sql}"
else:
sql = f"{rendered_column_arg.sql} BETWEEN {rendered_start_expr.sql} AND {rendered_end_expr.sql}"

return SqlExpressionRenderResult(
sql=sql,
bind_parameters=bind_parameters,
)

@override
def render_date_part(self, date_part: DatePart) -> str:
"""Render DATE PART for an EXTRACT expression.
Override DAY_OF_WEEK in Trino to ISO date part to ensure all engines return consistent results.
"""
if date_part is DatePart.DOW:
return "DAY_OF_WEEK"

return date_part.value


class TrinoSqlQueryPlanRenderer(DefaultSqlQueryPlanRenderer):
"""Plan renderer for the Trino engine."""

EXPR_RENDERER = TrinoSqlExpressionRenderer()

@property
@override
def expr_renderer(self) -> SqlExpressionRenderer:
return self.EXPR_RENDERER
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,14 @@ duckdb:
dev:
type: duckdb
schema: "{{ env_var('DBT_ENV_SECRET_SCHEMA') }}"
trino:
target: dev
outputs:
dev:
type: trino
host: "{{ env_var('DBT_ENV_SECRET_HOST') }}"
port: "{{ env_var('DBT_PROFILE_PORT') | int }}"
user: "{{ env_var('DBT_ENV_SECRET_USER') }}"
password: "{{ env_var('DBT_ENV_SECRET_PASSWORD') }}"
catalog: "{{ env_var('DBT_ENV_SECRET_CATALOG') }}"
schema: "{{ env_var('DBT_ENV_SECRET_SCHEMA') }}"
7 changes: 7 additions & 0 deletions metricflow/test/fixtures/sql_client_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
DBT_ENV_SECRET_PROJECT_ID = "DBT_ENV_SECRET_PROJECT_ID"
DBT_ENV_SECRET_TOKEN_URI = "DBT_ENV_SECRET_TOKEN_URI"

# Trino is special, so it gets its own set of env vars. Keeping them split out here for consistency.
DBT_ENV_SECRET_CATALOG = "DBT_ENV_SECRET_CATALOG"


def __configure_test_env_from_url(url: str, password: str, schema: str) -> sqlalchemy.engine.URL:
"""Populates default env var mapping from a sqlalchemy URL string.
Expand Down Expand Up @@ -163,6 +166,10 @@ def make_test_sql_client(url: str, password: str, schema: str) -> SqlClientWithD
__configure_databricks_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("databricks"))
elif dialect == SqlDialect.TRINO:
__configure_test_env_from_url(url, password=password, schema=schema)
__initialize_dbt()
return AdapterBackedDDLSqlClient(adapter=get_adapter_by_type("trino"))
else:
raise ValueError(f"Unknown dialect: `{dialect}` in URL {url}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_table_from_dataframe(
# This mirrors the SQLAlchemy schema detection logic in pandas.io.sql
df = df.convert_dtypes()
columns = df.columns

columns_to_insert = []
for i in range(len(df.columns)):
# Format as "column_name column_type"
Expand All @@ -63,7 +64,12 @@ def create_table_from_dataframe(
elif type(cell) in [str, pd.Timestamp]:
# Wrap cell in quotes & escape existing single quotes
escaped_cell = self._quote_escape_value(str(cell))
cells.append(f"'{escaped_cell}'")
# Trino requires timestamp literals to be wrapped in a timestamp() function.
# There is probably a better way to handle this.
if self.sql_engine_type is SqlEngine.TRINO and type(cell) is pd.Timestamp:
cells.append(f"timestamp '{escaped_cell}'")
else:
cells.append(f"'{escaped_cell}'")
else:
cells.append(str(cell))

Expand Down Expand Up @@ -93,6 +99,8 @@ def _get_type_from_pandas_dtype(self, dtype: str) -> str:
if dtype == "string" or dtype == "object":
if self.sql_engine_type is SqlEngine.DATABRICKS or self.sql_engine_type is SqlEngine.BIGQUERY:
return "string"
if self.sql_engine_type is SqlEngine.TRINO:
return "varchar"
return "text"
elif dtype == "boolean" or dtype == "bool":
return "boolean"
Expand Down
1 change: 1 addition & 0 deletions metricflow/test/fixtures/sql_clients/common_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class SqlDialect(ExtendedEnum):
SNOWFLAKE = "snowflake"
BIGQUERY = "bigquery"
DATABRICKS = "databricks"
TRINO = "trino"


T = TypeVar("T")
Expand Down
10 changes: 10 additions & 0 deletions metricflow/test/generate_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
"engine_url": postgres://...",
"engine_password": "..."
},
"trino": {
"engine_url": trino://...",
"engine_password": "..."
},
}
EOF
)
Expand Down Expand Up @@ -69,6 +73,7 @@ class MetricFlowTestCredentialSetForAllEngines(FrozenBaseModel): # noqa: D
big_query: MetricFlowTestCredentialSet
databricks: MetricFlowTestCredentialSet
postgres: MetricFlowTestCredentialSet
trino: MetricFlowTestCredentialSet

@property
def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
Expand Down Expand Up @@ -97,6 +102,10 @@ def as_configurations(self) -> Sequence[MetricFlowTestConfiguration]: # noqa: D
engine=SqlEngine.POSTGRES,
credential_set=self.postgres,
),
MetricFlowTestConfiguration(
engine=SqlEngine.TRINO,
credential_set=self.trino,
),
)


Expand Down Expand Up @@ -137,6 +146,7 @@ def run_tests(test_configuration: MetricFlowTestConfiguration) -> None: # noqa:
or test_configuration.engine is SqlEngine.BIGQUERY
or test_configuration.engine is SqlEngine.DATABRICKS
or test_configuration.engine is SqlEngine.POSTGRES
or test_configuration.engine is SqlEngine.TRINO
):
engine_name = test_configuration.engine.value.lower()
os.environ["MF_TEST_ADAPTER_TYPE"] = engine_name
Expand Down
4 changes: 2 additions & 2 deletions metricflow/test/integration/test_cases/itest_constraints.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ integration_test:
SELECT SUM(booking_value) AS booking_value
, ds AS metric_time__day
FROM {{ source_schema }}.fct_bookings b
WHERE ds = '2020-01-01'
WHERE {{ render_time_constraint("ds", "2020-01-01", "2020-01-01") }}
GROUP BY
ds
---
Expand All @@ -73,7 +73,7 @@ integration_test:
, ds AS metric_time__day
FROM {{ source_schema }}.fct_bookings b
WHERE is_instant
and ds = '2020-01-01'
and {{ render_time_constraint("ds", "2020-01-01", "2020-01-01") }}
GROUP BY ds
---
integration_test:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ integration_test:
group_bys: ["metric_time__day"]
order_bys: ["metric_time__day"]
where_filter: |
{{ render_time_dimension_template('metric_time', 'day') }} = '2019-12-20'
or {{ render_time_dimension_template('metric_time', 'day') }} = '2020-01-04'
{{ render_time_dimension_template('metric_time', 'day') }} = {{ cast_to_ts('2019-12-20') }}
or {{ render_time_dimension_template('metric_time', 'day') }} = {{ cast_to_ts('2020-01-04') }}
check_query: |
SELECT
COUNT (DISTINCT(b.guest_id)) as every_two_days_bookers
Expand Down
Loading

0 comments on commit ed71dc2

Please sign in to comment.