diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index ecf2fc427b..8b113a5de8 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -32,6 +32,7 @@ from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode + from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode @@ -137,6 +138,10 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> VisitorO def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> VisitorOutputT: # noqa: D102 pass + @abstractmethod + def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> VisitorOutputT: # noqa: D102 + pass + @abstractmethod def visit_order_by_limit_node(self, node: OrderByLimitNode) -> VisitorOutputT: # noqa: D102 pass diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 5d18f3121c..1124f00498 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -26,6 +26,7 @@ from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode +from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode from metricflow.dataflow.optimizer.source_scan.matching_linkable_specs import MatchingLinkableSpecsTransform @@ -327,13 +328,19 @@ def _handle_unsupported_node(self, current_right_node: DataflowPlanNode) -> Comp ) return ComputeMetricsBranchCombinerResult() + def visit_window_reaggregation_node( # noqa: D102 + self, node: WindowReaggregationNode + ) -> ComputeMetricsBranchCombinerResult: + self._log_visit_node_type(node) + return self._handle_unsupported_node(node) + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102 self._log_visit_node_type(node) return self._handle_unsupported_node(node) def visit_where_constraint_node( # noqa: D102 self, node: WhereConstraintNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) @@ -345,13 +352,11 @@ def visit_write_to_result_data_table_node( # noqa: D102 def visit_write_to_result_table_node( # noqa: D102 self, node: WriteToResultTableNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._handle_unsupported_node(node) - def visit_filter_elements_node( # noqa: D102 - self, node: FilterElementsNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + def visit_filter_elements_node(self, node: FilterElementsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102 self._log_visit_node_type(node) current_right_node = node @@ -403,19 +408,19 @@ def visit_combine_aggregated_outputs_node( # noqa: D102 def visit_constrain_time_range_node( # noqa: D102 self, node: ConstrainTimeRangeNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) def visit_join_over_time_range_node( # noqa: D102 self, node: JoinOverTimeRangeNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) def visit_semi_additive_join_node( # noqa: D102 self, node: SemiAdditiveJoinNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) @@ -427,7 +432,7 @@ def visit_metric_time_dimension_transform_node( # noqa: D102 def visit_join_to_time_spine_node( # noqa: D102 self, node: JoinToTimeSpineNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) @@ -439,7 +444,7 @@ def visit_add_generated_uuid_column_node( # noqa: D102 def visit_join_conversion_events_node( # noqa: D102 self, node: JoinConversionEventsNode - ) -> ComputeMetricsBranchCombinerResult: # noqa: D102 + ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index 1fd708da21..e7f87aa992 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -28,6 +28,7 @@ from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode +from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode from metricflow.dataflow.optimizer.dataflow_plan_optimizer import DataflowPlanOptimizer @@ -137,6 +138,10 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> Optimize self._log_visit_node_type(node) return self._default_base_output_handler(node) + def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> OptimizeBranchResult: # noqa: D102 + self._log_visit_node_type(node) + return self._default_base_output_handler(node) + def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) # Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG. diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index fa0b7ccbbf..b553590465 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -25,6 +25,7 @@ from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode +from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode from metricflow.execution.convert_to_execution_plan import ConvertToExecutionPlanResult @@ -133,6 +134,10 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> ConvertT def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> ConvertToExecutionPlanResult: raise NotImplementedError + @override + def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> ConvertToExecutionPlanResult: + raise NotImplementedError + @override def visit_order_by_limit_node(self, node: OrderByLimitNode) -> ConvertToExecutionPlanResult: raise NotImplementedError diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index ac5ad7739c..8d5a4f68d4 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -60,6 +60,7 @@ from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode +from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode from metricflow.dataset.dataset_classes import DataSet @@ -1610,3 +1611,105 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S from_source_alias=output_data_set_alias, ), ) + + def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet: # noqa: D102 + from_data_set = node.parent_node.accept(self) + parent_instance_set = from_data_set.instance_set # remove order by col + parent_data_set_alias = self._next_unique_table_alias() + + metric_instance = None + for parent_metric_instance in parent_instance_set.metric_instances: + if parent_metric_instance.spec == node.metric_spec: + metric_instance = parent_metric_instance + break + assert metric_instance, ( + "Did not receive appropriate metric instance to render SQL for WindowReaggregationNode. " + f"Got: {parent_instance_set.metric_instances}. Expected instance matching spec: " + f"{node.metric_spec}" + ) + + order_by_instance = None + partition_by_instance = None + for time_dimension_instance in parent_instance_set.time_dimension_instances: + if time_dimension_instance.spec == node.order_by_time_dimension_spec: + order_by_instance = time_dimension_instance + if time_dimension_instance.spec == node.partition_by_time_dimension_spec: + partition_by_instance = time_dimension_instance + assert order_by_instance and partition_by_instance, ( + "Did not receive appropriate time dimension instances to render SQL for WindowReaggregationNode. " + f"Got: {parent_instance_set.time_dimension_instances}. Expected instances matching specs: " + f"{[node.order_by_time_dimension_spec, node.partition_by_time_dimension_spec]}" + ) + + # Pending DSI upgrade: + # sql_window_function = SqlWindowFunction[ + # self._metric_lookup.get_metric( + # metric_instance.spec.reference + # ).type_params.cumulative_type_params.period_agg.name + # ] + sql_window_function = SqlWindowFunction.FIRST_VALUE # placeholder for now + order_by_args = [] + if sql_window_function.requires_ordering: + order_by_args.append( + SqlWindowOrderByArgument( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, + column_name=order_by_instance.associated_column.column_name, + ), + ) + ) + metric_select_column = SqlSelectColumn( + expr=SqlWindowFunctionExpression( + sql_function=sql_window_function, + sql_function_args=[ + SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, column_name=metric_instance.associated_column.column_name + ) + ], + partition_by_args=[ + SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, + column_name=partition_by_instance.associated_column.column_name, + ) + ], + order_by_args=order_by_args, + ), + column_alias=metric_instance.associated_column.column_name, + ) + + # Order by instance should not be included in the output dataset unless it was also included in the request, + # in which case it is the partition by instance. + if order_by_instance != partition_by_instance: + output_instance_set = parent_instance_set.transform( + FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=(order_by_instance.spec,))) + ) + + # Can't include window function in a group by, so we use a subquery and apply group by in the outer query. + subquery_select_columns = output_instance_set.transform( + FilterElements(exclude_specs=InstanceSpecSet(metric_specs=(metric_instance.spec,))) + ).transform( + CreateSelectColumnsForInstances(parent_data_set_alias, self._column_association_resolver) + ).as_tuple() + ( + metric_select_column, + ) + subquery = SqlSelectStatementNode( + description="", # description included in outer query + select_columns=subquery_select_columns, + from_source=from_data_set.checked_sql_select_node, + from_source_alias=parent_data_set_alias, + ) + subquery_alias = self._next_unique_table_alias() + + outer_query_select_columns = output_instance_set.transform( + CreateSelectColumnsForInstances(subquery_alias, self._column_association_resolver) + ).as_tuple() + return SqlDataSet( + instance_set=output_instance_set, + sql_select_node=SqlSelectStatementNode( + description=node.description, + select_columns=outer_query_select_columns, + from_source=subquery, + from_source_alias=subquery_alias, + group_bys=outer_query_select_columns, + ), + ) diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index dba5380f04..f6e10745dd 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -39,6 +39,7 @@ from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode from metricflow.dataflow.nodes.where_filter import WhereConstraintNode +from metricflow.dataflow.nodes.window_reaggregation_node import WindowReaggregationNode from metricflow.dataflow.nodes.write_to_data_table import WriteToResultDataTableNode from metricflow.dataflow.nodes.write_to_table import WriteToResultTableNode from metricflow.dataflow.optimizer.source_scan.source_scan_optimizer import SourceScanOptimizer @@ -66,6 +67,9 @@ def visit_aggregate_measures_node(self, node: AggregateMeasuresNode) -> int: # def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> int: # noqa: D102 return self._sum_parents(node) + def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> int: # noqa: D102 + return self._sum_parents(node) + def visit_order_by_limit_node(self, node: OrderByLimitNode) -> int: # noqa: D102 return self._sum_parents(node)