Skip to content

Commit

Permalink
/* PR_START p--cte 08 */ Support CTEs in the column pruner.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 9, 2024
1 parent 57c5ede commit 01b57a1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
7 changes: 4 additions & 3 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
select_columns=pruned_select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
# TODO: Handle CTEs.
cte_sources=(),
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs
),
Expand All @@ -96,7 +97,7 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
return node.with_new_select(node.select_statement.accept(self))


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
Expand Down
30 changes: 27 additions & 3 deletions metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,28 @@ def _search_for_expressions(

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
raise NotImplementedError
select_statement = node.select_statement
# Copy the tagged aliases from the CTE to the SELECT since when visiting a SELECT, the CTE node (not the SELECT
# in the CTE) was tagged with the required aliases.
required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node)
self._current_required_column_alias_mapping.add_aliases(select_statement, required_column_aliases_in_this_node)
# Visit parent nodes.
select_statement.accept(self)

def _visit_parents(self, node: SqlQueryPlanNode) -> None:
"""Default recursive handler to visit the parents of the given node."""
for parent_node in node.parent_nodes:
parent_node.accept(self)
return

def _tag_potential_cte_node(self, table_name: str, column_aliases: Set[str]) -> None:
"""A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs."""
cte_node = self._cte_alias_to_cte_node.get(table_name)
if cte_node is not None:
self._current_required_column_alias_mapping.add_aliases(cte_node, column_aliases)
# Propagate the required aliases to parents, which could be other CTEs.
cte_node.accept(self)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
"""Based on required column aliases for this SELECT, figure out required column aliases in parents."""
initial_required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node)
Expand Down Expand Up @@ -191,14 +205,24 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
self._current_required_column_alias_mapping.add_aliases(
node=node.from_source, column_aliases=aliases_required_in_parent
)

from_source_as_sql_table_node = node.from_source.as_sql_table_node
if from_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
table_name=from_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)
for join_desc in node.join_descs:
if join_desc.right_source_alias in source_alias_to_required_column_aliases:
aliases_required_in_parent = source_alias_to_required_column_aliases[join_desc.right_source_alias]
self._current_required_column_alias_mapping.add_aliases(
node=join_desc.right_source, column_aliases=aliases_required_in_parent
)
# TODO: Handle CTEs parent nodes.
right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node
if right_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
table_name=right_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)

# For all string columns, assume that they are needed from all sources since we don't have a table alias
# in SqlStringExpression.used_columns
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,13 @@ def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: #
cte_alias=cte_alias,
)

def with_new_select(self, new_select_statement: SqlQueryPlanNode) -> SqlCteNode:
"""Return a node with the same attributes but with the new SELECT statement."""
return SqlCteNode.create(
select_statement=new_select_statement,
cte_alias=self.cte_alias,
)

@override
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_cte_node(self)
Expand Down

0 comments on commit 01b57a1

Please sign in to comment.