From 01b4b2216fc6eeebeee8dcdbb4a4cb0eb2f35de6 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Sun, 27 Oct 2024 16:36:09 -0700 Subject: [PATCH] /* PR_START p--short-term-perf 28 */ Use a measure-based validation for SCDs. --- .../model/semantics/linkable_element_set.py | 36 ++++ .../metric_time_requirements.py | 156 +++++++----------- 2 files changed, 100 insertions(+), 92 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py index 16cdef4f5b..c09f8f62ee 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/linkable_element_set.py @@ -427,3 +427,39 @@ def filter_by_spec_patterns(self, spec_patterns: Sequence[SpecPattern]) -> Linka ) logger.debug(LazyFormat(lambda: f"Filtering valid linkable elements took: {time.time() - start_time:.2f}s")) return filtered_elements + + def filter_by_left_semantic_model( + self, left_semantic_model_reference: SemanticModelReference + ) -> LinkableElementSet: + """Return a `LinkableElementSet` with only elements that have the given left semantic model in the join path.""" + path_key_to_linkable_dimensions: Dict[ElementPathKey, Tuple[LinkableDimension, ...]] = {} + path_key_to_linkable_entities: Dict[ElementPathKey, Tuple[LinkableEntity, ...]] = {} + path_key_to_linkable_metrics: Dict[ElementPathKey, Tuple[LinkableMetric, ...]] = {} + + for path_key, linkable_dimensions in self.path_key_to_linkable_dimensions.items(): + path_key_to_linkable_dimensions[path_key] = tuple( + linkable_dimension + for linkable_dimension in linkable_dimensions + if linkable_dimension.join_path.left_semantic_model_reference == left_semantic_model_reference + ) + + for path_key, linkable_entities in self.path_key_to_linkable_entities.items(): + path_key_to_linkable_entities[path_key] = tuple( + linkable_entity + for linkable_entity in linkable_entities + if linkable_entity.join_path.left_semantic_model_reference == left_semantic_model_reference + ) + + for path_key, linkable_metrics in self.path_key_to_linkable_metrics.items(): + path_key_to_linkable_metrics[path_key] = tuple( + linkable_metric + for linkable_metric in linkable_metrics + if linkable_metric.join_path.semantic_model_join_path.left_semantic_model_reference + == left_semantic_model_reference + ) + + return LinkableElementSet( + path_key_to_linkable_dimensions=path_key_to_linkable_dimensions, + path_key_to_linkable_entities=path_key_to_linkable_entities, + path_key_to_linkable_metrics=path_key_to_linkable_metrics, + ) diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py index c831a3f3b7..7d09efc1e1 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py @@ -1,22 +1,21 @@ from __future__ import annotations import typing -from dataclasses import dataclass -from typing import List, Sequence, Tuple +from typing import List, Sequence from dbt_semantic_interfaces.enum_extension import assert_values_exhausted -from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME from dbt_semantic_interfaces.protocols import Metric, WhereFilterIntersection from dbt_semantic_interfaces.references import ( MeasureReference, MetricReference, - TimeDimensionReference, ) from dbt_semantic_interfaces.type_enums import MetricType from typing_extensions import override -from metricflow_semantics.collection_helpers.lru_cache import LruCache +from metricflow_semantics.model.linkable_element_property import LinkableElementProperty from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup +from metricflow_semantics.model.semantics.element_filter import LinkableElementFilter +from metricflow_semantics.model.semantics.linkable_element_set import LinkableElementSet from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath from metricflow_semantics.query.issues.issues_base import ( MetricFlowQueryResolutionIssue, @@ -33,22 +32,11 @@ ) from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule -from metricflow_semantics.specs.instance_spec import InstanceSpec -from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec if typing.TYPE_CHECKING: from metricflow_semantics.query.query_resolver import ResolveGroupByItemsResult -@dataclass(frozen=True) -class QueryItemsAnalysis: - """Contains data about which items a query contains.""" - - scds: Sequence[InstanceSpec] - has_metric_time: bool - has_agg_time_dimension: bool - - class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule): """Validates cases where a query requires metric_time to be specified as a group-by-item. @@ -71,64 +59,33 @@ def __init__( # noqa: D107 resolve_group_by_item_result=resolve_group_by_item_result, ) - self._metric_time_specs = tuple( - TimeDimensionSpec.generate_possible_specs_for_time_dimension( - time_dimension_reference=TimeDimensionReference(element_name=METRIC_TIME_ELEMENT_NAME), - entity_links=(), - custom_granularities=self._manifest_lookup.custom_granularities, - ) - ) - self._query_items_analysis_cache: LruCache[ - Tuple[ResolverInputForQuery, MetricReference], QueryItemsAnalysis - ] = LruCache(128) - - def _get_query_items_analysis( - self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference - ) -> QueryItemsAnalysis: - cache_key = (query_resolver_input, metric_reference) - result = self._query_items_analysis_cache.get(cache_key) - if result is not None: - return result - result = self._uncached_query_items_analysis(query_resolver_input, metric_reference) - self._query_items_analysis_cache.set(cache_key, result) - return result - - def _uncached_query_items_analysis( - self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference - ) -> QueryItemsAnalysis: - has_agg_time_dimension = False - has_metric_time = False - scds: List[InstanceSpec] = [] - - valid_agg_time_dimension_specs = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric( - metric_reference + self._query_includes_metric_time = ( + self._resolve_group_by_item_result.linkable_element_set.filter( + LinkableElementFilter(with_any_of=frozenset({LinkableElementProperty.METRIC_TIME})) + ).spec_count + > 0 ) - scd_specs = self._manifest_lookup.metric_lookup.get_joinable_scd_specs_for_metric(metric_reference) - - for group_by_item_input in query_resolver_input.group_by_item_inputs: - if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs): - has_metric_time = True - - if group_by_item_input.spec_pattern.matches_any(valid_agg_time_dimension_specs): - has_agg_time_dimension = True - - scd_matches = group_by_item_input.spec_pattern.match(scd_specs) - scds.extend(scd_matches) + self._scd_linkable_element_set = self._resolve_group_by_item_result.linkable_element_set.filter( + LinkableElementFilter(with_any_of=frozenset({LinkableElementProperty.SCD_HOP})) + ) - return QueryItemsAnalysis( - scds=scds, - has_metric_time=has_metric_time, - has_agg_time_dimension=has_agg_time_dimension, + def _query_includes_agg_time_dimension_of_metric(self, metric_reference: MetricReference) -> bool: + valid_agg_time_dimensions = self._manifest_lookup.metric_lookup.get_valid_agg_time_dimensions_for_metric( + metric_reference + ) + return ( + len(set(valid_agg_time_dimensions).intersection(self._resolve_group_by_item_result.group_by_item_specs)) > 0 ) def _validate_cumulative_metric( self, metric_reference: MetricReference, metric: Metric, - query_items_analysis: QueryItemsAnalysis, resolution_path: MetricFlowQueryResolutionPath, ) -> Sequence[MetricFlowQueryResolutionIssue]: + # A cumulative metric with a window or grain-to-date specified requires a `metric-time` or the aggregation time + # dimension for the metric. if ( metric.type_params is not None and metric.type_params.cumulative_type_params is not None @@ -136,8 +93,10 @@ def _validate_cumulative_metric( metric.type_params.cumulative_type_params.window is not None or metric.type_params.cumulative_type_params.grain_to_date is not None ) - and not (query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension) ): + if self._query_includes_metric_time or self._query_includes_agg_time_dimension_of_metric(metric_reference): + return () + return ( CumulativeMetricRequiresMetricTimeIssue.from_parameters( metric_reference=metric_reference, @@ -151,24 +110,37 @@ def _validate_derived_metric( metric_reference: MetricReference, metric: Metric, resolution_path: MetricFlowQueryResolutionPath, - query_items_analysis: QueryItemsAnalysis, ) -> Sequence[MetricFlowQueryResolutionIssue]: has_time_offset = any( input_metric.offset_window is not None or input_metric.offset_to_grain is not None for input_metric in metric.input_metrics ) - if has_time_offset and not ( - query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension - ): - return ( - OffsetMetricRequiresMetricTimeIssue.from_parameters( - metric_reference=metric_reference, - input_metrics=metric.input_metrics, - query_resolution_path=resolution_path, - ), - ) - return () + # If a derived metric does not define a time offset, then there are no requirements on what's in the group-by + # items. + if not has_time_offset: + return () + + # If a derived metric has a time offset, then the query needs to include `metric_time` or the aggregation time + # dimension of a metric. + if self._query_includes_metric_time or self._query_includes_agg_time_dimension_of_metric(metric_reference): + return () + + return ( + OffsetMetricRequiresMetricTimeIssue.from_parameters( + metric_reference=metric_reference, + input_metrics=metric.input_metrics, + query_resolution_path=resolution_path, + ), + ) + + def _scd_linkable_element_set_for_measure(self, measure_reference: MeasureReference) -> LinkableElementSet: + """Returns subset of the query's `LinkableElements` that are SCDs and associated with the measure.""" + measure_semantic_model = self._manifest_lookup.semantic_model_lookup.get_semantic_model_for_measure( + measure_reference + ) + + return self._scd_linkable_element_set.filter_by_left_semantic_model(measure_semantic_model.reference) @override def validate_metric_in_resolution_dag( @@ -177,27 +149,13 @@ def validate_metric_in_resolution_dag( resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: metric = self._manifest_lookup.metric_lookup.get_metric(metric_reference) - - query_items_analysis = self._get_query_items_analysis(self._resolver_input_for_query, metric_reference) - issues: List[MetricFlowQueryResolutionIssue] = [] - # Queries that join to an SCD don't support direct references to agg_time_dimension, so we - # only check for metric_time. If we decide to support agg_time_dimension, we should add a check - if len(query_items_analysis.scds) > 0 and not query_items_analysis.has_metric_time: - issues.append( - ScdRequiresMetricTimeIssue.from_parameters( - scds_in_query=query_items_analysis.scds, - query_resolution_path=resolution_path, - ) - ) - if metric.type is MetricType.CUMULATIVE: issues.extend( self._validate_cumulative_metric( metric_reference=metric_reference, metric=metric, - query_items_analysis=query_items_analysis, resolution_path=resolution_path, ) ) @@ -207,7 +165,6 @@ def validate_metric_in_resolution_dag( self._validate_derived_metric( metric_reference=metric_reference, metric=metric, - query_items_analysis=query_items_analysis, resolution_path=resolution_path, ) ) @@ -235,4 +192,19 @@ def validate_measure_in_resolution_dag( measure_reference: MeasureReference, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: - return MetricFlowQueryResolutionIssueSet.empty_instance() + scd_linkable_elemenent_set_for_measure = self._scd_linkable_element_set_for_measure(measure_reference) + + if scd_linkable_elemenent_set_for_measure.spec_count == 0: + return MetricFlowQueryResolutionIssueSet.empty_instance() + + if self._query_includes_metric_time: + return MetricFlowQueryResolutionIssueSet.empty_instance() + + # Queries that join to an SCD don't support direct references to agg_time_dimension, so we + # only check for metric_time. If we decide to support agg_time_dimension, we should add a check + + return MetricFlowQueryResolutionIssueSet.from_issue( + ScdRequiresMetricTimeIssue.from_parameters( + scds_in_query=scd_linkable_elemenent_set_for_measure.specs, query_resolution_path=resolution_path + ) + )