From 2a18b9cc6b641871d6a0b2be366622c9bffd2fec Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 30 Sep 2024 13:54:40 -0700 Subject: [PATCH 1/3] /* PR_START p--short-term-perf 13 */ Cache frequently-used methods in `MetricLookup`. --- .../model/semantics/metric_lookup.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py b/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py index e39e9063ec..02e7ea3a85 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/metric_lookup.py @@ -54,6 +54,14 @@ def __init__( max_entity_links=MAX_JOIN_HOPS, ) + # Cache for `get_min_queryable_time_granularity()` + self._metric_reference_to_min_metric_time_grain: Dict[MetricReference, TimeGranularity] = {} + + # Cache for `get_valid_agg_time_dimensions_for_metric()`. + self._metric_reference_to_valid_agg_time_dimension_specs: Dict[ + MetricReference, Sequence[TimeDimensionSpec] + ] = {} + @functools.lru_cache def linkable_elements_for_measure( self, @@ -183,6 +191,18 @@ def get_valid_agg_time_dimensions_for_metric( self, metric_reference: MetricReference ) -> Sequence[TimeDimensionSpec]: """Get the agg time dimension specs that can be used in place of metric time for this metric, if applicable.""" + result = self._metric_reference_to_valid_agg_time_dimension_specs.get(metric_reference) + if result is not None: + return result + + result = self._get_valid_agg_time_dimensions_for_metric(metric_reference) + self._metric_reference_to_valid_agg_time_dimension_specs[metric_reference] = result + + return result + + def _get_valid_agg_time_dimensions_for_metric( + self, metric_reference: MetricReference + ) -> Sequence[TimeDimensionSpec]: agg_time_dimension_specs = self._get_agg_time_dimension_specs_for_metric(metric_reference) distinct_agg_time_dimension_identifiers = set( [(spec.reference, spec.entity_links) for spec in agg_time_dimension_specs] @@ -204,6 +224,15 @@ def get_min_queryable_time_granularity(self, metric_reference: MetricReference) Maps to the largest granularity defined for any of the metric's agg_time_dimensions. """ + result = self._metric_reference_to_min_metric_time_grain.get(metric_reference) + if result is not None: + return result + + result = self._get_min_queryable_time_granularity(metric_reference) + self._metric_reference_to_min_metric_time_grain[metric_reference] = result + return result + + def _get_min_queryable_time_granularity(self, metric_reference: MetricReference) -> TimeGranularity: agg_time_dimension_specs = self._get_agg_time_dimension_specs_for_metric(metric_reference) assert ( agg_time_dimension_specs From 4b37c74a4048e9278286a916ce9e5ec7ba113011 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 30 Sep 2024 14:01:06 -0700 Subject: [PATCH 2/3] Cache `SemanticModelLookup.get_agg_time_dimension_specs_for_measure()`. This is also a common method that is called during query resolution. The memory used by the cache is small since the number of measures is on the order of ~100 in large manifests. --- .../model/semantics/semantic_model_lookup.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py index 91ec266172..a3c3ef8388 100644 --- a/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py +++ b/metricflow-semantics/metricflow_semantics/model/semantics/semantic_model_lookup.py @@ -67,6 +67,9 @@ def __init__(self, model: SemanticManifest, custom_granularities: Dict[str, Expa # Cache for defined time granularity. self._time_dimension_to_defined_time_granularity: Dict[TimeDimensionReference, TimeGranularity] = {} + # Cache for agg. time dimension for measure. + self._measure_reference_to_agg_time_dimension_specs: Dict[MeasureReference, Sequence[TimeDimensionSpec]] = {} + def get_dimension_references(self) -> Sequence[DimensionReference]: """Retrieve all dimension references from the collection of semantic models.""" return tuple(self._dimension_index.keys()) @@ -364,6 +367,17 @@ def get_agg_time_dimension_specs_for_measure( self, measure_reference: MeasureReference ) -> Sequence[TimeDimensionSpec]: """Get the agg time dimension specs that can be used in place of metric time for this measure.""" + result = self._measure_reference_to_agg_time_dimension_specs.get(measure_reference) + if result is not None: + return result + + result = self._get_agg_time_dimension_specs_for_measure(measure_reference) + self._measure_reference_to_agg_time_dimension_specs[measure_reference] = result + return result + + def _get_agg_time_dimension_specs_for_measure( + self, measure_reference: MeasureReference + ) -> Sequence[TimeDimensionSpec]: agg_time_dimension = self.get_agg_time_dimension_for_measure(measure_reference) # A measure's agg_time_dimension is required to be in the same semantic model as the measure, # so we can assume the same semantic model for both measure and dimension. From 765996598f032c5fd5863f64e3d8680f4e92e6eb Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Mon, 30 Sep 2024 14:53:21 -0700 Subject: [PATCH 3/3] /* PR_START p--short-term-perf 15 */ Reduce recursive-call overhead in `MetricTimeQueryValidationRule`. The check in `MetricTimeQueryValidationRule` is called for every metric in a derived metric's ancestors, so this moves expensive parts of the check to only where it's needed and caches results when possible. --- .../query/query_resolver.py | 10 +++++-- .../validation_rules/base_validation_rule.py | 7 +++-- .../validation_rules/duplicate_metric.py | 7 ----- .../metric_time_requirements.py | 28 ++++++++++--------- .../query/validation_rules/query_validator.py | 19 ++++--------- 5 files changed, 31 insertions(+), 40 deletions(-) diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index 5181ebf642..c12b95f1c1 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -54,6 +54,8 @@ ResolverInputForWhereFilterIntersection, ) from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator +from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule +from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule from metricflow_semantics.query.validation_rules.query_validator import PostResolutionQueryValidator from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec from metricflow_semantics.specs.metric_spec import MetricSpec @@ -123,9 +125,7 @@ def __init__( # noqa: D107 where_filter_pattern_factory: WhereFilterPatternFactory, ) -> None: self._manifest_lookup = manifest_lookup - self._post_resolution_query_validator = PostResolutionQueryValidator( - manifest_lookup=self._manifest_lookup, - ) + self._post_resolution_query_validator = PostResolutionQueryValidator() self._where_filter_pattern_factory = where_filter_pattern_factory @staticmethod @@ -491,6 +491,10 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met query_level_issue_set = self._post_resolution_query_validator.validate_query( resolution_dag=resolution_dag, resolver_input_for_query=resolver_input_for_query, + validation_rules=( + MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query), + DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query), + ), ) if query_level_issue_set.has_issues: diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py index abfc9ab42f..7de6ebc161 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py @@ -15,8 +15,11 @@ class PostResolutionQueryValidationRule(ABC): """A validation rule that runs after all query inputs have been resolved to specs.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 + def __init__( # noqa: D107 + self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + ) -> None: self._manifest_lookup = manifest_lookup + self._resolver_input_for_query = resolver_input_for_query def _get_metric(self, metric_reference: MetricReference) -> Metric: return self._manifest_lookup.metric_lookup.get_metric(metric_reference) @@ -25,7 +28,6 @@ def _get_metric(self, metric_reference: MetricReference) -> Metric: def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: """Given a metric that exists in a resolution DAG, check that the query is valid. @@ -39,7 +41,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: """Validate the parameters to the query are valid. diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py index 7a595f4be6..249fe57943 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py @@ -7,11 +7,9 @@ from dbt_semantic_interfaces.references import MetricReference from typing_extensions import override -from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet from metricflow_semantics.query.issues.parsing.duplicate_metric import DuplicateMetricIssue -from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule logger = logging.getLogger(__name__) @@ -20,14 +18,10 @@ class DuplicateMetricValidationRule(PostResolutionQueryValidationRule): """Validates that a query does not include the same metric multiple times.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - super().__init__(manifest_lookup=manifest_lookup) - @override def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: return MetricFlowQueryResolutionIssueSet.empty_instance() @@ -37,7 +31,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: duplicate_metric_references = tuple( diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py index 572b6f0ee7..bbb0fd37c7 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import cached_property from typing import Sequence from dbt_semantic_interfaces.enum_extension import assert_values_exhausted @@ -32,8 +33,10 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule): * Derived metrics with an offset time.g """ - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - super().__init__(manifest_lookup=manifest_lookup) + def __init__( # noqa: D107 + self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + ) -> None: + super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query) self._metric_time_specs = tuple( TimeDimensionSpec.generate_possible_specs_for_time_dimension( @@ -43,13 +46,19 @@ def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D1 ) ) - def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool: - for group_by_item_input in query_resolver_input.group_by_item_inputs: + @cached_property + def _group_by_items_include_metric_time(self) -> bool: + for group_by_item_input in self._resolver_input_for_query.group_by_item_inputs: if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs): return True return False + def _query_includes_metric_time_or_agg_time_dimension(self, metric_reference: MetricReference) -> bool: + return self._group_by_items_include_metric_time or self._group_by_items_include_agg_time_dimension( + query_resolver_input=self._resolver_input_for_query, metric_reference=metric_reference + ) + def _group_by_items_include_agg_time_dimension( self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference ) -> bool: @@ -66,15 +75,9 @@ def _group_by_items_include_agg_time_dimension( def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: metric = self._get_metric(metric_reference) - query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time( - resolver_input_for_query - ) or self._group_by_items_include_agg_time_dimension( - query_resolver_input=resolver_input_for_query, metric_reference=metric_reference - ) if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION: return MetricFlowQueryResolutionIssueSet.empty_instance() @@ -86,7 +89,7 @@ def validate_metric_in_resolution_dag( metric.type_params.cumulative_type_params.window is not None or metric.type_params.cumulative_type_params.grain_to_date is not None ) - and not query_includes_metric_time_or_agg_time_dimension + and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference) ): return MetricFlowQueryResolutionIssueSet.from_issue( CumulativeMetricRequiresMetricTimeIssue.from_parameters( @@ -102,7 +105,7 @@ def validate_metric_in_resolution_dag( for input_metric in metric.input_metrics ) - if has_time_offset and not query_includes_metric_time_or_agg_time_dimension: + if has_time_offset and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference): return MetricFlowQueryResolutionIssueSet.from_issue( OffsetMetricRequiresMetricTimeIssue.from_parameters( metric_reference=metric_reference, @@ -119,7 +122,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: return MetricFlowQueryResolutionIssueSet.empty_instance() diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py index 2885ed0c74..2aff826e7e 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py @@ -4,7 +4,6 @@ from typing_extensions import override -from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import ( @@ -26,27 +25,21 @@ from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule -from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule -from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule class PostResolutionQueryValidator: """Runs query validation rules after query resolution is complete.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - self._manifest_lookup = manifest_lookup - self._validation_rules = ( - MetricTimeQueryValidationRule(self._manifest_lookup), - DuplicateMetricValidationRule(self._manifest_lookup), - ) - def validate_query( - self, resolution_dag: GroupByItemResolutionDag, resolver_input_for_query: ResolverInputForQuery + self, + resolution_dag: GroupByItemResolutionDag, + resolver_input_for_query: ResolverInputForQuery, + validation_rules: Sequence[PostResolutionQueryValidationRule], ) -> MetricFlowQueryResolutionIssueSet: """Validate according to the list of configured validation rules and return a set containing issues found.""" validation_visitor = _PostResolutionQueryValidationVisitor( resolver_input_for_query=resolver_input_for_query, - validation_rules=self._validation_rules, + validation_rules=validation_rules, ) return resolution_dag.sink_node.accept(validation_visitor) @@ -83,7 +76,6 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> MetricFlow issue_sets_to_merge.append( validation_rule.validate_metric_in_resolution_dag( metric_reference=node.metric_reference, - resolver_input_for_query=self._resolver_input_for_query, resolution_path=current_traversal_path, ) ) @@ -100,7 +92,6 @@ def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> MetricFlowQu validation_rule.validate_query_in_resolution_dag( metrics_in_query=node.metrics_in_query, where_filter_intersection=node.where_filter_intersection, - resolver_input_for_query=self._resolver_input_for_query, resolution_path=current_traversal_path, ) )