Skip to content

Commit

Permalink
Update SqlRewritingSubQueryReducer to support CTEs.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 9, 2024
1 parent 08e26d4 commit e0ac2d9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 33 deletions.
65 changes: 34 additions & 31 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,12 @@ def _reduce_parents(
),
join_descs=tuple(
SqlJoinDescription(
right_source=x.right_source.accept(self),
right_source_alias=x.right_source_alias,
on_condition=x.on_condition,
join_type=x.join_type,
right_source=join_desc.right_source.accept(self),
right_source_alias=join_desc.right_source_alias,
on_condition=join_desc.on_condition,
join_type=join_desc.join_type,
)
for x in node.join_descs
for join_desc in node.join_descs
),
group_bys=node.group_bys,
order_bys=node.order_bys,
Expand Down Expand Up @@ -199,7 +199,7 @@ def _is_simple_source(node: SqlSelectStatementNode) -> bool:
if select_column.expr.lineage.contains_aggregate_exprs:
return False
return (
len(node.parent_nodes) <= 1
len(node.join_descs) == 0
and len(node.group_bys) == 0
and len(node.order_bys) == 0
and not node.limit
Expand All @@ -212,20 +212,15 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool:
Reducing this node means eliminating the SELECT of this node and merging it with the parent SELECT. This
checks for the cases where we are able to reduce.
"""
# If this node has multiple parents (i.e. a join) that are complex, then this can't be collapsed.
is_join = len(node.join_descs) > 0
has_multiple_parent_nodes = len(node.parent_nodes) > 1
if has_multiple_parent_nodes or is_join:
# If this node has joins, then don't collapse this as it can be complex.
if len(node.join_descs) > 0:
return False

assert len(node.parent_nodes) == 1
parent_node = node.parent_nodes[0]

# If the parent node is not a SELECT statement, then this can't be collapsed. e.g. with a table as a parent like
# SELECT foo FROM bar
if not parent_node.as_select_node:
from_source_node_as_select_node = node.from_source.as_select_node
if from_source_node_as_select_node is None:
return False
parent_select_node = parent_node.as_select_node

# More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in
# these cases for simplicity.
Expand All @@ -235,15 +230,15 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool:
return False

# Don't reduce distinct selects
if parent_select_node.distinct:
if from_source_node_as_select_node.distinct:
return False

# Skip this case for simplicity of reasoning.
if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0:
if len(node.order_bys) > 0 and len(from_source_node_as_select_node.order_bys) > 0:
return False

# Skip this case for simplicity of reasoning.
if len(parent_select_node.group_bys) > 0 and len(node.group_bys) > 0:
if len(from_source_node_as_select_node.group_bys) > 0 and len(node.group_bys) > 0:
return False

# If there is a column in the parent group by that is not used in the current select statement, don't reduce or it
Expand All @@ -270,9 +265,9 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool:
if select_column.expr.as_column_reference_expression
}
all_parent_group_bys_used_in_current_select = True
for group_by in parent_select_node.group_bys:
for group_by in from_source_node_as_select_node.group_bys:
parent_group_by_select = SqlGroupByRewritingVisitor._find_matching_select(
expr=group_by.expr, select_columns=parent_select_node.select_columns
expr=group_by.expr, select_columns=from_source_node_as_select_node.select_columns
)
if parent_group_by_select and parent_group_by_select.column_alias not in current_select_column_refs:
all_parent_group_bys_used_in_current_select = False
Expand All @@ -293,13 +288,13 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool:

# If the parent has a GROUP BY and this has a WHERE, avoid reducing as the WHERE could reference an
# aggregation expression.
if len(parent_select_node.group_bys) > 0 and node.where:
if len(from_source_node_as_select_node.group_bys) > 0 and node.where:
return False

# If the parent has a GROUP BY, the case where it's easiest to merge this with the parent is if all select
# columns are column references.
if len(
parent_select_node.group_bys
from_source_node_as_select_node.group_bys
) > 0 and not SqlRewritingSubQueryReducerVisitor._select_columns_are_column_references(node.select_columns):
return False

Expand All @@ -308,7 +303,7 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool:
parent_column_aliases_with_window_functions = {
select_column.column_alias
for select_column in SqlRewritingSubQueryReducerVisitor._select_columns_with_window_functions(
parent_select_node.select_columns
from_source_node_as_select_node.select_columns
)
}
if len(node.group_bys) > 0 and [
Expand Down Expand Up @@ -355,7 +350,7 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool:
# item in the SELECT to group by.

if len(node.group_bys) > 0 and SqlRewritingSubQueryReducerVisitor._select_columns_contain_string_expressions(
select_columns=parent_select_node.select_columns,
select_columns=from_source_node_as_select_node.select_columns,
):
return False

Expand Down Expand Up @@ -419,8 +414,7 @@ def _find_matching_select_column(
return select_column
return None

@staticmethod
def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementNode:
def _rewrite_node_with_join(self, node: SqlSelectStatementNode) -> SqlSelectStatementNode:
"""Reduces nodes with joins if the join source is simple to reduce.
Converts this:
Expand Down Expand Up @@ -558,7 +552,7 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN

clauses_to_rewrite.rewrite(column_replacements=column_replacements)
# This was already checked in _is_simple_source().
assert len(from_source_select.parent_nodes) == 1
assert len(from_source_select.join_descs) == 0
from_source = from_source_select.from_source
from_source_alias = from_source_select.from_source_alias

Expand All @@ -579,6 +573,9 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN
select_columns=tuple(clauses_to_rewrite.select_columns),
from_source=from_source,
from_source_alias=from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(new_join_descs),
group_bys=tuple(clauses_to_rewrite.group_bys),
order_bys=tuple(clauses_to_rewrite.order_bys),
Expand All @@ -595,7 +592,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
node_with_reduced_parents = self._reduce_parents(node)

if len(node_with_reduced_parents.join_descs) > 0:
return SqlRewritingSubQueryReducerVisitor._rewrite_node_with_join(node_with_reduced_parents)
return self._rewrite_node_with_join(node_with_reduced_parents)

if not self._current_node_can_be_reduced(node_with_reduced_parents):
return node_with_reduced_parents
Expand All @@ -615,7 +612,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# JOIN dim_listings c
# ON a.listing_id = b.listing_id

from_source_node = node_with_reduced_parents.parent_nodes[0]
from_source_node = node_with_reduced_parents.from_source
from_source_select_node = from_source_node.as_select_node
assert (
from_source_select_node is not None
Expand Down Expand Up @@ -699,6 +696,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
),
from_source=from_source_select_node.from_source,
from_source_alias=from_source_select_node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=from_source_select_node.join_descs,
group_bys=new_group_bys,
order_bys=tuple(new_order_bys),
Expand Down Expand Up @@ -733,13 +733,13 @@ def _find_matching_select(
) -> Optional[SqlSelectColumn]:
"""Given an expression, find the SELECT column that has the same expression."""
for select_column in select_columns:
if select_column.expr == expr:
if select_column.expr.matches(expr):
return select_column
return None

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
return node.with_new_select(node.select_statement.accept(self))

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
new_group_bys = []
Expand Down Expand Up @@ -767,6 +767,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
select_columns=node.select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
SqlJoinDescription(
right_source=x.right_source.accept(self),
Expand Down
4 changes: 3 additions & 1 deletion metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
raise RuntimeError(
"No columns are required in this node - this indicates a bug in this visitor or in the inputs."
)
# It's possible for `required_select_columns_in_this_node` to be empty because we traverse through the ancestors
# of a CTE node whenever a CTE node is updated. See `test_multi_child_pruning`.

# Based on the expressions in this select statement, figure out what column aliases are needed in the sources of
# this query (i.e. tables or sub-queries in the FROM or JOIN clauses).
Expand All @@ -178,7 +180,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
nodes_to_retain_all_columns.append(join_desc.right_source)

for node_to_retain_all_columns in nodes_to_retain_all_columns:
nearest_select_columns = node_to_retain_all_columns.nearest_select_columns({})
nearest_select_columns = node_to_retain_all_columns.nearest_select_columns(self._cte_alias_to_cte_node)
for select_column in nearest_select_columns or ():
self._current_required_column_alias_mapping.add_alias(
node=node_to_retain_all_columns, column_alias=select_column.column_alias
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def test_reducing_join_statement(
mf_test_configuration: MetricFlowTestConfiguration,
reducing_join_statement: SqlSelectStatementNode,
) -> None:
"""Tests a case where a join query should not reduced an aggregate."""
"""Tests a case where a join query should not reduce an aggregate."""
assert_default_rendered_sql_equal(
request=request,
mf_test_configuration=mf_test_configuration,
Expand Down

0 comments on commit e0ac2d9

Please sign in to comment.