From 00fb89ec88ac1afdedba27cc78cd78163683a87f Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Tue, 19 Nov 2024 18:18:16 -0800 Subject: [PATCH] WIP --- .../metricflow_semantics/instances.py | 13 ++ .../dataflow/builder/dataflow_plan_builder.py | 20 ++- metricflow/dataflow/nodes/join_over_time.py | 7 +- .../dataflow/nodes/join_to_time_spine.py | 2 + metricflow/dataset/sql_dataset.py | 40 ++++- metricflow/plan_conversion/dataflow_to_sql.py | 157 +++++------------- .../plan_conversion/instance_converters.py | 1 + .../plan_conversion/sql_join_builder.py | 19 +-- metricflow/sql/sql_plan.py | 2 + .../test_cumulative_metric_rendering.py | 30 ++++ 10 files changed, 145 insertions(+), 146 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/instances.py b/metricflow-semantics/metricflow_semantics/instances.py index c6a0f8e2b0..e853aa3b1f 100644 --- a/metricflow-semantics/metricflow_semantics/instances.py +++ b/metricflow-semantics/metricflow_semantics/instances.py @@ -145,6 +145,19 @@ def with_entity_prefix( spec=transformed_spec, ) + @staticmethod + def from_properties( + spec: TimeDimensionSpec, + defined_from: Tuple[SemanticModelElementReference, ...], + column_association_resolver: ColumnAssociationResolver, + ) -> TimeDimensionInstance: + """Create a TimeDimensionInstance from specified properties.""" + return TimeDimensionInstance( + associated_columns=(column_association_resolver.resolve_spec(spec),), + spec=spec, + defined_from=defined_from, + ) + @dataclass(frozen=True) class EntityInstance(LinkableInstance[EntitySpec], SemanticModelElementInstance): # noqa: D101 diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 0ad7e9ed38..3af0dc8ef2 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -643,15 +643,17 @@ def _build_derived_metric_output_node( # For ratio / derived metrics with time offset, apply offset & where constraint after metric computation. if metric_spec.has_time_offset: - queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_metric( + required_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_metric( metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup ) assert ( - queried_agg_time_dimension_specs + required_agg_time_dimension_specs ), "Joining to time spine requires querying with metric_time or the appropriate agg_time_dimension." output_node = JoinToTimeSpineNode.create( parent_node=output_node, - requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, + requested_agg_time_dimension_specs=[ + spec.with_base_grain() for spec in required_agg_time_dimension_specs + ], use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, time_range_constraint=predicate_pushdown_state.time_range_constraint, offset_window=metric_spec.offset_window, @@ -1590,10 +1592,14 @@ def _build_aggregated_measure_from_measure_source_node( queried_agg_time_dimension_specs = queried_linkable_specs.included_agg_time_dimension_specs_for_measure( measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup ) + required_agg_time_dimension_specs = required_linkable_specs.included_agg_time_dimension_specs_for_measure( + measure_reference=measure_spec.reference, semantic_model_lookup=self._semantic_model_lookup + ) # If a cumulative metric is queried with metric_time / agg_time_dimension, join over time range. # Otherwise, the measure will be aggregated over all time. unaggregated_measure_node: DataflowPlanNode = measure_recipe.source_node + # TODO: can we use required here, too? if cumulative and queried_agg_time_dimension_specs: unaggregated_measure_node = JoinOverTimeRangeNode.create( parent_node=unaggregated_measure_node, @@ -1619,7 +1625,11 @@ def _build_aggregated_measure_from_measure_source_node( # in join rendering unaggregated_measure_node = JoinToTimeSpineNode.create( parent_node=unaggregated_measure_node, - requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, + requested_agg_time_dimension_specs=[ + spec + for spec in required_agg_time_dimension_specs + if not spec.time_granularity.is_custom_granularity + ], use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, time_range_constraint=predicate_pushdown_state.time_range_constraint, offset_window=before_aggregation_time_spine_join_description.offset_window, @@ -1691,7 +1701,7 @@ def _build_aggregated_measure_from_measure_source_node( # like JoinToCustomGranularityNode, WhereConstraintNode, etc. output_node: DataflowPlanNode = JoinToTimeSpineNode.create( parent_node=aggregate_measures_node, - requested_agg_time_dimension_specs=queried_agg_time_dimension_specs, + requested_agg_time_dimension_specs=required_agg_time_dimension_specs, use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time, join_type=after_aggregation_time_spine_join_description.join_type, time_range_constraint=predicate_pushdown_state.time_range_constraint, diff --git a/metricflow/dataflow/nodes/join_over_time.py b/metricflow/dataflow/nodes/join_over_time.py index c766bb7dcb..99c3b095c3 100644 --- a/metricflow/dataflow/nodes/join_over_time.py +++ b/metricflow/dataflow/nodes/join_over_time.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.protocols import MetricTimeWindow from dbt_semantic_interfaces.type_enums import TimeGranularity @@ -15,6 +15,7 @@ from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor +# TODO: Shoult this class be combined with JoinToTimeSpineNode? @dataclass(frozen=True, eq=False) class JoinOverTimeRangeNode(DataflowPlanNode): """A node that allows for cumulative metric computation by doing a self join across a cumulative date range. @@ -26,7 +27,7 @@ class JoinOverTimeRangeNode(DataflowPlanNode): time_range_constraint: Time range to aggregate over. """ - queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...] window: Optional[MetricTimeWindow] grain_to_date: Optional[TimeGranularity] time_range_constraint: Optional[TimeRangeConstraint] @@ -38,7 +39,7 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 parent_node: DataflowPlanNode, - queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec], + queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...], window: Optional[MetricTimeWindow] = None, grain_to_date: Optional[TimeGranularity] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index a17b2e4283..18ab681f07 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -31,7 +31,9 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC): offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine. """ + # TODO: rename property to required_agg_time_dimension_specs requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + # TODO remove this property use_custom_agg_time_dimension: bool join_type: SqlJoinType time_range_constraint: Optional[TimeRangeConstraint] diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index 214d101eb1..2b86eb4799 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import List, Optional, Sequence +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set @@ -124,18 +125,19 @@ def column_association_for_dimension( def instances_for_time_dimensions( self, time_dimension_specs: Sequence[TimeDimensionSpec] - ) -> List[TimeDimensionInstance]: + ) -> Tuple[TimeDimensionInstance, ...]: """Return the instances associated with these specs in the data set.""" + time_dimension_specs_set = set(time_dimension_specs) matching_instances = 0 - instances_to_return: List[TimeDimensionInstance] = [] + instances_to_return: Tuple[TimeDimensionInstance, ...] = () for time_dimension_instance in self.instance_set.time_dimension_instances: - if time_dimension_instance.spec in time_dimension_specs: - instances_to_return.append(time_dimension_instance) + if time_dimension_instance.spec in time_dimension_specs_set: + instances_to_return += (time_dimension_instance,) matching_instances += 1 - if matching_instances != len(time_dimension_specs): + if matching_instances != len(time_dimension_specs_set): raise RuntimeError( - f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_specs}\n" + f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_specs_set}\n" f"Instances: {instances_to_return}" ) @@ -143,7 +145,7 @@ def instances_for_time_dimensions( def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> TimeDimensionInstance: """Given the name of the time dimension, return the instance associated with it in the data set.""" - return self.instances_for_time_dimensions([time_dimension_spec])[0] + return self.instances_for_time_dimensions((time_dimension_spec,))[0] def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> ColumnAssociation: """Given the name of the time dimension, return the set of columns associated with it in the data set.""" @@ -153,3 +155,25 @@ def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensi @override def semantic_model_reference(self) -> Optional[SemanticModelReference]: return None + + def annotate(self, alias: str, metric_time_spec: TimeDimensionSpec) -> AnnotatedSqlDataSet: + """Convert to an AnnotatedSqlDataSet with specified metadata.""" + metric_time_column_name = self.column_association_for_time_dimension(metric_time_spec).column_name + return AnnotatedSqlDataSet(data_set=self, alias=alias, _metric_time_column_name=metric_time_column_name) + + +@dataclass(frozen=True) +class AnnotatedSqlDataSet: + """Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan.""" + + data_set: SqlDataSet + alias: str + _metric_time_column_name: Optional[str] = None + + @property + def metric_time_column_name(self) -> str: + """Direct accessor for the optional metric time name, only safe to call when we know that value is set.""" + assert ( + self._metric_time_column_name + ), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!" + return self._metric_time_column_name diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index e85eb2af26..d6169cf477 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -6,9 +6,8 @@ from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union from dbt_semantic_interfaces.enum_extension import assert_values_exhausted -from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType -from dbt_semantic_interfaces.references import EntityReference, MetricModelReference, SemanticModelElementReference +from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType from dbt_semantic_interfaces.type_enums.conversion_calculation_type import ConversionCalculationType from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation @@ -466,65 +465,35 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet: """Generate time range join SQL.""" table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict() - input_data_set = node.parent_node.accept(self) - input_data_set_alias = self._next_unique_table_alias() - - # Find requested agg_time_dimensions in parent instance set. - # Will use instance with the smallest base granularity in time spine join. - agg_time_dimension_instance_for_join: Optional[TimeDimensionInstance] = None - requested_agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...] = () - for instance in input_data_set.instance_set.time_dimension_instances: - if instance.spec in node.queried_agg_time_dimension_specs: - requested_agg_time_dimension_instances += (instance,) - if not agg_time_dimension_instance_for_join or ( - instance.spec.time_granularity.base_granularity.to_int() - < agg_time_dimension_instance_for_join.spec.time_granularity.base_granularity.to_int() - ): - agg_time_dimension_instance_for_join = instance - assert ( - agg_time_dimension_instance_for_join - ), "Specified metric time spec not found in parent data set. This should have been caught by validations." + parent_data_set = node.parent_node.accept(self) + parent_data_set_alias = self._next_unique_table_alias() + # Assemble time_spine dataset. + agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions( + node.queried_agg_time_dimension_specs + ) + join_on_instance = self._choose_instance_for_time_spine_join(agg_time_dimension_instances) time_spine_data_set_alias = self._next_unique_table_alias() - - # Assemble time_spine dataset with requested agg time dimension instances selected. time_spine_data_set = self._make_time_spine_data_set( - agg_time_dimension_instances=requested_agg_time_dimension_instances, - time_range_constraint=node.time_range_constraint, + agg_time_dimension_instances=agg_time_dimension_instances, time_range_constraint=node.time_range_constraint ) table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set + # Build the join description. + join_spec = join_on_instance.spec + annotated_parent = parent_data_set.annotate(alias=parent_data_set_alias, metric_time_spec=join_spec) + annotated_time_spine = time_spine_data_set.annotate(alias=time_spine_data_set_alias, metric_time_spec=join_spec) join_desc = SqlQueryPlanJoinBuilder.make_cumulative_metric_time_range_join_description( - node=node, - metric_data_set=AnnotatedSqlDataSet( - data_set=input_data_set, - alias=input_data_set_alias, - _metric_time_column_name=input_data_set.column_association_for_time_dimension( - agg_time_dimension_instance_for_join.spec - ).column_name, - ), - time_spine_data_set=AnnotatedSqlDataSet( - data_set=time_spine_data_set, - alias=time_spine_data_set_alias, - _metric_time_column_name=time_spine_data_set.column_association_for_time_dimension( - agg_time_dimension_instance_for_join.spec - ).column_name, - ), + node=node, metric_data_set=annotated_parent, time_spine_data_set=annotated_time_spine ) # Remove instances of agg_time_dimension from input data set. They'll be replaced with time spine instances. - agg_time_dimension_specs = tuple(dim.spec for dim in requested_agg_time_dimension_instances) - modified_input_instance_set = input_data_set.instance_set.transform( - FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs)) + table_alias_to_instance_set[parent_data_set_alias] = parent_data_set.instance_set.transform( + FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=node.queried_agg_time_dimension_specs)) ) - table_alias_to_instance_set[input_data_set_alias] = modified_input_instance_set - # The output instances are the same as the input instances. - output_instance_set = ChangeAssociatedColumns(self._column_association_resolver).transform( - input_data_set.instance_set - ) return SqlDataSet( - instance_set=output_instance_set, + instance_set=parent_data_set.instance_set, # The output instances are the same as the input instances. sql_select_node=SqlSelectStatementNode.create( description=node.description, select_columns=create_simple_select_columns_for_instance_sets( @@ -1390,35 +1359,29 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ), ) + def _choose_instance_for_time_spine_join( + self, agg_time_dimension_instances: Sequence[TimeDimensionInstance] + ) -> TimeDimensionInstance: + """Find the agg_time_dimension instance with the smallest grain to use for the time spine join.""" + # We can't use a date part spec to join to the time spine, so filter those out. + agg_time_dimension_instances = [ + instance for instance in agg_time_dimension_instances if not instance.spec.date_part + ] + assert len(agg_time_dimension_instances) > 0, ( + "No appropriate agg_time_dimension was found to join to the time spine. " + "This indicates that the dataflow plan was configured incorrectly." + ) + agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int()) + return agg_time_dimension_instances[0] + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() - if node.use_custom_agg_time_dimension: - agg_time_dimension = node.requested_agg_time_dimension_specs[0] - agg_time_element_name = agg_time_dimension.element_name - agg_time_entity_links: Tuple[EntityReference, ...] = agg_time_dimension.entity_links - else: - agg_time_element_name = METRIC_TIME_ELEMENT_NAME - agg_time_entity_links = () - - # Find the time dimension instances in the parent data set that match the one we want to join with. - agg_time_dimension_instances: List[TimeDimensionInstance] = [] - for instance in parent_data_set.instance_set.time_dimension_instances: - if ( - instance.spec.date_part is None # Ensure we don't join using an instance with date part - and instance.spec.element_name == agg_time_element_name - and instance.spec.entity_links == agg_time_entity_links - ): - agg_time_dimension_instances.append(instance) - - # Choose the instance with the smallest base granularity available. - agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int()) - assert len(agg_time_dimension_instances) > 0, ( - "Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been " - "configured incorrectly." + agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions( + node.requested_agg_time_dimension_specs ) - agg_time_dimension_instance_for_join = agg_time_dimension_instances[0] + agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join(agg_time_dimension_instances) # Build time spine data set using the requested agg_time_dimension name. time_spine_alias = self._next_unique_table_alias() @@ -1439,32 +1402,11 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet parent_alias=parent_alias, ) - # Select all instances from the parent data set, EXCEPT agg_time_dimensions. - # The agg_time_dimensions will be selected from the time spine data set. - time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = () - time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = () - for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances: - if ( - time_dimension_instance.spec.element_name == agg_time_element_name - and time_dimension_instance.spec.entity_links == agg_time_entity_links - ): - time_dimensions_to_select_from_time_spine += (time_dimension_instance,) - else: - time_dimensions_to_select_from_parent += (time_dimension_instance,) - parent_instance_set = InstanceSet( - measure_instances=parent_data_set.instance_set.measure_instances, - dimension_instances=parent_data_set.instance_set.dimension_instances, - time_dimension_instances=tuple( - time_dimension_instance - for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances - if not ( - time_dimension_instance.spec.element_name == agg_time_element_name - and time_dimension_instance.spec.entity_links == agg_time_entity_links - ) - ), - entity_instances=parent_data_set.instance_set.entity_instances, - metric_instances=parent_data_set.instance_set.metric_instances, - metadata_instances=parent_data_set.instance_set.metadata_instances, + # Select all instances from the parent data set EXCEPT agg time dimensions, which will be selected from the time spine + parent_instance_set = parent_data_set.instance_set.transform( + FilterElements( + exclude_specs=InstanceSpecSet(time_dimension_specs=tuple(node.requested_agg_time_dimension_specs)) + ) ) parent_select_columns = create_simple_select_columns_for_instance_sets( self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set}) @@ -1492,19 +1434,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet and original_time_spine_dim_instance.spec not in node.requested_agg_time_dimension_specs ) + # TODO: column-building is handled in 2 different places (here and _make_time_spine_data_set) # Add requested granularities (if different from time_spine) and date_parts to time spine column. - for time_dimension_instance in time_dimensions_to_select_from_time_spine: - time_dimension_spec = time_dimension_instance.spec - if ( - time_dimension_spec.time_granularity.base_granularity.to_int() - < original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int() - ): - raise RuntimeError( - f"Can't join to time spine for a time dimension with a smaller granularity than that of the time " - f"spine column. Got {time_dimension_spec.time_granularity} for time dimension, " - f"{original_time_spine_dim_instance.spec.time_granularity} for time spine." - ) - + for time_dimension_spec in node.requested_agg_time_dimension_specs: # Apply grain to time spine select expression, unless grain already matches original time spine column. should_skip_date_trunc = ( time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity @@ -1533,6 +1465,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet # Apply date_part to time spine column select expression. if time_dimension_spec.date_part: select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr) + time_dim_spec = original_time_spine_dim_instance.spec.with_grain_and_date_part( time_granularity=time_dimension_spec.time_granularity, date_part=time_dimension_spec.date_part ) @@ -1615,10 +1548,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod ) # Build output time spine instances and columns. - time_spine_instance = TimeDimensionInstance( + time_spine_instance = TimeDimensionInstance.from_properties( defined_from=parent_time_dimension_instance.defined_from, - associated_columns=(self._column_association_resolver.resolve_spec(node.time_dimension_spec),), spec=node.time_dimension_spec, + column_association_resolver=self._column_association_resolver, ) time_spine_instance_set = InstanceSet(time_dimension_instances=(time_spine_instance,)) time_spine_select_columns = ( diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index c6e25835b5..b801d958d3 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -803,6 +803,7 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102 ) +# TODO: delete this class & all uses. It doesn't do anything. class ChangeAssociatedColumns(InstanceSetTransform[InstanceSet]): """Change the columns associated with instances to the one specified by the resolver. diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 13954157f7..4da3945e3a 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -12,7 +12,7 @@ from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinDescription from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode -from metricflow.dataset.sql_dataset import SqlDataSet +from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr from metricflow.sql.sql_exprs import ( SqlColumnReference, @@ -45,23 +45,6 @@ class ColumnEqualityDescription: treat_nulls_as_equal: bool = False -@dataclass(frozen=True) -class AnnotatedSqlDataSet: - """Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan.""" - - data_set: SqlDataSet - alias: str - _metric_time_column_name: Optional[str] = None - - @property - def metric_time_column_name(self) -> str: - """Direct accessor for the optional metric time name, only safe to call when we know that value is set.""" - assert ( - self._metric_time_column_name - ), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!" - return self._metric_time_column_name - - class SqlQueryPlanJoinBuilder: """Helper class for constructing various join components in a SqlQueryPlan.""" diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index e45b8bd792..9f0c80f891 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -243,6 +243,8 @@ def create_copy(self) -> SqlSelectStatementNode: # noqa: D102 distinct=self.distinct, ) + # TODO: add helper to get column from spec + @dataclass(frozen=True, eq=False) class SqlTableNode(SqlQueryPlanNode): diff --git a/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py b/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py index 0dbfae51fc..c8384b85ef 100644 --- a/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py +++ b/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py @@ -611,6 +611,36 @@ def test_derived_cumulative_metric_with_non_default_grains( ) +@pytest.mark.sql_engine_snapshot +def test_cumulative_metric_with_metric_time_where_filter_not_in_group_by( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, + query_parser: MetricFlowQueryParser, + # mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture], + sql_client: SqlClient, +) -> None: + """Test querying a derived metric with a cumulative input metric using non-default grains.""" + query_spec = query_parser.parse_and_validate_query( + metric_names=("trailing_2_months_revenue_sub_10",), + group_by_names=("metric_time__week",), + where_constraints=[ + PydanticWhereFilter(where_sql_template=("{{ TimeDimension('metric_time', 'day') }} >= '2020-01-03' ")) + ], + ).query_spec + + render_and_check( + request=request, + mf_test_configuration=mf_test_configuration, + dataflow_to_sql_converter=dataflow_to_sql_converter, + sql_client=sql_client, + dataflow_plan_builder=dataflow_plan_builder, + query_spec=query_spec, + ) + assert 0 + + # TODO: write the following tests when unblocked # - Query cumulative metric with non-day default_grain (using default grain and non-default grain) # - Query 2 metrics with different default_grains using metric_time (no grain specified)