From 8bb00499da325b29ad628d2ab34cefd5a5dff34d Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Sun, 12 Nov 2023 12:23:15 -0800 Subject: [PATCH] Consolidate conditional logic for time spine joins into MetricInputMeasureSpec. --- .../dataflow/builder/dataflow_plan_builder.py | 217 ++++++++++++------ metricflow/dataflow/dataflow_plan.py | 6 +- metricflow/specs/specs.py | 46 +++- .../test/test_instance_serialization.py | 4 + 4 files changed, 197 insertions(+), 76 deletions(-) diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index 755f5c0358..2c1f97c6be 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -53,9 +53,11 @@ from metricflow.specs.column_assoc import ColumnAssociationResolver from metricflow.specs.specs import ( InstanceSpecSet, + JoinToTimeSpineDescription, LinkableInstanceSpec, LinkableSpecSet, LinklessEntitySpec, + MeasureCulminationDescription, MeasureSpec, MetricFlowQuerySpec, MetricInputMeasureSpec, @@ -174,12 +176,19 @@ def _build_base_metric_output_node( """Builds a node to compute a metric that is not defined from other metrics.""" metric_reference = metric_spec.reference metric = self._metric_lookup.get_metric(metric_reference) - metric_input_measure_specs = self._measures_for_metric( + metric_input_measure_spec = self._build_metric_input_measure_for_base_metric( metric_reference=metric_reference, column_association_resolver=self._column_association_resolver, + query_contains_metric_time=queried_linkable_specs.contains_metric_time, + child_metric_offset_window=metric_spec.offset_window, + child_metric_offset_to_grain=metric_spec.offset_to_grain, + culmination_description=MeasureCulminationDescription( + cumulative_window=metric.type_params.window, + cumulative_grain_to_date=metric.type_params.grain_to_date, + ) + if metric.type is MetricType.CUMULATIVE + else None, ) - assert len(metric_input_measure_specs) == 1, "Simple and cumulative metrics must have one input measure." - metric_input_measure_spec = metric_input_measure_specs[0] logger.info( f"For {metric_spec}, needed measure is:\n" @@ -190,19 +199,13 @@ def _build_base_metric_output_node( combined_where = ( combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint ) + aggregated_measures_node = self.build_aggregated_measure( metric_input_measure_spec=metric_input_measure_spec, - metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, where_constraint=combined_where, time_range_constraint=time_range_constraint, - cumulative=metric.type == MetricType.CUMULATIVE, - cumulative_window=metric.type_params.window if metric.type == MetricType.CUMULATIVE else None, - cumulative_grain_to_date=( - metric.type_params.grain_to_date if metric.type == MetricType.CUMULATIVE else None - ), ) - return self.build_computed_metrics_node( metric_spec=metric_spec, aggregated_measures_node=aggregated_measures_node, @@ -225,13 +228,42 @@ def _build_derived_metric_output_node( f"For {metric.type} metric: {metric_spec}, needed metrics are:\n" f"{pformat_big_objects(metric_input_specs=metric_input_specs)}" ) + + parent_nodes: List[BaseOutput] = [] + + # TODO: Think I found an edge case that was not handled. + for metric_input_spec in metric_input_specs: + if (metric_spec.offset_to_grain is not None or metric_spec.offset_to_grain is not None) and ( + metric_input_spec.offset_window is not None or metric_input_spec.offset_to_grain is not None + ): + raise NotImplementedError( + f"Multiple descendent metrics in a derived metric hierarchy are not yet supported. " + f"For {metric_spec}, the parent metric input is {metric_input_spec}" + ) + + parent_nodes.append( + self._build_any_metric_output_node( + metric_spec=MetricSpec( + element_name=metric_input_spec.element_name, + constraint=metric_input_spec.constraint, + alias=metric_input_spec.alias, + offset_window=metric_input_spec.offset_window, + offset_to_grain=metric_input_spec.offset_to_grain, + ), + queried_linkable_specs=queried_linkable_specs, + where_constraint=where_constraint, + time_range_constraint=time_range_constraint, + ) + ) + + if len(parent_nodes) == 1: + return ComputeMetricsNode( + parent_node=parent_nodes[0], + metric_specs=[metric_spec], + ) + return ComputeMetricsNode( - parent_node=self._build_metrics_output_node( - metric_specs=metric_input_specs, - queried_linkable_specs=queried_linkable_specs, - where_constraint=where_constraint, - time_range_constraint=time_range_constraint, - ), + parent_node=CombineMetricsNode(parent_nodes=parent_nodes), metric_specs=[metric_spec], ) @@ -645,45 +677,76 @@ def build_computed_metrics_node( metric_specs=[metric_spec], ) - def _measures_for_metric( + def _build_metric_input_measure_for_base_metric( self, metric_reference: MetricReference, column_association_resolver: ColumnAssociationResolver, - ) -> Sequence[MetricInputMeasureSpec]: - """Return the measure specs required to compute the metric.""" + child_metric_offset_window: Optional[MetricTimeWindow], + child_metric_offset_to_grain: Optional[TimeGranularity], + query_contains_metric_time: bool, + culmination_description: Optional[MeasureCulminationDescription], + ) -> MetricInputMeasureSpec: + """Return the measure specs required to compute the base metric.""" metric = self._metric_lookup.get_metric(metric_reference) - input_measure_specs: List[MetricInputMeasureSpec] = [] - for input_measure in metric.input_measures: - measure_spec = MeasureSpec( - element_name=input_measure.name, - non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get( - input_measure.measure_reference - ), - ) - spec = MetricInputMeasureSpec( - measure_spec=measure_spec, - constraint=WhereSpecFactory( - column_association_resolver=column_association_resolver, - ).create_from_where_filter_intersection(input_measure.filter), - alias=input_measure.alias, - join_to_timespine=input_measure.join_to_timespine, - fill_nulls_with=input_measure.fill_nulls_with, + if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE: + pass + elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED: + raise ValueError("This should only be called for base metrics.") + else: + assert_values_exhausted(metric.type) + + assert ( + len(metric.input_measures) == 1 + ), f"A base metric should not have multiple measures. Got{metric.input_measures}" + + input_measure = metric.input_measures[0] + + measure_spec = MeasureSpec( + element_name=input_measure.name, + non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get( + input_measure.measure_reference + ), + ) + + before_aggregation_time_spine_join_description = None + # If querying an offset metric, join to time spine. + if child_metric_offset_window is not None or child_metric_offset_to_grain is not None: + before_aggregation_time_spine_join_description = JoinToTimeSpineDescription( + join_type=SqlJoinType.INNER, + offset_window=child_metric_offset_window, + offset_to_grain=child_metric_offset_to_grain, ) - input_measure_specs.append(spec) - return tuple(input_measure_specs) + # Even if the measure is configured to join to time spine, if there's no metric_time in the query, + # there's no need to join to the time spine since all metric_time will be aggregated. + after_aggregation_time_spine_join_description = None + if input_measure.join_to_timespine and query_contains_metric_time: + after_aggregation_time_spine_join_description = JoinToTimeSpineDescription( + join_type=SqlJoinType.LEFT_OUTER, + offset_window=None, + offset_to_grain=None, + ) + + return MetricInputMeasureSpec( + measure_spec=measure_spec, + offset_window=child_metric_offset_window, + offset_to_grain=child_metric_offset_to_grain, + culmination_description=culmination_description, + constraint=WhereSpecFactory( + column_association_resolver=column_association_resolver, + ).create_from_where_filter_intersection(input_measure.filter), + alias=input_measure.alias, + before_aggregation_time_spine_join_description=before_aggregation_time_spine_join_description, + after_aggregation_time_spine_join_description=after_aggregation_time_spine_join_description, + ) def build_aggregated_measure( self, metric_input_measure_spec: MetricInputMeasureSpec, - metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, where_constraint: Optional[WhereFilterSpec] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, - cumulative: Optional[bool] = False, - cumulative_window: Optional[MetricTimeWindow] = None, - cumulative_grain_to_date: Optional[TimeGranularity] = None, ) -> BaseOutput: """Returns a node where the measures are aggregated by the linkable specs and constrained appropriately. @@ -691,8 +754,10 @@ def build_aggregated_measure( a composite set of aggregations originating from multiple semantic models, and joined into a single aggregated set of measures. """ + measure_spec = metric_input_measure_spec.measure_spec measure_constraint = metric_input_measure_spec.constraint - logger.info(f"Building aggregated measure: {metric_input_measure_spec} with constraint: {measure_constraint}") + + logger.info(f"Building aggregated measure: {measure_spec} with constraint: {measure_constraint}") if measure_constraint is None: node_where_constraint = where_constraint elif where_constraint is None: @@ -702,33 +767,30 @@ def build_aggregated_measure( return self._build_aggregated_measure_from_measure_source_node( metric_input_measure_spec=metric_input_measure_spec, - metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, where_constraint=node_where_constraint, time_range_constraint=time_range_constraint, - cumulative=cumulative, - cumulative_window=cumulative_window, - cumulative_grain_to_date=cumulative_grain_to_date, ) def _build_aggregated_measure_from_measure_source_node( self, metric_input_measure_spec: MetricInputMeasureSpec, - metric_spec: MetricSpec, queried_linkable_specs: LinkableSpecSet, where_constraint: Optional[WhereFilterSpec] = None, time_range_constraint: Optional[TimeRangeConstraint] = None, - cumulative: Optional[bool] = False, - cumulative_window: Optional[MetricTimeWindow] = None, - cumulative_grain_to_date: Optional[TimeGranularity] = None, ) -> BaseOutput: - metric_time_dimension_specs = [ - time_dimension_spec - for time_dimension_spec in queried_linkable_specs.time_dimension_specs - if time_dimension_spec.element_name == self._metric_time_dimension_reference.element_name - ] - metric_time_dimension_requested = len(metric_time_dimension_specs) > 0 measure_spec = metric_input_measure_spec.measure_spec + cumulative = metric_input_measure_spec.culmination_description is not None + cumulative_window = ( + metric_input_measure_spec.culmination_description.cumulative_window + if metric_input_measure_spec.culmination_description is not None + else None + ) + cumulative_grain_to_date = ( + metric_input_measure_spec.culmination_description.cumulative_grain_to_date + if metric_input_measure_spec.culmination_description + else None + ) measure_properties = self._build_measure_spec_properties([measure_spec]) non_additive_dimension_spec = measure_properties.non_additive_dimension_spec @@ -787,7 +849,7 @@ def _build_aggregated_measure_from_measure_source_node( # If a cumulative metric is queried with metric_time, join over time range. # Otherwise, the measure will be aggregated over all time. time_range_node: Optional[JoinOverTimeRangeNode] = None - if cumulative and metric_time_dimension_requested: + if cumulative and queried_linkable_specs.contains_metric_time: time_range_node = JoinOverTimeRangeNode( parent_node=measure_recipe.source_node, window=cumulative_window, @@ -797,15 +859,25 @@ def _build_aggregated_measure_from_measure_source_node( # If querying an offset metric, join to time spine. join_to_time_spine_node: Optional[JoinToTimeSpineNode] = None - if metric_spec.offset_window or metric_spec.offset_to_grain: - assert metric_time_dimension_specs, "Joining to time spine requires querying with metric time." + + before_aggregation_time_spine_join_description = ( + metric_input_measure_spec.before_aggregation_time_spine_join_description + ) + if before_aggregation_time_spine_join_description is not None: + assert ( + queried_linkable_specs.contains_metric_time + ), "Joining to time spine requires querying with metric time." + assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, ( + f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a " + f"new use case." + ) join_to_time_spine_node = JoinToTimeSpineNode( parent_node=time_range_node or measure_recipe.source_node, - requested_metric_time_dimension_specs=metric_time_dimension_specs, + requested_metric_time_dimension_specs=list(queried_linkable_specs.metric_time_specs), time_range_constraint=time_range_constraint, - offset_window=metric_spec.offset_window, - offset_to_grain=metric_spec.offset_to_grain, - join_type=SqlJoinType.INNER, + offset_window=before_aggregation_time_spine_join_description.offset_window, + offset_to_grain=before_aggregation_time_spine_join_description.offset_to_grain, + join_type=before_aggregation_time_spine_join_description.join_type, ) # Only get the required measure and the local linkable instances so that aggregations work correctly. @@ -839,7 +911,7 @@ def _build_aggregated_measure_from_measure_source_node( if ( cumulative_metric_adjusted_time_constraint is not None and time_range_constraint is not None - and metric_time_dimension_requested + and queried_linkable_specs.contains_metric_time ): cumulative_metric_constrained_node = ConstrainTimeRangeNode( unaggregated_measure_node, time_range_constraint @@ -890,14 +962,21 @@ def _build_aggregated_measure_from_measure_source_node( parent_node=pre_aggregate_node, metric_input_measure_specs=(metric_input_measure_spec,), ) - - # Only join to time spine if metric time was requested in the query. - if metric_input_measure_spec.join_to_timespine and metric_time_dimension_requested: + after_aggregation_time_spine_join_description = ( + metric_input_measure_spec.after_aggregation_time_spine_join_description + ) + if after_aggregation_time_spine_join_description is not None: + assert after_aggregation_time_spine_join_description.join_type is SqlJoinType.LEFT_OUTER, ( + f"Expected {SqlJoinType.LEFT_OUTER} for joining to time spine after aggregation. Remove this if " + f"there's a new use case." + ) return JoinToTimeSpineNode( parent_node=aggregate_measures_node, - requested_metric_time_dimension_specs=metric_time_dimension_specs, + requested_metric_time_dimension_specs=list(queried_linkable_specs.metric_time_specs), + join_type=after_aggregation_time_spine_join_description.join_type, time_range_constraint=time_range_constraint, - join_type=SqlJoinType.LEFT_OUTER, + offset_window=after_aggregation_time_spine_join_description.offset_window, + offset_to_grain=after_aggregation_time_spine_join_description.offset_to_grain, ) else: return aggregate_measures_node diff --git a/metricflow/dataflow/dataflow_plan.py b/metricflow/dataflow/dataflow_plan.py index 94404d662a..c2efb0a942 100644 --- a/metricflow/dataflow/dataflow_plan.py +++ b/metricflow/dataflow/dataflow_plan.py @@ -445,7 +445,11 @@ class AggregateMeasuresNode(AggregatedMeasuresOutput): constraints applied to the measure. """ - def __init__(self, parent_node: BaseOutput, metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...]) -> None: + def __init__( + self, + parent_node: BaseOutput, + metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...], + ) -> None: """Initializer for AggregateMeasuresNode. The input measure specs are required for downstream nodes to be aware of any input measures with diff --git a/metricflow/specs/specs.py b/metricflow/specs/specs.py index 2c7f95e5ad..8b32542700 100644 --- a/metricflow/specs/specs.py +++ b/metricflow/specs/specs.py @@ -18,6 +18,8 @@ from dbt_semantic_interfaces.dataclass_serialization import SerializableDataclass from dbt_semantic_interfaces.implementations.metric import PydanticMetricTimeWindow +from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME +from dbt_semantic_interfaces.protocols import MetricTimeWindow from dbt_semantic_interfaces.references import ( DimensionReference, EntityReference, @@ -36,6 +38,7 @@ from metricflow.naming.linkable_spec_name import StructuredLinkableSpecName from metricflow.sql.sql_bind_parameters import SqlBindParameters from metricflow.sql.sql_column_type import SqlColumnType +from metricflow.sql.sql_plan import SqlJoinType from metricflow.visitor import VisitorOutputT @@ -469,25 +472,33 @@ def reference(self) -> MetricReference: return MetricReference(element_name=self.element_name) +@dataclass(frozen=True) +class MeasureCulminationDescription: + """If a measure is a part of a cumulative metric, this represents the associated parameters.""" + + cumulative_window: Optional[MetricTimeWindow] + cumulative_grain_to_date: Optional[TimeGranularity] + + @dataclass(frozen=True) class MetricInputMeasureSpec(SerializableDataclass): - """The spec for a measure defined as a metric input. + """The spec for a measure defined as a base metric input. This is necessary because the MeasureSpec is used as a key linking the measures used in the query to the measures defined in the semantic models. Adding metric-specific information, like constraints, causes lookups connecting query -> semantic model to fail in strange ways. This spec, then, provides both the key (in the form of a MeasureSpec) along with whatever measure-specific attributes a user might specify in a metric definition or query accessing the metric itself. - - Note - when specifying a metric comprised of two input instances of the same measure, at least one - must have a distinct alias, otherwise SQL exceptions may occur. This should be enforced via validation. """ measure_spec: MeasureSpec + offset_window: Optional[MetricTimeWindow] = None + offset_to_grain: Optional[TimeGranularity] = None + culmination_description: Optional[MeasureCulminationDescription] = None constraint: Optional[WhereFilterSpec] = None alias: Optional[str] = None - join_to_timespine: bool = False - fill_nulls_with: Optional[int] = None + before_aggregation_time_spine_join_description: Optional[JoinToTimeSpineDescription] = None + after_aggregation_time_spine_join_description: Optional[JoinToTimeSpineDescription] = None @property def post_aggregation_spec(self) -> MeasureSpec: @@ -521,6 +532,20 @@ class LinkableSpecSet(Mergeable, SerializableDataclass): time_dimension_specs: Tuple[TimeDimensionSpec, ...] = () entity_specs: Tuple[EntitySpec, ...] = () + @property + def contains_metric_time(self) -> bool: + """Returns true if this set contains a spec referring to metric time at any grain.""" + return len(self.metric_time_specs) > 0 + + @property + def metric_time_specs(self) -> Sequence[TimeDimensionSpec]: + """Returns any specs referring to metric time at any grain.""" + return tuple( + time_dimension_spec + for time_dimension_spec in self.time_dimension_specs + if time_dimension_spec.element_name == METRIC_TIME_ELEMENT_NAME + ) + @property def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D return tuple(itertools.chain(self.dimension_specs, self.time_dimension_specs, self.entity_specs)) @@ -764,3 +789,12 @@ def combine(self, other: WhereFilterSpec) -> WhereFilterSpec: # noqa: D bind_parameters=self.bind_parameters.combine(other.bind_parameters), linkable_spec_set=self.linkable_spec_set.merge(other.linkable_spec_set), ) + + +@dataclass(frozen=True) +class JoinToTimeSpineDescription: + """Describes how a time spine join should be performed.""" + + join_type: SqlJoinType + offset_window: Optional[MetricTimeWindow] + offset_to_grain: Optional[TimeGranularity] diff --git a/metricflow/test/test_instance_serialization.py b/metricflow/test/test_instance_serialization.py index 1d5ce4bd74..d74aebabbb 100644 --- a/metricflow/test/test_instance_serialization.py +++ b/metricflow/test/test_instance_serialization.py @@ -1,11 +1,15 @@ from __future__ import annotations +import logging + import pytest from dbt_semantic_interfaces.dataclass_serialization import DataClassDeserializer, DataclassSerializer from metricflow.instances import InstanceSet from metricflow.test.fixtures.model_fixtures import ConsistentIdObjectRepository +logger = logging.getLogger(__name__) + @pytest.fixture def serializer() -> DataclassSerializer: # noqa: D