Skip to content

Commit

Permalink
/* PR_START p--cte 09 */ Rename parent node to from-source node.
Browse files Browse the repository at this point in the history
Previously, "parent nodes" was used for either nodes in the FROM clause or the
JOIN clause. With the addition of CTEs that are also considered parent nodes,
this renames a few variables / updates some conditionals for clarity.
  • Loading branch information
plypaul committed Nov 9, 2024
1 parent 05e5747 commit b526055
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def _reduce_parents(
select_columns=node.select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
SqlJoinDescription(
right_source=x.right_source.accept(self),
Expand Down Expand Up @@ -150,7 +153,7 @@ def _select_column_for_alias(column_alias: str, select_columns: Sequence[SqlSele
for select_column in select_columns:
if select_column.column_alias == column_alias:
return select_column
raise RuntimeError(f"Column alias '{column_alias}' not in SELECT columns: {select_columns}")
raise RuntimeError(f"Column alias {repr(column_alias)} not in SELECT columns: {select_columns}")

@staticmethod
def _is_simple_source(node: SqlSelectStatementNode) -> bool:
Expand Down Expand Up @@ -591,7 +594,7 @@ def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
node_with_reduced_parents = self._reduce_parents(node)

if len(node_with_reduced_parents.parent_nodes) > 1:
if len(node_with_reduced_parents.join_descs) > 0:
return SqlRewritingSubQueryReducerVisitor._rewrite_node_with_join(node_with_reduced_parents)

if not self._current_node_can_be_reduced(node_with_reduced_parents):
Expand All @@ -612,10 +615,11 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# JOIN dim_listings c
# ON a.listing_id = b.listing_id

assert len(node_with_reduced_parents.parent_nodes) == 1
parent_node = node_with_reduced_parents.parent_nodes[0]
parent_select_node = parent_node.as_select_node
assert parent_select_node
from_source_node = node_with_reduced_parents.parent_nodes[0]
from_source_select_node = from_source_node.as_select_node
assert (
from_source_select_node is not None
), f"{from_source_select_node=} should be set as `_current_node_can_be_reduced()` returned True"

# At this point, the query should look similar to
#
Expand All @@ -631,7 +635,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# The ORDER BY in the parent doesn't matter since the order by in this node will "overwrite" the order in the
# parent as long as the parent has no limits.
column_replacements = SqlRewritingSubQueryReducerVisitor._get_column_replacements(
parent_node=parent_select_node,
parent_node=from_source_select_node,
parent_node_alias=node.from_source_alias,
)
new_order_bys: List[SqlOrderByDescription] = []
Expand Down Expand Up @@ -671,12 +675,12 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
# The limit should be the min of this SELECT limit and the parent SELECT limit.
new_limit: Optional[int] = node_with_reduced_parents.limit
if new_limit is None:
new_limit = parent_select_node.limit
elif parent_select_node.limit is not None:
new_limit = min(new_limit, parent_select_node.limit)
new_limit = from_source_select_node.limit
elif from_source_select_node.limit is not None:
new_limit = min(new_limit, from_source_select_node.limit)

new_group_bys: Tuple[SqlSelectColumn, ...] = ()
if node.group_bys and parent_select_node.group_bys:
if node.group_bys and from_source_select_node.group_bys:
raise RuntimeError(
"Attempting to reduce sub-queries when this and the parent have GROUP BYs. This should have been "
"prevent by _should_reduce()"
Expand All @@ -685,26 +689,26 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
new_group_bys = SqlRewritingSubQueryReducerVisitor._rewrite_select_columns(
old_select_columns=node.group_bys, column_replacements=column_replacements
)
elif parent_select_node.group_bys:
new_group_bys = parent_select_node.group_bys
elif from_source_select_node.group_bys:
new_group_bys = from_source_select_node.group_bys

return SqlSelectStatementNode.create(
description="\n".join([parent_select_node.description, node_with_reduced_parents.description]),
description="\n".join([from_source_select_node.description, node_with_reduced_parents.description]),
select_columns=SqlRewritingSubQueryReducerVisitor._rewrite_select_columns(
old_select_columns=node.select_columns, column_replacements=column_replacements
),
from_source=parent_select_node.from_source,
from_source_alias=parent_select_node.from_source_alias,
join_descs=parent_select_node.join_descs,
from_source=from_source_select_node.from_source,
from_source_alias=from_source_select_node.from_source_alias,
join_descs=from_source_select_node.join_descs,
group_bys=new_group_bys,
order_bys=tuple(new_order_bys),
where=SqlRewritingSubQueryReducerVisitor._rewrite_where(
column_replacements=column_replacements,
node_where=node.where,
parent_node_where=parent_select_node.where,
parent_node_where=from_source_select_node.where,
),
limit=new_limit,
distinct=parent_select_node.distinct,
distinct=from_source_select_node.distinct,
)

def visit_table_node(self, node: SqlTableNode) -> SqlQueryPlanNode: # noqa: D102
Expand Down

0 comments on commit b526055

Please sign in to comment.