Skip to content

Commit

Permalink
Align Dataflow Plans
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 7, 2024
1 parent 5f2ab0b commit b892fda
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 102 deletions.
144 changes: 71 additions & 73 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -328,34 +328,23 @@ def _build_aggregated_conversion_node(
)

# Build the unaggregated base measure node for computing conversions
unaggregated_base_measure_node = base_measure_recipe.source_node
if base_measure_recipe.join_targets:
unaggregated_base_measure_node = JoinOnEntitiesNode.create(
left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets
)
for time_dimension_spec in base_required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
unaggregated_base_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_base_measure_node, time_dimension_spec=time_dimension_spec
)
if len(base_measure_spec.filter_spec_set.all_filter_specs) > 0:
unaggregated_base_measure_node = WhereConstraintNode.create(
parent_node=unaggregated_base_measure_node,
where_specs=base_measure_spec.filter_spec_set.all_filter_specs,
)
filtered_unaggregated_base_node = FilterElementsNode.create(
parent_node=unaggregated_base_measure_node,
include_specs=group_specs_by_type(required_local_specs)
unaggregated_base_measure_node = self._build_pre_aggregation_plan(
source_node=base_measure_recipe.source_node,
join_targets=base_measure_recipe.join_targets,
filter_to_specs=group_specs_by_type(required_local_specs)
.merge(base_required_linkable_specs.as_instance_spec_set)
.dedupe(),
custom_granularity_specs=base_required_linkable_specs.time_dimension_specs_with_custom_grain,
time_range_constraint=None,
where_filter_specs=base_measure_spec.filter_spec_set.all_filter_specs,
)

# Gets the successful conversions using JoinConversionEventsNode
# The conversion events are joined by the base events which are already time constrained. However, this could
# be still be constrained, where we adjust the time range to the window size similar to cumulative, but
# adjusted in the opposite direction.
join_conversion_node = JoinConversionEventsNode.create(
base_node=filtered_unaggregated_base_node,
base_node=unaggregated_base_measure_node,
base_time_dimension_spec=metric_time_dimension_spec,
conversion_node=unaggregated_conversion_measure_node,
conversion_measure_spec=conversion_measure_spec.measure_spec,
Expand Down Expand Up @@ -833,26 +822,13 @@ def _build_plan_for_distinct_values(
if not dataflow_recipe:
raise UnableToSatisfyQueryError(f"Unable to join all items in request: {required_linkable_specs}")

output_node = dataflow_recipe.source_node
if dataflow_recipe.join_targets:
output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets)

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node, time_dimension_spec=time_dimension_spec
)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs)
if query_spec.time_range_constraint:
output_node = ConstrainTimeRangeNode.create(
parent_node=output_node, time_range_constraint=query_spec.time_range_constraint
)

output_node = FilterElementsNode.create(
parent_node=output_node,
include_specs=InstanceSpecSet.create_from_specs(query_spec.linkable_specs.as_tuple),
output_node = self._build_pre_aggregation_plan(
source_node=dataflow_recipe.source_node,
join_targets=dataflow_recipe.join_targets,
filter_to_specs=InstanceSpecSet.create_from_specs(query_spec.linkable_specs.as_tuple),
custom_granularity_specs=required_linkable_specs.time_dimension_specs_with_custom_grain,
where_filter_specs=query_level_filter_specs,
time_range_constraint=query_spec.time_range_constraint,
distinct=True,
)

Expand Down Expand Up @@ -1650,43 +1626,30 @@ def _build_aggregated_measure_from_measure_source_node(
join_type=before_aggregation_time_spine_join_description.join_type,
)

join_targets = measure_recipe.join_targets
if len(join_targets) > 0:
unaggregated_measure_node = JoinOnEntitiesNode.create(
left_node=unaggregated_measure_node, join_targets=join_targets
)

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if (
time_dimension_spec.time_granularity.is_custom_granularity
# If this is the second layer of aggregation for a conversion metric, we have already joined the custom granularity.
and time_dimension_spec not in measure_recipe.all_linkable_specs_required_for_source_nodes.as_tuple
):
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec
)

custom_granularity_specs_to_join = [
spec
for spec in required_linkable_specs.time_dimension_specs_with_custom_grain
# If this is the second layer of aggregation for a conversion metric, we have already joined the custom granularity.
if spec not in measure_recipe.all_linkable_specs_required_for_source_nodes.as_tuple
]
# If time constraint was previously adjusted for cumulative window or grain, apply original time constraint
# here. Can skip if metric is being aggregated over all time.
# TODO - Pushdown: Encapsulate all of this window sliding bookkeeping in the pushdown params object
if (
cumulative_metric_adjusted_time_constraint is not None
and predicate_pushdown_state.time_range_constraint is not None
):
assert (
queried_linkable_specs.contains_metric_time
), "Using time constraints currently requires querying with metric_time."
unaggregated_measure_node = ConstrainTimeRangeNode.create(
parent_node=unaggregated_measure_node,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
)

if len(metric_input_measure_spec.filter_spec_set.all_filter_specs) > 0:
# Apply where constraint on the node
unaggregated_measure_node = WhereConstraintNode.create(
parent_node=unaggregated_measure_node,
where_specs=metric_input_measure_spec.filter_spec_set.all_filter_specs,
time_range_constraint_to_apply = (
predicate_pushdown_state.time_range_constraint
if (
cumulative_metric_adjusted_time_constraint is not None
and predicate_pushdown_state.time_range_constraint is not None
)
else None
)
unaggregated_measure_node = self._build_pre_aggregation_plan(
source_node=unaggregated_measure_node,
join_targets=measure_recipe.join_targets,
custom_granularity_specs=custom_granularity_specs_to_join,
where_filter_specs=metric_input_measure_spec.filter_spec_set.all_filter_specs,
time_range_constraint=time_range_constraint_to_apply,
)

non_additive_dimension_spec = measure_properties.non_additive_dimension_spec
if non_additive_dimension_spec is not None:
Expand Down Expand Up @@ -1791,3 +1754,38 @@ def _build_aggregated_measure_from_measure_source_node(
return output_node

return aggregate_measures_node

def _build_pre_aggregation_plan(
self,
source_node: DataflowPlanNode,
join_targets: List[JoinDescription],
custom_granularity_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec],
time_range_constraint: Optional[TimeRangeConstraint],
filter_to_specs: Optional[InstanceSpecSet] = None,
distinct: bool = False,
) -> DataflowPlanNode:
# TODO: docstring
output_node = source_node
if join_targets:
output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=join_targets)

for custom_granularity_spec in custom_granularity_specs:
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node, time_dimension_spec=custom_granularity_spec
)

if len(where_filter_specs) > 0:
output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=where_filter_specs)

if time_range_constraint:
output_node = ConstrainTimeRangeNode.create(
parent_node=output_node, time_range_constraint=time_range_constraint
)

if filter_to_specs:
output_node = FilterElementsNode.create(
parent_node=output_node, include_specs=filter_to_specs, distinct=distinct
)

return output_node
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<!-- node_id = NodeId(id_str='wrd_0') -->
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['listing__is_lux_latest', 'metric_time__month']" -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- node_id = NodeId(id_str='pfe_1') -->
<!-- include_spec = -->
<!-- DimensionSpec(element_name='is_lux_latest', entity_links=(EntityReference(element_name='listing'),)) -->
<!-- include_spec = -->
Expand All @@ -21,16 +21,16 @@
<JoinOnEntitiesNode>
<!-- description = 'Join Standard Outputs' -->
<!-- node_id = NodeId(id_str='jso_0') -->
<!-- join0_for_node_id_pfe_1 = -->
<!-- JoinDescription(join_node=FilterElementsNode(node_id=pfe_1), join_type=CROSS_JOIN) -->
<!-- join0_for_node_id_pfe_0 = -->
<!-- JoinDescription(join_node=FilterElementsNode(node_id=pfe_0), join_type=CROSS_JOIN) -->
<ReadSqlSourceNode>
<!-- description = "Read From SemanticModelDataSet('listings_latest')" -->
<!-- node_id = NodeId(id_str='rss_28024') -->
<!-- data_set = SemanticModelDataSet('listings_latest') -->
</ReadSqlSourceNode>
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['metric_time__month',]" -->
<!-- node_id = NodeId(id_str='pfe_1') -->
<!-- node_id = NodeId(id_str='pfe_0') -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<!-- limit = '100' -->
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['user__home_state_latest', 'listing__is_lux_latest']" -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- node_id = NodeId(id_str='pfe_1') -->
<!-- include_spec = -->
<!-- DimensionSpec( -->
<!-- element_name='home_state_latest', -->
Expand Down Expand Up @@ -73,9 +73,9 @@
<JoinOnEntitiesNode>
<!-- description = 'Join Standard Outputs' -->
<!-- node_id = NodeId(id_str='jso_0') -->
<!-- join0_for_node_id_pfe_1 = -->
<!-- join0_for_node_id_pfe_0 = -->
<!-- JoinDescription( -->
<!-- join_node=FilterElementsNode(node_id=pfe_1), -->
<!-- join_node=FilterElementsNode(node_id=pfe_0), -->
<!-- join_on_entity=LinklessEntitySpec(element_name='user'), -->
<!-- join_type=FULL_OUTER, -->
<!-- ) -->
Expand All @@ -86,7 +86,7 @@
</ReadSqlSourceNode>
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['home_state_latest', 'user']" -->
<!-- node_id = NodeId(id_str='pfe_1') -->
<!-- node_id = NodeId(id_str='pfe_0') -->
<!-- include_spec = DimensionSpec(element_name='home_state_latest') -->
<!-- include_spec = LinklessEntitySpec(element_name='user') -->
<!-- distinct = False -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<!-- description = -->
<!-- ("Pass Only Elements: ['user__home_state_latest', 'listing__is_lux_latest', 'metric_time__day', " -->
<!-- "'metric_time__month']") -->
<!-- node_id = NodeId(id_str='pfe_4') -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- include_spec = -->
<!-- DimensionSpec(element_name='home_state_latest', entity_links=(EntityReference(element_name='user'),)) -->
<!-- include_spec = -->
Expand All @@ -25,11 +25,11 @@
<JoinOnEntitiesNode>
<!-- description = 'Join Standard Outputs' -->
<!-- node_id = NodeId(id_str='jso_0') -->
<!-- join0_for_node_id_pfe_2 = -->
<!-- JoinDescription(join_node=FilterElementsNode(node_id=pfe_2), join_type=CROSS_JOIN) -->
<!-- join1_for_node_id_pfe_3 = -->
<!-- join0_for_node_id_pfe_0 = -->
<!-- JoinDescription(join_node=FilterElementsNode(node_id=pfe_0), join_type=CROSS_JOIN) -->
<!-- join1_for_node_id_pfe_1 = -->
<!-- JoinDescription( -->
<!-- join_node=FilterElementsNode(node_id=pfe_3), -->
<!-- join_node=FilterElementsNode(node_id=pfe_1), -->
<!-- join_on_entity=LinklessEntitySpec(element_name='user'), -->
<!-- join_type=FULL_OUTER, -->
<!-- ) -->
Expand All @@ -40,7 +40,7 @@
</ReadSqlSourceNode>
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['metric_time__day', 'metric_time__month']" -->
<!-- node_id = NodeId(id_str='pfe_2') -->
<!-- node_id = NodeId(id_str='pfe_0') -->
<!-- include_spec = -->
<!-- TimeDimensionSpec( -->
<!-- element_name='metric_time', -->
Expand All @@ -65,7 +65,7 @@
</FilterElementsNode>
<FilterElementsNode>
<!-- description = "Pass Only Elements: ['home_state_latest', 'user']" -->
<!-- node_id = NodeId(id_str='pfe_3') -->
<!-- node_id = NodeId(id_str='pfe_1') -->
<!-- include_spec = DimensionSpec(element_name='home_state_latest') -->
<!-- include_spec = LinklessEntitySpec(element_name='user') -->
<!-- distinct = False -->
Expand Down
Loading

0 comments on commit b892fda

Please sign in to comment.