Skip to content

Commit

Permalink
Stub out simple DataFlowPlan visitor methods
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 5, 2024
1 parent 2aaf239 commit e27d822
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 0 deletions.
5 changes: 5 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -187,6 +188,10 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode)
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> VisitorOutputT: # noqa: D102
pass


class DataflowPlan(MetricFlowDag[DataflowPlanNode]):
"""Describes the flow of metric data as it goes from source nodes to sink nodes in the graph."""
Expand Down
7 changes: 7 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -458,6 +459,12 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O
)
)

def visit_join_to_custom_granularity_node( # noqa: D102
self, node: JoinToCustomGranularityNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -450,6 +451,12 @@ def visit_join_conversion_events_node( # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_to_custom_granularity_node( # noqa: D102
self, node: JoinToCustomGranularityNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -315,6 +316,12 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_join_to_custom_granularity_node( # noqa: D102
self, node: JoinToCustomGranularityNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
5 changes: 5 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -189,3 +190,7 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode)
@override
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError
4 changes: 4 additions & 0 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -1435,6 +1436,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
),
)

def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet: # noqa: D102
raise NotImplementedError # TODO in later commit

def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102
parent_data_set = node.parent_node.accept(self)
parent_table_alias = self._next_unique_table_alias()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode
from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
Expand Down Expand Up @@ -110,6 +111,9 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode)
def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> int: # noqa: D102
return self._sum_parents(node)

def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> int: # noqa: D102
return self._sum_parents(node)

def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102
return dataflow_plan.sink_node.accept(self)

Expand Down

0 comments on commit e27d822

Please sign in to comment.