Skip to content

Commit

Permalink
/* PR_START p--cte 14 */ Move DataflowPlanNodeVisitor to a separate…
Browse files Browse the repository at this point in the history
… file.
  • Loading branch information
plypaul committed Nov 10, 2024
1 parent 7f0c36f commit b18d395
Show file tree
Hide file tree
Showing 28 changed files with 162 additions and 139 deletions.
115 changes: 2 additions & 113 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import FrozenSet, Generic, Optional, Sequence, Set, Type, TypeVar
from typing import FrozenSet, Optional, Sequence, Set, Type, TypeVar

import more_itertools
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
Expand All @@ -17,27 +17,7 @@
from dbt_semantic_interfaces.references import SemanticModelReference
from metricflow_semantics.specs.instance_spec import LinkableInstanceSpec

from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode

from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,97 +82,6 @@ def aggregated_to_elements(self) -> Set[LinkableInstanceSpec]:
return set()


class DataflowPlanNodeVisitor(Generic[VisitorOutputT], ABC):
"""An object that can be used to visit the nodes of a dataflow plan.
Follows the visitor pattern: https://en.wikipedia.org/wiki/Visitor_pattern
All visit* methods are similar and one exists for every type of node in the dataflow plan. The appropriate method
will be called with DataflowPlanNode.accept().
"""

@abstractmethod
def visit_source_node(self, node: ReadSqlSourceNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_order_by_limit_node(self, node: OrderByLimitNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_where_constraint_node(self, node: WhereConstraintNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_filter_elements_node(self, node: FilterElementsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_metric_time_dimension_transform_node( # noqa: D102
self, node: MetricTimeDimensionTransformNode
) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_min_max_node(self, node: MinMaxNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> VisitorOutputT: # noqa: D102
pass


class DataflowPlan(MetricFlowDag[DataflowPlanNode]):
"""Describes the flow of metric data as it goes from source nodes to sink nodes in the graph."""

Expand Down
120 changes: 120 additions & 0 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from typing import Generic

from metricflow_semantics.visitor import VisitorOutputT

if typing.TYPE_CHECKING:
from metricflow.dataflow.nodes.add_generated_uuid import AddGeneratedUuidColumnNode
from metricflow.dataflow.nodes.aggregate_measures import AggregateMeasuresNode
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode


class DataflowPlanNodeVisitor(Generic[VisitorOutputT], ABC):
"""An object that can be used to visit the nodes of a dataflow plan.
Follows the visitor pattern: https://en.wikipedia.org/wiki/Visitor_pattern
All visit* methods are similar and one exists for every type of node in the dataflow plan. The appropriate method
will be called with DataflowPlanNode.accept().
"""

@abstractmethod
def visit_source_node(self, node: ReadSqlSourceNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_order_by_limit_node(self, node: OrderByLimitNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_where_constraint_node(self, node: WhereConstraintNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_filter_elements_node(self, node: FilterElementsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_combine_aggregated_outputs_node(self, node: CombineAggregatedOutputsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_metric_time_dimension_transform_node( # noqa: D102
self, node: MetricTimeDimensionTransformNode
) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_min_max_node(self, node: MinMaxNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> VisitorOutputT: # noqa: D102
pass
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/add_generated_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/aggregate_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from metricflow_semantics.specs.measure_spec import MetricInputMeasureSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/combine_aggregated_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
DataflowPlanNodeVisitor,
)
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
DataflowPlanNodeVisitor,
)
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/constrain_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.aggregate_measures import DataflowPlanNode


Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/filter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/join_conversion_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/join_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/join_to_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
PartitionDimensionJoinDescription,
PartitionTimeDimensionJoinDescription,
)
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/join_to_time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/metric_time_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
3 changes: 2 additions & 1 deletion metricflow/dataflow/nodes/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor
from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/order_by_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from metricflow.dataflow.dataflow_plan import (
DataflowPlanNode,
DataflowPlanNodeVisitor,
)
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
Expand Down
Loading

0 comments on commit b18d395

Please sign in to comment.