From a5d257005f249be9ac131544b0cd4da1d9a69c12 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 9 Oct 2024 17:11:05 -0700 Subject: [PATCH] Optimizer bug fix: don't prune any columns whose aliases are used in the where filter or join on condition No tests currently trigger this bug, and there might not be any production use cases that would encounter it. But the new logic added in this PR would trigger this bug. We would apply a where filter to the time spine table, the optimizer would remove the select column, and then we would get a SQL error for the optimized query because the column used in the filter does not exist. --- metricflow/sql/optimizer/column_pruner.py | 9 ++- metricflow/sql/sql_exprs.py | 93 ++++++++++++++++++++++- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index 2f44a1277f..d6c3e760ef 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -119,11 +119,16 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP select_column for select_column in node.select_columns if select_column.column_alias in self._required_column_aliases + or select_column.column_alias in (node.where.used_column_aliases if node.where else ()) + or select_column.column_alias + in { + column_alias + for join_desc in node.join_descs + for column_alias in (join_desc.on_condition.used_column_aliases if join_desc.on_condition else ()) + } or select_column in node.group_bys or node.distinct ) - # TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any - # SqlExpressionNode. if len(pruned_select_columns) == 0: raise RuntimeError("All columns have been pruned - this indicates an bug in the pruner or in the inputs.") diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 7afdf8c026..4392347839 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Dict, Generic, List, Mapping, Optional, Sequence, Tuple +from typing import Dict, Generic, List, Mapping, Optional, Sequence, Set, Tuple import more_itertools from dbt_semantic_interfaces.enum_extension import assert_values_exhausted @@ -98,6 +98,12 @@ def matches(self, other: SqlExpressionNode) -> bool: """Similar to equals - returns true if these expressions are equivalent.""" pass + @property + @abstractmethod + def used_column_aliases(self) -> Set[str]: + """All column aliases used in the expression.""" + pass + @dataclass(frozen=True) class SqlExpressionTreeLineage: @@ -313,6 +319,10 @@ def as_string_expression(self) -> Optional[SqlStringExpression]: """If this is a string expression, return self.""" return self + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set(self.used_columns) if self.used_columns else set() + @dataclass(frozen=True) class SqlStringLiteralExpression(SqlExpressionNode): @@ -366,6 +376,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.literal_value == other.literal_value + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set() + @dataclass(frozen=True) class SqlColumnReference: @@ -474,6 +488,14 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102 return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name)) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return { + f"{self.col_ref.table_alias}.{self.col_ref.column_name}" + if self.should_render_table_alias + else self.col_ref.column_name + } + @dataclass(frozen=True) class SqlColumnAliasReferenceExpression(SqlExpressionNode): @@ -535,6 +557,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.column_alias == other.column_alias + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return {self.column_alias} + class SqlComparison(Enum): # noqa: D101 LESS_THAN = "<" @@ -613,6 +639,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.comparison == other.comparison and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.left_expr.used_column_aliases.union(self.right_expr.used_column_aliases) + class SqlFunction(Enum): """Names of known SQL functions like SUM() in SELECT SUM(...). @@ -811,6 +841,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.sql_function == other.sql_function and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return {column_alias for arg in self.sql_function_args for column_alias in arg.used_column_aliases} + class SqlPercentileFunctionType(Enum): """Type of percentile function used.""" @@ -931,6 +965,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.percentile_args == other.percentile_args and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.order_by_arg.used_column_aliases + class SqlWindowFunction(Enum): """Names of known SQL window functions like SUM(), RANK(), ROW_NUMBER(). @@ -1100,6 +1138,15 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 and self._parents_match(other) ) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + args = ( + list(self.sql_function_args) + + list(self.partition_by_args) + + [order_by_arg.expr for order_by_arg in self.order_by_args] + ) + return {column_alias for arg in args for column_alias in arg.used_column_aliases} + @dataclass(frozen=True) class SqlNullExpression(SqlExpressionNode): @@ -1140,6 +1187,10 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return isinstance(other, SqlNullExpression) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set() + class SqlLogicalOperator(Enum): """List all supported binary logical operator expressions. @@ -1202,6 +1253,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.operator == other.operator and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return {column_alias for arg in self.args for column_alias in arg.used_column_aliases} + @dataclass(frozen=True) class SqlIsNullExpression(SqlExpressionNode): @@ -1247,6 +1302,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlSubtractTimeIntervalExpression(SqlExpressionNode): @@ -1312,6 +1371,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlCastToTimestampExpression(SqlExpressionNode): @@ -1359,6 +1422,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlDateTruncExpression(SqlExpressionNode): @@ -1410,6 +1477,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.time_granularity == other.time_granularity and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlExtractExpression(SqlExpressionNode): @@ -1469,6 +1540,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.date_part == other.date_part and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlRatioComputationExpression(SqlExpressionNode): @@ -1534,6 +1609,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.numerator.used_column_aliases.union(self.denominator.used_column_aliases) + @dataclass(frozen=True) class SqlBetweenExpression(SqlExpressionNode): @@ -1599,6 +1678,14 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return { + column_alias + for arg in (self.column_arg, self.start_expr, self.end_expr) + for column_alias in arg.used_column_aliases + } + @dataclass(frozen=True) class SqlGenerateUuidExpression(SqlExpressionNode): @@ -1649,3 +1736,7 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False + + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set()