diff --git a/metricflow-semantics/metricflow_semantics/instances.py b/metricflow-semantics/metricflow_semantics/instances.py index 779a2c879..c6a0f8e2b 100644 --- a/metricflow-semantics/metricflow_semantics/instances.py +++ b/metricflow-semantics/metricflow_semantics/instances.py @@ -49,7 +49,9 @@ class MdoInstance(ABC, Generic[SpecT]): @property def associated_column(self) -> ColumnAssociation: """Helper for getting the associated column until support for multiple associated columns is added.""" - assert len(self.associated_columns) == 1 + assert ( + len(self.associated_columns) == 1 + ), f"Expected exactly one column for {self.__class__.__name__}, but got {self.associated_columns}" return self.associated_columns[0] def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: diff --git a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py index 40cc5fe81..07b0d73f4 100644 --- a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py +++ b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py @@ -102,13 +102,28 @@ def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102 ) ) + def add_specs( + self, + dimension_specs: Tuple[DimensionSpec, ...] = (), + time_dimension_specs: Tuple[TimeDimensionSpec, ...] = (), + entity_specs: Tuple[EntitySpec, ...] = (), + group_by_metric_specs: Tuple[GroupByMetricSpec, ...] = (), + ) -> LinkableSpecSet: + """Return a new set with the new specs in addition to the existing ones.""" + return LinkableSpecSet( + dimension_specs=self.dimension_specs + dimension_specs, + time_dimension_specs=self.time_dimension_specs + time_dimension_specs, + entity_specs=self.entity_specs + entity_specs, + group_by_metric_specs=self.group_by_metric_specs + group_by_metric_specs, + ) + @override def merge(self, other: LinkableSpecSet) -> LinkableSpecSet: - return LinkableSpecSet( - dimension_specs=self.dimension_specs + other.dimension_specs, - time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs, - entity_specs=self.entity_specs + other.entity_specs, - group_by_metric_specs=self.group_by_metric_specs + other.group_by_metric_specs, + return self.add_specs( + dimension_specs=other.dimension_specs, + time_dimension_specs=other.time_dimension_specs, + entity_specs=other.entity_specs, + group_by_metric_specs=other.group_by_metric_specs, ) @classmethod diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index c7cdc5ebb..0ad7e9ed3 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -182,7 +182,6 @@ def _build_query_output_node( where_filter_specs=(), pushdown_enabled_types=frozenset({PredicateInputType.TIME_RANGE_CONSTRAINT}), ) - return self._build_metrics_output_node( metric_specs=tuple( MetricSpec( @@ -236,6 +235,13 @@ def _optimize_plan(self, plan: DataflowPlan, optimizations: FrozenSet[DataflowPl return plan + def _get_minimum_metric_time_spec_for_metric(self, metric_reference: MetricReference) -> TimeDimensionSpec: + """Gets the minimum metric time spec for the given metric reference.""" + min_granularity = ExpandedTimeGranularity.from_time_granularity( + self._metric_lookup.get_min_queryable_time_granularity(metric_reference) + ) + return DataSet.metric_time_dimension_spec(min_granularity) + def _build_aggregated_conversion_node( self, metric_spec: MetricSpec, @@ -307,14 +313,11 @@ def _build_aggregated_conversion_node( # Get the time dimension used to calculate the conversion window # Currently, both the base/conversion measure uses metric_time as it's the default agg time dimension. # However, eventually, there can be user-specified time dimensions used for this calculation. - default_granularity = ExpandedTimeGranularity.from_time_granularity( - self._metric_lookup.get_min_queryable_time_granularity(metric_spec.reference) - ) - metric_time_dimension_spec = DataSet.metric_time_dimension_spec(default_granularity) + min_metric_time_spec = self._get_minimum_metric_time_spec_for_metric(metric_spec.reference) # Filter the source nodes with only the required specs needed for the calculation constant_property_specs = [] - required_local_specs = [base_measure_spec.measure_spec, entity_spec, metric_time_dimension_spec] + list( + required_local_specs = [base_measure_spec.measure_spec, entity_spec, min_metric_time_spec] + list( base_measure_recipe.required_local_linkable_specs.as_tuple ) for constant_property in constant_properties or []: @@ -345,10 +348,10 @@ def _build_aggregated_conversion_node( # adjusted in the opposite direction. join_conversion_node = JoinConversionEventsNode.create( base_node=unaggregated_base_measure_node, - base_time_dimension_spec=metric_time_dimension_spec, + base_time_dimension_spec=min_metric_time_spec, conversion_node=unaggregated_conversion_measure_node, conversion_measure_spec=conversion_measure_spec.measure_spec, - conversion_time_dimension_spec=metric_time_dimension_spec, + conversion_time_dimension_spec=min_metric_time_spec, unique_identifier_keys=(MetadataSpec(MetricFlowReservedKeywords.MF_INTERNAL_UUID.value),), entity_spec=entity_spec, window=window, @@ -444,21 +447,19 @@ def _build_cumulative_metric_output_node( predicate_pushdown_state: PredicatePushdownState, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: - # TODO: [custom granularity] Figure out how to support custom granularities as defaults - default_granularity = ExpandedTimeGranularity.from_time_granularity( - self._metric_lookup.get_min_queryable_time_granularity(metric_spec.reference) - ) + min_metric_time_spec = self._get_minimum_metric_time_spec_for_metric(metric_spec.reference) + min_granularity = min_metric_time_spec.time_granularity queried_agg_time_dimensions = queried_linkable_specs.included_agg_time_dimension_specs_for_metric( metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup ) - query_includes_agg_time_dimension_with_default_granularity = False + query_includes_agg_time_dimension_with_min_granularity = False for time_dimension_spec in queried_agg_time_dimensions: - if time_dimension_spec.time_granularity == default_granularity: - query_includes_agg_time_dimension_with_default_granularity = True + if time_dimension_spec.time_granularity == min_granularity: + query_includes_agg_time_dimension_with_min_granularity = True break - if query_includes_agg_time_dimension_with_default_granularity or not queried_agg_time_dimensions: + if query_includes_agg_time_dimension_with_min_granularity or len(queried_agg_time_dimensions) == 0: return self._build_base_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, @@ -467,14 +468,11 @@ def _build_cumulative_metric_output_node( for_group_by_source_node=for_group_by_source_node, ) - # If a cumulative metric is queried without default granularity, it will need to be aggregated twice - + # If a cumulative metric is queried without its minimum granularity, it will need to be aggregated twice: # once as a normal metric, and again using a window function to narrow down to one row per granularity period. # In this case, add metric time at the default granularity to the linkable specs. It will be used in the order by # clause of the window function and later excluded from the output selections. - default_metric_time = DataSet.metric_time_dimension_spec(default_granularity) - include_linkable_specs = queried_linkable_specs.merge( - LinkableSpecSet(time_dimension_specs=(default_metric_time,)) - ) + include_linkable_specs = queried_linkable_specs.add_specs(time_dimension_specs=(min_metric_time_spec,)) compute_metrics_node = self._build_base_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=include_linkable_specs, @@ -485,7 +483,7 @@ def _build_cumulative_metric_output_node( return WindowReaggregationNode.create( parent_node=compute_metrics_node, metric_spec=metric_spec, - order_by_spec=default_metric_time, + order_by_spec=min_metric_time_spec, partition_by_specs=queried_linkable_specs.as_tuple, ) @@ -1613,10 +1611,6 @@ def _build_aggregated_measure_from_measure_source_node( # If querying an offset metric, join to time spine before aggregation. if before_aggregation_time_spine_join_description is not None: - assert queried_agg_time_dimension_specs, ( - "Joining to time spine requires querying with metric time or the appropriate agg_time_dimension." - "This should have been caught by validations." - ) assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, ( f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a " f"new use case." diff --git a/metricflow/dataflow/nodes/join_over_time.py b/metricflow/dataflow/nodes/join_over_time.py index c766bb7dc..82efccf88 100644 --- a/metricflow/dataflow/nodes/join_over_time.py +++ b/metricflow/dataflow/nodes/join_over_time.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.protocols import MetricTimeWindow from dbt_semantic_interfaces.type_enums import TimeGranularity @@ -26,7 +26,7 @@ class JoinOverTimeRangeNode(DataflowPlanNode): time_range_constraint: Time range to aggregate over. """ - queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...] window: Optional[MetricTimeWindow] grain_to_date: Optional[TimeGranularity] time_range_constraint: Optional[TimeRangeConstraint] @@ -38,7 +38,7 @@ def __post_init__(self) -> None: # noqa: D105 @staticmethod def create( # noqa: D102 parent_node: DataflowPlanNode, - queried_agg_time_dimension_specs: Sequence[TimeDimensionSpec], + queried_agg_time_dimension_specs: Tuple[TimeDimensionSpec, ...], window: Optional[MetricTimeWindow] = None, grain_to_date: Optional[TimeGranularity] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, diff --git a/metricflow/dataflow/nodes/join_to_time_spine.py b/metricflow/dataflow/nodes/join_to_time_spine.py index a17b2e428..18ab681f0 100644 --- a/metricflow/dataflow/nodes/join_to_time_spine.py +++ b/metricflow/dataflow/nodes/join_to_time_spine.py @@ -31,7 +31,9 @@ class JoinToTimeSpineNode(DataflowPlanNode, ABC): offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine. """ + # TODO: rename property to required_agg_time_dimension_specs requested_agg_time_dimension_specs: Sequence[TimeDimensionSpec] + # TODO remove this property use_custom_agg_time_dimension: bool join_type: SqlJoinType time_range_constraint: Optional[TimeRangeConstraint] diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index 363dbac33..2b86eb479 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import List, Optional, Sequence +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set -from metricflow_semantics.instances import EntityInstance, InstanceSet +from metricflow_semantics.instances import EntityInstance, InstanceSet, TimeDimensionInstance from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.specs.column_assoc import ColumnAssociation from metricflow_semantics.specs.dimension_spec import DimensionSpec @@ -122,32 +123,57 @@ def column_association_for_dimension( return column_associations_to_return[0] - def column_association_for_time_dimension( - self, - time_dimension_spec: TimeDimensionSpec, - ) -> ColumnAssociation: - """Given the name of the time dimension, return the set of columns associated with it in the data set.""" + def instances_for_time_dimensions( + self, time_dimension_specs: Sequence[TimeDimensionSpec] + ) -> Tuple[TimeDimensionInstance, ...]: + """Return the instances associated with these specs in the data set.""" + time_dimension_specs_set = set(time_dimension_specs) matching_instances = 0 - column_associations_to_return = None + instances_to_return: Tuple[TimeDimensionInstance, ...] = () for time_dimension_instance in self.instance_set.time_dimension_instances: - if time_dimension_instance.spec == time_dimension_spec: - column_associations_to_return = time_dimension_instance.associated_columns + if time_dimension_instance.spec in time_dimension_specs_set: + instances_to_return += (time_dimension_instance,) matching_instances += 1 - if matching_instances > 1: + if matching_instances != len(time_dimension_specs_set): raise RuntimeError( - f"More than one time dimension instance with spec {time_dimension_spec} in " - f"instance set: {self.instance_set}" + f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_specs_set}\n" + f"Instances: {instances_to_return}" ) - if not column_associations_to_return: - raise RuntimeError( - f"No time dimension instances with spec {time_dimension_spec} in instance set: {self.instance_set}" - ) + return instances_to_return - return column_associations_to_return[0] + def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> TimeDimensionInstance: + """Given the name of the time dimension, return the instance associated with it in the data set.""" + return self.instances_for_time_dimensions((time_dimension_spec,))[0] + + def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> ColumnAssociation: + """Given the name of the time dimension, return the set of columns associated with it in the data set.""" + return self.instance_for_time_dimension(time_dimension_spec).associated_column @property @override def semantic_model_reference(self) -> Optional[SemanticModelReference]: return None + + def annotate(self, alias: str, metric_time_spec: TimeDimensionSpec) -> AnnotatedSqlDataSet: + """Convert to an AnnotatedSqlDataSet with specified metadata.""" + metric_time_column_name = self.column_association_for_time_dimension(metric_time_spec).column_name + return AnnotatedSqlDataSet(data_set=self, alias=alias, _metric_time_column_name=metric_time_column_name) + + +@dataclass(frozen=True) +class AnnotatedSqlDataSet: + """Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan.""" + + data_set: SqlDataSet + alias: str + _metric_time_column_name: Optional[str] = None + + @property + def metric_time_column_name(self) -> str: + """Direct accessor for the optional metric time name, only safe to call when we know that value is set.""" + assert ( + self._metric_time_column_name + ), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!" + return self._metric_time_column_name diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 61d3d510e..d3a4abb71 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -6,9 +6,8 @@ from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union from dbt_semantic_interfaces.enum_extension import assert_values_exhausted -from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType -from dbt_semantic_interfaces.references import EntityReference, MetricModelReference, SemanticModelElementReference +from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType from dbt_semantic_interfaces.type_enums.conversion_calculation_type import ConversionCalculationType from dbt_semantic_interfaces.type_enums.period_agg import PeriodAggregation @@ -466,70 +465,41 @@ def visit_source_node(self, node: ReadSqlSourceNode) -> SqlDataSet: def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode) -> SqlDataSet: """Generate time range join SQL.""" table_alias_to_instance_set: OrderedDict[str, InstanceSet] = OrderedDict() - input_data_set = node.parent_node.accept(self) - input_data_set_alias = self._next_unique_table_alias() + parent_data_set = node.parent_node.accept(self) + parent_data_set_alias = self._next_unique_table_alias() - # Find requested agg_time_dimensions in parent instance set. - # Will use instance with the smallest base granularity in time spine join. - agg_time_dimension_instance_for_join: Optional[TimeDimensionInstance] = None - requested_agg_time_dimension_instances: Tuple[TimeDimensionInstance, ...] = () - for instance in input_data_set.instance_set.time_dimension_instances: - if instance.spec in node.queried_agg_time_dimension_specs: - requested_agg_time_dimension_instances += (instance,) - if not agg_time_dimension_instance_for_join or ( - instance.spec.time_granularity.base_granularity.to_int() - < agg_time_dimension_instance_for_join.spec.time_granularity.base_granularity.to_int() - ): - agg_time_dimension_instance_for_join = instance - assert ( - agg_time_dimension_instance_for_join - ), "Specified metric time spec not found in parent data set. This should have been caught by validations." + # For the purposes of this node, use base grains. Custom grains will be joined later in the dataflow plan. + agg_time_dimension_specs = tuple({spec.with_base_grain() for spec in node.queried_agg_time_dimension_specs}) + # Assemble time_spine dataset with a column for each agg_time_dimension requested. + agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions(agg_time_dimension_specs) time_spine_data_set_alias = self._next_unique_table_alias() - - # Assemble time_spine dataset with requested agg time dimension instances selected. time_spine_data_set = self._make_time_spine_data_set( - agg_time_dimension_instances=requested_agg_time_dimension_instances, - time_range_constraint=node.time_range_constraint, + agg_time_dimension_instances=agg_time_dimension_instances, time_range_constraint=node.time_range_constraint ) - table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set + # Build the join description. + join_spec = self._choose_instance_for_time_spine_join(agg_time_dimension_instances).spec + annotated_parent = parent_data_set.annotate(alias=parent_data_set_alias, metric_time_spec=join_spec) + annotated_time_spine = time_spine_data_set.annotate(alias=time_spine_data_set_alias, metric_time_spec=join_spec) join_desc = SqlQueryPlanJoinBuilder.make_cumulative_metric_time_range_join_description( - node=node, - metric_data_set=AnnotatedSqlDataSet( - data_set=input_data_set, - alias=input_data_set_alias, - _metric_time_column_name=input_data_set.column_association_for_time_dimension( - agg_time_dimension_instance_for_join.spec - ).column_name, - ), - time_spine_data_set=AnnotatedSqlDataSet( - data_set=time_spine_data_set, - alias=time_spine_data_set_alias, - _metric_time_column_name=time_spine_data_set.column_association_for_time_dimension( - agg_time_dimension_instance_for_join.spec - ).column_name, - ), + node=node, metric_data_set=annotated_parent, time_spine_data_set=annotated_time_spine ) - # Remove instances of agg_time_dimension from input data set. They'll be replaced with time spine instances. - agg_time_dimension_specs = tuple(dim.spec for dim in requested_agg_time_dimension_instances) - modified_input_instance_set = input_data_set.instance_set.transform( + # Build select columns, replacing agg_time_dimensions from the parent node with columns from the time spine. + table_alias_to_instance_set[time_spine_data_set_alias] = time_spine_data_set.instance_set + table_alias_to_instance_set[parent_data_set_alias] = parent_data_set.instance_set.transform( FilterElements(exclude_specs=InstanceSpecSet(time_dimension_specs=agg_time_dimension_specs)) ) - table_alias_to_instance_set[input_data_set_alias] = modified_input_instance_set - - # The output instances are the same as the input instances. - output_instance_set = ChangeAssociatedColumns(self._column_association_resolver).transform( - input_data_set.instance_set + select_columns = create_simple_select_columns_for_instance_sets( + column_resolver=self._column_association_resolver, table_alias_to_instance_set=table_alias_to_instance_set ) + return SqlDataSet( - instance_set=output_instance_set, + instance_set=parent_data_set.instance_set, # The output instances are the same as the input instances. sql_select_node=SqlSelectStatementNode.create( description=node.description, - select_columns=create_simple_select_columns_for_instance_sets( - self._column_association_resolver, table_alias_to_instance_set - ), + select_columns=select_columns, from_source=time_spine_data_set.checked_sql_select_node, from_source_alias=time_spine_data_set_alias, join_descs=(join_desc,), @@ -1390,35 +1360,29 @@ def visit_semi_additive_join_node(self, node: SemiAdditiveJoinNode) -> SqlDataSe ), ) + def _choose_instance_for_time_spine_join( + self, agg_time_dimension_instances: Sequence[TimeDimensionInstance] + ) -> TimeDimensionInstance: + """Find the agg_time_dimension instance with the smallest grain to use for the time spine join.""" + # We can't use a date part spec to join to the time spine, so filter those out. + agg_time_dimension_instances = [ + instance for instance in agg_time_dimension_instances if not instance.spec.date_part + ] + assert len(agg_time_dimension_instances) > 0, ( + "No appropriate agg_time_dimension was found to join to the time spine. " + "This indicates that the dataflow plan was configured incorrectly." + ) + agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int()) + return agg_time_dimension_instances[0] + def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet: # noqa: D102 parent_data_set = node.parent_node.accept(self) parent_alias = self._next_unique_table_alias() - if node.use_custom_agg_time_dimension: - agg_time_dimension = node.requested_agg_time_dimension_specs[0] - agg_time_element_name = agg_time_dimension.element_name - agg_time_entity_links: Tuple[EntityReference, ...] = agg_time_dimension.entity_links - else: - agg_time_element_name = METRIC_TIME_ELEMENT_NAME - agg_time_entity_links = () - - # Find the time dimension instances in the parent data set that match the one we want to join with. - agg_time_dimension_instances: List[TimeDimensionInstance] = [] - for instance in parent_data_set.instance_set.time_dimension_instances: - if ( - instance.spec.date_part is None # Ensure we don't join using an instance with date part - and instance.spec.element_name == agg_time_element_name - and instance.spec.entity_links == agg_time_entity_links - ): - agg_time_dimension_instances.append(instance) - - # Choose the instance with the smallest base granularity available. - agg_time_dimension_instances.sort(key=lambda instance: instance.spec.time_granularity.base_granularity.to_int()) - assert len(agg_time_dimension_instances) > 0, ( - "Couldn't find requested agg_time_dimension in parent data set. The dataflow plan may have been " - "configured incorrectly." + agg_time_dimension_instances = parent_data_set.instances_for_time_dimensions( + node.requested_agg_time_dimension_specs ) - agg_time_dimension_instance_for_join = agg_time_dimension_instances[0] + agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join(agg_time_dimension_instances) # Build time spine data set using the requested agg_time_dimension name. time_spine_alias = self._next_unique_table_alias() @@ -1439,47 +1403,18 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet parent_alias=parent_alias, ) - # Select all instances from the parent data set, EXCEPT agg_time_dimensions. - # The agg_time_dimensions will be selected from the time spine data set. - time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = () - time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = () - for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances: - if ( - time_dimension_instance.spec.element_name == agg_time_element_name - and time_dimension_instance.spec.entity_links == agg_time_entity_links - ): - time_dimensions_to_select_from_time_spine += (time_dimension_instance,) - else: - time_dimensions_to_select_from_parent += (time_dimension_instance,) - parent_instance_set = InstanceSet( - measure_instances=parent_data_set.instance_set.measure_instances, - dimension_instances=parent_data_set.instance_set.dimension_instances, - time_dimension_instances=tuple( - time_dimension_instance - for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances - if not ( - time_dimension_instance.spec.element_name == agg_time_element_name - and time_dimension_instance.spec.entity_links == agg_time_entity_links - ) - ), - entity_instances=parent_data_set.instance_set.entity_instances, - metric_instances=parent_data_set.instance_set.metric_instances, - metadata_instances=parent_data_set.instance_set.metadata_instances, + # Select all instances from the parent data set EXCEPT agg time dimensions, which will be selected from the time spine + parent_instance_set = parent_data_set.instance_set.transform( + FilterElements( + exclude_specs=InstanceSpecSet(time_dimension_specs=tuple(node.requested_agg_time_dimension_specs)) + ) ) parent_select_columns = create_simple_select_columns_for_instance_sets( self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set}) ) - # Select matching instance from time spine data set (using base grain - custom grain will be joined in a later node). - original_time_spine_dim_instance: Optional[TimeDimensionInstance] = None - for time_dimension_instance in time_spine_dataset.instance_set.time_dimension_instances: - if time_dimension_instance.spec == agg_time_dimension_instance_for_join.spec: - original_time_spine_dim_instance = time_dimension_instance - break - assert original_time_spine_dim_instance, ( - "Couldn't find requested agg_time_dimension_instance_for_join in time spine data set, which " - f"indicates it may have been configured incorrectly. Expected: {agg_time_dimension_instance_for_join.spec};" - f" Got: {[instance.spec for instance in time_spine_dataset.instance_set.time_dimension_instances]}" + original_time_spine_dim_instance = time_spine_dataset.instance_for_time_dimension( + agg_time_dimension_instance_for_join.spec ) time_spine_column_select_expr: Union[ SqlColumnReferenceExpression, SqlDateTruncExpression @@ -1500,19 +1435,9 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet and original_time_spine_dim_instance.spec not in node.requested_agg_time_dimension_specs ) + # TODO: column-building is handled in 2 different places (here and _make_time_spine_data_set) # Add requested granularities (if different from time_spine) and date_parts to time spine column. - for time_dimension_instance in time_dimensions_to_select_from_time_spine: - time_dimension_spec = time_dimension_instance.spec - if ( - time_dimension_spec.time_granularity.base_granularity.to_int() - < original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int() - ): - raise RuntimeError( - f"Can't join to time spine for a time dimension with a smaller granularity than that of the time " - f"spine column. Got {time_dimension_spec.time_granularity} for time dimension, " - f"{original_time_spine_dim_instance.spec.time_granularity} for time spine." - ) - + for time_dimension_spec in node.requested_agg_time_dimension_specs: # Apply grain to time spine select expression, unless grain already matches original time spine column. should_skip_date_trunc = ( time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity @@ -1541,6 +1466,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet # Apply date_part to time spine column select expression. if time_dimension_spec.date_part: select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr) + time_dim_spec = original_time_spine_dim_instance.spec.with_grain_and_date_part( time_granularity=time_dimension_spec.time_granularity, date_part=time_dimension_spec.date_part ) @@ -1590,17 +1516,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod # New dataset will be joined to parent dataset without a subquery, so use the same FROM alias as the parent node. parent_alias = parent_data_set.checked_sql_select_node.from_source_alias - parent_time_dimension_instance: Optional[TimeDimensionInstance] = None - for instance in parent_data_set.instance_set.time_dimension_instances: - if instance.spec == node.time_dimension_spec.with_base_grain(): - parent_time_dimension_instance = instance - break - parent_column: Optional[SqlSelectColumn] = None - assert parent_time_dimension_instance, ( - "JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. " - f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain()}; " - f"Got: {[instance.spec for instance in parent_data_set.instance_set.time_dimension_instances]}" + parent_time_dimension_instance = parent_data_set.instance_for_time_dimension( + node.time_dimension_spec.with_base_grain() ) + parent_column: Optional[SqlSelectColumn] = None for select_column in parent_data_set.checked_sql_select_node.select_columns: if select_column.column_alias == parent_time_dimension_instance.associated_column.column_name: parent_column = select_column diff --git a/metricflow/plan_conversion/instance_converters.py b/metricflow/plan_conversion/instance_converters.py index c6e25835b..b801d958d 100644 --- a/metricflow/plan_conversion/instance_converters.py +++ b/metricflow/plan_conversion/instance_converters.py @@ -803,6 +803,7 @@ def transform(self, instance_set: InstanceSet) -> SelectColumnSet: # noqa: D102 ) +# TODO: delete this class & all uses. It doesn't do anything. class ChangeAssociatedColumns(InstanceSetTransform[InstanceSet]): """Change the columns associated with instances to the one specified by the resolver. diff --git a/metricflow/plan_conversion/sql_join_builder.py b/metricflow/plan_conversion/sql_join_builder.py index 13954157f..4da3945e3 100644 --- a/metricflow/plan_conversion/sql_join_builder.py +++ b/metricflow/plan_conversion/sql_join_builder.py @@ -12,7 +12,7 @@ from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode from metricflow.dataflow.nodes.join_to_base import JoinDescription from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode -from metricflow.dataset.sql_dataset import SqlDataSet +from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr from metricflow.sql.sql_exprs import ( SqlColumnReference, @@ -45,23 +45,6 @@ class ColumnEqualityDescription: treat_nulls_as_equal: bool = False -@dataclass(frozen=True) -class AnnotatedSqlDataSet: - """Class to bind a DataSet to transient properties associated with it at a given point in the SqlQueryPlan.""" - - data_set: SqlDataSet - alias: str - _metric_time_column_name: Optional[str] = None - - @property - def metric_time_column_name(self) -> str: - """Direct accessor for the optional metric time name, only safe to call when we know that value is set.""" - assert ( - self._metric_time_column_name - ), "Expected a valid metric time dimension name to be associated with this dataset, but did not get one!" - return self._metric_time_column_name - - class SqlQueryPlanJoinBuilder: """Helper class for constructing various join components in a SqlQueryPlan.""" diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index e45b8bd79..9f0c80f89 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -243,6 +243,8 @@ def create_copy(self) -> SqlSelectStatementNode: # noqa: D102 distinct=self.distinct, ) + # TODO: add helper to get column from spec + @dataclass(frozen=True, eq=False) class SqlTableNode(SqlQueryPlanNode): diff --git a/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py b/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py index 0dbfae51f..c8384b85e 100644 --- a/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py +++ b/tests_metricflow/query_rendering/test_cumulative_metric_rendering.py @@ -611,6 +611,36 @@ def test_derived_cumulative_metric_with_non_default_grains( ) +@pytest.mark.sql_engine_snapshot +def test_cumulative_metric_with_metric_time_where_filter_not_in_group_by( + request: FixtureRequest, + mf_test_configuration: MetricFlowTestConfiguration, + dataflow_plan_builder: DataflowPlanBuilder, + dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter, + query_parser: MetricFlowQueryParser, + # mf_engine_test_fixture_mapping: Mapping[SemanticManifestSetup, MetricFlowEngineTestFixture], + sql_client: SqlClient, +) -> None: + """Test querying a derived metric with a cumulative input metric using non-default grains.""" + query_spec = query_parser.parse_and_validate_query( + metric_names=("trailing_2_months_revenue_sub_10",), + group_by_names=("metric_time__week",), + where_constraints=[ + PydanticWhereFilter(where_sql_template=("{{ TimeDimension('metric_time', 'day') }} >= '2020-01-03' ")) + ], + ).query_spec + + render_and_check( + request=request, + mf_test_configuration=mf_test_configuration, + dataflow_to_sql_converter=dataflow_to_sql_converter, + sql_client=sql_client, + dataflow_plan_builder=dataflow_plan_builder, + query_spec=query_spec, + ) + assert 0 + + # TODO: write the following tests when unblocked # - Query cumulative metric with non-day default_grain (using default grain and non-default grain) # - Query 2 metrics with different default_grains using metric_time (no grain specified)