From f3291f44df8ee21c1f51e8cf1801adaebc2b6fae Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 4 Nov 2024 08:40:24 -0800 Subject: [PATCH] /* PR_START p--cte 08 */ Support CTEs in the column pruner. --- metricflow/sql/optimizer/column_pruner.py | 7 +++--- .../optimizer/tag_required_column_aliases.py | 22 +++++++++++++++++-- metricflow/sql/sql_plan.py | 7 ++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index cd06a8fd8..3f4bdb102 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -68,8 +68,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP select_columns=pruned_select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - # TODO: Handle CTEs. - cte_sources=(), + cte_sources=tuple( + cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources + ), join_descs=tuple( join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs ), @@ -96,7 +97,7 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan @override def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: - raise NotImplementedError + return node.with_new_select(node.select_statement.accept(self)) class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer): diff --git a/metricflow/sql/optimizer/tag_required_column_aliases.py b/metricflow/sql/optimizer/tag_required_column_aliases.py index 8a9add74d..d2557cf9d 100644 --- a/metricflow/sql/optimizer/tag_required_column_aliases.py +++ b/metricflow/sql/optimizer/tag_required_column_aliases.py @@ -106,7 +106,13 @@ def _search_for_expressions( @override def visit_cte_node(self, node: SqlCteNode) -> None: - raise NotImplementedError + select_statement = node.select_statement + # Copy the tagged aliases from the CTE to the SELECT since when visiting a SELECT, the CTE node (not the SELECT + # in the CTE) was tagged with the required aliases. + required_column_aliases_in_this_node = self._current_required_column_alias_mapping.get_aliases(node) + self._current_required_column_alias_mapping.add_aliases(select_statement, required_column_aliases_in_this_node) + # Visit parent nodes. + select_statement.accept(self) def _visit_parents(self, node: SqlQueryPlanNode) -> None: """Default recursive handler to visit the parents of the given node.""" @@ -192,13 +198,25 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: self._current_required_column_alias_mapping.add_aliases( node=node.from_source, column_aliases=aliases_required_in_parent ) + from_source_as_sql_table_node = node.from_source.as_sql_table_node + if from_source_as_sql_table_node is not None: + self._tag_potential_cte_node( + table_name=from_source_as_sql_table_node.sql_table.table_name, + column_aliases=aliases_required_in_parent, + ) + for join_desc in node.join_descs: if join_desc.right_source_alias in source_alias_to_required_column_aliases: aliases_required_in_parent = source_alias_to_required_column_aliases[join_desc.right_source_alias] self._current_required_column_alias_mapping.add_aliases( node=join_desc.right_source, column_aliases=aliases_required_in_parent ) - # TODO: Handle CTEs parent nodes. + right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node + if right_source_as_sql_table_node is not None: + self._tag_potential_cte_node( + table_name=right_source_as_sql_table_node.sql_table.table_name, + column_aliases=aliases_required_in_parent, + ) # For all string columns, assume that they are needed from all sources since we don't have a table alias # in SqlStringExpression.used_columns diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index a0a532b0c..9a40a9ccc 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -401,6 +401,13 @@ def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: # cte_alias=cte_alias, ) + def with_new_select(self, new_select_statement: SqlQueryPlanNode) -> SqlCteNode: + """Return a node with the same attributes but with the new SELECT statement.""" + return SqlCteNode.create( + select_statement=new_select_statement, + cte_alias=self.cte_alias, + ) + @override def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: return visitor.visit_cte_node(self)