From aeb12d01d51232e5e4e59bfb13d2e10e82c4212f Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 18 Dec 2024 12:32:27 -0800 Subject: [PATCH] fixup! Add new dataflow plan nodes for custom offset windows --- metricflow/execution/dataflow_to_execution.py | 12 + metricflow/plan_conversion/dataflow_to_sql.py | 306 +++++++++++++++++- 2 files changed, 314 insertions(+), 4 deletions(-) diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index 1c2aa2fd9..8d1f5cfb3 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -16,6 +16,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -24,6 +25,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -205,3 +207,13 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod @override def visit_alias_specs_node(self, node: AliasSpecsNode) -> ConvertToExecutionPlanResult: raise NotImplementedError + + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> ConvertToExecutionPlanResult: + raise NotImplementedError + + @override + def visit_offset_by_custom_granularity_node( + self, node: OffsetByCustomGranularityNode + ) -> ConvertToExecutionPlanResult: + raise NotImplementedError diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index cd6b92a2b..aa4e34e61 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.naming.keywords import DUNDER from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType @@ -38,8 +39,12 @@ from metricflow_semantics.specs.spec_set import InstanceSpecSet from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.sql.sql_exprs import ( + SqlAddTimeExpression, SqlAggregateFunctionExpression, + SqlArithmeticExpression, + SqlArithmeticOperator, SqlBetweenExpression, + SqlCaseExpression, SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, @@ -50,6 +55,7 @@ SqlFunction, SqlFunctionExpression, SqlGenerateUuidExpression, + SqlIntegerExpression, SqlLogicalExpression, SqlLogicalOperator, SqlRatioComputationExpression, @@ -77,6 +83,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -85,6 +92,7 @@ from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode from metricflow.dataflow.nodes.min_max import MinMaxNode +from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode @@ -1268,9 +1276,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr spec=metric_time_dimension_spec, ) ) - output_column_to_input_column[ - metric_time_dimension_column_association.column_name - ] = matching_time_dimension_instance.associated_column.column_name + output_column_to_input_column[metric_time_dimension_column_association.column_name] = ( + matching_time_dimension_instance.associated_column.column_name + ) output_instance_set = InstanceSet( measure_instances=tuple(output_measure_instances), @@ -1888,7 +1896,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet: # noqa: D102 from_data_set = node.parent_node.accept(self) - parent_instance_set = from_data_set.instance_set # remove order by col + parent_instance_set = from_data_set.instance_set parent_data_set_alias = self._next_unique_table_alias() metric_instance = None @@ -2015,6 +2023,284 @@ def strip_time_from_dt(ts: dt.datetime) -> dt.datetime: ), ) + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: + """Build columns that will be needed for custom offset windows. + + This includes columns that represent the start and end of a custom grain period, as well as the row number of the base + grain within each period. For example, the columns might look like: + + SELECT + {{ existing columns }}, + FIRST_VALUE(ds) OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__fiscal_quarter__first_value, + LAST_VALUE(ds) OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__fiscal_quarter__last_value, + ROW_NUMBER() OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__day__row_number + FROM time_spine_read_node + """ + parent_data_set = node.parent_node.accept(self) + parent_instance_set = parent_data_set.instance_set + parent_data_set_alias = self._next_unique_table_alias() + + custom_granularity_name = node.custom_granularity_name + time_spine = self._get_time_spine_for_custom_granularity(custom_granularity_name) + custom_grain_instance_from_parent = parent_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=custom_granularity_name, date_part=None + ) + base_grain_instance_from_parent = parent_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity_name=time_spine.base_granularity.value, date_part=None + ) + custom_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, + column_name=custom_grain_instance_from_parent.associated_column.column_name, + ) + base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, column_name=base_grain_instance_from_parent.associated_column.column_name + ) + + new_instances: Tuple[TimeDimensionInstance, ...] = tuple() + new_select_columns: Tuple[SqlSelectColumn, ...] = tuple() + + # Build columns that get start and end of the custom grain period. + # Ex: "FIRST_VALUE(ds) OVER (PARTITION BY martian_day ORDER BY ds) AS ds__fiscal_quarter__first_value" + for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE): + new_instance = custom_grain_instance_from_parent.with_new_spec( + new_spec=custom_grain_instance_from_parent.spec.with_window_function(window_func), + column_association_resolver=self._column_association_resolver, + ) + select_column = SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=window_func, + sql_function_args=(base_column_expr,), + partition_by_args=(custom_column_expr,), + order_by_args=(SqlWindowOrderByArgument(base_column_expr),), + ), + column_alias=new_instance.associated_column.column_name, + ) + new_instances += (new_instance,) + new_select_columns += (select_column,) + + # Build a column that tracks the row number for the base grain column within the custom grain period. + # This will be offset by 1 to represent the number of base grain periods since the start of the custom grain period. + # Ex: "ROW_NUMBER() OVER (PARTITION BY martian_day ORDER BY ds) AS ds__day__row_number" + new_instance = base_grain_instance_from_parent.with_new_spec( + new_spec=base_grain_instance_from_parent.spec.with_window_function(SqlWindowFunction.ROW_NUMBER), + column_association_resolver=self._column_association_resolver, + ) + window_func_expr = SqlWindowFunctionExpression.create( + sql_function=SqlWindowFunction.ROW_NUMBER, + partition_by_args=(custom_column_expr,), + order_by_args=(SqlWindowOrderByArgument(base_column_expr),), + ) + new_select_column = SqlSelectColumn( + expr=window_func_expr, + column_alias=new_instance.associated_column.column_name, + ) + new_instances += (new_instance,) + new_select_columns += (new_select_column,) + + return SqlDataSet( + instance_set=InstanceSet.merge([InstanceSet(time_dimension_instances=new_instances), parent_instance_set]), + sql_select_node=SqlSelectStatementNode.create( + description=node.description, + select_columns=parent_data_set.checked_sql_select_node.select_columns + new_select_columns, + from_source=parent_data_set.checked_sql_select_node, + from_source_alias=parent_data_set_alias, + ), + ) + + def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> SqlDataSet: + """For a given custom grain, offset its base grain by the requested number of custom grain periods. + + Example: if the custom grain is `fiscal_quarter` with a base grain of DAY and we're offsetting by 1 period, the + output SQL should look something like this: + + SELECT + CASE + WHEN DATEADD(day, ds__day__row_number - 1, ds__fiscal_quarter__first_value__offset) <= ds__fiscal_quarter__last_value__offset + THEN DATEADD(day, ds__day__row_number - 1, ds__fiscal_quarter__first_value__offset) + ELSE ds__fiscal_quarter__last_value__offset + END AS date_day + FROM custom_granularity_bounds_node + INNER JOIN filter_elements_node ON filter_elements_node.fiscal_quarter = custom_granularity_bounds_node.fiscal_quarter + """ + bounds_data_set = node.custom_granularity_bounds_node.accept(self) + bounds_instance_set = bounds_data_set.instance_set + bounds_data_set_alias = self._next_unique_table_alias() + filter_elements_data_set = node.filter_elements_node.accept(self) + filter_elements_instance_set = filter_elements_data_set.instance_set + filter_elements_data_set_alias = self._next_unique_table_alias() + offset_window = node.offset_window + custom_grain_name = offset_window.granularity + base_grain = ExpandedTimeGranularity.from_time_granularity( + self._get_time_spine_for_custom_granularity(custom_grain_name).base_granularity + ) + + # Find the required instances in the parent data sets. + first_value_instance: Optional[TimeDimensionInstance] = None + last_value_instance: Optional[TimeDimensionInstance] = None + row_number_instance: Optional[TimeDimensionInstance] = None + custom_grain_instance: Optional[TimeDimensionInstance] = None + base_grain_instance: Optional[TimeDimensionInstance] = None + for instance in filter_elements_instance_set.time_dimension_instances: + if instance.spec.window_function is SqlWindowFunction.FIRST_VALUE: + first_value_instance = instance + elif instance.spec.window_function is SqlWindowFunction.LAST_VALUE: + last_value_instance = instance + elif instance.spec.time_granularity.name == custom_grain_name: + custom_grain_instance = instance + if custom_grain_instance and first_value_instance and last_value_instance: + break + for instance in bounds_instance_set.time_dimension_instances: + if instance.spec.window_function is SqlWindowFunction.ROW_NUMBER: + row_number_instance = instance + elif instance.spec.time_granularity == base_grain and instance.spec.date_part is None: + base_grain_instance = instance + if base_grain_instance and row_number_instance: + break + assert ( + custom_grain_instance + and base_grain_instance + and first_value_instance + and last_value_instance + and row_number_instance + ), ( + "Did not find all required time dimension instances in parent data sets for OffsetByCustomGranularityNode. " + f"This indicates internal misconfiguration. Got custom grain instance: {custom_grain_instance}; base grain " + f"instance: {base_grain_instance}; first value instance: {first_value_instance}; last value instance: " + f"{last_value_instance}; row number instance: {row_number_instance}\n" + f"Available instances:{bounds_instance_set.time_dimension_instances}." + ) + + # First, build a subquery that offsets the first and last value columns. + custom_grain_column_name = custom_grain_instance.associated_column.column_name + custom_grain_column = SqlSelectColumn.from_table_and_column_names( + column_name=custom_grain_column_name, table_alias=filter_elements_data_set_alias + ) + first_value_offset_column, last_value_offset_column = tuple( + SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=SqlWindowFunction.LAG, + sql_function_args=( + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=instance.associated_column.column_name, + table_alias=filter_elements_data_set_alias, + ), + SqlIntegerExpression.create(node.offset_window.count), + ), + order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),), + ), + column_alias=f"{instance.associated_column.column_name}{DUNDER}offset", + ) + for instance in (first_value_instance, last_value_instance) + ) + offset_bounds_subquery_alias = self._next_unique_table_alias() + offset_bounds_subquery = SqlSelectStatementNode.create( + description="Offset Custom Granularity Bounds", + select_columns=(custom_grain_column, first_value_offset_column, last_value_offset_column), + from_source=filter_elements_data_set.checked_sql_select_node, + from_source_alias=filter_elements_data_set_alias, + ) + offset_bounds_subquery_alias = self._next_unique_table_alias() + + # Offset the base column by the requested window. If the offset date is not within the offset custom grain period, + # default to the last value in that period. + new_custom_grain_column = SqlSelectColumn.from_table_and_column_names( + column_name=custom_grain_column_name, table_alias=bounds_data_set_alias + ) + first_value_offset_expr, last_value_offset_expr = [ + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=offset_column.column_alias, table_alias=offset_bounds_subquery_alias + ) + for offset_column in (first_value_offset_column, last_value_offset_column) + ] + offset_base_grain_expr = SqlAddTimeExpression.create( + arg=first_value_offset_expr, + count_expr=SqlArithmeticExpression.create( + left_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_data_set_alias, column_name=row_number_instance.associated_column.column_name + ), + operator=SqlArithmeticOperator.SUBTRACT, + right_expr=SqlIntegerExpression.create(1), + ), + granularity=base_grain.base_granularity, + ) + is_below_last_value_expr = SqlComparisonExpression.create( + left_expr=offset_base_grain_expr, + comparison=SqlComparison.LESS_THAN_OR_EQUALS, + right_expr=last_value_offset_expr, + ) + offset_base_column = SqlSelectColumn( + expr=SqlCaseExpression.create( + when_to_then_exprs={is_below_last_value_expr: offset_base_grain_expr}, + else_expr=last_value_offset_expr, + ), + column_alias=base_grain_instance.associated_column.column_name, + ) + join_desc = SqlJoinDescription( + right_source=offset_bounds_subquery, + right_source_alias=offset_bounds_subquery_alias, + join_type=SqlJoinType.INNER, + on_condition=SqlComparisonExpression.create( + left_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_data_set_alias, column_name=custom_grain_column_name + ), + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=offset_bounds_subquery_alias, column_name=custom_grain_column_name + ), + ), + ) + offset_base_grain_subquery = SqlSelectStatementNode.create( + description=node.description, + select_columns=(new_custom_grain_column, offset_base_column), + from_source=bounds_data_set.checked_sql_select_node, + from_source_alias=bounds_data_set_alias, + join_descs=(join_desc,), + ) + offset_base_grain_subquery_alias = self._next_unique_table_alias() + + # Apply standard grains & date parts requested in the query. Use base grain for any custom grains. + standard_grain_instances: Tuple[TimeDimensionInstance, ...] = () + standard_grain_columns: Tuple[SqlSelectColumn, ...] = () + base_column = SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + column_name=base_grain_instance.associated_column.column_name, + table_alias=offset_base_grain_subquery_alias, + ), + column_alias=base_grain_instance.associated_column.column_name, + ) + base_grain_requested = False + for spec in node.required_time_spine_specs: + new_instance = base_grain_instance.with_new_spec( + new_spec=spec, column_association_resolver=self._column_association_resolver + ) + standard_grain_instances += (new_instance,) + if spec.date_part: + expr: SqlExpressionNode = SqlExtractExpression.create(date_part=spec.date_part, arg=base_column.expr) + elif spec.time_granularity.base_granularity == base_grain.base_granularity: + expr = base_column.expr + base_grain_requested = True + else: + expr = SqlDateTruncExpression.create( + time_granularity=spec.time_granularity.base_granularity, arg=base_column.expr + ) + standard_grain_columns += ( + SqlSelectColumn(expr=expr, column_alias=new_instance.associated_column.column_name), + ) + if not base_grain_requested: + assert 0 + standard_grain_instances = (base_grain_instance,) + standard_grain_instances + standard_grain_columns = (base_column,) + standard_grain_columns + + return SqlDataSet( + instance_set=InstanceSet(time_dimension_instances=standard_grain_instances), + sql_select_node=SqlSelectStatementNode.create( + description="Apply Requested Granularities", + select_columns=standard_grain_columns, + from_source=offset_base_grain_subquery, + from_source_alias=offset_base_grain_subquery_alias, + ), + ) + class DataflowNodeToSqlCteVisitor(DataflowNodeToSqlSubqueryVisitor): """Similar to `DataflowNodeToSqlSubqueryVisitor`, except that this converts specific nodes to CTEs. @@ -2210,5 +2496,17 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> SqlDataSet: # noqa: D102 return self._default_handler(node=node, node_to_select_subquery_function=super().visit_alias_specs_node) + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: # noqa: D102 + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_custom_granularity_bounds_node + ) + + @override + def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularityNode) -> SqlDataSet: # noqa: D102 + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_offset_by_custom_granularity_node + ) + DataflowNodeT = TypeVar("DataflowNodeT", bound=DataflowPlanNode)