Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split column pruner into two phases #1501

Merged
merged 5 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ def _make_time_spine_data_set(
def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet:
"""Generate the SQL to read from the source."""
return SqlDataSet(
sql_select_node=node.data_set.checked_sql_select_node,
# This visitor is assumed to create a unique SELECT node for each dataflow node, so create a copy.
# The column pruner relies on this assumption to keep track of what columns are required at each node.
sql_select_node=node.data_set.checked_sql_select_node.create_copy(),
instance_set=node.data_set.instance_set,
)

Expand Down
216 changes: 55 additions & 161 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import Dict, List, Set, Tuple

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlQueryPlanOptimizer
from metricflow.sql.sql_exprs import (
SqlExpressionTreeLineage,
)
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.optimizer.tag_required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlQueryPlanNode,
SqlQueryPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
Expand All @@ -28,177 +24,55 @@
class SqlColumnPrunerVisitor(SqlQueryPlanNodeVisitor[SqlQueryPlanNode]):
"""Removes unnecessary columns from SELECT statements in the SQL query plan.

As the visitor traverses up to the parents, it pushes the list of required columns and rewrites the parent nodes.
This requires a set of tagged column aliases that should be kept for each SQL node.
"""

def __init__(
self,
required_column_aliases: Set[str],
required_alias_mapping: NodeToColumnAliasMapping,
) -> None:
"""Constructor.

Args:
required_column_aliases: the columns aliases that should not be pruned from the SELECT statements that this
visits.
"""
self._required_column_aliases = required_column_aliases

def _search_for_expressions(
self, select_node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
) -> SqlExpressionTreeLineage:
"""Returns the expressions used in the immediate select statement.

i.e. this does not return expressions used in sub-queries. pruned_select_columns needs to be passed in since the
node may have the select columns pruned.
required_alias_mapping: Describes columns aliases that should be kept / not pruned for each node.
"""
all_expr_search_results: List[SqlExpressionTreeLineage] = []

for select_column in pruned_select_columns:
all_expr_search_results.append(select_column.expr.lineage)

for join_description in select_node.join_descs:
if join_description.on_condition:
all_expr_search_results.append(join_description.on_condition.lineage)

for group_by in select_node.group_bys:
all_expr_search_results.append(group_by.expr.lineage)

for order_by in select_node.order_bys:
all_expr_search_results.append(order_by.expr.lineage)

if select_node.where:
all_expr_search_results.append(select_node.where.lineage)

return SqlExpressionTreeLineage.combine(all_expr_search_results)

def _prune_columns_from_grandparents(
self, node: SqlSelectStatementNode, pruned_select_columns: Tuple[SqlSelectColumn, ...]
) -> SqlSelectStatementNode:
"""Assume that you need all columns from the parent and prune the grandparents."""
pruned_from_source: SqlQueryPlanNode
if node.from_source.as_select_node:
from_visitor = SqlColumnPrunerVisitor(
required_column_aliases={x.column_alias for x in node.from_source.as_select_node.select_columns}
)
pruned_from_source = node.from_source.as_select_node.accept(from_visitor)
else:
pruned_from_source = node.from_source
pruned_join_descriptions: List[SqlJoinDescription] = []
for join_description in node.join_descs:
right_source_as_select_node = join_description.right_source.as_select_node
if right_source_as_select_node:
right_source_visitor = SqlColumnPrunerVisitor(
required_column_aliases={x.column_alias for x in right_source_as_select_node.select_columns}
)
pruned_join_descriptions.append(
SqlJoinDescription(
right_source=join_description.right_source.accept(right_source_visitor),
right_source_alias=join_description.right_source_alias,
on_condition=join_description.on_condition,
join_type=join_description.join_type,
)
)
else:
pruned_join_descriptions.append(join_description)

return SqlSelectStatementNode.create(
description=node.description,
select_columns=pruned_select_columns,
from_source=pruned_from_source,
from_source_alias=node.from_source_alias,
join_descs=tuple(pruned_join_descriptions),
group_bys=node.group_bys,
order_bys=node.order_bys,
where=node.where,
limit=node.limit,
distinct=node.distinct,
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError
self._required_alias_mapping = required_alias_mapping

def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlQueryPlanNode: # noqa: D102
# Remove columns that are not needed from this SELECT statement because the parent SELECT statement doesn't
# need them. However, keep columns that are in group bys because that changes the meaning of the query.
# Similarly, if this node is a distinct select node, keep all columns as it may return a different result set.
required_column_aliases = self._required_alias_mapping.get_aliases(node)
if required_column_aliases is None:
logger.error(
f"Did not find {node.node_id=} in the required alias mapping. Returning the non-pruned version "
f"as it should be valid SQL, but this is a bug and should be investigated."
)
return node

if len(required_column_aliases) == 0:
logger.error(
f"Got no required column aliases for {node}. Returning the non-pruned version as it should be valid "
f"SQL, but this is a bug and should be investigated."
)
return node

pruned_select_columns = tuple(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tangential, but I've frequently read this code and found this variable name confusing (pruned_select_columns). We frequently refer to "pruned columns" when we mean the ones that have been removed, but in this case we mean the columns that have been kept. I think the word pruned can technically be used both ways, but it typically is used to refer to what has been removed. Can we change this to a more clear variable name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this could be renamed. To understand where you're coming from, you mention that [pruned] typically is used to refer to what has been removed. What examples were you thinking?

When I think of pruned, I think about an overgrown tree. Once I prune it, I would call it a pruned tree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's true - the tree has been pruned. But if you're referring to the branches, I think the "pruned" branches would typically refer to the ones removed. In this case the tree is the SQL node and the columns are the branches.

I just did a quick search through the code to see how we use this word, and there are a couple places where we use the opposite meaning of prune:

required_alias_mapping: Describes columns aliases that should be kept / not pruned for each node.

def test_dont_prune_if_in_where(
request: FixtureRequest,

def test_dont_prune_with_str_expr(
request: FixtureRequest,

And this is silly but I just did a quick google search for a gut check here and it does look like the pruned leaves are the ones that have been removed:
Screenshot 2024-11-08 at 1 10 32 PM

Not totally related to this PR so this isn't blocking! But if you don't update the naming here I probably will the next time I come across it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I'm thinking we use different terms like "removed" and "retained" then, but will have to handle in a follow up.

select_column
for select_column in node.select_columns
if select_column.column_alias in self._required_column_aliases
or select_column in node.group_bys
or node.distinct
)
# TODO: don't prune columns used in join condition! Tricky to derive since the join condition can be any
# SqlExpressionNode.

if len(pruned_select_columns) == 0:
raise RuntimeError(
"All columns have been removed - this indicates an bug in the pruner or in the inputs.\n"
f"Original column aliases: {[col.column_alias for col in node.select_columns]}\n"
f"Required column aliases: {self._required_column_aliases}\n"
f"Group bys: {node.group_bys}\n"
f"Distinct: {node.distinct}"
)

# Based on the expressions in this select statement, figure out what column aliases are needed in the sources of
# this query (i.e. tables or sub-queries in the FROM or JOIN clauses).
exprs_used_in_this_node = self._search_for_expressions(node, pruned_select_columns)

# If any of the string expressions don't have context on what columns are used in the expression, then it's
# impossible to know what columns can be pruned from the parent sources. So return a SELECT statement that
# leaves the parent sources untouched. Columns from the grandparents can be pruned based on the parent node
# though.
if any([string_expr.used_columns is None for string_expr in exprs_used_in_this_node.string_exprs]):
return self._prune_columns_from_grandparents(node, pruned_select_columns)

# Create a mapping from the source alias to the column aliases needed from the corresponding source.
source_alias_to_required_column_alias: Dict[str, Set[str]] = defaultdict(set)
for column_reference_expr in exprs_used_in_this_node.column_reference_exprs:
column_reference = column_reference_expr.col_ref
source_alias_to_required_column_alias[column_reference.table_alias].add(column_reference.column_name)

# For all string columns, assume that they are needed from all sources since we don't have a table alias
# in SqlStringExpression.used_columns
for string_expr in exprs_used_in_this_node.string_exprs:
if string_expr.used_columns:
for column_alias in string_expr.used_columns:
source_alias_to_required_column_alias[node.from_source_alias].add(column_alias)
for join_description in node.join_descs:
source_alias_to_required_column_alias[join_description.right_source_alias].add(column_alias)
# Same with unqualified column references.
for unqualified_column_reference_expr in exprs_used_in_this_node.column_alias_reference_exprs:
column_alias = unqualified_column_reference_expr.column_alias
source_alias_to_required_column_alias[node.from_source_alias].add(column_alias)
for join_description in node.join_descs:
source_alias_to_required_column_alias[join_description.right_source_alias].add(column_alias)

# Once we know which column aliases are required from which source aliases, replace the sources with new SELECT
# statements.
from_source_pruner = SqlColumnPrunerVisitor(
required_column_aliases=source_alias_to_required_column_alias[node.from_source_alias]
if select_column.column_alias in required_column_aliases
)
pruned_from_source = node.from_source.accept(from_source_pruner)
pruned_join_descriptions: List[SqlJoinDescription] = []
for join_description in node.join_descs:
join_source_pruner = SqlColumnPrunerVisitor(
required_column_aliases=source_alias_to_required_column_alias[join_description.right_source_alias]
)
pruned_join_descriptions.append(
SqlJoinDescription(
right_source=join_description.right_source.accept(join_source_pruner),
right_source_alias=join_description.right_source_alias,
on_condition=join_description.on_condition,
join_type=join_description.join_type,
)
)

return SqlSelectStatementNode.create(
description=node.description,
select_columns=tuple(pruned_select_columns),
from_source=pruned_from_source,
select_columns=pruned_select_columns,
from_source=node.from_source.accept(self),
from_source_alias=node.from_source_alias,
join_descs=tuple(pruned_join_descriptions),
# TODO: Handle CTEs.
cte_sources=(),
join_descs=tuple(
join_desc.with_right_source(join_desc.right_source.accept(self)) for join_desc in node.join_descs
),
group_bys=node.group_bys,
order_bys=node.order_bys,
where=node.where,
Expand All @@ -220,17 +94,37 @@ def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlQueryPlan
parent_node=node.parent_node.accept(self),
)

@override
def visit_cte_node(self, node: SqlCteNode) -> SqlQueryPlanNode:
raise NotImplementedError


class SqlColumnPrunerOptimizer(SqlQueryPlanOptimizer):
"""Removes unnecessary columns in the SELECT clauses."""
"""Removes unnecessary columns in the SELECT statements."""

def optimize(self, node: SqlQueryPlanNode) -> SqlQueryPlanNode: # noqa: D102
# Can't prune columns without knowing the structure of the query.
if not node.as_select_node:
# ALl columns in the nearest SELECT node need to be kept as otherwise, the meaning of the query changes.
required_select_columns = node.nearest_select_columns({})

# Can't prune without knowing the structure of the query.
if required_select_columns is None:
logger.error(
LazyFormat(
"The columns required at this node can't be determined, so skipping column pruning",
node=node.structure_text(),
required_select_columns=required_select_columns,
)
)
return node

pruning_visitor = SqlColumnPrunerVisitor(
required_column_aliases={x.column_alias for x in node.as_select_node.select_columns}
map_required_column_aliases_visitor = SqlMapRequiredColumnAliasesVisitor(
start_node=node,
required_column_aliases_in_start_node=frozenset(
[select_column.column_alias for select_column in required_select_columns]
),
)
node.accept(map_required_column_aliases_visitor)

# Re-write the query, removing unnecessary columns in the SELECT statements.
pruning_visitor = SqlColumnPrunerVisitor(map_required_column_aliases_visitor.required_column_alias_mapping)
return node.accept(pruning_visitor)
33 changes: 33 additions & 0 deletions metricflow/sql/optimizer/tag_column_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import Dict, FrozenSet, Iterable, Set

from metricflow.sql.sql_plan import (
SqlQueryPlanNode,
)

logger = logging.getLogger(__name__)


class NodeToColumnAliasMapping:
"""Mutable class for mapping a SQL node to an arbitrary set of column aliases for that node.

* Alternatively, this can be described as mapping a location in the SQL query plan to a set of column aliases.
* See `SqlMapRequiredColumnAliasesVisitor` for the main use case for this class.
* This is a thin wrapper over a dict to aid readability.
"""

def __init__(self) -> None: # noqa: D107
self._node_to_tagged_aliases: Dict[SqlQueryPlanNode, Set[str]] = defaultdict(set)

def get_aliases(self, node: SqlQueryPlanNode) -> FrozenSet[str]:
"""Return the column aliases added for the given SQL node."""
return frozenset(self._node_to_tagged_aliases[node])

def add_alias(self, node: SqlQueryPlanNode, column_alias: str) -> None: # noqa: D102
return self._node_to_tagged_aliases[node].add(column_alias)

def add_aliases(self, node: SqlQueryPlanNode, column_aliases: Iterable[str]) -> None: # noqa: D102
self._node_to_tagged_aliases[node].update(column_aliases)
Loading
Loading