Skip to content

Commit

Permalink
Remove duplicated WhereConstraintNodes in predicate pushdown (#1304)
Browse files Browse the repository at this point in the history
The original implementation of predicate pushdown generated duplicated
where filter constraints, applying it once as a result of pushdown and
again in the original where constraint node.

This change removes the duplication, and does so at the level of the
individual filter spec. Logically, this means the following:

1. If all filter specs for a where constraint node can be pushed down,
the node itself will be moved past the join.
2. If some filter specs for a where constraint node can be pushed down,
the node itself will remain in place, but it will only apply the filters
that cannot be pushed down.
3. Any WhereConstraintNode added to the DataflowPlan for the purposes of
cleaning up spurious rows added as the result of an outer join, such as
we have in one particular JoinToTimeSpineNode scenario, must be annotated
to indicate that we should apply that constraint filter no matter what.

This change updates all of our logic to ensure that we effectively apply
each where filter exactly once within a branch in the DAG, and only re-apply
filters where explicitly requested at build time.

Note this change is all based on certain assumptions about how the
DataflowPlanBuilder constructs the plan. Specifically, we assume that we will
never need to re-apply filters upstream of a node unless otherwise
indicated. If we start to do fancy things with nesting where constraints
or swapping join sides - particularly if we push groupable metric filters
to an inner join or if we begin adding query time constraints outside of
an aggregated output join - this could get dodgy. Conversion metrics are
of particular concern here.
  • Loading branch information
tlento authored Jun 26, 2024
1 parent 69feadc commit 981b396
Show file tree
Hide file tree
Showing 104 changed files with 2,965 additions and 3,913 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240625-114914.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Remove extraneous where filter subqueries added by predicate pushdown
time: 2024-06-25T11:49:14.837794-07:00
custom:
Author: tlento
Issue: "1011"
4 changes: 3 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,9 @@ def _build_aggregated_measure_from_measure_source_node(
if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple))
]
if len(queried_filter_specs) > 0:
output_node = WhereConstraintNode(parent_node=output_node, where_specs=queried_filter_specs)
output_node = WhereConstraintNode(
parent_node=output_node, where_specs=queried_filter_specs, always_apply=True
)

# TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time.
# To fix when enabling time range constraints for agg_time_dimension.
Expand Down
30 changes: 25 additions & 5 deletions metricflow/dataflow/nodes/where_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,27 @@
class WhereConstraintNode(DataflowPlanNode):
"""Remove rows using a WHERE clause."""

def __init__( # noqa: D107
def __init__(
self,
parent_node: DataflowPlanNode,
where_specs: Sequence[WhereFilterSpec],
always_apply: bool = False,
) -> None:
"""Initializer.
WhereConstraintNodes must always have exactly one parent, since they always wrap a single subquery input.
The always_apply parameter serves as an indicator for a WhereConstraintNode that is added to a plan in order
to clean up null outputs from a pre-join filter. For example, when doing time spine joins to fill null values
for metric outputs sometimes that join will result in rows with null values for various dimension attributes.
By re-applying the filter expression after the join step we will discard those unexpected output rows created
by the join (rather than the underlying inputs). In this case, we must ensure that the filters defined in this
node are always applied at the moment this node is processed, regardless of whether or not they've been pushed
down through the DAG.
"""
self._where_specs = where_specs
self.parent_node = parent_node
self.always_apply = always_apply
super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,))

@classmethod
Expand Down Expand Up @@ -52,14 +66,20 @@ def description(self) -> str: # noqa: D102

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("where_condition", self.where),)
properties = tuple(super().displayed_properties) + (DisplayedProperty("where_condition", self.where),)
if self.always_apply:
properties = properties + (DisplayedProperty("All filters always applied:", self.always_apply),)
return properties

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return isinstance(other_node, self.__class__) and other_node.where == self.where
return (
isinstance(other_node, self.__class__)
and other_node.where == self.where
and other_node.always_apply == self.always_apply
)

def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> WhereConstraintNode: # noqa: D102
assert len(new_parent_nodes) == 1
return WhereConstraintNode(
parent_node=new_parent_nodes[0],
where_specs=self.input_where_specs,
parent_node=new_parent_nodes[0], where_specs=self.input_where_specs, always_apply=self.always_apply
)
30 changes: 23 additions & 7 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,10 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran

with self._predicate_pushdown_tracker.track_pushdown_state(updated_pushdown_state):
optimized_parent: OptimizeBranchResult = node.parent_node.accept(self)
# TODO: Update to only apply filters that have not been successfully pushed down
optimized_node = OptimizeBranchResult(
optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,))
)

pushdown_state_updated_by_parent = self._predicate_pushdown_tracker.last_pushdown_state
# Override the pushdown state for this node and allow all upstream propagation to be handled by the tracker
if len(pushdown_state_updated_by_parent.applied_where_filter_specs) > 0:
applied_filter_specs = pushdown_state_updated_by_parent.applied_where_filter_specs
filter_specs_to_apply = [spec for spec in node.input_where_specs if spec not in applied_filter_specs]
if len(applied_filter_specs) > 0:
updated_specs = frozenset.union(
frozenset(node.input_where_specs),
pushdown_state_updated_by_parent.applied_where_filter_specs,
Expand All @@ -388,6 +384,26 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
),
)

if node.always_apply:
logger.log(
level=self._log_level,
msg=(
"Applying original filter spec set based on node-level override directive. Additional specs "
+ f"appled:\n{[spec for spec in node.input_where_specs if spec not in filter_specs_to_apply]}"
),
)
optimized_node = OptimizeBranchResult(
optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,))
)
elif len(filter_specs_to_apply) > 0:
optimized_node = OptimizeBranchResult(
optimized_branch=WhereConstraintNode(
parent_node=optimized_parent.optimized_branch, where_specs=filter_specs_to_apply
)
)
else:
optimized_node = optimized_parent

return optimized_node

# Join nodes - these may affect pushdown state based on join type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def test_simple_join_categorical_pushdown(
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimization for a simple predicate through a single join."""
"""Tests pushdown optimization for a simple predicate through a single join.
In this case the entire constraint should be moved inside of the join.
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
group_by_names=("listing__country_latest",),
Expand All @@ -205,7 +208,7 @@ def test_simple_join_categorical_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand Down Expand Up @@ -256,7 +259,7 @@ def test_conversion_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1, # TODO: Remove superfluous where constraint nodes
expected_additional_constraint_nodes_in_optimized=0,
)


Expand All @@ -272,6 +275,8 @@ def test_cumulative_metric_predicate_pushdown(
since supporting time filter pushdown for cumulative metrics requires filter expansion to ensure we capture the
full set of inputs for the initial cumulative window.
For the query listed here the entire constraint will be moved past the dimension join.
TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
Expand All @@ -284,7 +289,7 @@ def test_cumulative_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand Down Expand Up @@ -323,6 +328,8 @@ def test_offset_metric_predicate_pushdown(
As with cumulative metrics, at this time categorical dimension predicates may be pushed down, but metric_time
predicates should not be, since we need to capture the union of the filter window and the offset span.
For the query listed here the entire constraint will be moved past the dimension join.
TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
Expand All @@ -335,7 +342,7 @@ def test_offset_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand All @@ -349,6 +356,8 @@ def test_fill_nulls_time_spine_metric_predicate_pushdown(
Until time dimension pushdown is supported we will only see the categorical dimension entry pushed down here.
For the query listed here the entire constraint will be moved past the dimension join.
TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
Expand All @@ -361,7 +370,7 @@ def test_fill_nulls_time_spine_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand All @@ -377,7 +386,9 @@ def test_fill_nulls_time_spine_metric_with_post_agg_join_predicate_pushdown(
against the time spine, which should preclude predicate pushdown for query-time filters at that state, but
will allow for pushdown within the JoinToTimeSpine operation. This will still do predicate pushdown as before,
but only exactly as before - the added constraint outside of the JoinToTimeSpine operation must still be
applied.
applied in its entirety, and so we expect 0 additional constraint nodes. If we failed to account for the
repeated constraint outside of the JoinToTimeSpine in our pushdown handling this would remove one of the
WhereConstraintNodes from the original query altogether.
Until time dimension pushdown is supported we will only see the categorical dimension entry pushed down here.
Expand All @@ -393,5 +404,5 @@ def test_fill_nulls_time_spine_metric_with_post_agg_join_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,32 @@ SELECT
FROM (
-- Combine Aggregated Outputs
SELECT
COALESCE(subq_21.metric_time__day, subq_31.metric_time__day) AS metric_time__day
, COALESCE(subq_21.visit__referrer_id, subq_31.visit__referrer_id) AS visit__referrer_id
, MAX(subq_21.visits) AS visits
, MAX(subq_31.buys) AS buys
COALESCE(subq_20.metric_time__day, subq_30.metric_time__day) AS metric_time__day
, COALESCE(subq_20.visit__referrer_id, subq_30.visit__referrer_id) AS visit__referrer_id
, MAX(subq_20.visits) AS visits
, MAX(subq_30.buys) AS buys
FROM (
-- Constrain Output with WHERE
-- Pass Only Elements: ['visits', 'visit__referrer_id', 'metric_time__day']
-- Aggregate Measures
SELECT
metric_time__day
, visit__referrer_id
, SUM(visits) AS visits
FROM (
-- Constrain Output with WHERE
-- Pass Only Elements: ['visits', 'visit__referrer_id', 'metric_time__day']
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
SELECT
metric_time__day
, visit__referrer_id
, visits
FROM (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
SELECT
DATETIME_TRUNC(ds, day) AS metric_time__day
, referrer_id AS visit__referrer_id
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_17
WHERE visit__referrer_id = 'ref_id_01'
) subq_19
DATETIME_TRUNC(ds, day) AS metric_time__day
, referrer_id AS visit__referrer_id
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_17
WHERE visit__referrer_id = 'ref_id_01'
GROUP BY
metric_time__day
, visit__referrer_id
) subq_21
) subq_20
FULL OUTER JOIN (
-- Find conversions for user within the range of INF
-- Pass Only Elements: ['buys', 'visit__referrer_id', 'metric_time__day']
Expand All @@ -51,48 +43,48 @@ FROM (
FROM (
-- Dedupe the fanout with mf_internal_uuid in the conversion data set
SELECT DISTINCT
FIRST_VALUE(subq_24.visits) OVER (
FIRST_VALUE(subq_23.visits) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS visits
, FIRST_VALUE(subq_24.visit__referrer_id) OVER (
, FIRST_VALUE(subq_23.visit__referrer_id) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS visit__referrer_id
, FIRST_VALUE(subq_24.ds__day) OVER (
, FIRST_VALUE(subq_23.ds__day) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS ds__day
, FIRST_VALUE(subq_24.metric_time__day) OVER (
, FIRST_VALUE(subq_23.metric_time__day) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS metric_time__day
, FIRST_VALUE(subq_24.user) OVER (
, FIRST_VALUE(subq_23.user) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS user
, subq_27.mf_internal_uuid AS mf_internal_uuid
, subq_27.buys AS buys
, subq_26.mf_internal_uuid AS mf_internal_uuid
, subq_26.buys AS buys
FROM (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
Expand All @@ -104,7 +96,7 @@ FROM (
, referrer_id AS visit__referrer_id
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_24
) subq_23
INNER JOIN (
-- Read Elements From Semantic Model 'buys_source'
-- Metric Time Dimension 'ds'
Expand All @@ -115,25 +107,25 @@ FROM (
, 1 AS buys
, GENERATE_UUID() AS mf_internal_uuid
FROM ***************************.fct_buys buys_source_src_28000
) subq_27
) subq_26
ON
(
subq_24.user = subq_27.user
subq_23.user = subq_26.user
) AND (
(subq_24.ds__day <= subq_27.ds__day)
(subq_23.ds__day <= subq_26.ds__day)
)
) subq_28
) subq_27
GROUP BY
metric_time__day
, visit__referrer_id
) subq_31
) subq_30
ON
(
subq_21.visit__referrer_id = subq_31.visit__referrer_id
subq_20.visit__referrer_id = subq_30.visit__referrer_id
) AND (
subq_21.metric_time__day = subq_31.metric_time__day
subq_20.metric_time__day = subq_30.metric_time__day
)
GROUP BY
metric_time__day
, visit__referrer_id
) subq_32
) subq_31
Loading

0 comments on commit 981b396

Please sign in to comment.