From ec58d8f1181c35ee936a7693fcb56cac89b1b2e2 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 11 Dec 2024 15:42:55 -0800 Subject: [PATCH] WIP --- .../metricflow_semantics/dag/id_prefix.py | 2 + .../metricflow_semantics/instances.py | 6 +- .../specs/time_dimension_spec.py | 22 ++ .../metricflow_semantics/sql/sql_exprs.py | 80 ++++++- .../simple_manifest/metrics.yaml | 21 ++ .../collection_helpers/test_pretty_print.py | 1 + .../dataflow/builder/dataflow_plan_builder.py | 105 +++++++-- metricflow/dataflow/dataflow_plan_visitor.py | 9 + .../nodes/custom_granularity_bounds.py | 79 +++++++ .../optimizer/predicate_pushdown_optimizer.py | 6 + .../source_scan/cm_branch_combiner.py | 7 + .../source_scan/source_scan_optimizer.py | 7 + metricflow/dataset/sql_dataset.py | 12 +- metricflow/execution/dataflow_to_execution.py | 5 + metricflow/plan_conversion/dataflow_to_sql.py | 201 +++++++++++++++++- .../plan_conversion/sql_join_builder.py | 2 +- metricflow/sql/render/expr_renderer.py | 16 ++ metricflow/sql/sql_plan.py | 12 +- .../source_scan/test_source_scan_optimizer.py | 4 + .../test_custom_granularity.py | 24 +++ x.sql | 94 ++++++++ 21 files changed, 678 insertions(+), 37 deletions(-) create mode 100644 metricflow/dataflow/nodes/custom_granularity_bounds.py create mode 100644 x.sql diff --git a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py index 8c2a6d1b4..dbf62ab4e 100644 --- a/metricflow-semantics/metricflow_semantics/dag/id_prefix.py +++ b/metricflow-semantics/metricflow_semantics/dag/id_prefix.py @@ -56,6 +56,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce" DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr" DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as" + DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX = "cgb" SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr" SQL_EXPR_COMPARISON_ID_PREFIX = "cmp" @@ -75,6 +76,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper): SQL_EXPR_BETWEEN_PREFIX = "betw" SQL_EXPR_WINDOW_FUNCTION_ID_PREFIX = "wfnc" SQL_EXPR_GENERATE_UUID_PREFIX = "uuid" + SQL_EXPR_CASE_PREFIX = "case" SQL_PLAN_SELECT_STATEMENT_ID_PREFIX = "ss" SQL_PLAN_TABLE_FROM_CLAUSE_ID_PREFIX = "tfc" diff --git a/metricflow-semantics/metricflow_semantics/instances.py b/metricflow-semantics/metricflow_semantics/instances.py index 6cd85fcbd..45d9560e5 100644 --- a/metricflow-semantics/metricflow_semantics/instances.py +++ b/metricflow-semantics/metricflow_semantics/instances.py @@ -164,11 +164,7 @@ def with_entity_prefix( ) -> TimeDimensionInstance: """Returns a new instance with the entity prefix added to the entity links.""" transformed_spec = self.spec.with_entity_prefix(entity_prefix) - return TimeDimensionInstance( - associated_columns=(column_association_resolver.resolve_spec(transformed_spec),), - defined_from=self.defined_from, - spec=transformed_spec, - ) + return self.with_new_spec(transformed_spec, column_association_resolver) def with_new_defined_from(self, defined_from: Sequence[SemanticModelElementReference]) -> TimeDimensionInstance: """Returns a new instance with the defined_from field replaced.""" diff --git a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py index fd47c80a6..dec834adc 100644 --- a/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py +++ b/metricflow-semantics/metricflow_semantics/specs/time_dimension_spec.py @@ -15,6 +15,7 @@ from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow_semantics.specs.dimension_spec import DimensionSpec from metricflow_semantics.specs.instance_spec import InstanceSpecVisitor +from metricflow_semantics.sql.sql_exprs import SqlWindowFunction from metricflow_semantics.time.granularity import ExpandedTimeGranularity from metricflow_semantics.visitor import VisitorOutputT @@ -91,6 +92,8 @@ class TimeDimensionSpec(DimensionSpec): # noqa: D101 # Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec. aggregation_state: Optional[AggregationState] = None + window_function: Optional[SqlWindowFunction] = None + @property def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102 assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}" @@ -99,6 +102,8 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102 entity_links=self.entity_links[1:], time_granularity=self.time_granularity, date_part=self.date_part, + aggregation_state=self.aggregation_state, + window_function=self.window_function, ) @property @@ -108,6 +113,8 @@ def without_entity_links(self) -> TimeDimensionSpec: # noqa: D102 time_granularity=self.time_granularity, date_part=self.date_part, entity_links=(), + aggregation_state=self.aggregation_state, + window_function=self.window_function, ) @property @@ -153,6 +160,7 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension time_granularity=time_granularity, date_part=self.date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102 @@ -162,6 +170,7 @@ def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102 time_granularity=ExpandedTimeGranularity.from_time_granularity(self.time_granularity.base_granularity), date_part=self.date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) def with_grain_and_date_part( # noqa: D102 @@ -173,6 +182,7 @@ def with_grain_and_date_part( # noqa: D102 time_granularity=time_granularity, date_part=date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDimensionSpec: # noqa: D102 @@ -182,6 +192,17 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim time_granularity=self.time_granularity, date_part=self.date_part, aggregation_state=aggregation_state, + window_function=self.window_function, + ) + + def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102 + return TimeDimensionSpec( + element_name=self.element_name, + entity_links=self.entity_links, + time_granularity=self.time_granularity, + date_part=self.date_part, + aggregation_state=self.aggregation_state, + window_function=window_function, ) def comparison_key(self, exclude_fields: Sequence[TimeDimensionSpecField] = ()) -> TimeDimensionSpecComparisonKey: @@ -243,6 +264,7 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe time_granularity=self.time_granularity, date_part=self.date_part, aggregation_state=self.aggregation_state, + window_function=self.window_function, ) @staticmethod diff --git a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py index 15b7268c5..391892d84 100644 --- a/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py +++ b/metricflow-semantics/metricflow_semantics/sql/sql_exprs.py @@ -15,11 +15,12 @@ from dbt_semantic_interfaces.type_enums.date_part import DatePart from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity +from typing_extensions import override + from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagNode, DisplayedProperty from metricflow_semantics.sql.sql_bind_parameters import SqlBindParameterSet from metricflow_semantics.visitor import Visitable, VisitorOutputT -from typing_extensions import override @dataclass(frozen=True, eq=False) @@ -235,6 +236,10 @@ def visit_window_function_expr(self, node: SqlWindowFunctionExpression) -> Visit def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOutputT: # noqa: D102 pass + @abstractmethod + def visit_case_expr(self, node: SqlCaseExpression) -> VisitorOutputT: # noqa: D102 + pass + @dataclass(frozen=True, eq=False) class SqlStringExpression(SqlExpressionNode): @@ -948,11 +953,18 @@ class SqlWindowFunction(Enum): FIRST_VALUE = "FIRST_VALUE" LAST_VALUE = "LAST_VALUE" AVERAGE = "AVG" + ROW_NUMBER = "ROW_NUMBER" + LAG = "LAG" @property def requires_ordering(self) -> bool: """Asserts whether or not ordering the window function will have an impact on the resulting value.""" - if self is SqlWindowFunction.FIRST_VALUE or self is SqlWindowFunction.LAST_VALUE: + if ( + self is SqlWindowFunction.FIRST_VALUE + or self is SqlWindowFunction.LAST_VALUE + or self is SqlWindowFunction.ROW_NUMBER + or self is SqlWindowFunction.LAG + ): return True elif self is SqlWindowFunction.AVERAGE: return False @@ -1715,3 +1727,67 @@ def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 return False + + +@dataclass(frozen=True, eq=False) +class SqlCaseExpression(SqlExpressionNode): + """Renders a CASE WHEN expression.""" + + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode] + else_expr: Optional[SqlExpressionNode] + + @staticmethod + def create( # noqa: D102 + when_to_then_exprs: Dict[SqlExpressionNode, SqlExpressionNode], else_expr: Optional[SqlExpressionNode] = None + ) -> SqlCaseExpression: + parent_nodes: Tuple[SqlExpressionNode, ...] = () + for when, then in when_to_then_exprs.items(): + parent_nodes += (when,) + parent_nodes += (then,) + + if else_expr: + parent_nodes += (else_expr,) + + return SqlCaseExpression(parent_nodes=parent_nodes, when_to_then_exprs=when_to_then_exprs, else_expr=else_expr) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_EXPR_CASE_PREFIX + + def accept(self, visitor: SqlExpressionNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_case_expr(self) + + @property + def description(self) -> str: # noqa: D102 + return "Case expression" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return super().displayed_properties + + @property + def requires_parenthesis(self) -> bool: # noqa: D102 + return False + + @property + def bind_parameter_set(self) -> SqlBindParameterSet: # noqa: D102 + return SqlBindParameterSet() + + def __repr__(self) -> str: # noqa: D105 + return f"{self.__class__.__name__}(node_id={self.node_id})" + + def rewrite( # noqa: D102 + self, + column_replacements: Optional[SqlColumnReplacements] = None, + should_render_table_alias: Optional[bool] = None, + ) -> SqlExpressionNode: + return self + + @property + def lineage(self) -> SqlExpressionTreeLineage: # noqa: D102 + return SqlExpressionTreeLineage(other_exprs=(self,)) + + def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102 + if not isinstance(other, SqlCaseExpression): + return False + return self.when_to_then_exprs == other.when_to_then_exprs and self.else_expr == other.else_expr diff --git a/metricflow-semantics/metricflow_semantics/test_helpers/semantic_manifest_yamls/simple_manifest/metrics.yaml b/metricflow-semantics/metricflow_semantics/test_helpers/semantic_manifest_yamls/simple_manifest/metrics.yaml index 7e34bebef..4cdad77e5 100644 --- a/metricflow-semantics/metricflow_semantics/test_helpers/semantic_manifest_yamls/simple_manifest/metrics.yaml +++ b/metricflow-semantics/metricflow_semantics/test_helpers/semantic_manifest_yamls/simple_manifest/metrics.yaml @@ -860,3 +860,24 @@ metric: - name: instant_bookings alias: shared_alias --- +metric: + name: bookings_offset_one_martian_day + description: bookings offset by one martian_day + type: derived + type_params: + expr: bookings + metrics: + - name: bookings + offset_window: 1 martian_day +--- +metric: + name: bookings_martian_day_over_martian_day + description: bookings growth martian day over martian day + type: derived + type_params: + expr: bookings - bookings_offset / NULLIF(bookings_offset, 0) + metrics: + - name: bookings + offset_window: 1 martian_day + alias: bookings_offset + - name: bookings diff --git a/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py index c09422caa..86a4c446c 100644 --- a/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py +++ b/metricflow-semantics/tests_metricflow_semantics/collection_helpers/test_pretty_print.py @@ -47,6 +47,7 @@ def test_classes() -> None: # noqa: D103 time_granularity=ExpandedTimeGranularity(name='day', base_granularity=DAY), date_part=None, aggregation_state=None, + window_function=None, ) """ ).rstrip() diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 348ba5b4e..f88382fae 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -84,6 +84,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -658,7 +659,10 @@ def _build_derived_metric_output_node( ) if metric_spec.has_time_offset and queried_agg_time_dimension_specs: # TODO: move this to a helper method - time_spine_node = self._build_time_spine_node(queried_agg_time_dimension_specs) + time_spine_node = self._build_time_spine_node( + queried_time_spine_specs=queried_agg_time_dimension_specs, + offset_window=metric_spec.offset_window, + ) output_node = JoinToTimeSpineNode.create( parent_node=output_node, time_spine_node=time_spine_node, @@ -1649,7 +1653,10 @@ def _build_aggregated_measure_from_measure_source_node( measure_properties=measure_properties, required_time_spine_specs=base_queried_agg_time_dimension_specs ) required_time_spine_specs = (join_on_time_dimension_spec,) + base_queried_agg_time_dimension_specs - time_spine_node = self._build_time_spine_node(required_time_spine_specs) + time_spine_node = self._build_time_spine_node( + queried_time_spine_specs=required_time_spine_specs, + offset_window=before_aggregation_time_spine_join_description.offset_window, + ) unaggregated_measure_node = JoinToTimeSpineNode.create( parent_node=unaggregated_measure_node, time_spine_node=time_spine_node, @@ -1862,6 +1869,7 @@ def _build_time_spine_node( queried_time_spine_specs: Sequence[TimeDimensionSpec], where_filter_specs: Sequence[WhereFilterSpec] = (), time_range_constraint: Optional[TimeRangeConstraint] = None, + offset_window: Optional[MetricTimeWindow] = None, ) -> DataflowPlanNode: """Return the time spine node needed to satisfy the specs.""" required_time_spine_spec_set = self.__get_required_linkable_specs( @@ -1870,28 +1878,81 @@ def _build_time_spine_node( ) required_time_spine_specs = required_time_spine_spec_set.time_dimension_specs - # TODO: support multiple time spines here. Build node on the one with the smallest base grain. - # Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine. - time_spine_source = self._choose_time_spine_source(required_time_spine_specs) - read_node = self._choose_time_spine_read_node(time_spine_source) - time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node) - - # Change the column aliases to match the specs that were requested in the query. - time_spine_node = AliasSpecsNode.create( - parent_node=read_node, - change_specs=tuple( - SpecToAlias( - input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(required_spec).spec, - output_spec=required_spec, + should_dedupe = False + if offset_window and offset_window in self._semantic_model_lookup.custom_granularity_names: + # Are sets the right choice here? + all_queried_grains: Set[ExpandedTimeGranularity] = set() + queried_custom_specs: Tuple[TimeDimensionSpec, ...] = () + queried_standard_specs: Tuple[TimeDimensionSpec, ...] = () + for spec in queried_time_spine_specs: + all_queried_grains.add(spec.time_granularity) + if spec.time_granularity.is_custom_granularity: + queried_custom_specs += (spec,) + else: + queried_standard_specs += (spec,) + + custom_grain_metric_time_spec = DataSet.metric_time_dimension_spec( + ExpandedTimeGranularity(name="martian_day", base_granularity=TimeGranularity.DAY) + ) # this would be offset_window.granularity + time_spine_source = self._choose_time_spine_source((custom_grain_metric_time_spec,)) + time_spine_read_node = self._choose_time_spine_read_node(time_spine_source) + # TODO: make sure this is checking the correct granularity type once DSI is updated + if {spec.time_granularity for spec in queried_time_spine_specs} == {offset_window.granularity}: + # If querying with only the same grain as is used in the offset_window, can use a simpler plan. + # offset_node = OffsetCustomGranularityNode.create( + # parent_node=time_spine_read_node, offset_window=offset_window + # ) + # time_spine_node: DataflowPlanNode = JoinToTimeSpineNode.create( + # parent_node=offset_node, + # # TODO: need to make sure we apply both agg time and metric time + # requested_agg_time_dimension_specs=queried_time_spine_specs, + # time_spine_node=time_spine_read_node, + # join_type=SqlJoinType.INNER, + # join_on_time_dimension_spec=custom_grain_metric_time_spec, + # ) + pass + else: + time_spine_node: DataflowPlanNode = CustomGranularityBoundsNode.create( + parent_node=time_spine_read_node, + offset_window=offset_window, + requested_time_spine_specs=required_time_spine_specs, ) - for required_spec in required_time_spine_specs - ), - ) + # if queried_standard_specs: + # time_spine_node = ApplyStandardGranularityNode.create( + # parent_node=time_spine_node, time_dimension_specs=queried_standard_specs + # ) + # TODO: check if this join is needed for the same grain as is used in offset window. Later + for custom_spec in queried_custom_specs: + time_spine_node = JoinToCustomGranularityNode.create( + parent_node=time_spine_node, time_dimension_spec=custom_spec + ) + else: + # TODO: support multiple time spines here. Build node on the one with the smallest base grain. + # Then, pass custom_granularity_specs into _build_pre_aggregation_plan if they aren't satisfied by smallest time spine. + time_spine_source = self._choose_time_spine_source(required_time_spine_specs) + read_node = self._choose_time_spine_read_node(time_spine_source) + time_spine_data_set = self._node_data_set_resolver.get_output_data_set(read_node) + + # Change the column aliases to match the specs that were requested in the query. + time_spine_node = AliasSpecsNode.create( + parent_node=read_node, + change_specs=tuple( + SpecToAlias( + input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity=required_spec.time_granularity, date_part=required_spec.date_part + ).spec, + output_spec=required_spec, + ) + for required_spec in required_time_spine_specs + ), + ) - # If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping. - should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in { - spec.time_granularity for spec in queried_time_spine_specs - } + # If the base grain of the time spine isn't selected, it will have duplicate rows that need deduping. + should_dedupe = ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity) not in { + spec.time_granularity for spec in queried_time_spine_specs + } + + # -- JoinToCustomGranularityNode -- if needed to support another custom grain not covered by initial time spine return self._build_pre_aggregation_plan( source_node=time_spine_node, diff --git a/metricflow/dataflow/dataflow_plan_visitor.py b/metricflow/dataflow/dataflow_plan_visitor.py index 412170a53..3fcd9ec8d 100644 --- a/metricflow/dataflow/dataflow_plan_visitor.py +++ b/metricflow/dataflow/dataflow_plan_visitor.py @@ -15,6 +15,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode + from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -126,6 +127,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102 raise NotImplementedError + @abstractmethod + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102 + raise NotImplementedError + class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]): """Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node. @@ -222,3 +227,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod @override def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102 return self._default_handler(node) + + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102 + return self._default_handler(node) diff --git a/metricflow/dataflow/nodes/custom_granularity_bounds.py b/metricflow/dataflow/nodes/custom_granularity_bounds.py new file mode 100644 index 000000000..4dd19c71a --- /dev/null +++ b/metricflow/dataflow/nodes/custom_granularity_bounds.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from abc import ABC +from dataclasses import dataclass +from typing import Sequence, Tuple + +from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow +from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix +from metricflow_semantics.dag.mf_dag import DisplayedProperty +from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec +from metricflow_semantics.visitor import VisitorOutputT + +from metricflow.dataflow.dataflow_plan import DataflowPlanNode +from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor + + +# TODO: rename node & file probably & docstring +@dataclass(frozen=True, eq=False) +class CustomGranularityBoundsNode(DataflowPlanNode, ABC): + """Calculate the start and end of a custom granularity period and each row number within that period.""" + + offset_window: MetricTimeWindow + requested_time_spine_specs: Tuple[TimeDimensionSpec, ...] + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create( # noqa: D102 + parent_node: DataflowPlanNode, + offset_window: MetricTimeWindow, + requested_time_spine_specs: Tuple[TimeDimensionSpec, ...], + ) -> CustomGranularityBoundsNode: + return CustomGranularityBoundsNode( + parent_nodes=(parent_node,), + offset_window=offset_window, + requested_time_spine_specs=requested_time_spine_specs, + ) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX + + def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_custom_granularity_bounds_node(self) + + @property + def description(self) -> str: # noqa: D102 + return """Calculate Custom Granularity Bounds""" + + @property + def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102 + return ( + tuple(super().displayed_properties) + + (DisplayedProperty("offset_window", self.offset_window),) + + (DisplayedProperty("requested_time_spine_specs", self.requested_time_spine_specs),) + ) + + @property + def parent_node(self) -> DataflowPlanNode: # noqa: D102 + return self.parent_nodes[0] + + def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102 + return ( + isinstance(other_node, self.__class__) + and other_node.offset_window == self.offset_window + and self.requested_time_spine_specs == other_node.requested_time_spine_specs + ) + + def with_new_parents( # noqa: D102 + self, new_parent_nodes: Sequence[DataflowPlanNode] + ) -> CustomGranularityBoundsNode: + assert len(new_parent_nodes) == 1 + return CustomGranularityBoundsNode.create( + parent_node=new_parent_nodes[0], + offset_window=self.offset_window, + requested_time_spine_specs=self.requested_time_spine_specs, + ) diff --git a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py index 223964af4..0c21ff612 100644 --- a/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py +++ b/metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py @@ -23,6 +23,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -472,6 +473,11 @@ def visit_join_to_custom_granularity_node( # noqa: D102 def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102 raise NotImplementedError + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> OptimizeBranchResult: + raise NotImplementedError + def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult: """Handles pushdown state propagation for the standard join node type. diff --git a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py index 233629e7a..e2db90d08 100644 --- a/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py +++ b/metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py @@ -17,6 +17,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -472,3 +473,9 @@ def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerRe def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102 self._log_visit_node_type(node) return self._default_handler(node) + + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> ComputeMetricsBranchCombinerResult: + self._log_visit_node_type(node) + return self._default_handler(node) diff --git a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py index c84035335..b2fa4b5f7 100644 --- a/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py +++ b/metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py @@ -19,6 +19,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -356,3 +357,9 @@ def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa: def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102 self._log_visit_node_type(node) return self._default_base_output_handler(node) + + def visit_custom_granularity_bounds_node( # noqa: D102 + self, node: CustomGranularityBoundsNode + ) -> OptimizeBranchResult: + self._log_visit_node_type(node) + return self._default_base_output_handler(node) diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index afa559387..c5707f012 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import SemanticModelReference +from dbt_semantic_interfaces.type_enums import DatePart from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set from metricflow_semantics.instances import EntityInstance, InstanceSet, MdoInstance, TimeDimensionInstance from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat @@ -12,6 +13,7 @@ from metricflow_semantics.specs.entity_spec import EntitySpec from metricflow_semantics.specs.instance_spec import InstanceSpec from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec +from metricflow_semantics.time.granularity import ExpandedTimeGranularity from typing_extensions import override from metricflow.dataset.dataset_classes import DataSet @@ -165,18 +167,18 @@ def instance_for_spec(self, spec: InstanceSpec) -> MdoInstance: ) def instance_from_time_dimension_grain_and_date_part( - self, time_dimension_spec: TimeDimensionSpec + self, time_granularity: ExpandedTimeGranularity, date_part: Optional[DatePart] ) -> TimeDimensionInstance: - """Find instance in dataset that matches the grain and date part of the given time dimension spec.""" + """Find instance in dataset that matches the given grain and date part.""" for time_dimension_instance in self.instance_set.time_dimension_instances: if ( - time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity - and time_dimension_instance.spec.date_part == time_dimension_spec.date_part + time_dimension_instance.spec.time_granularity == time_granularity + and time_dimension_instance.spec.date_part == date_part ): return time_dimension_instance raise RuntimeError( - f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n" + f"Did not find a time dimension instance with grain {time_granularity} and date part {date_part}\n" f"Instances available: {self.instance_set.time_dimension_instances}" ) diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index b5369f735..c192234ba 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -16,6 +16,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -205,3 +206,7 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod @override def visit_alias_specs_node(self, node: AliasSpecsNode) -> ConvertToExecutionPlanResult: raise NotImplementedError + + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> ConvertToExecutionPlanResult: + raise NotImplementedError diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 7e40df875..e0931bd07 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -6,6 +6,7 @@ from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar from dbt_semantic_interfaces.enum_extension import assert_values_exhausted +from dbt_semantic_interfaces.naming.keywords import DUNDER from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType @@ -38,8 +39,10 @@ from metricflow_semantics.specs.spec_set import InstanceSpecSet from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec from metricflow_semantics.sql.sql_exprs import ( + SqlAddTimeExpression, SqlAggregateFunctionExpression, SqlBetweenExpression, + SqlCaseExpression, SqlColumnReference, SqlColumnReferenceExpression, SqlComparison, @@ -77,6 +80,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -1827,7 +1831,7 @@ def visit_join_conversion_events_node(self, node: JoinConversionEventsNode) -> S def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> SqlDataSet: # noqa: D102 from_data_set = node.parent_node.accept(self) - parent_instance_set = from_data_set.instance_set # remove order by col + parent_instance_set = from_data_set.instance_set parent_data_set_alias = self._next_unique_table_alias() metric_instance = None @@ -1954,6 +1958,195 @@ def strip_time_from_dt(ts: dt.datetime) -> dt.datetime: ), ) + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: # noqa: D102 + from_data_set = node.parent_node.accept(self) + parent_instance_set = from_data_set.instance_set + parent_data_set_alias = self._next_unique_table_alias() + + # Get the column names needed to query the custom grain from the time spine where it's defined. + offset_grain_name = node.offset_window.granularity + time_spine = self._get_time_spine_for_custom_granularity(offset_grain_name) + window_column_name = self._get_custom_granularity_column_name(offset_grain_name) + window_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, column_name=window_column_name + ) + base_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=parent_data_set_alias, column_name=time_spine.base_column + ) + + # Build subquery to get start and end of custom grain period, as well as row number within the period. + parent_window_instance = from_data_set.instance_from_time_dimension_grain_and_date_part( + time_granularity=ExpandedTimeGranularity( + name=offset_grain_name, base_granularity=time_spine.base_granularity + ), + date_part=None, + ) + window_func_to_args: Dict[SqlWindowFunction, Tuple[SqlExpressionNode, ...]] = { + SqlWindowFunction.FIRST_VALUE: (base_column_expr,), + SqlWindowFunction.LAST_VALUE: (base_column_expr,), + SqlWindowFunction.ROW_NUMBER: (), + } + bounds_columns = tuple( + SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=window_func, + sql_function_args=func_args, + partition_by_args=(window_column_expr,), + order_by_args=(SqlWindowOrderByArgument(base_column_expr),), + ), + column_alias=self._column_association_resolver.resolve_spec( + parent_window_instance.spec.with_window_function(window_func) + ).column_name, + ) + for window_func, func_args in window_func_to_args.items() + ) + bounds_cte_alias = self._next_unique_table_alias() + bounds_cte = SqlCteNode.create( + SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=from_data_set.checked_sql_select_node.select_columns + bounds_columns, + from_source=from_data_set.checked_sql_select_node, + from_source_alias=parent_data_set_alias, + ), + cte_alias=bounds_cte_alias, + ) + + # Build a subquery to get a unique row for each custom grain along with its start date & end date. + unique_bounds_columns = tuple( + SqlSelectColumn.from_table_and_column_names(table_alias=bounds_cte_alias, column_name=alias) + for alias in [offset_grain_name] + [column.column_alias for column in bounds_columns[:-1]] + ) + unique_bounds_subquery_alias = self._next_unique_table_alias() + unique_bounds_subquery = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=unique_bounds_columns, + from_source=bounds_cte, # need? can I make this optional if CTEs are present? + from_source_alias=bounds_cte_alias, + cte_sources=(bounds_cte,), + group_bys=unique_bounds_columns, + ) + + # Build a subquery to offset the start and end dates by the requested offset_window. + custom_grain_column = SqlSelectColumn.from_table_and_column_names( + column_name=offset_grain_name, table_alias=unique_bounds_subquery_alias + ) + offset_bounds_columns = tuple( + SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=SqlWindowFunction.LAG, + sql_function_args=( + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=column.column_alias, table_alias=unique_bounds_subquery_alias + ), + SqlStringExpression.create(str(node.offset_window.count)), + ), + order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),), + ), + column_alias=f"{column.column_alias}{DUNDER}offset", # TODO: finalize this alias + ) + for column in unique_bounds_columns + ) + offset_bounds_subquery_alias = self._next_unique_table_alias() + offset_bounds_subquery = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=(custom_grain_column,) + offset_bounds_columns, + from_source=unique_bounds_subquery, + from_source_alias=unique_bounds_subquery_alias, + ) + + # Use the row number calculated above to offset the time spine's base column by the requested window. + # If the offset date is not within the offset custom grain period, default to the last value in that period. + custom_grain_column_2 = SqlSelectColumn.from_table_and_column_names( + column_name=offset_grain_name, table_alias=unique_bounds_subquery_alias + ) # TODO: better variable name + # TODO: Get time spine specs in this node. If any have the base grain, use any of those specs. + # Else, default to metric time as below. + base_grain_spec = DataSet.metric_time_dimension_spec( + ExpandedTimeGranularity.from_time_granularity(time_spine.base_granularity) + ) + base_grain_spec_column_name = self._column_association_resolver.resolve_spec(base_grain_spec).column_name + offset_start, offset_end = [ + SqlColumnReferenceExpression.from_table_and_column_names( + column_name=offset_bound_column.column_alias, table_alias=offset_bounds_subquery_alias + ) + for offset_bound_column in offset_bounds_columns + ] + add_row_number_expr = SqlAddTimeExpression.create( + arg=offset_start, + count_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_cte_alias, column_name=bounds_columns[-1].column_alias + ), + granularity=time_spine.base_granularity, + ) + below_end_date_expr = SqlComparisonExpression.create( + left_expr=add_row_number_expr, comparison=SqlComparison.LESS_THAN_OR_EQUALS, right_expr=add_row_number_expr + ) + offset_base_column = SqlSelectColumn( + expr=SqlCaseExpression.create( + when_to_then_exprs={below_end_date_expr: add_row_number_expr}, + else_expr=offset_end, + ), + column_alias=base_grain_spec_column_name, + ) + join_desc = SqlJoinDescription( + right_source=offset_bounds_subquery, + right_source_alias=offset_bounds_subquery_alias, + join_type=SqlJoinType.INNER, + on_condition=SqlComparisonExpression.create( + left_expr=custom_grain_column_2.expr, + comparison=SqlComparison.EQUALS, + right_expr=SqlColumnReferenceExpression.from_table_and_column_names( + table_alias=bounds_cte_alias, column_name=offset_grain_name + ), + ), + ) + offset_base_column_subquery_alias = self._next_unique_table_alias() + output_select_node = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=(custom_grain_column_2, offset_base_column), + from_source=bounds_cte, # need? + from_source_alias=bounds_cte_alias, + join_descs=(join_desc,), + cte_sources=(bounds_cte,), + ) + + # Apply any standard grains that were requested. + # TODO: add a conditional here if there are standard grains requested besides the base grain + base_grain_column = SqlSelectColumn.from_table_and_column_names( + table_alias=offset_base_column_subquery_alias, column_name=base_grain_spec_column_name + ) + standard_grain_columns = tuple( + SqlSelectColumn( + expr=SqlDateTruncExpression.create( + time_granularity=time_spine_spec.time_granularity.base_granularity, arg=base_grain_column.expr + ), + column_alias=self._column_association_resolver.resolve_spec(time_spine_spec).column_name, + ) + for time_spine_spec in node.requested_time_spine_specs + if not time_spine_spec.time_granularity.is_custom_granularity + ) + if standard_grain_columns: + output_select_node = SqlSelectStatementNode.create( + description=node.description, # TODO + select_columns=(base_grain_column,) + standard_grain_columns, + from_source=output_select_node, + from_source_alias=offset_base_column_subquery_alias, + ) + + # Build output instance set. + time_spine_instance_set = InstanceSet( + time_dimension_instances=tuple( + parent_window_instance.with_new_spec( + new_spec=spec, column_association_resolver=self._column_association_resolver + ) + for spec in node.requested_time_spine_specs + ) + ) + return SqlDataSet( + instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]), + sql_select_node=output_select_node, + ) + class DataflowNodeToSqlCteVisitor(DataflowNodeToSqlSubqueryVisitor): """Similar to `DataflowNodeToSqlSubqueryVisitor`, except that this converts specific nodes to CTEs. @@ -2149,5 +2342,11 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> SqlDataSet: # noqa: D102 return self._default_handler(node=node, node_to_select_subquery_function=super().visit_alias_specs_node) + @override + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> SqlDataSet: # noqa: D102 + return self._default_handler( + node=node, node_to_select_subquery_function=super().visit_custom_granularity_bounds_node + ) + DataflowNodeT = TypeVar("DataflowNodeT", bound=DataflowPlanNode) diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index f80cdf228..682599ab6 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -535,7 +535,7 @@ def make_join_to_time_spine_join_description( left_expr: SqlExpressionNode = SqlColumnReferenceExpression.create( col_ref=SqlColumnReference(table_alias=time_spine_alias, column_name=agg_time_dimension_column_name) ) - if node.offset_window: + if node.offset_window: # and not node.offset_window.granularity.is_custom_granularity: left_expr = SqlSubtractTimeIntervalExpression.create( arg=left_expr, count=node.offset_window.count, diff --git a/metricflow/sql/render/expr_renderer.py b/metricflow/sql/render/expr_renderer.py index f7cac9efb..b0461a1aa 100644 --- a/metricflow/sql/render/expr_renderer.py +++ b/metricflow/sql/render/expr_renderer.py @@ -16,6 +16,7 @@ SqlAddTimeExpression, SqlAggregateFunctionExpression, SqlBetweenExpression, + SqlCaseExpression, SqlCastToTimestampExpression, SqlColumnAliasReferenceExpression, SqlColumnReferenceExpression, @@ -438,3 +439,18 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> SqlExpres sql="UUID()", bind_parameter_set=SqlBindParameterSet(), ) + + def visit_case_expr(self, node: SqlCaseExpression) -> SqlExpressionRenderResult: # noqa: D102 + sql = "CASE\n" + for when, then in node.when_to_then_exprs.items(): + sql += indent( + f"WHEN {self.render_sql_expr(when).sql} THEN {self.render_sql_expr(then).sql}\n", + indent_prefix=SqlRenderingConstants.INDENT, + ) + if node.else_expr: + sql += indent( + f"ELSE {self.render_sql_expr(node.else_expr).sql}\n", + indent_prefix=SqlRenderingConstants.INDENT, + ) + sql += "END" + return SqlExpressionRenderResult(sql=sql, bind_parameter_set=SqlBindParameterSet()) diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index a01eb7a2f..6a75f3015 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -9,7 +9,7 @@ from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, DisplayedProperty, MetricFlowDag -from metricflow_semantics.sql.sql_exprs import SqlExpressionNode +from metricflow_semantics.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode from metricflow_semantics.sql.sql_join_type import SqlJoinType from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.visitor import VisitorOutputT @@ -102,6 +102,16 @@ class SqlSelectColumn: # Always require a column alias for simplicity. column_alias: str + @staticmethod + def from_table_and_column_names(table_alias: str, column_name: str) -> SqlSelectColumn: + """Create a column that selects a column from a table by name.""" + return SqlSelectColumn( + expr=SqlColumnReferenceExpression.from_table_and_column_names( + column_name=column_name, table_alias=table_alias + ), + column_alias=column_name, + ) + @dataclass(frozen=True) class SqlJoinDescription: diff --git a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py index 05770806a..a396ae104 100644 --- a/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py +++ b/tests_metricflow/dataflow/optimizer/source_scan/test_source_scan_optimizer.py @@ -24,6 +24,7 @@ from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode +from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode from metricflow.dataflow.nodes.filter_elements import FilterElementsNode from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode @@ -114,6 +115,9 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod def visit_alias_specs_node(self, node: AliasSpecsNode) -> int: # noqa: D102 return self._sum_parents(node) + def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> int: # noqa: D102 + return self._sum_parents(node) + def count_source_nodes(self, dataflow_plan: DataflowPlan) -> int: # noqa: D102 return dataflow_plan.sink_node.accept(self) diff --git a/tests_metricflow/query_rendering/test_custom_granularity.py b/tests_metricflow/query_rendering/test_custom_granularity.py index 4043c7b97..a87cbc620 100644 --- a/tests_metricflow/query_rendering/test_custom_granularity.py +++ b/tests_metricflow/query_rendering/test_custom_granularity.py @@ -610,3 +610,27 @@ def test_join_to_timespine_metric_with_custom_granularity_filter_not_in_group_by dataflow_plan_builder=dataflow_plan_builder, query_spec=query_spec, ) + + +@pytest.mark.sql_engine_snapshot +def test_custom_offset_window( # noqa: D103 + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, + sql_client: SqlClient, + query_parser: MetricFlowQueryParser, +) -> None: + query_spec = query_parser.parse_and_validate_query( + metric_names=("bookings_offset_one_martian_day",), + group_by_names=("metric_time__day",), + ).query_spec + + render_and_check( + request=request, + mf_test_configuration=mf_test_configuration, + dataflow_to_sql_converter=dataflow_to_sql_converter, + sql_client=sql_client, + dataflow_plan_builder=dataflow_plan_builder, + query_spec=query_spec, + ) diff --git a/x.sql b/x.sql new file mode 100644 index 000000000..85127566b --- /dev/null +++ b/x.sql @@ -0,0 +1,94 @@ +-- Grouping by a grain that is NOT the same as the custom grain used in the offset window +-------------------------------------------------- +-- Use the base grain of the custom grain's time spine in all initial subqueries, apply DATE_TRUNC in final query +-- This also works for custom grain, since we can just join it to the final subquery like usual. +-- Also works if there are multiple grains in the group by + +with cte as ( +-- CustomGranularityBoundsNode + SELECT + date_day, + fiscal_quarter, + row_number() over (partition by fiscal_quarter order by date_day) - 1 as days_from_start_of_fiscal_quarter + , first_value(date_day) over (partition by fiscal_quarter order by date_day) as fiscal_quarter_start + , last_value(date_day) over (partition by fiscal_quarter order by date_day) as fiscal_quarter_end + FROM ANALYTICS_DEV.DBT_JSTEIN.ALL_DAYS +) + +SELECT + metric_time__week, + metric_time__fiscal_year, + SUM(total_price) AS revenue_last_fiscal_quarter +FROM ANALYTICS_DEV.DBT_JSTEIN.STG_SALESFORCE__ORDER_ITEMS +INNER JOIN ( + -- ApplyStandardGranularityNode + SELECT + ts_offset.date_day, + DATE_TRUNC(week, ts_offset.date_day) AS metric_time__week, + fiscal_year AS metric_time__fiscal_year + FROM ( + -- OffsetByCustomGranularityNode + select + fiscal_quarter + , case + when dateadd(day, days_from_start_of_fiscal_quarter, fiscal_quarter_start__offset_by_1) <= fiscal_quarter_end__offset_by_1 + then dateadd(day, days_from_start_of_fiscal_quarter, fiscal_quarter_start__offset_by_1) + else fiscal_quarter_end__offset_by_1 + end as date_day + from cte -- CustomGranularityBoundsNode + inner join ( + -- OffsetCustomGranularityBoundsNode + select + fiscal_quarter, + lag(fiscal_quarter_start, 1) over (order by fiscal_quarter) as fiscal_quarter_start__offset_by_1, + lag(fiscal_quarter_end, 1) over (order by fiscal_quarter) as fiscal_quarter_end__offset_by_1 + from ( + -- FilterEelementsNode + select + fiscal_quarter, + fiscal_quarter_start, + fiscal_quarter_end + from cte -- CustomGranularityBoundsNode + GROUP BY 1, 2, 3 + ) ts_distinct + ) ts_with_offset_intervals USING (fiscal_quarter) + ) ts_offset + -- JoinToCustomGranularityNode + LEFT JOIN ANALYTICS_DEV.DBT_JSTEIN.ALL_DAYS custom ON custom.date_day = ts_offset.date_day +) ts_offset_dates ON ts_offset_dates.date_day = DATE_TRUNC(day, created_at)::date -- always join on base time spine column +GROUP BY 1, 2 +ORDER BY 1, 2; + + + + + + +-- Grouping by the just same custom grain as what's used in the offset window (and only that grain) +-------------------------------------------------- +-- Could follow the same SQL as above, but this would be a more optimized version (they appear to give the same results) +-- This is likely to be most common for period over period, so it might be good to optimize it + + +SELECT -- existing nodes! + metric_time__fiscal_quarter, + SUM(total_price) AS revenue +FROM ANALYTICS_DEV.DBT_JSTEIN.STG_SALESFORCE__ORDER_ITEMS +LEFT JOIN ( -- JoinToTimeSpineNode, no offset, join on custom grain spec + SELECT + -- JoinToTimeSpineNode + -- TransformTimeDimensionsNode?? + date_day, + fiscal_quarter_offset AS metric_time__fiscal_quarter + FROM ANALYTICS_DEV.DBT_JSTEIN.ALL_DAYS + INNER JOIN ( + -- OffsetCustomGranularityNode + SELECT + fiscal_quarter + , lag(fiscal_quarter, 1) OVER (ORDER BY fiscal_quarter) as fiscal_quarter_offset + FROM ANALYTICS_DEV.DBT_JSTEIN.ALL_DAYS + GROUP BY 1 + ) ts_offset_dates USING (fiscal_quarter) +) ts ON date_day = DATE_TRUNC(day, created_at)::date +GROUP BY 1 +ORDER BY 1;