diff --git a/metricflow-semantics/metricflow_semantics/query/query_resolver.py b/metricflow-semantics/metricflow_semantics/query/query_resolver.py index 5181ebf642..c12b95f1c1 100644 --- a/metricflow-semantics/metricflow_semantics/query/query_resolver.py +++ b/metricflow-semantics/metricflow_semantics/query/query_resolver.py @@ -54,6 +54,8 @@ ResolverInputForWhereFilterIntersection, ) from metricflow_semantics.query.suggestion_generator import QueryItemSuggestionGenerator +from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule +from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule from metricflow_semantics.query.validation_rules.query_validator import PostResolutionQueryValidator from metricflow_semantics.specs.instance_spec import InstanceSpec, LinkableInstanceSpec from metricflow_semantics.specs.metric_spec import MetricSpec @@ -123,9 +125,7 @@ def __init__( # noqa: D107 where_filter_pattern_factory: WhereFilterPatternFactory, ) -> None: self._manifest_lookup = manifest_lookup - self._post_resolution_query_validator = PostResolutionQueryValidator( - manifest_lookup=self._manifest_lookup, - ) + self._post_resolution_query_validator = PostResolutionQueryValidator() self._where_filter_pattern_factory = where_filter_pattern_factory @staticmethod @@ -491,6 +491,10 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met query_level_issue_set = self._post_resolution_query_validator.validate_query( resolution_dag=resolution_dag, resolver_input_for_query=resolver_input_for_query, + validation_rules=( + MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query), + DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query), + ), ) if query_level_issue_set.has_issues: diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py index abfc9ab42f..7de6ebc161 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/base_validation_rule.py @@ -15,8 +15,11 @@ class PostResolutionQueryValidationRule(ABC): """A validation rule that runs after all query inputs have been resolved to specs.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 + def __init__( # noqa: D107 + self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + ) -> None: self._manifest_lookup = manifest_lookup + self._resolver_input_for_query = resolver_input_for_query def _get_metric(self, metric_reference: MetricReference) -> Metric: return self._manifest_lookup.metric_lookup.get_metric(metric_reference) @@ -25,7 +28,6 @@ def _get_metric(self, metric_reference: MetricReference) -> Metric: def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: """Given a metric that exists in a resolution DAG, check that the query is valid. @@ -39,7 +41,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: """Validate the parameters to the query are valid. diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py index 7a595f4be6..249fe57943 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/duplicate_metric.py @@ -7,11 +7,9 @@ from dbt_semantic_interfaces.references import MetricReference from typing_extensions import override -from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet from metricflow_semantics.query.issues.parsing.duplicate_metric import DuplicateMetricIssue -from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule logger = logging.getLogger(__name__) @@ -20,14 +18,10 @@ class DuplicateMetricValidationRule(PostResolutionQueryValidationRule): """Validates that a query does not include the same metric multiple times.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - super().__init__(manifest_lookup=manifest_lookup) - @override def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: return MetricFlowQueryResolutionIssueSet.empty_instance() @@ -37,7 +31,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: duplicate_metric_references = tuple( diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py index 572b6f0ee7..bbb0fd37c7 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/metric_time_requirements.py @@ -1,5 +1,6 @@ from __future__ import annotations +from functools import cached_property from typing import Sequence from dbt_semantic_interfaces.enum_extension import assert_values_exhausted @@ -32,8 +33,10 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule): * Derived metrics with an offset time.g """ - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - super().__init__(manifest_lookup=manifest_lookup) + def __init__( # noqa: D107 + self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery + ) -> None: + super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query) self._metric_time_specs = tuple( TimeDimensionSpec.generate_possible_specs_for_time_dimension( @@ -43,13 +46,19 @@ def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D1 ) ) - def _group_by_items_include_metric_time(self, query_resolver_input: ResolverInputForQuery) -> bool: - for group_by_item_input in query_resolver_input.group_by_item_inputs: + @cached_property + def _group_by_items_include_metric_time(self) -> bool: + for group_by_item_input in self._resolver_input_for_query.group_by_item_inputs: if group_by_item_input.spec_pattern.matches_any(self._metric_time_specs): return True return False + def _query_includes_metric_time_or_agg_time_dimension(self, metric_reference: MetricReference) -> bool: + return self._group_by_items_include_metric_time or self._group_by_items_include_agg_time_dimension( + query_resolver_input=self._resolver_input_for_query, metric_reference=metric_reference + ) + def _group_by_items_include_agg_time_dimension( self, query_resolver_input: ResolverInputForQuery, metric_reference: MetricReference ) -> bool: @@ -66,15 +75,9 @@ def _group_by_items_include_agg_time_dimension( def validate_metric_in_resolution_dag( self, metric_reference: MetricReference, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: metric = self._get_metric(metric_reference) - query_includes_metric_time_or_agg_time_dimension = self._group_by_items_include_metric_time( - resolver_input_for_query - ) or self._group_by_items_include_agg_time_dimension( - query_resolver_input=resolver_input_for_query, metric_reference=metric_reference - ) if metric.type is MetricType.SIMPLE or metric.type is MetricType.CONVERSION: return MetricFlowQueryResolutionIssueSet.empty_instance() @@ -86,7 +89,7 @@ def validate_metric_in_resolution_dag( metric.type_params.cumulative_type_params.window is not None or metric.type_params.cumulative_type_params.grain_to_date is not None ) - and not query_includes_metric_time_or_agg_time_dimension + and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference) ): return MetricFlowQueryResolutionIssueSet.from_issue( CumulativeMetricRequiresMetricTimeIssue.from_parameters( @@ -102,7 +105,7 @@ def validate_metric_in_resolution_dag( for input_metric in metric.input_metrics ) - if has_time_offset and not query_includes_metric_time_or_agg_time_dimension: + if has_time_offset and not self._query_includes_metric_time_or_agg_time_dimension(metric_reference): return MetricFlowQueryResolutionIssueSet.from_issue( OffsetMetricRequiresMetricTimeIssue.from_parameters( metric_reference=metric_reference, @@ -119,7 +122,6 @@ def validate_query_in_resolution_dag( self, metrics_in_query: Sequence[MetricReference], where_filter_intersection: WhereFilterIntersection, - resolver_input_for_query: ResolverInputForQuery, resolution_path: MetricFlowQueryResolutionPath, ) -> MetricFlowQueryResolutionIssueSet: return MetricFlowQueryResolutionIssueSet.empty_instance() diff --git a/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py b/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py index 2885ed0c74..2aff826e7e 100644 --- a/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py +++ b/metricflow-semantics/metricflow_semantics/query/validation_rules/query_validator.py @@ -4,7 +4,6 @@ from typing_extensions import override -from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup from metricflow_semantics.query.group_by_item.candidate_push_down.push_down_visitor import DagTraversalPathTracker from metricflow_semantics.query.group_by_item.resolution_dag.dag import GroupByItemResolutionDag from metricflow_semantics.query.group_by_item.resolution_dag.resolution_nodes.base_node import ( @@ -26,27 +25,21 @@ from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule -from metricflow_semantics.query.validation_rules.duplicate_metric import DuplicateMetricValidationRule -from metricflow_semantics.query.validation_rules.metric_time_requirements import MetricTimeQueryValidationRule class PostResolutionQueryValidator: """Runs query validation rules after query resolution is complete.""" - def __init__(self, manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107 - self._manifest_lookup = manifest_lookup - self._validation_rules = ( - MetricTimeQueryValidationRule(self._manifest_lookup), - DuplicateMetricValidationRule(self._manifest_lookup), - ) - def validate_query( - self, resolution_dag: GroupByItemResolutionDag, resolver_input_for_query: ResolverInputForQuery + self, + resolution_dag: GroupByItemResolutionDag, + resolver_input_for_query: ResolverInputForQuery, + validation_rules: Sequence[PostResolutionQueryValidationRule], ) -> MetricFlowQueryResolutionIssueSet: """Validate according to the list of configured validation rules and return a set containing issues found.""" validation_visitor = _PostResolutionQueryValidationVisitor( resolver_input_for_query=resolver_input_for_query, - validation_rules=self._validation_rules, + validation_rules=validation_rules, ) return resolution_dag.sink_node.accept(validation_visitor) @@ -83,7 +76,6 @@ def visit_metric_node(self, node: MetricGroupByItemResolutionNode) -> MetricFlow issue_sets_to_merge.append( validation_rule.validate_metric_in_resolution_dag( metric_reference=node.metric_reference, - resolver_input_for_query=self._resolver_input_for_query, resolution_path=current_traversal_path, ) ) @@ -100,7 +92,6 @@ def visit_query_node(self, node: QueryGroupByItemResolutionNode) -> MetricFlowQu validation_rule.validate_query_in_resolution_dag( metrics_in_query=node.metrics_in_query, where_filter_intersection=node.where_filter_intersection, - resolver_input_for_query=self._resolver_input_for_query, resolution_path=current_traversal_path, ) )