diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index 2f44a1277f..d6c3e760ef 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -119,11 +119,16 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryP select_column for select_column in node.select_columns if select_column.column_alias in self._required_column_aliases + or select_column.column_alias in (node.where.used_column_aliases if node.where else ()) + or select_column.column_alias + in { + column_alias + for join_desc in node.join_descs + for column_alias in (join_desc.on_condition.used_column_aliases if join_desc.on_condition else ()) + } or select_column in node.group_bys or node.distinct ) - # TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any - # SqlExpressionNode. if len(pruned_select_columns) == 0: raise RuntimeError("All columns have been pruned - this indicates an bug in the pruner or in the inputs.") diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 7afdf8c026..4392347839 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Dict, Generic, List, Mapping, Optional, Sequence, Tuple +from typing import Dict, Generic, List, Mapping, Optional, Sequence, Set, Tuple import more_itertools from dbt_semantic_interfaces.enum_extension import assert_values_exhausted @@ -98,6 +98,12 @@ def matches(self, other: SqlExpressionNode) -> bool: """Similar to equals - returns true if these expressions are equivalent.""" pass + @property + @abstractmethod + def used_column_aliases(self) -> Set[str]: + """All column aliases used in the expression.""" + pass + @dataclass(frozen=True) class SqlExpressionTreeLineage: @@ -313,6 +319,10 @@ def as_string_expression(self) -> Optional[SqlStringExpression]: """If this is a string expression, return self.""" return self + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set(self.used_columns) if self.used_columns else set() + @dataclass(frozen=True) class SqlStringLiteralExpression(SqlExpressionNode): @@ -366,6 +376,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.literal_value == other.literal_value + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set() + @dataclass(frozen=True) class SqlColumnReference: @@ -474,6 +488,14 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102 return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name)) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return { + f"{self.col_ref.table_alias}.{self.col_ref.column_name}" + if self.should_render_table_alias + else self.col_ref.column_name + } + @dataclass(frozen=True) class SqlColumnAliasReferenceExpression(SqlExpressionNode): @@ -535,6 +557,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.column_alias == other.column_alias + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return {self.column_alias} + class SqlComparison(Enum): # noqa: D101 LESS_THAN = "<" @@ -613,6 +639,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.comparison == other.comparison and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.left_expr.used_column_aliases.union(self.right_expr.used_column_aliases) + class SqlFunction(Enum): """Names of known SQL functions like SUM() in SELECT SUM(...). @@ -811,6 +841,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.sql_function == other.sql_function and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return {column_alias for arg in self.sql_function_args for column_alias in arg.used_column_aliases} + class SqlPercentileFunctionType(Enum): """Type of percentile function used.""" @@ -931,6 +965,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.percentile_args == other.percentile_args and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.order_by_arg.used_column_aliases + class SqlWindowFunction(Enum): """Names of known SQL window functions like SUM(), RANK(), ROW_NUMBER(). @@ -1100,6 +1138,15 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 and self._parents_match(other) ) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + args = ( + list(self.sql_function_args) + + list(self.partition_by_args) + + [order_by_arg.expr for order_by_arg in self.order_by_args] + ) + return {column_alias for arg in args for column_alias in arg.used_column_aliases} + @dataclass(frozen=True) class SqlNullExpression(SqlExpressionNode): @@ -1140,6 +1187,10 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return isinstance(other, SqlNullExpression) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set() + class SqlLogicalOperator(Enum): """List all supported binary logical operator expressions. @@ -1202,6 +1253,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.operator == other.operator and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return {column_alias for arg in self.args for column_alias in arg.used_column_aliases} + @dataclass(frozen=True) class SqlIsNullExpression(SqlExpressionNode): @@ -1247,6 +1302,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlSubtractTimeIntervalExpression(SqlExpressionNode): @@ -1312,6 +1371,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.count == other.count and self.granularity == other.granularity and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlCastToTimestampExpression(SqlExpressionNode): @@ -1359,6 +1422,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlDateTruncExpression(SqlExpressionNode): @@ -1410,6 +1477,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.time_granularity == other.time_granularity and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlExtractExpression(SqlExpressionNode): @@ -1469,6 +1540,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self.date_part == other.date_part and self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.arg.used_column_aliases + @dataclass(frozen=True) class SqlRatioComputationExpression(SqlExpressionNode): @@ -1534,6 +1609,10 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return self.numerator.used_column_aliases.union(self.denominator.used_column_aliases) + @dataclass(frozen=True) class SqlBetweenExpression(SqlExpressionNode): @@ -1599,6 +1678,14 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False return self._parents_match(other) + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return { + column_alias + for arg in (self.column_arg, self.start_expr, self.end_expr) + for column_alias in arg.used_column_aliases + } + @dataclass(frozen=True) class SqlGenerateUuidExpression(SqlExpressionNode): @@ -1649,3 +1736,7 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False + + @property + def used_column_aliases(self) -> Set[str]: # noqa: D102 + return set()