Skip to content

Commit

Permalink
Optimizer bug fix: don't prune any columns whose aliases are used in …
Browse files Browse the repository at this point in the history
…the where filter or join on condition

No tests currently trigger this bug, and there might not be any production use cases that would encounter it. But the new logic added in this PR would trigger this bug. We would apply a where filter to the time spine table, the optimizer would remove the select column, and then we would get a SQL error for the optimized query because the column used in the filter does not exist.
  • Loading branch information
courtneyholcomb committed Oct 10, 2024
1 parent a773399 commit a5d2570
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 3 deletions.
9 changes: 7 additions & 2 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,16 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP
select_column
for select_column in node.select_columns
if select_column.column_alias in self._required_column_aliases
or select_column.column_alias in (node.where.used_column_aliases if node.where else ())
or select_column.column_alias
in {
column_alias
for join_desc in node.join_descs
for column_alias in (join_desc.on_condition.used_column_aliases if join_desc.on_condition else ())
}
or select_column in node.group_bys
or node.distinct
)
# TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any
# SqlExpressionNode.

if len(pruned_select_columns) == 0:
raise RuntimeError("All columns have been pruned - this indicates an bug in the pruner or in the inputs.")
Expand Down
93 changes: 92 additions & 1 deletion metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Generic, List, Mapping, Optional, Sequence, Tuple
from typing import Dict, Generic, List, Mapping, Optional, Sequence, Set, Tuple

import more_itertools
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand Down Expand Up @@ -98,6 +98,12 @@ def matches(self, other: SqlExpressionNode) -> bool:
"""Similar to equals - returns true if these expressions are equivalent."""
pass

@property
@abstractmethod
def used_column_aliases(self) -> Set[str]:
"""All column aliases used in the expression."""
pass


@dataclass(frozen=True)
class SqlExpressionTreeLineage:
Expand Down Expand Up @@ -313,6 +319,10 @@ def as_string_expression(self) -> Optional[SqlStringExpression]:
"""If this is a string expression, return self."""
return self

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return set(self.used_columns) if self.used_columns else set()


@dataclass(frozen=True)
class SqlStringLiteralExpression(SqlExpressionNode):
Expand Down Expand Up @@ -366,6 +376,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.literal_value == other.literal_value

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return set()


@dataclass(frozen=True)
class SqlColumnReference:
Expand Down Expand Up @@ -474,6 +488,14 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102
return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name))

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return {
f"{self.col_ref.table_alias}.{self.col_ref.column_name}"
if self.should_render_table_alias
else self.col_ref.column_name
}


@dataclass(frozen=True)
class SqlColumnAliasReferenceExpression(SqlExpressionNode):
Expand Down Expand Up @@ -535,6 +557,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.column_alias == other.column_alias

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return {self.column_alias}


class SqlComparison(Enum): # noqa: D101
LESS_THAN = "<"
Expand Down Expand Up @@ -613,6 +639,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.comparison == other.comparison and self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.left_expr.used_column_aliases.union(self.right_expr.used_column_aliases)


class SqlFunction(Enum):
"""Names of known SQL functions like SUM() in SELECT SUM(...).
Expand Down Expand Up @@ -811,6 +841,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.sql_function == other.sql_function and self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return {column_alias for arg in self.sql_function_args for column_alias in arg.used_column_aliases}


class SqlPercentileFunctionType(Enum):
"""Type of percentile function used."""
Expand Down Expand Up @@ -931,6 +965,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.percentile_args == other.percentile_args and self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.order_by_arg.used_column_aliases


class SqlWindowFunction(Enum):
"""Names of known SQL window functions like SUM(), RANK(), ROW_NUMBER().
Expand Down Expand Up @@ -1100,6 +1138,15 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
and self._parents_match(other)
)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
args = (
list(self.sql_function_args)
+ list(self.partition_by_args)
+ [order_by_arg.expr for order_by_arg in self.order_by_args]
)
return {column_alias for arg in args for column_alias in arg.used_column_aliases}


@dataclass(frozen=True)
class SqlNullExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1140,6 +1187,10 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102
def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return isinstance(other, SqlNullExpression)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return set()


class SqlLogicalOperator(Enum):
"""List all supported binary logical operator expressions.
Expand Down Expand Up @@ -1202,6 +1253,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.operator == other.operator and self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return {column_alias for arg in self.args for column_alias in arg.used_column_aliases}


@dataclass(frozen=True)
class SqlIsNullExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1247,6 +1302,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.arg.used_column_aliases


@dataclass(frozen=True)
class SqlSubtractTimeIntervalExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1312,6 +1371,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.count == other.count and self.granularity == other.granularity and self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.arg.used_column_aliases


@dataclass(frozen=True)
class SqlCastToTimestampExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1359,6 +1422,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.arg.used_column_aliases


@dataclass(frozen=True)
class SqlDateTruncExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1410,6 +1477,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.time_granularity == other.time_granularity and self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.arg.used_column_aliases


@dataclass(frozen=True)
class SqlExtractExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1469,6 +1540,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self.date_part == other.date_part and self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.arg.used_column_aliases


@dataclass(frozen=True)
class SqlRatioComputationExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1534,6 +1609,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return self.numerator.used_column_aliases.union(self.denominator.used_column_aliases)


@dataclass(frozen=True)
class SqlBetweenExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1599,6 +1678,14 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False
return self._parents_match(other)

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return {
column_alias
for arg in (self.column_arg, self.start_expr, self.end_expr)
for column_alias in arg.used_column_aliases
}


@dataclass(frozen=True)
class SqlGenerateUuidExpression(SqlExpressionNode):
Expand Down Expand Up @@ -1649,3 +1736,7 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102

def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return False

@property
def used_column_aliases(self) -> Set[str]: # noqa: D102
return set()

0 comments on commit a5d2570

Please sign in to comment.