From 0c12518ee661f4b5d2ac56e29a1da4286c54a1ad Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Sep 2024 14:21:05 -0700 Subject: [PATCH] DataflowPlan for custom granularities --- .../dataflow/builder/dataflow_plan_builder.py | 39 ++++++++++++++++++- metricflow/dataflow/builder/node_evaluator.py | 9 ++++- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index c42bffb2bc..d98bcd4eae 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -77,6 +77,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode +from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.min_max import MinMaxNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode @@ -795,6 +796,15 @@ def _build_plan_for_distinct_values( if dataflow_recipe.join_targets: output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=dataflow_recipe.join_targets) + for time_dimension_spec in required_linkable_specs.time_dimension_specs: + if time_dimension_spec.time_granularity.is_custom_granularity: + include_base_grain = time_dimension_spec.with_base_grain in required_linkable_specs.time_dimension_specs + output_node = JoinToCustomGranularityNode.create( + parent_node=output_node, + time_dimension_spec=time_dimension_spec, + include_base_grain=include_base_grain, + ) + if len(query_level_filter_specs) > 0: output_node = WhereConstraintNode.create(parent_node=output_node, where_specs=query_level_filter_specs) if query_spec.time_range_constraint: @@ -885,11 +895,25 @@ def _select_source_nodes_with_linkable_specs( """Find source nodes with requested linkable specs and no measures.""" # Use a dictionary to dedupe for consistent ordering. selected_nodes: Dict[DataflowPlanNode, None] = {} - requested_linkable_specs_set = set(linkable_specs.as_tuple) + + # Find the source node that will satisfy the base granularity. Custom granularities will be joined in later. + linkable_specs_set_with_base_granularities: Set[LinkableInstanceSpec] = set() + # TODO: Add support for no-metrics queries for custom grains without a join (i.e., select directly from time spine). + for linkable_spec in linkable_specs.as_tuple: + if isinstance(linkable_spec, TimeDimensionSpec) and linkable_spec.time_granularity.is_custom_granularity: + linkable_spec_with_base_grain = linkable_spec.with_grain( + ExpandedTimeGranularity.from_time_granularity(linkable_spec.time_granularity.base_granularity) + ) + linkable_specs_set_with_base_granularities.add(linkable_spec_with_base_grain) + else: + linkable_specs_set_with_base_granularities.add(linkable_spec) + for source_node in source_nodes: output_spec_set = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set all_linkable_specs_in_node = set(output_spec_set.linkable_specs) - requested_linkable_specs_in_node = requested_linkable_specs_set.intersection(all_linkable_specs_in_node) + requested_linkable_specs_in_node = linkable_specs_set_with_base_granularities.intersection( + all_linkable_specs_in_node + ) if requested_linkable_specs_in_node: selected_nodes[source_node] = None @@ -1020,10 +1044,12 @@ def _find_dataflow_recipe( metric_time_dimension_reference=self._metric_time_dimension_reference, time_spine_nodes=self._source_node_set.time_spine_nodes_tuple, ) + logger.info( f"After removing unnecessary nodes, there are {len(candidate_nodes_for_right_side_of_join)} candidate " f"nodes for the right side of the join" ) + # TODO: test multi-hop with custom grains if DataflowPlanBuilder._contains_multihop_linkables(linkable_specs): candidate_nodes_for_right_side_of_join = list( node_processor.add_multi_hop_joins( @@ -1544,6 +1570,15 @@ def _build_aggregated_measure_from_measure_source_node( else: unaggregated_measure_node = filtered_measure_source_node + for time_dimension_spec in queried_linkable_specs.time_dimension_specs: + if time_dimension_spec.time_granularity.is_custom_granularity: + include_base_grain = time_dimension_spec.with_base_grain in required_linkable_specs.time_dimension_specs + unaggregated_measure_node = JoinToCustomGranularityNode.create( + parent_node=unaggregated_measure_node, + time_dimension_spec=time_dimension_spec, + include_base_grain=include_base_grain, + ) + # If time constraint was previously adjusted for cumulative window or grain, apply original time constraint # here. Can skip if metric is being aggregated over all time. cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None diff --git a/metricflow/dataflow/builder/node_evaluator.py b/metricflow/dataflow/builder/node_evaluator.py index b5a11fb2e8..442952922b 100644 --- a/metricflow/dataflow/builder/node_evaluator.py +++ b/metricflow/dataflow/builder/node_evaluator.py @@ -29,6 +29,7 @@ from metricflow_semantics.specs.entity_spec import LinklessEntitySpec from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec from metricflow_semantics.specs.spec_set import group_specs_by_type +from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow.dataflow.builder.node_data_set import DataflowPlanNodeOutputDataSetResolver @@ -406,6 +407,10 @@ def evaluate_node( logger.debug(f"Candidate spec set is:\n{mf_pformat(candidate_spec_set)}") data_set_linkable_specs = candidate_spec_set.linkable_specs + # Look for which nodes can satisfy the linkable specs at their base grains. Custom grains will be joined later. + required_linkable_specs_with_base_grains = [ + spec.with_base_grain if isinstance(spec, TimeDimensionSpec) else spec for spec in required_linkable_specs + ] # These are linkable specs in the start node data set. Those are considered "local". local_linkable_specs: List[LinkableInstanceSpec] = [] @@ -415,10 +420,10 @@ def evaluate_node( # Group required_linkable_specs into local / un-joinable / or possibly joinable. unjoinable_linkable_specs = [] - for required_linkable_spec in required_linkable_specs: + for required_linkable_spec in required_linkable_specs_with_base_grains: is_metric_time = required_linkable_spec.element_name == DataSet.metric_time_dimension_name() is_local = required_linkable_spec in data_set_linkable_specs - is_unjoinable = not is_metric_time and ( + is_unjoinable = (not is_metric_time) and ( len(required_linkable_spec.entity_links) == 0 or LinklessEntitySpec.from_reference(required_linkable_spec.entity_links[0]) not in data_set_linkable_specs