diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index dde57a12f..a61c8db8f 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -68,8 +68,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP select_columns=pruned_select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, - # TODO: Handle CTEs. - cte_sources=(), + cte_sources=tuple( + cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources + ), join_descs=tuple( join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs ), @@ -96,7 +97,7 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan @override def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: - raise NotImplementedError + return node.with_new_select(node.select_statement.accept(self)) class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer): diff --git a/metricflow/sql/optimizer/tag_required_column_aliases.py b/metricflow/sql/optimizer/tag_required_column_aliases.py index dcddc5f0c..6b34fb545 100644 --- a/metricflow/sql/optimizer/tag_required_column_aliases.py +++ b/metricflow/sql/optimizer/tag_required_column_aliases.py @@ -55,6 +55,7 @@ def __init__(self, tagged_column_alias_set: TaggedColumnAliasSet) -> None: traverses the SQL-query representation DAG. """ self._column_alias_tagger = tagged_column_alias_set + self._cte_alias_to_cte_node: Dict[str, SqlCteNode] = {} def _search_for_expressions( self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...] @@ -86,7 +87,13 @@ def _search_for_expressions( @override def visit_cte_node(self, node: SqlCteNode) -> None: - raise NotImplementedError + select_statement = node.select_statement + # Copy the tagged aliases from the CTE to the SELECT since when visiting a SELECT, the CTE node (not the SELECT + # in the CTE) was tagged with the required aliases. + required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node) + self._column_alias_tagger.tag_aliases(select_statement, required_column_aliases_in_this_node) + # Visit parent nodes. + select_statement.accept(self) def _visit_parents(self, node: SqlQueryPlanNode) -> None: """Default recursive handler to visit the parents of the given node.""" @@ -94,9 +101,18 @@ def _visit_parents(self, node: SqlQueryPlanNode) -> None: parent_node.accept(self) return + def _tag_potential_cte_node(self, table_name: str, column_aliases: Set[str]) -> None: + """A reference to a SQL table might be a CTE. If so, tag the appropriate aliases in the CTEs.""" + cte_node = self._cte_alias_to_cte_node.get(table_name) + if cte_node is not None: + self._column_alias_tagger.tag_aliases(cte_node, column_aliases) + # Propagate the required aliases to parents, which could be other CTEs. + cte_node.accept(self) + def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # noqa: D102 # Based on column aliases that are tagged in this SELECT statement, tag corresponding column aliases in # parent nodes. + self._cte_alias_to_cte_node.update({cte_source.cte_alias: cte_source for cte_source in node.cte_sources}) initial_required_column_aliases_in_this_node = self._column_alias_tagger.get_tagged_aliases(node) @@ -155,17 +171,35 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: # column_reference = column_reference_expr.col_ref source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name) + logger.debug( + LazyFormat( + "Collected required column names from sources", + source_alias_to_required_column_alias=source_alias_to_required_column_alias, + ) + ) # Appropriately tag the columns required in the parent nodes. if node.from_source_alias in source_alias_to_required_column_alias: aliases_required_in_parent = source_alias_to_required_column_alias[node.from_source_alias] self._column_alias_tagger.tag_aliases(node=node.from_source, column_aliases=aliases_required_in_parent) + from_source_as_sql_table_node = node.from_source.as_sql_table_node + if from_source_as_sql_table_node is not None: + self._tag_potential_cte_node( + table_name=from_source_as_sql_table_node.sql_table.table_name, + column_aliases=aliases_required_in_parent, + ) + for join_desc in node.join_descs: if join_desc.right_source_alias in source_alias_to_required_column_alias: aliases_required_in_parent = source_alias_to_required_column_alias[join_desc.right_source_alias] self._column_alias_tagger.tag_aliases( node=join_desc.right_source, column_aliases=aliases_required_in_parent ) - # TODO: Handle CTEs parent nodes. + right_source_as_sql_table_node = join_desc.right_source.as_sql_table_node + if right_source_as_sql_table_node is not None: + self._tag_potential_cte_node( + table_name=right_source_as_sql_table_node.sql_table.table_name, + column_aliases=aliases_required_in_parent, + ) # For all string columns, assume that they are needed from all sources since we don't have a table alias # in SqlStringExpression.used_columns diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index d00ed6e6b..e6fb088a5 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -369,6 +369,13 @@ def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: # cte_alias=cte_alias, ) + def with_new_select(self, new_select_statement: SqlQueryPlanNode) -> SqlCteNode: + """Return a node with the same attributes but with the new SELECT statement.""" + return SqlCteNode.create( + select_statement=new_select_statement, + cte_alias=self.cte_alias, + ) + @override def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: return visitor.visit_cte_node(self)