Skip to content

Commit

Permalink
/* PR_START p--cte 08 */ Support CTEs in the column pruner.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 4, 2024
1 parent c7c7b6b commit e326ad3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 5 deletions.
7 changes: 4 additions & 3 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
select_columns=pruned_select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
# TODO: Handle CTEs.
cte_sources=(),
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs
),
Expand All @@ -96,7 +97,7 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
return node.with_new_select(node.select_statement.accept(self))


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
Expand Down
38 changes: 36 additions & 2 deletions metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, tagged_column_alias_set: TaggedColumnAliasSet) -> None:
traverses the SQL-query representation DAG.
"""
self._column_alias_tagger = tagged_column_alias_set
self._cte_alias_to_cte_node: Dict[str, SqlCteNode] = {}

def _search_for_expressions(
self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
Expand Down Expand Up @@ -86,17 +87,32 @@ def _search_for_expressions(

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
raise NotImplementedError
select_statement = node.select_statement
# Copy the tagged aliases from the CTE to the SELECT since when visiting a SELECT, the CTE node (not the SELECT
# in the CTE) was tagged with the required aliases.
required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node)
self._column_alias_tagger.tag_aliases(select_statement, required_column_aliases_in_this_node)
# Visit parent nodes.
select_statement.accept(self)

def _visit_parents(self, node: SqlQueryPlanNode) -> None:
"""Default recursive handler to visit the parents of the given node."""
for parent_node in node.parent_nodes:
parent_node.accept(self)
return

def _tag_potential_cte_node(self, table_name: str, column_aliases: Set[str]) -> None:
"""A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs."""
cte_node = self._cte_alias_to_cte_node.get(table_name)
if cte_node is not None:
self._column_alias_tagger.tag_aliases(cte_node, column_aliases)
# Propagate the required aliases to parents, which could be other CTEs.
cte_node.accept(self)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # noqa: D102
# Based on column aliases that are tagged in this SELECT statement, tag corresponding column aliases in
# parent nodes.
self._cte_alias_to_cte_node.update({cte_source.cte_alias: cte_source for cte_source in node.cte_sources})

initial_required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node)

Expand Down Expand Up @@ -155,17 +171,35 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: #
column_reference = column_reference_expr.col_ref
source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name)

logger.debug(
LazyFormat(
"Collected required column names from sources",
source_alias_to_required_column_alias=source_alias_to_required_column_alias,
)
)
# Appropriately tag the columns required in the parent nodes.
if node.from_source_alias in source_alias_to_required_column_alias:
aliases_required_in_parent = source_alias_to_required_column_alias[node.from_source_alias]
self._column_alias_tagger.tag_aliases(node=node.from_source, column_aliases=aliases_required_in_parent)
from_source_as_sql_table_node = node.from_source.as_sql_table_node
if from_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
table_name=from_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)

for join_desc in node.join_descs:
if join_desc.right_source_alias in source_alias_to_required_column_alias:
aliases_required_in_parent = source_alias_to_required_column_alias[join_desc.right_source_alias]
self._column_alias_tagger.tag_aliases(
node=join_desc.right_source, column_aliases=aliases_required_in_parent
)
# TODO: Handle CTEs parent nodes.
right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node
if right_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
table_name=right_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)

# For all string columns, assume that they are needed from all sources since we don't have a table alias
# in SqlStringExpression.used_columns
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: #
cte_alias=cte_alias,
)

def with_new_select(self, new_select_statement: SqlQueryPlanNode) -> SqlCteNode:
"""Return a node with the same attributes but with the new SELECT statement."""
return SqlCteNode.create(
select_statement=new_select_statement,
cte_alias=self.cte_alias,
)

@override
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_cte_node(self)
Expand Down

0 comments on commit e326ad3

Please sign in to comment.