diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index 55a746ef11..92b0e64309 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -108,15 +108,19 @@ def _prune_columns_from_grandparents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D # Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't # need them. However, keep columns that are in group bys because that changes the meaning of the query. + # Similarly, if this node is a distinct select node, keep all columns as it may return a different result set. pruned_select_columns = tuple( select_column for select_column in node.select_columns - if select_column.column_alias in self._required_column_aliases or select_column in node.group_bys + if select_column.column_alias in self._required_column_aliases + or select_column in node.group_bys + or node.distinct ) if len(pruned_select_columns) == 0: @@ -184,6 +188,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index 745a4172c2..0ec1637087 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -108,6 +108,7 @@ def _reduce_parents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) @staticmethod @@ -218,6 +219,10 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: # if SqlRewritingSubQueryReducerVisitor._statement_contains_difficult_expressions(node): return False + # Don't reduce distinct selects + if parent_select_node.distinct: + return False + # Skip this case for simplicity of reasoning. if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0: return False @@ -522,9 +527,11 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN order_bys=tuple(clauses_to_rewrite.order_bys), where=clauses_to_rewrite.combine_wheres(additional_where_clauses), limit=node.limit, + distinct=node.distinct, ) def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D + # print(node.description) node_with_reduced_parents = self._reduce_parents(node) if len(node_with_reduced_parents.parent_nodes) > 1: @@ -640,6 +647,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP parent_node_where=parent_select_node.where, ), limit=new_limit, + distinct=parent_select_node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D @@ -697,6 +705,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/sub_query_reducer.py b/metricflow/sql/optimizer/sub_query_reducer.py index a745660ed7..06e303950e 100644 --- a/metricflow/sql/optimizer/sub_query_reducer.py +++ b/metricflow/sql/optimizer/sub_query_reducer.py @@ -42,6 +42,7 @@ def _reduce_parents( order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D @@ -70,6 +71,10 @@ def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D # More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in # these cases for simplicity. + # Don't reduce distinct selects + if parent_select_node.distinct: + return False + # Reducing a where is tricky as it requires the expressions to be re-written. if node.where: return False @@ -178,6 +183,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=tuple(new_order_by), where=parent_select_node.where, limit=new_limit, + distinct=parent_select_node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 21521e0e12..a36a339f3f 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -42,6 +42,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP ), where=node.where.rewrite(should_render_table_alias=False) if node.where else None, limit=node.limit, + distinct=node.distinct, ) return SqlSelectStatementNode( @@ -62,6 +63,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP order_bys=node.order_bys, where=node.where, limit=node.limit, + distinct=node.distinct, ) def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D