Skip to content

Commit

Permalink
update optimizer for distinct select
Browse files Browse the repository at this point in the history
  • Loading branch information
WilliamDee committed Dec 9, 2022
1 parent 4a37893 commit cce12a8
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
7 changes: 6 additions & 1 deletion metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,19 @@ def _prune_columns_from_grandparents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
pruned_select_columns = tuple(
select_column
for select_column in node.select_columns
if select_column.column_alias in self._required_column_aliases or select_column in node.group_bys
if select_column.column_alias in self._required_column_aliases
or select_column in node.group_bys
or node.distinct
)

if len(pruned_select_columns) == 0:
Expand Down Expand Up @@ -184,6 +188,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
9 changes: 9 additions & 0 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _reduce_parents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

@staticmethod
Expand Down Expand Up @@ -218,6 +219,10 @@ def _current_node_can_be_reduced(self, node: SqlSelectStatementNode) -> bool: #
if SqlRewritingSubQueryReducerVisitor._statement_contains_difficult_expressions(node):
return False

# Don't reduce distinct selects
if parent_select_node.distinct:
return False

# Skip this case for simplicity of reasoning.
if len(node.order_bys) > 0 and len(parent_select_node.order_bys) > 0:
return False
Expand Down Expand Up @@ -522,9 +527,11 @@ def _rewrite_node_with_join(node: SqlSelectStatementNode) -> SqlSelectStatementN
order_bys=tuple(clauses_to_rewrite.order_bys),
where=clauses_to_rewrite.combine_wheres(additional_where_clauses),
limit=node.limit,
distinct=node.distinct,
)

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D
# print(node.description)
node_with_reduced_parents = self._reduce_parents(node)

if len(node_with_reduced_parents.parent_nodes) > 1:
Expand Down Expand Up @@ -640,6 +647,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
parent_node_where=parent_select_node.where,
),
limit=new_limit,
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down Expand Up @@ -697,6 +705,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
6 changes: 6 additions & 0 deletions metricflow/sql/optimizer/sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _reduce_parents(
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D
Expand Down Expand Up @@ -70,6 +71,10 @@ def _reduce_is_possible(self, node: SqlSelectStatementNode) -> bool: # noqa: D
# More conditions where we don't want to collapse. It's not impossible with these cases, but not reducing in
# these cases for simplicity.

# Don't reduce distinct selects
if parent_select_node.distinct:
return False

# Reducing a where is tricky as it requires the expressions to be re-written.
if node.where:
return False
Expand Down Expand Up @@ -178,6 +183,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=tuple(new_order_by),
where=parent_select_node.where,
limit=new_limit,
distinct=parent_select_node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down
2 changes: 2 additions & 0 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
),
where=node.where.rewrite(should_render_table_alias=False) if node.where else None,
limit=node.limit,
distinct=node.distinct,
)

return SqlSelectStatementNode(
Expand All @@ -62,6 +63,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

def visit_table_from_clause_node(self, node: SqlTableFromClauseNode) -> SqlQueryPlanNode: # noqa: D
Expand Down

0 comments on commit cce12a8

Please sign in to comment.