Skip to content

Commit

Permalink
Split case handling for different metric types in `MetricTimeQueryVal…
Browse files Browse the repository at this point in the history
…idationRule` (#1478)

To improve readability / shorten methods, this PR splits the case
handling for different metrics types in `MetricTimeQueryValidationRule`.

Due to an incorrect merge, this also includes changes in #1479.
  • Loading branch information
plypaul authored Oct 29, 2024
1 parent 8c2e064 commit 7a715b5
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,12 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
resolution_dag=resolution_dag,
resolver_input_for_query=resolver_input_for_query,
validation_rules=(
MetricTimeQueryValidationRule(self._manifest_lookup, resolver_input_for_query),
DuplicateMetricValidationRule(self._manifest_lookup, resolver_input_for_query),
MetricTimeQueryValidationRule(
self._manifest_lookup, resolver_input_for_query, resolve_group_by_item_result
),
DuplicateMetricValidationRule(
self._manifest_lookup, resolver_input_for_query, resolve_group_by_item_result
),
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from typing import Sequence

Expand All @@ -11,15 +12,22 @@
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery

if typing.TYPE_CHECKING:
from metricflow_semantics.query.query_resolver import ResolveGroupByItemsResult


class PostResolutionQueryValidationRule(ABC):
"""A validation rule that runs after all query inputs have been resolved to specs."""

def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
self,
manifest_lookup: SemanticManifestLookup,
resolver_input_for_query: ResolverInputForQuery,
resolve_group_by_item_result: ResolveGroupByItemsResult,
) -> None:
self._manifest_lookup = manifest_lookup
self._resolver_input_for_query = resolver_input_for_query
self._resolve_group_by_item_result = resolve_group_by_item_result

@abstractmethod
def validate_metric_in_resolution_dag(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from typing import List, Sequence, Tuple

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.naming.keywords import METRIC_TIME_ELEMENT_NAME
from dbt_semantic_interfaces.protocols import WhereFilterIntersection
from dbt_semantic_interfaces.protocols import Metric, WhereFilterIntersection
from dbt_semantic_interfaces.references import (
MetricReference,
TimeDimensionReference,
Expand All @@ -16,7 +17,10 @@
from metricflow_semantics.collection_helpers.lru_cache import LruCache
from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import MetricFlowQueryResolutionIssueSet
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryResolutionIssue,
MetricFlowQueryResolutionIssueSet,
)
from metricflow_semantics.query.issues.parsing.cumulative_metric_requires_metric_time import (
CumulativeMetricRequiresMetricTimeIssue,
)
Expand All @@ -26,13 +30,14 @@
from metricflow_semantics.query.issues.parsing.scd_requires_metric_time import (
ScdRequiresMetricTimeIssue,
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import (
ResolverInputForQuery,
)
from metricflow_semantics.query.resolver_inputs.query_resolver_inputs import ResolverInputForQuery
from metricflow_semantics.query.validation_rules.base_validation_rule import PostResolutionQueryValidationRule
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec

if typing.TYPE_CHECKING:
from metricflow_semantics.query.query_resolver import ResolveGroupByItemsResult


@dataclass(frozen=True)
class QueryItemsAnalysis:
Expand All @@ -54,9 +59,16 @@ class MetricTimeQueryValidationRule(PostResolutionQueryValidationRule):
"""

def __init__( # noqa: D107
self, manifest_lookup: SemanticManifestLookup, resolver_input_for_query: ResolverInputForQuery
self,
manifest_lookup: SemanticManifestLookup,
resolver_input_for_query: ResolverInputForQuery,
resolve_group_by_item_result: ResolveGroupByItemsResult,
) -> None:
super().__init__(manifest_lookup=manifest_lookup, resolver_input_for_query=resolver_input_for_query)
super().__init__(
manifest_lookup=manifest_lookup,
resolver_input_for_query=resolver_input_for_query,
resolve_group_by_item_result=resolve_group_by_item_result,
)

self._metric_time_specs = tuple(
TimeDimensionSpec.generate_possible_specs_for_time_dimension(
Expand Down Expand Up @@ -109,6 +121,54 @@ def _uncached_query_items_analysis(
has_agg_time_dimension=has_agg_time_dimension,
)

def _validate_cumulative_metric(
self,
metric_reference: MetricReference,
metric: Metric,
query_items_analysis: QueryItemsAnalysis,
resolution_path: MetricFlowQueryResolutionPath,
) -> Sequence[MetricFlowQueryResolutionIssue]:
if (
metric.type_params is not None
and metric.type_params.cumulative_type_params is not None
and (
metric.type_params.cumulative_type_params.window is not None
or metric.type_params.cumulative_type_params.grain_to_date is not None
)
and not (query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension)
):
return (
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
query_resolution_path=resolution_path,
),
)
return ()

def _validate_derived_metric(
self,
metric_reference: MetricReference,
metric: Metric,
resolution_path: MetricFlowQueryResolutionPath,
query_items_analysis: QueryItemsAnalysis,
) -> Sequence[MetricFlowQueryResolutionIssue]:
has_time_offset = any(
input_metric.offset_window is not None or input_metric.offset_to_grain is not None
for input_metric in metric.input_metrics
)

if has_time_offset and not (
query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension
):
return (
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
input_metrics=metric.input_metrics,
query_resolution_path=resolution_path,
),
)
return ()

@override
def validate_metric_in_resolution_dag(
self,
Expand All @@ -119,54 +179,45 @@ def validate_metric_in_resolution_dag(

query_items_analysis = self._get_query_items_analysis(self._resolver_input_for_query, metric_reference)

issues = MetricFlowQueryResolutionIssueSet.empty_instance()
issues: List[MetricFlowQueryResolutionIssue] = []

# Queries that join to an SCD don't support direct references to agg_time_dimension, so we
# only check for metric_time. If we decide to support agg_time_dimension, we should add a check
if len(query_items_analysis.scds) > 0 and not query_items_analysis.has_metric_time:
issues = issues.add_issue(
issues.append(
ScdRequiresMetricTimeIssue.from_parameters(
scds_in_query=query_items_analysis.scds,
query_resolution_path=resolution_path,
)
)

if metric.type is MetricType.CUMULATIVE:
if (
metric.type_params is not None
and metric.type_params.cumulative_type_params is not None
and (
metric.type_params.cumulative_type_params.window is not None
or metric.type_params.cumulative_type_params.grain_to_date is not None
)
and not (query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension)
):
issues = issues.add_issue(
CumulativeMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
query_resolution_path=resolution_path,
)
issues.extend(
self._validate_cumulative_metric(
metric_reference=metric_reference,
metric=metric,
query_items_analysis=query_items_analysis,
resolution_path=resolution_path,
)
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
has_time_offset = any(
input_metric.offset_window is not None or input_metric.offset_to_grain is not None
for input_metric in metric.input_metrics
)

if has_time_offset and not (
query_items_analysis.has_metric_time or query_items_analysis.has_agg_time_dimension
):
issues = issues.add_issue(
OffsetMetricRequiresMetricTimeIssue.from_parameters(
metric_reference=metric_reference,
input_metrics=metric.input_metrics,
query_resolution_path=resolution_path,
)
elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
issues.extend(
self._validate_derived_metric(
metric_reference=metric_reference,
metric=metric,
query_items_analysis=query_items_analysis,
resolution_path=resolution_path,
)
elif metric.type is not MetricType.SIMPLE and metric.type is not MetricType.CONVERSION:
)
elif metric.type is MetricType.SIMPLE:
pass
elif metric.type is MetricType.CONVERSION:
pass
else:
assert_values_exhausted(metric.type)

return issues
return MetricFlowQueryResolutionIssueSet(issues=tuple(issues))

@override
def validate_query_in_resolution_dag(
Expand Down

0 comments on commit 7a715b5

Please sign in to comment.