From 3ae56dcbb6f4a1e84a9dbe5360f409d1b978555c Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Mon, 9 Dec 2024 21:01:07 -0800 Subject: [PATCH] WIP --- .../dataflow/builder/dataflow_plan_builder.py | 2 +- metricflow/dataflow/dataflow_plan_visitor.py | 9 + .../nodes/custom_granularity_bounds.py | 22 +- .../optimizer/predicate_pushdown_optimizer.py | 6 + .../source_scan/cm_branch_combiner.py | 7 + .../source_scan/source_scan_optimizer.py | 7 + metricflow/execution/dataflow_to_execution.py | 5 + metricflow/plan_conversion/dataflow_to_sql.py | 209 +++++++++++++++--- metricflow/sql/sql_exprs.py | 86 ++++++- metricflow/sql/sql_plan.py | 12 +- .../source_scan/test_source_scan_optimizer.py | 4 + x.sql | 18 +- 12 files changed, 326 insertions(+), 61 deletions(-) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 7b63e72d6..cbd29d4f1 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -1913,7 +1913,7 @@ def _build_time_spine_node( pass else: time_spine_node: DataflowPlanNode = CustomGranularityBoundsNode.create( - parent_node=time_spine_read_node, custom_granularity_name=offset_window.granularity + parent_node=time_spine_read_node, offset_window=offset_window ) # # need to add a property to these specs to indicate that they are offset or bounds or something # filtered_bounds_node = FilterElementsNode.create( diff --git a/metricflow/dataflow/dataflow_plan_visitor.py b/metricflow/dataflow/dataflow_plan_visitor.py index 06d88c21d..0efc57cdb 100644 --- a/metricflow/dataflow/dataflow_plan_visitor.py +++ b/metricflow/dataflow/dataflow_plan_visitor.py @@ -14,6 +14,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode + from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -126,6 +127,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> VisitorOutputT: # noqa: D102 raise NotImplementedError + @abstractmethod + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102 + raise NotImplementedError + class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]): """Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node. @@ -222,3 +227,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod @override def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> VisitorOutputT: # noqa: D102 return self._default_handler(node) + + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102 + return self._default_handler(node) diff --git a/metricflow/dataflow/nodes/custom_granularity_bounds.py b/metricflow/dataflow/nodes/custom_granularity_bounds.py index 03529d35e..92f24040a 100644 --- a/metricflow/dataflow/nodes/custom_granularity_bounds.py +++ b/metricflow/dataflow/nodes/custom_granularity_bounds.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Sequence +from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.visitor import VisitorOutputT @@ -12,11 +13,12 @@ from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor +# TODO: rename node & file probably & docstring @dataclass(frozen=True, eq=False) class CustomGranularityBoundsNode(DataflowPlanNode, ABC): """Calculate the start and end of a custom granularity period and each row number within that period.""" - custom_granularity_name: str + offset_window: MetricTimeWindow def __post_init__(self) -> None: # noqa: D105 super().__post_init__() @@ -24,16 +26,15 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 - parent_node: DataflowPlanNode, custom_granularity_name: str + parent_node: DataflowPlanNode, offset_window: MetricTimeWindow ) -> CustomGranularityBoundsNode: - return CustomGranularityBoundsNode(parent_nodes=(parent_node,), custom_granularity_name=custom_granularity_name) + return CustomGranularityBoundsNode(parent_nodes=(parent_node,), offset_window=offset_window) @classmethod def id_prefix(cls) -> IdPrefix: # noqa: D102 return StaticIdPrefix.DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 - # Type checking not working here return visitor.visit_custom_granularity_bounds_node(self) @property @@ -42,24 +43,17 @@ def description(self) -> str: # noqa: D102 @property def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 - return tuple(super().displayed_properties) + ( - DisplayedProperty("custom_granularity_name", self.custom_granularity_name), - ) + return tuple(super().displayed_properties) + (DisplayedProperty("offset_window", self.offset_window),) @property def parent_node(self) -> DataflowPlanNode: # noqa: D102 return self.parent_nodes[0] def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 - return ( - isinstance(other_node, self.__class__) - and other_node.custom_granularity_name == self.custom_granularity_name - ) + return isinstance(other_node, self.__class__) and other_node.offset_window == self.offset_window def with_new_parents( # noqa: D102 self, new_parent_nodes: Sequence[DataflowPlanNode] ) -> CustomGranularityBoundsNode: assert len(new_parent_nodes) == 1 - return CustomGranularityBoundsNode.create( - parent_node=new_parent_nodes[0], custom_granularity_name=self.custom_granularity_name - ) + return CustomGranularityBoundsNode.create(parent_node=new_parent_nodes[0], offset_window=self.offset_window) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index b008c6d2a..1623537fc 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -22,6 +22,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -474,6 +475,11 @@ def visit_transform_time_dimensions_node( # noqa: D102 ) -> OptimizeBranchResult: raise NotImplementedError + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> OptimizeBranchResult: + raise NotImplementedError + def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult: """Handles pushdown state propagation for the standard join node type. diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 6b384f72f..b28e56be6 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -16,6 +16,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -467,3 +468,9 @@ def visit_transform_time_dimensions_node( # noqa: D102 ) -> ComputeMetricsBranchCombinerResult: self._log_visit_node_type(node) return self._default_handler(node) + + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> ComputeMetricsBranchCombinerResult: + self._log_visit_node_type(node) + return self._default_handler(node) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index baa946d8b..9ee3b4019 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -18,6 +18,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -331,3 +332,9 @@ def visit_transform_time_dimensions_node( # noqa: D102 ) -> OptimizeBranchResult: self._log_visit_node_type(node) return self._default_base_output_handler(node) + + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> OptimizeBranchResult: + self._log_visit_node_type(node) + return self._default_base_output_handler(node) diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index bdefd2413..e387ec609 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -15,6 +15,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -205,3 +206,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod @override def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> ConvertToExecutionPlanResult: raise NotImplementedError + + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> ConvertToExecutionPlanResult: + raise NotImplementedError diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 1b2723570..8549cead9 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -6,12 +6,12 @@ from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.naming.keywords import DUNDER from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType from dbt_semantic_interfaces.type_enums.conversion_calculation_type import ConversionCalculationType from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation -from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from dbt_semantic_interfaces.validations.unique_valid_name import MetricFlowReservedKeywords from metricflow_semantics.aggregation_properties import AggregationState from metricflow_semantics.dag.id_prefix import StaticIdPrefix @@ -56,6 +56,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -117,6 +118,7 @@ from metricflow.sql.sql_exprs import ( SqlAggregateFunctionExpression, SqlBetweenExpression, + SqlCaseExpression, SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, @@ -132,6 +134,7 @@ SqlRatioComputationExpression, SqlStringExpression, SqlStringLiteralExpression, + SqlSubtractTimeIntervalExpression, SqlWindowFunction, SqlWindowFunctionExpression, SqlWindowOrderByArgument, @@ -1953,62 +1956,194 @@ def strip_time_from_dt(ts: dt.datetime) -> dt.datetime: ), ) - def visit_cutom_granularity_bounds_node( # noqa: D102 - self, - node: JoinToTimeSpineNode, # TODO: replace with actual node when built - ) -> SqlDataSet: + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: # noqa: D102 from_data_set = node.parent_node.accept(self) parent_instance_set = from_data_set.instance_set parent_data_set_alias = self._next_unique_table_alias() - window_grain = ExpandedTimeGranularity( - name="martian_day", base_granularity=TimeGranularity.DAY - ) # will be node.time_granularity - # Build new columns & instances that calculate the start and end of the custom grain. - parent_window_instance = from_data_set.instance_from_time_dimension_grain_and_date_part( - time_granularity=window_grain, date_part=None - ) + # Get the column names needed to query the custom grain from the time spine where it's defined. + offset_grain_name = node.offset_window.granularity + time_spine = self._get_time_spine_for_custom_granularity(offset_grain_name) + window_column_name = self._get_custom_granularity_column_name(offset_grain_name) window_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=parent_data_set_alias, column_name=self._get_custom_granularity_column_name(window_grain.name) + table_alias=parent_data_set_alias, column_name=window_column_name ) - base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( - table_alias=parent_data_set_alias, - column_name=self._get_time_spine_for_custom_granularity(window_grain.name).base_column, + table_alias=parent_data_set_alias, column_name=time_spine.base_column ) + # Build subquery to get start and end of custom grain period, as well as row number within the period. + parent_window_instance = from_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity=ExpandedTimeGranularity( + name=offset_grain_name, base_granularity=time_spine.base_granularity + ), + date_part=None, + ) agg_state_to_func_args: Dict[AggregationState, Tuple[SqlExpressionNode, ...]] = { - AggregationState.ROW_NUMBER: (), AggregationState.FIRST_VALUE: (base_column_expr,), AggregationState.LAST_VALUE: (base_column_expr,), + AggregationState.ROW_NUMBER: (), } - new_instances: Tuple[TimeDimensionInstance, ...] = () - new_select_columns: Tuple[SqlSelectColumn, ...] = () - for agg_state, func_args in agg_state_to_func_args.items(): - new_instance = parent_window_instance.with_new_spec( - new_spec=parent_window_instance.spec.with_aggregation_state(agg_state), - column_association_resolver=self._column_association_resolver, - ) - new_instances += (new_instance,) - new_select_column = SqlSelectColumn( + bounds_columns = tuple( + SqlSelectColumn( expr=SqlWindowFunctionExpression.create( sql_function=agg_state.sql_function, sql_function_args=func_args, partition_by_args=(window_column_expr,), order_by_args=(SqlWindowOrderByArgument(base_column_expr),), ), - column_alias=new_instance.associated_column.column_name, + column_alias=self._column_association_resolver.resolve_spec( + parent_window_instance.spec.with_aggregation_state(agg_state) + ).column_name, ) - new_select_columns += (new_select_column,) - - return SqlDataSet( - instance_set=InstanceSet.merge([InstanceSet(time_dimension_instances=new_instances), parent_instance_set]), - sql_select_node=SqlSelectStatementNode.create( - description=node.description, - select_columns=new_select_columns + from_data_set.checked_sql_select_node.select_columns, + for agg_state, func_args in agg_state_to_func_args.items() + ) + bounds_cte_alias = self._next_unique_table_alias() + bounds_cte = SqlCteNode.create( + SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=from_data_set.checked_sql_select_node.select_columns + bounds_columns, from_source=from_data_set.checked_sql_select_node, from_source_alias=parent_data_set_alias, ), + cte_alias=bounds_cte_alias, + ) + + # Build a subquery to get a unique row for each custom grain along with its start date & end date. + unique_bounds_columns = tuple( + SqlSelectColumn.from_table_and_column_names(table_alias=bounds_cte_alias, column_name=alias) + for alias in [offset_grain_name] + [column.column_alias for column in bounds_columns[:-1]] + ) + unique_bounds_subquery_alias = self._next_unique_table_alias() + unique_bounds_subquery = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=unique_bounds_columns, + from_source=bounds_cte, # need? can I make this optional if CTEs are present? + from_source_alias=bounds_cte_alias, + cte_sources=(bounds_cte,), + group_bys=unique_bounds_columns, + ) + + # Build a subquery to offset the start and end dates by the requested offset_window. + custom_grain_column = SqlSelectColumn.from_table_and_column_names( + column_name=offset_grain_name, table_alias=unique_bounds_subquery_alias + ) + offset_bounds_columns = tuple( + SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=SqlWindowFunction.LAG, + sql_function_args=( + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=column.column_alias, table_alias=unique_bounds_subquery_alias + ), + SqlStringExpression.create(str(node.offset_window.count)), + ), + order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),), + ), + column_alias=f"{column.column_alias}{DUNDER}offset", # TODO: finalize this alias + ) + for column in unique_bounds_columns + ) + offset_bounds_subquery_alias = self._next_unique_table_alias() + offset_bounds_subquery = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=custom_grain_column + offset_bounds_columns, + from_source=unique_bounds_subquery, + from_source_alias=unique_bounds_subquery_alias, + ) + + # Use the row number calculated above to offset the time spine's base column by the requested window. + # If the offset date is not within the offset custom grain period, default to the last value in that period. + custom_grain_column_2 = SqlSelectColumn.from_table_and_column_names( + column_name=offset_grain_name, table_alias=unique_bounds_subquery_alias + ) # TODO: better variable name + # TODO: Get time spine specs in this node. If any have the base grain, use any of those specs. + # Else, default to metric time as below. + base_grain_spec = DataSet.metric_time_dimension_spec(time_spine.base_granularity) + base_grain_spec_column_name = self._column_association_resolver.resolve_spec(base_grain_spec) + offset_start, offset_end = [ + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=offset_bound_column.column_alias, table_alias=offset_bounds_subquery_alias + ) + for offset_bound_column in offset_bounds_columns + ] + add_row_number_expr = SqlSubtractTimeIntervalExpression.create_with_count_expr( + arg=offset_start, + count_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_cte_alias, column_name=bounds_columns[-1].column_alias + ), + ) + below_end_date_expr = SqlComparisonExpression.create( + left_expr=add_row_number_expr, comparison=SqlComparison.LESS_THAN_OR_EQUALS, right_expr=add_row_number_expr + ) + offset_base_column = SqlSelectColumn( + expr=SqlCaseExpression.create( + when_to_then_exprs={below_end_date_expr: add_row_number_expr}, + else_expr=offset_end, + column_alias=base_grain_spec_column_name, + ) + ) + join_desc = SqlJoinDescription( + right_source=offset_bounds_subquery, + right_source_alias=offset_bounds_subquery_alias, + join_type=SqlJoinType.INNER, + on_condition=SqlComparisonExpression.create( + left_expr=custom_grain_column_2.expr, + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_cte_alias, column_name=offset_grain_name + ), + ), + ) + offset_base_column_subquery_alias = self._next_unique_table_alias() + output_select_node = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=(custom_grain_column_2, offset_base_column), + from_source=bounds_cte, # need? + from_source_alias=bounds_cte_alias, + join_descs=(join_desc,), + cte_sources=(bounds_cte,), + ) + + # Apply any standard grains that were requested. + # TODO: add a conditional here if there are standard grains requested besides the base grain + base_grain_column = SqlSelectColumn.from_table_and_column_names( + table_alias=offset_base_column_subquery_alias, column_name=base_grain_spec_column_name + ) + standard_grain_columns = tuple( + SqlSelectColumn( + expr=SqlDateTruncExpression.create( + time_granularity=time_spine_spec.time_granularity.base_granularity, arg=base_grain_column.expr + ), + column_alias=self._column_association_resolver.resolve_spec(time_spine_spec), + ) + for time_spine_spec in node.requested_time_spine_specs + if not time_spine_spec.time_granularity.is_custom_granularity + ) + if standard_grain_columns: + output_select_node = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=(base_grain_column,) + standard_grain_columns, + from_source=output_select_node, + from_source_alias=offset_base_column_subquery_alias, + ) + + # Build output instance set. + output_instance_set = InstanceSet( + time_dimension_instances=tuple( + TimeDimensionInstance( + spec=spec, + defined_from=parent_window_instance.defined_from, + associated_columns=self._column_association_resolver.resolve_spec(spec), + ) + for spec in node.requested_time_spine_specs + ) + ) + return SqlDataSet( + instance_set=InstanceSet.merge( + [InstanceSet(time_dimension_instances=output_instance_set), parent_instance_set] + ), + sql_select_node=x, ) @@ -2208,5 +2343,11 @@ def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode node=node, node_to_select_subquery_function=super().visit_transform_time_dimensions_node ) + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: # noqa: D102 + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_custom_granularity_bounds_node + ) + DataflowNodeT = TypeVar("DataflowNodeT", bound=DataflowPlanNode) diff --git a/metricflow/sql/sql_exprs.py b/metricflow/sql/sql_exprs.py index 2eb0dae5c..78fe3b5c2 100644 --- a/metricflow/sql/sql_exprs.py +++ b/metricflow/sql/sql_exprs.py @@ -229,6 +229,10 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102 pass + @abstractmethod + def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102 + pass + @dataclass(frozen=True, eq=False) class SqlStringExpression(SqlExpressionNode): @@ -943,6 +947,7 @@ class SqlWindowFunction(Enum): LAST_VALUE = "LAST_VALUE" AVERAGE = "AVG" ROW_NUMBER = "ROW_NUMBER" + LAG = "LAG" @property def requires_ordering(self) -> bool: @@ -951,6 +956,7 @@ def requires_ordering(self) -> bool: self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE or self is SqlWindowFunction.ROW_NUMBER + or self is SqlWindowFunction.LAG ): return True elif self is SqlWindowFunction.AVERAGE: @@ -1254,6 +1260,8 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return self._parents_match(other) +# TODO: make this SqlTimeDeltaExpression. Renderers will check if the number is positive or negative and update syntax accordingly. +# Put this in a separate commit @dataclass(frozen=True, eq=False) class SqlSubtractTimeIntervalExpression(SqlExpressionNode): """Represents an interval subtraction from a given timestamp. @@ -1265,7 +1273,7 @@ class SqlSubtractTimeIntervalExpression(SqlExpressionNode): """ arg: SqlExpressionNode - count: int + count_expr: SqlExpressionNode granularity: TimeGranularity @staticmethod @@ -1277,7 +1285,20 @@ def create( # noqa: D102 return SqlSubtractTimeIntervalExpression( parent_nodes=(arg,), arg=arg, - count=count, + count_expr=SqlStringExpression.create(str(count)), + granularity=granularity, + ) + + @staticmethod + def create_with_count_expr( # noqa: D102 + arg: SqlExpressionNode, + count_expr: SqlExpressionNode, + granularity: TimeGranularity, + ) -> SqlSubtractTimeIntervalExpression: + return SqlSubtractTimeIntervalExpression( + parent_nodes=(arg, count_expr), + arg=arg, + count_expr=count_expr, granularity=granularity, ) @@ -1655,3 +1676,64 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False + + +@dataclass(frozen=True, eq=False) +class SqlCaseExpression(SqlExpressionNode): + """Renders a CASE WHEN expression.""" + + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode] + else_expr: Optional[SqlExpressionNode] + + @staticmethod + def create( # noqa: D102 + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None + ) -> SqlCaseExpression: + parent_nodes = (else_expr,) + for when, then in when_to_then_exprs.items(): + parent_nodes += (when,) + parent_nodes += (then,) + + return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_CASE_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_case_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Case expression" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return super().displayed_properties + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + @property + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() + + def __repr__(self) -> str: # noqa: D105 + return f"{self.__class__.__name__}(node_id={self.node_id})" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return self + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage(other_exprs=(self,)) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlCaseExpression): + return False + return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index e45b8bd79..d63e476d5 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -14,7 +14,7 @@ from metricflow_semantics.visitor import VisitorOutputT from typing_extensions import override -from metricflow.sql.sql_exprs import SqlExpressionNode +from metricflow.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode logger = logging.getLogger(__name__) @@ -103,6 +103,16 @@ class SqlSelectColumn: # Always require a column alias for simplicity. column_alias: str + @staticmethod + def from_table_and_column_names(table_alias: str, column_name: str) -> SqlSelectColumn: + """Create a column that selects a column from a table by name.""" + return SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + column_name=column_name, table_alias=table_alias + ), + column_alias=column_name, + ) + @dataclass(frozen=True) class SqlJoinDescription: diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index 0bf63ac3a..00c74655c 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -23,6 +23,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -114,6 +115,9 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_transform_time_dimensions_node(self, node: TransformTimeDimensionsNode) -> int: # noqa: D102 return self._sum_parents(node) + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> int: # noqa: D102 + return self._sum_parents(node) + def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102 return dataflow_plan.sink_node.accept(self) diff --git a/x.sql b/x.sql index 55980845d..85127566b 100644 --- a/x.sql +++ b/x.sql @@ -10,8 +10,8 @@ with cte as ( date_day, fiscal_quarter, row_number() over (partition by fiscal_quarter order by date_day) - 1 as days_from_start_of_fiscal_quarter - , first_value(date_day) over (partition by fiscal_quarter order by date_day) as fiscal_quarter_start_date - , last_value(date_day) over (partition by fiscal_quarter order by date_day) as fiscal_quarter_end_date + , first_value(date_day) over (partition by fiscal_quarter order by date_day) as fiscal_quarter_start + , last_value(date_day) over (partition by fiscal_quarter order by date_day) as fiscal_quarter_end FROM ANALYTICS_DEV.DBT_JSTEIN.ALL_DAYS ) @@ -31,23 +31,23 @@ INNER JOIN ( select fiscal_quarter , case - when dateadd(day, days_from_start_of_fiscal_quarter, fiscal_quarter_start_date__offset_by_1) <= fiscal_quarter_end_date__offset_by_1 - then dateadd(day, days_from_start_of_fiscal_quarter, fiscal_quarter_start_date__offset_by_1) - else fiscal_quarter_end_date__offset_by_1 + when dateadd(day, days_from_start_of_fiscal_quarter, fiscal_quarter_start__offset_by_1) <= fiscal_quarter_end__offset_by_1 + then dateadd(day, days_from_start_of_fiscal_quarter, fiscal_quarter_start__offset_by_1) + else fiscal_quarter_end__offset_by_1 end as date_day from cte -- CustomGranularityBoundsNode inner join ( -- OffsetCustomGranularityBoundsNode select fiscal_quarter, - lag(fiscal_quarter_start_date, 1) over (order by fiscal_quarter) as fiscal_quarter_start_date__offset_by_1, - lag(fiscal_quarter_end_date, 1) over (order by fiscal_quarter) as fiscal_quarter_end_date__offset_by_1 + lag(fiscal_quarter_start, 1) over (order by fiscal_quarter) as fiscal_quarter_start__offset_by_1, + lag(fiscal_quarter_end, 1) over (order by fiscal_quarter) as fiscal_quarter_end__offset_by_1 from ( -- FilterEelementsNode select fiscal_quarter, - fiscal_quarter_start_date, - fiscal_quarter_end_date + fiscal_quarter_start, + fiscal_quarter_end from cte -- CustomGranularityBoundsNode GROUP BY 1, 2, 3 ) ts_distinct