diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index a9ae0df0e..8c2a6d1b4 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -69,6 +69,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt" SQL_EXPR_DATE_TRUNC = "dt" SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX = "sti" + SQL_EXPR_ADD_TIME_PREFIX = "ati" SQL_EXPR_EXTRACT = "ex" SQL_EXPR_RATIO_COMPUTATION = "rc" SQL_EXPR_BETWEEN_PREFIX = "betw" diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index bef39e99f..a63b2d06c 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -18,6 +18,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlCastToTimestampExpression, SqlDateTruncExpression, SqlExtractExpression, @@ -167,7 +168,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> SqlExpressionRenderR ) @override - def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: + def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value.""" column = node.arg.accept(self) @@ -176,6 +177,17 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE bind_parameter_set=column.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value.""" + column = node.arg.accept(self) + count = node.count_expr.accept(self) + + return SqlExpressionRenderResult( + sql=f"DATE_ADD(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {count} {node.granularity.value})", + bind_parameter_set=column.bind_parameter_set, + ) + @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index 6ff7fb9e7..3e03b7eca 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -15,6 +15,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, @@ -37,13 +38,13 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio } @override - def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: + def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" arg_rendered = node.arg.accept(self) count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 @@ -52,6 +53,22 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE bind_parameter_set=arg_rendered.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity is TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"({count_rendered} * 3)" + + return SqlExpressionRenderResult( + sql=f"{arg_rendered.sql} + INTERVAL {count_rendered} {granularity.value}", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 0fd9e81e8..a387a2da0 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -16,6 +16,7 @@ from metricflow.sql.render.rendering_constants import SqlRenderingConstants from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlAggregateFunctionExpression, SqlBetweenExpression, SqlCastToTimestampExpression, @@ -303,12 +304,14 @@ def render_date_part(self, date_part: DatePart) -> str: return date_part.value - def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: # noqa: D102 + def visit_subtract_time_interval_expr( # noqa: D102 + self, node: SqlSubtractTimeIntervalExpression + ) -> SqlExpressionRenderResult: arg_rendered = node.arg.accept(self) count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 return SqlExpressionRenderResult( @@ -316,6 +319,20 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE bind_parameter_set=arg_rendered.bind_parameter_set, ) + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: # noqa: D102 + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity is TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"({count_rendered} * 3)" + + return SqlExpressionRenderResult( + sql=f"DATEADD({granularity.value}, {count_rendered}, {arg_rendered.sql})", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + def visit_ratio_computation_expr(self, node: SqlRatioComputationExpression) -> SqlExpressionRenderResult: """Render the ratio computation for a ratio metric. diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index 36a1b687e..8ced43020 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -16,6 +16,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, @@ -40,13 +41,13 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio return {SqlPercentileFunctionType.CONTINUOUS, SqlPercentileFunctionType.DISCRETE} @override - def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: + def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" arg_rendered = node.arg.accept(self) count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 return SqlExpressionRenderResult( @@ -54,6 +55,22 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE bind_parameter_set=arg_rendered.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity is TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"({count_rendered} * 3)" + + return SqlExpressionRenderResult( + sql=f"{arg_rendered.sql} + MAKE_INTERVAL({granularity.value}s => {count_rendered})", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + @override def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpressionRenderResult: return SqlExpressionRenderResult( diff --git a/metricflow/sql/render/trino.py b/metricflow/sql/render/trino.py index 23bd65ab6..e6aff3150 100644 --- a/metricflow/sql/render/trino.py +++ b/metricflow/sql/render/trino.py @@ -17,6 +17,7 @@ ) from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlBetweenExpression, SqlGenerateUuidExpression, SqlPercentileExpression, @@ -45,13 +46,13 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpres ) @override - def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: + def visit_subtract_time_interval_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta for Trino, require granularity in quotes and function name change.""" arg_rendered = node.arg.accept(self) count = node.count granularity = node.granularity - if granularity == TimeGranularity.QUARTER: + if granularity is TimeGranularity.QUARTER: granularity = TimeGranularity.MONTH count *= 3 return SqlExpressionRenderResult( @@ -59,6 +60,22 @@ def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlE bind_parameter_set=arg_rendered.bind_parameter_set, ) + @override + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> SqlExpressionRenderResult: + """Render time delta for Trino, require granularity in quotes and function name change.""" + arg_rendered = node.arg.accept(self) + count_rendered = node.count_expr.accept(self).sql + + granularity = node.granularity + if granularity is TimeGranularity.QUARTER: + granularity = TimeGranularity.MONTH + count_rendered = f"({count_rendered} * 3)" + + return SqlExpressionRenderResult( + sql=f"DATE_ADD('{granularity.value}', {count_rendered}, {arg_rendered.sql})", + bind_parameter_set=arg_rendered.bind_parameter_set, + ) + @override def visit_percentile_expr(self, node: SqlPercentileExpression) -> SqlExpressionRenderResult: """Render a percentile expression for Trino.""" diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index bde085260..ec7866f00 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -212,7 +212,13 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # n pass @abstractmethod - def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> VisitorOutputT: # noqa: D102 + def visit_subtract_time_interval_expr( # noqa: D102 + self, node: SqlSubtractTimeIntervalExpression + ) -> VisitorOutputT: + pass + + @abstractmethod + def visit_add_time_expr(self, node: SqlAddTimeExpression) -> VisitorOutputT: # noqa: D102 pass @abstractmethod @@ -1289,11 +1295,11 @@ def requires_parenthesis(self) -> bool: # noqa: D102 return False def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 - return visitor.visit_time_delta_expr(self) + return visitor.visit_subtract_time_interval_expr(self) @property def description(self) -> str: # noqa: D102 - return "Time delta" + return "Subtract time interval" def rewrite( # noqa: D102 self, @@ -1318,6 +1324,65 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) +@dataclass(frozen=True, eq=False) +class SqlAddTimeExpression(SqlExpressionNode): + """Add a time interval expr to a timestamp.""" + + arg: SqlExpressionNode + count_expr: SqlExpressionNode + granularity: TimeGranularity + + @staticmethod + def create( # noqa: D102 + arg: SqlExpressionNode, + count_expr: SqlExpressionNode, + granularity: TimeGranularity, + ) -> SqlAddTimeExpression: + return SqlAddTimeExpression( + parent_nodes=(arg, count_expr), + arg=arg, + count_expr=count_expr, + granularity=granularity, + ) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_ADD_TIME_PREFIX + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_add_time_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Add time interval" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return SqlAddTimeExpression.create( + arg=self.arg.rewrite(column_replacements, should_render_table_alias), + count_expr=self.count_expr, + granularity=self.granularity, + ) + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage.merge_iterable( + tuple(x.lineage for x in self.parent_nodes) + (SqlExpressionTreeLineage(other_exprs=(self,)),) + ) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlAddTimeExpression): + return False + return self.count_expr == other.count_expr and self.granularity == other.granularity and self.arg == other.arg + + @dataclass(frozen=True, eq=False) class SqlCastToTimestampExpression(SqlExpressionNode): """Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP).""" diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/BigQuery/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/BigQuery/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..90fc09ace --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/BigQuery/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: BigQuery +--- +-- Test Add Time Expression +SELECT + DATE_ADD(CAST('2020-01-01' AS DATETIME), INTERVAL SqlExpressionRenderResult(sql='1', bind_parameter_set=SqlBindParameterSet(param_items=())) quarter) AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Databricks/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Databricks/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..3f6f5e12d --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Databricks/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Databricks +--- +-- Test Add Time Expression +SELECT + DATEADD(month, (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/DuckDB/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/DuckDB/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..984e2096f --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/DuckDB/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: DuckDB +--- +-- Test Add Time Expression +SELECT + '2020-01-01' + INTERVAL (1 * 3) month AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Postgres/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Postgres/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..2be701942 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Postgres/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Postgres +--- +-- Test Add Time Expression +SELECT + '2020-01-01' + MAKE_INTERVAL(months => (1 * 3)) AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Redshift/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Redshift/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..bac4dc733 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Redshift/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Redshift +--- +-- Test Add Time Expression +SELECT + DATEADD(month, (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Snowflake/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Snowflake/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..b83e17338 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Snowflake/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Snowflake +--- +-- Test Add Time Expression +SELECT + DATEADD(month, (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Trino/test_add_time_expr__plan0.sql b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Trino/test_add_time_expr__plan0.sql new file mode 100644 index 000000000..b6bb21279 --- /dev/null +++ b/tests_metricflow/snapshots/test_engine_specific_rendering.py/SqlQueryPlan/Trino/test_add_time_expr__plan0.sql @@ -0,0 +1,10 @@ +test_name: test_add_time_expr +test_filename: test_engine_specific_rendering.py +docstring: + Tests rendering of the SqlAddTimeExpr in a query. +sql_engine: Trino +--- +-- Test Add Time Expression +SELECT + DATE_ADD('month', (1 * 3), '2020-01-01') AS add_time +FROM foo.bar a diff --git a/tests_metricflow/sql/test_engine_specific_rendering.py b/tests_metricflow/sql/test_engine_specific_rendering.py index 60c5a97ca..987762006 100644 --- a/tests_metricflow/sql/test_engine_specific_rendering.py +++ b/tests_metricflow/sql/test_engine_specific_rendering.py @@ -4,11 +4,13 @@ import pytest from _pytest.fixtures import FixtureRequest +from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.protocols.sql_client import SqlClient from metricflow.sql.sql_exprs import ( + SqlAddTimeExpression, SqlCastToTimestampExpression, SqlColumnReference, SqlColumnReferenceExpression, @@ -16,6 +18,7 @@ SqlPercentileExpression, SqlPercentileExpressionArgument, SqlPercentileFunctionType, + SqlStringExpression, SqlStringLiteralExpression, ) from metricflow.sql.sql_plan import ( @@ -295,3 +298,42 @@ def test_approximate_discrete_percentile_expr( plan_id="plan0", sql_client=sql_client, ) + + +@pytest.mark.sql_engine_snapshot +def test_add_time_expr( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + sql_client: SqlClient, +) -> None: + """Tests rendering of the SqlAddTimeExpr in a query.""" + select_columns = [ + SqlSelectColumn( + expr=SqlAddTimeExpression.create( + arg=SqlStringLiteralExpression.create( + "2020-01-01", + ), + count_expr=SqlStringExpression.create( + "1", + ), + granularity=TimeGranularity.QUARTER, + ), + column_alias="add_time", + ), + ] + + from_source = SqlTableNode.create(sql_table=SqlTable(schema_name="foo", table_name="bar")) + from_source_alias = "a" + + assert_rendered_sql_equal( + request=request, + mf_test_configuration=mf_test_configuration, + sql_plan_node=SqlSelectStatementNode.create( + description="Test Add Time Expression", + select_columns=tuple(select_columns), + from_source=from_source, + from_source_alias=from_source_alias, + ), + plan_id="plan0", + sql_client=sql_client, + )