diff --git a/metricflow/protocols/sql_client.py b/metricflow/protocols/sql_client.py index c2a9151116..ca1822246d 100644 --- a/metricflow/protocols/sql_client.py +++ b/metricflow/protocols/sql_client.py @@ -2,8 +2,10 @@ from abc import abstractmethod from enum import Enum -from typing import Protocol +from typing import Protocol, Set +from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from metricflow.data_table.mf_table import MetricFlowDataTable @@ -24,6 +26,26 @@ class SqlEngine(Enum): DATABRICKS = "Databricks" TRINO = "Trino" + @property + def unsupported_granularities(self) -> Set[TimeGranularity]: + """Granularities that can't be used with this SqlEngine.""" + if self is SqlEngine.SNOWFLAKE: + return set() + elif self is SqlEngine.BIGQUERY: + return {TimeGranularity.NANOSECOND} + elif self is SqlEngine.DATABRICKS: + return {TimeGranularity.NANOSECOND} + elif self is SqlEngine.DUCKDB: + return {TimeGranularity.NANOSECOND} + elif self is SqlEngine.POSTGRES: + return {TimeGranularity.NANOSECOND} + elif self is SqlEngine.REDSHIFT: + return {TimeGranularity.NANOSECOND} + elif self is SqlEngine.TRINO: + return {TimeGranularity.NANOSECOND, TimeGranularity.MICROSECOND} + else: + assert_values_exhausted(self) + class SqlClient(Protocol): """Base interface for SqlClient instances used inside MetricFlow. diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index 0bc4222e66..e1aa091062 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -10,6 +10,7 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from typing_extensions import override +from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, SqlExpressionRenderer, @@ -31,6 +32,8 @@ class BigQuerySqlExpressionRenderer(DefaultSqlExpressionRenderer): """Expression renderer for the BigQuery engine.""" + sql_engine = SqlEngine.BIGQUERY + @property @override def double_data_type(self) -> str: @@ -120,14 +123,18 @@ def visit_cast_to_timestamp_expr(self, node: SqlCastToTimestampExpression) -> Sq @override def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRenderResult: """Render DATE_TRUNC for BigQuery, which takes the opposite argument order from Snowflake and Redshift.""" + self._validate_granularity_for_engine(node.time_granularity) + arg_rendered = self.render_sql_expr(node.arg) prefix = "" if node.time_granularity == TimeGranularity.WEEK: prefix = "iso" + trunc_expr = "DATE_TRUNC" if node.time_granularity.to_int() >= TimeGranularity.DAY.to_int() else "TIME_TRUNC" + return SqlExpressionRenderResult( - sql=f"DATE_TRUNC({arg_rendered.sql}, {prefix}{node.time_granularity.value})", + sql=f"{trunc_expr}({arg_rendered.sql}, {prefix}{node.time_granularity.value})", bind_parameters=arg_rendered.bind_parameters, ) diff --git a/metricflow/sql/render/databricks.py b/metricflow/sql/render/databricks.py index 01ffd9d219..a4f0510062 100644 --- a/metricflow/sql/render/databricks.py +++ b/metricflow/sql/render/databricks.py @@ -7,6 +7,7 @@ from metricflow_semantics.errors.error_classes import UnsupportedEngineFeatureError from typing_extensions import override +from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, SqlExpressionRenderer, @@ -19,6 +20,8 @@ class DatabricksSqlExpressionRenderer(DefaultSqlExpressionRenderer): """Expression renderer for the Databricks engine.""" + sql_engine = SqlEngine.DATABRICKS + @property @override def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]: diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index ad90f2c53d..53094de945 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -7,6 +7,7 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from typing_extensions import override +from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, SqlExpressionRenderer, @@ -24,6 +25,8 @@ class DuckDbSqlExpressionRenderer(DefaultSqlExpressionRenderer): """Expression renderer for the DuckDB engine.""" + sql_engine = SqlEngine.DUCKDB + @property @override def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]: diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 795ea02b9e..71abb10b9d 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections import namedtuple from dataclasses import dataclass -from typing import Collection, List +from typing import TYPE_CHECKING, Collection, List, Optional import jinja2 from dbt_semantic_interfaces.type_enums.date_part import DatePart @@ -41,6 +41,10 @@ ) from metricflow.sql.sql_plan import SqlSelectColumn +if TYPE_CHECKING: + from metricflow.protocols.sql_client import SqlEngine + + logger = logging.getLogger(__name__) @@ -93,6 +97,10 @@ def can_render_percentile_function(self, percentile_type: SqlPercentileFunctionT class DefaultSqlExpressionRenderer(SqlExpressionRenderer): """Renders the SQL query plan assuming ANSI SQL.""" + @property + def sql_engine(self) -> Optional[SqlEngine]: # noqa: D102 + return None + @property @override def double_data_type(self) -> str: @@ -263,7 +271,13 @@ def visit_cast_to_timestamp_expr( # noqa: D102 bind_parameters=arg_rendered.bind_parameters, ) + def _validate_granularity_for_engine(self, time_granularity: TimeGranularity) -> None: + if self.sql_engine and time_granularity in self.sql_engine.unsupported_granularities: + raise RuntimeError(f"{self.sql_engine.name} does not support time granularity {time_granularity.name}.") + def visit_date_trunc_expr(self, node: SqlDateTruncExpression) -> SqlExpressionRenderResult: # noqa: D102 + self._validate_granularity_for_engine(node.time_granularity) + arg_rendered = self.render_sql_expr(node.arg) return SqlExpressionRenderResult( diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 57dbf1dfd2..04eb2b9bba 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -8,6 +8,7 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from typing_extensions import override +from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, SqlExpressionRenderer, @@ -25,6 +26,8 @@ class PostgresSqlExpressionRenderer(DefaultSqlExpressionRenderer): """Expression renderer for the PostgreSQL engine.""" + sql_engine = SqlEngine.POSTGRES + @property @override def double_data_type(self) -> str: diff --git a/metricflow/sql/render/redshift.py b/metricflow/sql/render/redshift.py index 1306c5080b..8074871169 100644 --- a/metricflow/sql/render/redshift.py +++ b/metricflow/sql/render/redshift.py @@ -8,6 +8,7 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from typing_extensions import override +from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, SqlExpressionRenderer, @@ -25,6 +26,8 @@ class RedshiftSqlExpressionRenderer(DefaultSqlExpressionRenderer): """Expression renderer for the Redshift engine.""" + sql_engine = SqlEngine.REDSHIFT + @property @override def double_data_type(self) -> str: diff --git a/metricflow/sql/render/snowflake.py b/metricflow/sql/render/snowflake.py index 402aff0a45..f623d89691 100644 --- a/metricflow/sql/render/snowflake.py +++ b/metricflow/sql/render/snowflake.py @@ -8,6 +8,7 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from typing_extensions import override +from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, SqlExpressionRenderer, @@ -24,6 +25,8 @@ class SnowflakeSqlExpressionRenderer(DefaultSqlExpressionRenderer): """Expression renderer for the Snowflake engine.""" + sql_engine = SqlEngine.SNOWFLAKE + @property @override def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]: diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index 8445b294ea..5bfda74fd5 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -9,6 +9,7 @@ from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameters from typing_extensions import override +from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.render.expr_renderer import ( DefaultSqlExpressionRenderer, SqlExpressionRenderer, @@ -27,6 +28,8 @@ class TrinoSqlExpressionRenderer(DefaultSqlExpressionRenderer): """Expression renderer for the Trino engine.""" + sql_engine = SqlEngine.TRINO + @property @override def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctionType]: