Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Dataflow Plan Node Types #1205

Merged
merged 8 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 36 additions & 37 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@
)
from metricflow.dataflow.builder.source_node import SourceNodeBuilder, SourceNodeSet
from metricflow.dataflow.dataflow_plan import (
BaseOutput,
DataflowPlan,
SinkOutput,
DataflowPlanNode,
)
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
Expand All @@ -73,7 +72,7 @@
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinToBaseOutputNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
Expand All @@ -93,7 +92,7 @@
class DataflowRecipe:
"""Get a recipe for how to build a dataflow plan node that outputs measures and linkable instances as needed."""

source_node: BaseOutput
source_node: DataflowPlanNode
required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...]
join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...]

Expand Down Expand Up @@ -151,7 +150,7 @@ def build_plan(

def _build_query_output_node(
self, query_spec: MetricFlowQuerySpec, for_group_by_source_node: bool = False
) -> BaseOutput:
) -> DataflowPlanNode:
"""Build SQL output node from query inputs. May be used to build query DFP or source node."""
for metric_spec in query_spec.metric_specs:
if (
Expand Down Expand Up @@ -211,7 +210,7 @@ def _build_plan(
)

plan_id = DagId.from_id_prefix(StaticIdPrefix.DATAFLOW_PLAN_PREFIX)
plan = DataflowPlan(sink_output_nodes=[sink_node], plan_id=plan_id)
plan = DataflowPlan(sink_nodes=[sink_node], plan_id=plan_id)
for optimizer in optimizers:
logger.info(f"Applying {optimizer.__class__.__name__}")
try:
Expand All @@ -234,7 +233,7 @@ def _build_aggregated_conversion_node(
queried_linkable_specs: LinkableSpecSet,
time_range_constraint: Optional[TimeRangeConstraint] = None,
constant_properties: Optional[Sequence[ConstantPropertyInput]] = None,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node that contains aggregated values of conversions and opportunities."""
# Build measure recipes
base_required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs(
Expand Down Expand Up @@ -302,7 +301,7 @@ def _build_aggregated_conversion_node(
# Build the unaggregated base measure node for computing conversions
unaggregated_base_measure_node = base_measure_recipe.source_node
if base_measure_recipe.join_targets:
unaggregated_base_measure_node = JoinToBaseOutputNode(
unaggregated_base_measure_node = JoinOnEntitiesNode(
left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets
)
filtered_unaggregated_base_node = FilterElementsNode(
Expand Down Expand Up @@ -455,7 +454,7 @@ def _build_derived_metric_output_node(
filter_spec_factory: WhereSpecFactory,
time_range_constraint: Optional[TimeRangeConstraint] = None,
for_group_by_source_node: bool = False,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node to compute a metric defined from other metrics."""
metric = self._metric_lookup.get_metric(metric_spec.reference)
metric_input_specs = self._build_input_metric_specs_for_derived_metric(
Expand All @@ -470,7 +469,7 @@ def _build_derived_metric_output_node(
queried_linkable_specs=queried_linkable_specs, filter_specs=metric_spec.filter_specs
)

parent_nodes: List[BaseOutput] = []
parent_nodes: List[DataflowPlanNode] = []

# This is the filter that's defined for the metric in the configs.
metric_definition_filter_specs = filter_spec_factory.create_from_where_filter_intersection(
Expand Down Expand Up @@ -509,7 +508,7 @@ def _build_derived_metric_output_node(
parent_node = (
parent_nodes[0] if len(parent_nodes) == 1 else CombineAggregatedOutputsNode(parent_nodes=parent_nodes)
)
output_node: BaseOutput = ComputeMetricsNode(
output_node: DataflowPlanNode = ComputeMetricsNode(
parent_node=parent_node,
metric_specs=[metric_spec],
for_group_by_source_node=for_group_by_source_node,
Expand Down Expand Up @@ -553,7 +552,7 @@ def _build_any_metric_output_node(
filter_spec_factory: WhereSpecFactory,
time_range_constraint: Optional[TimeRangeConstraint] = None,
for_group_by_source_node: bool = False,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node to compute a metric of any type."""
metric = self._metric_lookup.get_metric(metric_spec.reference)

Expand Down Expand Up @@ -592,7 +591,7 @@ def _build_metrics_output_node(
filter_spec_factory: WhereSpecFactory,
time_range_constraint: Optional[TimeRangeConstraint] = None,
for_group_by_source_node: bool = False,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node that computes all requested metrics.

Args:
Expand All @@ -602,7 +601,7 @@ def _build_metrics_output_node(
filter_spec_factory: Constructs WhereFilterSpecs with the resolved ambiguous group-by-items in the filter.
time_range_constraint: Time range constraint used to compute the metric.
"""
output_nodes: List[BaseOutput] = []
output_nodes: List[DataflowPlanNode] = []

for metric_spec in metric_specs:
logger.info(f"Generating compute metrics node for:\n{indent(mf_pformat(metric_spec))}")
Expand Down Expand Up @@ -661,7 +660,7 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da

output_node = dataflow_recipe.source_node
if dataflow_recipe.join_targets:
output_node = JoinToBaseOutputNode(left_node=output_node, join_targets=dataflow_recipe.join_targets)
output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode(
Expand All @@ -683,18 +682,18 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da
parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit
)

return DataflowPlan(sink_output_nodes=[sink_node])
return DataflowPlan(sink_nodes=[sink_node])

@staticmethod
def build_sink_node(
parent_node: BaseOutput,
parent_node: DataflowPlanNode,
order_by_specs: Sequence[OrderBySpec],
output_sql_table: Optional[SqlTable] = None,
limit: Optional[int] = None,
output_selection_specs: Optional[InstanceSpecSet] = None,
) -> SinkOutput:
) -> DataflowPlanNode:
"""Adds order by / limit / write nodes."""
pre_result_node: Optional[BaseOutput] = None
pre_result_node: Optional[DataflowPlanNode] = None

if order_by_specs or limit:
pre_result_node = OrderByLimitNode(
Expand All @@ -706,7 +705,7 @@ def build_sink_node(
parent_node=pre_result_node or parent_node, include_specs=output_selection_specs
)

write_result_node: SinkOutput
write_result_node: DataflowPlanNode
if not output_sql_table:
write_result_node = WriteToResultDataframeNode(parent_node=pre_result_node or parent_node)
else:
Expand All @@ -721,21 +720,21 @@ def _contains_multihop_linkables(linkable_specs: Sequence[LinkableInstanceSpec])
"""Returns true if any of the linkable specs requires a multi-hop join to realize."""
return any(len(x.entity_links) > 1 for x in linkable_specs)

def _sort_by_suitability(self, nodes: Sequence[BaseOutput]) -> Sequence[BaseOutput]:
def _sort_by_suitability(self, nodes: Sequence[DataflowPlanNode]) -> Sequence[DataflowPlanNode]:
"""Sort nodes by the number of linkable specs.

The lower the number of linkable specs means less aggregation required.
"""

def sort_function(node: BaseOutput) -> int:
def sort_function(node: DataflowPlanNode) -> int:
data_set = self._node_data_set_resolver.get_output_data_set(node)
return len(data_set.instance_set.spec_set.linkable_specs)

return sorted(nodes, key=sort_function)

def _select_source_nodes_with_measures(
self, measure_specs: Set[MeasureSpec], source_nodes: Sequence[BaseOutput]
) -> Sequence[BaseOutput]:
self, measure_specs: Set[MeasureSpec], source_nodes: Sequence[DataflowPlanNode]
) -> Sequence[DataflowPlanNode]:
nodes = []
measure_specs_set = set(measure_specs)
for source_node in source_nodes:
Expand All @@ -747,11 +746,11 @@ def _select_source_nodes_with_measures(
return nodes

def _select_source_nodes_with_linkable_specs(
self, linkable_specs: LinkableSpecSet, source_nodes: Sequence[BaseOutput]
) -> Sequence[BaseOutput]:
self, linkable_specs: LinkableSpecSet, source_nodes: Sequence[DataflowPlanNode]
) -> Sequence[DataflowPlanNode]:
"""Find source nodes with requested linkable specs and no measures."""
# Use a dictionary to dedupe for consistent ordering.
selected_nodes: Dict[BaseOutput, None] = {}
selected_nodes: Dict[DataflowPlanNode, None] = {}
requested_linkable_specs_set = set(linkable_specs.as_tuple)
for source_node in source_nodes:
output_spec_set = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
Expand Down Expand Up @@ -823,8 +822,8 @@ def _find_dataflow_recipe(
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> Optional[DataflowRecipe]:
linkable_specs = linkable_spec_set.as_tuple
candidate_nodes_for_left_side_of_join: List[BaseOutput] = []
candidate_nodes_for_right_side_of_join: List[BaseOutput] = []
candidate_nodes_for_left_side_of_join: List[DataflowPlanNode] = []
candidate_nodes_for_right_side_of_join: List[DataflowPlanNode] = []

if measure_spec_properties:
candidate_nodes_for_right_side_of_join += self._source_node_set.source_nodes_for_metric_queries
Expand Down Expand Up @@ -909,7 +908,7 @@ def _find_dataflow_recipe(
)

# Dict from the node that contains the source node to the evaluation results.
node_to_evaluation: Dict[BaseOutput, LinkableInstanceSatisfiabilityEvaluation] = {}
node_to_evaluation: Dict[DataflowPlanNode, LinkableInstanceSatisfiabilityEvaluation] = {}

for node in self._sort_by_suitability(candidate_nodes_for_left_side_of_join):
data_set = self._node_data_set_resolver.get_output_data_set(node)
Expand Down Expand Up @@ -1009,7 +1008,7 @@ def _find_dataflow_recipe(
def build_computed_metrics_node(
self,
metric_spec: MetricSpec,
aggregated_measures_node: Union[AggregateMeasuresNode, BaseOutput],
aggregated_measures_node: Union[AggregateMeasuresNode, DataflowPlanNode],
aggregated_to_elements: Set[LinkableInstanceSpec],
for_group_by_source_node: bool = False,
) -> ComputeMetricsNode:
Expand Down Expand Up @@ -1182,7 +1181,7 @@ def build_aggregated_measure(
queried_linkable_specs: LinkableSpecSet,
time_range_constraint: Optional[TimeRangeConstraint] = None,
measure_recipe: Optional[DataflowRecipe] = None,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Returns a node where the measures are aggregated by the linkable specs and constrained appropriately.

This might be a node representing a single aggregation over one semantic model, or a node representing
Expand Down Expand Up @@ -1234,7 +1233,7 @@ def _build_aggregated_measure_from_measure_source_node(
queried_linkable_specs: LinkableSpecSet,
time_range_constraint: Optional[TimeRangeConstraint] = None,
measure_recipe: Optional[DataflowRecipe] = None,
) -> BaseOutput:
) -> DataflowPlanNode:
measure_spec = metric_input_measure_spec.measure_spec
cumulative = metric_input_measure_spec.cumulative_description is not None
cumulative_window = (
Expand Down Expand Up @@ -1359,9 +1358,9 @@ def _build_aggregated_measure_from_measure_source_node(
)

join_targets = measure_recipe.join_targets
unaggregated_measure_node: BaseOutput
unaggregated_measure_node: DataflowPlanNode
if len(join_targets) > 0:
filtered_measures_with_joined_elements = JoinToBaseOutputNode(
filtered_measures_with_joined_elements = JoinOnEntitiesNode(
left_node=filtered_measure_source_node,
join_targets=join_targets,
)
Expand All @@ -1388,7 +1387,7 @@ def _build_aggregated_measure_from_measure_source_node(
unaggregated_measure_node, time_range_constraint
)

pre_aggregate_node: BaseOutput = cumulative_metric_constrained_node or unaggregated_measure_node
pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node
merged_where_filter_spec = WhereFilterSpec.merge_iterable(metric_input_measure_spec.filter_specs)
if len(metric_input_measure_spec.filter_specs) > 0:
# Apply where constraint on the node
Expand Down Expand Up @@ -1444,7 +1443,7 @@ def _build_aggregated_measure_from_measure_source_node(
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
output_node: BaseOutput = JoinToTimeSpineNode(
output_node: DataflowPlanNode = JoinToTimeSpineNode(
parent_node=aggregate_measures_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/node_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ class DataflowPlanNodeOutputDataSetResolver(DataflowToSqlQueryPlanConverter):
generate a set of nodes that already include the multi-hop dimensions available, the join resolution logic becomes
much simpler. For example, a node like:

<JoinToBaseOutputNode>
<JoinOnEntitiesNode>
<!-- Join dim_users and dim_devices by device_id -->
<ReadSqlSourceNode>
<!-- Read from dim_users to get user_id, device_id -->
<ReadSqlSourceNodes>
<!-- Read from dim_devices device_id, platform -->
</JoinToBaseOutputNode>
</JoinOnEntitiesNode>

would have the dimension user_id__device_id__platform available, so to NodeEvaluatorForLinkableInstances,
it's the same problem as doing a single-hop join. This simplifies the join resolution logic, though now the input
Expand Down
8 changes: 4 additions & 4 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
PartitionJoinResolver,
PartitionTimeDimensionJoinDescription,
)
from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, ValidityWindowJoinDescription
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
Expand All @@ -59,7 +59,7 @@ class JoinLinkableInstancesRecipe:
satisfiable_linkable_specs.
"""

node_to_join: BaseOutput
node_to_join: DataflowPlanNode
# The entity to join "node_to_join" on. Only nullable if using CROSS JOIN.
join_on_entity: Optional[LinklessEntitySpec]
# The linkable instances from the query that can be satisfied if we join this node. Note that this is different from
Expand Down Expand Up @@ -168,7 +168,7 @@ class NodeEvaluatorForLinkableInstances:
def __init__(
self,
semantic_model_lookup: SemanticModelLookup,
nodes_available_for_joins: Sequence[BaseOutput],
nodes_available_for_joins: Sequence[DataflowPlanNode],
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver,
time_spine_node: MetricTimeDimensionTransformNode,
) -> None:
Expand Down Expand Up @@ -387,7 +387,7 @@ def _update_candidates_that_can_satisfy_linkable_specs(

def evaluate_node(
self,
left_node: BaseOutput,
left_node: DataflowPlanNode,
required_linkable_specs: Sequence[LinkableInstanceSpec],
default_join_type: SqlJoinType,
) -> LinkableInstanceSatisfiabilityEvaluation:
Expand Down
12 changes: 6 additions & 6 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.spec_classes import GroupByMetricSpec

from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter
Expand All @@ -30,17 +30,17 @@ class SourceNodeSet:
# mapped to components with a transformation node to add `metric_time` / to support multiple aggregation time
# dimensions. Each semantic model containing measures with k different aggregation time dimensions is mapped to k
# components.
source_nodes_for_metric_queries: Tuple[BaseOutput, ...]
source_nodes_for_metric_queries: Tuple[DataflowPlanNode, ...]

# Semantic models are 1:1 mapped to a ReadSqlSourceNode. The tuple also contains the same `time_spine_node` as
# below. See usage in `DataflowPlanBuilder`.
source_nodes_for_group_by_item_queries: Tuple[BaseOutput, ...]
source_nodes_for_group_by_item_queries: Tuple[DataflowPlanNode, ...]

# Provides the time spine.
time_spine_node: MetricTimeDimensionTransformNode

@property
def all_nodes(self) -> Sequence[BaseOutput]: # noqa: D102
def all_nodes(self) -> Sequence[DataflowPlanNode]: # noqa: D102
return (
self.source_nodes_for_metric_queries + self.source_nodes_for_group_by_item_queries + (self.time_spine_node,)
)
Expand All @@ -67,8 +67,8 @@ def __init__( # noqa: D107

def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> SourceNodeSet:
"""Creates a `SourceNodeSet` from SemanticModelDataSets."""
group_by_item_source_nodes: List[BaseOutput] = []
source_nodes_for_metric_queries: List[BaseOutput] = []
group_by_item_source_nodes: List[DataflowPlanNode] = []
source_nodes_for_metric_queries: List[DataflowPlanNode] = []

for data_set in data_sets:
read_node = ReadSqlSourceNode(data_set)
Expand Down
Loading
Loading