diff --git a/metricflow/dag/id_generation.py b/metricflow/dag/id_generation.py index e08fca06de..1fd5e03581 100644 --- a/metricflow/dag/id_generation.py +++ b/metricflow/dag/id_generation.py @@ -35,6 +35,7 @@ SQL_EXPR_IS_NULL_PREFIX = "isn" SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX = "ctt" SQL_EXPR_DATE_TRUNC = "dt" +SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX = "sti" SQL_EXPR_EXTRACT = "ex" SQL_EXPR_RATIO_COMPUTATION = "rc" SQL_EXPR_BETWEEN_PREFIX = "betw" diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 92ca575721..20b98d6ab5 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -15,7 +15,7 @@ SqlIsNullExpression, SqlLogicalExpression, SqlLogicalOperator, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) from metricflow.sql.sql_plan import SqlExpressionNode, SqlJoinDescription, SqlJoinType, SqlSelectStatementNode @@ -441,7 +441,7 @@ def make_cumulative_metric_time_range_join_description( start_of_range_comparison_expr = SqlComparisonExpression( left_expr=metric_time_column_expr, comparison=SqlComparison.GREATER_THAN, - right_expr=SqlTimeDeltaExpression( + right_expr=SqlSubtractTimeIntervalExpression( arg=time_spine_column_expr, count=node.window.count, granularity=node.window.granularity, @@ -481,7 +481,7 @@ def make_join_to_time_spine_join_description( col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=metric_time_dimension_column_name) ) if node.offset_window: - left_expr = SqlTimeDeltaExpression( + left_expr = SqlSubtractTimeIntervalExpression( arg=left_expr, count=node.offset_window.count, granularity=node.offset_window.granularity ) elif node.offset_to_grain: diff --git a/metricflow/sql/render/big_query.py b/metricflow/sql/render/big_query.py index 4dcb42aaa3..aa360e154a 100644 --- a/metricflow/sql/render/big_query.py +++ b/metricflow/sql/render/big_query.py @@ -21,7 +21,7 @@ SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn from metricflow.time.date_part import DatePart @@ -142,19 +142,9 @@ def render_date_part(self, date_part: DatePart) -> str: return super().render_date_part(date_part) @override - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta for BigQuery, which requires ISO prefixing for the WEEK granularity value.""" column = node.arg.accept(self) - if node.grain_to_date: - granularity = node.granularity - if granularity == TimeGranularity.WEEK: - granularity_value = "ISO" + granularity.value.upper() - else: - granularity_value = granularity.value - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC({column.sql}, {granularity_value})", - bind_parameters=column.bind_parameters, - ) return SqlExpressionRenderResult( sql=f"DATE_SUB(CAST({column.sql} AS {self.timestamp_data_type}), INTERVAL {node.count} {node.granularity.value})", diff --git a/metricflow/sql/render/duckdb_renderer.py b/metricflow/sql/render/duckdb_renderer.py index 7158d84de6..451fb05398 100644 --- a/metricflow/sql/render/duckdb_renderer.py +++ b/metricflow/sql/render/duckdb_renderer.py @@ -17,7 +17,7 @@ SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) @@ -34,14 +34,9 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio } @override - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta expression for DuckDB, which requires slightly different syntax from other engines.""" arg_rendered = node.arg.accept(self) - if node.grain_to_date: - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)", - bind_parameters=arg_rendered.bind_parameters, - ) count = node.count granularity = node.granularity diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index 9a5e58a465..84df8d2395 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -33,7 +33,7 @@ SqlRatioComputationExpression, SqlStringExpression, SqlStringLiteralExpression, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, SqlWindowFunctionExpression, ) from metricflow.sql.sql_plan import SqlSelectColumn @@ -281,13 +281,8 @@ def render_date_part(self, date_part: DatePart) -> str: """Render DATE PART for an EXTRACT expression.""" return date_part.value - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: # noqa: D + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: # noqa: D arg_rendered = node.arg.accept(self) - if node.grain_to_date: - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)", - bind_parameters=arg_rendered.bind_parameters, - ) count = node.count granularity = node.granularity diff --git a/metricflow/sql/render/postgres.py b/metricflow/sql/render/postgres.py index f77708d7fa..9ffc08dbf1 100644 --- a/metricflow/sql/render/postgres.py +++ b/metricflow/sql/render/postgres.py @@ -18,7 +18,7 @@ SqlGenerateUuidExpression, SqlPercentileExpression, SqlPercentileFunctionType, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) @@ -37,14 +37,9 @@ def supported_percentile_function_types(self) -> Collection[SqlPercentileFunctio return {SqlPercentileFunctionType.CONTINUOUS, SqlPercentileFunctionType.DISCRETE} @override - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> SqlExpressionRenderResult: + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> SqlExpressionRenderResult: """Render time delta operations for PostgreSQL, which needs custom support for quarterly granularity.""" arg_rendered = node.arg.accept(self) - if node.grain_to_date: - return SqlExpressionRenderResult( - sql=f"DATE_TRUNC('{node.granularity.value}', {arg_rendered.sql}::timestamp)", - bind_parameters=arg_rendered.bind_parameters, - ) count = node.count granularity = node.granularity diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 8ab96a29e9..2206691e7c 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -16,6 +16,7 @@ from metricflow.dag.id_generation import ( SQL_EXPR_BETWEEN_PREFIX, + SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX, SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX, SQL_EXPR_COMPARISON_ID_PREFIX, SQL_EXPR_DATE_TRUNC, @@ -29,6 +30,7 @@ SQL_EXPR_RATIO_COMPUTATION, SQL_EXPR_STRING_ID_PREFIX, SQL_EXPR_STRING_LITERAL_PREFIX, + SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX, SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX, ) from metricflow.dag.mf_dag import DagNode, DisplayedProperty, NodeId @@ -225,7 +227,7 @@ def visit_extract_expr(self, node: SqlExtractExpression) -> VisitorOutputT: # n pass @abstractmethod - def visit_time_delta_expr(self, node: SqlTimeDeltaExpression) -> VisitorOutputT: # noqa: D + def visit_time_delta_expr(self, node: SqlSubtractTimeIntervalExpression) -> VisitorOutputT: # noqa: D pass @abstractmethod @@ -1243,25 +1245,29 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D return self._parents_match(other) -class SqlTimeDeltaExpression(SqlExpressionNode): - """create time delta between eg `DATE_SUB(ds, 2, month)`.""" +class SqlSubtractTimeIntervalExpression(SqlExpressionNode): + """Represents an interval subtraction from a given timestamp. + + This node contains the information required to produce a SQL statement which subtracts an interval with the given + count and granularity (which together define the interval duration) from the input timestamp expression. The return + value from the SQL rendering for this expression should be a timestamp expression offset from the initial input + value. + """ def __init__( # noqa: D self, arg: SqlExpressionNode, count: int, granularity: TimeGranularity, - grain_to_date: Optional[TimeGranularity] = None, ) -> None: super().__init__(node_id=self.create_unique_id(), parent_nodes=[arg]) self._count = count self._time_granularity = granularity self._arg = arg - self._grain_to_date = grain_to_date @classmethod def id_prefix(cls) -> str: # noqa: D - return SQL_EXPR_IS_NULL_PREFIX + return SQL_EXPR_SUBTRACT_TIME_INTERVAL_PREFIX @property def requires_parenthesis(self) -> bool: # noqa: D @@ -1278,10 +1284,6 @@ def description(self) -> str: # noqa: D def arg(self) -> SqlExpressionNode: # noqa: D return self._arg - @property - def grain_to_date(self) -> Optional[TimeGranularity]: # noqa: D - return self._grain_to_date - @property def count(self) -> int: # noqa: D return self._count @@ -1295,11 +1297,10 @@ def rewrite( # noqa: D column_replacements: Optional[SqlColumnReplacements] = None, should_render_table_alias: Optional[bool] = None, ) -> SqlExpressionNode: - return SqlTimeDeltaExpression( + return SqlSubtractTimeIntervalExpression( arg=self.arg.rewrite(column_replacements, should_render_table_alias), count=self.count, granularity=self.granularity, - grain_to_date=self.grain_to_date, ) @property @@ -1309,14 +1310,9 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D ) def matches(self, other: SqlExpressionNode) -> bool: # noqa: D - if not isinstance(other, SqlTimeDeltaExpression): + if not isinstance(other, SqlSubtractTimeIntervalExpression): return False - return ( - self.count == other.count - and self.granularity == other.granularity - and self.grain_to_date == other.grain_to_date - and self._parents_match(other) - ) + return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) class SqlCastToTimestampExpression(SqlExpressionNode): @@ -1327,7 +1323,7 @@ def __init__(self, arg: SqlExpressionNode) -> None: # noqa: D @classmethod def id_prefix(cls) -> str: # noqa: D - return SQL_EXPR_IS_NULL_PREFIX + return SQL_EXPR_CAST_TO_TIMESTAMP_PREFIX @property def requires_parenthesis(self) -> bool: # noqa: D diff --git a/metricflow/test/integration/test_configured_cases.py b/metricflow/test/integration/test_configured_cases.py index eec106d555..95579a11d1 100644 --- a/metricflow/test/integration/test_configured_cases.py +++ b/metricflow/test/integration/test_configured_cases.py @@ -31,7 +31,7 @@ SqlPercentileExpressionArgument, SqlPercentileFunctionType, SqlStringExpression, - SqlTimeDeltaExpression, + SqlSubtractTimeIntervalExpression, ) from metricflow.test.compare_df import assert_dataframes_equal from metricflow.test.fixtures.setup_fixtures import MetricFlowTestSessionState @@ -85,7 +85,7 @@ def render_date_sub( granularity: TimeGranularity, ) -> str: """Renders a date subtract expression.""" - expr = SqlTimeDeltaExpression( + expr = SqlSubtractTimeIntervalExpression( arg=SqlColumnReferenceExpression(SqlColumnReference(table_alias, column_alias)), count=count, granularity=granularity,