Skip to content

Commit

Permalink
Engine-specific rendering for sub-daily granularity options
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jun 7, 2024
1 parent 588a6f0 commit a0b69ac
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 3 deletions.
24 changes: 23 additions & 1 deletion metricflow/protocols/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from abc import abstractmethod
from enum import Enum
from typing import Protocol
from typing import Protocol, Set

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters

from metricflow.data_table.mf_table import MetricFlowDataTable
Expand All @@ -24,6 +26,26 @@ class SqlEngine(Enum):
DATABRICKS = "Databricks"
TRINO = "Trino"

@property
def unsupported_granularities(self) -> Set[TimeGranularity]:
"""Granularities that can't be used with this SqlEngine."""
if self is SqlEngine.SNOWFLAKE:
return set()
elif self is SqlEngine.BIGQUERY:
return {TimeGranularity.NANOSECOND}
elif self is SqlEngine.DATABRICKS:
return {TimeGranularity.NANOSECOND}
elif self is SqlEngine.DUCKDB:
return {TimeGranularity.NANOSECOND}
elif self is SqlEngine.POSTGRES:
return {TimeGranularity.NANOSECOND}
elif self is SqlEngine.REDSHIFT:
return {TimeGranularity.NANOSECOND}
elif self is SqlEngine.TRINO:
return {TimeGranularity.NANOSECOND, TimeGranularity.MICROSECOND}
else:
assert_values_exhausted(self)


class SqlClient(Protocol):
"""Base interface for SqlClient instances used inside MetricFlow.
Expand Down
9 changes: 8 additions & 1 deletion metricflow/sql/render/big_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from typing_extensions import override

from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
Expand All @@ -31,6 +32,8 @@
class BigQuerySqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the BigQuery engine."""

sql_engine = SqlEngine.BIGQUERY

@property
@override
def double_data_type(self) -> str:
Expand Down Expand Up @@ -120,14 +123,18 @@ def visit_cast_to_timestamp_expr(self, node: SqlCastToTimestampExpression) -> Sq
@override
def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRenderResult:
"""Render DATE_TRUNC for BigQuery, which takes the opposite argument order from Snowflake and Redshift."""
self._validate_granularity_for_engine(node.time_granularity)

arg_rendered = self.render_sql_expr(node.arg)

prefix = ""
if node.time_granularity == TimeGranularity.WEEK:
prefix = "iso"

trunc_expr = "DATE_TRUNC" if node.time_granularity.to_int() >= TimeGranularity.DAY.to_int() else "TIME_TRUNC"

return SqlExpressionRenderResult(
sql=f"DATE_TRUNC({arg_rendered.sql}, {prefix}{node.time_granularity.value})",
sql=f"{trunc_expr}({arg_rendered.sql}, {prefix}{node.time_granularity.value})",
bind_parameters=arg_rendered.bind_parameters,
)

Expand Down
3 changes: 3 additions & 0 deletions metricflow/sql/render/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError
from typing_extensions import override

from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
Expand All @@ -19,6 +20,8 @@
class DatabricksSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the Databricks engine."""

sql_engine = SqlEngine.DATABRICKS

@property
@override
def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]:
Expand Down
3 changes: 3 additions & 0 deletions metricflow/sql/render/duckdb_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from typing_extensions import override

from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
Expand All @@ -24,6 +25,8 @@
class DuckDbSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the DuckDB engine."""

sql_engine = SqlEngine.DUCKDB

@property
@override
def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]:
Expand Down
16 changes: 15 additions & 1 deletion metricflow/sql/render/expr_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass
from typing import Collection, List
from typing import TYPE_CHECKING, Collection, List, Optional

import jinja2
from dbt_semantic_interfaces.type_enums.date_part import DatePart
Expand Down Expand Up @@ -41,6 +41,10 @@
)
from metricflow.sql.sql_plan import SqlSelectColumn

if TYPE_CHECKING:
from metricflow.protocols.sql_client import SqlEngine


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -93,6 +97,10 @@ def can_render_percentile_function(self, percentile_type: SqlPercentileFunctionT
class DefaultSqlExpressionRenderer(SqlExpressionRenderer):
"""Renders the SQL query plan assuming ANSI SQL."""

@property
def sql_engine(self) -> Optional[SqlEngine]: # noqa: D102
return None

@property
@override
def double_data_type(self) -> str:
Expand Down Expand Up @@ -263,7 +271,13 @@ def visit_cast_to_timestamp_expr( # noqa: D102
bind_parameters=arg_rendered.bind_parameters,
)

def _validate_granularity_for_engine(self, time_granularity: TimeGranularity) -> None:
if self.sql_engine and time_granularity in self.sql_engine.unsupported_granularities:
raise RuntimeError(f"{self.sql_engine.name} does not support time granularity {time_granularity.name}.")

def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRenderResult: # noqa: D102
self._validate_granularity_for_engine(node.time_granularity)

arg_rendered = self.render_sql_expr(node.arg)

return SqlExpressionRenderResult(
Expand Down
3 changes: 3 additions & 0 deletions metricflow/sql/render/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from typing_extensions import override

from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
Expand All @@ -25,6 +26,8 @@
class PostgresSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the PostgreSQL engine."""

sql_engine = SqlEngine.POSTGRES

@property
@override
def double_data_type(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions metricflow/sql/render/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from typing_extensions import override

from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
Expand All @@ -25,6 +26,8 @@
class RedshiftSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the Redshift engine."""

sql_engine = SqlEngine.REDSHIFT

@property
@override
def double_data_type(self) -> str:
Expand Down
3 changes: 3 additions & 0 deletions metricflow/sql/render/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from typing_extensions import override

from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
Expand All @@ -24,6 +25,8 @@
class SnowflakeSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the Snowflake engine."""

sql_engine = SqlEngine.SNOWFLAKE

@property
@override
def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]:
Expand Down
3 changes: 3 additions & 0 deletions metricflow/sql/render/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters
from typing_extensions import override

from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.render.expr_renderer import (
DefaultSqlExpressionRenderer,
SqlExpressionRenderer,
Expand All @@ -27,6 +28,8 @@
class TrinoSqlExpressionRenderer(DefaultSqlExpressionRenderer):
"""Expression renderer for the Trino engine."""

sql_engine = SqlEngine.TRINO

@property
@override
def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]:
Expand Down

0 comments on commit a0b69ac

Please sign in to comment.