Skip to content

Commit

Permalink
/* PR_START p--cte 08 */ Support CTEs in the column pruner.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Nov 8, 2024
1 parent 92ce60c commit f3291f4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
7 changes: 4 additions & 3 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
select_columns=pruned_select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
# TODO: Handle CTEs.
cte_sources=(),
cte_sources=tuple(
cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources
),
join_descs=tuple(
join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs
),
Expand All @@ -96,7 +97,7 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
return node.with_new_select(node.select_statement.accept(self))


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
Expand Down
22 changes: 20 additions & 2 deletions metricflow/sql/optimizer/tag_required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ def _search_for_expressions(

@override
def visit_cte_node(self, node: SqlCteNode) -> None:
raise NotImplementedError
select_statement = node.select_statement
# Copy the tagged aliases from the CTE to the SELECT since when visiting a SELECT, the CTE node (not the SELECT
# in the CTE) was tagged with the required aliases.
required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node)
self._current_required_column_alias_mapping.add_aliases(select_statement, required_column_aliases_in_this_node)
# Visit parent nodes.
select_statement.accept(self)

def _visit_parents(self, node: SqlQueryPlanNode) -> None:
"""Default recursive handler to visit the parents of the given node."""
Expand Down Expand Up @@ -192,13 +198,25 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
self._current_required_column_alias_mapping.add_aliases(
node=node.from_source, column_aliases=aliases_required_in_parent
)
from_source_as_sql_table_node = node.from_source.as_sql_table_node
if from_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
table_name=from_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)

for join_desc in node.join_descs:
if join_desc.right_source_alias in source_alias_to_required_column_aliases:
aliases_required_in_parent = source_alias_to_required_column_aliases[join_desc.right_source_alias]
self._current_required_column_alias_mapping.add_aliases(
node=join_desc.right_source, column_aliases=aliases_required_in_parent
)
# TODO: Handle CTEs parent nodes.
right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node
if right_source_as_sql_table_node is not None:
self._tag_potential_cte_node(
table_name=right_source_as_sql_table_node.sql_table.table_name,
column_aliases=aliases_required_in_parent,
)

# For all string columns, assume that they are needed from all sources since we don't have a table alias
# in SqlStringExpression.used_columns
Expand Down
7 changes: 7 additions & 0 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,13 @@ def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: #
cte_alias=cte_alias,
)

def with_new_select(self, new_select_statement: SqlQueryPlanNode) -> SqlCteNode:
"""Return a node with the same attributes but with the new SELECT statement."""
return SqlCteNode.create(
select_statement=new_select_statement,
cte_alias=self.cte_alias,
)

@override
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_cte_node(self)
Expand Down

0 comments on commit f3291f4

Please sign in to comment.