Skip to content

Commit

Permalink
fixup! Align Dataflow Plans
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 8, 2024
1 parent b892fda commit fbc9c9f
Showing 1 changed file with 55 additions and 39 deletions.
94 changes: 55 additions & 39 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,43 +1649,11 @@ def _build_aggregated_measure_from_measure_source_node(
custom_granularity_specs=custom_granularity_specs_to_join,
where_filter_specs=metric_input_measure_spec.filter_spec_set.all_filter_specs,
time_range_constraint=time_range_constraint_to_apply,
)

non_additive_dimension_spec = measure_properties.non_additive_dimension_spec
if non_additive_dimension_spec is not None:
# Apply semi additive join on the node
agg_time_dimension = measure_properties.agg_time_dimension
non_additive_dimension_grain = measure_properties.agg_time_dimension_grain
queried_time_dimension_spec: Optional[
TimeDimensionSpec
] = self._find_non_additive_dimension_in_linkable_specs(
agg_time_dimension=agg_time_dimension,
linkable_specs=queried_linkable_specs.as_tuple,
non_additive_dimension_spec=non_additive_dimension_spec,
)
time_dimension_spec = TimeDimensionSpec(
# The NonAdditiveDimensionSpec name property is a plain element name
element_name=non_additive_dimension_spec.name,
entity_links=(),
time_granularity=ExpandedTimeGranularity.from_time_granularity(non_additive_dimension_grain),
)
window_groupings = tuple(
LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings
)
unaggregated_measure_node = SemiAdditiveJoinNode.create(
parent_node=unaggregated_measure_node,
entity_specs=window_groupings,
time_dimension_spec=time_dimension_spec,
agg_by_function=non_additive_dimension_spec.window_choice,
queried_time_dimension_spec=queried_time_dimension_spec,
)

# Filter to just the required measure and the requested group bys so that aggregations work correctly.
unaggregated_measure_node = FilterElementsNode.create(
parent_node=unaggregated_measure_node,
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(
filter_to_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(
InstanceSpecSet.create_from_specs(queried_linkable_specs.as_tuple)
),
measure_properties=measure_properties,
queried_linkable_specs=queried_linkable_specs,
)

aggregate_measures_node = AggregateMeasuresNode.create(
Expand Down Expand Up @@ -1762,7 +1730,9 @@ def _build_pre_aggregation_plan(
custom_granularity_specs: Sequence[TimeDimensionSpec],
where_filter_specs: Sequence[WhereFilterSpec],
time_range_constraint: Optional[TimeRangeConstraint],
filter_to_specs: Optional[InstanceSpecSet] = None,
filter_to_specs: InstanceSpecSet,
measure_properties: Optional[MeasureSpecProperties] = None,
queried_linkable_specs: Optional[LinkableSpecSet] = None,
distinct: bool = False,
) -> DataflowPlanNode:
# TODO: docstring
Expand All @@ -1783,9 +1753,55 @@ def _build_pre_aggregation_plan(
parent_node=output_node, time_range_constraint=time_range_constraint
)

if filter_to_specs:
output_node = FilterElementsNode.create(
parent_node=output_node, include_specs=filter_to_specs, distinct=distinct
if measure_properties and measure_properties.non_additive_dimension_spec:
if queried_linkable_specs is None:
raise ValueError(
"`queried_linkable_specs` must be provided in _build_pre_aggregation_plan() if "
"`non_additive_dimension_spec` is present."
)
output_node = self._build_semi_additive_join_node(
measure_properties=measure_properties,
queried_linkable_specs=queried_linkable_specs,
current_node=output_node,
)

output_node = FilterElementsNode.create(
parent_node=output_node, include_specs=filter_to_specs, distinct=distinct
)

return output_node

def _build_semi_additive_join_node(
self,
measure_properties: MeasureSpecProperties,
queried_linkable_specs: LinkableSpecSet,
current_node: DataflowPlanNode,
) -> SemiAdditiveJoinNode:
non_additive_dimension_spec = measure_properties.non_additive_dimension_spec
assert (
non_additive_dimension_spec
), "_build_semi_additive_join_node() should only be called if there is a non_additive_dimension_spec."
# Apply semi additive join on the node
agg_time_dimension = measure_properties.agg_time_dimension
non_additive_dimension_grain = measure_properties.agg_time_dimension_grain
queried_time_dimension_spec: Optional[TimeDimensionSpec] = self._find_non_additive_dimension_in_linkable_specs(
agg_time_dimension=agg_time_dimension,
linkable_specs=queried_linkable_specs.as_tuple,
non_additive_dimension_spec=non_additive_dimension_spec,
)
time_dimension_spec = TimeDimensionSpec(
# The NonAdditiveDimensionSpec name property is a plain element name
element_name=non_additive_dimension_spec.name,
entity_links=(),
time_granularity=ExpandedTimeGranularity.from_time_granularity(non_additive_dimension_grain),
)
window_groupings = tuple(
LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings
)
return SemiAdditiveJoinNode.create(
parent_node=current_node,
entity_specs=window_groupings,
time_dimension_spec=time_dimension_spec,
agg_by_function=non_additive_dimension_spec.window_choice,
queried_time_dimension_spec=queried_time_dimension_spec,
)

0 comments on commit fbc9c9f

Please sign in to comment.