diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 1e1437c76..053f0d7f4 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -106,12 +106,12 @@ def _reduce_parents( ), join_descs=tuple( SqlJoinDescription( - right_source=x.right_source.accept(self), - right_source_alias=x.right_source_alias, - on_condition=x.on_condition, - join_type=x.join_type, + right_source=join_desc.right_source.accept(self), + right_source_alias=join_desc.right_source_alias, + on_condition=join_desc.on_condition, + join_type=join_desc.join_type, ) - for x in node.join_descs + for join_desc in node.join_descs ), group_bys=node.group_bys, order_bys=node.order_bys, @@ -199,7 +199,7 @@ def _is_simple_source(node: SqlSelectStatementNode) -> bool: if select_column.expr.lineage.contains_aggregate_exprs: return False return ( - len(node.parent_nodes) <= 1 + len(node.join_descs) == 0 and len(node.group_bys) == 0 and len(node.order_bys) == 0 and not node.limit @@ -212,20 +212,15 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: Reducing this node means eliminating the SELECT of this node and merging it with the parent SELECT. This checks for the cases where we are able to reduce. """ - # If this node has multiple parents (i.e. a join) that are complex, then this can't be collapsed. - is_join = len(node.join_descs) > 0 - has_multiple_parent_nodes = len(node.parent_nodes) > 1 - if has_multiple_parent_nodes or is_join: + # If this node has joins, then don't collapse this as it can be complex. + if len(node.join_descs) > 0: return False - assert len(node.parent_nodes) == 1 - parent_node = node.parent_nodes[0] - # If the parent node is not a SELECT statement, then this can't be collapsed. e.g. with a table as a parent like # SELECT foo FROM bar - if not parent_node.as_select_node: + from_source_node_as_select_node = node.from_source.as_select_node + if from_source_node_as_select_node is None: return False - parent_select_node = parent_node.as_select_node # More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in # these cases for simplicity. @@ -235,15 +230,15 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: return False # Don't reduce distinct selects - if parent_select_node.distinct: + if from_source_node_as_select_node.distinct: return False # Skip this case for simplicity of reasoning. - if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0: + if len(node.order_bys) > 0 and len(from_source_node_as_select_node.order_bys) > 0: return False # Skip this case for simplicity of reasoning. - if len(parent_select_node.group_bys) > 0 and len(node.group_bys) > 0: + if len(from_source_node_as_select_node.group_bys) > 0 and len(node.group_bys) > 0: return False # If there is a column in the parent group by that is not used in the current select statement, don't reduce or it @@ -270,9 +265,9 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: if select_column.expr.as_column_reference_expression } all_parent_group_bys_used_in_current_select = True - for group_by in parent_select_node.group_bys: + for group_by in from_source_node_as_select_node.group_bys: parent_group_by_select = SqlGroupByRewritingVisitor._find_matching_select( - expr=group_by.expr, select_columns=parent_select_node.select_columns + expr=group_by.expr, select_columns=from_source_node_as_select_node.select_columns ) if parent_group_by_select and parent_group_by_select.column_alias not in current_select_column_refs: all_parent_group_bys_used_in_current_select = False @@ -293,13 +288,13 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: # If the parent has a GROUP BY and this has a WHERE, avoid reducing as the WHERE could reference an # aggregation expression. - if len(parent_select_node.group_bys) > 0 and node.where: + if len(from_source_node_as_select_node.group_bys) > 0 and node.where: return False # If the parent has a GROUP BY, the case where it's easiest to merge this with the parent is if all select # columns are column references. if len( - parent_select_node.group_bys + from_source_node_as_select_node.group_bys ) > 0 and not SqlRewritingSubQueryReducerVisitor._select_columns_are_column_references(node.select_columns): return False @@ -308,7 +303,7 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: parent_column_aliases_with_window_functions = { select_column.column_alias for select_column in SqlRewritingSubQueryReducerVisitor._select_columns_with_window_functions( - parent_select_node.select_columns + from_source_node_as_select_node.select_columns ) } if len(node.group_bys) > 0 and [ @@ -355,7 +350,7 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: # item in the SELECT to group by. if len(node.group_bys) > 0 and SqlRewritingSubQueryReducerVisitor._select_columns_contain_string_expressions( - select_columns=parent_select_node.select_columns, + select_columns=from_source_node_as_select_node.select_columns, ): return False @@ -419,8 +414,7 @@ def _find_matching_select_column( return select_column return None - @staticmethod - def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementNode: + def _rewrite_node_with_join(self, node: SqlSelectStatementNode) -> SqlSelectStatementNode: """Reduces nodes with joins if the join source is simple to reduce. Converts this: @@ -558,7 +552,7 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN clauses_to_rewrite.rewrite(column_replacements=column_replacements) # This was already checked in _is_simple_source(). - assert len(from_source_select.parent_nodes) == 1 + assert len(from_source_select.join_descs) == 0 from_source = from_source_select.from_source from_source_alias = from_source_select.from_source_alias @@ -579,6 +573,9 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN select_columns=tuple(clauses_to_rewrite.select_columns), from_source=from_source, from_source_alias=from_source_alias, + cte_sources=tuple( + cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources + ), join_descs=tuple(new_join_descs), group_bys=tuple(clauses_to_rewrite.group_bys), order_bys=tuple(clauses_to_rewrite.order_bys), @@ -595,7 +592,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP node_with_reduced_parents = self._reduce_parents(node) if len(node_with_reduced_parents.join_descs) > 0: - return SqlRewritingSubQueryReducerVisitor._rewrite_node_with_join(node_with_reduced_parents) + return self._rewrite_node_with_join(node_with_reduced_parents) if not self._current_node_can_be_reduced(node_with_reduced_parents): return node_with_reduced_parents @@ -615,7 +612,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP # JOIN dim_listings c # ON a.listing_id = b.listing_id - from_source_node = node_with_reduced_parents.parent_nodes[0] + from_source_node = node_with_reduced_parents.from_source from_source_select_node = from_source_node.as_select_node assert ( from_source_select_node is not None @@ -699,6 +696,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP ), from_source=from_source_select_node.from_source, from_source_alias=from_source_select_node.from_source_alias, + cte_sources=tuple( + cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources + ), join_descs=from_source_select_node.join_descs, group_bys=new_group_bys, order_bys=tuple(new_order_bys), @@ -733,13 +733,13 @@ def _find_matching_select( ) -> Optional[SqlSelectColumn]: """Given an expression, find the SELECT column that has the same expression.""" for select_column in select_columns: - if select_column.expr == expr: + if select_column.expr.matches(expr): return select_column return None @override def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode: - raise NotImplementedError + return node.with_new_select(node.select_statement.accept(self)) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102 new_group_bys = [] @@ -767,6 +767,9 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP select_columns=node.select_columns, from_source=node.from_source.accept(self), from_source_alias=node.from_source_alias, + cte_sources=tuple( + cte_source.with_new_select(cte_source.select_statement.accept(self)) for cte_source in node.cte_sources + ), join_descs=tuple( SqlJoinDescription( right_source=x.right_source.accept(self), diff --git a/metricflow/sql/optimizer/tag_required_column_aliases.py b/metricflow/sql/optimizer/tag_required_column_aliases.py index a1c08f733..81a09f009 100644 --- a/metricflow/sql/optimizer/tag_required_column_aliases.py +++ b/metricflow/sql/optimizer/tag_required_column_aliases.py @@ -165,6 +165,8 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: raise RuntimeError( "No columns are required in this node - this indicates a bug in this visitor or in the inputs." ) + # It's possible for `required_select_columns_in_this_node` to be empty because we traverse through the ancestors + # of a CTE node whenever a CTE node is updated. See `test_multi_child_pruning`. # Based on the expressions in this select statement, figure out what column aliases are needed in the sources of # this query (i.e. tables or sub-queries in the FROM or JOIN clauses). @@ -178,7 +180,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None: nodes_to_retain_all_columns.append(join_desc.right_source) for node_to_retain_all_columns in nodes_to_retain_all_columns: - nearest_select_columns = node_to_retain_all_columns.nearest_select_columns({}) + nearest_select_columns = node_to_retain_all_columns.nearest_select_columns(self._cte_alias_to_cte_node) for select_column in nearest_select_columns or (): self._current_required_column_alias_mapping.add_alias( node=node_to_retain_all_columns, column_alias=select_column.column_alias diff --git a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py index e4673891f..221594f36 100644 --- a/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py +++ b/tests_metricflow/sql/optimizer/test_rewriting_sub_query_reducer.py @@ -822,7 +822,7 @@ def test_reducing_join_statement( mf_test_configuration: MetricFlowTestConfiguration, reducing_join_statement: SqlSelectStatementNode, ) -> None: - """Tests a case where a join query should not reduced an aggregate.""" + """Tests a case where a join query should not reduce an aggregate.""" assert_default_rendered_sql_equal( request=request, mf_test_configuration=mf_test_configuration,