Skip to content

Commit

Permalink
Update DataflowToSqlQueryPlanConverter to use CTEs.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 13, 2024
1 parent 0bff032 commit 5f14ab2
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 15 deletions.
283 changes: 271 additions & 12 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime as dt
import logging
from collections import OrderedDict
from typing import List, Optional, Sequence, Set, Tuple, Union
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
Expand Down Expand Up @@ -39,12 +39,15 @@
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT
from metricflow_semantics.time.time_spine_source import TIME_SPINE_DATA_SET_DESCRIPTION, TimeSpineSource
from typing_extensions import override

from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
)
from metricflow.dataflow.dataflow_plan_analyzer import DataflowPlanAnalyzer
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
Expand Down Expand Up @@ -104,9 +107,10 @@
)
from metricflow.protocols.sql_client import SqlEngine
from metricflow.sql.optimizer.optimization_levels import (
SqlQueryGenerationOptionSet,
SqlGenerationOptionSet,
SqlQueryOptimizationLevel,
)
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import (
SqlAggregateFunctionExpression,
SqlBetweenExpression,
Expand All @@ -131,6 +135,7 @@
)
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlQueryPlan,
Expand Down Expand Up @@ -181,23 +186,72 @@ def convert_to_sql_query_plan(
sql_query_plan_id: Optional[DagId] = None,
) -> ConvertToSqlPlanResult:
"""Create an SQL query plan that represents the computation up to the given dataflow plan node."""
# TODO: Handle generation with CTE.
to_sql_subquery_visitor = DataflowNodeToSqlSubqueryVisitor(
column_association_resolver=self.column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
data_set = dataflow_plan_node.accept(to_sql_subquery_visitor)

sql_node: SqlQueryPlanNode = data_set.sql_node
# TODO: Make this a more generally accessible attribute instead of checking against the
# BigQuery-ness of the engine
use_column_alias_in_group_by = sql_engine_type is SqlEngine.BIGQUERY

option_set = SqlQueryGenerationOptionSet.options_for_level(
option_set = SqlGenerationOptionSet.options_for_level(
optimization_level, use_column_alias_in_group_by=use_column_alias_in_group_by
)

for optimizer in option_set.optimizers:
nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode] = frozenset()
if option_set.allow_cte:
nodes_to_convert_to_cte = self._get_nodes_to_convert_to_cte(dataflow_plan_node)

return self.convert_using_specifics(
dataflow_plan_node=dataflow_plan_node,
sql_query_plan_id=sql_query_plan_id,
nodes_to_convert_to_cte=nodes_to_convert_to_cte,
optimizers=option_set.optimizers,
)

def convert_using_specifics(
self,
dataflow_plan_node: DataflowPlanNode,
sql_query_plan_id: Optional[DagId],
nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode],
optimizers: Sequence[SqlQueryPlanOptimizer],
) -> ConvertToSqlPlanResult:
"""Helper method to convert using specific options. Main use case are tests."""
logger.debug(LazyFormat("Converting to SQL", nodes_to_convert_to_cte=nodes_to_convert_to_cte))

if len(nodes_to_convert_to_cte) == 0:
# Avoid `DataflowNodeToSqlCteVisitor` code path for better isolation during rollout.
# Later this branch can be removed as `DataflowNodeToSqlCteVisitor` should handle an empty
# `dataflow_nodes_to_convert_to_cte`.
to_sql_subquery_visitor = DataflowNodeToSqlSubqueryVisitor(
column_association_resolver=self.column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
)
data_set = dataflow_plan_node.accept(to_sql_subquery_visitor)
else:
to_sql_cte_visitor = DataflowNodeToSqlCteVisitor(
column_association_resolver=self.column_association_resolver,
semantic_manifest_lookup=self._semantic_manifest_lookup,
nodes_to_convert_to_cte=nodes_to_convert_to_cte,
)
data_set = dataflow_plan_node.accept(to_sql_cte_visitor)
select_statement = data_set.checked_sql_select_node
data_set = SqlDataSet(
instance_set=data_set.instance_set,
sql_select_node=SqlSelectStatementNode.create(
description=select_statement.description,
select_columns=select_statement.select_columns,
from_source=select_statement.from_source,
from_source_alias=select_statement.from_source_alias,
cte_sources=tuple(to_sql_cte_visitor.generated_cte_nodes()),
join_descs=select_statement.join_descs,
group_bys=select_statement.group_bys,
order_bys=select_statement.order_bys,
where=select_statement.where,
limit=select_statement.limit,
distinct=select_statement.distinct,
),
)

sql_node: SqlQueryPlanNode = data_set.sql_node

for optimizer in optimizers:
logger.debug(LazyFormat(lambda: f"Applying optimizer: {optimizer.__class__.__name__}"))
sql_node = optimizer.optimize(sql_node)
logger.debug(
Expand All @@ -212,6 +266,17 @@ def convert_to_sql_query_plan(
sql_plan=SqlQueryPlan(render_node=sql_node, plan_id=sql_query_plan_id),
)

def _get_nodes_to_convert_to_cte(
self,
dataflow_plan_node: DataflowPlanNode,
) -> FrozenSet[DataflowPlanNode]:
"""Handles logic for selecting which nodes to convert to CTEs based on the request."""
dataflow_plan = dataflow_plan_node.as_plan()
nodes_to_convert_to_cte: Set[DataflowPlanNode] = set(DataflowPlanAnalyzer.find_common_branches(dataflow_plan))
# Additional nodes will be added later.

return frozenset(nodes_to_convert_to_cte)


class DataflowNodeToSqlSubqueryVisitor(DataflowPlanNodeVisitor[SqlDataSet]):
"""Generates a SQL query plan by converting a node's parents to sub-queries.
Expand Down Expand Up @@ -1961,3 +2026,197 @@ def strip_time_from_dt(ts: dt.datetime) -> dt.datetime:
literal_value=time_range_constraint.end_time.strftime(time_format_to_render),
),
)


class DataflowNodeToSqlCteVisitor(DataflowNodeToSqlSubqueryVisitor):
"""Similar to `DataflowNodeToSqlSubqueryVisitor`, except that this converts specific nodes to CTEs.
This is implemented as a subclass of `DataflowNodeToSqlSubqueryVisitor` so that by default, it has the same behavior
but in cases where there are nodes that should be converted to CTEs, alternate methods can be used.
The generated CTE nodes are collected instead of getting incorporated into the associated SQL query plan generated
at each node so that the CTE nodes can be included at the top-level SELECT statement.
# TODO: Move these visitors to separate files at the end of the stack.
"""

def __init__( # noqa: D107
self,
column_association_resolver: ColumnAssociationResolver,
semantic_manifest_lookup: SemanticManifestLookup,
nodes_to_convert_to_cte: FrozenSet[DataflowPlanNode],
) -> None:
super().__init__(
column_association_resolver=column_association_resolver, semantic_manifest_lookup=semantic_manifest_lookup
)
self._nodes_to_convert_to_cte = nodes_to_convert_to_cte
self._generated_cte_nodes: List[SqlCteNode] = []

# If a given node is supposed to use a CTE, map the node to the generated dataset that uses a CTE.
self._node_to_cte_dataset: Dict[DataflowPlanNode, SqlDataSet] = {}

def generated_cte_nodes(self) -> Sequence[SqlCteNode]:
"""Returns the CTE nodes that have been generated while traversing the dataflow plan."""
return self._generated_cte_nodes

def _default_handler(
self, node: DataflowNodeT, node_to_select_subquery_function: Callable[[DataflowNodeT], SqlDataSet]
) -> SqlDataSet:
"""Default handler that is called for each node as the dataflow plan is traversed.
Args:
node: The current node in traversal.
node_to_select_subquery_function: A function that converts the given node to a `SqlDataSet` where the
SELECT statement source is a subquery. This should be a method in `DataflowNodeToSqlSubqueryVisitor` as this
was the default behavior before CTEs were supported.
Returns: The `SqlDataSet` that produces the data for the given node.
"""
# For the given node, if there is already a generated dataset that uses a SELECT from a CTE, return it.
select_from_cte_dataset = self._node_to_cte_dataset.get(node)
if select_from_cte_dataset is not None:
logger.debug(LazyFormat("Handling node via existing CTE", node=node))
return select_from_cte_dataset

# If the given node is supposed to use a CTE, generate one for it. Otherwise, use the default subquery as the
# source for the SELECT.
select_from_subquery_dataset = node_to_select_subquery_function(node)
if node not in self._nodes_to_convert_to_cte:
logger.debug(LazyFormat("Handling node via subquery", node=node))
return select_from_subquery_dataset
logger.debug(LazyFormat("Handling node via new CTE", node=node))

cte_alias = node.node_id.id_str + "_cte"

if cte_alias in set(node.cte_alias for node in self._generated_cte_nodes):
raise ValueError(
f"{cte_alias=} is a duplicate of one that already exists. "
f"This implies a bug that is generating a CTE for the same dataflow plan node multiple times."
)

cte_source = SqlCteNode.create(
select_statement=select_from_subquery_dataset.sql_node,
cte_alias=cte_alias,
)
self._generated_cte_nodes.append(cte_source)
node_id = node.node_id
select_from_cte_dataset = SqlDataSet(
instance_set=select_from_subquery_dataset.instance_set,
sql_select_node=SqlSelectStatementNode.create(
description=f"Read From CTE For {node_id=}",
select_columns=CreateSelectColumnsForInstances(
table_alias=cte_alias,
column_resolver=self._column_association_resolver,
)
.transform(select_from_subquery_dataset.instance_set)
.as_tuple(),
from_source=SqlTableNode.create(SqlTable(schema_name=None, table_name=cte_alias)),
from_source_alias=cte_alias,
),
)
self._node_to_cte_dataset[node] = select_from_cte_dataset

return select_from_cte_dataset

@override
def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_source_node)

@override
def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_join_on_entities_node)

@override
def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_aggregate_measures_node)

@override
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_compute_metrics_node)

@override
def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_window_reaggregation_node
)

@override
def visit_order_by_limit_node(self, node: OrderByLimitNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_order_by_limit_node)

@override
def visit_where_constraint_node(self, node: WhereConstraintNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_where_constraint_node)

@override
def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_write_to_result_data_table_node
)

@override
def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_write_to_result_table_node
)

@override
def visit_filter_elements_node(self, node: FilterElementsNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_filter_elements_node)

@override
def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_combine_aggregated_outputs_node
)

@override
def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_constrain_time_range_node
)

@override
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_join_over_time_range_node
)

@override
def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_semi_additive_join_node)

@override
def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTransformNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_metric_time_dimension_transform_node
)

@override
def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_join_to_time_spine_node)

@override
def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet:
return self._default_handler(node=node, node_to_select_subquery_function=super().visit_min_max_node)

@override
def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_add_generated_uuid_column_node
)

@override
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_join_conversion_events_node
)

@override
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet:
return self._default_handler(
node=node, node_to_select_subquery_function=super().visit_join_to_custom_granularity_node
)


DataflowNodeT = TypeVar("DataflowNodeT", bound=DataflowPlanNode)
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/optimization_levels.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SqlQueryOptimizationLevel(Enum):


@dataclass(frozen=True)
class SqlQueryGenerationOptionSet:
class SqlGenerationOptionSet:
"""Defines the different SQL generation optimizers / options that should be used at each level."""

optimizers: Tuple[SqlQueryPlanOptimizer, ...]
Expand All @@ -36,7 +36,7 @@ class SqlQueryGenerationOptionSet:
@staticmethod
def options_for_level( # noqa: D102
level: SqlQueryOptimizationLevel, use_column_alias_in_group_by: bool
) -> SqlQueryGenerationOptionSet:
) -> SqlGenerationOptionSet:
optimizers: Tuple[SqlQueryPlanOptimizer, ...] = ()
allow_cte = False
if level is SqlQueryOptimizationLevel.O0:
Expand All @@ -63,7 +63,7 @@ def options_for_level( # noqa: D102
else:
assert_values_exhausted(level)

return SqlQueryGenerationOptionSet(
return SqlGenerationOptionSet(
optimizers=optimizers,
allow_cte=allow_cte,
)

0 comments on commit 5f14ab2

Please sign in to comment.