From 5f14ab26eef861f3e559567b151c597381417548 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Fri, 8 Nov 2024 13:05:55 -0800 Subject: [PATCH] Update `DataflowToSqlQueryPlanConverter` to use CTEs. --- metricflow/plan_conversion/dataflow_to_sql.py | 283 +++++++++++++++++- .../sql/optimizer/optimization_levels.py | 6 +- 2 files changed, 274 insertions(+), 15 deletions(-) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 116930a20..a53702e3d 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -3,7 +3,7 @@ import datetime as dt import logging from collections import OrderedDict -from typing import List, Optional, Sequence, Set, Tuple, Union +from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union from dbt_semantic_interfaces.enum_extension import assert_values_exhausted from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME @@ -39,12 +39,15 @@ from metricflow_semantics.specs.spec_set import InstanceSpecSet from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.sql.sql_join_type import SqlJoinType +from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT from metricflow_semantics.time.time_spine_source import TIME_SPINE_DATA_SET_DESCRIPTION, TimeSpineSource +from typing_extensions import override from metricflow.dataflow.dataflow_plan import ( DataflowPlanNode, ) +from metricflow.dataflow.dataflow_plan_analyzer import DataflowPlanAnalyzer from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode @@ -104,9 +107,10 @@ ) from metricflow.protocols.sql_client import SqlEngine from metricflow.sql.optimizer.optimization_levels import ( - SqlQueryGenerationOptionSet, + SqlGenerationOptionSet, SqlQueryOptimizationLevel, ) +from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer from metricflow.sql.sql_exprs import ( SqlAggregateFunctionExpression, SqlBetweenExpression, @@ -131,6 +135,7 @@ ) from metricflow.sql.sql_plan import ( SqlCreateTableAsNode, + SqlCteNode, SqlJoinDescription, SqlOrderByDescription, SqlQueryPlan, @@ -181,23 +186,72 @@ def convert_to_sql_query_plan( sql_query_plan_id: Optional[DagId] = None, ) -> ConvertToSqlPlanResult: """Create an SQL query plan that represents the computation up to the given dataflow plan node.""" - # TODO: Handle generation with CTE. - to_sql_subquery_visitor = DataflowNodeToSqlSubqueryVisitor( - column_association_resolver=self.column_association_resolver, - semantic_manifest_lookup=self._semantic_manifest_lookup, - ) - data_set = dataflow_plan_node.accept(to_sql_subquery_visitor) - - sql_node: SqlQueryPlanNode = data_set.sql_node # TODO: Make this a more generally accessible attribute instead of checking against the # BigQuery-ness of the engine use_column_alias_in_group_by = sql_engine_type is SqlEngine.BIGQUERY - option_set = SqlQueryGenerationOptionSet.options_for_level( + option_set = SqlGenerationOptionSet.options_for_level( optimization_level, use_column_alias_in_group_by=use_column_alias_in_group_by ) - for optimizer in option_set.optimizers: + nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode] = frozenset() + if option_set.allow_cte: + nodes_to_convert_to_cte = self._get_nodes_to_convert_to_cte(dataflow_plan_node) + + return self.convert_using_specifics( + dataflow_plan_node=dataflow_plan_node, + sql_query_plan_id=sql_query_plan_id, + nodes_to_convert_to_cte=nodes_to_convert_to_cte, + optimizers=option_set.optimizers, + ) + + def convert_using_specifics( + self, + dataflow_plan_node: DataflowPlanNode, + sql_query_plan_id: Optional[DagId], + nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode], + optimizers: Sequence[SqlQueryPlanOptimizer], + ) -> ConvertToSqlPlanResult: + """Helper method to convert using specific options. Main use case are tests.""" + logger.debug(LazyFormat("Converting to SQL", nodes_to_convert_to_cte=nodes_to_convert_to_cte)) + + if len(nodes_to_convert_to_cte) == 0: + # Avoid `DataflowNodeToSqlCteVisitor` code path for better isolation during rollout. + # Later this branch can be removed as `DataflowNodeToSqlCteVisitor` should handle an empty + # `dataflow_nodes_to_convert_to_cte`. + to_sql_subquery_visitor = DataflowNodeToSqlSubqueryVisitor( + column_association_resolver=self.column_association_resolver, + semantic_manifest_lookup=self._semantic_manifest_lookup, + ) + data_set = dataflow_plan_node.accept(to_sql_subquery_visitor) + else: + to_sql_cte_visitor = DataflowNodeToSqlCteVisitor( + column_association_resolver=self.column_association_resolver, + semantic_manifest_lookup=self._semantic_manifest_lookup, + nodes_to_convert_to_cte=nodes_to_convert_to_cte, + ) + data_set = dataflow_plan_node.accept(to_sql_cte_visitor) + select_statement = data_set.checked_sql_select_node + data_set = SqlDataSet( + instance_set=data_set.instance_set, + sql_select_node=SqlSelectStatementNode.create( + description=select_statement.description, + select_columns=select_statement.select_columns, + from_source=select_statement.from_source, + from_source_alias=select_statement.from_source_alias, + cte_sources=tuple(to_sql_cte_visitor.generated_cte_nodes()), + join_descs=select_statement.join_descs, + group_bys=select_statement.group_bys, + order_bys=select_statement.order_bys, + where=select_statement.where, + limit=select_statement.limit, + distinct=select_statement.distinct, + ), + ) + + sql_node: SqlQueryPlanNode = data_set.sql_node + + for optimizer in optimizers: logger.debug(LazyFormat(lambda: f"Applying optimizer: {optimizer.__class__.__name__}")) sql_node = optimizer.optimize(sql_node) logger.debug( @@ -212,6 +266,17 @@ def convert_to_sql_query_plan( sql_plan=SqlQueryPlan(render_node=sql_node, plan_id=sql_query_plan_id), ) + def _get_nodes_to_convert_to_cte( + self, + dataflow_plan_node: DataflowPlanNode, + ) -> FrozenSet[DataflowPlanNode]: + """Handles logic for selecting which nodes to convert to CTEs based on the request.""" + dataflow_plan = dataflow_plan_node.as_plan() + nodes_to_convert_to_cte: Set[DataflowPlanNode] = set(DataflowPlanAnalyzer.find_common_branches(dataflow_plan)) + # Additional nodes will be added later. + + return frozenset(nodes_to_convert_to_cte) + class DataflowNodeToSqlSubqueryVisitor(DataflowPlanNodeVisitor[SqlDataSet]): """Generates a SQL query plan by converting a node's parents to sub-queries. @@ -1961,3 +2026,197 @@ def strip_time_from_dt(ts: dt.datetime) -> dt.datetime: literal_value=time_range_constraint.end_time.strftime(time_format_to_render), ), ) + + +class DataflowNodeToSqlCteVisitor(DataflowNodeToSqlSubqueryVisitor): + """Similar to `DataflowNodeToSqlSubqueryVisitor`, except that this converts specific nodes to CTEs. + + This is implemented as a subclass of `DataflowNodeToSqlSubqueryVisitor` so that by default, it has the same behavior + but in cases where there are nodes that should be converted to CTEs, alternate methods can be used. + + The generated CTE nodes are collected instead of getting incorporated into the associated SQL query plan generated + at each node so that the CTE nodes can be included at the top-level SELECT statement. + + # TODO: Move these visitors to separate files at the end of the stack. + """ + + def __init__( # noqa: D107 + self, + column_association_resolver: ColumnAssociationResolver, + semantic_manifest_lookup: SemanticManifestLookup, + nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode], + ) -> None: + super().__init__( + column_association_resolver=column_association_resolver, semantic_manifest_lookup=semantic_manifest_lookup + ) + self._nodes_to_convert_to_cte = nodes_to_convert_to_cte + self._generated_cte_nodes: List[SqlCteNode] = [] + + # If a given node is supposed to use a CTE, map the node to the generated dataset that uses a CTE. + self._node_to_cte_dataset: Dict[DataflowPlanNode, SqlDataSet] = {} + + def generated_cte_nodes(self) -> Sequence[SqlCteNode]: + """Returns the CTE nodes that have been generated while traversing the dataflow plan.""" + return self._generated_cte_nodes + + def _default_handler( + self, node: DataflowNodeT, node_to_select_subquery_function: Callable[[DataflowNodeT], SqlDataSet] + ) -> SqlDataSet: + """Default handler that is called for each node as the dataflow plan is traversed. + + Args: + node: The current node in traversal. + node_to_select_subquery_function: A function that converts the given node to a `SqlDataSet` where the + SELECT statement source is a subquery. This should be a method in `DataflowNodeToSqlSubqueryVisitor` as this + was the default behavior before CTEs were supported. + + Returns: The `SqlDataSet` that produces the data for the given node. + """ + # For the given node, if there is already a generated dataset that uses a SELECT from a CTE, return it. + select_from_cte_dataset = self._node_to_cte_dataset.get(node) + if select_from_cte_dataset is not None: + logger.debug(LazyFormat("Handling node via existing CTE", node=node)) + return select_from_cte_dataset + + # If the given node is supposed to use a CTE, generate one for it. Otherwise, use the default subquery as the + # source for the SELECT. + select_from_subquery_dataset = node_to_select_subquery_function(node) + if node not in self._nodes_to_convert_to_cte: + logger.debug(LazyFormat("Handling node via subquery", node=node)) + return select_from_subquery_dataset + logger.debug(LazyFormat("Handling node via new CTE", node=node)) + + cte_alias = node.node_id.id_str + "_cte" + + if cte_alias in set(node.cte_alias for node in self._generated_cte_nodes): + raise ValueError( + f"{cte_alias=} is a duplicate of one that already exists. " + f"This implies a bug that is generating a CTE for the same dataflow plan node multiple times." + ) + + cte_source = SqlCteNode.create( + select_statement=select_from_subquery_dataset.sql_node, + cte_alias=cte_alias, + ) + self._generated_cte_nodes.append(cte_source) + node_id = node.node_id + select_from_cte_dataset = SqlDataSet( + instance_set=select_from_subquery_dataset.instance_set, + sql_select_node=SqlSelectStatementNode.create( + description=f"Read From CTE For {node_id=}", + select_columns=CreateSelectColumnsForInstances( + table_alias=cte_alias, + column_resolver=self._column_association_resolver, + ) + .transform(select_from_subquery_dataset.instance_set) + .as_tuple(), + from_source=SqlTableNode.create(SqlTable(schema_name=None, table_name=cte_alias)), + from_source_alias=cte_alias, + ), + ) + self._node_to_cte_dataset[node] = select_from_cte_dataset + + return select_from_cte_dataset + + @override + def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_source_node) + + @override + def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_join_on_entities_node) + + @override + def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_aggregate_measures_node) + + @override + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_compute_metrics_node) + + @override + def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_window_reaggregation_node + ) + + @override + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_order_by_limit_node) + + @override + def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_where_constraint_node) + + @override + def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_write_to_result_data_table_node + ) + + @override + def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_write_to_result_table_node + ) + + @override + def visit_filter_elements_node(self, node: FilterElementsNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_filter_elements_node) + + @override + def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_combine_aggregated_outputs_node + ) + + @override + def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_constrain_time_range_node + ) + + @override + def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_join_over_time_range_node + ) + + @override + def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_semi_additive_join_node) + + @override + def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTransformNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_metric_time_dimension_transform_node + ) + + @override + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_join_to_time_spine_node) + + @override + def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: + return self._default_handler(node=node, node_to_select_subquery_function=super().visit_min_max_node) + + @override + def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_add_generated_uuid_column_node + ) + + @override + def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_join_conversion_events_node + ) + + @override + def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet: + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_join_to_custom_granularity_node + ) + + +DataflowNodeT = TypeVar("DataflowNodeT", bound=DataflowPlanNode) diff --git a/metricflow/sql/optimizer/optimization_levels.py b/metricflow/sql/optimizer/optimization_levels.py index 067019eb9..f1f965d73 100644 --- a/metricflow/sql/optimizer/optimization_levels.py +++ b/metricflow/sql/optimizer/optimization_levels.py @@ -25,7 +25,7 @@ class SqlQueryOptimizationLevel(Enum): @dataclass(frozen=True) -class SqlQueryGenerationOptionSet: +class SqlGenerationOptionSet: """Defines the different SQL generation optimizers / options that should be used at each level.""" optimizers: Tuple[SqlQueryPlanOptimizer, ...] @@ -36,7 +36,7 @@ class SqlQueryGenerationOptionSet: @staticmethod def options_for_level( # noqa: D102 level: SqlQueryOptimizationLevel, use_column_alias_in_group_by: bool - ) -> SqlQueryGenerationOptionSet: + ) -> SqlGenerationOptionSet: optimizers: Tuple[SqlQueryPlanOptimizer, ...] = () allow_cte = False if level is SqlQueryOptimizationLevel.O0: @@ -63,7 +63,7 @@ def options_for_level( # noqa: D102 else: assert_values_exhausted(level) - return SqlQueryGenerationOptionSet( + return SqlGenerationOptionSet( optimizers=optimizers, allow_cte=allow_cte, )