diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index fed318e75c..85aea277c9 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -1649,43 +1649,11 @@ def _build_aggregated_measure_from_measure_source_node( custom_granularity_specs=custom_granularity_specs_to_join, where_filter_specs=metric_input_measure_spec.filter_spec_set.all_filter_specs, time_range_constraint=time_range_constraint_to_apply, - ) - - non_additive_dimension_spec = measure_properties.non_additive_dimension_spec - if non_additive_dimension_spec is not None: - # Apply semi additive join on the node - agg_time_dimension = measure_properties.agg_time_dimension - non_additive_dimension_grain = measure_properties.agg_time_dimension_grain - queried_time_dimension_spec: Optional[ - TimeDimensionSpec - ] = self._find_non_additive_dimension_in_linkable_specs( - agg_time_dimension=agg_time_dimension, - linkable_specs=queried_linkable_specs.as_tuple, - non_additive_dimension_spec=non_additive_dimension_spec, - ) - time_dimension_spec = TimeDimensionSpec( - # The NonAdditiveDimensionSpec name property is a plain element name - element_name=non_additive_dimension_spec.name, - entity_links=(), - time_granularity=ExpandedTimeGranularity.from_time_granularity(non_additive_dimension_grain), - ) - window_groupings = tuple( - LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings - ) - unaggregated_measure_node = SemiAdditiveJoinNode.create( - parent_node=unaggregated_measure_node, - entity_specs=window_groupings, - time_dimension_spec=time_dimension_spec, - agg_by_function=non_additive_dimension_spec.window_choice, - queried_time_dimension_spec=queried_time_dimension_spec, - ) - - # Filter to just the required measure and the requested group bys so that aggregations work correctly. - unaggregated_measure_node = FilterElementsNode.create( - parent_node=unaggregated_measure_node, - include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge( + filter_to_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge( InstanceSpecSet.create_from_specs(queried_linkable_specs.as_tuple) ), + measure_properties=measure_properties, + queried_linkable_specs=queried_linkable_specs, ) aggregate_measures_node = AggregateMeasuresNode.create( @@ -1762,7 +1730,9 @@ def _build_pre_aggregation_plan( custom_granularity_specs: Sequence[TimeDimensionSpec], where_filter_specs: Sequence[WhereFilterSpec], time_range_constraint: Optional[TimeRangeConstraint], - filter_to_specs: Optional[InstanceSpecSet] = None, + filter_to_specs: InstanceSpecSet, + measure_properties: Optional[MeasureSpecProperties] = None, + queried_linkable_specs: Optional[LinkableSpecSet] = None, distinct: bool = False, ) -> DataflowPlanNode: # TODO: docstring @@ -1783,9 +1753,55 @@ def _build_pre_aggregation_plan( parent_node=output_node, time_range_constraint=time_range_constraint ) - if filter_to_specs: - output_node = FilterElementsNode.create( - parent_node=output_node, include_specs=filter_to_specs, distinct=distinct + if measure_properties and measure_properties.non_additive_dimension_spec: + if queried_linkable_specs is None: + raise ValueError( + "`queried_linkable_specs` must be provided in _build_pre_aggregation_plan() if " + "`non_additive_dimension_spec` is present." + ) + output_node = self._build_semi_additive_join_node( + measure_properties=measure_properties, + queried_linkable_specs=queried_linkable_specs, + current_node=output_node, ) + output_node = FilterElementsNode.create( + parent_node=output_node, include_specs=filter_to_specs, distinct=distinct + ) + return output_node + + def _build_semi_additive_join_node( + self, + measure_properties: MeasureSpecProperties, + queried_linkable_specs: LinkableSpecSet, + current_node: DataflowPlanNode, + ) -> SemiAdditiveJoinNode: + non_additive_dimension_spec = measure_properties.non_additive_dimension_spec + assert ( + non_additive_dimension_spec + ), "_build_semi_additive_join_node() should only be called if there is a non_additive_dimension_spec." + # Apply semi additive join on the node + agg_time_dimension = measure_properties.agg_time_dimension + non_additive_dimension_grain = measure_properties.agg_time_dimension_grain + queried_time_dimension_spec: Optional[TimeDimensionSpec] = self._find_non_additive_dimension_in_linkable_specs( + agg_time_dimension=agg_time_dimension, + linkable_specs=queried_linkable_specs.as_tuple, + non_additive_dimension_spec=non_additive_dimension_spec, + ) + time_dimension_spec = TimeDimensionSpec( + # The NonAdditiveDimensionSpec name property is a plain element name + element_name=non_additive_dimension_spec.name, + entity_links=(), + time_granularity=ExpandedTimeGranularity.from_time_granularity(non_additive_dimension_grain), + ) + window_groupings = tuple( + LinklessEntitySpec.from_element_name(name) for name in non_additive_dimension_spec.window_groupings + ) + return SemiAdditiveJoinNode.create( + parent_node=current_node, + entity_specs=window_groupings, + time_dimension_spec=time_dimension_spec, + agg_by_function=non_additive_dimension_spec.window_choice, + queried_time_dimension_spec=queried_time_dimension_spec, + )