Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicated WhereConstraintNodes in predicate pushdown #1304

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20240625-114914.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Remove extraneous where filter subqueries added by predicate pushdown
time: 2024-06-25T11:49:14.837794-07:00
custom:
Author: tlento
Issue: "1011"
4 changes: 3 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,9 @@ def _build_aggregated_measure_from_measure_source_node(
if set(filter_spec.linkable_specs).issubset(set(queried_linkable_specs.as_tuple))
]
if len(queried_filter_specs) > 0:
output_node = WhereConstraintNode(parent_node=output_node, where_specs=queried_filter_specs)
output_node = WhereConstraintNode(
parent_node=output_node, where_specs=queried_filter_specs, always_apply=True
)

# TODO: this will break if you query by agg_time_dimension but apply a time constraint on metric_time.
# To fix when enabling time range constraints for agg_time_dimension.
Expand Down
30 changes: 25 additions & 5 deletions metricflow/dataflow/nodes/where_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,27 @@
class WhereConstraintNode(DataflowPlanNode):
"""Remove rows using a WHERE clause."""

def __init__( # noqa: D107
def __init__(
self,
parent_node: DataflowPlanNode,
where_specs: Sequence[WhereFilterSpec],
always_apply: bool = False,
) -> None:
"""Initializer.

WhereConstraintNodes must always have exactly one parent, since they always wrap a single subquery input.

The always_apply parameter serves as an indicator for a WhereConstraintNode that is added to a plan in order
to clean up null outputs from a pre-join filter. For example, when doing time spine joins to fill null values
for metric outputs sometimes that join will result in rows with null values for various dimension attributes.
By re-applying the filter expression after the join step we will discard those unexpected output rows created
by the join (rather than the underlying inputs). In this case, we must ensure that the filters defined in this
node are always applied at the moment this node is processed, regardless of whether or not they've been pushed
down through the DAG.
"""
self._where_specs = where_specs
self.parent_node = parent_node
self.always_apply = always_apply
super().__init__(node_id=self.create_unique_id(), parent_nodes=(parent_node,))

@classmethod
Expand Down Expand Up @@ -52,14 +66,20 @@ def description(self) -> str: # noqa: D102

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (DisplayedProperty("where_condition", self.where),)
properties = tuple(super().displayed_properties) + (DisplayedProperty("where_condition", self.where),)
if self.always_apply:
properties = properties + (DisplayedProperty("All filters always applied:", self.always_apply),)
return properties

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return isinstance(other_node, self.__class__) and other_node.where == self.where
return (
isinstance(other_node, self.__class__)
and other_node.where == self.where
and other_node.always_apply == self.always_apply
)

def with_new_parents(self, new_parent_nodes: Sequence[DataflowPlanNode]) -> WhereConstraintNode: # noqa: D102
assert len(new_parent_nodes) == 1
return WhereConstraintNode(
parent_node=new_parent_nodes[0],
where_specs=self.input_where_specs,
parent_node=new_parent_nodes[0], where_specs=self.input_where_specs, always_apply=self.always_apply
)
30 changes: 23 additions & 7 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,10 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran

with self._predicate_pushdown_tracker.track_pushdown_state(updated_pushdown_state):
optimized_parent: OptimizeBranchResult = node.parent_node.accept(self)
# TODO: Update to only apply filters that have not been successfully pushed down
optimized_node = OptimizeBranchResult(
optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,))
)

pushdown_state_updated_by_parent = self._predicate_pushdown_tracker.last_pushdown_state
# Override the pushdown state for this node and allow all upstream propagation to be handled by the tracker
if len(pushdown_state_updated_by_parent.applied_where_filter_specs) > 0:
applied_filter_specs = pushdown_state_updated_by_parent.applied_where_filter_specs
filter_specs_to_apply = [spec for spec in node.input_where_specs if spec not in applied_filter_specs]
if len(applied_filter_specs) > 0:
updated_specs = frozenset.union(
frozenset(node.input_where_specs),
pushdown_state_updated_by_parent.applied_where_filter_specs,
Expand All @@ -388,6 +384,26 @@ def visit_where_constraint_node(self, node: WhereConstraintNode) -> OptimizeBran
),
)

if node.always_apply:
logger.log(
level=self._log_level,
msg=(
"Applying original filter spec set based on node-level override directive. Additional specs "
+ f"appled:\n{[spec for spec in node.input_where_specs if spec not in filter_specs_to_apply]}"
),
)
optimized_node = OptimizeBranchResult(
optimized_branch=node.with_new_parents((optimized_parent.optimized_branch,))
)
elif len(filter_specs_to_apply) > 0:
optimized_node = OptimizeBranchResult(
optimized_branch=WhereConstraintNode(
parent_node=optimized_parent.optimized_branch, where_specs=filter_specs_to_apply
)
)
else:
optimized_node = optimized_parent

return optimized_node

# Join nodes - these may affect pushdown state based on join type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,10 @@ def test_simple_join_categorical_pushdown(
query_parser: MetricFlowQueryParser,
dataflow_plan_builder: DataflowPlanBuilder,
) -> None:
"""Tests pushdown optimization for a simple predicate through a single join."""
"""Tests pushdown optimization for a simple predicate through a single join.

In this case the entire constraint should be moved inside of the join.
"""
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookings",),
group_by_names=("listing__country_latest",),
Expand All @@ -205,7 +208,7 @@ def test_simple_join_categorical_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand Down Expand Up @@ -256,7 +259,7 @@ def test_conversion_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1, # TODO: Remove superfluous where constraint nodes
expected_additional_constraint_nodes_in_optimized=0,
)


Expand All @@ -272,6 +275,8 @@ def test_cumulative_metric_predicate_pushdown(
since supporting time filter pushdown for cumulative metrics requires filter expansion to ensure we capture the
full set of inputs for the initial cumulative window.

For the query listed here the entire constraint will be moved past the dimension join.

TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
Expand All @@ -284,7 +289,7 @@ def test_cumulative_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand Down Expand Up @@ -323,6 +328,8 @@ def test_offset_metric_predicate_pushdown(
As with cumulative metrics, at this time categorical dimension predicates may be pushed down, but metric_time
predicates should not be, since we need to capture the union of the filter window and the offset span.

For the query listed here the entire constraint will be moved past the dimension join.

TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
Expand All @@ -335,7 +342,7 @@ def test_offset_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand All @@ -349,6 +356,8 @@ def test_fill_nulls_time_spine_metric_predicate_pushdown(

Until time dimension pushdown is supported we will only see the categorical dimension entry pushed down here.

For the query listed here the entire constraint will be moved past the dimension join.

TODO: Add metric time filters
"""
query_spec = query_parser.parse_and_validate_query(
Expand All @@ -361,7 +370,7 @@ def test_fill_nulls_time_spine_metric_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)


Expand All @@ -377,7 +386,9 @@ def test_fill_nulls_time_spine_metric_with_post_agg_join_predicate_pushdown(
against the time spine, which should preclude predicate pushdown for query-time filters at that state, but
will allow for pushdown within the JoinToTimeSpine operation. This will still do predicate pushdown as before,
but only exactly as before - the added constraint outside of the JoinToTimeSpine operation must still be
applied.
applied in its entirety, and so we expect 0 additional constraint nodes. If we failed to account for the
repeated constraint outside of the JoinToTimeSpine in our pushdown handling this would remove one of the
WhereConstraintNodes from the original query altogether.

Until time dimension pushdown is supported we will only see the categorical dimension entry pushed down here.

Expand All @@ -393,5 +404,5 @@ def test_fill_nulls_time_spine_metric_with_post_agg_join_predicate_pushdown(
mf_test_configuration=mf_test_configuration,
dataflow_plan_builder=dataflow_plan_builder,
query_spec=query_spec,
expected_additional_constraint_nodes_in_optimized=1,
expected_additional_constraint_nodes_in_optimized=0,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,32 @@ SELECT
FROM (
-- Combine Aggregated Outputs
SELECT
COALESCE(subq_21.metric_time__day, subq_31.metric_time__day) AS metric_time__day
, COALESCE(subq_21.visit__referrer_id, subq_31.visit__referrer_id) AS visit__referrer_id
, MAX(subq_21.visits) AS visits
, MAX(subq_31.buys) AS buys
COALESCE(subq_20.metric_time__day, subq_30.metric_time__day) AS metric_time__day
, COALESCE(subq_20.visit__referrer_id, subq_30.visit__referrer_id) AS visit__referrer_id
, MAX(subq_20.visits) AS visits
, MAX(subq_30.buys) AS buys
FROM (
-- Constrain Output with WHERE
-- Pass Only Elements: ['visits', 'visit__referrer_id', 'metric_time__day']
-- Aggregate Measures
SELECT
metric_time__day
, visit__referrer_id
, SUM(visits) AS visits
FROM (
-- Constrain Output with WHERE
-- Pass Only Elements: ['visits', 'visit__referrer_id', 'metric_time__day']
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
SELECT
metric_time__day
, visit__referrer_id
, visits
FROM (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
SELECT
DATETIME_TRUNC(ds, day) AS metric_time__day
, referrer_id AS visit__referrer_id
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_17
WHERE visit__referrer_id = 'ref_id_01'
) subq_19
DATETIME_TRUNC(ds, day) AS metric_time__day
, referrer_id AS visit__referrer_id
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_17
WHERE visit__referrer_id = 'ref_id_01'
GROUP BY
metric_time__day
, visit__referrer_id
) subq_21
) subq_20
FULL OUTER JOIN (
-- Find conversions for user within the range of INF
-- Pass Only Elements: ['buys', 'visit__referrer_id', 'metric_time__day']
Expand All @@ -51,48 +43,48 @@ FROM (
FROM (
-- Dedupe the fanout with mf_internal_uuid in the conversion data set
SELECT DISTINCT
FIRST_VALUE(subq_24.visits) OVER (
FIRST_VALUE(subq_23.visits) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS visits
, FIRST_VALUE(subq_24.visit__referrer_id) OVER (
, FIRST_VALUE(subq_23.visit__referrer_id) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS visit__referrer_id
, FIRST_VALUE(subq_24.ds__day) OVER (
, FIRST_VALUE(subq_23.ds__day) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS ds__day
, FIRST_VALUE(subq_24.metric_time__day) OVER (
, FIRST_VALUE(subq_23.metric_time__day) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS metric_time__day
, FIRST_VALUE(subq_24.user) OVER (
, FIRST_VALUE(subq_23.user) OVER (
PARTITION BY
subq_27.user
, subq_27.ds__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.ds__day DESC
subq_26.user
, subq_26.ds__day
, subq_26.mf_internal_uuid
ORDER BY subq_23.ds__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS user
, subq_27.mf_internal_uuid AS mf_internal_uuid
, subq_27.buys AS buys
, subq_26.mf_internal_uuid AS mf_internal_uuid
, subq_26.buys AS buys
FROM (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
Expand All @@ -104,7 +96,7 @@ FROM (
, referrer_id AS visit__referrer_id
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_24
) subq_23
INNER JOIN (
-- Read Elements From Semantic Model 'buys_source'
-- Metric Time Dimension 'ds'
Expand All @@ -115,25 +107,25 @@ FROM (
, 1 AS buys
, GENERATE_UUID() AS mf_internal_uuid
FROM ***************************.fct_buys buys_source_src_28000
) subq_27
) subq_26
ON
(
subq_24.user = subq_27.user
subq_23.user = subq_26.user
) AND (
(subq_24.ds__day <= subq_27.ds__day)
(subq_23.ds__day <= subq_26.ds__day)
)
) subq_28
) subq_27
GROUP BY
metric_time__day
, visit__referrer_id
) subq_31
) subq_30
ON
(
subq_21.visit__referrer_id = subq_31.visit__referrer_id
subq_20.visit__referrer_id = subq_30.visit__referrer_id
) AND (
subq_21.metric_time__day = subq_31.metric_time__day
subq_20.metric_time__day = subq_30.metric_time__day
)
GROUP BY
metric_time__day
, visit__referrer_id
) subq_32
) subq_31
Loading
Loading