From e27d822e92d1722a64625d7d9da91c4be9e037b5 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Thu, 5 Sep 2024 11:53:31 -0700 Subject: [PATCH] Stub out simple DataFlowPlan visitor methods --- metricflow/dataflow/dataflow_plan.py | 5 +++++ .../dataflow/optimizer/predicate_pushdown_optimizer.py | 7 +++++++ .../dataflow/optimizer/source_scan/cm_branch_combiner.py | 7 +++++++ .../optimizer/source_scan/source_scan_optimizer.py | 7 +++++++ metricflow/execution/dataflow_to_execution.py | 5 +++++ metricflow/plan_conversion/dataflow_to_sql.py | 4 ++++ .../optimizer/source_scan/test_source_scan_optimizer.py | 4 ++++ 7 files changed, 39 insertions(+) diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index e324fdd4bb..b986d7b8f6 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -26,6 +26,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode + from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode @@ -187,6 +188,10 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> VisitorOutputT: # noqa: D102 pass + @abstractmethod + def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> VisitorOutputT: # noqa: D102 + pass + class DataflowPlan(MetricFlowDag[DataflowPlanNode]): """Describes the flow of metric data as it goes from source nodes to sink nodes in the graph.""" diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 93ae2e1958..4fcf4666e7 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -25,6 +25,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode +from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode @@ -458,6 +459,12 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O ) ) + def visit_join_to_custom_granularity_node( # noqa: D102 + self, node: JoinToCustomGranularityNode + ) -> OptimizeBranchResult: + self._log_visit_node_type(node) + return self._default_handler(node) + def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult: """Handles pushdown state propagation for the standard join node type. diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 1882898860..4e786e8f48 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -19,6 +19,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode +from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode @@ -450,6 +451,12 @@ def visit_join_conversion_events_node( # noqa: D102 self._log_visit_node_type(node) return self._default_handler(node) + def visit_join_to_custom_granularity_node( # noqa: D102 + self, node: JoinToCustomGranularityNode + ) -> ComputeMetricsBranchCombinerResult: + self._log_visit_node_type(node) + return self._default_handler(node) + def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102 self._log_visit_node_type(node) return self._default_handler(node) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index 5fa9fc4602..016da5d0cc 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -21,6 +21,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode +from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode @@ -315,6 +316,12 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> O self._log_visit_node_type(node) return self._default_base_output_handler(node) + def visit_join_to_custom_granularity_node( # noqa: D102 + self, node: JoinToCustomGranularityNode + ) -> OptimizeBranchResult: + self._log_visit_node_type(node) + return self._default_base_output_handler(node) + def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) return self._default_base_output_handler(node) diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index 4df6cc0f21..8f2bb1de80 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -18,6 +18,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode +from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode @@ -189,3 +190,7 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) @override def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> ConvertToExecutionPlanResult: raise NotImplementedError + + @override + def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> ConvertToExecutionPlanResult: + raise NotImplementedError diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 79c72fee2b..7b168c4f0c 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -55,6 +55,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode +from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode @@ -1435,6 +1436,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet ), ) + def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> SqlDataSet: # noqa: D102 + raise NotImplementedError # TODO in later commit + def visit_min_max_node(self, node: MinMaxNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_table_alias = self._next_unique_table_alias() diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index c8cbb851c6..974e9a1ef3 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -30,6 +30,7 @@ from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinOnEntitiesNode +from metricflow.dataflow.nodes.join_to_custom_granularity import JoinToCustomGranularityNode from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode @@ -110,6 +111,9 @@ def visit_add_generated_uuid_column_node(self, node: AddGeneratedUuidColumnNode) def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> int: # noqa: D102 return self._sum_parents(node) + def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNode) -> int: # noqa: D102 + return self._sum_parents(node) + def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102 return dataflow_plan.sink_node.accept(self)