Skip to content

Commit

Permalink
Remove Dataflow Plan Node Types (#1205)
Browse files Browse the repository at this point in the history
### Description

The original class hierarchy for the `DataflowPlanNodes` included types
that described the data that was output by the node. However, those
turned out to not be useful in practice (e.g. `BaseOutput` was the
majority of use cases), so this PR removes them.


<!--- 
  Before requesting review, please make sure you have:
1. read [the contributing
guide](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md),
2. signed the
[CLA](https://docs.getdbt.com/docs/contributor-license-agreements)
3. run `changie new` to [create a changelog
entry](https://github.com/dbt-labs/metricflow/blob/main/CONTRIBUTING.md#adding-a-changelog-entry)
-->
  • Loading branch information
plypaul authored May 16, 2024
1 parent 03d7650 commit 8ba3897
Show file tree
Hide file tree
Showing 63 changed files with 466 additions and 502 deletions.
73 changes: 36 additions & 37 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@
)
from metricflow.dataflow.builder.source_node import SourceNodeBuilder, SourceNodeSet
from metricflow.dataflow.dataflow_plan import (
BaseOutput,
DataflowPlan,
SinkOutput,
DataflowPlanNode,
)
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
Expand All @@ -73,7 +72,7 @@
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinToBaseOutputNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
Expand All @@ -93,7 +92,7 @@
class DataflowRecipe:
"""Get a recipe for how to build a dataflow plan node that outputs measures and linkable instances as needed."""

source_node: BaseOutput
source_node: DataflowPlanNode
required_local_linkable_specs: Tuple[LinkableInstanceSpec, ...]
join_linkable_instances_recipes: Tuple[JoinLinkableInstancesRecipe, ...]

Expand Down Expand Up @@ -151,7 +150,7 @@ def build_plan(

def _build_query_output_node(
self, query_spec: MetricFlowQuerySpec, for_group_by_source_node: bool = False
) -> BaseOutput:
) -> DataflowPlanNode:
"""Build SQL output node from query inputs. May be used to build query DFP or source node."""
for metric_spec in query_spec.metric_specs:
if (
Expand Down Expand Up @@ -211,7 +210,7 @@ def _build_plan(
)

plan_id = DagId.from_id_prefix(StaticIdPrefix.DATAFLOW_PLAN_PREFIX)
plan = DataflowPlan(sink_output_nodes=[sink_node], plan_id=plan_id)
plan = DataflowPlan(sink_nodes=[sink_node], plan_id=plan_id)
for optimizer in optimizers:
logger.info(f"Applying {optimizer.__class__.__name__}")
try:
Expand All @@ -234,7 +233,7 @@ def _build_aggregated_conversion_node(
queried_linkable_specs: LinkableSpecSet,
time_range_constraint: Optional[TimeRangeConstraint] = None,
constant_properties: Optional[Sequence[ConstantPropertyInput]] = None,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node that contains aggregated values of conversions and opportunities."""
# Build measure recipes
base_required_linkable_specs, _ = self.__get_required_and_extraneous_linkable_specs(
Expand Down Expand Up @@ -302,7 +301,7 @@ def _build_aggregated_conversion_node(
# Build the unaggregated base measure node for computing conversions
unaggregated_base_measure_node = base_measure_recipe.source_node
if base_measure_recipe.join_targets:
unaggregated_base_measure_node = JoinToBaseOutputNode(
unaggregated_base_measure_node = JoinOnEntitiesNode(
left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets
)
filtered_unaggregated_base_node = FilterElementsNode(
Expand Down Expand Up @@ -455,7 +454,7 @@ def _build_derived_metric_output_node(
filter_spec_factory: WhereSpecFactory,
time_range_constraint: Optional[TimeRangeConstraint] = None,
for_group_by_source_node: bool = False,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node to compute a metric defined from other metrics."""
metric = self._metric_lookup.get_metric(metric_spec.reference)
metric_input_specs = self._build_input_metric_specs_for_derived_metric(
Expand All @@ -470,7 +469,7 @@ def _build_derived_metric_output_node(
queried_linkable_specs=queried_linkable_specs, filter_specs=metric_spec.filter_specs
)

parent_nodes: List[BaseOutput] = []
parent_nodes: List[DataflowPlanNode] = []

# This is the filter that's defined for the metric in the configs.
metric_definition_filter_specs = filter_spec_factory.create_from_where_filter_intersection(
Expand Down Expand Up @@ -509,7 +508,7 @@ def _build_derived_metric_output_node(
parent_node = (
parent_nodes[0] if len(parent_nodes) == 1 else CombineAggregatedOutputsNode(parent_nodes=parent_nodes)
)
output_node: BaseOutput = ComputeMetricsNode(
output_node: DataflowPlanNode = ComputeMetricsNode(
parent_node=parent_node,
metric_specs=[metric_spec],
for_group_by_source_node=for_group_by_source_node,
Expand Down Expand Up @@ -553,7 +552,7 @@ def _build_any_metric_output_node(
filter_spec_factory: WhereSpecFactory,
time_range_constraint: Optional[TimeRangeConstraint] = None,
for_group_by_source_node: bool = False,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node to compute a metric of any type."""
metric = self._metric_lookup.get_metric(metric_spec.reference)

Expand Down Expand Up @@ -592,7 +591,7 @@ def _build_metrics_output_node(
filter_spec_factory: WhereSpecFactory,
time_range_constraint: Optional[TimeRangeConstraint] = None,
for_group_by_source_node: bool = False,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Builds a node that computes all requested metrics.
Args:
Expand All @@ -602,7 +601,7 @@ def _build_metrics_output_node(
filter_spec_factory: Constructs WhereFilterSpecs with the resolved ambiguous group-by-items in the filter.
time_range_constraint: Time range constraint used to compute the metric.
"""
output_nodes: List[BaseOutput] = []
output_nodes: List[DataflowPlanNode] = []

for metric_spec in metric_specs:
logger.info(f"Generating compute metrics node for:\n{indent(mf_pformat(metric_spec))}")
Expand Down Expand Up @@ -661,7 +660,7 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da

output_node = dataflow_recipe.source_node
if dataflow_recipe.join_targets:
output_node = JoinToBaseOutputNode(left_node=output_node, join_targets=dataflow_recipe.join_targets)
output_node = JoinOnEntitiesNode(left_node=output_node, join_targets=dataflow_recipe.join_targets)

if len(query_level_filter_specs) > 0:
output_node = WhereConstraintNode(
Expand All @@ -683,18 +682,18 @@ def _build_plan_for_distinct_values(self, query_spec: MetricFlowQuerySpec) -> Da
parent_node=output_node, order_by_specs=query_spec.order_by_specs, limit=query_spec.limit
)

return DataflowPlan(sink_output_nodes=[sink_node])
return DataflowPlan(sink_nodes=[sink_node])

@staticmethod
def build_sink_node(
parent_node: BaseOutput,
parent_node: DataflowPlanNode,
order_by_specs: Sequence[OrderBySpec],
output_sql_table: Optional[SqlTable] = None,
limit: Optional[int] = None,
output_selection_specs: Optional[InstanceSpecSet] = None,
) -> SinkOutput:
) -> DataflowPlanNode:
"""Adds order by / limit / write nodes."""
pre_result_node: Optional[BaseOutput] = None
pre_result_node: Optional[DataflowPlanNode] = None

if order_by_specs or limit:
pre_result_node = OrderByLimitNode(
Expand All @@ -706,7 +705,7 @@ def build_sink_node(
parent_node=pre_result_node or parent_node, include_specs=output_selection_specs
)

write_result_node: SinkOutput
write_result_node: DataflowPlanNode
if not output_sql_table:
write_result_node = WriteToResultDataframeNode(parent_node=pre_result_node or parent_node)
else:
Expand All @@ -721,21 +720,21 @@ def _contains_multihop_linkables(linkable_specs: Sequence[LinkableInstanceSpec])
"""Returns true if any of the linkable specs requires a multi-hop join to realize."""
return any(len(x.entity_links) > 1 for x in linkable_specs)

def _sort_by_suitability(self, nodes: Sequence[BaseOutput]) -> Sequence[BaseOutput]:
def _sort_by_suitability(self, nodes: Sequence[DataflowPlanNode]) -> Sequence[DataflowPlanNode]:
"""Sort nodes by the number of linkable specs.
The lower the number of linkable specs means less aggregation required.
"""

def sort_function(node: BaseOutput) -> int:
def sort_function(node: DataflowPlanNode) -> int:
data_set = self._node_data_set_resolver.get_output_data_set(node)
return len(data_set.instance_set.spec_set.linkable_specs)

return sorted(nodes, key=sort_function)

def _select_source_nodes_with_measures(
self, measure_specs: Set[MeasureSpec], source_nodes: Sequence[BaseOutput]
) -> Sequence[BaseOutput]:
self, measure_specs: Set[MeasureSpec], source_nodes: Sequence[DataflowPlanNode]
) -> Sequence[DataflowPlanNode]:
nodes = []
measure_specs_set = set(measure_specs)
for source_node in source_nodes:
Expand All @@ -747,11 +746,11 @@ def _select_source_nodes_with_measures(
return nodes

def _select_source_nodes_with_linkable_specs(
self, linkable_specs: LinkableSpecSet, source_nodes: Sequence[BaseOutput]
) -> Sequence[BaseOutput]:
self, linkable_specs: LinkableSpecSet, source_nodes: Sequence[DataflowPlanNode]
) -> Sequence[DataflowPlanNode]:
"""Find source nodes with requested linkable specs and no measures."""
# Use a dictionary to dedupe for consistent ordering.
selected_nodes: Dict[BaseOutput, None] = {}
selected_nodes: Dict[DataflowPlanNode, None] = {}
requested_linkable_specs_set = set(linkable_specs.as_tuple)
for source_node in source_nodes:
output_spec_set = self._node_data_set_resolver.get_output_data_set(source_node).instance_set.spec_set
Expand Down Expand Up @@ -823,8 +822,8 @@ def _find_dataflow_recipe(
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> Optional[DataflowRecipe]:
linkable_specs = linkable_spec_set.as_tuple
candidate_nodes_for_left_side_of_join: List[BaseOutput] = []
candidate_nodes_for_right_side_of_join: List[BaseOutput] = []
candidate_nodes_for_left_side_of_join: List[DataflowPlanNode] = []
candidate_nodes_for_right_side_of_join: List[DataflowPlanNode] = []

if measure_spec_properties:
candidate_nodes_for_right_side_of_join += self._source_node_set.source_nodes_for_metric_queries
Expand Down Expand Up @@ -909,7 +908,7 @@ def _find_dataflow_recipe(
)

# Dict from the node that contains the source node to the evaluation results.
node_to_evaluation: Dict[BaseOutput, LinkableInstanceSatisfiabilityEvaluation] = {}
node_to_evaluation: Dict[DataflowPlanNode, LinkableInstanceSatisfiabilityEvaluation] = {}

for node in self._sort_by_suitability(candidate_nodes_for_left_side_of_join):
data_set = self._node_data_set_resolver.get_output_data_set(node)
Expand Down Expand Up @@ -1009,7 +1008,7 @@ def _find_dataflow_recipe(
def build_computed_metrics_node(
self,
metric_spec: MetricSpec,
aggregated_measures_node: Union[AggregateMeasuresNode, BaseOutput],
aggregated_measures_node: Union[AggregateMeasuresNode, DataflowPlanNode],
aggregated_to_elements: Set[LinkableInstanceSpec],
for_group_by_source_node: bool = False,
) -> ComputeMetricsNode:
Expand Down Expand Up @@ -1182,7 +1181,7 @@ def build_aggregated_measure(
queried_linkable_specs: LinkableSpecSet,
time_range_constraint: Optional[TimeRangeConstraint] = None,
measure_recipe: Optional[DataflowRecipe] = None,
) -> BaseOutput:
) -> DataflowPlanNode:
"""Returns a node where the measures are aggregated by the linkable specs and constrained appropriately.
This might be a node representing a single aggregation over one semantic model, or a node representing
Expand Down Expand Up @@ -1234,7 +1233,7 @@ def _build_aggregated_measure_from_measure_source_node(
queried_linkable_specs: LinkableSpecSet,
time_range_constraint: Optional[TimeRangeConstraint] = None,
measure_recipe: Optional[DataflowRecipe] = None,
) -> BaseOutput:
) -> DataflowPlanNode:
measure_spec = metric_input_measure_spec.measure_spec
cumulative = metric_input_measure_spec.cumulative_description is not None
cumulative_window = (
Expand Down Expand Up @@ -1359,9 +1358,9 @@ def _build_aggregated_measure_from_measure_source_node(
)

join_targets = measure_recipe.join_targets
unaggregated_measure_node: BaseOutput
unaggregated_measure_node: DataflowPlanNode
if len(join_targets) > 0:
filtered_measures_with_joined_elements = JoinToBaseOutputNode(
filtered_measures_with_joined_elements = JoinOnEntitiesNode(
left_node=filtered_measure_source_node,
join_targets=join_targets,
)
Expand All @@ -1388,7 +1387,7 @@ def _build_aggregated_measure_from_measure_source_node(
unaggregated_measure_node, time_range_constraint
)

pre_aggregate_node: BaseOutput = cumulative_metric_constrained_node or unaggregated_measure_node
pre_aggregate_node: DataflowPlanNode = cumulative_metric_constrained_node or unaggregated_measure_node
merged_where_filter_spec = WhereFilterSpec.merge_iterable(metric_input_measure_spec.filter_specs)
if len(metric_input_measure_spec.filter_specs) > 0:
# Apply where constraint on the node
Expand Down Expand Up @@ -1444,7 +1443,7 @@ def _build_aggregated_measure_from_measure_source_node(
f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if "
f"there's a new use case."
)
output_node: BaseOutput = JoinToTimeSpineNode(
output_node: DataflowPlanNode = JoinToTimeSpineNode(
parent_node=aggregate_measures_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
Expand Down
4 changes: 2 additions & 2 deletions metricflow/dataflow/builder/node_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ class DataflowPlanNodeOutputDataSetResolver(DataflowToSqlQueryPlanConverter):
generate a set of nodes that already include the multi-hop dimensions available, the join resolution logic becomes
much simpler. For example, a node like:
<JoinToBaseOutputNode>
<JoinOnEntitiesNode>
<!-- Join dim_users and dim_devices by device_id -->
<ReadSqlSourceNode>
<!-- Read from dim_users to get user_id, device_id -->
<ReadSqlSourceNodes>
<!-- Read from dim_devices device_id, platform -->
</JoinToBaseOutputNode>
</JoinOnEntitiesNode>
would have the dimension user_id__device_id__platform available, so to NodeEvaluatorForLinkableInstances,
it's the same problem as doing a single-hop join. This simplifies the join resolution logic, though now the input
Expand Down
8 changes: 4 additions & 4 deletions metricflow/dataflow/builder/node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
PartitionJoinResolver,
PartitionTimeDimensionJoinDescription,
)
from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_to_base import JoinDescription, ValidityWindowJoinDescription
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
Expand All @@ -59,7 +59,7 @@ class JoinLinkableInstancesRecipe:
satisfiable_linkable_specs.
"""

node_to_join: BaseOutput
node_to_join: DataflowPlanNode
# The entity to join "node_to_join" on. Only nullable if using CROSS JOIN.
join_on_entity: Optional[LinklessEntitySpec]
# The linkable instances from the query that can be satisfied if we join this node. Note that this is different from
Expand Down Expand Up @@ -168,7 +168,7 @@ class NodeEvaluatorForLinkableInstances:
def __init__(
self,
semantic_model_lookup: SemanticModelLookup,
nodes_available_for_joins: Sequence[BaseOutput],
nodes_available_for_joins: Sequence[DataflowPlanNode],
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver,
time_spine_node: MetricTimeDimensionTransformNode,
) -> None:
Expand Down Expand Up @@ -387,7 +387,7 @@ def _update_candidates_that_can_satisfy_linkable_specs(

def evaluate_node(
self,
left_node: BaseOutput,
left_node: DataflowPlanNode,
required_linkable_specs: Sequence[LinkableInstanceSpec],
default_join_type: SqlJoinType,
) -> LinkableInstanceSatisfiabilityEvaluation:
Expand Down
12 changes: 6 additions & 6 deletions metricflow/dataflow/builder/source_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow_semantics.specs.query_spec import MetricFlowQuerySpec
from metricflow_semantics.specs.spec_classes import GroupByMetricSpec

from metricflow.dataflow.dataflow_plan import BaseOutput
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataset.convert_semantic_model import SemanticModelToDataSetConverter
Expand All @@ -30,17 +30,17 @@ class SourceNodeSet:
# mapped to components with a transformation node to add `metric_time` / to support multiple aggregation time
# dimensions. Each semantic model containing measures with k different aggregation time dimensions is mapped to k
# components.
source_nodes_for_metric_queries: Tuple[BaseOutput, ...]
source_nodes_for_metric_queries: Tuple[DataflowPlanNode, ...]

# Semantic models are 1:1 mapped to a ReadSqlSourceNode. The tuple also contains the same `time_spine_node` as
# below. See usage in `DataflowPlanBuilder`.
source_nodes_for_group_by_item_queries: Tuple[BaseOutput, ...]
source_nodes_for_group_by_item_queries: Tuple[DataflowPlanNode, ...]

# Provides the time spine.
time_spine_node: MetricTimeDimensionTransformNode

@property
def all_nodes(self) -> Sequence[BaseOutput]: # noqa: D102
def all_nodes(self) -> Sequence[DataflowPlanNode]: # noqa: D102
return (
self.source_nodes_for_metric_queries + self.source_nodes_for_group_by_item_queries + (self.time_spine_node,)
)
Expand All @@ -67,8 +67,8 @@ def __init__( # noqa: D107

def create_from_data_sets(self, data_sets: Sequence[SemanticModelDataSet]) -> SourceNodeSet:
"""Creates a `SourceNodeSet` from SemanticModelDataSets."""
group_by_item_source_nodes: List[BaseOutput] = []
source_nodes_for_metric_queries: List[BaseOutput] = []
group_by_item_source_nodes: List[DataflowPlanNode] = []
source_nodes_for_metric_queries: List[DataflowPlanNode] = []

for data_set in data_sets:
read_node = ReadSqlSourceNode(data_set)
Expand Down
Loading

0 comments on commit 8ba3897

Please sign in to comment.