Skip to content

Commit

Permalink
Visitor methods for WindowReaggregation node - primarily SQL renderin…
Browse files Browse the repository at this point in the history
…g logic
  • Loading branch information
courtneyholcomb committed Jun 13, 2024
1 parent 53dc266 commit efde7f5
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 10 deletions.
5 changes: 5 additions & 0 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode

Expand Down Expand Up @@ -137,6 +138,10 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> VisitorO
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> VisitorOutputT: # noqa: D102
pass

@abstractmethod
def visit_order_by_limit_node(self, node: OrderByLimitNode) -> VisitorOutputT: # noqa: D102
pass
Expand Down
25 changes: 15 additions & 10 deletions metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform
Expand Down Expand Up @@ -327,13 +328,19 @@ def _handle_unsupported_node(self, current_right_node: DataflowPlanNode) -> Comp
)
return ComputeMetricsBranchCombinerResult()

def visit_window_reaggregation_node( # noqa: D102
self, node: WindowReaggregationNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

def visit_order_by_limit_node(self, node: OrderByLimitNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

def visit_where_constraint_node( # noqa: D102
self, node: WhereConstraintNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

Expand All @@ -345,13 +352,11 @@ def visit_write_to_result_data_table_node( # noqa: D102

def visit_write_to_result_table_node( # noqa: D102
self, node: WriteToResultTableNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._handle_unsupported_node(node)

def visit_filter_elements_node( # noqa: D102
self, node: FilterElementsNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
def visit_filter_elements_node(self, node: FilterElementsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)

current_right_node = node
Expand Down Expand Up @@ -403,19 +408,19 @@ def visit_combine_aggregated_outputs_node( # noqa: D102

def visit_constrain_time_range_node( # noqa: D102
self, node: ConstrainTimeRangeNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_join_over_time_range_node( # noqa: D102
self, node: JoinOverTimeRangeNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_semi_additive_join_node( # noqa: D102
self, node: SemiAdditiveJoinNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

Expand All @@ -427,7 +432,7 @@ def visit_metric_time_dimension_transform_node( # noqa: D102

def visit_join_to_time_spine_node( # noqa: D102
self, node: JoinToTimeSpineNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

Expand All @@ -439,7 +444,7 @@ def visit_add_generated_uuid_column_node( # noqa: D102

def visit_join_conversion_events_node( # noqa: D102
self, node: JoinConversionEventsNode
) -> ComputeMetricsBranchCombinerResult: # noqa: D102
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer
Expand Down Expand Up @@ -137,6 +138,10 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> Optimize
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
# Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG.
Expand Down
5 changes: 5 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.execution.convert_to_execution_plan import ConvertToExecutionPlanResult
Expand Down Expand Up @@ -133,6 +134,10 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> ConvertT
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_order_by_limit_node(self, node: OrderByLimitNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError
Expand Down
103 changes: 103 additions & 0 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataset.dataset_classes import DataSet
Expand Down Expand Up @@ -1610,3 +1611,105 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S
from_source_alias=output_data_set_alias,
),
)

def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet: # noqa: D102
from_data_set = node.parent_node.accept(self)
parent_instance_set = from_data_set.instance_set # remove order by col
parent_data_set_alias = self._next_unique_table_alias()

metric_instance = None
for parent_metric_instance in parent_instance_set.metric_instances:
if parent_metric_instance.spec == node.metric_spec:
metric_instance = parent_metric_instance
break
assert metric_instance, (
"Did not receive appropriate metric instance to render SQL for WindowReaggregationNode. "
f"Got: {parent_instance_set.metric_instances}. Expected instance matching spec: "
f"{node.metric_spec}"
)

order_by_instance = None
partition_by_instance = None
for time_dimension_instance in parent_instance_set.time_dimension_instances:
if time_dimension_instance.spec == node.order_by_time_dimension_spec:
order_by_instance = time_dimension_instance
if time_dimension_instance.spec == node.partition_by_time_dimension_spec:
partition_by_instance = time_dimension_instance
assert order_by_instance and partition_by_instance, (
"Did not receive appropriate time dimension instances to render SQL for WindowReaggregationNode. "
f"Got: {parent_instance_set.time_dimension_instances}. Expected instances matching specs: "
f"{[node.order_by_time_dimension_spec, node.partition_by_time_dimension_spec]}"
)

# Pending DSI upgrade:
# sql_window_function = SqlWindowFunction[
# self._metric_lookup.get_metric(
# metric_instance.spec.reference
# ).type_params.cumulative_type_params.period_agg.name
# ]
sql_window_function = SqlWindowFunction.FIRST_VALUE # placeholder for now
order_by_args = []
if sql_window_function.requires_ordering:
order_by_args.append(
SqlWindowOrderByArgument(
expr=SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_data_set_alias,
column_name=order_by_instance.associated_column.column_name,
),
)
)
metric_select_column = SqlSelectColumn(
expr=SqlWindowFunctionExpression(
sql_function=sql_window_function,
sql_function_args=[
SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_data_set_alias, column_name=metric_instance.associated_column.column_name
)
],
partition_by_args=[
SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=parent_data_set_alias,
column_name=partition_by_instance.associated_column.column_name,
)
],
order_by_args=order_by_args,
),
column_alias=metric_instance.associated_column.column_name,
)

# Order by instance should not be included in the output dataset unless it was also included in the request,
# in which case it is the partition by instance.
if order_by_instance != partition_by_instance:
output_instance_set = parent_instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=(order_by_instance.spec,)))
)

# Can't include window function in a group by, so we use a subquery and apply group by in the outer query.
subquery_select_columns = output_instance_set.transform(
FilterElements(exclude_specs=InstanceSpecSet(metric_specs=(metric_instance.spec,)))
).transform(
CreateSelectColumnsForInstances(parent_data_set_alias, self._column_association_resolver)
).as_tuple() + (
metric_select_column,
)
subquery = SqlSelectStatementNode(
description="", # description included in outer query
select_columns=subquery_select_columns,
from_source=from_data_set.checked_sql_select_node,
from_source_alias=parent_data_set_alias,
)
subquery_alias = self._next_unique_table_alias()

outer_query_select_columns = output_instance_set.transform(
CreateSelectColumnsForInstances(subquery_alias, self._column_association_resolver)
).as_tuple()
return SqlDataSet(
instance_set=output_instance_set,
sql_select_node=SqlSelectStatementNode(
description=node.description,
select_columns=outer_query_select_columns,
from_source=subquery,
from_source_alias=subquery_alias,
group_bys=outer_query_select_columns,
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
from metricflow.dataflow.nodes.where_filter import WhereConstraintNode
from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode
from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode
from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode
from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import SourceScanOptimizer
Expand Down Expand Up @@ -66,6 +67,9 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> int: #
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> int: # noqa: D102
return self._sum_parents(node)

def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> int: # noqa: D102
return self._sum_parents(node)

def visit_order_by_limit_node(self, node: OrderByLimitNode) -> int: # noqa: D102
return self._sum_parents(node)

Expand Down

0 comments on commit efde7f5

Please sign in to comment.