From 1a5b5e4cd29e40da2401946188bc96f9b4464bc7 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 13:25:08 -0800 Subject: [PATCH] Build DataflowPlan for custom offset window with most grains This is the dataflow plan that will be used if the custom grain is queried with any grains that aren't the same as the grain used in the offset window. --- .../model/semantics/test_metric_lookup.py | 7 +- .../dataflow/builder/dataflow_plan_builder.py | 121 ++++++++++++++---- 2 files changed, 94 insertions(+), 34 deletions(-) diff --git a/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py b/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py index d9942eeb4..b69c82d62 100644 --- a/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py +++ b/metricflow-semantics/tests_metricflow_semantics/model/semantics/test_metric_lookup.py @@ -27,12 +27,7 @@ def test_min_queryable_time_granularity_for_different_agg_time_grains( # noqa: def test_custom_offset_window_for_metric( simple_semantic_manifest_lookup: SemanticManifestLookup, ) -> None: - """Test offset window with custom grain supplied. - - TODO: As of now, the functionality of an offset window with a custom grain is not supported in MF. - This test is added to show that at least the parsing is successful using a custom grain offset window. - Once support for that is added in MF + relevant tests, this test can be removed. - """ + """Test offset window with custom grain supplied.""" metric = simple_semantic_manifest_lookup.metric_lookup.get_metric(MetricReference("bookings_offset_martian_day")) assert len(metric.input_metrics) == 1 diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index e7456d777..f66667cbf 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -54,6 +54,7 @@ from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.specs.where_filter.where_filter_spec_set import WhereFilterSpecSet from metricflow_semantics.specs.where_filter.where_filter_transform import WhereSpecFactory +from metricflow_semantics.sql.sql_exprs import SqlWindowFunction from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.time.dateutil_adjuster import DateutilTimePeriodAdjuster @@ -84,6 +85,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -92,6 +94,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -658,13 +661,22 @@ def _build_derived_metric_output_node( ) if metric_spec.has_time_offset and queried_agg_time_dimension_specs: # TODO: move this to a helper method - time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs) + time_spine_node = self._build_time_spine_node( + queried_time_spine_specs=queried_agg_time_dimension_specs, + offset_window=metric_spec.offset_window, + ) output_node = JoinToTimeSpineNode.create( metric_source_node=output_node, time_spine_node=time_spine_node, requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, join_on_time_dimension_spec=self._sort_by_base_granularity(queried_agg_time_dimension_specs)[0], - offset_window=metric_spec.offset_window, + offset_window=( + metric_spec.offset_window + if metric_spec.offset_window + and metric_spec.offset_window.granularity + not in self._semantic_model_lookup.custom_granularity_names + else None + ), offset_to_grain=metric_spec.offset_to_grain, join_type=SqlJoinType.INNER, ) @@ -1651,13 +1663,22 @@ def _build_aggregated_measure_from_measure_source_node( required_time_spine_specs = base_queried_agg_time_dimension_specs if join_on_time_dimension_spec not in required_time_spine_specs: required_time_spine_specs = (join_on_time_dimension_spec,) + required_time_spine_specs - time_spine_node = self._build_time_spine_node(required_time_spine_specs) + time_spine_node = self._build_time_spine_node( + queried_time_spine_specs=required_time_spine_specs, + offset_window=before_aggregation_time_spine_join_description.offset_window, + ) unaggregated_measure_node = JoinToTimeSpineNode.create( metric_source_node=unaggregated_measure_node, time_spine_node=time_spine_node, requested_agg_time_dimension_specs=base_queried_agg_time_dimension_specs, join_on_time_dimension_spec=join_on_time_dimension_spec, - offset_window=before_aggregation_time_spine_join_description.offset_window, + offset_window=( + before_aggregation_time_spine_join_description.offset_window + if before_aggregation_time_spine_join_description.offset_window + and before_aggregation_time_spine_join_description.offset_window.granularity + not in self._semantic_model_lookup.custom_granularity_names + else None + ), offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, join_type=before_aggregation_time_spine_join_description.join_type, ) @@ -1864,6 +1885,7 @@ def _build_time_spine_node( queried_time_spine_specs: Sequence[TimeDimensionSpec], where_filter_specs: Sequence[WhereFilterSpec] = (), time_range_constraint: Optional[TimeRangeConstraint] = None, + offset_window: Optional[MetricTimeWindow] = None, ) -> DataflowPlanNode: """Return the time spine node needed to satisfy the specs.""" required_time_spine_spec_set = self.__get_required_linkable_specs( @@ -1872,30 +1894,35 @@ def _build_time_spine_node( ) required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs - # TODO: support multiple time spines here. Build node on the one with the smallest base grain. - # Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine. - time_spine_source = self._choose_time_spine_source(required_time_spine_specs) - read_node = self._choose_time_spine_read_node(time_spine_source) - time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node) - - # Change the column aliases to match the specs that were requested in the query. - time_spine_node = AliasSpecsNode.create( - parent_node=read_node, - change_specs=tuple( - SpecToAlias( - input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part( - time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part - ).spec, - output_spec=required_spec, - ) - for required_spec in required_time_spine_specs - ), - ) - - # If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping. - should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in { - spec.time_granularity for spec in queried_time_spine_specs - } + should_dedupe = False + if offset_window and offset_window.granularity in self._semantic_model_lookup._custom_granularities: + time_spine_node = self._build_custom_offset_time_spine_node( + offset_window=offset_window, required_time_spine_specs=required_time_spine_specs + ) + else: + # For simpler time spine queries, choose the appropriate time spine node and apply requested aliases. + time_spine_source = self._choose_time_spine_source(required_time_spine_specs) + # TODO: support multiple time spines here. Build node on the one with the smallest base grain. + # Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine. + read_node = self._choose_time_spine_read_node(time_spine_source) + time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node) + # Change the column aliases to match the specs that were requested in the query. + time_spine_node = AliasSpecsNode.create( + parent_node=read_node, + change_specs=tuple( + SpecToAlias( + input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=required_spec.time_granularity.name, date_part=required_spec.date_part + ).spec, + output_spec=required_spec, + ) + for required_spec in required_time_spine_specs + ), + ) + # If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping. + should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in { + spec.time_granularity for spec in queried_time_spine_specs + } return self._build_pre_aggregation_plan( source_node=time_spine_node, @@ -1905,6 +1932,44 @@ def _build_time_spine_node( distinct=should_dedupe, ) + def _build_custom_offset_time_spine_node( + self, offset_window: MetricTimeWindow, required_time_spine_specs: Tuple[TimeDimensionSpec, ...] + ) -> DataflowPlanNode: + # Build time spine node that offsets agg time dimensions by a custom grain. + custom_grain = self._semantic_model_lookup._custom_granularities[offset_window.granularity] + time_spine_source = self._choose_time_spine_source((DataSet.metric_time_dimension_spec(custom_grain),)) + time_spine_read_node = self._choose_time_spine_read_node(time_spine_source) + if {spec.time_granularity for spec in required_time_spine_specs} == {custom_grain}: + # If querying with only the same grain as is used in the offset_window, can use a simpler plan. + raise NotImplementedError + else: + # For custom offset windows queried with other granularities, first, build CustomGranularityBoundsNode. + # This will be used twice in the output node, and ideally will be turned into a CTE. + bounds_node = CustomGranularityBoundsNode.create( + parent_node=time_spine_read_node, custom_granularity_name=custom_grain.name + ) + # Build a FilterElementsNode from bounds node to get required unique rows. + bounds_data_set = self._node_data_set_resolver.get_output_data_set(bounds_node) + bounds_specs = tuple( + bounds_data_set.instance_from_window_function(window_func).spec + for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE) + ) + custom_grain_spec = bounds_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=custom_grain.name, date_part=None + ).spec + filter_elements_node = FilterElementsNode.create( + parent_node=bounds_node, + include_specs=InstanceSpecSet(time_dimension_specs=(custom_grain_spec,) + bounds_specs), + distinct=True, + ) + # Pass both the CustomGranularityBoundsNode and the FilterElementsNode into the OffsetByCustomGranularityNode. + return OffsetByCustomGranularityNode.create( + custom_granularity_bounds_node=bounds_node, + filter_elements_node=filter_elements_node, + offset_window=offset_window, + required_time_spine_specs=required_time_spine_specs, + ) + def _sort_by_base_granularity(self, time_dimension_specs: Sequence[TimeDimensionSpec]) -> List[TimeDimensionSpec]: """Sort the time dimensions by their base granularity.