Skip to content

Commit

Permalink
/* PR_START p--short-term-perf 15 */ Reduce recursive-call overhead i…
Browse files Browse the repository at this point in the history
…n `MetricTimeQueryValidationRule`.

The check in `MetricTimeQueryValidationRule` is called for every metric in a
derived metric's ancestors, so this moves expensive parts of the check to only
where it's needed and caches results when possible.
  • Loading branch information
plypaul committed Oct 2, 2024
1 parent 4b37c74 commit 7659965
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
ResolverInputForWhereFilterIntersection,
)
from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator
from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule
from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule
from metricflow_semantics.query.validation_rules.query_validator import PostResolutionQueryValidator
from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
Expand Down Expand Up @@ -123,9 +125,7 @@ def __init__( # noqa: D107
where_filter_pattern_factory: WhereFilterPatternFactory,
) -> None:
self._manifest_lookup = manifest_lookup
self._post_resolution_query_validator = PostResolutionQueryValidator(
manifest_lookup=self._manifest_lookup,
)
self._post_resolution_query_validator = PostResolutionQueryValidator()
self._where_filter_pattern_factory = where_filter_pattern_factory

@staticmethod
Expand Down Expand Up @@ -491,6 +491,10 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
query_level_issue_set = self._post_resolution_query_validator.validate_query(
resolution_dag=resolution_dag,
resolver_input_for_query=resolver_input_for_query,
validation_rules=(
MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query),
DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query),
),
)

if query_level_issue_set.has_issues:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
class PostResolutionQueryValidationRule(ABC):
"""A validation rule that runs after all query inputs have been resolved to specs."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
) -> None:
self._manifest_lookup = manifest_lookup
self._resolver_input_for_query = resolver_input_for_query

def _get_metric(self, metric_reference: MetricReference) -> Metric:
return self._manifest_lookup.metric_lookup.get_metric(metric_reference)
Expand All @@ -25,7 +28,6 @@ def _get_metric(self, metric_reference: MetricReference) -> Metric:
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
"""Given a metric that exists in a resolution DAG, check that the query is valid.
Expand All @@ -39,7 +41,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
"""Validate the parameters to the query are valid.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@
from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.issues.parsing.duplicate_metric import DuplicateMetricIssue
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule

logger = logging.getLogger(__name__)
Expand All @@ -20,14 +18,10 @@
class DuplicateMetricValidationRule(PostResolutionQueryValidationRule):
"""Validates that a query does not include the same metric multiple times."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
super().__init__(manifest_lookup=manifest_lookup)

@override
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Expand All @@ -37,7 +31,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
duplicate_metric_references = tuple(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import cached_property
from typing import Sequence

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
Expand Down Expand Up @@ -32,8 +33,10 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
* Derived metrics with an offset time.g
"""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
super().__init__(manifest_lookup=manifest_lookup)
def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
) -> None:
super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query)

self._metric_time_specs = tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
Expand All @@ -43,13 +46,19 @@ def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D1
)
)

def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool:
for group_by_item_input in query_resolver_input.group_by_item_inputs:
@cached_property
def _group_by_items_include_metric_time(self) -> bool:
for group_by_item_input in self._resolver_input_for_query.group_by_item_inputs:
if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs):
return True

return False

def _query_includes_metric_time_or_agg_time_dimension(self, metric_reference: MetricReference) -> bool:
return self._group_by_items_include_metric_time or self._group_by_items_include_agg_time_dimension(
query_resolver_input=self._resolver_input_for_query, metric_reference=metric_reference
)

def _group_by_items_include_agg_time_dimension(
self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference
) -> bool:
Expand All @@ -66,15 +75,9 @@ def _group_by_items_include_agg_time_dimension(
def validate_metric_in_resolution_dag(
self,
metric_reference: MetricReference,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
metric = self._get_metric(metric_reference)
query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time(
resolver_input_for_query
) or self._group_by_items_include_agg_time_dimension(
query_resolver_input=resolver_input_for_query, metric_reference=metric_reference
)

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Expand All @@ -86,7 +89,7 @@ def validate_metric_in_resolution_dag(
metric.type_params.cumulative_type_params.window is not None
or metric.type_params.cumulative_type_params.grain_to_date is not None
)
and not query_includes_metric_time_or_agg_time_dimension
and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference)
):
return MetricFlowQueryResolutionIssueSet.from_issue(
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
Expand All @@ -102,7 +105,7 @@ def validate_metric_in_resolution_dag(
for input_metric in metric.input_metrics
)

if has_time_offset and not query_includes_metric_time_or_agg_time_dimension:
if has_time_offset and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference):
return MetricFlowQueryResolutionIssueSet.from_issue(
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
Expand All @@ -119,7 +122,6 @@ def validate_query_in_resolution_dag(
self,
metrics_in_query: Sequence[MetricReference],
where_filter_intersection: WhereFilterIntersection,
resolver_input_for_query: ResolverInputForQuery,
resolution_path: MetricFlowQueryResolutionPath,
) -> MetricFlowQueryResolutionIssueSet:
return MetricFlowQueryResolutionIssueSet.empty_instance()
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker
from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag
from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import (
Expand All @@ -26,27 +25,21 @@
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule
from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule


class PostResolutionQueryValidator:
"""Runs query validation rules after query resolution is complete."""

def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
self._manifest_lookup = manifest_lookup
self._validation_rules = (
MetricTimeQueryValidationRule(self._manifest_lookup),
DuplicateMetricValidationRule(self._manifest_lookup),
)

def validate_query(
self, resolution_dag: GroupByItemResolutionDag, resolver_input_for_query: ResolverInputForQuery
self,
resolution_dag: GroupByItemResolutionDag,
resolver_input_for_query: ResolverInputForQuery,
validation_rules: Sequence[PostResolutionQueryValidationRule],
) -> MetricFlowQueryResolutionIssueSet:
"""Validate according to the list of configured validation rules and return a set containing issues found."""
validation_visitor = _PostResolutionQueryValidationVisitor(
resolver_input_for_query=resolver_input_for_query,
validation_rules=self._validation_rules,
validation_rules=validation_rules,
)

return resolution_dag.sink_node.accept(validation_visitor)
Expand Down Expand Up @@ -83,7 +76,6 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> MetricFlow
issue_sets_to_merge.append(
validation_rule.validate_metric_in_resolution_dag(
metric_reference=node.metric_reference,
resolver_input_for_query=self._resolver_input_for_query,
resolution_path=current_traversal_path,
)
)
Expand All @@ -100,7 +92,6 @@ def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> MetricFlowQu
validation_rule.validate_query_in_resolution_dag(
metrics_in_query=node.metrics_in_query,
where_filter_intersection=node.where_filter_intersection,
resolver_input_for_query=self._resolver_input_for_query,
resolution_path=current_traversal_path,
)
)
Expand Down

0 comments on commit 7659965

Please sign in to comment.